Skip to content

Commit

Permalink
Enhance edge and vertex weights for HyPar optimizer (#43)
Browse files Browse the repository at this point in the history
Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
  • Loading branch information
jofrevalles and mofeing authored Nov 3, 2023
1 parent b78f86d commit 98cdef0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
6 changes: 4 additions & 2 deletions src/Optimizers/KaHyPar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ using KaHyPar
imbalance::Float32 = 0.03
stop::Function = <=(2) length Base.Fix2(getfield, :args)
configuration::Union{Nothing,Symbol,String} = nothing
edge_scaler::Function = Base.Fix1(*, 1000) Int round log2
vertex_scaler::Function = Base.Fix1(*, 1000) Int round log2
end

function EinExprs.einexpr(config::HyPar, path)
Expand All @@ -21,8 +23,8 @@ function EinExprs.einexpr(config::HyPar, path)
incidence_matrix = sparse(I, J, V)

# NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order
edge_weights = map(Base.Fix1(size, path), inds)
vertex_weights = ones(Int, length(path.args))
edge_weights = map(config.edge_scaler Base.Fix1(size, path), inds)
vertex_weights = map(config.vertex_scaler length, path.args)

hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights)

Expand Down
6 changes: 3 additions & 3 deletions test/KaHyPar_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])),
]

path = einexpr(HyPar, EinExpr(Symbol[], tensors))
path = einexpr(HyPar(imbalance=0.42), EinExpr(Symbol[], tensors))

@test path isa EinExpr

Expand Down Expand Up @@ -41,10 +41,10 @@
EinExpr([:a, :C, :d], Dict(:a => 3, :d => 6, :C => 4)),
]

path = einexpr(HyPar, EinExpr(Symbol[], tensors))
path = einexpr(HyPar(imbalance=0.45), EinExpr(Symbol[], tensors))

@test path isa EinExpr

@test mapreduce(flops, +, Branches(path)) == 31653164
@test mapreduce(flops, +, Branches(path)) == 19099592
end
end

0 comments on commit 98cdef0

Please sign in to comment.