diff --git a/Project.toml b/Project.toml index 72ad02a..657f1c8 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.6.2" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880" @@ -29,11 +30,13 @@ EinExprsMakieExt = "Makie" [compat] AbstractTrees = "0.4" Combinatorics = "1.0" +Compat = "4" DataStructures = "0.18" FiniteDifferences = "0.12" GraphMakie = "0.5" Graphs = "1.6" KaHyPar = "0.3.1" Makie = "0.19" +PackageExtensionCompat = "1" Suppressor = "0.2" julia = "1.6" diff --git a/src/EinExpr.jl b/src/EinExpr.jl index 5fc78b3..3d8eaf2 100644 --- a/src/EinExpr.jl +++ b/src/EinExpr.jl @@ -1,6 +1,7 @@ using Base: AbstractVecOrTuple using DataStructures: DefaultDict using AbstractTrees +using Compat Base.@kwdef struct EinExpr{Label} head::Vector{Label} @@ -90,7 +91,7 @@ Base.ndims(path::EinExpr) = length(head(path)) Return the size of the resulting tensor from contracting `path`. If `index` is specified, return the size of such index. """ -Base.size(path::EinExpr, sizedict) = (sizedict[i] for i in head(path)) |> splat(tuple) +Base.size(path::EinExpr, sizedict) = (sizedict[i] for i in head(path)) |> @compat(splat(tuple)) Base.length(path::EinExpr, sizedict) = (prod ∘ size)(path, sizedict) """ diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index 26d0a3b..ac24d77 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -1,6 +1,7 @@ using Base: @kwdef using Combinatorics using LinearAlgebra: Symmetric +using Compat @doc raw""" Exhaustive(; outer = false) @@ -46,10 +47,12 @@ function einexpr(config::Exhaustive, path::SizedEinExpr{L}; cost = BigInt(0)) wh return exhaustive_breadthfirst(Val(config.metric), path, settype, config.outer) elseif config.strategy === :depth init_path = einexpr(Naive(), path) - leader = Ref((; - path = init_path, - cost = mapreduce(config.metric, +, Branches(init_path, inverse = true), init = BigInt(0))::BigInt, - )) + leader = Ref( + @compat (; + path = init_path, + cost = mapreduce(config.metric, +, Branches(init_path, inverse = true), init = BigInt(0))::BigInt, + ) + ) exhaustive_depthfirst(Val(config.metric), path, cost, config.outer, leader) return leader[].path else @@ -67,7 +70,7 @@ function exhaustive_depthfirst( hashyperinds = !isempty(hyperinds(path)), ) where {L,Metric} if nargs(path) <= 2 - leader[] = (; path = path, cost = cost) + leader[] = @compat (; path = path, cost = cost) return end diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl index 4da2f53..729e01e 100644 --- a/src/Optimizers/KaHyPar.jl +++ b/src/Optimizers/KaHyPar.jl @@ -2,6 +2,7 @@ using AbstractTrees using SparseArrays using KaHyPar using Suppressor +using Compat @kwdef struct HyPar <: Optimizer parts::Int = 2 @@ -17,7 +18,7 @@ function EinExprs.einexpr(config::HyPar, path) config.stop(path) && return path inds = mapreduce(head, ∪, path.args) - indexmap = Dict(Iterators.map(splat(Pair) ∘ reverse, enumerate(inds))) + indexmap = Dict(Iterators.map(@compat(splat(Pair)) ∘ reverse, enumerate(inds))) I = Iterators.flatmap(((i, tensor),) -> fill(i, ndims(tensor)), enumerate(path.args)) |> collect J = Iterators.flatmap(tensor -> Iterators.map(Base.Fix1(getindex, indexmap), head(tensor)), path.args) |> collect diff --git a/src/SizedEinExpr.jl b/src/SizedEinExpr.jl index 80fa329..a0c35ec 100644 --- a/src/SizedEinExpr.jl +++ b/src/SizedEinExpr.jl @@ -1,4 +1,5 @@ using AbstractTrees +using Compat struct SizedEinExpr{Label} path::EinExpr{Label} @@ -58,7 +59,7 @@ Base.sum(sexpr::SizedEinExpr, inds) = sum(sexpr.path, inds) function Base.sum(sexpr::Vector{SizedEinExpr{L}}; skip = L[]) where {L} path = sum(map(x -> x.path, sexpr); skip) - size = allequal(Iterators.map(x -> x.size, sexpr)) ? first(sexpr).size : merge(map(x -> x.size, sexpr)...) + size = @compat(allequal(Iterators.map(x -> x.size, sexpr))) ? first(sexpr).size : merge(map(x -> x.size, sexpr)...) # size = merge(map(x -> x.size, sexpr)...) SizedEinExpr(path, size) end diff --git a/src/Slicing.jl b/src/Slicing.jl index f33c727..cb98f1e 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -1,4 +1,5 @@ using AbstractTrees +using Compat """ selectdim(path::EinExpr, index, i) @@ -99,7 +100,7 @@ function findslices( candidates = Set(setdiff(mapreduce(head, ∪, PostOrderDFS(path)), skip)) solution = Set{L}() - current = (; slices = 1, size = maximum(length, PostOrderDFS(path)), overhead = 1.0) + current = @compat (; slices = 1, size = maximum(length, PostOrderDFS(path)), overhead = 1.0) original_flops = mapreduce(flops, +, Branches(path; inverse = true)) sliced_path = path @@ -118,7 +119,7 @@ function findslices( !isnothing(overhead) && cur_overhead > overhead && break push!(solution, winner) - current = (; + current = @compat (; slices = current.slices * Base.size(path, winner), size = maximum(length, PostOrderDFS(sliced_path)), overhead = cur_overhead, diff --git a/test/EinExpr_test.jl b/test/EinExpr_test.jl index a600213..cb9fe71 100644 --- a/test/EinExpr_test.jl +++ b/test/EinExpr_test.jl @@ -47,11 +47,11 @@ tensor = EinExpr([:i, :j]) expr = EinExpr([:i], [tensor]) - @test all(splat(==), zip(expr.head, [:i])) + @test all(@compat(splat(==)), zip(expr.head, [:i])) @test expr.args == [tensor] - @test all(splat(==), zip(head(expr), [:i])) - @test all(splat(==), zip(inds(expr), [:i, :j])) + @test all(@compat(splat(==)), zip(head(expr), [:i])) + @test all(@compat(splat(==)), zip(inds(expr), [:i, :j])) @test isempty(hyperinds(expr)) @test suminds(expr) == [:j] @@ -68,11 +68,11 @@ tensor = EinExpr([:i, :i]) expr = EinExpr([:i], [tensor]) - @test all(splat(==), zip(expr.head, [:i])) + @test all(@compat(splat(==)), zip(expr.head, [:i])) @test expr.args == [tensor] - @test all(splat(==), zip(head(expr), [:i])) - @test all(splat(==), zip(inds(expr), head(expr))) + @test all(@compat(splat(==)), zip(head(expr), [:i])) + @test all(@compat(splat(==)), zip(inds(expr), head(expr))) @test isempty(hyperinds(expr)) @test isempty(suminds(expr)) @@ -91,7 +91,7 @@ @test expr.args == [tensor] @test isempty(head(expr)) - @test all(splat(==), zip(inds(expr), [:i])) + @test all(@compat(splat(==)), zip(inds(expr), [:i])) @test isempty(hyperinds(expr)) @test suminds(expr) == [:i] @@ -106,11 +106,11 @@ tensors = [EinExpr([:i, :j]), EinExpr([:k, :l])] expr = EinExpr([:i, :j, :k, :l], tensors) - @test all(splat(==), zip(expr.head, [:i, :j, :k, :l])) + @test all(@compat(splat(==)), zip(expr.head, [:i, :j, :k, :l])) @test expr.args == tensors - @test all(splat(==), zip(head(expr), mapreduce(collect ∘ inds, vcat, tensors))) - @test all(splat(==), zip(inds(expr), head(expr))) + @test all(@compat(splat(==)), zip(head(expr), mapreduce(collect ∘ inds, vcat, tensors))) + @test all(@compat(splat(==)), zip(inds(expr), head(expr))) @test ndims(expr) == 4 @test isempty(hyperinds(expr)) @@ -135,7 +135,7 @@ @test expr.args == tensors @test isempty(head(expr)) - @test all(splat(==), zip(inds(expr), [:i])) + @test all(@compat(splat(==)), zip(inds(expr), [:i])) @test ndims(expr) == 0 @test isempty(hyperinds(expr)) @@ -154,7 +154,7 @@ @test expr.args == tensors @test isempty(head(expr)) - @test all(splat(==), zip(inds(expr), [:i, :j])) + @test all(@compat(splat(==)), zip(inds(expr), [:i, :j])) @test ndims(expr) == 0 @test isempty(hyperinds(expr)) @@ -172,11 +172,11 @@ tensors = [EinExpr([:i, :k]), EinExpr([:k, :j])] expr = EinExpr([:i, :j], tensors) - @test all(splat(==), zip(expr.head, [:i, :j])) + @test all(@compat(splat(==)), zip(expr.head, [:i, :j])) @test expr.args == tensors - @test all(splat(==), zip(head(expr), [:i, :j])) - @test all(splat(==), zip(inds(expr), [:i, :k, :j])) + @test all(@compat(splat(==)), zip(head(expr), [:i, :j])) + @test all(@compat(splat(==)), zip(inds(expr), [:i, :k, :j])) @test ndims(expr) == 2 @test isempty(hyperinds(expr)) @@ -217,7 +217,7 @@ tensors = [EinExpr([:i, :β, :j]), EinExpr([:k, :β]), EinExpr([:β, :l, :m])] expr = sum(tensors) - @test all(splat(==), zip(expr.head, [:i, :j, :k, :l, :m])) + @test all(@compat(splat(==)), zip(expr.head, [:i, :j, :k, :l, :m])) @test expr.args == tensors @test issetequal(head(expr), [:i, :j, :k, :l, :m]) diff --git a/test/Exhaustive_test.jl b/test/Exhaustive_test.jl index 57bef2a..020266f 100644 --- a/test/Exhaustive_test.jl +++ b/test/Exhaustive_test.jl @@ -19,7 +19,10 @@ @test mapreduce(flops, +, Branches(path)) == 92 - @test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:j], [:i, :h], [:d, :b]])) + @test all( + @compat(splat(issetequal)), + zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:j], [:i, :h], [:d, :b]]), + ) @testset "hyperedges" begin sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m, :β]) @@ -57,6 +60,9 @@ @test mapreduce(flops, +, Branches(path)) == 90 - @test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:d], [:b, :i, :h], [:j]])) + @test all( + @compat(splat(issetequal)), + zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:d], [:b, :i, :h], [:j]]), + ) end end diff --git a/test/Greedy_test.jl b/test/Greedy_test.jl index 7416719..c823a2f 100644 --- a/test/Greedy_test.jl +++ b/test/Greedy_test.jl @@ -17,7 +17,10 @@ @test mapreduce(flops, +, Branches(path)) == 100 - @test all(splat(issetequal), zip(contractorder(path), [[:i, :h], [:j], [:a, :e], [:g, :c], [:f], [:b, :d]])) + @test all( + @compat(splat(issetequal)), + zip(contractorder(path), [[:i, :h], [:j], [:a, :e], [:g, :c], [:f], [:b, :d]]), + ) @testset "example: let unchanged" begin sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m]) diff --git a/test/Naive_test.jl b/test/Naive_test.jl index 2b8f15c..97b1452 100644 --- a/test/Naive_test.jl +++ b/test/Naive_test.jl @@ -20,7 +20,7 @@ # FIXME non-determinist behaviour on order @test all( - splat(issetequal), + @compat(splat(issetequal)), zip(map(suminds, Branches(path)), [Symbol[], [:j], [:a, :e], [:f, :b], [:i, :h], [:d, :g, :c]]), ) diff --git a/test/runtests.jl b/test/runtests.jl index 9248f5f..cbfd110 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using Test using EinExprs +using Compat @testset "Unit tests" verbose = true begin include("EinExpr_test.jl")