diff --git a/Project.toml b/Project.toml
index 3ac8010..407d7d9 100644
--- a/Project.toml
+++ b/Project.toml
@@ -13,9 +13,11 @@ TreeWidthSolver = "7d267fc5-9ace-409f-a54c-cd2374872a55"
[weakdeps]
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
+LuxorGraphPlot = "1f49bdf2-22a7-4bc4-978b-948dc219fbbc"
[extensions]
KaHyParExt = ["KaHyPar"]
+LuxorTensorPlot = ["LuxorGraphPlot"]
[compat]
AbstractTrees = "0.3, 0.4"
@@ -23,6 +25,7 @@ JSON = "0.21"
KaHyPar = "0.3"
StatsBase = "0.34"
Suppressor = "0.2"
+LuxorGraphPlot = "0.5.1"
TreeWidthSolver = "0.1.0"
julia = "1.9"
@@ -32,6 +35,7 @@ OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
+LuxorGraphPlot = "1f49bdf2-22a7-4bc4-978b-948dc219fbbc"
[targets]
-test = ["Test", "Random", "Graphs", "TropicalNumbers", "OMEinsum", "KaHyPar"]
+test = ["Test", "Random", "Graphs", "TropicalNumbers", "OMEinsum", "KaHyPar", "LuxorGraphPlot"]
diff --git a/README.md b/README.md
index 1b51d51..dfef18c 100644
--- a/README.md
+++ b/README.md
@@ -87,6 +87,52 @@ SlicedEinsum{Char, NestedEinsum{DynamicEinCode{Char}}}(Char[], ki, ki ->
)
```
+## Extensions
+
+### LuxorTensorPlot
+
+`LuxorTensorPlot` is an extension of the `OMEinsumContractionOrders` package that provides a visualization of the contraction order. It is designed to work with the `OMEinsumContractionOrders` package. To use `LuxorTensorPlot`, please follow these steps:
+```julia
+pkg> add OMEinsumContractionOrders, LuxorGraphPlot
+
+julia> using OMEinsumContractionOrders, LuxorGraphPlot
+```
+and then the extension will be loaded automatically.
+
+The extension provides the following to function, `viz_eins` and `viz_contraction`, where the former will plot the tensor network as a graph, and the latter will generate a video or gif of the contraction process.
+Here is an example:
+```julia
+julia> using OMEinsumContractionOrders, LuxorGraphPlot
+
+julia> eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a'])
+ab, acd, bcef, e, df -> a
+
+julia> viz_eins(eincode, filename = "eins.png")
+
+julia> nested_eins = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod())
+ab, ab -> a
+├─ ab
+└─ acf, bcf -> ab
+ ├─ acd, df -> acf
+ │ ├─ acd
+ │ └─ df
+ └─ bcef, e -> bcf
+ ├─ bcef
+ └─ e
+
+
+julia> viz_contraction(nested_code)
+[ Info: Generating frames, 7 frames in total
+[ Info: Creating video at: /var/folders/3y/xl2h1bxj4ql27p01nl5hrrnc0000gn/T/jl_SiSvrH/contraction.mp4
+"/var/folders/3y/xl2h1bxj4ql27p01nl5hrrnc0000gn/T/jl_SiSvrH/contraction.mp4"
+```
+
+The resulting image and video will be saved in the current working directory, and the image is shown below:
+
+
+
+The large white nodes represent the tensors, and the small colored nodes represent the indices, red for closed indices and green for open indices.
+
## References
If you find this package useful in your research, please cite the *relevant* papers in [CITATION.bib](CITATION.bib).
diff --git a/examples/eins.png b/examples/eins.png
new file mode 100644
index 0000000..a02363f
Binary files /dev/null and b/examples/eins.png differ
diff --git a/examples/visualization.jl b/examples/visualization.jl
new file mode 100644
index 0000000..d8b37d7
--- /dev/null
+++ b/examples/visualization.jl
@@ -0,0 +1,8 @@
+using OMEinsumContractionOrders, LuxorGraphPlot
+
+eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a'])
+
+viz_eins(eincode, filename = "eins.png")
+
+nested_eins = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod())
+viz_contraction(nested_eins)
\ No newline at end of file
diff --git a/ext/LuxorTensorPlot.jl b/ext/LuxorTensorPlot.jl
new file mode 100644
index 0000000..0d46624
--- /dev/null
+++ b/ext/LuxorTensorPlot.jl
@@ -0,0 +1,5 @@
+module LuxorTensorPlot
+
+include("LuxorTensorPlot/src/LuxorTensorPlot.jl")
+
+end
\ No newline at end of file
diff --git a/ext/LuxorTensorPlot/src/LuxorTensorPlot.jl b/ext/LuxorTensorPlot/src/LuxorTensorPlot.jl
new file mode 100644
index 0000000..09961e1
--- /dev/null
+++ b/ext/LuxorTensorPlot/src/LuxorTensorPlot.jl
@@ -0,0 +1,14 @@
+using OMEinsumContractionOrders, LuxorGraphPlot
+
+using OMEinsumContractionOrders.SparseArrays
+using LuxorGraphPlot.Graphs
+using LuxorGraphPlot.Luxor
+using LuxorGraphPlot.Luxor.FFMPEG
+
+using OMEinsumContractionOrders: AbstractEinsum, NestedEinsum, SlicedEinsum
+using OMEinsumContractionOrders: getixsv, getiyv
+using OMEinsumContractionOrders: ein2hypergraph, ein2elimination
+
+include("hypergraph.jl")
+include("viz_eins.jl")
+include("viz_contraction.jl")
\ No newline at end of file
diff --git a/ext/LuxorTensorPlot/src/hypergraph.jl b/ext/LuxorTensorPlot/src/hypergraph.jl
new file mode 100644
index 0000000..3cfe87d
--- /dev/null
+++ b/ext/LuxorTensorPlot/src/hypergraph.jl
@@ -0,0 +1,71 @@
+struct LabeledHyperGraph{TS, TV, TE}
+ adjacency_matrix::SparseMatrixCSC{TS}
+ vertex_labels::Vector{TV}
+ edge_labels::Vector{TE}
+ open_edges::Vector{TE}
+
+ function LabeledHyperGraph(adjacency_matrix::SparseMatrixCSC{TS}; vl::Vector{TV} = [1:size(adjacency_matrix, 1)...], el::Vector{TE} = [1:size(adjacency_matrix, 2)...], oe::Vector = []) where{TS, TV, TE}
+ if size(adjacency_matrix, 1) != length(vl)
+ throw(ArgumentError("Number of vertices does not match number of vertex labels"))
+ end
+ if size(adjacency_matrix, 2) != length(el)
+ throw(ArgumentError("Number of edges does not match number of edge labels"))
+ end
+ if !all(oei in el for oei in oe)
+ throw(ArgumentError("Open edges must be in edge labels"))
+ end
+ if isempty(oe)
+ oe = Vector{TE}()
+ end
+ new{TS, TV, TE}(adjacency_matrix, vl, el, oe)
+ end
+end
+
+Base.show(io::IO, g::LabeledHyperGraph{TS, TV, TE}) where{TS,TV,TE} = print(io, "LabeledHyperGraph{$TS, $TV, $TE} \n adjacency_mat: $(g.adjacency_matrix) \n vertex: $(g.vertex_labels) \n edges: $(g.edge_labels)) \n open_edges: $(g.open_edges)")
+
+Base.:(==)(a::LabeledHyperGraph, b::LabeledHyperGraph) = a.adjacency_matrix == b.adjacency_matrix && a.vertex_labels == b.vertex_labels && a.edge_labels == b.edge_labels && a.open_edges == b.open_edges
+
+struct TensorNetworkGraph{TT, TI}
+ graph::SimpleGraph
+ tensors_labels::Dict{Int, TT}
+ indices_labels::Dict{Int, TI}
+ open_indices::Vector{TI}
+
+ function TensorNetworkGraph(graph::SimpleGraph; tl::Dict{Int, TT} = Dict{Int, Int}(), il::Dict{Int, TI} = Dict{Int, Int}(), oi::Vector = []) where{TT, TI}
+ if length(tl) + length(il) != nv(graph)
+ throw(ArgumentError("Number of tensors + indices does not match number of vertices"))
+ end
+ if !all(oii in values(il) for oii in oi)
+ throw(ArgumentError("Open indices must be in indices"))
+ end
+ if isempty(oi)
+ oi = Vector{TI}()
+ end
+ new{TT, TI}(graph, tl, il, oi)
+ end
+end
+
+Base.show(io::IO, g::TensorNetworkGraph{TT, TI}) where{TT, TI} = print(io, "TensorNetworkGraph{$TT, $TI} \n graph: {$(nv(g.graph)), $(ne(g.graph))} \n tensors: $(g.tensors_labels) \n indices: $(g.indices_labels)) \n open_indices: $(g.open_indices)")
+
+# convert the labeled hypergraph to a tensor network graph, where vertices and edges of the hypergraph are mapped as the vertices of the tensor network graph, and the open edges are recorded.
+function TensorNetworkGraph(lhg::LabeledHyperGraph{TS, TV, TE}) where{TS, TV, TE}
+ graph = SimpleGraph(length(lhg.vertex_labels) + length(lhg.edge_labels))
+ tensors_labels = Dict{Int, TV}()
+ indices_labels = Dict{Int, TE}()
+
+ lv = length(lhg.vertex_labels)
+ for i in 1:length(lhg.vertex_labels)
+ tensors_labels[i] = lhg.vertex_labels[i]
+ end
+ for i in 1:length(lhg.edge_labels)
+ indices_labels[i + lv] = lhg.edge_labels[i]
+ end
+
+ for i in 1:size(lhg.adjacency_matrix, 1)
+ for j in findall(!iszero, lhg.adjacency_matrix[i, :])
+ add_edge!(graph, i, j + lv)
+ end
+ end
+
+ TensorNetworkGraph(graph, tl=tensors_labels, il=indices_labels, oi=lhg.open_edges)
+end
\ No newline at end of file
diff --git a/ext/LuxorTensorPlot/src/viz_contraction.jl b/ext/LuxorTensorPlot/src/viz_contraction.jl
new file mode 100644
index 0000000..12af417
--- /dev/null
+++ b/ext/LuxorTensorPlot/src/viz_contraction.jl
@@ -0,0 +1,98 @@
+function OMEinsumContractionOrders.ein2elimination(code::NestedEinsum{T}) where{T}
+ elimination_order = Vector{T}()
+ _ein2elimination!(code, elimination_order)
+ return elimination_order
+end
+
+function OMEinsumContractionOrders.ein2elimination(code::SlicedEinsum{T, NestedEinsum{T}}) where{T}
+ elimination_order = Vector{T}()
+ _ein2elimination!(code.eins, elimination_order)
+ # the slicing indices are eliminated at the end
+ return vcat(elimination_order, code.slicing)
+end
+
+function _ein2elimination!(code::NestedEinsum{T}, elimination_order::Vector{T}) where{T}
+ if code.tensorindex == -1
+ for arg in code.args
+ _ein2elimination!(arg, elimination_order)
+ end
+ iy = unique(vcat(getiyv(code.eins)...))
+ for ix in unique(vcat(getixsv(code.eins)...))
+ if !(ix in iy) && !(ix in elimination_order)
+ push!(elimination_order, ix)
+ end
+ end
+ end
+ return elimination_order
+end
+
+function elimination_frame(gviz::GraphViz, tng::TensorNetworkGraph{TG, TL}, elimination_order::Vector{TL}, i::Int; filename = nothing) where{TG, TL}
+ gviz2 = deepcopy(gviz)
+ for j in 1:i
+ id = _get_key(tng.indices_labels, elimination_order[j])
+ gviz2.vertex_colors[id] = (0.5, 0.5, 0.5, 0.5)
+ end
+ return show_graph(gviz2, filename = filename)
+end
+
+function OMEinsumContractionOrders.viz_contraction(code::T, args...; kwargs...) where{T <: AbstractEinsum}
+ throw(ArgumentError("Only NestedEinsum and SlicedEinsum{T, NestedEinsum{T}} have contraction order"))
+end
+
+"""
+ viz_contraction(code::Union{NestedEinsum, SlicedEinsum}; locs=StressLayout(), framerate=10, filename=tempname() * ".mp4", show_progress=true)
+
+Visualize the contraction process of a tensor network.
+
+### Arguments
+- `code`: The tensor network to visualize.
+
+### Keyword Arguments
+- `locs`: The coordinates or layout algorithm to use for positioning the nodes in the graph. Default is `StressLayout()`.
+- `framerate`: The frame rate of the animation. Default is `10`.
+- `filename`: The name of the output file, with `.gif` or `.mp4` extension. Default is a temporary file with `.mp4` extension.
+- `show_progress`: Whether to show progress information. Default is `true`.
+
+# Returns
+- the path of the generated file.
+"""
+function OMEinsumContractionOrders.viz_contraction(
+ code::Union{NestedEinsum, SlicedEinsum};
+ locs=StressLayout(),
+ framerate = 10,
+ filename::String = tempname() * ".mp4",
+ show_progress::Bool = true)
+
+ # analyze the output format
+ @assert endswith(filename, ".gif") || endswith(filename, ".mp4") "Unsupported file format: $filename, only :gif and :mp4 are supported"
+ tempdirectory = mktempdir()
+
+ # generate the frames
+ elimination_order = ein2elimination(code)
+ tng = TensorNetworkGraph(ein2hypergraph(code))
+ gviz = GraphViz(tng, locs)
+
+ le = length(elimination_order)
+ for i in 0:le
+ show_progress && @info "Frame $(i + 1) of $(le + 1)"
+ fig_name = joinpath(tempdirectory, "$(lpad(i+1, 10, "0")).png")
+ elimination_frame(gviz, tng, elimination_order, i; filename = fig_name)
+ end
+
+ if endswith(filename, ".gif")
+ Luxor.FFMPEG.exe(`-loglevel panic -r $(framerate) -f image2 -i $(tempdirectory)/%10d.png -filter_complex "[0:v] split [a][b]; [a] palettegen=stats_mode=full:reserve_transparent=on:transparency_color=FFFFFF [p]; [b][p] paletteuse=new=1:alpha_threshold=128" -y $filename`)
+ else
+ Luxor.FFMPEG.ffmpeg_exe(`
+ -loglevel panic
+ -r $(framerate)
+ -f image2
+ -i $(tempdirectory)/%10d.png
+ -c:v libx264
+ -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2"
+ -r $(framerate)
+ -pix_fmt yuv420p
+ -y $filename`)
+ end
+ show_progress && @info "Generated output at: $filename"
+ return filename
+end
\ No newline at end of file
diff --git a/ext/LuxorTensorPlot/src/viz_eins.jl b/ext/LuxorTensorPlot/src/viz_eins.jl
new file mode 100644
index 0000000..43c6ff6
--- /dev/null
+++ b/ext/LuxorTensorPlot/src/viz_eins.jl
@@ -0,0 +1,82 @@
+function LuxorGraphPlot.GraphViz(tng::TensorNetworkGraph, locs=StressLayout(); highlight::Vector=[], highlight_color = (0.0, 0.0, 255.0, 0.5), kwargs...)
+
+ white = (255.0, 255.0, 255.0, 0.8)
+ black = (0.0, 0.0, 0.0, 1.0)
+ r = (255.0, 0.0, 0.0, 0.8)
+ g = (0.0, 255.0, 0.0, 0.8)
+
+ colors = Vector{typeof(r)}()
+ text = Vector{String}()
+ sizes = Vector{Float64}()
+
+ for i in 1:nv(tng.graph)
+ if i in keys(tng.tensors_labels)
+ push!(colors, white)
+ push!(text, string(tng.tensors_labels[i]))
+ push!(sizes, 20.0)
+ else
+ push!(colors, r)
+ push!(text, string(tng.indices_labels[i]))
+ push!(sizes, 10.0)
+ end
+ end
+
+ for oi in tng.open_indices
+ id = _get_key(tng.indices_labels, oi)
+ colors[id] = g
+ end
+
+ for hl in highlight
+ id = _get_key(tng.indices_labels, hl)
+ colors[id] = highlight_color
+ end
+
+ return GraphViz(tng.graph, locs, texts = text, vertex_colors = colors, vertex_sizes = sizes, kwargs...)
+end
+
+function _get_key(dict::Dict, value)
+ for (key, val) in dict
+ if val == value
+ return key
+ end
+ end
+ @error "Value not found in dictionary"
+end
+
+function OMEinsumContractionOrders.ein2hypergraph(code::T) where{T <: AbstractEinsum}
+ ixs = getixsv(code)
+ iy = getiyv(code)
+
+ edges = unique!([Iterators.flatten(ixs)...])
+ open_edges = [iy[i] for i in 1:length(iy) if iy[i] in edges]
+
+ rows = Int[]
+ cols = Int[]
+ for (i,ix) in enumerate(ixs)
+ push!(rows, map(x->i, ix)...)
+ push!(cols, map(x->findfirst(==(x), edges), ix)...)
+ end
+ adj = sparse(rows, cols, ones(Int, length(rows)))
+
+ return LabeledHyperGraph(adj, el = edges, oe = open_edges)
+end
+
+"""
+ viz_eins(code::AbstractEinsum; locs=StressLayout(), filename = nothing, kwargs...)
+
+Visualizes an `AbstractEinsum` object by creating a tensor network graph and rendering it using GraphViz.
+
+### Arguments
+- `code::AbstractEinsum`: The `AbstractEinsum` object to visualize.
+
+### Keyword Arguments
+- `locs=StressLayout()`: The coordinates or layout algorithm to use for positioning the nodes in the graph.
+- `filename = nothing`: The name of the file to save the visualization to. If `nothing`, the visualization will be displayed on the screen instead of saving to a file.
+- `config = GraphDisplayConfig()`: The configuration for displaying the graph. Please refer to the documentation of [`GraphDisplayConfig`](https://giggleliu.github.io/LuxorGraphPlot.jl/dev/ref/#LuxorGraphPlot.GraphDisplayConfig) for more information.
+- `kwargs...`: Additional keyword arguments to be passed to the [`GraphViz`](https://giggleliu.github.io/LuxorGraphPlot.jl/dev/ref/#LuxorGraphPlot.GraphViz) constructor.
+"""
+function OMEinsumContractionOrders.viz_eins(code::AbstractEinsum; locs=StressLayout(), filename = nothing, config=LuxorTensorPlot.GraphDisplayConfig(), kwargs...)
+ tng = TensorNetworkGraph(ein2hypergraph(code))
+ gviz = GraphViz(tng, locs; kwargs...)
+ return show_graph(gviz; filename, config)
+end
\ No newline at end of file
diff --git a/src/OMEinsumContractionOrders.jl b/src/OMEinsumContractionOrders.jl
index 2bcc141..95bd99b 100644
--- a/src/OMEinsumContractionOrders.jl
+++ b/src/OMEinsumContractionOrders.jl
@@ -19,6 +19,9 @@ export CodeOptimizer, CodeSimplifier,
label_elimination_order
# writejson, readjson are not exported to avoid namespace conflict
+# visiualization tools provided by extension `LuxorTensorPlot`
+export viz_eins, viz_contraction
+
include("Core.jl")
include("utils.jl")
@@ -46,6 +49,9 @@ include("interfaces.jl")
# saveload
include("json.jl")
+# extension for visiualization
+include("visualization.jl")
+
@deprecate timespacereadwrite_complexity(code, size_dict::Dict) (contraction_complexity(code, size_dict)...,)
@deprecate timespace_complexity(code, size_dict::Dict) (contraction_complexity(code, size_dict)...,)[1:2]
diff --git a/src/visualization.jl b/src/visualization.jl
new file mode 100644
index 0000000..bcff1a1
--- /dev/null
+++ b/src/visualization.jl
@@ -0,0 +1,15 @@
+function ein2hypergraph(args...; kwargs...)
+ throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`."))
+end
+
+function ein2elimination(args...; kwargs...)
+ throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`."))
+end
+
+function viz_eins(args...; kwargs...)
+ throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`."))
+end
+
+function viz_contraction(args...; kwargs...)
+ throw(ArgumentError("Extension `LuxorTensorPlot` not loaeded, please load it first by `using LuxorGraphPlot`."))
+end
\ No newline at end of file
diff --git a/test/runtests.jl b/test/runtests.jl
index 9e9bfcd..ebb3c28 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -36,3 +36,8 @@ end
@testset "json" begin
include("json.jl")
end
+
+# testing the extension `LuxorTensorPlot` for visualization
+@testset "visualization" begin
+ include("visualization.jl")
+end
\ No newline at end of file
diff --git a/test/visualization.jl b/test/visualization.jl
new file mode 100644
index 0000000..eda62aa
--- /dev/null
+++ b/test/visualization.jl
@@ -0,0 +1,108 @@
+using OMEinsum
+using OMEinsumContractionOrders: ein2hypergraph, ein2elimination
+using Test, OMEinsumContractionOrders
+
+# tests before the extension loaded
+@testset "luxor tensor plot dependency check" begin
+ @test_throws ArgumentError begin
+ eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a'])
+ ein2hypergraph(eincode)
+ end
+
+ @test_throws ArgumentError begin
+ eincode = OMEinsum.rawcode(ein"((ij, jk), kl), lm -> im")
+ ein2elimination(eincode)
+ end
+
+ @test_throws ArgumentError begin
+ eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], Vector{Char}())
+ viz_eins(eincode)
+ end
+
+ @test_throws ArgumentError begin
+ eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], Vector{Char}())
+ nested_code = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod())
+ viz_contraction(nested_code, pathname = "")
+ end
+end
+
+using LuxorGraphPlot
+using LuxorGraphPlot.Luxor
+
+@testset "eincode to hypergraph" begin
+ eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a'])
+ g1 = ein2hypergraph(eincode)
+
+ nested_code = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod())
+ g2 = ein2hypergraph(nested_code)
+
+ sliced_code = optimize_code(eincode, uniformsize(eincode, 2), TreeSA(nslices = 1))
+ g3 = ein2hypergraph(sliced_code)
+
+ @test g1 == g2 == g3
+ @test size(g1.adjacency_matrix, 1) == 5
+ @test size(g1.adjacency_matrix, 2) == 6
+end
+
+@testset "eincode to elimination order" begin
+ eincode = OMEinsum.rawcode(ein"((ij, jk), kl), lm -> im")
+ elimination_order = ein2elimination(eincode)
+ @test elimination_order == ['j', 'k', 'l']
+end
+
+@testset "visualize eincode" begin
+ eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], Vector{Char}())
+ t = viz_eins(eincode)
+ @test t isa Luxor.Drawing
+
+ nested_code = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod())
+ t = viz_eins(nested_code)
+ @test t isa Luxor.Drawing
+
+ sliced_code = optimize_code(eincode, uniformsize(eincode, 2), TreeSA())
+ t = viz_eins(sliced_code)
+ @test t isa Luxor.Drawing
+
+ open_eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a'])
+ t = viz_eins(open_eincode)
+ @test t isa Luxor.Drawing
+
+ # filename and location specified
+ eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], Vector{Char}())
+ filename = tempname() * ".png"
+ viz_eins(eincode; filename, locs=vcat([(randn() * 60, 0.0) for i=1:5], [(randn() * 60, 320.0) for i=1:6]))
+ @test isfile(filename)
+end
+
+@testset "visualize contraction" begin
+ eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], Vector{Char}())
+ nested_code = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod())
+ t_mp4 = viz_contraction(nested_code)
+ tempmp4 = tempname() * ".mp4"
+ tempgif = tempname() * ".gif"
+ t_mp4_2 = viz_contraction(nested_code, filename = tempmp4)
+ @test t_mp4 isa String
+ @test t_mp4_2 isa String
+ t_gif = viz_contraction(nested_code, filename = tempgif)
+ @test t_gif isa String
+
+ @test_throws AssertionError begin
+ viz_contraction(nested_code, filename = "test.avi")
+ end
+
+ sliced_code = optimize_code(eincode, uniformsize(eincode, 2), TreeSA())
+ t_mp4 = viz_contraction(sliced_code)
+ t_mp4_2 = viz_contraction(sliced_code, filename = tempmp4)
+ @test t_mp4 isa String
+ @test t_mp4_2 isa String
+ t_gif = viz_contraction(sliced_code, filename = tempgif)
+ @test t_gif isa String
+
+ sliced_code2 = optimize_code(eincode, uniformsize(eincode, 2), TreeSA(nslices = 1))
+ t_mp4 = viz_contraction(sliced_code2)
+ t_mp4_2 = viz_contraction(sliced_code2, filename = tempmp4)
+ @test t_mp4 isa String
+ @test t_mp4_2 isa String
+ t_gif = viz_contraction(sliced_code2, filename = tempgif)
+ @test t_gif isa String
+end
\ No newline at end of file