Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple size dictionary from EinExprs #45

Merged
merged 28 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b55d013
Decouple `size` dictionary from `EinExpr` struct
mofeing Oct 21, 2023
8d2d2b0
Fix only-head `EinExpr` constructor
mofeing Oct 21, 2023
03573e1
Microoptimize candidate cost caching in `Exhaustive` optimizer
mofeing Oct 21, 2023
8350351
Microoptimize resource counting by changing to pre-order in `Exhausti…
mofeing Oct 21, 2023
5f44cce
Move resource counting microoptimization to `Branches`
mofeing Oct 21, 2023
6896fba
Fix resource counting in `findslices`
mofeing Oct 21, 2023
fce47db
Curry resource counters on `Dict` argument
mofeing Oct 21, 2023
cbcb274
Microoptimize `mapreduce` call in `flops`
mofeing Oct 23, 2023
f556060
Microoptimize type-stability in `Exhaustive` optimizer
mofeing Oct 23, 2023
102e947
Fix recursive problem in generic `einexpr` call
mofeing Oct 23, 2023
68bfd82
Remove artifact `ImmutableVector`
mofeing Dec 4, 2023
1a7b0c2
Refactor code to use SizedEinExpr for size-aware expressions
mofeing Dec 19, 2023
fbd9104
Fix `removedsize` on `SizedEinExpr`
mofeing Dec 19, 2023
513f872
Fix `Greedy` optimizer on `SizedEinExpr`
mofeing Dec 19, 2023
c08dba7
Specialize on `metric` function to avoid recursive dynamic-dispatch
mofeing Dec 24, 2023
b975efd
Speedup `sum` function on `EinExpr`s
mofeing Dec 24, 2023
72db773
Format comment
mofeing Dec 24, 2023
5da6cd6
Refactor code to use sum function in make.jl
mofeing Dec 24, 2023
e5826b2
Remove comment
mofeing Dec 25, 2023
7b77076
Add size indexing to SizedEinExpr
mofeing Dec 27, 2023
af17760
Fix `EinExprs` tests
mofeing Dec 27, 2023
fbd1eb5
Fix `hashyperinds` default in `Exhaustive` optimizer
mofeing Dec 27, 2023
43b114d
Refactor code to use `SizedEinExpr` in Exhaustive_test.jl
mofeing Dec 27, 2023
e9945b0
Refactor `Greedy` tests for `SizedEinExpr`
mofeing Dec 27, 2023
383088b
Refactor `args(::SizedEinExpr)` to return `SizedEinExpr`s
mofeing Dec 28, 2023
f1e7d31
Fix KaHyPar tests
mofeing Dec 28, 2023
49448ae
Refactor `selectdim` and slicing methods for `SizedEinExpr`
mofeing Dec 28, 2023
47c64e5
Update `Makie.plot` functions to accept `SizedEinExpr` instead of `Ei…
mofeing Dec 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions benchmark/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,15 @@ suite["greedy"] = BenchmarkGroup([])
suite["kahypar"] = BenchmarkGroup([])

# BENCHMARK 1
expr = EinExpr(
Symbol[],
[
EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])),
EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])),
EinExpr([:j], Dict(i => 2 for i in [:j])),
EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])),
EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])),
EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])),
EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])),
],
)
expr = sum([
EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])),
EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])),
EinExpr([:j], Dict(i => 2 for i in [:j])),
EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])),
EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])),
EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])),
EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])),
])

suite["naive"][1] = @benchmarkable einexpr(EinExprs.Naive(), $expr)
suite["exhaustive"][1] = @benchmarkable einexpr(Exhaustive(), $expr)
Expand All @@ -41,7 +38,7 @@ D = EinExpr([:c, :h, :d, :i], Dict(:c => 2, :h => 2, :d => 2, :i => 2))
E = EinExpr([:f, :i, :g, :j], Dict(:f => 2, :i => 2, :g => 2, :j => 2))
F = EinExpr([:B, :h, :k, :l], Dict(:B => 2, :h => 2, :k => 2, :l => 2))
G = EinExpr([:j, :k, :l, :D], Dict(:j => 2, :k => 2, :l => 2, :D => 2))
expr = EinExpr([:A, :B, :C, :D], [A, B, C, D, E, F, G])
expr = sum([A, B, C, D, E, F, G], skip = [:A, :B, :C, :D])

