Skip to content

Commit

Permalink
Xz/luxor tensor plot (#44)
Browse files Browse the repository at this point in the history
* add LuxorTensorPlot as extension

* update readme for the viz toll

* add tests for dependency check

* revise error message

* update the interface of viz contraction

* update readme

* simplify

* fix file path

---------

Co-authored-by: GiggleLiu <[email protected]>
  • Loading branch information
ArrogantGao and GiggleLiu authored Aug 4, 2024
1 parent 2be8ee9 commit c491c55
Show file tree
Hide file tree
Showing 13 changed files with 463 additions and 1 deletion.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@ TreeWidthSolver = "7d267fc5-9ace-409f-a54c-cd2374872a55"

[weakdeps]
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
LuxorGraphPlot = "1f49bdf2-22a7-4bc4-978b-948dc219fbbc"

[extensions]
KaHyParExt = ["KaHyPar"]
LuxorTensorPlot = ["LuxorGraphPlot"]

[compat]
AbstractTrees = "0.3, 0.4"
JSON = "0.21"
KaHyPar = "0.3"
StatsBase = "0.34"
Suppressor = "0.2"
LuxorGraphPlot = "0.5.1"
TreeWidthSolver = "0.1.0"
julia = "1.9"

Expand All @@ -32,6 +35,7 @@ OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
LuxorGraphPlot = "1f49bdf2-22a7-4bc4-978b-948dc219fbbc"

[targets]
test = ["Test", "Random", "Graphs", "TropicalNumbers", "OMEinsum", "KaHyPar"]
test = ["Test", "Random", "Graphs", "TropicalNumbers", "OMEinsum", "KaHyPar", "LuxorGraphPlot"]
46 changes: 46 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,52 @@ SlicedEinsum{Char, NestedEinsum{DynamicEinCode{Char}}}(Char[], ki, ki ->
)
```

## Extensions

### LuxorTensorPlot

`LuxorTensorPlot` is an extension of the `OMEinsumContractionOrders` package that provides a visualization of the contraction order. It is designed to work with the `OMEinsumContractionOrders` package. To use `LuxorTensorPlot`, please follow these steps:
```julia
pkg> add OMEinsumContractionOrders, LuxorGraphPlot

julia> using OMEinsumContractionOrders, LuxorGraphPlot
```
and then the extension will be loaded automatically.

The extension provides the following to function, `viz_eins` and `viz_contraction`, where the former will plot the tensor network as a graph, and the latter will generate a video or gif of the contraction process.
Here is an example:
```julia
julia> using OMEinsumContractionOrders, LuxorGraphPlot

julia> eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a'])
ab, acd, bcef, e, df -> a

julia> viz_eins(eincode, filename = "eins.png")

julia> nested_eins = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod())
ab, ab -> a
├─ ab
└─ acf, bcf -> ab
├─ acd, df -> acf
│ ├─ acd
│ └─ df
└─ bcef, e -> bcf
├─ bcef
└─ e


