diff --git a/src/Optimizers/Greedy.jl b/src/Optimizers/Greedy.jl index 38e913f..819bd88 100644 --- a/src/Optimizers/Greedy.jl +++ b/src/Optimizers/Greedy.jl @@ -32,6 +32,8 @@ function einexpr(config::Greedy, path::EinExpr{L}, sizedict::Dict{L}) where {L} path = sumtraces(path) metric = config.metric(sizedict) + hashyperinds = !isempty(hyperinds(path)) + # generate initial candidate contractions queue = MutableBinaryHeap{Tuple{Float64,EinExpr{L}}}( Base.By(first, Base.Reverse), @@ -39,7 +41,7 @@ function einexpr(config::Greedy, path::EinExpr{L}, sizedict::Dict{L}) where {L} Iterators.filter(((a, b),) -> config.outer || !isdisjoint(a.head, b.head), combinations(path.args, 2)), ) do (a, b) # TODO don't consider outer products - candidate = sum([a, b], skip = path.head ∪ hyperinds(path)) + candidate = sum([a, b], skip = hashyperinds ? path.head ∪ hyperinds(path) : path.head) weight = metric(candidate) (weight, candidate) end, @@ -58,7 +60,7 @@ function einexpr(config::Greedy, path::EinExpr{L}, sizedict::Dict{L}) where {L} # update candidate queue for other in Iterators.filter(other -> config.outer || !isdisjoint(winner.head, other.head), path.args) # TODO don't consider outer products - candidate = sum([winner, other], skip = path.head ∪ hyperinds(path)) + candidate = sum([winner, other], skip = hashyperinds ? path.head ∪ hyperinds(path) : path.head) weight = metric(candidate) push!(queue, (weight, candidate)) end