diff --git a/Project.toml b/Project.toml index 9af7dafef..9a1277d9d 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ Scratch = "6c6a2e73-6563-6170-7368-637461726353" [weakdeps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" @@ -31,6 +32,7 @@ path = "lib/ReactantCore" [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" +ReactantCUDAExt = "CUDA" ReactantNNlibExt = "NNlib" ReactantStatisticsExt = "Statistics" ReactantYaoBlocksExt = "YaoBlocks" @@ -58,4 +60,5 @@ julia = "1.10" [extras] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index f93b32ea4..3ae7a7ebf 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -470,6 +470,10 @@ extern "C" void RegisterDialects(MlirContext cctx) { context.loadDialect(); context.loadDialect(); } + +#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::DialectRegistry ®istry = *unwrap(creg); @@ -513,6 +517,11 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::affine::registerAffinePasses(); mlir::registerReconcileUnrealizedCasts(); + mlir::registerLLVMDialectImport(registry); + mlir::registerNVVMDialectImport(registry); + + mlir::LLVM::registerInlinerInterface(registry); + /* registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { LLVM::LLVMFunctionType::attachInterface(*ctx); @@ -540,6 +549,81 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::enzyme::registerEnzymeJaxTransformExtension(registry); } + +/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric +/// suffix in `lastUsedID`. +static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName, + unsigned &lastUsedID, + mlir::ModuleOp source, + mlir::ModuleOp target) { + using namespace llvm; + using namespace mlir; + SmallString<64> newSymName(oldSymName); + newSymName.push_back('_'); + while (true) { + auto possible = newSymName + Twine(++lastUsedID); + if (!SymbolTable::lookupSymbolIn(source, possible.str()) && !SymbolTable::lookupSymbolIn(target, possible.str())) { + return StringAttr::get(target.getContext(), possible); + } + } +} + + +/// Checks if a symbol with the same name as `op` already exists in `source`. +/// If so, renames `op` and updates all its references in `target`. +static mlir::LogicalResult +updateSymbolAndAllUses(mlir::SymbolOpInterface op, mlir::ModuleOp source, mlir::ModuleOp target, + unsigned &lastUsedID) { + using namespace llvm; + using namespace mlir; + + auto opName = op.getName().str(); + + if (!SymbolTable::lookupSymbolIn(target, opName)) { + return success(); + } + + StringAttr newSymName = + renameSymbol(opName, lastUsedID, source, target); + + if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, source))) + return op.emitError("unable to update all symbol uses for ") + << opName << " to " << newSymName; + + SymbolTable::setSymbolName(op, newSymName); + return success(); +} + +extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, const char* entryfn) { + auto prevMod = cast(*unwrap(prevModC)); + auto newMod = cast(*unwrap(newModC)); + + Operation* entryFn = nullptr; + + unsigned lastUsedID = 0; + + for (auto &op : *newMod.getBody()) { + auto symbolOp = dyn_cast(op); + if (!symbolOp) + continue; + + StringRef oldSymName = symbolOp.getName(); + + if (oldSymName == entryfn) { + entryFn = &op; + } + + if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, + lastUsedID))) { + assert(0 && "failed to update all uses"); + } + SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private); + } + prevMod.getBody()->getOperations().splice(prevMod.getBody()->getOperations().end(), + newMod.getBody()->getOperations()); + return wrap(entryFn); +} + #pragma region xla::ifrt #pragma region xla::ifrt::Value diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index c718304bd..c538bbb8a 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -416,6 +416,7 @@ cc_library( "-Wl,-exported_symbol,_BufferToHost", "-Wl,-exported_symbol,_FreeClient", "-Wl,-exported_symbol,_ClientCompile", +"-Wl,-exported_symbol,_LinkInModule", "-Wl,-exported_symbol,_FreeFuture", "-Wl,-exported_symbol,_FutureIsReady", "-Wl,-exported_symbol,_FutureAwait", @@ -451,6 +452,10 @@ cc_library( "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:LLVMIRToLLVMTranslation", + "@llvm-project//mlir:LLVMIRToNVVMTranslation", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//llvm:IRReader", "@llvm-project//llvm:Support", "@llvm-project//llvm:AArch64AsmParser", diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl new file mode 100644 index 000000000..b38b55001 --- /dev/null +++ b/ext/ReactantCUDAExt.jl @@ -0,0 +1,408 @@ +module ReactantCUDAExt + +using CUDA +using Reactant: + Reactant, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber +using ReactantCore: @trace + +using Adapt + +struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N} + ptr::Core.LLVMPtr{T,A} +end + + +Base.show(io::IO, a::AT) where AT <: CuTracedArray = + CUDA.Printf.@printf(io, "%s cu traced array at %p", join(size(a), '×'), Int(pointer(a))) + +## array interface + +Base.elsize(::Type{<:CuTracedArray{T}}) where {T} = sizeof(T) +Base.size(g::CuTracedArray{T,N,A,Size}) where {T,N,A,Size} = Size +Base.sizeof(x::CuTracedArray) = Base.elsize(x) * length(x) +Base.pointer(x::CuTracedArray{T,<:Any,A}) where {T,A} = Base.unsafe_convert(Core.LLVMPtr{T,A}, x) +@inline function Base.pointer(x::CuTracedArray{T,<:Any,A}, i::Integer) where {T,A} + Base.unsafe_convert(Core.LLVMPtr{T,A}, x) + Base._memory_offset(x, i) +end + + +## conversions + +Base.unsafe_convert(::Type{Core.LLVMPtr{T,A}}, x::CuTracedArray{T,<:Any,A}) where {T,A} = + x.ptr + + +## indexing intrinsics + +CUDA.@device_function @inline function arrayref(A::CuTracedArray{T}, index::Integer) where {T} + @boundscheck checkbounds(A, index) + if Base.isbitsunion(T) + arrayref_union(A, index) + else + arrayref_bits(A, index) + end +end + +@inline function arrayref_bits(A::CuTracedArray{T}, index::Integer) where {T} + unsafe_load(pointer(A), index) +end + +@inline @generated function arrayref_union(A::CuTracedArray{T,<:Any,AS}, index::Integer) where {T,AS} + typs = Base.uniontypes(T) + + # generate code that conditionally loads a value based on the selector value. + # lacking noreturn, we return T to avoid inference thinking this can return Nothing. + ex = :(Base.llvmcall("unreachable", $T, Tuple{})) + for (sel, typ) in Iterators.reverse(enumerate(typs)) + ex = quote + if selector == $(sel-1) + ptr = reinterpret(Core.LLVMPtr{$typ,AS}, data_ptr) + unsafe_load(ptr, 1) + else + $ex + end + end + end + + quote + selector_ptr = typetagdata(A, index) + selector = unsafe_load(selector_ptr) + + data_ptr = pointer(A, index) + + return $ex + end +end + +CUDA.@device_function @inline function arrayset(A::CuTracedArray{T}, x::T, index::Integer) where {T} + @boundscheck checkbounds(A, index) + if Base.isbitsunion(T) + arrayset_union(A, x, index) + else + arrayset_bits(A, x, index) + end + return A +end + +@inline function arrayset_bits(A::CuTracedArray{T}, x::T, index::Integer) where {T} + unsafe_store!(pointer(A), x, index) +end + +@inline @generated function arrayset_union(A::CuTracedArray{T,<:Any,AS}, x::T, index::Integer) where {T,AS} + typs = Base.uniontypes(T) + sel = findfirst(isequal(x), typs) + + quote + selector_ptr = typetagdata(A, index) + unsafe_store!(selector_ptr, $(UInt8(sel-1))) + + data_ptr = pointer(A, index) + + unsafe_store!(reinterpret(Core.LLVMPtr{$x,AS}, data_ptr), x, 1) + return + end +end + +CUDA.@device_function @inline function const_arrayref(A::CuTracedArray{T}, index::Integer) where {T} + @boundscheck checkbounds(A, index) + unsafe_cached_load(pointer(A), index) +end + + +## indexing + +Base.IndexStyle(::Type{<:CuTracedArray}) = Base.IndexLinear() + +Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} = + arrayref(A, i1) +Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} = + arrayset(A, convert(T,x)::T, i1) + +# preserve the specific integer type when indexing device arrays, +# to avoid extending 32-bit hardware indices to 64-bit. +Base.to_index(::CuTracedArray, i::Integer) = i + +# Base doesn't like Integer indices, so we need our own ND get and setindex! routines. +# See also: https://github.com/JuliaLang/julia/pull/42289 +Base.@propagate_inbounds Base.getindex(A::CuTracedArray, + I::Union{Integer, CartesianIndex}...) = + A[Base._to_linear_index(A, to_indices(A, I)...)] +Base.@propagate_inbounds Base.setindex!(A::CuTracedArray, x, + I::Union{Integer, CartesianIndex}...) = + A[Base._to_linear_index(A, to_indices(A, I)...)] = x + + +## const indexing + +""" + Const(A::CuTracedArray) + +Mark a CuTracedArray as constant/read-only. The invariant guaranteed is that you will not +modify an CuTracedArray for the duration of the current kernel. + +This API can only be used on devices with compute capability 3.5 or higher. + +!!! warning + Experimental API. Subject to change without deprecation. +""" +struct Const{T,N,AS} <: DenseArray{T,N} + a::CuTracedArray{T,N,AS} +end +Base.Experimental.Const(A::CuTracedArray) = Const(A) + +Base.IndexStyle(::Type{<:Const}) = IndexLinear() +Base.size(C::Const) = size(C.a) +Base.axes(C::Const) = axes(C.a) +Base.@propagate_inbounds Base.getindex(A::Const, i1::Integer) = const_arrayref(A.a, i1) + +# deprecated +Base.@propagate_inbounds ldg(A::CuTracedArray, i1::Integer) = const_arrayref(A, i1) + + +## other + +@inline function Base.iterate(A::CuTracedArray, i=1) + if (i % UInt) - 1 < length(A) + (@inbounds A[i], i + 1) + else + nothing + end +end + +function Base.reinterpret(::Type{T}, a::CuTracedArray{S,N,A}) where {T,S,N,A} + err = GPUArrays._reinterpret_exception(T, a) + err === nothing || throw(err) + + if sizeof(T) == sizeof(S) # fast case + return CuTracedArray{T,N,A}(reinterpret(Core.LLVMPtr{T,A}, a.ptr), size(a), a.maxsize) + end + + isize = size(a) + size1 = div(isize[1]*sizeof(S), sizeof(T)) + osize = tuple(size1, Base.tail(isize)...) + return CuTracedArray{T,N,A}(reinterpret(Core.LLVMPtr{T,A}, a.ptr), osize, a.maxsize) +end + + +## reshape + +function Base.reshape(a::CuTracedArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M,A} + if prod(dims) != length(a) + throw(DimensionMismatch("new dimensions (argument `dims`) must be consistent with array size (`size(a)`)")) + end + if N == M && dims == size(a) + return a + end + _derived_array(a, T, dims) +end + + + +function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N} + res = CuTracedArray{T,N,CUDA.AS.Global, size(xs)}(Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))) + return res +end + +const _kernel_instances = Dict{Any, Any}() + +struct LLVMFunc{F,tt} + f::Union{F, Nothing} + entry::MLIR.IR.Operation +end + + +const GPUCompiler = CUDA.GPUCompiler +const LLVM = GPUCompiler.LLVM + + +GPULowerCPUFeaturesPass() = LLVM.NewPMModulePass("GPULowerCPUFeatures", GPUCompiler.cpu_features!) +GPULowerPTLSPass() = LLVM.NewPMModulePass("GPULowerPTLS", GPUCompiler.lower_ptls!) +GPULowerGCFramePass() = LLVM.NewPMFunctionPass("GPULowerGCFrame", GPUCompiler.lower_gc_frame!) +function noop_pass(x) + return false +end +function kern_pass(mod) + for fname in ("julia.gpu.state_getter",) + if LLVM.haskey(LLVM.functions(mod), fname) + fn = LLVM.functions(mod)[fname] + insts = LLVM.Instruction[] + for u in LLVM.uses(fn) + u = LLVM.user(u) + LLVM.replace_uses!(u, LLVM.UndefValue(LLVM.value_type(u))) + push!(insts, u) + end + for inst in insts + Reactant.Enzyme.Compiler.eraseInst(LLVM.parent(inst), inst) + end + Reactant.Enzyme.Compiler.eraseInst(mod, fn) + end + end + + return true +end +AddKernelStatePass() = LLVM.NewPMModulePass("AddKernelStatePass", kern_pass) +LowerKernelStatePass() = LLVM.NewPMFunctionPass("LowerKernelStatePass", noop_pass) +CleanupKernelStatePass() = LLVM.NewPMModulePass("CleanupKernelStatePass", noop_pass) + +# compile to executable machine code +function compile(job) + # lower to PTX + # TODO: on 1.9, this actually creates a context. cache those. + entry = GPUCompiler.JuliaContext() do ctx + mod, meta = GPUCompiler.compile(:llvm, job; optimize=false, cleanup=false, validate=false) + + entryname = LLVM.name(meta.entry) + + GPUCompiler.optimize_module!(job, mod) + opt_level = 2 + tm = GPUCompiler.llvm_machine(job.config.target) + LLVM.@dispose pb=LLVM.NewPMPassBuilder() begin + LLVM.register!(pb, GPULowerCPUFeaturesPass()) + LLVM.register!(pb, GPULowerPTLSPass()) + LLVM.register!(pb, GPULowerGCFramePass()) + LLVM.register!(pb, AddKernelStatePass()) + LLVM.register!(pb, LowerKernelStatePass()) + LLVM.register!(pb, CleanupKernelStatePass()) + + LLVM.add!(pb, LLVM.NewPMModulePassManager()) do mpm + GPUCompiler.buildNewPMPipeline!(mpm, job, opt_level) + end + LLVM.run!(pb, mod, tm) + end + GPUCompiler.optimize_module!(job, mod) + LLVM.run!(CUDA.GPUCompiler.DeadArgumentEliminationPass(), mod, tm) + + + for fname in ("gpu_report_exception", "gpu_signal_exception") + if LLVM.haskey(LLVM.functions(mod), fname) + fn = LLVM.functions(mod)[fname] + insts = LLVM.Instruction[] + for u in LLVM.uses(fn) + push!(insts, LLVM.user(u)) + end + for inst in insts + Reactant.Enzyme.Compiler.eraseInst(LLVM.parent(inst), inst) + end + Reactant.Enzyme.Compiler.eraseInst(mod, fn) + end + end + + LLVM.strip_debuginfo!(mod) + modstr = string(mod) + + # This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version + # it is probably safer to reparse a string using the right llvm module api, so we will do that. + + mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMStrToMLIR(modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule) + + linkRes = @ccall MLIR.API.mlir_c.LinkInModule(MLIR.IR.mmodule()::MLIR.API.MlirModule, mmod::MLIR.API.MlirModule, entryname::Cstring)::MLIR.API.MlirOperation + + entry = MLIR.IR.Operation(linkRes) + + entry + end + LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, entry) +end + +# link into an executable kernel +function link(job, compiled) + # load as an executable kernel object + return compiled +end + +function transpose_val(val) + attr = MLIR.IR.DenseArrayAttribute( + Int64[reverse(0:(length(size(MLIR.IR.type(val))) - 1))...] + ) + return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1) +end + +Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1, + cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt} + @show call_kwargs + + blockdim = CUDA.CuDim3(blocks) + threaddim = CUDA.CuDim3(threads) + + mlir_args = MLIR.IR.Value[] + restys = MLIR.IR.Type[] + aliases = MLIR.IR.Attribute[] + rarrays = TracedRArray[] + for (i, a) in enumerate(args) + @assert a isa CuTracedArray + ta = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray + push!(rarrays, ta) + arg = ta.mlir_data + arg = transpose_val(arg) + push!(restys, MLIR.IR.type(arg)) + push!(mlir_args, arg) + push!(aliases, + MLIR.IR.Attribute(MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), + length(args) == 1 ? 0 : 1, + length(args) == 1 ? C_NULL : Ref{Int64}(i-1), + i-1, + 0, + C_NULL + )) + ) + end + + output_operand_aliases=MLIR.IR.Attribute(aliases) + + fname = Reactant.TracedUtils.get_attribute_by_name(func.entry, "sym_name") + # Force public for now while we don't have real users + MLIR.IR.rmattr!(func.entry, "sym_visibility") + + call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(fname)) + # call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod)) + for (i, res) in enumerate(rarrays) + res.mlir_data = transpose_val(MLIR.IR.result(call, i)) + end + + @show blockdim + @show threaddim + #CUDA.cuLaunchKernel(f, + # blockdim.x, blockdim.y, blockdim.z, + # threaddim.x, threaddim.y, threaddim.z, + # shmem, stream, kernelParams, C_NULL) +end + +# cache of compilation caches, per context +const _compiler_caches = Dict{MLIR.IR.Context, Dict{Any, LLVMFunc}}(); +function compiler_cache(ctx::MLIR.IR.Context) + cache = get(_compiler_caches, ctx, nothing) + if cache === nothing + cache = Dict{Any, LLVMFunc}() + _compiler_caches[ctx] = cache + end + return cache +end + +Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} + res = Base.@lock CUDA.cufunction_lock begin + # compile the function + cache = compiler_cache(MLIR.IR.context()) + source = CUDA.methodinstance(F, tt) + # cuda = CUDA.active_state() + device = nothing # cuda.device + # config = CUDA.compiler_config(device; kwargs...)::CUDA.CUDACompilerConfig + cuda_cap=v"5.0" + cuda_ptx=v"6.3" + llvm_cap=v"5.0" + llvm_ptx=v"6.3" + kernel=true + always_inline=false + name=nothing + debuginfo=false + config = CUDA.CompilerConfig(CUDA.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo), CUDA.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx); kernel, name, always_inline) + CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link) + end + res +end + +function __init__() + +end + +end # module ReactantCUDAExt diff --git a/src/utils.jl b/src/utils.jl index b65077c03..3fdd646e0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -307,78 +307,6 @@ function call_with_reactant_generator( # No method could be found (including in our method table), bail with an error if lookup_result == nothing return stub(world, source, method_error) - tmp_min_world = Ref{UInt}(typemin(UInt)) - tmp_max_world = Ref{UInt}(typemax(UInt)) - match = ccall( - :jl_gf_invoke_lookup_worlds, - Any, - (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), - Tuple{typeof(throw_method_error),sig}, - nothing, - world, - tmp_min_world, - tmp_max_world, - ) #=mt=# - @assert match !== nothing - - # look up the method and code instance - mi = ccall( - :jl_specializations_get_linfo, - Ref{Core.MethodInstance}, - (Any, Any, Any), - match.method, - match.spec_types, - match.sparams, - ) - - ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo - - src = copy(ci) - src.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] - - src.edges = Any[ - ccall(:jl_method_table_for, Any, (Any,), sig)::Core.MethodTable, sig - ] - src.min_world = min_world[] - src.max_world = max_world[] - - push!(overdubbed_code, :($(Base.getindex)($(Core.Argument(2)), 1))) - push!(overdubbed_codelocs, 0) - - expr_fn = Core.SSAValue(length(overdubbed_code)) - - push!(overdubbed_code, :($(Base.lastindex)($(Core.Argument(2))))) - push!(overdubbed_codelocs, 0) - - expr_lastindex = Core.SSAValue(length(overdubbed_code)) - - push!(overdubbed_code, :(2:($expr_lastindex))) - push!(overdubbed_codelocs, 0) - - expr_slice = Core.SSAValue(length(overdubbed_code)) - - push!(overdubbed_code, :($(Base.getindex)($(Core.Argument(2)), $expr_slice))) - push!(overdubbed_codelocs, 0) - - expr_args = Core.SSAValue(length(overdubbed_code)) - - push!(overdubbed_code, :($(Base.MethodError)($expr_fn, $expr_args, $world))) - push!(overdubbed_codelocs, 0) - - expr_method = Core.SSAValue(length(overdubbed_code)) - - push!(overdubbed_code, :($(Base.throw)($expr_method))) - push!(overdubbed_codelocs, 0) - - push!(overdubbed_code, Core.ReturnNode(Core.SSAValue(length(overdubbed_code)))) - push!(overdubbed_codelocs, 0) - - src.code = overdubbed_code - src.codelocs = overdubbed_codelocs - src.ssavaluetypes = length(overdubbed_code) - src.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - - return src end match = lookup_result::Core.MethodMatch @@ -438,17 +366,19 @@ function call_with_reactant_generator( # Also rewrite invoke (type stable call) to be :call, since otherwise apparently # screws up type inference after this (TODO this should be fixed). any_changed = false - for (i, inst) in enumerate(ir.stmts) - @static if VERSION < v"1.11" - changed, next = rewrite_inst(inst[:inst], ir, interp) - Core.Compiler.setindex!(ir.stmts[i], next, :inst) - else - changed, next = rewrite_inst(inst[:stmt], ir, interp) - Core.Compiler.setindex!(ir.stmts[i], next, :stmt) - end - if changed - any_changed = true - Core.Compiler.setindex!(ir.stmts[i], Any, :type) + if should_rewrite_ft(args[1]) && !is_reactant_method(mi) + for (i, inst) in enumerate(ir.stmts) + @static if VERSION < v"1.11" + changed, next = rewrite_inst(inst[:inst], ir, interp) + Core.Compiler.setindex!(ir.stmts[i], next, :inst) + else + changed, next = rewrite_inst(inst[:stmt], ir, interp) + Core.Compiler.setindex!(ir.stmts[i], next, :stmt) + end + if changed + any_changed = true + Core.Compiler.setindex!(ir.stmts[i], Any, :type) + end end end diff --git a/test/Project.toml b/test/Project.toml index 4b50a487f..9956337ea 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/test/cuda.jl b/test/cuda.jl new file mode 100644 index 000000000..05d0777c5 --- /dev/null +++ b/test/cuda.jl @@ -0,0 +1,25 @@ +using Reactant +using Test +using CUDA + +function square_kernel!(x) + i = threadIdx().x + x[i] *= x[i] + # sync_threads() + return nothing +end + +# basic squaring on GPU +function square!(x) + @cuda blocks = 1 threads = length(x) square_kernel!(x) + return nothing +end + +@testset "Square Kernel" begin + oA = collect(1:1:64) + A = Reactant.to_rarray(oA) + @show @code_hlo optimize=false square!(A) + @show @code_hlo square!(A) + func = @compile square!(A) + @test all(Array(A) .≈ (oA .* oA)) +end