diff --git a/ext/EinExprsGraphMakieExt.jl b/ext/EinExprsGraphMakieExt.jl index 482b864..99234a1 100644 --- a/ext/EinExprsGraphMakieExt.jl +++ b/ext/EinExprsGraphMakieExt.jl @@ -15,13 +15,13 @@ const MAX_EDGE_WIDTH = 10.0 const MAX_ARROW_SIZE = 35.0 const MAX_NODE_SIZE = 40.0 -function Makie.plot(path::EinExpr; kwargs...) +function Makie.plot(path::SizedEinExpr; kwargs...) f = Figure() ax, p = plot!(f[1, 1], path; kwargs...) return Makie.FigureAxisPlot(f, ax, p) end -function Makie.plot!(f::Union{Figure,GridPosition}, path::EinExpr; kwargs...) +function Makie.plot!(f::Union{Figure,GridPosition}, path::SizedEinExpr; kwargs...) ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3 Axis3(f[1, 1]) else @@ -65,13 +65,13 @@ end # TODO replace `to_colormap(:viridis)[begin:end-10]` with a custom colormap function Makie.plot!( ax::Union{Axis,Axis3}, - path::EinExpr; + path::SizedEinExpr; colormap = to_colormap(:viridis)[begin:end-10], inds = false, kwargs..., ) - handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path))) - graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path) for from in to.args]) + handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path.path))) + graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path.path) for from in to.args]) lin_size = length.(PostOrderDFS(path))[1:end-1] lin_flops = map(max, Iterators.repeated(1), Iterators.map(flops, PostOrderDFS(path))) diff --git a/test/Project.toml b/test/Project.toml index 93713ac..a8d1523 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"