suite["naive"][2] = @benchmarkable einexpr(EinExprs.Naive(), $expr)
suite["exhaustive"][2] = @benchmarkable einexpr(Exhaustive(), $expr)
Expand Down
10 changes: 5 additions & 5 deletions ext/EinExprsMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ const MAX_EDGE_WIDTH = 10.0
const MAX_ARROW_SIZE = 35.0
const MAX_NODE_SIZE = 40.0

function Makie.plot(path::EinExpr; kwargs...)
function Makie.plot(path::SizedEinExpr; kwargs...)
f = Figure()
ax, p = plot!(f[1, 1], path; kwargs...)
return Makie.FigureAxisPlot(f, ax, p)
end

function Makie.plot!(f::Union{Figure,GridPosition}, path::EinExpr; kwargs...)
function Makie.plot!(f::Union{Figure,GridPosition}, path::SizedEinExpr; kwargs...)
ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3
Axis3(f[1, 1])
else
Expand Down Expand Up @@ -65,13 +65,13 @@ end
# TODO replace `to_colormap(:viridis)[begin:end-10]` with a custom colormap
function Makie.plot!(
ax::Union{Axis,Axis3},
path::EinExpr;
path::SizedEinExpr;
colormap = to_colormap(:viridis)[begin:end-10],
inds = false,
kwargs...,
)
handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path)))
graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path) for from in to.args])
handles = IdDict(obj => i for (i, obj) in enumerate(PostOrderDFS(path.path)))
graph = SimpleDiGraph([Edge(handles[from], handles[to]) for to in Branches(path.path) for from in to.args])

lin_size = length.(PostOrderDFS(path))[1:end-1]
lin_flops = map(max, Iterators.repeated(1), Iterators.map(flops, PostOrderDFS(path)))
Expand Down
26 changes: 22 additions & 4 deletions src/Counters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,41 @@

Count the number of mathematical operations will be performed by the contraction of the root of the `path` tree.
"""
flops(expr::EinExpr) =
if length(expr.args) == 0 || length(expr.args) == 1 && isempty(suminds(expr))
flops(sexpr::SizedEinExpr) =
if nargs(sexpr) == 0 || nargs(sexpr) == 1 && isempty(suminds(sexpr))
0
else
mapreduce(Base.Fix1(size, expr), *, Iterators.flatten((head(expr), suminds(expr))), init = one(BigInt))
mapreduce(
Base.Fix1(getindex, sexpr.size),
*,
Iterators.flatten((head(sexpr), suminds(sexpr))),
init = one(BigInt),
)
end

flops(expr::EinExpr, size) = flops(SizedEinExpr(expr, size))

"""
removedsize(path::EinExpr)

Count the amount of memory that will be freed after performing the contraction of the root of the `path` tree.
"""
removedsize(expr::EinExpr) = mapreduce(prod ∘ size, +, expr.args) - prod(size(expr))
removedsize(sexpr::SizedEinExpr) = -length(sexpr) + mapreduce(+, sexpr.args) do arg
length(SizedEinExpr(arg, sexpr.size))
end

removedsize(expr::EinExpr, size) = removedsize(SizedEinExpr(expr, size))

"""
removedrank(path::EinExpr)

