From 2f536016312d6391d581fdb78296cabd4c9211c5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 24 Dec 2024 21:05:02 +0530 Subject: [PATCH] fix: overload the main methods --- src/Overlay.jl | 23 ----------------------- src/TracedRNumber.jl | 12 +++++++----- 2 files changed, 7 insertions(+), 28 deletions(-) diff --git a/src/Overlay.jl b/src/Overlay.jl index 789b393ff..b9785b7fa 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -115,26 +115,3 @@ for randfun in (:rand, :randn, :randexp) # end end end - -# Type Conversions -## We don't define it outside our interpreter since it is inconsistent with how it is -## defined in Base for other types like Complex -@reactant_overlay @noinline function (::Type{T})(x::TracedRNumber) where {T<:Number} - return TracedUtils.promote_to(TracedRNumber{T}, x) -end - -@reactant_overlay @noinline function (::Type{TracedRNumber{T}})(x::TracedRNumber) where {T} - return TracedUtils.promote_to(TracedRNumber{T}, x) -end - -@reactant_overlay @noinline function Base.convert( - ::Type{T}, x::TracedRNumber -) where {T<:Number} - return TracedUtils.promote_to(TracedRNumber{T}, x) -end - -@reactant_overlay @noinline function Base.convert( - ::Type{TracedRNumber{T}}, x::TracedRNumber -) where {T<:Number} - return TracedUtils.promote_to(TracedRNumber{T}, x) -end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 712cca5d8..8f1a97015 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -22,10 +22,6 @@ function Base.eps(::Type{TracedRNumber{T}}) where {T} return TracedUtils.promote_to(TracedRNumber{T}, eps(T)) end -function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T} - return TracedUtils.promote_to(TracedRNumber{T}, T(x)) -end - function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}} return print(io, "TracedRNumber{", T, "}(", X.paths, ")") end @@ -49,9 +45,15 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S} return TracedRNumber{Base.promote_type(T, S)} end +# NOTE: This is inconsistent with the behavior of `convert` but we do it since it is a very +# common usecase TracedRNumber{T}(x::TracedRNumber{T}) where {T} = x +TracedRNumber{T}(x::TracedRNumber) where {T} = TracedUtils.promote_to(TracedRNumber{T}, x) +TracedRNumber{T}(x::Number) where {T} = TracedUtils.promote_to(TracedRNumber{T}, x) -function TracedRNumber{T}(x::Number) where {T} +(T::Type{<:Number})(x::TracedRNumber) = TracedUtils.promote_to(TracedRNumber{T}, x) + +function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T} return TracedUtils.promote_to(TracedRNumber{T}, x) end