Skip to content

Commit

Permalink
Implement KaHyPar optimizer (#35)
Browse files Browse the repository at this point in the history
Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
  • Loading branch information
3 people authored Oct 30, 2023
1 parent d0cfcd5 commit d81db6f
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 1 deletion.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion src/EinExprs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ export findslices, FlopsScorer, SizeScorer

include("Optimizers/Optimizers.jl")
export Optimizer, einexpr
export Exhaustive, Greedy
export Exhaustive, Greedy, HyPar

end
48 changes: 48 additions & 0 deletions src/Optimizers/KaHyPar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using AbstractTrees
using SparseArrays
using KaHyPar

@kwdef struct HyPar <: Optimizer
parts = 2
imbalance = 0.03
cutoff = 2
configuration::Union{Nothing,Symbol,String} = nothing
end

function EinExprs.einexpr(config::HyPar, path)
inds = mapreduce(head, , path.args)
indexmap = Dict(Iterators.map(splat(Pair) reverse, enumerate(inds)))

num_columns = maximum(values(indexmap))
num_rows = length(path.args)
incidence_matrix = spzeros(Int, num_rows, num_columns)

# Iterate through each tensor and its associated indices, and update the incidence matrix.
for (i, tensor) in enumerate(path.args)
tensor_indices = [i] # Current tensor is represented as a row in the matrix.
edge_indices = [indexmap[idx] for idx in head(tensor)] # Map indices via 'indexmap'.

# Create a subview for the current tensor and associated hyperedges.
incidence_subview = view(incidence_matrix, tensor_indices, edge_indices)
incidence_subview .= 1 # Update the subview directly. This step modifies the original sparse matrix.
end

# NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order
edge_weights = map(Base.Fix1(size, path), inds)
vertex_weights = ones(Int, length(path.args))

hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights)

# stop on cutoff
hypergraph.n_vertices <= config.cutoff && return path

partitions =
KaHyPar.partition(hypergraph, config.parts; imbalance = config.imbalance, configuration = config.configuration)

args = map(unique(partitions)) do partition
expr = sum(path.args[partitions.==partition], skip = path.head)
einexpr(config, expr)
end

return EinExpr(path.head, args)
end
1 change: 1 addition & 0 deletions src/Optimizers/Optimizers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ einexpr(config::Optimizer, expr) = einexpr(config, expr)
include("Naive.jl")
include("Exhaustive.jl")
include("Greedy.jl")
include("KaHyPar.jl")
50 changes: 50 additions & 0 deletions test/KaHyPar_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
@testset "KaHyPar" begin
@testset begin
tensors = [
EinExpr([:j, :b, :i, :h], Dict(i => 2 for i in [:j, :b, :i, :h])),
EinExpr([:a, :c, :e, :f], Dict(i => 2 for i in [:a, :c, :e, :f])),
EinExpr([:j], Dict(i => 2 for i in [:j])),
EinExpr([:e, :a, :g], Dict(i => 2 for i in [:e, :a, :g])),
EinExpr([:f, :b], Dict(i => 2 for i in [:f, :b])),
EinExpr([:i, :h, :d], Dict(i => 2 for i in [:i, :h, :d])),
EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])),
]

path = einexpr(HyPar, EinExpr(Symbol[], tensors))

@test path isa EinExpr

@test mapreduce(flops, +, Branches(path)) == 108
end

@testset begin
tensors = [
EinExpr([:F, :P, :V], Dict(:P => 5, :F => 8, :V => 5)),
EinExpr([:T, :Y, :V, :X, :N, :B], Dict(:T => 5, :N => 2, :B => 5, :V => 5, :Y => 7, :X => 8)),
EinExpr([:L, :K, :S], Dict(:K => 8, :L => 7, :S => 5)),
EinExpr([:M, :J, :Q, :O], Dict(:M => 5, :J => 7, :Q => 7, :O => 6)),
EinExpr([:c], Dict(:c => 2)),
EinExpr([:I, :U, :E], Dict(:U => 8, :I => 5, :E => 6)),
EinExpr([:N, :C], Dict(:N => 2, :C => 4)),
EinExpr([:a, :K], Dict(:a => 3, :K => 8)),
EinExpr([:d, :E, :M], Dict(:M => 5, :d => 6, :E => 6)),
EinExpr([:B, :b, :D, :H, :L], Dict(:b => 5, :H => 7, :D => 8, :B => 5, :L => 7)),
EinExpr([:c, :P, :X, :Q], Dict(:P => 5, :Q => 7, :c => 2, :X => 8)),
EinExpr([:G], Dict(:G => 6)),
EinExpr([:Z, :W], Dict(:Z => 9, :W => 6)),
EinExpr([:Y, :H, :S], Dict(:H => 7, :S => 5, :Y => 7)),
EinExpr([:O, :F, :b, :I], Dict(:b => 5, :I => 5, :F => 8, :O => 6)),
EinExpr([:A, :J, :T, :G], Dict(:T => 5, :A => 6, :J => 7, :G => 6)),
EinExpr([:Z, :D, :R], Dict(:Z => 9, :R => 8, :D => 8)),
EinExpr([:R, :U], Dict(:U => 8, :R => 8)),
EinExpr([:A, :W], Dict(:A => 6, :W => 6)),
EinExpr([:a, :C, :d], Dict(:a => 3, :d => 6, :C => 4)),
]

path = einexpr(HyPar, EinExpr(Symbol[], tensors))

@test path isa EinExpr

@test mapreduce(flops, +, Branches(path)) == 31653164
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using EinExprs
include("Naive_test.jl")
include("Exhaustive_test.jl")
include("Greedy_test.jl")
include("KaHyPar_test.jl")
end
include("Slicing_test.jl")
end
Expand Down

0 comments on commit d81db6f

Please sign in to comment.