Skip to content

Commit

Permalink
Update tests to recent changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Oct 27, 2023
1 parent 8b58f75 commit 6f0351c
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 40 deletions.
26 changes: 13 additions & 13 deletions test/Counters_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
using EinExprs: removedrank

@testset "identity" begin
tensor = EinExpr((:i, :j), Dict(:i => 2, :j => 3))
expr = EinExpr((:i, :j), [tensor])
tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3))
expr = EinExpr([:i, :j], [tensor])

@test flops(expr) == 0
@test removedsize(expr) == 0
@test removedrank(expr) == 0
end

@testset "transpose" begin
tensor = EinExpr((:i, :j), Dict(:i => 2, :j => 3))
tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3))
expr = EinExpr([:j, :i], [tensor])

@test flops(expr) == 0
Expand All @@ -20,25 +20,25 @@
end

@testset "axis sum" begin
tensor = EinExpr((:i, :j), Dict(:i => 2, :j => 3))
expr = EinExpr((:i,), [tensor])
tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3))
expr = EinExpr([:i], [tensor])

@test flops(expr) == 6
@test removedsize(expr) == 4
@test removedrank(expr) == 1
end

@testset "diagonal" begin
tensor = EinExpr((:i, :i), Dict(:i => 2))
expr = EinExpr((:i,), [tensor])
tensor = EinExpr([:i, :i], Dict(:i => 2))
expr = EinExpr([:i], [tensor])

@test flops(expr) == 0
@test removedsize(expr) == 2
@test removedrank(expr) == 1
end

@testset "trace" begin
tensor = EinExpr((:i, :i), Dict(:i => 2))
tensor = EinExpr([:i, :i], Dict(:i => 2))
expr = EinExpr(Symbol[], [tensor])

@test flops(expr) == 2
Expand All @@ -47,16 +47,16 @@
end

@testset "outer product" begin
tensors = [EinExpr((:i, :j), Dict(:i => 2, :j => 3)), EinExpr((:k, :l), Dict(:k => 4, :l => 5))]
expr = EinExpr((:i, :j, :k, :l), tensors)
tensors = [EinExpr([:i, :j], Dict(:i => 2, :j => 3)), EinExpr([:k, :l], Dict(:k => 4, :l => 5))]
expr = EinExpr([:i, :j, :k, :l], tensors)

@test flops(expr) == prod(2:5)
@test removedsize(expr) == -94
@test removedrank(expr) == -2
end

@testset "inner product" begin
tensors = [EinExpr((:i,), Dict(:i => 2)), EinExpr((:i,), Dict(:i => 2))]
tensors = [EinExpr([:i], Dict(:i => 2)), EinExpr([:i], Dict(:i => 2))]
expr = EinExpr(Symbol[], tensors)

@test flops(expr) == 2
Expand All @@ -65,8 +65,8 @@
end

@testset "matrix multiplication" begin
tensors = [EinExpr((:i, :k), Dict(:i => 2, :k => 3)), EinExpr((:k, :j), Dict(:k => 3, :j => 4))]
expr = EinExpr((:i, :j), tensors)
tensors = [EinExpr([:i, :k], Dict(:i => 2, :k => 3)), EinExpr([:k, :j], Dict(:k => 3, :j => 4))]
expr = EinExpr([:i, :j], tensors)

@test flops(expr) == 2 * 3 * 4
@test removedsize(expr) == 10
Expand Down
27 changes: 12 additions & 15 deletions test/EinExpr_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

@testset "identity" begin
tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3))
expr = EinExpr((:i, :j), [tensor])
expr = EinExpr([:i, :j], [tensor])

@test expr.head == head(tensor)
@test expr.args == [tensor]
Expand All @@ -28,7 +28,7 @@

@testset "transpose" begin
tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3))
expr = EinExpr((:j, :i), [tensor])
expr = EinExpr([:j, :i], [tensor])

@test expr.head == reverse(inds(tensor))
@test expr.args == [tensor]
Expand Down Expand Up @@ -59,7 +59,7 @@
@test expr.args == [tensor]

@test all(splat(==), zip(head(expr), (:i,)))
@test all(splat(==), zip(inds(expr), (:i, :j)))
@test all(splat(==), zip(inds(expr), [:i, :j]))

@test size(expr, :i) == 2
@test size(expr, :j) == 3
Expand Down Expand Up @@ -121,8 +121,8 @@
end

@testset "outer product" begin
tensors = [EinExpr((:i, :j), Dict(:i => 2, :j => 3)), EinExpr((:k, :l), Dict(:k => 4, :l => 5))]
expr = EinExpr((:i, :j, :k, :l), tensors)
tensors = [EinExpr([:i, :j], Dict(:i => 2, :j => 3)), EinExpr([:k, :l], Dict(:k => 4, :l => 5))]
expr = EinExpr([:i, :j, :k, :l], tensors)

@test all(splat(==), zip(expr.head, (:i, :j, :k, :l)))
@test expr.args == tensors
Expand Down Expand Up @@ -151,7 +151,7 @@

