diff --git a/src/Overlay.jl b/src/Overlay.jl index b9785b7fa..1c35ad188 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -115,3 +115,16 @@ for randfun in (:rand, :randn, :randexp) # end end end + +# Type Conversions +## We don't define it outside our interpreter since it is inconsistent with how it is +## defined in Base for other types like Complex +@reactant_overlay @noinline function (::Type{T})(x::TracedRNumber) where {T<:Number} + return TracedUtils.promote_to(TracedRNumber{T}, x) +end + +@reactant_overlay @noinline function Base.convert( + ::Type{T}, x::TracedRNumber +) where {T<:Number} + return TracedUtils.promote_to(TracedRNumber{T}, x) +end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index df664031e..712cca5d8 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -22,7 +22,7 @@ function Base.eps(::Type{TracedRNumber{T}}) where {T} return TracedUtils.promote_to(TracedRNumber{T}, eps(T)) end -function Base.convert(::Type{<:TracedRNumber{T}}, x::Number) where {T} +function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T} return TracedUtils.promote_to(TracedRNumber{T}, T(x)) end @@ -49,10 +49,6 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S} return TracedRNumber{Base.promote_type(T, S)} end -function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T} - return TracedUtils.promote_to(TracedRNumber{T}, x) -end - TracedRNumber{T}(x::TracedRNumber{T}) where {T} = x function TracedRNumber{T}(x::Number) where {T} diff --git a/test/basic.jl b/test/basic.jl index 31fd841b4..45805484b 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -676,3 +676,15 @@ end fill!(z, 1.0) @test all(==(1.0), Array(z)) end + +@testset "eltype conversion inside interpreter" begin + function test_convert(x::AbstractArray{T}, eta) where {T} + eta = T(eta) + return x .* eta, eta + end + + res = @jit test_convert(ConcreteRArray(rand(4, 2)), ConcreteRNumber(3.0f0)) + + @test res[1] isa ConcreteRArray{Float64,2} + @test res[2] isa ConcreteRNumber{Float64} +end