Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xz/luxor tensor plot #44

Merged
merged 9 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@ Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"

[weakdeps]
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
LuxorGraphPlot = "1f49bdf2-22a7-4bc4-978b-948dc219fbbc"

[extensions]
KaHyParExt = ["KaHyPar"]
LuxorTensorPlot = ["LuxorGraphPlot"]

[compat]
AbstractTrees = "0.3, 0.4"
JSON = "0.21"
KaHyPar = "0.3"
StatsBase = "0.34"
Suppressor = "0.2"
LuxorGraphPlot = "0.5.1"
julia = "1.9"

[extras]
Expand All @@ -30,6 +33,7 @@ OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
LuxorGraphPlot = "1f49bdf2-22a7-4bc4-978b-948dc219fbbc"

[targets]
test = ["Test", "Random", "Graphs", "TropicalNumbers", "OMEinsum", "KaHyPar"]
test = ["Test", "Random", "Graphs", "TropicalNumbers", "OMEinsum", "KaHyPar", "LuxorGraphPlot"]
46 changes: 46 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,52 @@ SlicedEinsum{Char, NestedEinsum{DynamicEinCode{Char}}}(Char[], ki, ki ->
)
```

## Extensions

### LuxorTensorPlot

`LuxorTensorPlot` is an extension of the `OMEinsumContractionOrders` package that provides a visualization of the contraction order. It is designed to work with the `OMEinsumContractionOrders` package. To use `LuxorTensorPlot`, please follow these steps:
```julia
pkg> add OMEinsumContractionOrders, LuxorGraphPlot

julia> using OMEinsumContractionOrders, LuxorGraphPlot
```
and then the extension will be loaded automatically.

The extension provides the following to function, `viz_eins` and `viz_contraction`, where the former will plot the tensor network as a graph, and the latter will generate a video or gif of the contraction process.
Here is an example:
```julia
julia> using OMEinsumContractionOrders, LuxorGraphPlot

julia> eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a'])
ab, acd, bcef, e, df -> a

julia> viz_eins(eincode, filename = "eins.png")

julia> nested_eins = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod())
ab, ab -> a
├─ ab
└─ acf, bcf -> ab
├─ acd, df -> acf
│ ├─ acd
│ └─ df
└─ bcef, e -> bcf
├─ bcef
└─ e