Count the rank reduction after performing the contraction of the root of the `path` tree.
"""
removedrank(expr::EinExpr) = mapreduce(ndims, max, expr.args) - ndims(expr)
removedrank(expr::EinExpr, _) = removedrank(expr)
removedrank(sexpr::SizedEinExpr, _) = removedrank(sexpr.path)

Check warning on line 38 in src/Counters.jl

View check run for this annotation

Codecov / codecov/patch

src/Counters.jl#L38

Added line #L38 was not covered by tests

for f in [:flops, :removedsize]
@eval $f(sizedict::Dict{Symbol}) = Base.Fix2($f, sizedict)
end
removedrank(::Dict) = removedrank

Check warning on line 43 in src/Counters.jl

View check run for this annotation

Codecov / codecov/patch

src/Counters.jl#L43

Added line #L43 was not covered by tests
62 changes: 34 additions & 28 deletions src/EinExpr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,18 @@
using DataStructures: DefaultDict
using AbstractTrees

struct EinExpr
Base.@kwdef struct EinExpr
head::Vector{Symbol}
args::Vector{EinExpr}
size::Dict{Symbol,Int}

# TODO checks: same dim for index, valid indices
EinExpr(head, args) = new(head, args, Dict{Symbol,EinExpr}())

function EinExpr(head::AbstractVector{Symbol}, size::AbstractDict{Symbol,Int})
head ⊆ keys(size) || throw(ArgumentError("Missing sizes for indices $(setdiff(head, keys(size)))"))
new(head, EinExpr[], size)
end
args::Vector{EinExpr} = EinExpr[]
end

EinExpr(head) = EinExpr(head, EinExpr[])
EinExpr(head, args::AbstractVecOrTuple{<:AbstractVecOrTuple{Symbol}}) = EinExpr(head, map(EinExpr, args))

Check warning on line 11 in src/EinExpr.jl

View check run for this annotation

Codecov / codecov/patch

src/EinExpr.jl#L11

Added line #L11 was not covered by tests

EinExpr(head::NTuple, args) = EinExpr(collect(head), args)
EinExpr(head, args::NTuple) = EinExpr(head, collect(args))
EinExpr(head::NTuple, args::NTuple) = EinExpr(collect(head), collect(args))

function EinExpr(head, args::AbstractVecOrTuple{<:AbstractVecOrTuple{Symbol}}, sizes)
args = map(args) do arg
sizedict = filter(∈(arg) ∘ first, sizes)
EinExpr(arg, sizedict)
end
EinExpr(head, args)
end

"""
head(path::EinExpr)

Expand All @@ -46,6 +32,8 @@
"""
args(path::EinExpr) = path.args

nargs(path::EinExpr) = length(path.args)

"""
inds(path)

Expand Down Expand Up @@ -100,11 +88,8 @@

Return the size of the resulting tensor from contracting `path`. If `index` is specified, return the size of such index.
"""
Base.size(path::EinExpr) = (size(path, i) for i in head(path)) |> splat(tuple)
Base.size(path::EinExpr, i::Symbol) =
Iterators.filter(∋(i) ∘ head, Leaves(path)) |> first |> Base.Fix2(getproperty, :size) |> Base.Fix2(getindex, i)

Base.length(path::EinExpr) = (prod ∘ size)(path)
Base.size(path::EinExpr, sizedict) = (sizedict[i] for i in head(path)) |> splat(tuple)
Base.length(path::EinExpr, sizedict) = (prod ∘ size)(path, sizedict)

"""
collapse!(path::EinExpr)
Expand Down Expand Up @@ -241,24 +226,45 @@
function Base.sum(args::Vector{EinExpr}; skip = Symbol[])
_head = Symbol[]
_counts = Int[]

for arg in args
for index in head(arg)
i = findfirst(Base.Fix1(===, index), _head)
if isnothing(i)
push!(_head, index)
push!(_counts, 1)
else
_counts[i] += 1
@inbounds _counts[i] += 1
end
end
end

_head = map(first, Iterators.filter(zip(_head, _counts)) do (index, count)
count == 1 || index ∈ skip
end)
# NOTE `map` with `Iterators.filter` induces many heap grows; allocating once and deleting is faster
for i in Iterators.reverse(eachindex(_head, _counts))
(_counts[i] == 1 || _head[i] ∈ skip) && continue
deleteat!(_head, i)
end

