Skip to content

Commit

Permalink
Use Compat.@compat to forward-support Julia features
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Feb 7, 2024
1 parent bb4b827 commit 98647fb
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 30 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.6.2"
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
Expand All @@ -29,11 +30,13 @@ EinExprsMakieExt = "Makie"
[compat]
AbstractTrees = "0.4"
Combinatorics = "1.0"
Compat = "4"
DataStructures = "0.18"
FiniteDifferences = "0.12"
GraphMakie = "0.5"
Graphs = "1.6"
KaHyPar = "0.3.1"
Makie = "0.19"
PackageExtensionCompat = "1"
Suppressor = "0.2"
julia = "1.6"
3 changes: 2 additions & 1 deletion src/EinExpr.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Base: AbstractVecOrTuple
using DataStructures: DefaultDict
using AbstractTrees
using Compat

Base.@kwdef struct EinExpr{Label}
head::Vector{Label}
Expand Down Expand Up @@ -90,7 +91,7 @@ Base.ndims(path::EinExpr) = length(head(path))
Return the size of the resulting tensor from contracting `path`. If `index` is specified, return the size of such index.
"""
Base.size(path::EinExpr, sizedict) = (sizedict[i] for i in head(path)) |> splat(tuple)
Base.size(path::EinExpr, sizedict) = (sizedict[i] for i in head(path)) |> @compat(splat(tuple))
Base.length(path::EinExpr, sizedict) = (prod size)(path, sizedict)

"""
Expand Down
13 changes: 8 additions & 5 deletions src/Optimizers/Exhaustive.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Base: @kwdef
using Combinatorics
using LinearAlgebra: Symmetric
using Compat