@testset "inner product" begin
@testset "Vector" begin
tensors = [EinExpr((:i,), Dict(:i => 2)), EinExpr((:i,), Dict(:i => 2))]
tensors = [EinExpr([:i], Dict(:i => 2)), EinExpr([:i], Dict(:i => 2))]
expr = EinExpr(Symbol[], tensors)

@test isempty(expr.head)
Expand All @@ -173,14 +173,14 @@
@test isempty(neighbours(expr, :i))
end
@testset "Matrix" begin
tensors = [EinExpr((:i, :j), Dict(:i => 2, :j => 3)), EinExpr((:i, :j), Dict(:i => 2, :j => 3))]
tensors = [EinExpr([:i, :j], Dict(:i => 2, :j => 3)), EinExpr([:i, :j], Dict(:i => 2, :j => 3))]
expr = EinExpr(Symbol[], tensors)

@test isempty(expr.head)
@test expr.args == tensors

@test isempty(head(expr))
@test all(splat(==), zip(inds(expr), (:i, :j)))
@test all(splat(==), zip(inds(expr), [:i, :j]))
@test ndims(expr) == 0

@test size(expr, :i) == 2
Expand All @@ -199,13 +199,13 @@
end

@testset "matrix multiplication" begin
tensors = [EinExpr((:i, :k), Dict(:i => 2, :k => 3)), EinExpr((:k, :j), Dict(:k => 3, :j => 4))]
expr = EinExpr((:i, :j), tensors)
tensors = [EinExpr([:i, :k], Dict(:i => 2, :k => 3)), EinExpr([:k, :j], Dict(:k => 3, :j => 4))]
expr = EinExpr([:i, :j], tensors)

@test all(splat(==), zip(expr.head, (:i, :j)))
@test all(splat(==), zip(expr.head, [:i, :j]))
@test expr.args == tensors

@test all(splat(==), zip(head(expr), (:i, :j)))
@test all(splat(==), zip(head(expr), [:i, :j]))
@test all(splat(==), zip(inds(expr), (:i, :k, :j)))
@test ndims(expr) == 2

Expand Down Expand Up @@ -236,9 +236,6 @@

expr = sum(tensors, skip = [])

@test all(splat(==), zip(expr.head, (:i, :j, :k, :l, :m, )))
@test expr.args == tensors

@test issetequal(head(expr), (:i, :j, :k, :l, :m, ))
@test issetequal(inds(expr), (:i, :j, :k, :l, :m, ))
@test ndims(expr) == 6
Expand Down
4 changes: 2 additions & 2 deletions test/Exhaustive_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

@test path isa EinExpr

@test mapreduce(flops, +, Branches(path)) == 92
@test mapreduce(flops, +, Branches(path)) == 90

@test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:j], [:h, :i], [:b, :d]]))
@test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:d], [:b, :i, :h], [:j]]))

@testset "hyperedges" begin
a = EinExpr([:i, , :j], Dict(i => 2 for i in [:i, , :j]))
Expand Down
20 changes: 10 additions & 10 deletions test/Slicing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,43 +44,43 @@
[:m, :f, :g],
[
EinExpr(
(:m, :f, :q),
[:m, :f, :q],
Dict(i => sizes[i] for i in [:m, :f, :q]),
),
EinExpr(
(:g, :q),
[:g, :q],
Dict(i => sizes[i] for i in [:g, :q]),
),
],
),
EinExpr(
(:o, :i, :m, :c),
[:o, :i, :m, :c],
Dict(i => sizes[i] for i in [:o, :i, :m, :c]),
),
],
),
EinExpr((:f, :l, :i), Dict(i => sizes[i] for i in [:f, :l, :i])),
EinExpr([:f, :l, :i], Dict(i => sizes[i] for i in [:f, :l, :i])),
],
),
EinExpr((:g, :n, :l, :a), Dict(i => sizes[i] for i in [:g, :n, :l, :a])),
EinExpr([:g, :n, :l, :a], Dict(i => sizes[i] for i in [:g, :n, :l, :a])),
],
),
EinExpr(
[:e, :d, :o],
[
EinExpr((:b, :e), Dict(i => sizes[i] for i in [:b, :e])),
EinExpr((:d, :b, :o), Dict(i => sizes[i] for i in [:d, :b, :o])),
EinExpr([:b, :e], Dict(i => sizes[i] for i in [:b, :e])),
EinExpr([:d, :b, :o], Dict(i => sizes[i] for i in [:d, :b, :o])),
],
),
],
),
EinExpr((:c, :e, :h), Dict(i => sizes[i] for i in [:c, :e, :h])),
EinExpr([:c, :e, :h], Dict(i => sizes[i] for i in [:c, :e, :h])),
],
),
EinExpr((:k, :d, :h, :a, :n, :j), Dict(i => sizes[i] for i in [:k, :d, :h, :a, :n, :j])),
EinExpr([:k, :d, :h, :a, :n, :j], Dict(i => sizes[i] for i in [:k, :d, :h, :a, :n, :j])),
],
),
EinExpr((:p, :k), Dict(i => sizes[i] for i in [:p, :k])),
EinExpr([:p, :k], Dict(i => sizes[i] for i in [:p, :k])),
],
)

Expand Down

0 comments on commit 6f0351c

Please sign in to comment.