Skip to content

Commit

Permalink
Fix tests for Exhaustive strategies
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Jan 18, 2024
1 parent 0936be6 commit 0766f76
Showing 1 changed file with 59 additions and 33 deletions.
92 changes: 59 additions & 33 deletions test/Exhaustive_test.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,62 @@
@testset "Exhaustive" begin
sizedict = Dict(i => 2 for i in [:a, :b, :c, :d, :e, :f, :g, :h, :i, :j])
tensors = [
EinExpr([:j, :b, :i, :h]),
EinExpr([:a, :c, :e, :f]),
EinExpr([:j]),
EinExpr([:e, :a, :g]),
EinExpr([:f, :b]),
EinExpr([:i, :h, :d]),
EinExpr([:d, :g, :c]),
]
expr = EinExpr(Symbol[], tensors)
sexpr = SizedEinExpr(expr, sizedict)

path = einexpr(Exhaustive, sexpr)

@test path isa SizedEinExpr

@test mapreduce(flops, +, Branches(path)) == 92

@test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:j], [:i, :h], [:d, :b]]))

@testset "hyperedges" begin
sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m, ])
a = EinExpr([:i, , :j])
b = EinExpr([:k, ])
c = EinExpr([, :l, :m])

path = einexpr(EinExprs.Exhaustive(), SizedEinExpr(sum([a, b, c], skip = []), sizedict))
@test all(() head, branches(path))

path = einexpr(EinExprs.Exhaustive(), SizedEinExpr(sum([a, b, c], skip = Symbol[]), sizedict))
@test all(() head, branches(path)[1:end-1])
@test all(!() head, branches(path)[end:end])
@testset "Depth-first" begin
sizedict = Dict(i => 2 for i in [:a, :b, :c, :d, :e, :f, :g, :h, :i, :j])
tensors = [
EinExpr([:j, :b, :i, :h]),
EinExpr([:a, :c, :e, :f]),
EinExpr([:j]),
EinExpr([:e, :a, :g]),
EinExpr([:f, :b]),
EinExpr([:i, :h, :d]),
EinExpr([:d, :g, :c]),
]
expr = EinExpr(Symbol[], tensors)
sexpr = SizedEinExpr(expr, sizedict)

path = einexpr(Exhaustive(strategy = :depth), sexpr)

@test path isa SizedEinExpr

@test mapreduce(flops, +, Branches(path)) == 92

@test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:j], [:i, :h], [:d, :b]]))

@testset "hyperedges" begin
sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m, ])
a = EinExpr([:i, , :j])
b = EinExpr([:k, ])
c = EinExpr([, :l, :m])

path = einexpr(EinExprs.Exhaustive(strategy = :depth), SizedEinExpr(sum([a, b, c], skip = []), sizedict))
@test all(() head, branches(path))

path =
einexpr(EinExprs.Exhaustive(strategy = :depth), SizedEinExpr(sum([a, b, c], skip = Symbol[]), sizedict))
@test all(() head, branches(path)[1:end-1])
@test all(!() head, branches(path)[end:end])
end
end

@testset "Breadth-first" begin
sizedict = Dict(i => 2 for i in [:a, :b, :c, :d, :e, :f, :g, :h, :i, :j])
tensors = [
EinExpr([:j, :b, :i, :h]),
EinExpr([:a, :c, :e, :f]),
EinExpr([:j]),
EinExpr([:e, :a, :g]),
EinExpr([:f, :b]),
EinExpr([:i, :h, :d]),
EinExpr([:d, :g, :c]),
]
expr = EinExpr(Symbol[], tensors)
sexpr = SizedEinExpr(expr, sizedict)

path = einexpr(Exhaustive(strategy = :breadth), sexpr)

@test path isa SizedEinExpr

@test mapreduce(flops, +, Branches(path)) == 90

@test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:d], [:b, :i, :h], [:j]]))
end
end

0 comments on commit 0766f76

Please sign in to comment.