Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement KaHyPar optimizer #35

Merged
merged 15 commits into from
Oct 30, 2023
Merged
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion src/EinExprs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ export findslices, FlopsScorer, SizeScorer

include("Optimizers/Optimizers.jl")
export Optimizer, einexpr
export Exhaustive, Greedy
export Exhaustive, Greedy, HyPar

end
48 changes: 48 additions & 0 deletions src/Optimizers/KaHyPar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using AbstractTrees
using SparseArrays
using KaHyPar

@kwdef struct HyPar <: Optimizer
parts = 2
imbalance = 0.03
cutoff = 2
configuration::Union{Nothing,Symbol,String} = nothing
end

function EinExprs.einexpr(config::HyPar, path)
inds = mapreduce(head, , path.args)
indexmap = Dict(Iterators.map(splat(Pair) reverse, enumerate(inds)))

num_columns = maximum(values(indexmap))
num_rows = length(path.args)
incidence_matrix = spzeros(Int, num_rows, num_columns)

# Iterate through each tensor and its associated indices, and update the incidence matrix.
for (i, tensor) in enumerate(path.args)
tensor_indices = [i] # Current tensor is represented as a row in the matrix.
edge_indices = [indexmap[idx] for idx in head(tensor)] # Map indices via 'indexmap'.

# Create a subview for the current tensor and associated hyperedges.
incidence_subview = view(incidence_matrix, tensor_indices, edge_indices)
incidence_subview .= 1 # Update the subview directly. This step modifies the original sparse matrix.
end

# NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order
edge_weights = map(Base.Fix1(size, path), inds)
vertex_weights = ones(Int, length(path.args))

hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights)

# stop on cutoff
hypergraph.n_vertices <= config.cutoff && return path

partitions =
KaHyPar.partition(hypergraph, config.parts; imbalance = config.imbalance, configuration = config.configuration)

args = map(unique(partitions)) do partition
expr = sum(path.args[partitions.==partition], skip = path.head)
einexpr(config, expr)
end

return EinExpr(path.head, args)
end
1 change: 1 addition & 0 deletions src/Optimizers/Optimizers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ einexpr(config::Optimizer, expr) = einexpr(config, expr)
include("Naive.jl")
include("Exhaustive.jl")
include("Greedy.jl")
include("KaHyPar.jl")
50 changes: 50 additions & 0 deletions test/KaHyPar_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
@testset "KaHyPar" begin
@testset begin
tensors = [
EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])),
EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])),
EinExpr([:j], Dict(i => 2 for i in [:j])),
EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])),
EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])),
EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])),
EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])),
]

path = einexpr(HyPar, EinExpr(Symbol[], tensors))

@test path isa EinExpr

@test mapreduce(flops, +, Branches(path)) == 108
end

@testset begin
tensors = [
EinExpr([:F, :P, :V], Dict(:P => 5, :F => 8, :V => 5)),
EinExpr([:T, :Y, :V, :X, :N, :B], Dict(:T => 5, :N => 2, :B => 5, :V => 5, :Y => 7, :X => 8)),
EinExpr([:L, :K, :S], Dict(:K => 8, :L => 7, :S => 5)),
EinExpr([:M, :J, :Q, :O], Dict(:M => 5, :J => 7, :Q => 7, :O => 6)),
EinExpr([:c], Dict(:c => 2)),
EinExpr([:I, :U, :E], Dict(:U => 8, :I => 5, :E => 6)),
EinExpr([:N, :C], Dict(:N => 2, :C => 4)),
EinExpr([:a, :K], Dict(:a => 3, :K => 8)),
EinExpr([:d, :E, :M], Dict(:M => 5, :d => 6, :E => 6)),
EinExpr([:B, :b, :D, :H, :L], Dict(:b => 5, :H => 7, :D => 8, :B => 5, :L => 7)),
EinExpr([:c, :P, :X, :Q], Dict(:P => 5, :Q => 7, :c => 2, :X => 8)),
EinExpr([:G], Dict(:G => 6)),
EinExpr([:Z, :W], Dict(:Z => 9, :W => 6)),
EinExpr([:Y, :H, :S], Dict(:H => 7, :S => 5, :Y => 7)),
EinExpr([:O, :F, :b, :I], Dict(:b => 5, :I => 5, :F => 8, :O => 6)),
EinExpr([:A, :J, :T, :G], Dict(:T => 5, :A => 6, :J => 7, :G => 6)),
EinExpr([:Z, :D, :R], Dict(:Z => 9, :R => 8, :D => 8)),
EinExpr([:R, :U], Dict(:U => 8, :R => 8)),
EinExpr([:A, :W], Dict(:A => 6, :W => 6)),
EinExpr([:a, :C, :d], Dict(:a => 3, :d => 6, :C => 4)),
]

path = einexpr(HyPar, EinExpr(Symbol[], tensors))

@test path isa EinExpr

@test mapreduce(flops, +, Branches(path)) == 31653164
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using EinExprs
include("Naive_test.jl")
include("Exhaustive_test.jl")
include("Greedy_test.jl")
include("KaHyPar_test.jl")
end
include("Slicing_test.jl")
end
Expand Down
Loading