From 6f0351ce02f960607df0067660305d86c2ea0b58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 27 Oct 2023 11:10:56 +0200 Subject: [PATCH] Update tests to recent changes --- test/Counters_test.jl | 26 +++++++++++++------------- test/EinExpr_test.jl | 27 ++++++++++++--------------- test/Exhaustive_test.jl | 4 ++-- test/Slicing_test.jl | 20 ++++++++++---------- 4 files changed, 37 insertions(+), 40 deletions(-) diff --git a/test/Counters_test.jl b/test/Counters_test.jl index b9327c3..6edc54e 100644 --- a/test/Counters_test.jl +++ b/test/Counters_test.jl @@ -2,8 +2,8 @@ using EinExprs: removedrank @testset "identity" begin - tensor = EinExpr((:i, :j), Dict(:i => 2, :j => 3)) - expr = EinExpr((:i, :j), [tensor]) + tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) + expr = EinExpr([:i, :j], [tensor]) @test flops(expr) == 0 @test removedsize(expr) == 0 @@ -11,7 +11,7 @@ end @testset "transpose" begin - tensor = EinExpr((:i, :j), Dict(:i => 2, :j => 3)) + tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) expr = EinExpr([:j, :i], [tensor]) @test flops(expr) == 0 @@ -20,8 +20,8 @@ end @testset "axis sum" begin - tensor = EinExpr((:i, :j), Dict(:i => 2, :j => 3)) - expr = EinExpr((:i,), [tensor]) + tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) + expr = EinExpr([:i], [tensor]) @test flops(expr) == 6 @test removedsize(expr) == 4 @@ -29,8 +29,8 @@ end @testset "diagonal" begin - tensor = EinExpr((:i, :i), Dict(:i => 2)) - expr = EinExpr((:i,), [tensor]) + tensor = EinExpr([:i, :i], Dict(:i => 2)) + expr = EinExpr([:i], [tensor]) @test flops(expr) == 0 @test removedsize(expr) == 2 @@ -38,7 +38,7 @@ end @testset "trace" begin - tensor = EinExpr((:i, :i), Dict(:i => 2)) + tensor = EinExpr([:i, :i], Dict(:i => 2)) expr = EinExpr(Symbol[], [tensor]) @test flops(expr) == 2 @@ -47,8 +47,8 @@ end @testset "outer product" begin - tensors = [EinExpr((:i, :j), Dict(:i => 2, :j => 3)), EinExpr((:k, :l), Dict(:k => 4, :l => 5))] - expr = EinExpr((:i, :j, :k, :l), tensors) + tensors = [EinExpr([:i, :j], Dict(:i => 2, :j => 3)), EinExpr([:k, :l], Dict(:k => 4, :l => 5))] + expr = EinExpr([:i, :j, :k, :l], tensors) @test flops(expr) == prod(2:5) @test removedsize(expr) == -94 @@ -56,7 +56,7 @@ end @testset "inner product" begin - tensors = [EinExpr((:i,), Dict(:i => 2)), EinExpr((:i,), Dict(:i => 2))] + tensors = [EinExpr([:i], Dict(:i => 2)), EinExpr([:i], Dict(:i => 2))] expr = EinExpr(Symbol[], tensors) @test flops(expr) == 2 @@ -65,8 +65,8 @@ end @testset "matrix multiplication" begin - tensors = [EinExpr((:i, :k), Dict(:i => 2, :k => 3)), EinExpr((:k, :j), Dict(:k => 3, :j => 4))] - expr = EinExpr((:i, :j), tensors) + tensors = [EinExpr([:i, :k], Dict(:i => 2, :k => 3)), EinExpr([:k, :j], Dict(:k => 3, :j => 4))] + expr = EinExpr([:i, :j], tensors) @test flops(expr) == 2 * 3 * 4 @test removedsize(expr) == 10 diff --git a/test/EinExpr_test.jl b/test/EinExpr_test.jl index 65357d4..4d781fe 100644 --- a/test/EinExpr_test.jl +++ b/test/EinExpr_test.jl @@ -3,7 +3,7 @@ @testset "identity" begin tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) - expr = EinExpr((:i, :j), [tensor]) + expr = EinExpr([:i, :j], [tensor]) @test expr.head == head(tensor) @test expr.args == [tensor] @@ -28,7 +28,7 @@ @testset "transpose" begin tensor = EinExpr([:i, :j], Dict(:i => 2, :j => 3)) - expr = EinExpr((:j, :i), [tensor]) + expr = EinExpr([:j, :i], [tensor]) @test expr.head == reverse(inds(tensor)) @test expr.args == [tensor] @@ -59,7 +59,7 @@ @test expr.args == [tensor] @test all(splat(==), zip(head(expr), (:i,))) - @test all(splat(==), zip(inds(expr), (:i, :j))) + @test all(splat(==), zip(inds(expr), [:i, :j])) @test size(expr, :i) == 2 @test size(expr, :j) == 3 @@ -121,8 +121,8 @@ end @testset "outer product" begin - tensors = [EinExpr((:i, :j), Dict(:i => 2, :j => 3)), EinExpr((:k, :l), Dict(:k => 4, :l => 5))] - expr = EinExpr((:i, :j, :k, :l), tensors) + tensors = [EinExpr([:i, :j], Dict(:i => 2, :j => 3)), EinExpr([:k, :l], Dict(:k => 4, :l => 5))] + expr = EinExpr([:i, :j, :k, :l], tensors) @test all(splat(==), zip(expr.head, (:i, :j, :k, :l))) @test expr.args == tensors @@ -151,7 +151,7 @@ @testset "inner product" begin @testset "Vector" begin - tensors = [EinExpr((:i,), Dict(:i => 2)), EinExpr((:i,), Dict(:i => 2))] + tensors = [EinExpr([:i], Dict(:i => 2)), EinExpr([:i], Dict(:i => 2))] expr = EinExpr(Symbol[], tensors) @test isempty(expr.head) @@ -173,14 +173,14 @@ @test isempty(neighbours(expr, :i)) end @testset "Matrix" begin - tensors = [EinExpr((:i, :j), Dict(:i => 2, :j => 3)), EinExpr((:i, :j), Dict(:i => 2, :j => 3))] + tensors = [EinExpr([:i, :j], Dict(:i => 2, :j => 3)), EinExpr([:i, :j], Dict(:i => 2, :j => 3))] expr = EinExpr(Symbol[], tensors) @test isempty(expr.head) @test expr.args == tensors @test isempty(head(expr)) - @test all(splat(==), zip(inds(expr), (:i, :j))) + @test all(splat(==), zip(inds(expr), [:i, :j])) @test ndims(expr) == 0 @test size(expr, :i) == 2 @@ -199,13 +199,13 @@ end @testset "matrix multiplication" begin - tensors = [EinExpr((:i, :k), Dict(:i => 2, :k => 3)), EinExpr((:k, :j), Dict(:k => 3, :j => 4))] - expr = EinExpr((:i, :j), tensors) + tensors = [EinExpr([:i, :k], Dict(:i => 2, :k => 3)), EinExpr([:k, :j], Dict(:k => 3, :j => 4))] + expr = EinExpr([:i, :j], tensors) - @test all(splat(==), zip(expr.head, (:i, :j))) + @test all(splat(==), zip(expr.head, [:i, :j])) @test expr.args == tensors - @test all(splat(==), zip(head(expr), (:i, :j))) + @test all(splat(==), zip(head(expr), [:i, :j])) @test all(splat(==), zip(inds(expr), (:i, :k, :j))) @test ndims(expr) == 2 @@ -236,9 +236,6 @@ expr = sum(tensors, skip = [:β]) - @test all(splat(==), zip(expr.head, (:i, :j, :k, :l, :m, :β))) - @test expr.args == tensors - @test issetequal(head(expr), (:i, :j, :k, :l, :m, :β)) @test issetequal(inds(expr), (:i, :j, :k, :l, :m, :β)) @test ndims(expr) == 6 diff --git a/test/Exhaustive_test.jl b/test/Exhaustive_test.jl index 0b9ea66..a27316c 100644 --- a/test/Exhaustive_test.jl +++ b/test/Exhaustive_test.jl @@ -13,9 +13,9 @@ @test path isa EinExpr - @test mapreduce(flops, +, Branches(path)) == 92 + @test mapreduce(flops, +, Branches(path)) == 90 - @test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:j], [:h, :i], [:b, :d]])) + @test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:d], [:b, :i, :h], [:j]])) @testset "hyperedges" begin a = EinExpr([:i, :β, :j], Dict(i => 2 for i in [:i, :β, :j])) diff --git a/test/Slicing_test.jl b/test/Slicing_test.jl index 83e4280..720038d 100644 --- a/test/Slicing_test.jl +++ b/test/Slicing_test.jl @@ -44,43 +44,43 @@ [:m, :f, :g], [ EinExpr( - (:m, :f, :q), + [:m, :f, :q], Dict(i => sizes[i] for i in [:m, :f, :q]), ), EinExpr( - (:g, :q), + [:g, :q], Dict(i => sizes[i] for i in [:g, :q]), ), ], ), EinExpr( - (:o, :i, :m, :c), + [:o, :i, :m, :c], Dict(i => sizes[i] for i in [:o, :i, :m, :c]), ), ], ), - EinExpr((:f, :l, :i), Dict(i => sizes[i] for i in [:f, :l, :i])), + EinExpr([:f, :l, :i], Dict(i => sizes[i] for i in [:f, :l, :i])), ], ), - EinExpr((:g, :n, :l, :a), Dict(i => sizes[i] for i in [:g, :n, :l, :a])), + EinExpr([:g, :n, :l, :a], Dict(i => sizes[i] for i in [:g, :n, :l, :a])), ], ), EinExpr( [:e, :d, :o], [ - EinExpr((:b, :e), Dict(i => sizes[i] for i in [:b, :e])), - EinExpr((:d, :b, :o), Dict(i => sizes[i] for i in [:d, :b, :o])), + EinExpr([:b, :e], Dict(i => sizes[i] for i in [:b, :e])), + EinExpr([:d, :b, :o], Dict(i => sizes[i] for i in [:d, :b, :o])), ], ), ], ), - EinExpr((:c, :e, :h), Dict(i => sizes[i] for i in [:c, :e, :h])), + EinExpr([:c, :e, :h], Dict(i => sizes[i] for i in [:c, :e, :h])), ], ), - EinExpr((:k, :d, :h, :a, :n, :j), Dict(i => sizes[i] for i in [:k, :d, :h, :a, :n, :j])), + EinExpr([:k, :d, :h, :a, :n, :j], Dict(i => sizes[i] for i in [:k, :d, :h, :a, :n, :j])), ], ), - EinExpr((:p, :k), Dict(i => sizes[i] for i in [:p, :k])), + EinExpr([:p, :k], Dict(i => sizes[i] for i in [:p, :k])), ], )