Skip to content

Commit

Permalink
Speedup Greedy optimizer if no hyperindex is present
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Feb 9, 2024
1 parent b5df61a commit 7e953f4
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/Optimizers/Greedy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@ function einexpr(config::Greedy, path::EinExpr{L}, sizedict::Dict{L}) where {L}
path = sumtraces(path)
metric = config.metric(sizedict)

hashyperinds = !isempty(hyperinds(path))

# generate initial candidate contractions
queue = MutableBinaryHeap{Tuple{Float64,EinExpr{L}}}(
Base.By(first, Base.Reverse),
map(
Iterators.filter(((a, b),) -> config.outer || !isdisjoint(a.head, b.head), combinations(path.args, 2)),
) do (a, b)
# TODO don't consider outer products
candidate = sum([a, b], skip = path.head hyperinds(path))
candidate = sum([a, b], skip = hashyperinds ? path.head hyperinds(path) : path.head)
weight = metric(candidate)
(weight, candidate)
end,
Expand All @@ -58,7 +60,7 @@ function einexpr(config::Greedy, path::EinExpr{L}, sizedict::Dict{L}) where {L}
# update candidate queue
for other in Iterators.filter(other -> config.outer || !isdisjoint(winner.head, other.head), path.args)
# TODO don't consider outer products
candidate = sum([winner, other], skip = path.head hyperinds(path))
candidate = sum([winner, other], skip = hashyperinds ? path.head hyperinds(path) : path.head)
weight = metric(candidate)
push!(queue, (weight, candidate))
end
Expand Down

0 comments on commit 7e953f4

Please sign in to comment.