Skip to content

Commit

Permalink
traced type
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 8, 2025
1 parent a93e455 commit fa2a1df
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,12 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction(
return Core.Typeof(res)(f, res.entry)
end

function Reactant.traced_type(
::Type{A}, seen::ST, ::Val{mode}, track_numbers
) where {T,N,A<:CuTracedArray,ST,mode}
return A
end

function Reactant.traced_type(
::Type{A}, seen::ST, ::Val{mode}, track_numbers
) where {T,N,A<:CUDA.CuArray{T,N},ST,mode}
Expand Down

0 comments on commit fa2a1df

Please sign in to comment.