Skip to content

Commit

Permalink
fix: overload the main methods
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 24, 2024
1 parent 190148b commit 2f53601
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 28 deletions.
23 changes: 0 additions & 23 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,3 @@ for randfun in (:rand, :randn, :randexp)
# end
end
end

# Type Conversions
## We don't define it outside our interpreter since it is inconsistent with how it is
## defined in Base for other types like Complex
@reactant_overlay @noinline function (::Type{T})(x::TracedRNumber) where {T<:Number}
return TracedUtils.promote_to(TracedRNumber{T}, x)
end

@reactant_overlay @noinline function (::Type{TracedRNumber{T}})(x::TracedRNumber) where {T}
return TracedUtils.promote_to(TracedRNumber{T}, x)
end

@reactant_overlay @noinline function Base.convert(
::Type{T}, x::TracedRNumber
) where {T<:Number}
return TracedUtils.promote_to(TracedRNumber{T}, x)
end

@reactant_overlay @noinline function Base.convert(
::Type{TracedRNumber{T}}, x::TracedRNumber
) where {T<:Number}
return TracedUtils.promote_to(TracedRNumber{T}, x)
end
12 changes: 7 additions & 5 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ function Base.eps(::Type{TracedRNumber{T}}) where {T}
return TracedUtils.promote_to(TracedRNumber{T}, eps(T))
end

function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T}
return TracedUtils.promote_to(TracedRNumber{T}, T(x))
end

function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}}
return print(io, "TracedRNumber{", T, "}(", X.paths, ")")
end
Expand All @@ -49,9 +45,15 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S}
return TracedRNumber{Base.promote_type(T, S)}
end

# NOTE: This is inconsistent with the behavior of `convert` but we do it since it is a very
# common usecase
TracedRNumber{T}(x::TracedRNumber{T}) where {T} = x
TracedRNumber{T}(x::TracedRNumber) where {T} = TracedUtils.promote_to(TracedRNumber{T}, x)
TracedRNumber{T}(x::Number) where {T} = TracedUtils.promote_to(TracedRNumber{T}, x)

function TracedRNumber{T}(x::Number) where {T}
(T::Type{<:Number})(x::TracedRNumber) = TracedUtils.promote_to(TracedRNumber{T}, x)

function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T}
return TracedUtils.promote_to(TracedRNumber{T}, x)
end

Expand Down

0 comments on commit 2f53601

Please sign in to comment.