-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add LuxorTensorPlot as extension * update readme for the viz toll * add tests for dependency check * revise error message * update the interface of viz contraction * update readme * simplify * fix file path --------- Co-authored-by: GiggleLiu <[email protected]>
- Loading branch information
1 parent
2be8ee9
commit c491c55
Showing
13 changed files
with
463 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
using OMEinsumContractionOrders, LuxorGraphPlot | ||
|
||
eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a']) | ||
|
||
viz_eins(eincode, filename = "eins.png") | ||
|
||
nested_eins = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod()) | ||
viz_contraction(nested_eins) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
module LuxorTensorPlot | ||
|
||
include("LuxorTensorPlot/src/LuxorTensorPlot.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
using OMEinsumContractionOrders, LuxorGraphPlot | ||
|
||
using OMEinsumContractionOrders.SparseArrays | ||
using LuxorGraphPlot.Graphs | ||
using LuxorGraphPlot.Luxor | ||
using LuxorGraphPlot.Luxor.FFMPEG | ||
|
||
using OMEinsumContractionOrders: AbstractEinsum, NestedEinsum, SlicedEinsum | ||
using OMEinsumContractionOrders: getixsv, getiyv | ||
using OMEinsumContractionOrders: ein2hypergraph, ein2elimination | ||
|
||
include("hypergraph.jl") | ||
include("viz_eins.jl") | ||
include("viz_contraction.jl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
struct LabeledHyperGraph{TS, TV, TE} | ||
adjacency_matrix::SparseMatrixCSC{TS} | ||
vertex_labels::Vector{TV} | ||
edge_labels::Vector{TE} | ||
open_edges::Vector{TE} | ||
|
||
function LabeledHyperGraph(adjacency_matrix::SparseMatrixCSC{TS}; vl::Vector{TV} = [1:size(adjacency_matrix, 1)...], el::Vector{TE} = [1:size(adjacency_matrix, 2)...], oe::Vector = []) where{TS, TV, TE} | ||
if size(adjacency_matrix, 1) != length(vl) | ||
throw(ArgumentError("Number of vertices does not match number of vertex labels")) | ||
end | ||
if size(adjacency_matrix, 2) != length(el) | ||
throw(ArgumentError("Number of edges does not match number of edge labels")) | ||
end | ||
if !all(oei in el for oei in oe) | ||
throw(ArgumentError("Open edges must be in edge labels")) | ||
end | ||
if isempty(oe) | ||
oe = Vector{TE}() | ||
end | ||
new{TS, TV, TE}(adjacency_matrix, vl, el, oe) | ||
end | ||
end | ||
|
||
Base.show(io::IO, g::LabeledHyperGraph{TS, TV, TE}) where{TS,TV,TE} = print(io, "LabeledHyperGraph{$TS, $TV, $TE} \n adjacency_mat: $(g.adjacency_matrix) \n vertex: $(g.vertex_labels) \n edges: $(g.edge_labels)) \n open_edges: $(g.open_edges)") | ||
|
||
Base.:(==)(a::LabeledHyperGraph, b::LabeledHyperGraph) = a.adjacency_matrix == b.adjacency_matrix && a.vertex_labels == b.vertex_labels && a.edge_labels == b.edge_labels && a.open_edges == b.open_edges | ||
|
||
struct TensorNetworkGraph{TT, TI} | ||
graph::SimpleGraph | ||
tensors_labels::Dict{Int, TT} | ||
indices_labels::Dict{Int, TI} | ||
open_indices::Vector{TI} | ||
|
||
function TensorNetworkGraph(graph::SimpleGraph; tl::Dict{Int, TT} = Dict{Int, Int}(), il::Dict{Int, TI} = Dict{Int, Int}(), oi::Vector = []) where{TT, TI} | ||
if length(tl) + length(il) != nv(graph) | ||
throw(ArgumentError("Number of tensors + indices does not match number of vertices")) | ||
end | ||
if !all(oii in values(il) for oii in oi) | ||
throw(ArgumentError("Open indices must be in indices")) | ||
end | ||
if isempty(oi) | ||
oi = Vector{TI}() | ||
end | ||
new{TT, TI}(graph, tl, il, oi) | ||
end | ||
end | ||
|
||
Base.show(io::IO, g::TensorNetworkGraph{TT, TI}) where{TT, TI} = print(io, "TensorNetworkGraph{$TT, $TI} \n graph: {$(nv(g.graph)), $(ne(g.graph))} \n tensors: $(g.tensors_labels) \n indices: $(g.indices_labels)) \n open_indices: $(g.open_indices)") | ||
|
||
# convert the labeled hypergraph to a tensor network graph, where vertices and edges of the hypergraph are mapped as the vertices of the tensor network graph, and the open edges are recorded. | ||
function TensorNetworkGraph(lhg::LabeledHyperGraph{TS, TV, TE}) where{TS, TV, TE} | ||
graph = SimpleGraph(length(lhg.vertex_labels) + length(lhg.edge_labels)) | ||
tensors_labels = Dict{Int, TV}() | ||
indices_labels = Dict{Int, TE}() | ||
|
||
lv = length(lhg.vertex_labels) | ||
for i in 1:length(lhg.vertex_labels) | ||
tensors_labels[i] = lhg.vertex_labels[i] | ||
end | ||
for i in 1:length(lhg.edge_labels) | ||
indices_labels[i + lv] = lhg.edge_labels[i] | ||
end | ||
|
||
for i in 1:size(lhg.adjacency_matrix, 1) | ||
for j in findall(!iszero, lhg.adjacency_matrix[i, :]) | ||
add_edge!(graph, i, j + lv) | ||
end | ||
end | ||
|
||
TensorNetworkGraph(graph, tl=tensors_labels, il=indices_labels, oi=lhg.open_edges) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
function OMEinsumContractionOrders.ein2elimination(code::NestedEinsum{T}) where{T} | ||
elimination_order = Vector{T}() | ||
_ein2elimination!(code, elimination_order) | ||
return elimination_order | ||
end | ||
|
||
function OMEinsumContractionOrders.ein2elimination(code::SlicedEinsum{T, NestedEinsum{T}}) where{T} | ||
elimination_order = Vector{T}() | ||
_ein2elimination!(code.eins, elimination_order) | ||
# the slicing indices are eliminated at the end | ||
return vcat(elimination_order, code.slicing) | ||
end | ||
|
||
function _ein2elimination!(code::NestedEinsum{T}, elimination_order::Vector{T}) where{T} | ||
if code.tensorindex == -1 | ||
for arg in code.args | ||
_ein2elimination!(arg, elimination_order) | ||
end | ||
iy = unique(vcat(getiyv(code.eins)...)) | ||
for ix in unique(vcat(getixsv(code.eins)...)) | ||
if !(ix in iy) && !(ix in elimination_order) | ||
push!(elimination_order, ix) | ||
end | ||
end | ||
end | ||
return elimination_order | ||
end | ||
|
||
function elimination_frame(gviz::GraphViz, tng::TensorNetworkGraph{TG, TL}, elimination_order::Vector{TL}, i::Int; filename = nothing) where{TG, TL} | ||
gviz2 = deepcopy(gviz) | ||
for j in 1:i | ||
id = _get_key(tng.indices_labels, elimination_order[j]) | ||
gviz2.vertex_colors[id] = (0.5, 0.5, 0.5, 0.5) | ||
end | ||
return show_graph(gviz2, filename = filename) | ||
end | ||
|
||
function OMEinsumContractionOrders.viz_contraction(code::T, args...; kwargs...) where{T <: AbstractEinsum} | ||
throw(ArgumentError("Only NestedEinsum and SlicedEinsum{T, NestedEinsum{T}} have contraction order")) | ||
end | ||
|
||
""" | ||
viz_contraction(code::Union{NestedEinsum, SlicedEinsum}; locs=StressLayout(), framerate=10, filename=tempname() * ".mp4", show_progress=true) | ||
Visualize the contraction process of a tensor network. | ||
### Arguments | ||
- `code`: The tensor network to visualize. | ||
### Keyword Arguments | ||
- `locs`: The coordinates or layout algorithm to use for positioning the nodes in the graph. Default is `StressLayout()`. | ||
- `framerate`: The frame rate of the animation. Default is `10`. | ||
- `filename`: The name of the output file, with `.gif` or `.mp4` extension. Default is a temporary file with `.mp4` extension. | ||
- `show_progress`: Whether to show progress information. Default is `true`. | ||
# Returns | ||
- the path of the generated file. | ||
""" | ||
function OMEinsumContractionOrders.viz_contraction( | ||
code::Union{NestedEinsum, SlicedEinsum}; | ||
locs=StressLayout(), | ||
framerate = 10, | ||
filename::String = tempname() * ".mp4", | ||
show_progress::Bool = true) | ||
|
||
# analyze the output format | ||
@assert endswith(filename, ".gif") || endswith(filename, ".mp4") "Unsupported file format: $filename, only :gif and :mp4 are supported" | ||
tempdirectory = mktempdir() | ||
|
||
# generate the frames | ||
elimination_order = ein2elimination(code) | ||
tng = TensorNetworkGraph(ein2hypergraph(code)) | ||
gviz = GraphViz(tng, locs) | ||
|
||
le = length(elimination_order) | ||
for i in 0:le | ||
show_progress && @info "Frame $(i + 1) of $(le + 1)" | ||
fig_name = joinpath(tempdirectory, "$(lpad(i+1, 10, "0")).png") | ||
elimination_frame(gviz, tng, elimination_order, i; filename = fig_name) | ||
end | ||
|
||
if endswith(filename, ".gif") | ||
Luxor.FFMPEG.exe(`-loglevel panic -r $(framerate) -f image2 -i $(tempdirectory)/%10d.png -filter_complex "[0:v] split [a][b]; [a] palettegen=stats_mode=full:reserve_transparent=on:transparency_color=FFFFFF [p]; [b][p] paletteuse=new=1:alpha_threshold=128" -y $filename`) | ||
else | ||
Luxor.FFMPEG.ffmpeg_exe(` | ||
-loglevel panic | ||
-r $(framerate) | ||
-f image2 | ||
-i $(tempdirectory)/%10d.png | ||
-c:v libx264 | ||
-vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" | ||
-r $(framerate) | ||
-pix_fmt yuv420p | ||
-y $filename`) | ||
end | ||
show_progress && @info "Generated output at: $filename" | ||
return filename | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
function LuxorGraphPlot.GraphViz(tng::TensorNetworkGraph, locs=StressLayout(); highlight::Vector=[], highlight_color = (0.0, 0.0, 255.0, 0.5), kwargs...) | ||
|
||
white = (255.0, 255.0, 255.0, 0.8) | ||
black = (0.0, 0.0, 0.0, 1.0) | ||
r = (255.0, 0.0, 0.0, 0.8) | ||
g = (0.0, 255.0, 0.0, 0.8) | ||
|
||
colors = Vector{typeof(r)}() | ||
text = Vector{String}() | ||
sizes = Vector{Float64}() | ||
|
||
for i in 1:nv(tng.graph) | ||
if i in keys(tng.tensors_labels) | ||
push!(colors, white) | ||
push!(text, string(tng.tensors_labels[i])) | ||
push!(sizes, 20.0) | ||
else | ||
push!(colors, r) | ||
push!(text, string(tng.indices_labels[i])) | ||
push!(sizes, 10.0) | ||
end | ||
end | ||
|
||
for oi in tng.open_indices | ||
id = _get_key(tng.indices_labels, oi) | ||
colors[id] = g | ||
end | ||
|
||
for hl in highlight | ||
id = _get_key(tng.indices_labels, hl) | ||
colors[id] = highlight_color | ||
end | ||
|
||
return GraphViz(tng.graph, locs, texts = text, vertex_colors = colors, vertex_sizes = sizes, kwargs...) | ||
end | ||
|
||
function _get_key(dict::Dict, value) | ||
for (key, val) in dict | ||
if val == value | ||
return key | ||
end | ||
end | ||
@error "Value not found in dictionary" | ||
end | ||
|
||
function OMEinsumContractionOrders.ein2hypergraph(code::T) where{T <: AbstractEinsum} | ||
ixs = getixsv(code) | ||
iy = getiyv(code) | ||
|
||
edges = unique!([Iterators.flatten(ixs)...]) | ||
open_edges = [iy[i] for i in 1:length(iy) if iy[i] in edges] | ||
|
||
rows = Int[] | ||
cols = Int[] | ||
for (i,ix) in enumerate(ixs) | ||
push!(rows, map(x->i, ix)...) | ||
push!(cols, map(x->findfirst(==(x), edges), ix)...) | ||
end | ||
adj = sparse(rows, cols, ones(Int, length(rows))) | ||
|
||
return LabeledHyperGraph(adj, el = edges, oe = open_edges) | ||
end | ||
|
||
""" | ||
viz_eins(code::AbstractEinsum; locs=StressLayout(), filename = nothing, kwargs...) | ||
Visualizes an `AbstractEinsum` object by creating a tensor network graph and rendering it using GraphViz. | ||
### Arguments | ||
- `code::AbstractEinsum`: The `AbstractEinsum` object to visualize. | ||
### Keyword Arguments | ||
- `locs=StressLayout()`: The coordinates or layout algorithm to use for positioning the nodes in the graph. | ||
- `filename = nothing`: The name of the file to save the visualization to. If `nothing`, the visualization will be displayed on the screen instead of saving to a file. | ||
- `config = GraphDisplayConfig()`: The configuration for displaying the graph. Please refer to the documentation of [`GraphDisplayConfig`](https://giggleliu.github.io/LuxorGraphPlot.jl/dev/ref/#LuxorGraphPlot.GraphDisplayConfig) for more information. | ||
- `kwargs...`: Additional keyword arguments to be passed to the [`GraphViz`](https://giggleliu.github.io/LuxorGraphPlot.jl/dev/ref/#LuxorGraphPlot.GraphViz) constructor. | ||
""" | ||
function OMEinsumContractionOrders.viz_eins(code::AbstractEinsum; locs=StressLayout(), filename = nothing, config=LuxorTensorPlot.GraphDisplayConfig(), kwargs...) | ||
tng = TensorNetworkGraph(ein2hypergraph(code)) | ||
gviz = GraphViz(tng, locs; kwargs...) | ||
return show_graph(gviz; filename, config) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.