diff --git a/src/Reactant.jl b/src/Reactant.jl index 9eb3d564d..9bc604179 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -167,20 +167,6 @@ end include("Tracing.jl") -@inline val_value(::Val{T}) where {T} = T -@inline val_value(::Type{Val{T}}) where {T} = T - -@inline getmap(::Val{T}) where {T} = nothing -@inline getmap(::Val{T}, a, b, args...) where {T} = getmap(Val(T), args...) -@inline getmap(::Val{T}, ::Val{T}, ::Val{T2}, args...) where {T,T2} = T2 - -@inline is_concrete_tuple(x::T2) where {T2} = - (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) - -function append_path(path, i) - return (path..., i) -end - struct MakeConcreteRArray{T} end struct MakeArray{AT,Vals} end struct MakeString{AT,Val} end diff --git a/src/Tracing.jl b/src/Tracing.jl index 6608e1394..0d75dd138 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -695,6 +695,9 @@ function traced_type(::Type{T}, seen, mode) where {T<:Function} return Core.apply_type(T.name.wrapper, traced_fieldtypes...) end +@inline is_concrete_tuple(x::T2) where {T2} = + (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) + function traced_type(::Type{T}, seen, mode) where {T<:Tuple} if !Base.isconcretetype(T) || !is_concrete_tuple(T) || T isa UnionAll throw(AssertionError("Type $T is not concrete type or concrete tuple")) @@ -714,6 +717,10 @@ function traced_type(::Type{T}, seen, mode) where {K,V,T<:AbstractDict{K,V}} return dictty{K,traced_type(V, seen, mode)} end +@inline getmap(::Val{T}) where {T} = nothing +@inline getmap(::Val{T}, a, b, args...) where {T} = getmap(Val(T), args...) +@inline getmap(::Val{T}, ::Val{T}, ::Val{T2}, args...) where {T,T2} = T2 + function traced_type(::Type{T}, seen, mode) where {T} if T === Any return T @@ -856,6 +863,8 @@ function traced_type(::Type{Val{T}}, seen, mode) where {T} throw("Val type $T cannot be traced") end +append_path(path, i) = (path..., i) + function make_tracer(seen, prev::RT, path, mode; toscalar=false, tobatch=nothing) where {RT} if haskey(seen, prev) return seen[prev]