Skip to content

Commit

Permalink
Require GraphMakie explicitly to plot (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing authored Jan 18, 2024
1 parent 32e21ac commit a29e672
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 141 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ version = "0.5.10"
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -18,12 +17,14 @@ Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"

[extensions]
EinExprsChainRulesCoreExt = "ChainRulesCore"
EinExprsFiniteDifferencesExt = "FiniteDifferences"
EinExprsMakieExt = "Makie"
EinExprsGraphMakieExt = ["Makie", "GraphMakie"]

[compat]
AbstractTrees = "0.4"
Expand Down
146 changes: 146 additions & 0 deletions ext/EinExprsGraphMakieExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
module EinExprsGraphMakieExt

using EinExprs
using EinExprs: Branches
using Graphs
using Makie
using GraphMakie
using AbstractTrees

# NOTE this is a hack! removes NetworkLayout dependency but can be unstable
__networklayout_dim(x) = supertype(typeof(x)).parameters |> first

# TODO rework size calculating algorithm
const MAX_EDGE_WIDTH = 10.0
const MAX_ARROW_SIZE = 35.0
const MAX_NODE_SIZE = 40.0

function Makie.plot(path::SizedEinExpr; kwargs...)
f = Figure()
ax, p = plot!(f[1, 1], path; kwargs...)
return Makie.FigureAxisPlot(f, ax, p)
end

function Makie.plot!(f::Union{Figure,GridPosition}, path::SizedEinExpr; kwargs...)
ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3
Axis3(f[1, 1])
else
ax = Axis(f[1, 1])
ax.aspect = DataAspect()
ax
end

hidedecorations!(ax)
hidespines!(ax)

p = plot!(ax, path; kwargs...)

# plot colorbars
# TODO configurable `labelsize`
# TODO configurable alignments
Colorbar(
f[1, 2],
get_edge_plot(p);
label = "SIZE",
flipaxis = true,
flip_vertical_label = true,
labelsize = 24,
height = Relative(5 // 6),
scale = log2,
)

Colorbar(
f[1, 0],
get_node_plot(p);
label = "FLOPS",
flipaxis = false,
labelsize = 24,
height = Relative(5 // 6),
scale = log10,
)

return Makie.AxisPlot(ax, p)
end

# TODO replace `to_colormap(:viridis)[begin:end-10]` with a custom colormap
function Makie.plot!(
ax::Union{Axis,Axis3},
path::SizedEinExpr;
colormap = to_colormap(:viridis)[begin:end-10],
inds = false,
kwargs...,
)
handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path.path)))
graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path.path) for from in to.args])

lin_size = length.(PostOrderDFS(path))[1:end-1]
lin_flops = map(max, Iterators.repeated(1), Iterators.map(flops, PostOrderDFS(path)))

log_size = log2.(lin_size)
log_flops = log10.(lin_flops)

kwargs = Dict{Symbol,Any}(kwargs)

# configure graphics
get!(kwargs, :edge_width) do
map(log_size ./ maximum(log_size) .* MAX_EDGE_WIDTH) do x
iszero(x) ? 4.0 : x
end
end

get!(kwargs, :arrow_size) do
map(log_size ./ maximum(log_size) .* MAX_ARROW_SIZE) do x
iszero(x) ? 30.0 : x
end
end

get!(() -> log_flops ./ maximum(log_flops) .* MAX_NODE_SIZE, kwargs, :node_size)

get!(kwargs, :edge_color, lin_size)
get!(kwargs, :node_color, lin_flops)

get!(
kwargs,
:arrow_attr,
(
colorrange = extrema(lin_size),
colormap = colormap,
colorscale = log2,
highclip = Makie.Automatic(),
lowclip = Makie.Automatic(),
),
)
get!(
kwargs,
:edge_attr,
(
colorrange = extrema(lin_size),
colormap = colormap,
colorscale = log2,
highclip = Makie.Automatic(),
lowclip = Makie.Automatic(),
),
)
# TODO replace `to_colormap(:plasma)[begin:end-50]), kwargs...)` with a custom colormap
get!(
kwargs,
:node_attr,
(
colorrange = extrema(lin_flops),
colormap = to_colormap(:plasma)[begin:end-50],
colorscale = log10,
highclip = Makie.Automatic(),
lowclip = Makie.Automatic(),
),
)

# configure labels
inds == true && get!(() -> join.(head.(PostOrderDFS(path)))[1:end-1], kwargs, :elabels)
get!(() -> repeat([:black], ne(graph)), kwargs, :elabels_color)
get!(() -> log_size ./ maximum(log_size) .* 5 .+ 12, kwargs, :elabels_textsize)