EinExpr(_head, args)
end

function Base.sum(a::EinExpr, b::EinExpr; skip = Symbol[])
_head = copy(head(a))

for index in head(b)
i = findfirst(Base.Fix1(===, index), _head)
if isnothing(i)
push!(_head, index)
elseif index ∈ skip
continue
else
deleteat!(_head, i)
end
end

EinExpr(_head, [a, b])
end

function Base.string(path::EinExpr; recursive::Bool = false)
!recursive && return "$(join(map(x -> string.(head(x)) |> join, args(path)), ","))->$(string.(head(path)) |> join)"
map(string, Branches(path))
Expand Down
3 changes: 3 additions & 0 deletions src/EinExprs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ export EinExpr
export head, args, inds, hyperinds, suminds, parsuminds, collapse!, contractorder, select, neighbours
export Branches, branches, leaves

include("SizedEinExpr.jl")
export SizedEinExpr

include("Counters.jl")
export flops, removedsize

Expand Down
34 changes: 20 additions & 14 deletions src/Optimizers/Exhaustive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,41 @@ The algorithm has a ``\mathcal{O}(n!)`` time complexity if `outer = true` and ``
end

function einexpr(config::Exhaustive, path; cost = BigInt(0))
leader = Ref{NamedTuple{(:path, :cost),Tuple{EinExpr,BigInt}}}((;
# metric = Base.Fix2(config.metric, path.size)
leader = Ref((;
path = einexpr(Naive(), path),
cost = mapreduce(config.metric, +, Branches(einexpr(Naive(), path), inverse = true), init = BigInt(0))::BigInt,
))
cache = Dict{Vector{Symbol},BigInt}()
__einexpr_exhaustive_it(path, cost, config.metric, config.outer, leader, cache)
__einexpr_exhaustive_it(path, cost, Val(config.metric), config.outer, leader)
return leader[].path
end

function __einexpr_exhaustive_it(path, cost, metric, outer, leader, cache)
if length(path.args) == 1
# remove identity einsum (i.e. "i...->i...")
path = path.args[1]

leader[] = (; path, cost = mapreduce(metric, +, Branches(path, inverse = true), init = BigInt(0))::BigInt)
function __einexpr_exhaustive_it(
path,
cost,
@specialize(metric::Val{Metric}),
outer,
leader;
cache = Dict{Vector{Symbol},BigInt}(),
hashyperinds = !isempty(hyperinds(path)),
) where {Metric}
if nargs(path) <= 2
#= mapreduce(metric, +, Branches(path, inverse = true), init = BigInt(0))) =#
leader[] = (; path = path, cost = cost)
return
end

for (i, j) in combinations(args(path), 2)
for (i, j) in combinations(path.args, 2)
!outer && isdisjoint(head(i), head(j)) && continue
candidate = sum([i, j], skip = path.head ∪ hyperinds(path))
candidate = sum(i, j; skip = hashyperinds ? path.head ∪ hyperinds(path) : path.head)

# prune paths based on metric
new_cost = cost + get!(cache, head(candidate)) do
metric(candidate)
Metric(SizedEinExpr(candidate, path.size))
end
new_cost >= leader[].cost && continue

new_path = EinExpr(head(path), [candidate, filter(∉([i, j]), args(path))...])
__einexpr_exhaustive_it(new_path, new_cost, metric, outer, leader, cache)
new_path = SizedEinExpr(EinExpr(head(path), [candidate, filter(∉([i, j]), path.args)...]), path.size) # sum([candidate, filter(∉([i, j]), args(path))...], skip = path.head)
__einexpr_exhaustive_it(new_path, new_cost, metric, outer, leader; cache, hashyperinds)
end
end
12 changes: 8 additions & 4 deletions src/Optimizers/Greedy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ The implementation uses a binary heaptree to sort candidate pairwise tensor cont
outer::Bool = false
end

function einexpr(config::Greedy, path)
function einexpr(config::Greedy, path, sizedict)
metric = config.metric(sizedict)

# generate initial candidate contractions
queue = MutableBinaryHeap{Tuple{Float64,EinExpr}}(
Base.By(first, Base.Reverse),
Expand All @@ -36,12 +38,12 @@ function einexpr(config::Greedy, path)
) do (a, b)
# TODO don't consider outer products
candidate = sum([a, b], skip = path.head ∪ hyperinds(path))
weight = config.metric(candidate)
weight = metric(candidate)
(weight, candidate)
end,
)

while length(path.args) > 2 && length(queue) > 1
while nargs(path) > 2 && length(queue) > 1
# choose winner
_, winner = config.choose(queue)

Expand All @@ -55,7 +57,7 @@ function einexpr(config::Greedy, path)
for other in Iterators.filter(other -> config.outer || !isdisjoint(winner.head, other.head), path.args)
# TODO don't consider outer products
candidate = sum([winner, other], skip = path.head ∪ hyperinds(path))
weight = config.metric(candidate)
weight = metric(candidate)
push!(queue, (weight, candidate))
end

Expand All @@ -65,3 +67,5 @@ function einexpr(config::Greedy, path)

return path
end

einexpr(config::Greedy, path::SizedEinExpr) = SizedEinExpr(einexpr(config, path.path, path.size), path.size)
12 changes: 6 additions & 6 deletions src/Optimizers/KaHyPar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Suppressor
@kwdef struct HyPar <: Optimizer
parts::Int = 2
imbalance::Float32 = 0.03
stop::Function = <=(2) ∘ length ∘ Base.Fix2(getfield, :args)
stop::Function = <=(2) ∘ length ∘ Base.Fix2(getproperty, :args)
configuration::Union{Nothing,Symbol,String} = nothing
edge_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2
vertex_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2
Expand All @@ -25,7 +25,7 @@ function EinExprs.einexpr(config::HyPar, path)

# NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order
edge_weights = map(config.edge_scaler ∘ Base.Fix1(size, path), inds)
vertex_weights = map(config.vertex_scaler ∘ length, path.args)
vertex_weights = map(config.vertex_scaler ∘ length, args(path))

hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights)