julia> viz_contraction(nested_eins)
[ Info: Generating frames, 5 frames in total
[ Info: Creating video at: ./contraction.mp4
"./contraction.mp4"
```

The resulting image and video will be saved in the current working directory, and the image is shown below:
<div style="text-align:center">
<img src="examples/eins.png" alt="Image" width="40%" />
</div>
The large white nodes represent the tensors, and the small colored nodes represent the indices, red for closed indices and green for open indices.

## References

If you find this package useful in your research, please cite the *relevant* papers in [CITATION.bib](CITATION.bib).
Expand Down
Binary file added examples/eins.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions examples/visualization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using OMEinsumContractionOrders, LuxorGraphPlot

eincode = OMEinsumContractionOrders.EinCode([['a', 'b'], ['a', 'c', 'd'], ['b', 'c', 'e', 'f'], ['e'], ['d', 'f']], ['a'])

viz_eins(eincode, filename = "eins.png")

nested_eins = optimize_code(eincode, uniformsize(eincode, 2), GreedyMethod())
viz_contraction(nested_eins)
5 changes: 5 additions & 0 deletions ext/LuxorTensorPlot.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module LuxorTensorPlot

include("LuxorTensorPlot/src/LuxorTensorPlot.jl")

end
14 changes: 14 additions & 0 deletions ext/LuxorTensorPlot/src/LuxorTensorPlot.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using OMEinsumContractionOrders, LuxorGraphPlot

using OMEinsumContractionOrders.SparseArrays
using LuxorGraphPlot.Graphs
using LuxorGraphPlot.Luxor
using LuxorGraphPlot.Luxor.FFMPEG

using OMEinsumContractionOrders: AbstractEinsum, NestedEinsum, SlicedEinsum
using OMEinsumContractionOrders: getixsv, getiyv
using OMEinsumContractionOrders: ein2hypergraph, ein2elimination

include("hypergraph.jl")
include("viz_eins.jl")
include("viz_contraction.jl")
71 changes: 71 additions & 0 deletions ext/LuxorTensorPlot/src/hypergraph.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
struct LabeledHyperGraph{TS, TV, TE}
adjacency_matrix::SparseMatrixCSC{TS}
vertex_labels::Vector{TV}
edge_labels::Vector{TE}
open_edges::Vector{TE}

function LabeledHyperGraph(adjacency_matrix::SparseMatrixCSC{TS}; vl::Vector{TV} = [1:size(adjacency_matrix, 1)...], el::Vector{TE} = [1:size(adjacency_matrix, 2)...], oe::Vector = []) where{TS, TV, TE}
if size(adjacency_matrix, 1) != length(vl)
throw(ArgumentError("Number of vertices does not match number of vertex labels"))
end
if size(adjacency_matrix, 2) != length(el)
throw(ArgumentError("Number of edges does not match number of edge labels"))
end
if !all(oei in el for oei in oe)
throw(ArgumentError("Open edges must be in edge labels"))
end
if isempty(oe)
oe = Vector{TE}()
end
new{TS, TV, TE}(adjacency_matrix, vl, el, oe)
end
end

Base.show(io::IO, g::LabeledHyperGraph{TS, TV, TE}) where{TS,TV,TE} = print(io, "LabeledHyperGraph{$TS, $TV, $TE} \n adjacency_mat: $(g.adjacency_matrix) \n vertex: $(g.vertex_labels) \n edges: $(g.edge_labels)) \n open_edges: $(g.open_edges)")

Base.:(==)(a::LabeledHyperGraph, b::LabeledHyperGraph) = a.adjacency_matrix == b.adjacency_matrix && a.vertex_labels == b.vertex_labels && a.edge_labels == b.edge_labels && a.open_edges == b.open_edges

struct TensorNetworkGraph{TT, TI}
graph::SimpleGraph
tensors_labels::Dict{Int, TT}
indices_labels::Dict{Int, TI}
open_indices::Vector{TI}

function TensorNetworkGraph(graph::SimpleGraph; tl::Dict{Int, TT} = Dict{Int, Int}(), il::Dict{Int, TI} = Dict{Int, Int}(), oi::Vector = []) where{TT, TI}
if length(tl) + length(il) != nv(graph)
throw(ArgumentError("Number of tensors + indices does not match number of vertices"))
end
if !all(oii in values(il) for oii in oi)
throw(ArgumentError("Open indices must be in indices"))
end
if isempty(oi)
oi = Vector{TI}()
end
new{TT, TI}(graph, tl, il, oi)
end
end

Base.show(io::IO, g::TensorNetworkGraph{TT, TI}) where{TT, TI} = print(io, "TensorNetworkGraph{$TT, $TI} \n graph: {$(nv(g.graph)), $(ne(g.graph))} \n tensors: $(g.tensors_labels) \n indices: $(g.indices_labels)) \n open_indices: $(g.open_indices)")

# convert the labeled hypergraph to a tensor network graph, where vertices and edges of the hypergraph are mapped as the vertices of the tensor network graph, and the open edges are recorded.
function TensorNetworkGraph(lhg::LabeledHyperGraph{TS, TV, TE}) where{TS, TV, TE}
graph = SimpleGraph(length(lhg.vertex_labels) + length(lhg.edge_labels))
tensors_labels = Dict{Int, TV}()
indices_labels = Dict{Int, TE}()

lv = length(lhg.vertex_labels)
for i in 1:length(lhg.vertex_labels)
tensors_labels[i] = lhg.vertex_labels[i]
end
for i in 1:length(lhg.edge_labels)
indices_labels[i + lv] = lhg.edge_labels[i]
end

for i in 1:size(lhg.adjacency_matrix, 1)
for j in findall(!iszero, lhg.adjacency_matrix[i, :])
add_edge!(graph, i, j + lv)
end
end

TensorNetworkGraph(graph, tl=tensors_labels, il=indices_labels, oi=lhg.open_edges)
end
138 changes: 138 additions & 0 deletions ext/LuxorTensorPlot/src/viz_contraction.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
function OMEinsumContractionOrders.ein2elimination(ein::NestedEinsum{T}) where{T}
elimination_order = Vector{T}()
_ein2elimination!(ein, elimination_order)
return elimination_order
end

function OMEinsumContractionOrders.ein2elimination(ein::SlicedEinsum{T, NestedEinsum{T}}) where{T}
elimination_order = Vector{T}()
_ein2elimination!(ein.eins, elimination_order)
# the slicing indices are eliminated at the end
return vcat(elimination_order, ein.slicing)
end

function _ein2elimination!(ein::NestedEinsum{T}, elimination_order::Vector{T}) where{T}
if ein.tensorindex == -1
for arg in ein.args
_ein2elimination!(arg, elimination_order)
end
iy = unique(vcat(getiyv(ein.eins)...))
for ix in unique(vcat(getixsv(ein.eins)...))
if !(ix in iy) && !(ix in elimination_order)
push!(elimination_order, ix)
end
end
end
return elimination_order
end

function elimination_frame(GViz, tng::TensorNetworkGraph{TG, TL}, elimination_order::Vector{TL}, i::Int; filename = nothing, color = (0.5, 0.5, 0.5, 0.5)) where{TG, TL}
GViz2 = deepcopy(GViz)
for j in 1:i
id = _get_key(tng.indices_labels, elimination_order[j])
GViz2.vertex_colors[id] = color
end
return show_graph(GViz2, filename = filename)
end

function OMEinsumContractionOrders.viz_contraction(ein::T, args...; kwargs...) where{T <: AbstractEinsum}
throw(ArgumentError("Only NestedEinsum and SlicedEinsum{T, NestedEinsum{T}} have contraction order"))
end

"""
viz_contraction(ein::ET; locs=StressLayout(), framerate=30, filename="contraction", pathname=".", create_gif=false, create_video=true, color=(0.5, 0.5, 0.5, 0.5), show_progress=false) where {ET <: Union{NestedEinsum, SlicedEinsum}}

Visualize the contraction process of a tensor network.

# Arguments
- `ein::ET`: The tensor network to visualize.
- `locs`: The layout algorithm to use for positioning the nodes in the graph. Default is `StressLayout()`.
- `framerate`: The frame rate of the animation. Default is 30.
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved
- `filename`: The base name of the output files. Default is "contraction".
- `pathname`: The directory path to save the output files. Default is the current directory.
- `create_gif`: Whether to create a GIF animation. Default is `false`.
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved
- `create_video`: Whether to create a video. Default is `true`.
- `color`: The color of the contraction lines. Default is `(0.5, 0.5, 0.5, 0.5)`.
- `show_progress`: Whether to show progress information. Default is `false`.
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved

# Returns
- If `create_gif` is `true`, returns the path to the generated GIF animation.
- If `create_video` is `true`, returns the path to the generated video.
"""
function OMEinsumContractionOrders.viz_contraction(
ein::ET;
locs=StressLayout(),
framerate = 30,
filename = "contraction",
pathname = ".",
create_gif = false,
create_video = true,
color = (0.5, 0.5, 0.5, 0.5),
show_progress::Bool = false
) where{ET <: Union{NestedEinsum, SlicedEinsum}}

elimination_order = ein2elimination(ein)
tng = TensorNetworkGraph(ein2hypergraph(ein))
GViz = GraphViz(tng, locs)

tempdirectory = mktempdir()
# @info("Frames for animation \"$(filename)\" are being stored in directory: \n\t $(tempdirectory)")
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved

filecounter = 1
le = length(elimination_order)
@info "Generating frames, $(le + 1) frames in total"
for i in 0:le
if show_progress
@info "Frame $(i + 1) of $(le + 1)"
end
fig_name = "$(tempdirectory)/$(lpad(filecounter, 10, "0")).png"
elimination_frame(GViz, tng, elimination_order, i; filename = fig_name, color = color)
filecounter += 1
end

if create_gif
Luxor.FFMPEG.exe(`-loglevel panic -r $(framerate) -f image2 -i $(tempdirectory)/%10d.png -filter_complex "[0:v] split [a][b]; [a] palettegen=stats_mode=full:reserve_transparent=on:transparency_color=FFFFFF [p]; [b][p] paletteuse=new=1:alpha_threshold=128" -y $(tempdirectory)/$(filename).gif`)

if !isempty(pathname)
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved
if !isdir(pathname)
@error "$pathname is not a directory."
end
fig_path = joinpath(pathname, "$filename.gif")
mv("$(tempdirectory)/$(filename).gif", fig_path, force = true)
@info("GIF is: $fig_path")
giffn = fig_path
else
@info("GIF is: $(tempdirectory)/$(filename).gif")
giffn = tempdirectory * "/" * filename * ".gif"
end

return giffn
elseif create_video
movieformat = ".mp4"

if !isempty(pathname)
if !isdir(pathname)
@error "$pathname is not a directory."
end
pathname = joinpath(pathname, "$(filename)$(movieformat)")
else
pathname = joinpath("$(tempdirectory)", "$(filename)$(movieformat)")
end

@info "Creating video at: $pathname"
FFMPEG.ffmpeg_exe(`
-loglevel panic
-r $(framerate)
-f image2
-i $(tempdirectory)/%10d.png
-c:v libx264
-vf "pad=ceil(iw/2)*2:ceil(ih/2)*2"
-r $(framerate)
-pix_fmt yuv420p
-y $(pathname)`)

return pathname
else
return tempdirectory
end
end
Loading
Loading