From b55d013ab4d2809ec0cc8819a2a5112b86ede214 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 21 Oct 2023 14:01:15 +0200 Subject: [PATCH 01/28] Decouple `size` dictionary from `EinExpr` struct --- src/Counters.jl | 10 +++-- src/EinExpr.jl | 32 ++++----------- src/Optimizers/Exhaustive.jl | 38 +++++++++++------- src/Optimizers/Greedy.jl | 8 ++-- src/Optimizers/Naive.jl | 2 + src/Slicing.jl | 15 +++---- test/Counters_test.jl | 76 ++++++++++++++++++------------------ test/Exhaustive_test.jl | 28 +++++++------ test/Greedy_test.jl | 36 ++++++++--------- test/Naive_test.jl | 30 +++++++------- test/Slicing_test.jl | 36 +++++------------ 11 files changed, 148 insertions(+), 163 deletions(-) diff --git a/src/Counters.jl b/src/Counters.jl index 1ec72fd..ad726e7 100644 --- a/src/Counters.jl +++ b/src/Counters.jl @@ -3,11 +3,11 @@ Count the number of mathematical operations will be performed by the contraction of the root of the `path` tree. """ -flops(expr::EinExpr) = +flops(expr::EinExpr, sizedict) = if length(expr.args) == 0 || length(expr.args) == 1 && isempty(suminds(expr)) 0 else - mapreduce(Base.Fix1(size, expr), *, Iterators.flatten((head(expr), suminds(expr))), init = one(BigInt)) + mapreduce(i -> sizedict[i], *, Iterators.flatten((head(expr), suminds(expr))), init = one(BigInt)) end """ @@ -15,11 +15,13 @@ flops(expr::EinExpr) = Count the amount of memory that will be freed after performing the contraction of the root of the `path` tree. """ -removedsize(expr::EinExpr) = mapreduce(prod ∘ size, +, expr.args) - prod(size(expr)) +function removedsize(expr::EinExpr, sizedict) + mapreduce(prod ∘ Base.Fix2(size, sizedict), +, expr.args) - prod(size(expr, sizedict)) +end """ removedrank(path::EinExpr) Count the rank reduction after performing the contraction of the root of the `path` tree. """ -removedrank(expr::EinExpr) = mapreduce(ndims, max, expr.args) - ndims(expr) +removedrank(expr::EinExpr, _) = mapreduce(ndims, max, expr.args) - ndims(expr) diff --git a/src/EinExpr.jl b/src/EinExpr.jl index cc65304..71e1efd 100644 --- a/src/EinExpr.jl +++ b/src/EinExpr.jl @@ -2,32 +2,17 @@ using Base: AbstractVecOrTuple using DataStructures: DefaultDict using AbstractTrees -struct EinExpr - head::Vector{Symbol} - args::Vector{EinExpr} - size::Dict{Symbol,Int} - - # TODO checks: same dim for index, valid indices - EinExpr(head, args) = new(head, args, Dict{Symbol,EinExpr}()) - - function EinExpr(head::AbstractVector{Symbol}, size::AbstractDict{Symbol,Int}) - head ⊆ keys(size) || throw(ArgumentError("Missing sizes for indices $(setdiff(head, keys(size)))")) - new(head, EinExpr[], size) - end +Base.@kwdef struct EinExpr + head::ImmutableVector{Symbol,Vector{Symbol}} + args::Vector{EinExpr} = EinExpr[] end +EinExpr(head, args::AbstractVecOrTuple{<:AbstractVecOrTuple{Symbol}}) = EinExpr(head, map(EinExpr, args)) + EinExpr(head::NTuple, args) = EinExpr(collect(head), args) EinExpr(head, args::NTuple) = EinExpr(head, collect(args)) EinExpr(head::NTuple, args::NTuple) = EinExpr(collect(head), collect(args)) -function EinExpr(head, args::AbstractVecOrTuple{<:AbstractVecOrTuple{Symbol}}, sizes) - args = map(args) do arg - sizedict = filter(∈(arg) ∘ first, sizes) - EinExpr(arg, sizedict) - end - EinExpr(head, args) -end - """ head(path::EinExpr) @@ -100,11 +85,8 @@ Base.ndims(path::EinExpr) = length(head(path)) Return the size of the resulting tensor from contracting `path`. If `index` is specified, return the size of such index. """ -Base.size(path::EinExpr) = (size(path, i) for i in head(path)) |> splat(tuple) -Base.size(path::EinExpr, i::Symbol) = - Iterators.filter(∋(i) ∘ head, Leaves(path)) |> first |> Base.Fix2(getproperty, :size) |> Base.Fix2(getindex, i) - -Base.length(path::EinExpr) = (prod ∘ size)(path) +Base.size(path::EinExpr, sizedict) = (sizedict[i] for i in head(path)) |> splat(tuple) +Base.length(path::EinExpr, sizedict) = (prod ∘ size)(path, sizedict) """ collapse!(path::EinExpr) diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index ea3591c..4c3d7f4 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -22,23 +22,31 @@ The algorithm has a ``\mathcal{O}(n!)`` time complexity if `outer = true` and `` outer::Bool = false end -function einexpr(config::Exhaustive, path; cost = BigInt(0)) - leader = Ref{NamedTuple{(:path, :cost),Tuple{EinExpr,BigInt}}}((; - path = einexpr(Naive(), path), - cost = mapreduce(config.metric, +, Branches(einexpr(Naive(), path), inverse = true), init = BigInt(0))::BigInt, - )) - cache = Dict{Vector{Symbol},BigInt}() - __einexpr_exhaustive_it(path, cost, config.metric, config.outer, leader, cache) - return leader[].path -end +function einexpr(config::Exhaustive, path, sizedict; cost = BigInt(0)) + metric = Base.Fix2(config.metric, sizedict) + + leader = (; path = einexpr(Naive(), path), cost = mapreduce(config.metric, +, Branches(einexpr(Naive(), path)))) + cache = Dict{Vector{ImmutableVector{Symbol,Vector{Symbol}}},BigInt}() + + function __einexpr_iterate(path, cost) + if length(path.args) <= 2 + leader = (; path = path, cost = mapreduce(metric, +, Branches(path))) + return + end + + for (i, j) in combinations(args(path), 2) + !config.outer && isdisjoint(head(i), head(j)) && continue + candidate = sum([i, j], skip = path.head ∪ hyperinds(path)) -function __einexpr_exhaustive_it(path, cost, metric, outer, leader, cache) - if length(path.args) == 1 - # remove identity einsum (i.e. "i...->i...") - path = path.args[1] + # prune paths based on metric + new_cost = cost + get!(cache, head.(candidate.args)) do + metric(candidate) + end + new_cost >= leader.cost && continue - leader[] = (; path, cost = mapreduce(metric, +, Branches(path, inverse = true), init = BigInt(0))::BigInt) - return + new_path = EinExpr(head(path), [candidate, filter(∉([i, j]), args(path))...]) + __einexpr_iterate(new_path, new_cost) + end end for (i, j) in combinations(args(path), 2) diff --git a/src/Optimizers/Greedy.jl b/src/Optimizers/Greedy.jl index 24a27bc..f9559d2 100644 --- a/src/Optimizers/Greedy.jl +++ b/src/Optimizers/Greedy.jl @@ -27,7 +27,9 @@ The implementation uses a binary heaptree to sort candidate pairwise tensor cont outer::Bool = false end -function einexpr(config::Greedy, path) +function einexpr(config::Greedy, path, sizedict) + metric = Base.Fix2(config.metric, sizedict) + # generate initial candidate contractions queue = MutableBinaryHeap{Tuple{Float64,EinExpr}}( Base.By(first, Base.Reverse), @@ -36,7 +38,7 @@ function einexpr(config::Greedy, path) ) do (a, b) # TODO don't consider outer products candidate = sum([a, b], skip = path.head ∪ hyperinds(path)) - weight = config.metric(candidate) + weight = metric(candidate) (weight, candidate) end, ) @@ -55,7 +57,7 @@ function einexpr(config::Greedy, path) for other in Iterators.filter(other -> config.outer || !isdisjoint(winner.head, other.head), path.args) # TODO don't consider outer products candidate = sum([winner, other], skip = path.head ∪ hyperinds(path)) - weight = config.metric(candidate) + weight = metric(candidate) push!(queue, (weight, candidate)) end diff --git a/src/Optimizers/Naive.jl b/src/Optimizers/Naive.jl index 88d9a5c..407b500 100644 --- a/src/Optimizers/Naive.jl +++ b/src/Optimizers/Naive.jl @@ -1,5 +1,7 @@ struct Naive <: Optimizer end +einexpr(::Naive, path, _) = einexpr(Naive(), path) + function einexpr(::Naive, path) hist = Dict(i => count(∋(i) ∘ head, path.args) for i in hyperinds(path)) diff --git a/src/Slicing.jl b/src/Slicing.jl index 1330ec9..fb9934d 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -88,7 +88,8 @@ Reimplementation based on [`contengra`](https://github.com/jcmgray/cotengra)'s ` """ function findslices( scorer, - path::EinExpr; + path::EinExpr, + sizedict; size = nothing, overhead = nothing, slices = nothing, @@ -100,8 +101,8 @@ function findslices( candidates = Set(setdiff(mapreduce(head, ∪, PostOrderDFS(path)), skip)) solution = Set{Symbol}() - current = (; slices = 1, size = maximum(prod ∘ Base.size, PostOrderDFS(path)), overhead = 1.0) - original_flops = mapreduce(flops, +, Branches(path; inverse = true)) + current = (; slices = 1, size = maximum(Base.Fix2(length, sizedict), PostOrderDFS(path)), overhead = 1.0) + original_flops = mapreduce(flops, +, Branches(path)) sliced_path = path while !isempty(candidates) @@ -113,15 +114,15 @@ function findslices( sliced_path = selectdim(sliced_path, winner, 1) cur_overhead = - prod(i -> Base.size(path, i), [solution..., winner]) * - mapreduce(flops, +, Branches(sliced_path; inverse = true)) / original_flops + prod(i -> sizedict[i], [solution..., winner]) * + mapreduce(Base.Fix2(flops, sizedict), +, Branches(sliced_path)) / original_flops !isnothing(overhead) && cur_overhead > overhead && break push!(solution, winner) current = (; - slices = current.slices * (prod ∘ Base.size)(path, winner), - size = maximum(prod ∘ Base.size, PostOrderDFS(sliced_path)), + slices = current.slices * Base.Fix2(length, sizedict)(path, winner), + size = maximum(Base.Fix2(length, sizedict), PostOrderDFS(sliced_path)), overhead = cur_overhead, ) diff --git a/test/Counters_test.jl b/test/Counters_test.jl index 6edc54e..8675ea5 100644 --- a/test/Counters_test.jl +++ b/test/Counters_test.jl @@ -1,75 +1,77 @@ @testset "Counters" begin using EinExprs: removedrank + sizedict = Dict(:i => 2, :j => 3, :k => 4, :l => 5) + @testset "identity" begin - tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) - expr = EinExpr([:i, :j], [tensor]) + tensor = EinExpr((:i, :j)) + expr = EinExpr((:i, :j), [tensor]) - @test flops(expr) == 0 - @test removedsize(expr) == 0 - @test removedrank(expr) == 0 + @test flops(expr, sizedict) == 0 + @test removedsize(expr, sizedict) == 0 + @test removedrank(expr, sizedict) == 0 end @testset "transpose" begin - tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) + tensor = EinExpr((:i, :j)) expr = EinExpr([:j, :i], [tensor]) - @test flops(expr) == 0 - @test removedsize(expr) == 0 - @test removedrank(expr) == 0 + @test flops(expr, sizedict) == 0 + @test removedsize(expr, sizedict) == 0 + @test removedrank(expr, sizedict) == 0 end @testset "axis sum" begin - tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) - expr = EinExpr([:i], [tensor]) + tensor = EinExpr((:i, :j)) + expr = EinExpr((:i,), [tensor]) - @test flops(expr) == 6 - @test removedsize(expr) == 4 - @test removedrank(expr) == 1 + @test flops(expr, sizedict) == 6 + @test removedsize(expr, sizedict) == 4 + @test removedrank(expr, sizedict) == 1 end @testset "diagonal" begin - tensor = EinExpr([:i, :i], Dict(:i => 2)) - expr = EinExpr([:i], [tensor]) + tensor = EinExpr((:i, :i)) + expr = EinExpr((:i,), [tensor]) - @test flops(expr) == 0 - @test removedsize(expr) == 2 - @test removedrank(expr) == 1 + @test flops(expr, sizedict) == 0 + @test removedsize(expr, sizedict) == 2 + @test removedrank(expr, sizedict) == 1 end @testset "trace" begin - tensor = EinExpr([:i, :i], Dict(:i => 2)) + tensor = EinExpr((:i, :i)) expr = EinExpr(Symbol[], [tensor]) - @test flops(expr) == 2 - @test removedsize(expr) == 3 - @test removedrank(expr) == 2 + @test flops(expr, sizedict) == 2 + @test removedsize(expr, sizedict) == 3 + @test removedrank(expr, sizedict) == 2 end @testset "outer product" begin - tensors = [EinExpr([:i, :j], Dict(:i => 2, :j => 3)), EinExpr([:k, :l], Dict(:k => 4, :l => 5))] - expr = EinExpr([:i, :j, :k, :l], tensors) + tensors = [EinExpr((:i, :j)), EinExpr((:k, :l))] + expr = EinExpr((:i, :j, :k, :l), tensors) - @test flops(expr) == prod(2:5) - @test removedsize(expr) == -94 - @test removedrank(expr) == -2 + @test flops(expr, sizedict) == prod(2:5) + @test removedsize(expr, sizedict) == -94 + @test removedrank(expr, sizedict) == -2 end @testset "inner product" begin - tensors = [EinExpr([:i], Dict(:i => 2)), EinExpr([:i], Dict(:i => 2))] + tensors = [EinExpr((:i,)), EinExpr((:i,))] expr = EinExpr(Symbol[], tensors) - @test flops(expr) == 2 - @test removedsize(expr) == 3 - @test removedrank(expr) == 1 + @test flops(expr, sizedict) == 2 + @test removedsize(expr, sizedict) == 3 + @test removedrank(expr, sizedict) == 1 end @testset "matrix multiplication" begin - tensors = [EinExpr([:i, :k], Dict(:i => 2, :k => 3)), EinExpr([:k, :j], Dict(:k => 3, :j => 4))] - expr = EinExpr([:i, :j], tensors) + tensors = [EinExpr((:i, :j)), EinExpr((:j, :k))] + expr = EinExpr((:i, :k), tensors) - @test flops(expr) == 2 * 3 * 4 - @test removedsize(expr) == 10 - @test removedrank(expr) == 0 + @test flops(expr, sizedict) == 2 * 3 * 4 + @test removedsize(expr, sizedict) == 10 + @test removedrank(expr, sizedict) == 0 end end diff --git a/test/Exhaustive_test.jl b/test/Exhaustive_test.jl index a27316c..0282562 100644 --- a/test/Exhaustive_test.jl +++ b/test/Exhaustive_test.jl @@ -1,15 +1,16 @@ @testset "Exhaustive" begin + sizedict = Dict(i => 2 for i in [:a, :b, :c, :d, :e, :f, :g, :h, :i, :j]) tensors = [ - EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])), - EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])), - EinExpr([:j], Dict(i => 2 for i in [:j])), - EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])), - EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])), - EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])), - EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), + EinExpr([:j, :b, :i, :h]), + EinExpr([:a, :c, :e, :f]), + EinExpr([:j]), + EinExpr([:e, :a, :g]), + EinExpr([:f, :b]), + EinExpr([:i, :h, :d]), + EinExpr([:d, :g, :c]), ] - path = einexpr(Exhaustive, EinExpr(Symbol[], tensors)) + path = einexpr(Exhaustive, EinExpr(Symbol[], tensors), sizedict) @test path isa EinExpr @@ -18,14 +19,15 @@ @test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:d], [:b, :i, :h], [:j]])) @testset "hyperedges" begin - a = EinExpr([:i, :β, :j], Dict(i => 2 for i in [:i, :β, :j])) - b = EinExpr([:k, :β], Dict(i => 2 for i in [:k, :β])) - c = EinExpr([:β, :l, :m], Dict(i => 2 for i in [:β, :l, :m])) + sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m, :β]) + a = EinExpr([:i, :β, :j]) + b = EinExpr([:k, :β]) + c = EinExpr([:β, :l, :m]) - path = einexpr(EinExprs.Exhaustive(), sum([a, b, c], skip = [:β])) + path = einexpr(EinExprs.Exhaustive(), sum([a, b, c], skip = [:β]), sizedict) @test all(∋(:β) ∘ head, branches(path)) - path = einexpr(EinExprs.Exhaustive(), sum([a, b, c], skip = Symbol[])) + path = einexpr(EinExprs.Exhaustive(), sum([a, b, c], skip = Symbol[]), sizedict) @test all(∋(:β) ∘ head, branches(path)[1:end-1]) @test all(!∋(:β) ∘ head, branches(path)[end:end]) end diff --git a/test/Greedy_test.jl b/test/Greedy_test.jl index f010dfb..412f3b8 100644 --- a/test/Greedy_test.jl +++ b/test/Greedy_test.jl @@ -1,27 +1,26 @@ @testset "Greedy" begin + sizedict = Dict(i => 2 for i in [:a, :b, :c, :d, :e, :f, :g, :h, :i, :j]) tensors = [ - EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])), - EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])), - EinExpr([:j], Dict(i => 2 for i in [:j])), - EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])), - EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])), - EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])), - EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), + EinExpr([:j, :b, :i, :h]), + EinExpr([:a, :c, :e, :f]), + EinExpr([:j]), + EinExpr([:e, :a, :g]), + EinExpr([:f, :b]), + EinExpr([:i, :h, :d]), + EinExpr([:d, :g, :c]), ] - path = einexpr(Greedy(), EinExpr(Symbol[], tensors)) + path = einexpr(Greedy(), EinExpr(Symbol[], tensors), sizedict) @test path isa EinExpr - @test mapreduce(flops, +, Branches(path)) == 100 + @test mapreduce(Base.Fix2(flops, sizedict), +, Branches(path)) == 100 @test all(splat(issetequal), zip(contractorder(path), [[:i, :h], [:j], [:a, :e], [:g, :c], [:f], [:b, :d]])) @testset "example: let unchanged" begin - tensors = [ - EinExpr([:i, :j, :k], Dict(:i => 2, :j => 2, :k => 2)), - EinExpr([:k, :l, :m], Dict(:k => 2, :l => 2, :m => 2)), - ] + sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m]) + tensors = [EinExpr([:i, :j, :k]), EinExpr([:k, :l, :m])] path = einexpr(Greedy(), EinExpr(Symbol[:i, :j, :l, :m], tensors)) @@ -29,14 +28,15 @@ end @testset "hyperedges" begin - a = EinExpr([:i, :β, :j], Dict(i => 2 for i in [:i, :β, :j])) - b = EinExpr([:k, :β], Dict(i => 2 for i in [:k, :β])) - c = EinExpr([:β, :l, :m], Dict(i => 2 for i in [:β, :l, :m])) + sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m, :β]) + a = EinExpr([:i, :β, :j]) + b = EinExpr([:k, :β]) + c = EinExpr([:β, :l, :m]) - path = einexpr(EinExprs.Greedy(), sum([a, b, c], skip = [:β])) + path = einexpr(EinExprs.Greedy(), sum([a, b, c], skip = [:β]), sizedict) @test all(∋(:β) ∘ head, branches(path)) - path = einexpr(EinExprs.Greedy(), sum([a, b, c], skip = Symbol[])) + path = einexpr(EinExprs.Greedy(), sum([a, b, c], skip = Symbol[]), sizedict) @test all(∋(:β) ∘ head, branches(path)[1:end-1]) @test all(!∋(:β) ∘ head, branches(path)[end:end]) end diff --git a/test/Naive_test.jl b/test/Naive_test.jl index 3e01037..4a9d9da 100644 --- a/test/Naive_test.jl +++ b/test/Naive_test.jl @@ -1,21 +1,22 @@ @testset "Naive" begin + sizedict = Dict(i => 2 for i in [:a, :b, :c, :d, :e, :f, :g, :h, :i, :j]) tensors = [ - EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])), - EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])), - EinExpr([:j], Dict(i => 2 for i in [:j])), - EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])), - EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])), - EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])), - EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), + EinExpr([:j, :b, :i, :h]), + EinExpr([:a, :c, :e, :f]), + EinExpr([:j]), + EinExpr([:e, :a, :g]), + EinExpr([:f, :b]), + EinExpr([:i, :h, :d]), + EinExpr([:d, :g, :c]), ] - path = einexpr(EinExprs.Naive(), EinExpr(Symbol[], tensors)) + path = einexpr(EinExprs.Naive(), EinExpr(Symbol[], tensors), sizedict) @test path isa EinExpr @test foldl((a, b) -> sum([a, b]), tensors) == path # TODO traverse through the tree and check everything is ok - @test mapreduce(flops, +, Branches(path)) == 872 + @test mapreduce(Base.Fix2(flops, sizedict), +, Branches(path)) == 872 # FIXME non-determinist behaviour on order @test all( @@ -24,14 +25,15 @@ ) @testset "hyperedges" begin - a = EinExpr([:i, :β, :j], Dict(i => 2 for i in [:i, :β, :j])) - b = EinExpr([:k, :β], Dict(i => 2 for i in [:k, :β])) - c = EinExpr([:β, :l, :m], Dict(i => 2 for i in [:β, :l, :m])) + sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m, :β]) + a = EinExpr([:i, :β, :j]) + b = EinExpr([:k, :β]) + c = EinExpr([:β, :l, :m]) - path = einexpr(EinExprs.Naive(), sum([a, b, c], skip = [:β])) + path = einexpr(EinExprs.Naive(), sum([a, b, c], skip = [:β]), sizedict) @test all(∋(:β) ∘ head, branches(path)) - path = einexpr(EinExprs.Naive(), sum([a, b, c], skip = Symbol[])) + path = einexpr(EinExprs.Naive(), sum([a, b, c], skip = Symbol[]), sizedict) @test all(∋(:β) ∘ head, branches(path)[1:end-1]) @test all(!∋(:β) ∘ head, branches(path)[end:end]) end diff --git a/test/Slicing_test.jl b/test/Slicing_test.jl index 720038d..6e9f6ef 100644 --- a/test/Slicing_test.jl +++ b/test/Slicing_test.jl @@ -42,49 +42,31 @@ [ EinExpr( [:m, :f, :g], - [ - EinExpr( - [:m, :f, :q], - Dict(i => sizes[i] for i in [:m, :f, :q]), - ), - EinExpr( - [:g, :q], - Dict(i => sizes[i] for i in [:g, :q]), - ), - ], - ), - EinExpr( - [:o, :i, :m, :c], - Dict(i => sizes[i] for i in [:o, :i, :m, :c]), + [EinExpr((:m, :f, :q),), EinExpr((:g, :q),)], ), + EinExpr((:o, :i, :m, :c),), ], ), - EinExpr([:f, :l, :i], Dict(i => sizes[i] for i in [:f, :l, :i])), + EinExpr((:f, :l, :i)), ], ), - EinExpr([:g, :n, :l, :a], Dict(i => sizes[i] for i in [:g, :n, :l, :a])), - ], - ), - EinExpr( - [:e, :d, :o], - [ - EinExpr([:b, :e], Dict(i => sizes[i] for i in [:b, :e])), - EinExpr([:d, :b, :o], Dict(i => sizes[i] for i in [:d, :b, :o])), + EinExpr((:g, :n, :l, :a)), ], ), + EinExpr([:e, :d, :o], [EinExpr((:b, :e)), EinExpr((:d, :b, :o))]), ], ), - EinExpr([:c, :e, :h], Dict(i => sizes[i] for i in [:c, :e, :h])), + EinExpr((:c, :e, :h)), ], ), - EinExpr([:k, :d, :h, :a, :n, :j], Dict(i => sizes[i] for i in [:k, :d, :h, :a, :n, :j])), + EinExpr((:k, :d, :h, :a, :n, :j)), ], ), - EinExpr([:p, :k], Dict(i => sizes[i] for i in [:p, :k])), + EinExpr((:p, :k)), ], ) cuttings = findslices(FlopsScorer(), expr, slices = 1000) - @test prod(i -> size(expr, i), cuttings) >= 1000 + @test prod(i -> sizedict[i], cuttings) >= 1000 end From 8d2d2b08950e8c8a5a33a57cdff201d4be6b3e11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 21 Oct 2023 15:47:33 +0200 Subject: [PATCH 02/28] Fix only-head `EinExpr` constructor --- src/EinExpr.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/EinExpr.jl b/src/EinExpr.jl index 71e1efd..76ae9aa 100644 --- a/src/EinExpr.jl +++ b/src/EinExpr.jl @@ -7,6 +7,7 @@ Base.@kwdef struct EinExpr args::Vector{EinExpr} = EinExpr[] end +EinExpr(head) = EinExpr(head, EinExpr[]) EinExpr(head, args::AbstractVecOrTuple{<:AbstractVecOrTuple{Symbol}}) = EinExpr(head, map(EinExpr, args)) EinExpr(head::NTuple, args) = EinExpr(collect(head), args) From 03573e12cb76733f9da6db5b48323c4a53c0affa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 21 Oct 2023 19:07:53 +0200 Subject: [PATCH 03/28] Microoptimize candidate cost caching in `Exhaustive` optimizer --- src/Optimizers/Exhaustive.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index 4c3d7f4..1df1261 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -25,8 +25,8 @@ end function einexpr(config::Exhaustive, path, sizedict; cost = BigInt(0)) metric = Base.Fix2(config.metric, sizedict) - leader = (; path = einexpr(Naive(), path), cost = mapreduce(config.metric, +, Branches(einexpr(Naive(), path)))) - cache = Dict{Vector{ImmutableVector{Symbol,Vector{Symbol}}},BigInt}() + leader = (; path = einexpr(Naive(), path), cost = mapreduce(metric, +, Branches(einexpr(Naive(), path)))) + cache = Dict{ImmutableVector{Symbol,Vector{Symbol}},BigInt}() function __einexpr_iterate(path, cost) if length(path.args) <= 2 @@ -39,7 +39,7 @@ function einexpr(config::Exhaustive, path, sizedict; cost = BigInt(0)) candidate = sum([i, j], skip = path.head ∪ hyperinds(path)) # prune paths based on metric - new_cost = cost + get!(cache, head.(candidate.args)) do + new_cost = cost + get!(cache, head(candidate)) do metric(candidate) end new_cost >= leader.cost && continue From 8350351518ecd636e7d4a88d9ce3b7c9f4ba580c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 21 Oct 2023 19:22:21 +0200 Subject: [PATCH 04/28] Microoptimize resource counting by changing to pre-order in `Exhaustive` optimizer --- src/Optimizers/Exhaustive.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index 1df1261..88292a8 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -25,12 +25,12 @@ end function einexpr(config::Exhaustive, path, sizedict; cost = BigInt(0)) metric = Base.Fix2(config.metric, sizedict) - leader = (; path = einexpr(Naive(), path), cost = mapreduce(metric, +, Branches(einexpr(Naive(), path)))) + leader = (; path = einexpr(Naive(), path), cost = mapreduce(metric, +, PreOrderDFS(einexpr(Naive(), path)))) cache = Dict{ImmutableVector{Symbol,Vector{Symbol}},BigInt}() function __einexpr_iterate(path, cost) if length(path.args) <= 2 - leader = (; path = path, cost = mapreduce(metric, +, Branches(path))) + leader = (; path = path, cost = mapreduce(metric, +, PreOrderDFS(path))) return end From 5f44cce91aefecfe8edd60452167c149b84d1cc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 21 Oct 2023 19:36:03 +0200 Subject: [PATCH 05/28] Move resource counting microoptimization to `Branches` --- src/Optimizers/Exhaustive.jl | 5 +++-- src/Slicing.jl | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index 88292a8..3b29ce9 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -25,12 +25,13 @@ end function einexpr(config::Exhaustive, path, sizedict; cost = BigInt(0)) metric = Base.Fix2(config.metric, sizedict) - leader = (; path = einexpr(Naive(), path), cost = mapreduce(metric, +, PreOrderDFS(einexpr(Naive(), path)))) + leader = + (; path = einexpr(Naive(), path), cost = mapreduce(metric, +, Branches(einexpr(Naive(), path), inverse = true))) cache = Dict{ImmutableVector{Symbol,Vector{Symbol}},BigInt}() function __einexpr_iterate(path, cost) if length(path.args) <= 2 - leader = (; path = path, cost = mapreduce(metric, +, PreOrderDFS(path))) + leader = (; path = path, cost = mapreduce(metric, +, Branches(path, inverse = true))) return end diff --git a/src/Slicing.jl b/src/Slicing.jl index fb9934d..f7f2366 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -102,7 +102,7 @@ function findslices( candidates = Set(setdiff(mapreduce(head, ∪, PostOrderDFS(path)), skip)) solution = Set{Symbol}() current = (; slices = 1, size = maximum(Base.Fix2(length, sizedict), PostOrderDFS(path)), overhead = 1.0) - original_flops = mapreduce(flops, +, Branches(path)) + original_flops = mapreduce(flops, +, Branches(path; inverse = true)) sliced_path = path while !isempty(candidates) @@ -115,7 +115,7 @@ function findslices( sliced_path = selectdim(sliced_path, winner, 1) cur_overhead = prod(i -> sizedict[i], [solution..., winner]) * - mapreduce(Base.Fix2(flops, sizedict), +, Branches(sliced_path)) / original_flops + mapreduce(Base.Fix2(flops, sizedict), +, Branches(sliced_path; inverse = true)) / original_flops !isnothing(overhead) && cur_overhead > overhead && break push!(solution, winner) From 6896fba5bc59a8752abe00ffa9d7ed6aca162c78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 21 Oct 2023 19:36:42 +0200 Subject: [PATCH 06/28] Fix resource counting in `findslices` --- src/Slicing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Slicing.jl b/src/Slicing.jl index f7f2366..9f15bef 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -102,7 +102,7 @@ function findslices( candidates = Set(setdiff(mapreduce(head, ∪, PostOrderDFS(path)), skip)) solution = Set{Symbol}() current = (; slices = 1, size = maximum(Base.Fix2(length, sizedict), PostOrderDFS(path)), overhead = 1.0) - original_flops = mapreduce(flops, +, Branches(path; inverse = true)) + original_flops = mapreduce(Base.Fix2(flops, sizedict), +, Branches(path; inverse = true)) sliced_path = path while !isempty(candidates) From fce47db9e8e0ba226ec7773b05300a8400d70418 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 21 Oct 2023 19:46:40 +0200 Subject: [PATCH 07/28] Curry resource counters on `Dict` argument --- src/Counters.jl | 5 +++++ src/Optimizers/Exhaustive.jl | 2 +- src/Optimizers/Greedy.jl | 2 +- src/Slicing.jl | 4 ++-- test/Greedy_test.jl | 2 +- test/Naive_test.jl | 2 +- 6 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/Counters.jl b/src/Counters.jl index ad726e7..9dad562 100644 --- a/src/Counters.jl +++ b/src/Counters.jl @@ -25,3 +25,8 @@ end Count the rank reduction after performing the contraction of the root of the `path` tree. """ removedrank(expr::EinExpr, _) = mapreduce(ndims, max, expr.args) - ndims(expr) + +for f in [:flops, :removedsize] + @eval $f(sizedict::Dict{Symbol}) = Base.Fix2($f, sizedict) +end +removedrank(::Dict) = Base.Fix2(removedrank, nothing) diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index 3b29ce9..209413f 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -23,7 +23,7 @@ The algorithm has a ``\mathcal{O}(n!)`` time complexity if `outer = true` and `` end function einexpr(config::Exhaustive, path, sizedict; cost = BigInt(0)) - metric = Base.Fix2(config.metric, sizedict) + metric = config.metric(sizedict) leader = (; path = einexpr(Naive(), path), cost = mapreduce(metric, +, Branches(einexpr(Naive(), path), inverse = true))) diff --git a/src/Optimizers/Greedy.jl b/src/Optimizers/Greedy.jl index f9559d2..6ee5a18 100644 --- a/src/Optimizers/Greedy.jl +++ b/src/Optimizers/Greedy.jl @@ -28,7 +28,7 @@ The implementation uses a binary heaptree to sort candidate pairwise tensor cont end function einexpr(config::Greedy, path, sizedict) - metric = Base.Fix2(config.metric, sizedict) + metric = config.metric(sizedict) # generate initial candidate contractions queue = MutableBinaryHeap{Tuple{Float64,EinExpr}}( diff --git a/src/Slicing.jl b/src/Slicing.jl index 9f15bef..82e6b07 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -102,7 +102,7 @@ function findslices( candidates = Set(setdiff(mapreduce(head, ∪, PostOrderDFS(path)), skip)) solution = Set{Symbol}() current = (; slices = 1, size = maximum(Base.Fix2(length, sizedict), PostOrderDFS(path)), overhead = 1.0) - original_flops = mapreduce(Base.Fix2(flops, sizedict), +, Branches(path; inverse = true)) + original_flops = mapreduce(flops(sizedict), +, Branches(path; inverse = true)) sliced_path = path while !isempty(candidates) @@ -115,7 +115,7 @@ function findslices( sliced_path = selectdim(sliced_path, winner, 1) cur_overhead = prod(i -> sizedict[i], [solution..., winner]) * - mapreduce(Base.Fix2(flops, sizedict), +, Branches(sliced_path; inverse = true)) / original_flops + mapreduce(flops(sizedict), +, Branches(sliced_path; inverse = true)) / original_flops !isnothing(overhead) && cur_overhead > overhead && break push!(solution, winner) diff --git a/test/Greedy_test.jl b/test/Greedy_test.jl index 412f3b8..07d2586 100644 --- a/test/Greedy_test.jl +++ b/test/Greedy_test.jl @@ -14,7 +14,7 @@ @test path isa EinExpr - @test mapreduce(Base.Fix2(flops, sizedict), +, Branches(path)) == 100 + @test mapreduce(flops(sizedict), +, Branches(path)) == 100 @test all(splat(issetequal), zip(contractorder(path), [[:i, :h], [:j], [:a, :e], [:g, :c], [:f], [:b, :d]])) diff --git a/test/Naive_test.jl b/test/Naive_test.jl index 4a9d9da..2b8f15c 100644 --- a/test/Naive_test.jl +++ b/test/Naive_test.jl @@ -16,7 +16,7 @@ @test foldl((a, b) -> sum([a, b]), tensors) == path # TODO traverse through the tree and check everything is ok - @test mapreduce(Base.Fix2(flops, sizedict), +, Branches(path)) == 872 + @test mapreduce(flops(sizedict), +, Branches(path)) == 872 # FIXME non-determinist behaviour on order @test all( From cbcb2740e408c2ce385409ac15f0fd147b9f2933 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 23 Oct 2023 16:09:02 +0200 Subject: [PATCH 08/28] Microoptimize `mapreduce` call in `flops` --- src/Counters.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Counters.jl b/src/Counters.jl index 9dad562..130bb8b 100644 --- a/src/Counters.jl +++ b/src/Counters.jl @@ -7,7 +7,7 @@ flops(expr::EinExpr, sizedict) = if length(expr.args) == 0 || length(expr.args) == 1 && isempty(suminds(expr)) 0 else - mapreduce(i -> sizedict[i], *, Iterators.flatten((head(expr), suminds(expr))), init = one(BigInt)) + mapreduce(Base.Fix1(getindex, sizedict), *, Iterators.flatten((head(expr), suminds(expr))), init = one(BigInt)) end """ From f55606054670fad513a337f7b12394ca80839e4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 23 Oct 2023 17:30:45 +0200 Subject: [PATCH 09/28] Microoptimize type-stability in `Exhaustive` optimizer --- src/Optimizers/Exhaustive.jl | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index 209413f..cce2223 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -24,30 +24,20 @@ end function einexpr(config::Exhaustive, path, sizedict; cost = BigInt(0)) metric = config.metric(sizedict) - - leader = - (; path = einexpr(Naive(), path), cost = mapreduce(metric, +, Branches(einexpr(Naive(), path), inverse = true))) + leader = Ref{NamedTuple{(:path, :cost),Tuple{EinExpr,BigInt}}}((; + path = einexpr(Naive(), path), + cost = mapreduce(metric, +, Branches(einexpr(Naive(), path), inverse = true), init = BigInt(0))::BigInt, + )) cache = Dict{ImmutableVector{Symbol,Vector{Symbol}},BigInt}() + __einexpr_exhaustive_it(path, cost, metric, config.outer, leader, cache) + return leader[].path +end - function __einexpr_iterate(path, cost) - if length(path.args) <= 2 - leader = (; path = path, cost = mapreduce(metric, +, Branches(path, inverse = true))) - return - end - - for (i, j) in combinations(args(path), 2) - !config.outer && isdisjoint(head(i), head(j)) && continue - candidate = sum([i, j], skip = path.head ∪ hyperinds(path)) - - # prune paths based on metric - new_cost = cost + get!(cache, head(candidate)) do - metric(candidate) - end - new_cost >= leader.cost && continue - - new_path = EinExpr(head(path), [candidate, filter(∉([i, j]), args(path))...]) - __einexpr_iterate(new_path, new_cost) - end +function __einexpr_exhaustive_it(path, cost, metric, outer, leader, cache) + if length(path.args) <= 2 + leader[] = + (; path = path, cost = mapreduce(metric, +, Branches(path, inverse = true), init = BigInt(0))::BigInt) + return end for (i, j) in combinations(args(path), 2) From 102e94747c5abfa59985710abe338b0bec1d40ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 23 Oct 2023 17:31:34 +0200 Subject: [PATCH 10/28] Fix recursive problem in generic `einexpr` call --- src/Optimizers/Optimizers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Optimizers/Optimizers.jl b/src/Optimizers/Optimizers.jl index d4fb537..516b8ce 100644 --- a/src/Optimizers/Optimizers.jl +++ b/src/Optimizers/Optimizers.jl @@ -3,7 +3,7 @@ abstract type Optimizer end function einexpr end einexpr(T::Type{<:Optimizer}, args...; kwargs...) = einexpr(T(; kwargs...), args...) -einexpr(config::Optimizer, expr) = einexpr(config, expr) +einexpr(config::Optimizer, expr, sizedict) = einexpr(config, expr, sizedict) include("Naive.jl") include("Exhaustive.jl") From 68bfd82e293da41b363d76ca244c6a206b17302a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 17:39:28 +0100 Subject: [PATCH 11/28] Remove artifact `ImmutableVector` --- src/EinExpr.jl | 2 +- src/Optimizers/Exhaustive.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/EinExpr.jl b/src/EinExpr.jl index 76ae9aa..f24eff7 100644 --- a/src/EinExpr.jl +++ b/src/EinExpr.jl @@ -3,7 +3,7 @@ using DataStructures: DefaultDict using AbstractTrees Base.@kwdef struct EinExpr - head::ImmutableVector{Symbol,Vector{Symbol}} + head::Vector{Symbol} args::Vector{EinExpr} = EinExpr[] end diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index cce2223..d220934 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -28,7 +28,7 @@ function einexpr(config::Exhaustive, path, sizedict; cost = BigInt(0)) path = einexpr(Naive(), path), cost = mapreduce(metric, +, Branches(einexpr(Naive(), path), inverse = true), init = BigInt(0))::BigInt, )) - cache = Dict{ImmutableVector{Symbol,Vector{Symbol}},BigInt}() + cache = Dict{Vector{Symbol},BigInt}() __einexpr_exhaustive_it(path, cost, metric, config.outer, leader, cache) return leader[].path end From 1a7b0c2ed1a81da8bf1e408d2cb0ca442b326a82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 19 Dec 2023 19:24:02 +0100 Subject: [PATCH 12/28] Refactor code to use SizedEinExpr for size-aware expressions --- src/Counters.jl | 25 +++++++++---- src/EinExpr.jl | 2 ++ src/EinExprs.jl | 3 ++ src/Optimizers/Exhaustive.jl | 34 ++++++++++-------- src/Optimizers/Naive.jl | 10 ++++-- src/SizedEinExpr.jl | 70 ++++++++++++++++++++++++++++++++++++ 6 files changed, 120 insertions(+), 24 deletions(-) create mode 100644 src/SizedEinExpr.jl diff --git a/src/Counters.jl b/src/Counters.jl index 130bb8b..b284cca 100644 --- a/src/Counters.jl +++ b/src/Counters.jl @@ -3,30 +3,41 @@ Count the number of mathematical operations will be performed by the contraction of the root of the `path` tree. """ -flops(expr::EinExpr, sizedict) = - if length(expr.args) == 0 || length(expr.args) == 1 && isempty(suminds(expr)) +flops(sexpr::SizedEinExpr) = + if nargs(sexpr) == 0 || nargs(sexpr) == 1 && isempty(suminds(sexpr)) 0 else - mapreduce(Base.Fix1(getindex, sizedict), *, Iterators.flatten((head(expr), suminds(expr))), init = one(BigInt)) + mapreduce( + Base.Fix1(getindex, sexpr.size), + *, + Iterators.flatten((head(sexpr), suminds(sexpr))), + init = one(BigInt), + ) end +flops(expr::EinExpr, size) = flops(SizedEinExpr(expr, size)) + """ removedsize(path::EinExpr) Count the amount of memory that will be freed after performing the contraction of the root of the `path` tree. """ -function removedsize(expr::EinExpr, sizedict) - mapreduce(prod ∘ Base.Fix2(size, sizedict), +, expr.args) - prod(size(expr, sizedict)) +function removedsize(sexpr::SizedEinExpr) + mapreduce(prod ∘ Base.Fix2(size, sexpr.size), +, sexpr.args) - prod(size(sexpr, sexpr.size)) end +removedsize(expr::EinExpr, size) = removedsize(SizedEinExpr(expr, size)) + """ removedrank(path::EinExpr) Count the rank reduction after performing the contraction of the root of the `path` tree. """ -removedrank(expr::EinExpr, _) = mapreduce(ndims, max, expr.args) - ndims(expr) +removedrank(expr::EinExpr) = mapreduce(ndims, max, expr.args) - ndims(expr) +removedrank(expr::EinExpr, _) = removedrank(expr) +removedrank(sexpr::SizedEinExpr, _) = removedrank(sexpr.path) for f in [:flops, :removedsize] @eval $f(sizedict::Dict{Symbol}) = Base.Fix2($f, sizedict) end -removedrank(::Dict) = Base.Fix2(removedrank, nothing) +removedrank(::Dict) = removedrank diff --git a/src/EinExpr.jl b/src/EinExpr.jl index f24eff7..3838527 100644 --- a/src/EinExpr.jl +++ b/src/EinExpr.jl @@ -32,6 +32,8 @@ See also: [`head`](@ref). """ args(path::EinExpr) = path.args +nargs(path::EinExpr) = length(path.args) + """ inds(path) diff --git a/src/EinExprs.jl b/src/EinExprs.jl index 589af7b..c3d52f8 100644 --- a/src/EinExprs.jl +++ b/src/EinExprs.jl @@ -5,6 +5,9 @@ export EinExpr export head, args, inds, hyperinds, suminds, parsuminds, collapse!, contractorder, select, neighbours export Branches, branches, leaves +include("SizedEinExpr.jl") +export SizedEinExpr + include("Counters.jl") export flops, removedsize diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index d220934..12410b3 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -22,35 +22,41 @@ The algorithm has a ``\mathcal{O}(n!)`` time complexity if `outer = true` and `` outer::Bool = false end -function einexpr(config::Exhaustive, path, sizedict; cost = BigInt(0)) - metric = config.metric(sizedict) - leader = Ref{NamedTuple{(:path, :cost),Tuple{EinExpr,BigInt}}}((; +function einexpr(config::Exhaustive, path; cost = BigInt(0)) + # metric = Base.Fix2(config.metric, path.size) + leader = Ref((; path = einexpr(Naive(), path), - cost = mapreduce(metric, +, Branches(einexpr(Naive(), path), inverse = true), init = BigInt(0))::BigInt, + cost = mapreduce(config.metric, +, Branches(einexpr(Naive(), path), inverse = true), init = BigInt(0))::BigInt, )) - cache = Dict{Vector{Symbol},BigInt}() - __einexpr_exhaustive_it(path, cost, metric, config.outer, leader, cache) + __einexpr_exhaustive_it(path, cost, config.metric, config.outer, leader) return leader[].path end -function __einexpr_exhaustive_it(path, cost, metric, outer, leader, cache) - if length(path.args) <= 2 - leader[] = - (; path = path, cost = mapreduce(metric, +, Branches(path, inverse = true), init = BigInt(0))::BigInt) +function __einexpr_exhaustive_it( + path, + cost, + metric, + outer, + leader; + cache = Dict{Vector{Symbol},BigInt}(), + hashyperinds = any(hyperinds(path)), +) + if nargs(path) <= 2 + leader[] = (; path = path, cost = cost) #= mapreduce(metric, +, Branches(path, inverse = true), init = BigInt(0))) =# return end for (i, j) in combinations(args(path), 2) !outer && isdisjoint(head(i), head(j)) && continue - candidate = sum([i, j], skip = path.head ∪ hyperinds(path)) + candidate = sum([i, j]; skip = hashyperinds ? path.head ∪ hyperinds(path) : path.head) # prune paths based on metric new_cost = cost + get!(cache, head(candidate)) do - metric(candidate) + metric(SizedEinExpr(candidate, path.size)) end new_cost >= leader[].cost && continue - new_path = EinExpr(head(path), [candidate, filter(∉([i, j]), args(path))...]) - __einexpr_exhaustive_it(new_path, new_cost, metric, outer, leader, cache) + new_path = SizedEinExpr(EinExpr(head(path), [candidate, filter(∉([i, j]), args(path))...]), path.size) # sum([candidate, filter(∉([i, j]), args(path))...], skip = path.head) + __einexpr_exhaustive_it(new_path, new_cost, metric, outer, leader; cache, hashyperinds) end end diff --git a/src/Optimizers/Naive.jl b/src/Optimizers/Naive.jl index 407b500..9b31d93 100644 --- a/src/Optimizers/Naive.jl +++ b/src/Optimizers/Naive.jl @@ -1,12 +1,14 @@ +using AbstractTrees + struct Naive <: Optimizer end einexpr(::Naive, path, _) = einexpr(Naive(), path) function einexpr(::Naive, path) - hist = Dict(i => count(∋(i) ∘ head, path.args) for i in hyperinds(path)) + hist = Dict(i => count(∋(i) ∘ head, args(path)) for i in hyperinds(path)) - foldl(path.args) do a, b - expr = sum([a, b], skip = path.head ∪ collect(keys(hist))) + foldl(args(path)) do a, b + expr = sum([a, b], skip = head(path) ∪ collect(keys(hist))) for i in Iterators.filter(∈(keys(hist)), ∩(head(a), head(b))) hist[i] -= 1 @@ -16,3 +18,5 @@ function einexpr(::Naive, path) return expr end end + +einexpr(::Naive, path::SizedEinExpr) = SizedEinExpr(einexpr(Naive(), path.path), path.size) diff --git a/src/SizedEinExpr.jl b/src/SizedEinExpr.jl new file mode 100644 index 0000000..15ab83e --- /dev/null +++ b/src/SizedEinExpr.jl @@ -0,0 +1,70 @@ +using AbstractTrees + +struct SizedEinExpr + path::EinExpr + size::Dict{Symbol,Int} + + function SizedEinExpr(path, size) + # inds(path) ⊆ keys(size) || throw(ArgumentError("")) + new(path, size) + end +end + +EinExpr(path::Vector{Symbol}, size::Dict{Symbol}) = SizedEinExpr(EinExpr(path), size) + +head(sexpr::SizedEinExpr) = head(sexpr.path) +args(sexpr::SizedEinExpr) = sexpr.path.args # map(Base.Fix2(SizedEinExpr, sexpr.size), sexpr.path.args) +nargs(sexpr::SizedEinExpr) = nargs(sexpr.path) +inds(sexpr::SizedEinExpr) = inds(sexpr.path) + +function Base.getproperty(sexpr::SizedEinExpr, name::Symbol) + name === :head && return getfield(sexpr, :path).head + name === :args && return getfield(sexpr, :path).args + return getfield(sexpr, name) +end + +Base.:(==)(a::SizedEinExpr, b::SizedEinExpr) = a.path == b.path && a.size == b.size + +Base.ndims(sexpr::SizedEinExpr) = ndims(sexpr.path) + +Base.size(sexpr::SizedEinExpr) = size(sexpr.path, sexpr.size) +Base.length(sexpr::SizedEinExpr) = length(sexpr.path, sexpr.size) + +collapse!(sexpr::SizedEinExpr) = collapse!(sexpr.path) + +select(sexpr::SizedEinExpr, i) = map(Base.Fix2(SizedEinExpr, sexpr.size), select(sexpr.path, i)) + +neighbours(sexpr::SizedEinExpr, i) = map(Base.Fix2(SizedEinExpr, sexpr.size), neighbours(sexpr.path, i)) + +contractorder(sexpr::SizedEinExpr) = contractorder(sexpr.path) + +hyperinds(sexpr::SizedEinExpr) = hyperinds(sexpr.path) + +suminds(sexpr::SizedEinExpr) = suminds(sexpr.path) +parsuminds(sexpr::SizedEinExpr) = parsuminds(sexpr.path) + +Base.sum!(sexpr::SizedEinExpr, inds) = sum!(sexpr.path, inds) +Base.sum(sexpr::SizedEinExpr, inds) = sum(sexpr.path, inds) + +function Base.sum(sexpr::Vector{SizedEinExpr}; skip = Symbol[]) + path = sum(map(x -> x.path, sexpr); skip) + size = allequal(Iterators.map(x -> x.size, sexpr)) ? first(sexpr).size : merge(map(x -> x.size, sexpr)...) + # size = merge(map(x -> x.size, sexpr)...) + SizedEinExpr(path, size) +end + +# Iteration interface +Base.IteratorEltype(::Type{<:TreeIterator{SizedEinExpr}}) = Base.HasEltype() +Base.eltype(::Type{<:TreeIterator{SizedEinExpr}}) = SizedEinExpr + +# AbstractTrees interface and traits +AbstractTrees.children(sexpr::SizedEinExpr) = map(Base.Fix2(SizedEinExpr, sexpr.size), args(sexpr)) +AbstractTrees.childtype(::Type{SizedEinExpr}) = SizedEinExpr +AbstractTrees.childrentype(::Type{SizedEinExpr}) = Vector{SizedEinExpr} +AbstractTrees.childstatetype(::Type{SizedEinExpr}) = Int +AbstractTrees.nodetype(::Type{SizedEinExpr}) = SizedEinExpr + +AbstractTrees.ParentLinks(::Type{SizedEinExpr}) = ImplicitParents() +AbstractTrees.SiblingLinks(::Type{SizedEinExpr}) = ImplicitSiblings() +AbstractTrees.ChildIndexing(::Type{SizedEinExpr}) = IndexedChildren() +AbstractTrees.NodeType(::Type{SizedEinExpr}) = HasNodeType() From fbd9104c0efae5469aa1cda21cba55285d08ca29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 19 Dec 2023 19:38:15 +0100 Subject: [PATCH 13/28] Fix `removedsize` on `SizedEinExpr` --- src/Counters.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Counters.jl b/src/Counters.jl index b284cca..eca2b8e 100644 --- a/src/Counters.jl +++ b/src/Counters.jl @@ -22,8 +22,8 @@ flops(expr::EinExpr, size) = flops(SizedEinExpr(expr, size)) Count the amount of memory that will be freed after performing the contraction of the root of the `path` tree. """ -function removedsize(sexpr::SizedEinExpr) - mapreduce(prod ∘ Base.Fix2(size, sexpr.size), +, sexpr.args) - prod(size(sexpr, sexpr.size)) +removedsize(sexpr::SizedEinExpr) = -length(sexpr) + mapreduce(+, sexpr.args) do arg + length(SizedEinExpr(arg, sexpr.size)) end removedsize(expr::EinExpr, size) = removedsize(SizedEinExpr(expr, size)) From 513f8722ac2d806931d0ffa4d2145603ef9ca09e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 19 Dec 2023 19:38:49 +0100 Subject: [PATCH 14/28] Fix `Greedy` optimizer on `SizedEinExpr` --- src/Optimizers/Greedy.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Optimizers/Greedy.jl b/src/Optimizers/Greedy.jl index 6ee5a18..ec4982c 100644 --- a/src/Optimizers/Greedy.jl +++ b/src/Optimizers/Greedy.jl @@ -43,7 +43,7 @@ function einexpr(config::Greedy, path, sizedict) end, ) - while length(path.args) > 2 && length(queue) > 1 + while nargs(path) > 2 && length(queue) > 1 # choose winner _, winner = config.choose(queue) @@ -67,3 +67,7 @@ function einexpr(config::Greedy, path, sizedict) return path end + +function einexpr(config::Greedy, path::SizedEinExpr) + return einexpr(config, path.path, path.size) +end From c08dba78de091be878b7be5058ae6cd4868ea3a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 24 Dec 2023 17:25:29 +0100 Subject: [PATCH 15/28] Specialize on `metric` function to avoid recursive dynamic-dispatch --- src/Optimizers/Exhaustive.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index 12410b3..f4ef1d9 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -28,19 +28,19 @@ function einexpr(config::Exhaustive, path; cost = BigInt(0)) path = einexpr(Naive(), path), cost = mapreduce(config.metric, +, Branches(einexpr(Naive(), path), inverse = true), init = BigInt(0))::BigInt, )) - __einexpr_exhaustive_it(path, cost, config.metric, config.outer, leader) + __einexpr_exhaustive_it(path, cost, Val(config.metric), config.outer, leader) return leader[].path end function __einexpr_exhaustive_it( path, cost, - metric, + @specialize(metric::Val{Metric}), outer, leader; cache = Dict{Vector{Symbol},BigInt}(), hashyperinds = any(hyperinds(path)), -) +) where {Metric} if nargs(path) <= 2 leader[] = (; path = path, cost = cost) #= mapreduce(metric, +, Branches(path, inverse = true), init = BigInt(0))) =# return @@ -52,7 +52,7 @@ function __einexpr_exhaustive_it( # prune paths based on metric new_cost = cost + get!(cache, head(candidate)) do - metric(SizedEinExpr(candidate, path.size)) + Metric(SizedEinExpr(candidate, path.size)) end new_cost >= leader[].cost && continue From b975efda95cca6a4662b94a5be1d0245bd216210 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 24 Dec 2023 18:20:50 +0100 Subject: [PATCH 16/28] Speedup `sum` function on `EinExpr`s --- src/EinExpr.jl | 29 +++++++++++++++++++++++++---- src/Optimizers/Exhaustive.jl | 3 ++- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/EinExpr.jl b/src/EinExpr.jl index 3838527..e4976a4 100644 --- a/src/EinExpr.jl +++ b/src/EinExpr.jl @@ -226,6 +226,7 @@ Create an `EinExpr` from other `EinExpr`s. function Base.sum(args::Vector{EinExpr}; skip = Symbol[]) _head = Symbol[] _counts = Int[] + for arg in args for index in head(arg) i = findfirst(Base.Fix1(===, index), _head) @@ -233,17 +234,37 @@ function Base.sum(args::Vector{EinExpr}; skip = Symbol[]) push!(_head, index) push!(_counts, 1) else - _counts[i] += 1 + @inbounds _counts[i] += 1 end end end - _head = map(first, Iterators.filter(zip(_head, _counts)) do (index, count) - count == 1 || index ∈ skip - end) + # NOTE `map` with `Iterators.filter` induces many heap grows; allocating once and deleting is faster + for i in Iterators.reverse(eachindex(_head, _counts)) + (_counts[i] == 1 || _head[i] ∈ skip) && continue + deleteat!(_head, i) + end + EinExpr(_head, args) end +function Base.sum(a::EinExpr, b::EinExpr; skip = Symbol[]) + _head = copy(head(a)) + + for index in head(b) + i = findfirst(Base.Fix1(===, index), _head) + if isnothing(i) + push!(_head, index) + elseif index ∈ skip + continue + else + deleteat!(_head, i) + end + end + + EinExpr(_head, [a, b]) +end + function Base.string(path::EinExpr; recursive::Bool = false) !recursive && return "$(join(map(x -> string.(head(x)) |> join, args(path)), ","))->$(string.(head(path)) |> join)" map(string, Branches(path)) diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index f4ef1d9..45679d5 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -48,7 +48,8 @@ function __einexpr_exhaustive_it( for (i, j) in combinations(args(path), 2) !outer && isdisjoint(head(i), head(j)) && continue - candidate = sum([i, j]; skip = hashyperinds ? path.head ∪ hyperinds(path) : path.head) + # candidate = sum([i, j]; skip = hashyperinds ? path.head ∪ hyperinds(path) : path.head) + candidate = sum(i, j; skip = hashyperinds ? path.head ∪ hyperinds(path) : path.head) # prune paths based on metric new_cost = cost + get!(cache, head(candidate)) do From 72db773823c6ca5e420deb1e6bc00a0b50c31b18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 24 Dec 2023 18:21:23 +0100 Subject: [PATCH 17/28] Format comment --- src/Optimizers/Exhaustive.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index 45679d5..33ae6be 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -42,7 +42,8 @@ function __einexpr_exhaustive_it( hashyperinds = any(hyperinds(path)), ) where {Metric} if nargs(path) <= 2 - leader[] = (; path = path, cost = cost) #= mapreduce(metric, +, Branches(path, inverse = true), init = BigInt(0))) =# + #= mapreduce(metric, +, Branches(path, inverse = true), init = BigInt(0))) =# + leader[] = (; path = path, cost = cost) return end From 5da6cd6bfbf54aa096acdd380ec1a5d630dca511 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 24 Dec 2023 18:21:35 +0100 Subject: [PATCH 18/28] Refactor code to use sum function in make.jl --- benchmark/make.jl | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/benchmark/make.jl b/benchmark/make.jl index d7002fc..d9ffa8e 100644 --- a/benchmark/make.jl +++ b/benchmark/make.jl @@ -15,18 +15,15 @@ suite["greedy"] = BenchmarkGroup([]) suite["kahypar"] = BenchmarkGroup([]) # BENCHMARK 1 -expr = EinExpr( - Symbol[], - [ - EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])), - EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])), - EinExpr([:j], Dict(i => 2 for i in [:j])), - EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])), - EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])), - EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])), - EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), - ], -) +expr = sum([ + EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])), + EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])), + EinExpr([:j], Dict(i => 2 for i in [:j])), + EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])), + EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])), + EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])), + EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), +]) suite["naive"][1] = @benchmarkable einexpr(EinExprs.Naive(), $expr) suite["exhaustive"][1] = @benchmarkable einexpr(Exhaustive(), $expr) @@ -41,7 +38,7 @@ D = EinExpr([:c, :h, :d, :i], Dict(:c => 2, :h => 2, :d => 2, :i => 2)) E = EinExpr([:f, :i, :g, :j], Dict(:f => 2, :i => 2, :g => 2, :j => 2)) F = EinExpr([:B, :h, :k, :l], Dict(:B => 2, :h => 2, :k => 2, :l => 2)) G = EinExpr([:j, :k, :l, :D], Dict(:j => 2, :k => 2, :l => 2, :D => 2)) -expr = EinExpr([:A, :B, :C, :D], [A, B, C, D, E, F, G]) +expr = sum([A, B, C, D, E, F, G], skip = [:A, :B, :C, :D]) suite["naive"][2] = @benchmarkable einexpr(EinExprs.Naive(), $expr) suite["exhaustive"][2] = @benchmarkable einexpr(Exhaustive(), $expr) From e5826b24d60cb1c3a24094c0f716081105a09b59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 25 Dec 2023 02:37:51 +0100 Subject: [PATCH 19/28] Remove comment --- src/Optimizers/Exhaustive.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index 33ae6be..264600d 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -49,7 +49,6 @@ function __einexpr_exhaustive_it( for (i, j) in combinations(args(path), 2) !outer && isdisjoint(head(i), head(j)) && continue - # candidate = sum([i, j]; skip = hashyperinds ? path.head ∪ hyperinds(path) : path.head) candidate = sum(i, j; skip = hashyperinds ? path.head ∪ hyperinds(path) : path.head) # prune paths based on metric From 7b7707636761ccc670d5cc5a209618c5bbfde737 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 27 Dec 2023 16:41:21 +0100 Subject: [PATCH 20/28] Add size indexing to SizedEinExpr --- src/SizedEinExpr.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/SizedEinExpr.jl b/src/SizedEinExpr.jl index 15ab83e..801717a 100644 --- a/src/SizedEinExpr.jl +++ b/src/SizedEinExpr.jl @@ -28,6 +28,7 @@ Base.:(==)(a::SizedEinExpr, b::SizedEinExpr) = a.path == b.path && a.size == b.s Base.ndims(sexpr::SizedEinExpr) = ndims(sexpr.path) Base.size(sexpr::SizedEinExpr) = size(sexpr.path, sexpr.size) +Base.size(sexpr::SizedEinExpr, i) = sexpr.size[i] Base.length(sexpr::SizedEinExpr) = length(sexpr.path, sexpr.size) collapse!(sexpr::SizedEinExpr) = collapse!(sexpr.path) From af17760857cd105b8890bc0e6382b50c1788dad5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 27 Dec 2023 16:42:32 +0100 Subject: [PATCH 21/28] Fix `EinExprs` tests --- test/EinExpr_test.jl | 87 ++++++++------------------------------- test/SizedEinExpr_test.jl | 22 ++++++++++ test/runtests.jl | 1 + 3 files changed, 41 insertions(+), 69 deletions(-) create mode 100644 test/SizedEinExpr_test.jl diff --git a/test/EinExpr_test.jl b/test/EinExpr_test.jl index 76c1d79..114e133 100644 --- a/test/EinExpr_test.jl +++ b/test/EinExpr_test.jl @@ -2,7 +2,7 @@ using LinearAlgebra @testset "identity" begin - tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) + tensor = EinExpr([:i, :j]) expr = EinExpr([:i, :j], [tensor]) @test expr.head == head(tensor) @@ -11,10 +11,6 @@ @test head(expr) == head(tensor) @test ndims(expr) == 2 - @test size(expr, :i) == 2 - @test size(expr, :j) == 3 - @test size(expr) == (2, 3) - @test isempty(hyperinds(expr)) @test isempty(suminds(expr)) @test isempty(parsuminds(expr)) @@ -27,7 +23,7 @@ end @testset "transpose" begin - tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) + tensor = EinExpr([:i, :j]) expr = EinExpr([:j, :i], [tensor]) @test expr.head == reverse(inds(tensor)) @@ -36,10 +32,6 @@ @test head(expr) == reverse(inds(tensor)) @test ndims(expr) == 2 - @test size(expr, :i) == 2 - @test size(expr, :j) == 3 - @test size(expr) == (3, 2) - @test isempty(hyperinds(expr)) @test isempty(suminds(expr)) @test isempty(parsuminds(expr)) @@ -52,7 +44,7 @@ end @testset "axis sum" begin - tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) + tensor = EinExpr([:i, :j]) expr = EinExpr((:i,), [tensor]) @test all(splat(==), zip(expr.head, [:i])) @@ -61,10 +53,6 @@ @test all(splat(==), zip(head(expr), (:i,))) @test all(splat(==), zip(inds(expr), [:i, :j])) - @test size(expr, :i) == 2 - @test size(expr, :j) == 3 - @test size(expr) == (2,) - @test isempty(hyperinds(expr)) @test suminds(expr) == [:j] @test isempty(parsuminds(expr)) @@ -77,7 +65,7 @@ end @testset "diagonal" begin - tensor = EinExpr([:i, :i], Dict(:i => 2)) + tensor = EinExpr([:i, :i]) expr = EinExpr((:i,), [tensor]) @test all(splat(==), zip(expr.head, (:i,))) @@ -86,9 +74,6 @@ @test all(splat(==), zip(head(expr), (:i,))) @test all(splat(==), zip(inds(expr), head(expr))) - @test size(expr, :i) == 2 - @test size(expr) == (2,) - @test isempty(hyperinds(expr)) @test isempty(suminds(expr)) @test isempty(parsuminds(expr)) @@ -99,7 +84,7 @@ end @testset "trace" begin - tensor = EinExpr([:i, :i], Dict(:i => 2)) + tensor = EinExpr([:i, :i]) expr = EinExpr(Symbol[], [tensor]) @test isempty(expr.head) @@ -108,9 +93,6 @@ @test isempty(head(expr)) @test all(splat(==), zip(inds(expr), (:i,))) - @test size(expr, :i) == 2 - @test size(expr) == () - @test isempty(hyperinds(expr)) @test suminds(expr) == [:i] @test isempty(parsuminds(expr)) @@ -121,7 +103,7 @@ end @testset "outer product" begin - tensors = [EinExpr([:i, :j], Dict(:i => 2, :j => 3)), EinExpr([:k, :l], Dict(:k => 4, :l => 5))] + tensors = [EinExpr([:i, :j]), EinExpr([:k, :l])] expr = EinExpr([:i, :j, :k, :l], tensors) @test all(splat(==), zip(expr.head, (:i, :j, :k, :l))) @@ -131,11 +113,6 @@ @test all(splat(==), zip(inds(expr), head(expr))) @test ndims(expr) == 4 - for (i, d) in zip([:i, :j, :k, :l], [2, 3, 4, 5]) - @test size(expr, i) == d - end - @test size(expr) == (2, 3, 4, 5) - @test isempty(hyperinds(expr)) @test isempty(suminds(expr)) @test isempty(parsuminds(expr)) @@ -151,7 +128,7 @@ @testset "inner product" begin @testset "Vector" begin - tensors = [EinExpr([:i], Dict(:i => 2)), EinExpr([:i], Dict(:i => 2))] + tensors = [EinExpr([:i]), EinExpr([:i])] expr = EinExpr(Symbol[], tensors) @test isempty(expr.head) @@ -161,9 +138,6 @@ @test all(splat(==), zip(inds(expr), (:i,))) @test ndims(expr) == 0 - @test size(expr, :i) == 2 - @test size(expr) == () - @test isempty(hyperinds(expr)) @test suminds(expr) == [:i] @test parsuminds(expr) == [[:i]] @@ -173,7 +147,7 @@ @test isempty(neighbours(expr, :i)) end @testset "Matrix" begin - tensors = [EinExpr([:i, :j], Dict(:i => 2, :j => 3)), EinExpr([:i, :j], Dict(:i => 2, :j => 3))] + tensors = [EinExpr([:i, :j]), EinExpr([:i, :j])] expr = EinExpr(Symbol[], tensors) @test isempty(expr.head) @@ -183,10 +157,6 @@ @test all(splat(==), zip(inds(expr), [:i, :j])) @test ndims(expr) == 0 - @test size(expr, :i) == 2 - @test size(expr, :j) == 3 - @test size(expr) == () - @test isempty(hyperinds(expr)) @test issetequal(suminds(expr), [:i, :j]) @test Set(Set.(parsuminds(expr))) == Set([Set([:i, :j])]) @@ -199,7 +169,7 @@ end @testset "matrix multiplication" begin - tensors = [EinExpr([:i, :k], Dict(:i => 2, :k => 3)), EinExpr([:k, :j], Dict(:k => 3, :j => 4))] + tensors = [EinExpr([:i, :k]), EinExpr([:k, :j])] expr = EinExpr([:i, :j], tensors) @test all(splat(==), zip(expr.head, [:i, :j])) @@ -209,11 +179,6 @@ @test all(splat(==), zip(inds(expr), (:i, :k, :j))) @test ndims(expr) == 2 - @test size(expr, :i) == 2 - @test size(expr, :j) == 4 - @test size(expr, :k) == 3 - @test size(expr) == (2, 4) - @test isempty(hyperinds(expr)) @test suminds(expr) == [:k] @test parsuminds(expr) == [[:k]] @@ -228,21 +193,13 @@ @testset "hyperindex contraction" begin @testset "hyperindex is not summed" begin - tensors = [ - EinExpr([:i, :β, :j], Dict(i => 2 for i in [:i, :β, :j])), - EinExpr([:k, :β], Dict(i => 2 for i in [:k, :β])), - EinExpr([:β, :l, :m], Dict(i => 2 for i in [:β, :l, :m])), - ] - + tensors = [EinExpr([:i, :β, :j]), EinExpr([:k, :β]), EinExpr([:β, :l, :m])] expr = sum(tensors, skip = [:β]) @test issetequal(head(expr), (:i, :j, :k, :l, :m, :β)) @test issetequal(inds(expr), (:i, :j, :k, :l, :m, :β)) @test ndims(expr) == 6 - @test all(i -> size(expr, i) == 2, inds(expr)) - @test size(expr) == tuple(fill(2, 6)...) - @test issetequal(hyperinds(expr), [:β]) @test isempty(suminds(expr)) @test_broken isempty(parsuminds(expr)) @@ -257,12 +214,7 @@ end @testset "hyperindex is summed" begin - tensors = [ - EinExpr([:i, :β, :j], Dict(i => 2 for i in [:i, :β, :j])), - EinExpr([:k, :β], Dict(i => 2 for i in [:k, :β])), - EinExpr([:β, :l, :m], Dict(i => 2 for i in [:β, :l, :m])), - ] - + tensors = [EinExpr([:i, :β, :j]), EinExpr([:k, :β]), EinExpr([:β, :l, :m])] expr = sum(tensors) @test all(splat(==), zip(expr.head, (:i, :j, :k, :l, :m))) @@ -272,9 +224,6 @@ @test issetequal(inds(expr), (:i, :j, :k, :l, :m, :β)) @test ndims(expr) == 5 - @test all(i -> size(expr, i) == 2, inds(expr)) - @test size(expr) == tuple(fill(2, 5)...) - @test issetequal(hyperinds(expr), [:β]) @test issetequal(suminds(expr), [:β]) @test issetequal(parsuminds(expr), [[:β]]) @@ -291,13 +240,13 @@ @testset "manual path" begin tensors = [ - EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])), - EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])), - EinExpr([:j], Dict(i => 2 for i in [:j])), - EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])), - EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])), - EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])), - EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), + EinExpr([:j, :b, :i, :h]), + EinExpr([:a, :c, :e, :f]), + EinExpr([:j]), + EinExpr([:e, :a, :g]), + EinExpr([:f, :b]), + EinExpr([:i, :h, :d]), + EinExpr([:d, :g, :c]), ] path = EinExpr(Symbol[], tensors) diff --git a/test/SizedEinExpr_test.jl b/test/SizedEinExpr_test.jl new file mode 100644 index 0000000..dd0fcf7 --- /dev/null +++ b/test/SizedEinExpr_test.jl @@ -0,0 +1,22 @@ +@testset "SizedEinExpr" begin + using LinearAlgebra + + tensor = EinExpr([:i, :j]) + expr = EinExpr([:i, :j], [tensor]) + sexpr = SizedEinExpr(expr, Dict(:i => 2, :j => 3)) + + @test head(sexpr) === head(expr) === sexpr.head + @test args(sexpr) === args(expr) === sexpr.args + @test EinExprs.nargs(sexpr) == EinExprs.nargs(expr) + + @test inds(sexpr) == inds(expr) + @test ndims(sexpr) == ndims(expr) + @test length(sexpr) == 6 + + @test size(sexpr, :i) == 2 + @test size(sexpr, :j) == 3 + @test size(sexpr) == (2, 3) + + @test select(sexpr, :i) == SizedEinExpr[sexpr, SizedEinExpr(tensor, Dict(:i => 2, :j => 3))] + @test select(sexpr, :j) == SizedEinExpr[sexpr, SizedEinExpr(tensor, Dict(:i => 2, :j => 3))] +end diff --git a/test/runtests.jl b/test/runtests.jl index ead32fd..9248f5f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using EinExprs @testset "Unit tests" verbose = true begin include("EinExpr_test.jl") + include("SizedEinExpr_test.jl") include("Counters_test.jl") @testset "Optimizers" begin include("Naive_test.jl") From fbd1eb51c3f3d437da68f24452d6b1eed5bc1613 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 27 Dec 2023 18:38:28 +0100 Subject: [PATCH 22/28] Fix `hashyperinds` default in `Exhaustive` optimizer --- src/Optimizers/Exhaustive.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index 264600d..85285ec 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -39,7 +39,7 @@ function __einexpr_exhaustive_it( outer, leader; cache = Dict{Vector{Symbol},BigInt}(), - hashyperinds = any(hyperinds(path)), + hashyperinds = !isempty(hyperinds(path)), ) where {Metric} if nargs(path) <= 2 #= mapreduce(metric, +, Branches(path, inverse = true), init = BigInt(0))) =# From 43b114d9d330ccad44716082723928dfce3e5ba5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 27 Dec 2023 18:39:34 +0100 Subject: [PATCH 23/28] Refactor code to use `SizedEinExpr` in Exhaustive_test.jl --- test/Exhaustive_test.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/test/Exhaustive_test.jl b/test/Exhaustive_test.jl index 0282562..d4954ed 100644 --- a/test/Exhaustive_test.jl +++ b/test/Exhaustive_test.jl @@ -9,14 +9,16 @@ EinExpr([:i, :h, :d]), EinExpr([:d, :g, :c]), ] + expr = EinExpr(Symbol[], tensors) + sexpr = SizedEinExpr(expr, sizedict) - path = einexpr(Exhaustive, EinExpr(Symbol[], tensors), sizedict) + path = einexpr(Exhaustive, sexpr) - @test path isa EinExpr + @test path isa SizedEinExpr - @test mapreduce(flops, +, Branches(path)) == 90 + @test mapreduce(flops, +, Branches(path)) == 92 - @test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:d], [:b, :i, :h], [:j]])) + @test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:j], [:i, :h], [:d, :b]])) @testset "hyperedges" begin sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m, :β]) @@ -24,10 +26,10 @@ b = EinExpr([:k, :β]) c = EinExpr([:β, :l, :m]) - path = einexpr(EinExprs.Exhaustive(), sum([a, b, c], skip = [:β]), sizedict) + path = einexpr(EinExprs.Exhaustive(), SizedEinExpr(sum([a, b, c], skip = [:β]), sizedict)) @test all(∋(:β) ∘ head, branches(path)) - path = einexpr(EinExprs.Exhaustive(), sum([a, b, c], skip = Symbol[]), sizedict) + path = einexpr(EinExprs.Exhaustive(), SizedEinExpr(sum([a, b, c], skip = Symbol[]), sizedict)) @test all(∋(:β) ∘ head, branches(path)[1:end-1]) @test all(!∋(:β) ∘ head, branches(path)[end:end]) end From e9945b005e17b2834d97e313bfe790c3c71d9fea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 27 Dec 2023 18:45:00 +0100 Subject: [PATCH 24/28] Refactor `Greedy` tests for `SizedEinExpr` --- src/Optimizers/Greedy.jl | 4 +--- test/Greedy_test.jl | 15 +++++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/Optimizers/Greedy.jl b/src/Optimizers/Greedy.jl index ec4982c..11c53a9 100644 --- a/src/Optimizers/Greedy.jl +++ b/src/Optimizers/Greedy.jl @@ -68,6 +68,4 @@ function einexpr(config::Greedy, path, sizedict) return path end -function einexpr(config::Greedy, path::SizedEinExpr) - return einexpr(config, path.path, path.size) -end +einexpr(config::Greedy, path::SizedEinExpr) = SizedEinExpr(einexpr(config, path.path, path.size), path.size) diff --git a/test/Greedy_test.jl b/test/Greedy_test.jl index 07d2586..7416719 100644 --- a/test/Greedy_test.jl +++ b/test/Greedy_test.jl @@ -9,20 +9,23 @@ EinExpr([:i, :h, :d]), EinExpr([:d, :g, :c]), ] + expr = sum(tensors) - path = einexpr(Greedy(), EinExpr(Symbol[], tensors), sizedict) + path = einexpr(Greedy(), SizedEinExpr(expr, sizedict)) - @test path isa EinExpr + @test path isa SizedEinExpr - @test mapreduce(flops(sizedict), +, Branches(path)) == 100 + @test mapreduce(flops, +, Branches(path)) == 100 @test all(splat(issetequal), zip(contractorder(path), [[:i, :h], [:j], [:a, :e], [:g, :c], [:f], [:b, :d]])) @testset "example: let unchanged" begin sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m]) tensors = [EinExpr([:i, :j, :k]), EinExpr([:k, :l, :m])] + expr = sum(tensors, skip = [:i, :j, :l, :m]) + sexpr = SizedEinExpr(expr, sizedict) - path = einexpr(Greedy(), EinExpr(Symbol[:i, :j, :l, :m], tensors)) + path = einexpr(Greedy(), sexpr) @test suminds(path) == [:k] end @@ -33,10 +36,10 @@ b = EinExpr([:k, :β]) c = EinExpr([:β, :l, :m]) - path = einexpr(EinExprs.Greedy(), sum([a, b, c], skip = [:β]), sizedict) + path = einexpr(EinExprs.Greedy(), SizedEinExpr(sum([a, b, c], skip = [:β]), sizedict)) @test all(∋(:β) ∘ head, branches(path)) - path = einexpr(EinExprs.Greedy(), sum([a, b, c], skip = Symbol[]), sizedict) + path = einexpr(EinExprs.Greedy(), SizedEinExpr(sum([a, b, c], skip = Symbol[]), sizedict)) @test all(∋(:β) ∘ head, branches(path)[1:end-1]) @test all(!∋(:β) ∘ head, branches(path)[end:end]) end From 383088b28517a578a09cfeb3ca518140eddb8ee5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 28 Dec 2023 11:30:44 +0100 Subject: [PATCH 25/28] Refactor `args(::SizedEinExpr)` to return `SizedEinExpr`s Use `getproperty(x, :args)` to return the `EinExpr`s underneath. --- src/Optimizers/Exhaustive.jl | 4 ++-- src/Optimizers/KaHyPar.jl | 12 ++++++------ src/SizedEinExpr.jl | 13 +++++++++++-- test/SizedEinExpr_test.jl | 3 ++- 4 files changed, 21 insertions(+), 11 deletions(-) diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index 85285ec..ef99a19 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -47,7 +47,7 @@ function __einexpr_exhaustive_it( return end - for (i, j) in combinations(args(path), 2) + for (i, j) in combinations(path.args, 2) !outer && isdisjoint(head(i), head(j)) && continue candidate = sum(i, j; skip = hashyperinds ? path.head ∪ hyperinds(path) : path.head) @@ -57,7 +57,7 @@ function __einexpr_exhaustive_it( end new_cost >= leader[].cost && continue - new_path = SizedEinExpr(EinExpr(head(path), [candidate, filter(∉([i, j]), args(path))...]), path.size) # sum([candidate, filter(∉([i, j]), args(path))...], skip = path.head) + new_path = SizedEinExpr(EinExpr(head(path), [candidate, filter(∉([i, j]), path.args)...]), path.size) # sum([candidate, filter(∉([i, j]), args(path))...], skip = path.head) __einexpr_exhaustive_it(new_path, new_cost, metric, outer, leader; cache, hashyperinds) end end diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl index b3c040c..96dc4b4 100644 --- a/src/Optimizers/KaHyPar.jl +++ b/src/Optimizers/KaHyPar.jl @@ -6,7 +6,7 @@ using Suppressor @kwdef struct HyPar <: Optimizer parts::Int = 2 imbalance::Float32 = 0.03 - stop::Function = <=(2) ∘ length ∘ Base.Fix2(getfield, :args) + stop::Function = <=(2) ∘ length ∘ Base.Fix2(getproperty, :args) configuration::Union{Nothing,Symbol,String} = nothing edge_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2 vertex_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2 @@ -25,7 +25,7 @@ function EinExprs.einexpr(config::HyPar, path) # NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order edge_weights = map(config.edge_scaler ∘ Base.Fix1(size, path), inds) - vertex_weights = map(config.vertex_scaler ∘ length, path.args) + vertex_weights = map(config.vertex_scaler ∘ length, args(path)) hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights) @@ -36,13 +36,13 @@ function EinExprs.einexpr(config::HyPar, path) configuration = config.configuration, ) - args = map(unique(partitions)) do partition + _args = map(unique(partitions)) do partition selection = partitions .== partition - count(selection) == 1 && return only(path.args[selection]) + count(selection) == 1 && return only(args(path)[selection]) - expr = sum(path.args[selection], skip = path.head) + expr = sum(args(path)[selection], skip = path.head) einexpr(config, expr) end - return EinExpr(path.head, args) + return sum(_args, skip = path.head) end diff --git a/src/SizedEinExpr.jl b/src/SizedEinExpr.jl index 801717a..c63e555 100644 --- a/src/SizedEinExpr.jl +++ b/src/SizedEinExpr.jl @@ -13,7 +13,16 @@ end EinExpr(path::Vector{Symbol}, size::Dict{Symbol}) = SizedEinExpr(EinExpr(path), size) head(sexpr::SizedEinExpr) = head(sexpr.path) -args(sexpr::SizedEinExpr) = sexpr.path.args # map(Base.Fix2(SizedEinExpr, sexpr.size), sexpr.path.args) + +""" + args(sexpr::SizedEinExpr) + +# Note + +Unlike `args(::EinExpr)`, this function returns `SizedEinExpr` objects. +""" +args(sexpr::SizedEinExpr) = map(Base.Fix2(SizedEinExpr, sexpr.size), sexpr.path.args) # sexpr.path.args + nargs(sexpr::SizedEinExpr) = nargs(sexpr.path) inds(sexpr::SizedEinExpr) = inds(sexpr.path) @@ -59,7 +68,7 @@ Base.IteratorEltype(::Type{<:TreeIterator{SizedEinExpr}}) = Base.HasEltype() Base.eltype(::Type{<:TreeIterator{SizedEinExpr}}) = SizedEinExpr # AbstractTrees interface and traits -AbstractTrees.children(sexpr::SizedEinExpr) = map(Base.Fix2(SizedEinExpr, sexpr.size), args(sexpr)) +AbstractTrees.children(sexpr::SizedEinExpr) = args(sexpr) AbstractTrees.childtype(::Type{SizedEinExpr}) = SizedEinExpr AbstractTrees.childrentype(::Type{SizedEinExpr}) = Vector{SizedEinExpr} AbstractTrees.childstatetype(::Type{SizedEinExpr}) = Int diff --git a/test/SizedEinExpr_test.jl b/test/SizedEinExpr_test.jl index dd0fcf7..3dccf4d 100644 --- a/test/SizedEinExpr_test.jl +++ b/test/SizedEinExpr_test.jl @@ -6,7 +6,8 @@ sexpr = SizedEinExpr(expr, Dict(:i => 2, :j => 3)) @test head(sexpr) === head(expr) === sexpr.head - @test args(sexpr) === args(expr) === sexpr.args + @test args(expr) === sexpr.args + @test args(sexpr) == map(Base.Fix2(SizedEinExpr, Dict(:i => 2, :j => 3)), args(expr)) @test EinExprs.nargs(sexpr) == EinExprs.nargs(expr) @test inds(sexpr) == inds(expr) From f1e7d3110eecfb69cd29581f184332d857ba71bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 28 Dec 2023 11:31:29 +0100 Subject: [PATCH 26/28] Fix KaHyPar tests --- test/KaHyPar_test.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/KaHyPar_test.jl b/test/KaHyPar_test.jl index f4cdc84..0794b45 100644 --- a/test/KaHyPar_test.jl +++ b/test/KaHyPar_test.jl @@ -9,10 +9,11 @@ EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])), EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), ] + sexpr = sum(tensors) - path = einexpr(HyPar(imbalance=0.42), EinExpr(Symbol[], tensors)) + path = einexpr(HyPar(imbalance = 0.42), sexpr) - @test path isa EinExpr + @test path isa SizedEinExpr @test mapreduce(flops, +, Branches(path)) == 108 end @@ -40,10 +41,11 @@ EinExpr([:A, :W], Dict(:A => 6, :W => 6)), EinExpr([:a, :C, :d], Dict(:a => 3, :d => 6, :C => 4)), ] + sexpr = sum(tensors) - path = einexpr(HyPar(imbalance=0.45), EinExpr(Symbol[], tensors)) + path = einexpr(HyPar(imbalance = 0.45), sexpr) - @test path isa EinExpr + @test path isa SizedEinExpr @test mapreduce(flops, +, Branches(path)) == 19099592 end From 49448aea8f5b4aa0dbc3e9effc856fa04ab4de26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 28 Dec 2023 16:19:30 +0100 Subject: [PATCH 27/28] Refactor `selectdim` and slicing methods for `SizedEinExpr` --- src/Slicing.jl | 56 +++++++++++++++++++++----------------------- test/Slicing_test.jl | 5 ++-- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/src/Slicing.jl b/src/Slicing.jl index 82e6b07..281ecf3 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -13,34 +13,33 @@ Project `index` to dimension `i` in a EinExpr. This is equivalent to tensor cutt See also: [`view`](@ref). """ -function Base.selectdim(path::EinExpr, index::Symbol, i) +Base.selectdim(path::EinExpr, ::Symbol, i) = path + +function Base.selectdim(path::EinExpr, index::Symbol, i::Integer) path = deepcopy(path) - for leave in Iterators.filter(∋(index) ∘ head, Leaves(path)) - leave.size[index] = length(i) + for expr in PreOrderDFS(path) + filter!(!=(index), expr.head) end return path end -function Base.selectdim(path::EinExpr, index::Symbol, _::Integer) - path = deepcopy(path) +function Base.selectdim(sexpr::SizedEinExpr, index::Symbol, i) + path = selectdim(sexpr.path, index, i) - index ∈ head(path) && (path = EinExpr(filter(!=(index), path.head), path.args)) - - for branch in Branches(path) - for arg in Iterators.filter(∋(index) ∘ head, branch.args) - replace!( - branch.args, - arg => EinExpr( - filter(!=(index), arg.head), - isempty(arg.args) ? filter(p -> p.first != index, arg.size) : arg.args, - ), - ) - end - end + size = copy(sexpr.size) + size[index] = length(i) - return path + return SizedEinExpr(path, size) +end + +function Base.selectdim(sexpr::SizedEinExpr, index::Symbol, i::Integer) + path = selectdim(sexpr.path, index, i) + + size = filter(!=(index) ∘ first, sexpr.size) + + return SizedEinExpr(path, size) end """ @@ -88,8 +87,7 @@ Reimplementation based on [`contengra`](https://github.com/jcmgray/cotengra)'s ` """ function findslices( scorer, - path::EinExpr, - sizedict; + path; size = nothing, overhead = nothing, slices = nothing, @@ -101,8 +99,8 @@ function findslices( candidates = Set(setdiff(mapreduce(head, ∪, PostOrderDFS(path)), skip)) solution = Set{Symbol}() - current = (; slices = 1, size = maximum(Base.Fix2(length, sizedict), PostOrderDFS(path)), overhead = 1.0) - original_flops = mapreduce(flops(sizedict), +, Branches(path; inverse = true)) + current = (; slices = 1, size = maximum(length, PostOrderDFS(path)), overhead = 1.0) + original_flops = mapreduce(flops, +, Branches(path; inverse = true)) sliced_path = path while !isempty(candidates) @@ -114,15 +112,15 @@ function findslices( sliced_path = selectdim(sliced_path, winner, 1) cur_overhead = - prod(i -> sizedict[i], [solution..., winner]) * - mapreduce(flops(sizedict), +, Branches(sliced_path; inverse = true)) / original_flops + prod(i -> Base.size(path, i), [solution..., winner]) * + mapreduce(flops, +, Branches(sliced_path; inverse = true)) / original_flops !isnothing(overhead) && cur_overhead > overhead && break push!(solution, winner) current = (; - slices = current.slices * Base.Fix2(length, sizedict)(path, winner), - size = maximum(Base.Fix2(length, sizedict), PostOrderDFS(sliced_path)), + slices = current.slices * Base.size(path, winner), + size = maximum(length, PostOrderDFS(sliced_path)), overhead = cur_overhead, ) @@ -150,7 +148,7 @@ function (cb::FlopsScorer)(path, index) slice = selectdim(path, index, 1) flops_reduction = mapreduce(flops, +, PostOrderDFS(path)) - mapreduce(flops, +, PostOrderDFS(slice)) - write_reduction = mapreduce(prod ∘ size, +, PostOrderDFS(path)) - mapreduce(prod ∘ size, +, PostOrderDFS(slice)) + write_reduction = mapreduce(length, +, PostOrderDFS(path)) - mapreduce(length, +, PostOrderDFS(slice)) log(flops_reduction + write_reduction * cb.weight + 1) end @@ -170,7 +168,7 @@ function (cb::SizeScorer)(path, index) slice = selectdim(path, index, 1) flops_reduction = mapreduce(flops, +, PostOrderDFS(path)) - mapreduce(flops, +, PostOrderDFS(slice)) - write_reduction = mapreduce(prod ∘ size, +, PostOrderDFS(path)) - mapreduce(prod ∘ size, +, PostOrderDFS(slice)) + write_reduction = mapreduce(length, +, PostOrderDFS(path)) - mapreduce(length, +, PostOrderDFS(slice)) log(write_reduction + flops_reduction * cb.weight + 1) end diff --git a/test/Slicing_test.jl b/test/Slicing_test.jl index 6e9f6ef..5980863 100644 --- a/test/Slicing_test.jl +++ b/test/Slicing_test.jl @@ -65,8 +65,9 @@ EinExpr((:p, :k)), ], ) + sexpr = SizedEinExpr(expr, sizes) - cuttings = findslices(FlopsScorer(), expr, slices = 1000) + cuttings = findslices(FlopsScorer(), sexpr, slices = 1000) - @test prod(i -> sizedict[i], cuttings) >= 1000 + @test prod(i -> sizes[i], cuttings) >= 1000 end From 47c64e5e2eecc278d8716cb4872aeadffd67e5ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 28 Dec 2023 22:12:41 +0100 Subject: [PATCH 28/28] Update `Makie.plot` functions to accept `SizedEinExpr` instead of `EinExpr` --- ext/EinExprsMakieExt.jl | 10 +++++----- test/ext/Makie_test.jl | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ext/EinExprsMakieExt.jl b/ext/EinExprsMakieExt.jl index 7d3a36c..79e52ed 100644 --- a/ext/EinExprsMakieExt.jl +++ b/ext/EinExprsMakieExt.jl @@ -15,13 +15,13 @@ const MAX_EDGE_WIDTH = 10.0 const MAX_ARROW_SIZE = 35.0 const MAX_NODE_SIZE = 40.0 -function Makie.plot(path::EinExpr; kwargs...) +function Makie.plot(path::SizedEinExpr; kwargs...) f = Figure() ax, p = plot!(f[1, 1], path; kwargs...) return Makie.FigureAxisPlot(f, ax, p) end -function Makie.plot!(f::Union{Figure,GridPosition}, path::EinExpr; kwargs...) +function Makie.plot!(f::Union{Figure,GridPosition}, path::SizedEinExpr; kwargs...) ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3 Axis3(f[1, 1]) else @@ -65,13 +65,13 @@ end # TODO replace `to_colormap(:viridis)[begin:end-10]` with a custom colormap function Makie.plot!( ax::Union{Axis,Axis3}, - path::EinExpr; + path::SizedEinExpr; colormap = to_colormap(:viridis)[begin:end-10], inds = false, kwargs..., ) - handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path))) - graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path) for from in to.args]) + handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path.path))) + graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path.path) for from in to.args]) lin_size = length.(PostOrderDFS(path))[1:end-1] lin_flops = map(max, Iterators.repeated(1), Iterators.map(flops, PostOrderDFS(path))) diff --git a/test/ext/Makie_test.jl b/test/ext/Makie_test.jl index 74cccdb..98543a7 100644 --- a/test/ext/Makie_test.jl +++ b/test/ext/Makie_test.jl @@ -34,8 +34,8 @@ EinExpr([:g, :q], filter(p -> p.first ∈ [:g, :q], sizes)), EinExpr([:d, :b, :o], filter(p -> p.first ∈ [:d, :b, :o], sizes)), ] - - path = einexpr(Exhaustive(), EinExpr([:p, :j], tensors)) + expr = sum(tensors, skip = [:p, :j]) + path = einexpr(Exhaustive(), expr) @testset "plot!" begin f = Figure()