From fa2a1dfa9ae2ac599762b30d5c0c1efa550d7247 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 8 Jan 2025 18:48:46 -0500 Subject: [PATCH] traced type --- ext/ReactantCUDAExt.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 6d7570b68..aa2cbb413 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -613,6 +613,12 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( return Core.Typeof(res)(f, res.entry) end +function Reactant.traced_type( + ::Type{A}, seen::ST, ::Val{mode}, track_numbers +) where {T,N,A<:CuTracedArray,ST,mode} + return A +end + function Reactant.traced_type( ::Type{A}, seen::ST, ::Val{mode}, track_numbers ) where {T,N,A<:CUDA.CuArray{T,N},ST,mode}