Skip to content

Commit

Permalink
Minimize specialization of make_tracer methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Jul 24, 2024
1 parent 9804dee commit 4e1824b
Showing 1 changed file with 39 additions and 10 deletions.
49 changes: 39 additions & 10 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,14 @@ end

append_path(path, i) = (path..., i)

function make_tracer(seen, prev::RT, path, mode; toscalar=false, tobatch=nothing) where {RT}
function make_tracer(
seen,
@nospecialize(prev::RT),
@nospecialize(path),
mode;
toscalar=false,
tobatch=nothing,
) where {RT}
if haskey(seen, prev)
return seen[prev]
end
Expand Down Expand Up @@ -944,7 +951,9 @@ function make_tracer(seen, prev::RT, path, mode; toscalar=false, tobatch=nothing
return y
end

function make_tracer(seen, prev::ConcreteRArray{T,N}, path, mode; kwargs...) where {T,N}
function make_tracer(
seen, @nospecialize(prev::ConcreteRArray{T,N}), @nospecialize(path), mode; kwargs...
) where {T,N}
if mode == ArrayToConcrete
return prev
end
Expand All @@ -961,7 +970,12 @@ function make_tracer(seen, prev::ConcreteRArray{T,N}, path, mode; kwargs...) whe
end

function make_tracer(
seen, prev::TracedRArray{T,N}, path, mode; toscalar=false, tobatch=nothing
seen,
@nospecialize(prev::TracedRArray{T,N}),
@nospecialize(path),
mode;
toscalar=false,
tobatch=nothing,
) where {T,N}
if mode == ConcreteToTraced
throw("Cannot trace existing trace type")
Expand Down Expand Up @@ -1000,20 +1014,31 @@ function make_tracer(
throw("Cannot Unknown trace mode $mode")
end

make_tracer(seen, prev::RT, path, mode; kwargs...) where {RT<:AbstractFloat} = prev
function make_tracer(
seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs...
) where {RT<:AbstractFloat}
return prev
end

make_tracer(seen, prev::Symbol, path, mode; kwargs...) = prev
make_tracer(seen, prev::Symbol, @nospecialize(path), mode; kwargs...) = prev

function make_tracer(
seen, prev::Complex{RT}, path, mode; toscalar=false, tobatch=nothing
seen,
@nospecialize(prev::Complex{RT}),
@nospecialize(path),
mode;
toscalar=false,
tobatch=nothing,
) where {RT}
return Complex(
make_tracer(seen, prev.re, append_path(path, :re), mode; toscalar, tobatch),
make_tracer(seen, prev.im, append_path(path, :im), mode; toscalar, tobatch),
)
end

function make_tracer(seen, prev::RT, path, mode; kwargs...) where {RT<:Array}
function make_tracer(
seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs...
) where {RT<:Array}
if haskey(seen, prev)
return seen[prev]
end
Expand Down Expand Up @@ -1041,7 +1066,9 @@ function make_tracer(seen, prev::RT, path, mode; kwargs...) where {RT<:Array}
return newa
end

function make_tracer(seen, prev::RT, path, mode; kwargs...) where {RT<:Tuple}
function make_tracer(
seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs...
) where {RT<:Tuple}
return (
(
make_tracer(seen, v, append_path(path, i), mode; kwargs...) for
Expand All @@ -1050,7 +1077,9 @@ function make_tracer(seen, prev::RT, path, mode; kwargs...) where {RT<:Tuple}
)
end

function make_tracer(seen, prev::NamedTuple{A,RT}, path, mode; kwargs...) where {A,RT}
function make_tracer(
seen, @nospecialize(prev::NamedTuple{A,RT}), @nospecialize(path), mode; kwargs...
) where {A,RT}
return NamedTuple{A,traced_type(RT, (), Val(mode))}((
(
make_tracer(
Expand All @@ -1060,7 +1089,7 @@ function make_tracer(seen, prev::NamedTuple{A,RT}, path, mode; kwargs...) where
))
end

function make_tracer(seen, prev::Core.Box, path, mode; kwargs...)
function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...)
if haskey(seen, prev)
return seen[prev]
end
Expand Down

0 comments on commit 4e1824b

Please sign in to comment.