@doc raw"""
Exhaustive(; outer = false)
Expand Down Expand Up @@ -46,10 +47,12 @@ function einexpr(config::Exhaustive, path::SizedEinExpr{L}; cost = BigInt(0)) wh
return exhaustive_breadthfirst(Val(config.metric), path, settype, config.outer)
elseif config.strategy === :depth
init_path = einexpr(Naive(), path)
leader = Ref((;
path = init_path,
cost = mapreduce(config.metric, +, Branches(init_path, inverse = true), init = BigInt(0))::BigInt,
))
leader = Ref(
@compat (;
path = init_path,
cost = mapreduce(config.metric, +, Branches(init_path, inverse = true), init = BigInt(0))::BigInt,
)
)
exhaustive_depthfirst(Val(config.metric), path, cost, config.outer, leader)
return leader[].path
else
Expand All @@ -67,7 +70,7 @@ function exhaustive_depthfirst(
hashyperinds = !isempty(hyperinds(path)),
) where {L,Metric}
if nargs(path) <= 2
leader[] = (; path = path, cost = cost)
leader[] = @compat (; path = path, cost = cost)
return
end

Expand Down
3 changes: 2 additions & 1 deletion src/Optimizers/KaHyPar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using AbstractTrees
using SparseArrays
using KaHyPar
using Suppressor
using Compat

@kwdef struct HyPar <: Optimizer
parts::Int = 2
Expand All @@ -17,7 +18,7 @@ function EinExprs.einexpr(config::HyPar, path)
config.stop(path) && return path

inds = mapreduce(head, , path.args)
indexmap = Dict(Iterators.map(splat(Pair) reverse, enumerate(inds)))
indexmap = Dict(Iterators.map(@compat(splat(Pair)) reverse, enumerate(inds)))

I = Iterators.flatmap(((i, tensor),) -> fill(i, ndims(tensor)), enumerate(path.args)) |> collect
J = Iterators.flatmap(tensor -> Iterators.map(Base.Fix1(getindex, indexmap), head(tensor)), path.args) |> collect
Expand Down
3 changes: 2 additions & 1 deletion src/SizedEinExpr.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using AbstractTrees
using Compat

struct SizedEinExpr{Label}
path::EinExpr{Label}
Expand Down Expand Up @@ -58,7 +59,7 @@ Base.sum(sexpr::SizedEinExpr, inds) = sum(sexpr.path, inds)

function Base.sum(sexpr::Vector{SizedEinExpr{L}}; skip = L[]) where {L}
path = sum(map(x -> x.path, sexpr); skip)
size = allequal(Iterators.map(x -> x.size, sexpr)) ? first(sexpr).size : merge(map(x -> x.size, sexpr)...)
size = @compat(allequal(Iterators.map(x -> x.size, sexpr))) ? first(sexpr).size : merge(map(x -> x.size, sexpr)...)
# size = merge(map(x -> x.size, sexpr)...)
SizedEinExpr(path, size)
end
Expand Down
5 changes: 3 additions & 2 deletions src/Slicing.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using AbstractTrees
using Compat

"""
selectdim(path::EinExpr, index, i)
Expand Down Expand Up @@ -99,7 +100,7 @@ function findslices(

candidates = Set(setdiff(mapreduce(head, , PostOrderDFS(path)), skip))
solution = Set{L}()
current = (; slices = 1, size = maximum(length, PostOrderDFS(path)), overhead = 1.0)
current = @compat (; slices = 1, size = maximum(length, PostOrderDFS(path)), overhead = 1.0)
original_flops = mapreduce(flops, +, Branches(path; inverse = true))

sliced_path = path
Expand All @@ -118,7 +119,7 @@ function findslices(
!isnothing(overhead) && cur_overhead > overhead && break
push!(solution, winner)

current = (;
current = @compat (;
slices = current.slices * Base.size(path, winner),
size = maximum(length, PostOrderDFS(sliced_path)),
overhead = cur_overhead,
Expand Down
32 changes: 16 additions & 16 deletions test/EinExpr_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@
tensor = EinExpr([:i, :j])
expr = EinExpr([:i], [tensor])

@test all(splat(==), zip(expr.head, [:i]))
@test all(@compat(splat(==)), zip(expr.head, [:i]))
@test expr.args == [tensor]

@test all(splat(==), zip(head(expr), [:i]))
@test all(splat(==), zip(inds(expr), [:i, :j]))
@test all(@compat(splat(==)), zip(head(expr), [:i]))
@test all(@compat(splat(==)), zip(inds(expr), [:i, :j]))

@test isempty(hyperinds(expr))
@test suminds(expr) == [:j]
Expand All @@ -68,11 +68,11 @@
tensor = EinExpr([:i, :i])
expr = EinExpr([:i], [tensor])

@test all(splat(==), zip(expr.head, [:i]))
@test all(@compat(splat(==)), zip(expr.head, [:i]))
@test expr.args == [tensor]

@test all(splat(==), zip(head(expr), [:i]))
@test all(splat(==), zip(inds(expr), head(expr)))
@test all(@compat(splat(==)), zip(head(expr), [:i]))
@test all(@compat(splat(==)), zip(inds(expr), head(expr)))

@test isempty(hyperinds(expr))
@test isempty(suminds(expr))
Expand All @@ -91,7 +91,7 @@
@test expr.args == [tensor]

@test isempty(head(expr))
@test all(splat(==), zip(inds(expr), [:i]))
@test all(@compat(splat(==)), zip(inds(expr), [:i]))

@test isempty(hyperinds(expr))
@test suminds(expr) == [:i]
Expand All @@ -106,11 +106,11 @@
tensors = [EinExpr([:i, :j]), EinExpr([:k, :l])]
expr = EinExpr([:i, :j, :k, :l], tensors)

@test all(splat(==), zip(expr.head, [:i, :j, :k, :l]))
@test all(@compat(splat(==)), zip(expr.head, [:i, :j, :k, :l]))
@test expr.args == tensors

@test all(splat(==), zip(head(expr), mapreduce(collect inds, vcat, tensors)))
@test all(splat(==), zip(inds(expr), head(expr)))
@test all(@compat(splat(==)), zip(head(expr), mapreduce(collect inds, vcat, tensors)))
@test all(@compat(splat(==)), zip(inds(expr), head(expr)))
@test ndims(expr) == 4

@test isempty(hyperinds(expr))
Expand All @@ -135,7 +135,7 @@
@test expr.args == tensors

@test isempty(head(expr))
@test all(splat(==), zip(inds(expr), [:i]))
@test all(@compat(splat(==)), zip(inds(expr), [:i]))
@test ndims(expr) == 0

@test isempty(hyperinds(expr))
Expand All @@ -154,7 +154,7 @@
@test expr.args == tensors

@test isempty(head(expr))
@test all(splat(==), zip(inds(expr), [:i, :j]))
@test all(@compat(splat(==)), zip(inds(expr), [:i, :j]))
@test ndims(expr) == 0

@test isempty(hyperinds(expr))
Expand All @@ -172,11 +172,11 @@
tensors = [EinExpr([:i, :k]), EinExpr([:k, :j])]
expr = EinExpr([:i, :j], tensors)

@test all(splat(==), zip(expr.head, [:i, :j]))
@test all(@compat(splat(==)), zip(expr.head, [:i, :j]))
@test expr.args == tensors

@test all(splat(==), zip(head(expr), [:i, :j]))
@test all(splat(==), zip(inds(expr), [:i, :k, :j]))
@test all(@compat(splat(==)), zip(head(expr), [:i, :j]))
@test all(@compat(splat(==)), zip(inds(expr), [:i, :k, :j]))
@test ndims(expr) == 2

@test isempty(hyperinds(expr))
Expand Down Expand Up @@ -217,7 +217,7 @@
tensors = [EinExpr([:i, , :j]), EinExpr([:k, ]), EinExpr([, :l, :m])]
expr = sum(tensors)

@test all(splat(==), zip(expr.head, [:i, :j, :k, :l, :m]))
@test all(@compat(splat(==)), zip(expr.head, [:i, :j, :k, :l, :m]))
@test expr.args == tensors

@test issetequal(head(expr), [:i, :j, :k, :l, :m])
Expand Down
10 changes: 8 additions & 2 deletions test/Exhaustive_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

@test mapreduce(flops, +, Branches(path)) == 92

@test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:j], [:i, :h], [:d, :b]]))
@test all(
@compat(splat(issetequal)),
zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:j], [:i, :h], [:d, :b]]),
)

