diff --git a/src/Tracing.jl b/src/Tracing.jl index 76710c58d..6e16e424d 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -867,7 +867,14 @@ end append_path(path, i) = (path..., i) -function make_tracer(seen, prev::RT, path, mode; toscalar=false, tobatch=nothing) where {RT} +function make_tracer( + seen, + @nospecialize(prev::RT), + @nospecialize(path), + mode; + toscalar=false, + tobatch=nothing, +) where {RT} if haskey(seen, prev) return seen[prev] end @@ -944,7 +951,9 @@ function make_tracer(seen, prev::RT, path, mode; toscalar=false, tobatch=nothing return y end -function make_tracer(seen, prev::ConcreteRArray{T,N}, path, mode; kwargs...) where {T,N} +function make_tracer( + seen, @nospecialize(prev::ConcreteRArray{T,N}), @nospecialize(path), mode; kwargs... +) where {T,N} if mode == ArrayToConcrete return prev end @@ -961,7 +970,12 @@ function make_tracer(seen, prev::ConcreteRArray{T,N}, path, mode; kwargs...) whe end function make_tracer( - seen, prev::TracedRArray{T,N}, path, mode; toscalar=false, tobatch=nothing + seen, + @nospecialize(prev::TracedRArray{T,N}), + @nospecialize(path), + mode; + toscalar=false, + tobatch=nothing, ) where {T,N} if mode == ConcreteToTraced throw("Cannot trace existing trace type") @@ -1000,12 +1014,21 @@ function make_tracer( throw("Cannot Unknown trace mode $mode") end -make_tracer(seen, prev::RT, path, mode; kwargs...) where {RT<:AbstractFloat} = prev +function make_tracer( + seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs... +) where {RT<:AbstractFloat} + return prev +end -make_tracer(seen, prev::Symbol, path, mode; kwargs...) = prev +make_tracer(seen, prev::Symbol, @nospecialize(path), mode; kwargs...) = prev function make_tracer( - seen, prev::Complex{RT}, path, mode; toscalar=false, tobatch=nothing + seen, + @nospecialize(prev::Complex{RT}), + @nospecialize(path), + mode; + toscalar=false, + tobatch=nothing, ) where {RT} return Complex( make_tracer(seen, prev.re, append_path(path, :re), mode; toscalar, tobatch), @@ -1013,7 +1036,9 @@ function make_tracer( ) end -function make_tracer(seen, prev::RT, path, mode; kwargs...) where {RT<:Array} +function make_tracer( + seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs... +) where {RT<:Array} if haskey(seen, prev) return seen[prev] end @@ -1041,7 +1066,9 @@ function make_tracer(seen, prev::RT, path, mode; kwargs...) where {RT<:Array} return newa end -function make_tracer(seen, prev::RT, path, mode; kwargs...) where {RT<:Tuple} +function make_tracer( + seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs... +) where {RT<:Tuple} return ( ( make_tracer(seen, v, append_path(path, i), mode; kwargs...) for @@ -1050,7 +1077,9 @@ function make_tracer(seen, prev::RT, path, mode; kwargs...) where {RT<:Tuple} ) end -function make_tracer(seen, prev::NamedTuple{A,RT}, path, mode; kwargs...) where {A,RT} +function make_tracer( + seen, @nospecialize(prev::NamedTuple{A,RT}), @nospecialize(path), mode; kwargs... +) where {A,RT} return NamedTuple{A,traced_type(RT, (), Val(mode))}(( ( make_tracer( @@ -1060,7 +1089,7 @@ function make_tracer(seen, prev::NamedTuple{A,RT}, path, mode; kwargs...) where )) end -function make_tracer(seen, prev::Core.Box, path, mode; kwargs...) +function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...) if haskey(seen, prev) return seen[prev] end