Skip to content

Commit

Permalink
more fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 9, 2025
1 parent fa2a1df commit 8127ed6
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ end

function Reactant.make_tracer(seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...)
x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))::TracedRArray
return Reactant.make_tracer(seen, x, path, mode; kwargs...)
Reactant.make_tracer(seen, x, path, mode; kwargs...)
return prev
end

function get_field_offset(T::Type, path)
Expand Down Expand Up @@ -416,7 +417,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(

wrapper_tys = MLIR.IR.Type[]
ctx = MLIR.IR.context()
cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1), 1))
cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1))

# linearize kernel arguments
seen = Reactant.OrderedIdDict()
Expand Down Expand Up @@ -615,7 +616,7 @@ end

function Reactant.traced_type(
::Type{A}, seen::ST, ::Val{mode}, track_numbers
) where {T,N,A<:CuTracedArray,ST,mode}
) where {A<:CuTracedArray,ST,mode}
return A
end

Expand Down

0 comments on commit 8127ed6

Please sign in to comment.