Skip to content

Commit

Permalink
feat: overlay eltype conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 24, 2024
1 parent a02fd5b commit 29f8709
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
13 changes: 13 additions & 0 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,16 @@ for randfun in (:rand, :randn, :randexp)
# end
end
end

# Type Conversions
## We don't define it outside our interpreter since it is inconsistent with how it is
## defined in Base for other types like Complex
@reactant_overlay @noinline function (::Type{T})(x::TracedRNumber) where {T<:Number}
return TracedUtils.promote_to(TracedRNumber{T}, x)
end

@reactant_overlay @noinline function Base.convert(
::Type{T}, x::TracedRNumber
) where {T<:Number}
return TracedUtils.promote_to(TracedRNumber{T}, x)
end
6 changes: 1 addition & 5 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function Base.eps(::Type{TracedRNumber{T}}) where {T}
return TracedUtils.promote_to(TracedRNumber{T}, eps(T))
end

function Base.convert(::Type{<:TracedRNumber{T}}, x::Number) where {T}
function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T}
return TracedUtils.promote_to(TracedRNumber{T}, T(x))
end

Expand All @@ -49,10 +49,6 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S}
return TracedRNumber{Base.promote_type(T, S)}
end

function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T}
return TracedUtils.promote_to(TracedRNumber{T}, x)
end

TracedRNumber{T}(x::TracedRNumber{T}) where {T} = x

function TracedRNumber{T}(x::Number) where {T}
Expand Down
12 changes: 12 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -676,3 +676,15 @@ end
fill!(z, 1.0)
@test all(==(1.0), Array(z))
end

@testset "eltype conversion inside interpreter" begin
function test_convert(x::AbstractArray{T}, eta) where {T}
eta = T(eta)
return x .* eta, eta
end

res = @jit test_convert(ConcreteRArray(rand(4, 2)), ConcreteRNumber(3.0f0))

@test res[1] isa ConcreteRArray{Float64,2}
@test res[2] isa ConcreteRNumber{Float64}
end

0 comments on commit 29f8709

Please sign in to comment.