julia> viz_contraction(nested_code)
[ Info: Generating frames, 7 frames in total
[ Info: Creating video at: /var/folders/3y/xl2h1bxj4ql27p01nl5hrrnc0000gn/T/jl_SiSvrH/contraction.mp4
"/var/folders/3y/xl2h1bxj4ql27p01nl5hrrnc0000gn/T/jl_SiSvrH/contraction.mp4"
```
The resulting image and video will be saved in the current working directory, and the image is shown below:
<div style="text-align:center">
<img src="examples/eins.png" alt="Image" width="40%" />
</div>
The large white nodes represent the tensors, and the small colored nodes represent the indices, red for closed indices and green for open indices.
## References
If you find this package useful in your research, please cite the *relevant* papers in [CITATION.bib](CITATION.bib).
Expand Down
Binary file added examples/eins.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions examples/visualization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using OMEinsumContractionOrders, LuxorGraphPlot

eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a'])

viz_eins(eincode, filename = "eins.png")

nested_eins = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod())
viz_contraction(nested_eins)
5 changes: 5 additions & 0 deletions ext/LuxorTensorPlot.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module LuxorTensorPlot

include("LuxorTensorPlot/src/LuxorTensorPlot.jl")

end
14 changes: 14 additions & 0 deletions ext/LuxorTensorPlot/src/LuxorTensorPlot.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using OMEinsumContractionOrders, LuxorGraphPlot

using OMEinsumContractionOrders.SparseArrays
using LuxorGraphPlot.Graphs
using LuxorGraphPlot.Luxor
using LuxorGraphPlot.Luxor.FFMPEG

using OMEinsumContractionOrders: AbstractEinsum, NestedEinsum, SlicedEinsum
using OMEinsumContractionOrders: getixsv, getiyv
using OMEinsumContractionOrders: ein2hypergraph, ein2elimination

include("hypergraph.jl")
include("viz_eins.jl")
include("viz_contraction.jl")
71 changes: 71 additions & 0 deletions ext/LuxorTensorPlot/src/hypergraph.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
struct LabeledHyperGraph{TS, TV, TE}
adjacency_matrix::SparseMatrixCSC{TS}
vertex_labels::Vector{TV}
edge_labels::Vector{TE}
open_edges::Vector{TE}

function LabeledHyperGraph(adjacency_matrix::SparseMatrixCSC{TS}; vl::Vector{TV} = [1:size(adjacency_matrix, 1)...], el::Vector{TE} = [1:size(adjacency_matrix, 2)...], oe::Vector = []) where{TS, TV, TE}
if size(adjacency_matrix, 1) != length(vl)
throw(ArgumentError("Number of vertices does not match number of vertex labels"))
end
if size(adjacency_matrix, 2) != length(el)
throw(ArgumentError("Number of edges does not match number of edge labels"))
end
if !all(oei in el for oei in oe)
throw(ArgumentError("Open edges must be in edge labels"))
end
if isempty(oe)
oe = Vector{TE}()
end
new{TS, TV, TE}(adjacency_matrix, vl, el, oe)
end
end

Base.show(io::IO, g::LabeledHyperGraph{TS, TV, TE}) where{TS,TV,TE} = print(io, "LabeledHyperGraph{$TS, $TV, $TE} \n adjacency_mat: $(g.adjacency_matrix) \n vertex: $(g.vertex_labels) \n edges: $(g.edge_labels)) \n open_edges: $(g.open_edges)")

Base.:(==)(a::LabeledHyperGraph, b::LabeledHyperGraph) = a.adjacency_matrix == b.adjacency_matrix && a.vertex_labels == b.vertex_labels && a.edge_labels == b.edge_labels && a.open_edges == b.open_edges

struct TensorNetworkGraph{TT, TI}
graph::SimpleGraph
tensors_labels::Dict{Int, TT}
indices_labels::Dict{Int, TI}
open_indices::Vector{TI}

function TensorNetworkGraph(graph::SimpleGraph; tl::Dict{Int, TT} = Dict{Int, Int}(), il::Dict{Int, TI} = Dict{Int, Int}(), oi::Vector = []) where{TT, TI}
if length(tl) + length(il) != nv(graph)
throw(ArgumentError("Number of tensors + indices does not match number of vertices"))
end
if !all(oii in values(il) for oii in oi)
throw(ArgumentError("Open indices must be in indices"))
end
if isempty(oi)
oi = Vector{TI}()
end
new{TT, TI}(graph, tl, il, oi)
end
end

Base.show(io::IO, g::TensorNetworkGraph{TT, TI}) where{TT, TI} = print(io, "TensorNetworkGraph{$TT, $TI} \n graph: {$(nv(g.graph)), $(ne(g.graph))} \n tensors: $(g.tensors_labels) \n indices: $(g.indices_labels)) \n open_indices: $(g.open_indices)")

# convert the labeled hypergraph to a tensor network graph, where vertices and edges of the hypergraph are mapped as the vertices of the tensor network graph, and the open edges are recorded.
function TensorNetworkGraph(lhg::LabeledHyperGraph{TS, TV, TE}) where{TS, TV, TE}
graph = SimpleGraph(length(lhg.vertex_labels) + length(lhg.edge_labels))
tensors_labels = Dict{Int, TV}()
indices_labels = Dict{Int, TE}()

lv = length(lhg.vertex_labels)
for i in 1:length(lhg.vertex_labels)
tensors_labels[i] = lhg.vertex_labels[i]
end
for i in 1:length(lhg.edge_labels)
indices_labels[i + lv] = lhg.edge_labels[i]
end

for i in 1:size(lhg.adjacency_matrix, 1)
for j in findall(!iszero, lhg.adjacency_matrix[i, :])
add_edge!(graph, i, j + lv)
end
end

TensorNetworkGraph(graph, tl=tensors_labels, il=indices_labels, oi=lhg.open_edges)
end
98 changes: 98 additions & 0 deletions ext/LuxorTensorPlot/src/viz_contraction.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
function OMEinsumContractionOrders.ein2elimination(code::NestedEinsum{T}) where{T}
elimination_order = Vector{T}()
_ein2elimination!(code, elimination_order)
return elimination_order
end

function OMEinsumContractionOrders.ein2elimination(code::SlicedEinsum{T, NestedEinsum{T}}) where{T}
elimination_order = Vector{T}()
_ein2elimination!(code.eins, elimination_order)
# the slicing indices are eliminated at the end
return vcat(elimination_order, code.slicing)
end

function _ein2elimination!(code::NestedEinsum{T}, elimination_order::Vector{T}) where{T}
if code.tensorindex == -1
for arg in code.args
_ein2elimination!(arg, elimination_order)
end
iy = unique(vcat(getiyv(code.eins)...))
for ix in unique(vcat(getixsv(code.eins)...))
if !(ix in iy) && !(ix in elimination_order)
push!(elimination_order, ix)
end
end
end
return elimination_order
end

function elimination_frame(gviz::GraphViz, tng::TensorNetworkGraph{TG, TL}, elimination_order::Vector{TL}, i::Int; filename = nothing) where{TG, TL}
gviz2 = deepcopy(gviz)
for j in 1:i
id = _get_key(tng.indices_labels, elimination_order[j])
gviz2.vertex_colors[id] = (0.5, 0.5, 0.5, 0.5)
end
return show_graph(gviz2, filename = filename)
end

function OMEinsumContractionOrders.viz_contraction(code::T, args...; kwargs...) where{T <: AbstractEinsum}
throw(ArgumentError("Only NestedEinsum and SlicedEinsum{T, NestedEinsum{T}} have contraction order"))
end

"""
viz_contraction(code::Union{NestedEinsum, SlicedEinsum}; locs=StressLayout(), framerate=10, filename=tempname() * ".mp4", show_progress=true)
Visualize the contraction process of a tensor network.
### Arguments
- `code`: The tensor network to visualize.
### Keyword Arguments
- `locs`: The coordinates or layout algorithm to use for positioning the nodes in the graph. Default is `StressLayout()`.
- `framerate`: The frame rate of the animation. Default is `10`.
- `filename`: The name of the output file, with `.gif` or `.mp4` extension. Default is a temporary file with `.mp4` extension.
- `show_progress`: Whether to show progress information. Default is `true`.
# Returns
- the path of the generated file.
"""
function OMEinsumContractionOrders.viz_contraction(
code::Union{NestedEinsum, SlicedEinsum};
locs=StressLayout(),
framerate = 10,
filename::String = tempname() * ".mp4",
show_progress::Bool = true)

# analyze the output format
@assert endswith(filename, ".gif") || endswith(filename, ".mp4") "Unsupported file format: $filename, only :gif and :mp4 are supported"
tempdirectory = mktempdir()

# generate the frames
elimination_order = ein2elimination(code)
tng = TensorNetworkGraph(ein2hypergraph(code))
gviz = GraphViz(tng, locs)

le = length(elimination_order)
for i in 0:le
show_progress && @info "Frame $(i + 1) of $(le + 1)"
fig_name = joinpath(tempdirectory, "$(lpad(i+1, 10, "0")).png")
elimination_frame(gviz, tng, elimination_order, i; filename = fig_name)
end

if endswith(filename, ".gif")
Luxor.FFMPEG.exe(`-loglevel panic -r $(framerate) -f image2 -i $(tempdirectory)/%10d.png -filter_complex "[0:v] split [a][b]; [a] palettegen=stats_mode=full:reserve_transparent=on:transparency_color=FFFFFF [p]; [b][p] paletteuse=new=1:alpha_threshold=128" -y $filename`)
else
Luxor.FFMPEG.ffmpeg_exe(`
-loglevel panic
-r $(framerate)
-f image2
-i $(tempdirectory)/%10d.png
-c:v libx264
-vf "pad=ceil(iw/2)*2:ceil(ih/2)*2"
-r $(framerate)
-pix_fmt yuv420p
-y $filename`)
end
show_progress && @info "Generated output at: $filename"
return filename
end
82 changes: 82 additions & 0 deletions ext/LuxorTensorPlot/src/viz_eins.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
function LuxorGraphPlot.GraphViz(tng::TensorNetworkGraph, locs=StressLayout(); highlight::Vector=[], highlight_color = (0.0, 0.0, 255.0, 0.5), kwargs...)

white = (255.0, 255.0, 255.0, 0.8)
black = (0.0, 0.0, 0.0, 1.0)
r = (255.0, 0.0, 0.0, 0.8)
g = (0.0, 255.0, 0.0, 0.8)

colors = Vector{typeof(r)}()
text = Vector{String}()
sizes = Vector{Float64}()

for i in 1:nv(tng.graph)
if i in keys(tng.tensors_labels)
push!(colors, white)
push!(text, string(tng.tensors_labels[i]))
push!(sizes, 20.0)
else
push!(colors, r)
push!(text, string(tng.indices_labels[i]))
push!(sizes, 10.0)
end
end

for oi in tng.open_indices
id = _get_key(tng.indices_labels, oi)
colors[id] = g
end

for hl in highlight
id = _get_key(tng.indices_labels, hl)
colors[id] = highlight_color
end

return GraphViz(tng.graph, locs, texts = text, vertex_colors = colors, vertex_sizes = sizes, kwargs...)
end

function _get_key(dict::Dict, value)
for (key, val) in dict
if val == value
return key
end
end
@error "Value not found in dictionary"
end

function OMEinsumContractionOrders.ein2hypergraph(code::T) where{T <: AbstractEinsum}
ixs = getixsv(code)
iy = getiyv(code)

edges = unique!([Iterators.flatten(ixs)...])
open_edges = [iy[i] for i in 1:length(iy) if iy[i] in edges]

rows = Int[]
cols = Int[]
for (i,ix) in enumerate(ixs)
push!(rows, map(x->i, ix)...)
push!(cols, map(x->findfirst(==(x), edges), ix)...)
end
adj = sparse(rows, cols, ones(Int, length(rows)))

return LabeledHyperGraph(adj, el = edges, oe = open_edges)
end

"""
viz_eins(code::AbstractEinsum; locs=StressLayout(), filename = nothing, kwargs...)
Visualizes an `AbstractEinsum` object by creating a tensor network graph and rendering it using GraphViz.
### Arguments
- `code::AbstractEinsum`: The `AbstractEinsum` object to visualize.
### Keyword Arguments
- `locs=StressLayout()`: The coordinates or layout algorithm to use for positioning the nodes in the graph.
- `filename = nothing`: The name of the file to save the visualization to. If `nothing`, the visualization will be displayed on the screen instead of saving to a file.
- `config = GraphDisplayConfig()`: The configuration for displaying the graph. Please refer to the documentation of [`GraphDisplayConfig`](https://giggleliu.github.io/LuxorGraphPlot.jl/dev/ref/#LuxorGraphPlot.GraphDisplayConfig) for more information.
- `kwargs...`: Additional keyword arguments to be passed to the [`GraphViz`](https://giggleliu.github.io/LuxorGraphPlot.jl/dev/ref/#LuxorGraphPlot.GraphViz) constructor.
"""
function OMEinsumContractionOrders.viz_eins(code::AbstractEinsum; locs=StressLayout(), filename = nothing, config=LuxorTensorPlot.GraphDisplayConfig(), kwargs...)
tng = TensorNetworkGraph(ein2hypergraph(code))
gviz = GraphViz(tng, locs; kwargs...)
return show_graph(gviz; filename, config)
end
6 changes: 6 additions & 0 deletions src/OMEinsumContractionOrders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ export CodeOptimizer, CodeSimplifier,
label_elimination_order
# writejson, readjson are not exported to avoid namespace conflict

# visiualization tools provided by extension `LuxorTensorPlot`
export viz_eins, viz_contraction

include("Core.jl")
include("utils.jl")

Expand Down Expand Up @@ -46,6 +49,9 @@ include("interfaces.jl")
# saveload
include("json.jl")

# extension for visiualization
include("visualization.jl")

@deprecate timespacereadwrite_complexity(code, size_dict::Dict) (contraction_complexity(code, size_dict)...,)
@deprecate timespace_complexity(code, size_dict::Dict) (contraction_complexity(code, size_dict)...,)[1:2]

Expand Down
Loading

0 comments on commit c491c55

Please sign in to comment.