Expand All @@ -36,13 +36,13 @@ function EinExprs.einexpr(config::HyPar, path)
configuration = config.configuration,
)

args = map(unique(partitions)) do partition
_args = map(unique(partitions)) do partition
selection = partitions .== partition
count(selection) == 1 && return only(path.args[selection])
count(selection) == 1 && return only(args(path)[selection])

expr = sum(path.args[selection], skip = path.head)
expr = sum(args(path)[selection], skip = path.head)
einexpr(config, expr)
end

return EinExpr(path.head, args)
return sum(_args, skip = path.head)
end
12 changes: 9 additions & 3 deletions src/Optimizers/Naive.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
using AbstractTrees

struct Naive <: Optimizer end

einexpr(::Naive, path, _) = einexpr(Naive(), path)

function einexpr(::Naive, path)
hist = Dict(i => count(∋(i) ∘ head, path.args) for i in hyperinds(path))
hist = Dict(i => count(∋(i) ∘ head, args(path)) for i in hyperinds(path))

foldl(path.args) do a, b
expr = sum([a, b], skip = path.head ∪ collect(keys(hist)))
foldl(args(path)) do a, b
expr = sum([a, b], skip = head(path) ∪ collect(keys(hist)))

for i in Iterators.filter(∈(keys(hist)), ∩(head(a), head(b)))
hist[i] -= 1
Expand All @@ -14,3 +18,5 @@ function einexpr(::Naive, path)
return expr
end
end

einexpr(::Naive, path::SizedEinExpr) = SizedEinExpr(einexpr(Naive(), path.path), path.size)
Loading
Loading