diff --git a/Project.toml b/Project.toml index 84aab57..9132082 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ version = "0.5.10" AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -18,12 +17,14 @@ Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" [extensions] EinExprsChainRulesCoreExt = "ChainRulesCore" EinExprsFiniteDifferencesExt = "FiniteDifferences" EinExprsMakieExt = "Makie" +EinExprsGraphMakieExt = ["Makie", "GraphMakie"] [compat] AbstractTrees = "0.4" diff --git a/ext/EinExprsGraphMakieExt.jl b/ext/EinExprsGraphMakieExt.jl new file mode 100644 index 0000000..99234a1 --- /dev/null +++ b/ext/EinExprsGraphMakieExt.jl @@ -0,0 +1,146 @@ +module EinExprsGraphMakieExt + +using EinExprs +using EinExprs: Branches +using Graphs +using Makie +using GraphMakie +using AbstractTrees + +# NOTE this is a hack! removes NetworkLayout dependency but can be unstable +__networklayout_dim(x) = supertype(typeof(x)).parameters |> first + +# TODO rework size calculating algorithm +const MAX_EDGE_WIDTH = 10.0 +const MAX_ARROW_SIZE = 35.0 +const MAX_NODE_SIZE = 40.0 + +function Makie.plot(path::SizedEinExpr; kwargs...) + f = Figure() + ax, p = plot!(f[1, 1], path; kwargs...) + return Makie.FigureAxisPlot(f, ax, p) +end + +function Makie.plot!(f::Union{Figure,GridPosition}, path::SizedEinExpr; kwargs...) + ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3 + Axis3(f[1, 1]) + else + ax = Axis(f[1, 1]) + ax.aspect = DataAspect() + ax + end + + hidedecorations!(ax) + hidespines!(ax) + + p = plot!(ax, path; kwargs...) + + # plot colorbars + # TODO configurable `labelsize` + # TODO configurable alignments + Colorbar( + f[1, 2], + get_edge_plot(p); + label = "SIZE", + flipaxis = true, + flip_vertical_label = true, + labelsize = 24, + height = Relative(5 // 6), + scale = log2, + ) + + Colorbar( + f[1, 0], + get_node_plot(p); + label = "FLOPS", + flipaxis = false, + labelsize = 24, + height = Relative(5 // 6), + scale = log10, + ) + + return Makie.AxisPlot(ax, p) +end + +# TODO replace `to_colormap(:viridis)[begin:end-10]` with a custom colormap +function Makie.plot!( + ax::Union{Axis,Axis3}, + path::SizedEinExpr; + colormap = to_colormap(:viridis)[begin:end-10], + inds = false, + kwargs..., +) + handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path.path))) + graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path.path) for from in to.args]) + + lin_size = length.(PostOrderDFS(path))[1:end-1] + lin_flops = map(max, Iterators.repeated(1), Iterators.map(flops, PostOrderDFS(path))) + + log_size = log2.(lin_size) + log_flops = log10.(lin_flops) + + kwargs = Dict{Symbol,Any}(kwargs) + + # configure graphics + get!(kwargs, :edge_width) do + map(log_size ./ maximum(log_size) .* MAX_EDGE_WIDTH) do x + iszero(x) ? 4.0 : x + end + end + + get!(kwargs, :arrow_size) do + map(log_size ./ maximum(log_size) .* MAX_ARROW_SIZE) do x + iszero(x) ? 30.0 : x + end + end + + get!(() -> log_flops ./ maximum(log_flops) .* MAX_NODE_SIZE, kwargs, :node_size) + + get!(kwargs, :edge_color, lin_size) + get!(kwargs, :node_color, lin_flops) + + get!( + kwargs, + :arrow_attr, + ( + colorrange = extrema(lin_size), + colormap = colormap, + colorscale = log2, + highclip = Makie.Automatic(), + lowclip = Makie.Automatic(), + ), + ) + get!( + kwargs, + :edge_attr, + ( + colorrange = extrema(lin_size), + colormap = colormap, + colorscale = log2, + highclip = Makie.Automatic(), + lowclip = Makie.Automatic(), + ), + ) + # TODO replace `to_colormap(:plasma)[begin:end-50]), kwargs...)` with a custom colormap + get!( + kwargs, + :node_attr, + ( + colorrange = extrema(lin_flops), + colormap = to_colormap(:plasma)[begin:end-50], + colorscale = log10, + highclip = Makie.Automatic(), + lowclip = Makie.Automatic(), + ), + ) + + # configure labels + inds == true && get!(() -> join.(head.(PostOrderDFS(path)))[1:end-1], kwargs, :elabels) + get!(() -> repeat([:black], ne(graph)), kwargs, :elabels_color) + get!(() -> log_size ./ maximum(log_size) .* 5 .+ 12, kwargs, :elabels_textsize) + + # plot graph + graphplot!(ax, graph; kwargs...) +end + +end diff --git a/ext/EinExprsMakieExt.jl b/ext/EinExprsMakieExt.jl index 79e52ed..989f762 100644 --- a/ext/EinExprsMakieExt.jl +++ b/ext/EinExprsMakieExt.jl @@ -1,146 +1,13 @@ module EinExprsMakieExt -using EinExprs -using EinExprs: Branches -using Graphs -using Makie -using GraphMakie -using AbstractTrees - -# NOTE this is a hack! removes NetworkLayout dependency but can be unstable -__networklayout_dim(x) = supertype(typeof(x)).parameters |> first - -# TODO rework size calculating algorithm -const MAX_EDGE_WIDTH = 10.0 -const MAX_ARROW_SIZE = 35.0 -const MAX_NODE_SIZE = 40.0 - -function Makie.plot(path::SizedEinExpr; kwargs...) - f = Figure() - ax, p = plot!(f[1, 1], path; kwargs...) - return Makie.FigureAxisPlot(f, ax, p) -end - -function Makie.plot!(f::Union{Figure,GridPosition}, path::SizedEinExpr; kwargs...) - ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3 - Axis3(f[1, 1]) - else - ax = Axis(f[1, 1]) - ax.aspect = DataAspect() - ax - end - - hidedecorations!(ax) - hidespines!(ax) - - p = plot!(ax, path; kwargs...) - - # plot colorbars - # TODO configurable `labelsize` - # TODO configurable alignments - Colorbar( - f[1, 2], - get_edge_plot(p); - label = "SIZE", - flipaxis = true, - flip_vertical_label = true, - labelsize = 24, - height = Relative(5 // 6), - scale = log2, - ) - - Colorbar( - f[1, 0], - get_node_plot(p); - label = "FLOPS", - flipaxis = false, - labelsize = 24, - height = Relative(5 // 6), - scale = log10, - ) - - return Makie.AxisPlot(ax, p) -end - -# TODO replace `to_colormap(:viridis)[begin:end-10]` with a custom colormap -function Makie.plot!( - ax::Union{Axis,Axis3}, - path::SizedEinExpr; - colormap = to_colormap(:viridis)[begin:end-10], - inds = false, - kwargs..., -) - handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path.path))) - graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path.path) for from in to.args]) - - lin_size = length.(PostOrderDFS(path))[1:end-1] - lin_flops = map(max, Iterators.repeated(1), Iterators.map(flops, PostOrderDFS(path))) - - log_size = log2.(lin_size) - log_flops = log10.(lin_flops) - - kwargs = Dict{Symbol,Any}(kwargs) - - # configure graphics - get!(kwargs, :edge_width) do - map(log_size ./ maximum(log_size) .* MAX_EDGE_WIDTH) do x - iszero(x) ? 4.0 : x - end - end - - get!(kwargs, :arrow_size) do - map(log_size ./ maximum(log_size) .* MAX_ARROW_SIZE) do x - iszero(x) ? 30.0 : x - end +function __init__() + try + Base.require(Main, :GraphMakie) + catch + @warn """Package GraphMakie not found in current path. It is needed to plot `EinExpr`s with `Makie`. + - Run `import Pkg; Pkg.add(\"GraphMakie\")` or `]add GraphMakie` to install the GraphMakie package, then restart julia. + """ end - - get!(() -> log_flops ./ maximum(log_flops) .* MAX_NODE_SIZE, kwargs, :node_size) - - get!(kwargs, :edge_color, lin_size) - get!(kwargs, :node_color, lin_flops) - - get!( - kwargs, - :arrow_attr, - ( - colorrange = extrema(lin_size), - colormap = colormap, - colorscale = log2, - highclip = Makie.Automatic(), - lowclip = Makie.Automatic(), - ), - ) - get!( - kwargs, - :edge_attr, - ( - colorrange = extrema(lin_size), - colormap = colormap, - colorscale = log2, - highclip = Makie.Automatic(), - lowclip = Makie.Automatic(), - ), - ) - # TODO replace `to_colormap(:plasma)[begin:end-50]), kwargs...)` with a custom colormap - get!( - kwargs, - :node_attr, - ( - colorrange = extrema(lin_flops), - colormap = to_colormap(:plasma)[begin:end-50], - colorscale = log10, - highclip = Makie.Automatic(), - lowclip = Makie.Automatic(), - ), - ) - - # configure labels - inds == true && get!(() -> join.(head.(PostOrderDFS(path)))[1:end-1], kwargs, :elabels) - get!(() -> repeat([:black], ne(graph)), kwargs, :elabels_color) - get!(() -> log_size ./ maximum(log_size) .* 5 .+ 12, kwargs, :elabels_textsize) - - # plot graph - graphplot!(ax, graph; kwargs...) end end diff --git a/test/Project.toml b/test/Project.toml index 93713ac..a8d1523 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"