Skip to content

Commit

Permalink
Some minor refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergio Sánchez Ramírez committed Jul 20, 2024
1 parent c3d4367 commit 2b3b277
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 14 deletions.
14 changes: 0 additions & 14 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,20 +167,6 @@ end

include("Tracing.jl")

@inline val_value(::Val{T}) where {T} = T
@inline val_value(::Type{Val{T}}) where {T} = T

@inline getmap(::Val{T}) where {T} = nothing
@inline getmap(::Val{T}, a, b, args...) where {T} = getmap(Val(T), args...)
@inline getmap(::Val{T}, ::Val{T}, ::Val{T2}, args...) where {T,T2} = T2

@inline is_concrete_tuple(x::T2) where {T2} =
(x <: Tuple) && !(x === Tuple) && !(x isa UnionAll)

function append_path(path, i)
return (path..., i)
end

struct MakeConcreteRArray{T} end
struct MakeArray{AT,Vals} end
struct MakeString{AT,Val} end
Expand Down
9 changes: 9 additions & 0 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,9 @@ function traced_type(::Type{T}, seen, mode) where {T<:Function}
return Core.apply_type(T.name.wrapper, traced_fieldtypes...)
end

@inline is_concrete_tuple(x::T2) where {T2} =
(x <: Tuple) && !(x === Tuple) && !(x isa UnionAll)

function traced_type(::Type{T}, seen, mode) where {T<:Tuple}
if !Base.isconcretetype(T) || !is_concrete_tuple(T) || T isa UnionAll
throw(AssertionError("Type $T is not concrete type or concrete tuple"))
Expand All @@ -714,6 +717,10 @@ function traced_type(::Type{T}, seen, mode) where {K,V,T<:AbstractDict{K,V}}
return dictty{K,traced_type(V, seen, mode)}
end

@inline getmap(::Val{T}) where {T} = nothing
@inline getmap(::Val{T}, a, b, args...) where {T} = getmap(Val(T), args...)
@inline getmap(::Val{T}, ::Val{T}, ::Val{T2}, args...) where {T,T2} = T2

function traced_type(::Type{T}, seen, mode) where {T}
if T === Any
return T
Expand Down Expand Up @@ -856,6 +863,8 @@ function traced_type(::Type{Val{T}}, seen, mode) where {T}
throw("Val type $T cannot be traced")
end

append_path(path, i) = (path..., i)

function make_tracer(seen, prev::RT, path, mode; toscalar=false, tobatch=nothing) where {RT}
if haskey(seen, prev)
return seen[prev]
Expand Down

0 comments on commit 2b3b277

Please sign in to comment.