Skip to content

Commit

Permalink
Accelerate Exhaustive breadth-first search with One-Hot encoded sets
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Jan 18, 2024
1 parent cf6f2af commit 949fcc7
Showing 1 changed file with 66 additions and 12 deletions.
78 changes: 66 additions & 12 deletions src/Optimizers/Exhaustive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,19 @@ end

function einexpr(config::Exhaustive, path::SizedEinExpr{L}; cost = BigInt(0)) where {L}
if config.strategy === :breadth
return exhaustive_breadthfirst(Val(config.metric), path; outer = config.outer)
ninds = length(inds(path))
settype = if ninds <= 8
UInt8
elseif ninds <= 16
UInt16
elseif ninds <= 32
UInt32
elseif ninds <= 64
UInt64
else
BitSet
end
return exhaustive_breadthfirst(Val(config.metric), path, settype; outer = config.outer)
elseif config.strategy === :depth
init_path = einexpr(Naive(), path)
leader = Ref((;
Expand Down Expand Up @@ -69,12 +81,51 @@ function exhaustive_depthfirst(
end
end

onehot_init(T::Type{<:Integer}) = zero(T)
onehot_init(::Type{BitSet}) = BitSet()

function onehot_in(i, set::T) where {T<:Integer}
i > sizeof(T) * 8 && return false
mask = one(T) << (i - 1)
return mask & set != zero(T)
end
onehot_in(i, set::BitSet) = in(i, set)

function onehot_push!(set::T, i) where {T<:Integer}
i > sizeof(T) * 8 && error("Index out of bounds")
mask = one(T) << (i - 1)
set |= mask
return set
end
onehot_push!(set::BitSet, i) = push!(set, i)

function onehot_pop!(set::T, i) where {T<:Integer}
i > sizeof(T) * 8 && error("Index out of bounds")
mask = one(T) << (i - 1)
set &= ~mask
return set
end
onehot_pop!(set::BitSet, i) = pop!(set, i)

onehot_isdisjoint(a::T, b::T) where {T<:Integer} = a & b == zero(T)
onehot_isdisjoint(a::BitSet, b::BitSet) = isdisjoint(a, b)

onehot_union(a::T, b::T) where {T<:Integer} = a | b
onehot_union(a::BitSet, b::BitSet) = union(a, b)

onehot_only(set::T) where {T<:Integer} = count_ones(set) == 1 ? trailing_zeros(set) + 1 : error("Expected 1 element")
onehot_only(set::BitSet) = only(set)

onehot_isempty(set::T) where {T<:Integer} = set == zero(T)
onehot_isempty(set::BitSet) = isempty(set)

function exhaustive_breadthfirst(
@specialize(metric::Val{Metric}),
expr::SizedEinExpr{L};
expr::SizedEinExpr{L},
::Type{SetType} = BitSet;
outer::Bool = false,
hashyperinds = !isempty(hyperinds(expr)),
) where {L,Metric}
) where {L,Metric,SetType}
hashyperinds && error("Hyperindices not supported yet")

cost_fac = maximum(values(expr.size))
Expand All @@ -87,22 +138,24 @@ function exhaustive_breadthfirst(
n = nargs(expr)

# S[c]: set of all objects made up by contracting together `c` unique tensors from S[1]
# NOTE BitSet contains identifiers (i.e. an `Integer`) of input tensors, so each set is a candidate "contracted" subgraph
# NOTE Set contains identifiers (i.e. an `Integer`) of input tensors, so each set is a candidate "contracted" subgraph
# NOTE it doesn't contain all combinations (as it's combinatorially big); it's filtered by `cost_max`
S = map(_ -> BitSet[], 1:n)
S = map(_ -> SetType[], 1:n)

# initialize S₁
S[1] = [sizehint!(BitSet([i]), n) for i in 1:n]
S[1] = map(1:n) do i
onehot_push!(onehot_init(SetType), i)
end

# caches the best-known cost for constructing each object in S[c]
# NOTE no cost because no contraction on S₁ (only input tensors)
costs = Dict{BitSet,BigInt}(s => zero(BigInt) for s in S[1])
costs = Dict{SetType,BigInt}(s => zero(BigInt) for s in S[1])

# contains the indices of the intermediate tensors in S
indices = Dict{BitSet,Vector{L}}(s => head(expr.args[only(s)]) for s in S[1])
indices = Dict{SetType,Vector{L}}(s => head(expr.args[onehot_only(s)]) for s in S[1])

# contains the best-known contraction tree for constructing each object in S[c]
trees = Dict{BitSet,Tuple{BitSet,BitSet}}(s => (BitSet(), BitSet()) for s in S[1])
trees = Dict{SetType,Tuple{SetType,SetType}}(s => (onehot_init(SetType), onehot_init(SetType)) for s in S[1])

cost_cur = cost_max
cost_prev = zero(cost_max)
Expand All @@ -116,13 +169,14 @@ function exhaustive_breadthfirst(
k == c - k && ia >= ib && continue

# if not disjoint, then ta and tb contain at least one common tensor
isdisjoint(ta, tb) || continue
onehot_isdisjoint(ta, tb) || continue

# outer products do not generally improve contraction path
!outer && isdisjoint(indices[ta], indices[tb]) && continue

# new candidate contraction
tc = ta tb # aka Q in the paper
tc = onehot_union(ta, tb) # aka Q in the paper
get(costs, tc, cost_cur) > cost_prev || continue

# compute cost of getting `tc` by contracting `ta` and `tb
shallow_expr_a = EinExpr(indices[ta])
Expand Down Expand Up @@ -152,7 +206,7 @@ function exhaustive_breadthfirst(
function recurse_construct(tc)
ta, tb = trees[tc]

if isempty(ta) && isempty(tb)
if onehot_isempty(ta) && onehot_isempty(tb)
return EinExpr(indices[tc]::Vector{L})
end

Expand Down

0 comments on commit 949fcc7

Please sign in to comment.