# plot graph
graphplot!(ax, graph; kwargs...)
end

end
147 changes: 7 additions & 140 deletions ext/EinExprsMakieExt.jl
Original file line number Diff line number Diff line change
@@ -1,146 +1,13 @@
module EinExprsMakieExt

using EinExprs
using EinExprs: Branches
using Graphs
using Makie
using GraphMakie
using AbstractTrees

# NOTE this is a hack! removes NetworkLayout dependency but can be unstable
__networklayout_dim(x) = supertype(typeof(x)).parameters |> first

# TODO rework size calculating algorithm
const MAX_EDGE_WIDTH = 10.0
const MAX_ARROW_SIZE = 35.0
const MAX_NODE_SIZE = 40.0

function Makie.plot(path::SizedEinExpr; kwargs...)
f = Figure()
ax, p = plot!(f[1, 1], path; kwargs...)
return Makie.FigureAxisPlot(f, ax, p)
end

function Makie.plot!(f::Union{Figure,GridPosition}, path::SizedEinExpr; kwargs...)
ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3
Axis3(f[1, 1])
else
ax = Axis(f[1, 1])
ax.aspect = DataAspect()
ax
end

hidedecorations!(ax)
hidespines!(ax)

p = plot!(ax, path; kwargs...)

# plot colorbars
# TODO configurable `labelsize`
# TODO configurable alignments
Colorbar(
f[1, 2],
get_edge_plot(p);
label = "SIZE",
flipaxis = true,
flip_vertical_label = true,
labelsize = 24,
height = Relative(5 // 6),
scale = log2,
)

Colorbar(
f[1, 0],
get_node_plot(p);
label = "FLOPS",
flipaxis = false,
labelsize = 24,
height = Relative(5 // 6),
scale = log10,
)

return Makie.AxisPlot(ax, p)
end

# TODO replace `to_colormap(:viridis)[begin:end-10]` with a custom colormap
function Makie.plot!(
ax::Union{Axis,Axis3},
path::SizedEinExpr;
colormap = to_colormap(:viridis)[begin:end-10],
inds = false,
kwargs...,
)
handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path.path)))
graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path.path) for from in to.args])

lin_size = length.(PostOrderDFS(path))[1:end-1]
lin_flops = map(max, Iterators.repeated(1), Iterators.map(flops, PostOrderDFS(path)))

log_size = log2.(lin_size)
log_flops = log10.(lin_flops)

kwargs = Dict{Symbol,Any}(kwargs)

# configure graphics
get!(kwargs, :edge_width) do
map(log_size ./ maximum(log_size) .* MAX_EDGE_WIDTH) do x
iszero(x) ? 4.0 : x
end
end

get!(kwargs, :arrow_size) do
map(log_size ./ maximum(log_size) .* MAX_ARROW_SIZE) do x
iszero(x) ? 30.0 : x
end
function __init__()
try
Base.require(Main, :GraphMakie)
catch
@warn """Package GraphMakie not found in current path. It is needed to plot `EinExpr`s with `Makie`.
- Run `import Pkg; Pkg.add(\"GraphMakie\")` or `]add GraphMakie` to install the GraphMakie package, then restart julia.
"""
end

get!(() -> log_flops ./ maximum(log_flops) .* MAX_NODE_SIZE, kwargs, :node_size)

get!(kwargs, :edge_color, lin_size)
get!(kwargs, :node_color, lin_flops)

get!(
kwargs,
:arrow_attr,
(
colorrange = extrema(lin_size),
colormap = colormap,
colorscale = log2,
highclip = Makie.Automatic(),
lowclip = Makie.Automatic(),
),
)
get!(
kwargs,
:edge_attr,
(
colorrange = extrema(lin_size),
colormap = colormap,
colorscale = log2,
highclip = Makie.Automatic(),
lowclip = Makie.Automatic(),
),
)
# TODO replace `to_colormap(:plasma)[begin:end-50]), kwargs...)` with a custom colormap
get!(
kwargs,
:node_attr,
(
colorrange = extrema(lin_flops),
colormap = to_colormap(:plasma)[begin:end-50],
colorscale = log10,
highclip = Makie.Automatic(),
lowclip = Makie.Automatic(),
),
)

# configure labels
inds == true && get!(() -> join.(head.(PostOrderDFS(path)))[1:end-1], kwargs, :elabels)
get!(() -> repeat([:black], ne(graph)), kwargs, :elabels_color)
get!(() -> log_size ./ maximum(log_size) .* 5 .+ 12, kwargs, :elabels_textsize)

# plot graph
graphplot!(ax, graph; kwargs...)
end

end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

0 comments on commit a29e672

Please sign in to comment.