@testset "hyperedges" begin
sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m, ])
Expand Down Expand Up @@ -57,6 +60,9 @@

@test mapreduce(flops, +, Branches(path)) == 90

@test all(splat(issetequal), zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:d], [:b, :i, :h], [:j]]))
@test all(
@compat(splat(issetequal)),
zip(contractorder(path), [[:a, :e], [:c, :g], [:f], [:d], [:b, :i, :h], [:j]]),
)
end
end
5 changes: 4 additions & 1 deletion test/Greedy_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

@test mapreduce(flops, +, Branches(path)) == 100

@test all(splat(issetequal), zip(contractorder(path), [[:i, :h], [:j], [:a, :e], [:g, :c], [:f], [:b, :d]]))
@test all(
@compat(splat(issetequal)),
zip(contractorder(path), [[:i, :h], [:j], [:a, :e], [:g, :c], [:f], [:b, :d]]),
)

@testset "example: let unchanged" begin
sizedict = Dict(i => 2 for i in [:i, :j, :k, :l, :m])
Expand Down
2 changes: 1 addition & 1 deletion test/Naive_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

# FIXME non-determinist behaviour on order
@test all(
splat(issetequal),
@compat(splat(issetequal)),
zip(map(suminds, Branches(path)), [Symbol[], [:j], [:a, :e], [:f, :b], [:i, :h], [:d, :g, :c]]),
)

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test
using EinExprs
using Compat

@testset "Unit tests" verbose = true begin
include("EinExpr_test.jl")
Expand Down

0 comments on commit 98647fb

Please sign in to comment.