From d0e51956e1447ff3808f8b7543f313e2475aec5c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 15 Dec 2024 02:39:18 -0500 Subject: [PATCH] host and device IR --- deps/ReactantExtra/API.cpp | 84 +++++++++++++++++ deps/ReactantExtra/BUILD | 5 + ext/ReactantCUDAExt.jl | 188 +++++-------------------------------- test/cuda.jl | 2 +- 4 files changed, 111 insertions(+), 168 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index f93b32ea4..3ae7a7ebf 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -470,6 +470,10 @@ extern "C" void RegisterDialects(MlirContext cctx) { context.loadDialect(); context.loadDialect(); } + +#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::DialectRegistry ®istry = *unwrap(creg); @@ -513,6 +517,11 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::affine::registerAffinePasses(); mlir::registerReconcileUnrealizedCasts(); + mlir::registerLLVMDialectImport(registry); + mlir::registerNVVMDialectImport(registry); + + mlir::LLVM::registerInlinerInterface(registry); + /* registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { LLVM::LLVMFunctionType::attachInterface(*ctx); @@ -540,6 +549,81 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::enzyme::registerEnzymeJaxTransformExtension(registry); } + +/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric +/// suffix in `lastUsedID`. +static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName, + unsigned &lastUsedID, + mlir::ModuleOp source, + mlir::ModuleOp target) { + using namespace llvm; + using namespace mlir; + SmallString<64> newSymName(oldSymName); + newSymName.push_back('_'); + while (true) { + auto possible = newSymName + Twine(++lastUsedID); + if (!SymbolTable::lookupSymbolIn(source, possible.str()) && !SymbolTable::lookupSymbolIn(target, possible.str())) { + return StringAttr::get(target.getContext(), possible); + } + } +} + + +/// Checks if a symbol with the same name as `op` already exists in `source`. +/// If so, renames `op` and updates all its references in `target`. +static mlir::LogicalResult +updateSymbolAndAllUses(mlir::SymbolOpInterface op, mlir::ModuleOp source, mlir::ModuleOp target, + unsigned &lastUsedID) { + using namespace llvm; + using namespace mlir; + + auto opName = op.getName().str(); + + if (!SymbolTable::lookupSymbolIn(target, opName)) { + return success(); + } + + StringAttr newSymName = + renameSymbol(opName, lastUsedID, source, target); + + if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, source))) + return op.emitError("unable to update all symbol uses for ") + << opName << " to " << newSymName; + + SymbolTable::setSymbolName(op, newSymName); + return success(); +} + +extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, const char* entryfn) { + auto prevMod = cast(*unwrap(prevModC)); + auto newMod = cast(*unwrap(newModC)); + + Operation* entryFn = nullptr; + + unsigned lastUsedID = 0; + + for (auto &op : *newMod.getBody()) { + auto symbolOp = dyn_cast(op); + if (!symbolOp) + continue; + + StringRef oldSymName = symbolOp.getName(); + + if (oldSymName == entryfn) { + entryFn = &op; + } + + if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, + lastUsedID))) { + assert(0 && "failed to update all uses"); + } + SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private); + } + prevMod.getBody()->getOperations().splice(prevMod.getBody()->getOperations().end(), + newMod.getBody()->getOperations()); + return wrap(entryFn); +} + #pragma region xla::ifrt #pragma region xla::ifrt::Value diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index c718304bd..c538bbb8a 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -416,6 +416,7 @@ cc_library( "-Wl,-exported_symbol,_BufferToHost", "-Wl,-exported_symbol,_FreeClient", "-Wl,-exported_symbol,_ClientCompile", +"-Wl,-exported_symbol,_LinkInModule", "-Wl,-exported_symbol,_FreeFuture", "-Wl,-exported_symbol,_FutureIsReady", "-Wl,-exported_symbol,_FutureAwait", @@ -451,6 +452,10 @@ cc_library( "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:LLVMIRToLLVMTranslation", + "@llvm-project//mlir:LLVMIRToNVVMTranslation", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//llvm:IRReader", "@llvm-project//llvm:Support", "@llvm-project//llvm:AArch64AsmParser", diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index adf35e0aa..b38b55001 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -200,7 +200,6 @@ end function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N} res = CuTracedArray{T,N,CUDA.AS.Global, size(xs)}(Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))) - @show res, xs return res end @@ -208,9 +207,7 @@ const _kernel_instances = Dict{Any, Any}() struct LLVMFunc{F,tt} f::Union{F, Nothing} - mod::String - image - entry::String + entry::MLIR.IR.Operation end @@ -249,11 +246,13 @@ CleanupKernelStatePass() = LLVM.NewPMModulePass("CleanupKernelStatePass", noop_p # compile to executable machine code function compile(job) - # lower to PTX # TODO: on 1.9, this actually creates a context. cache those. - modstr, image, entry = GPUCompiler.JuliaContext() do ctx + entry = GPUCompiler.JuliaContext() do ctx mod, meta = GPUCompiler.compile(:llvm, job; optimize=false, cleanup=false, validate=false) + + entryname = LLVM.name(meta.entry) + GPUCompiler.optimize_module!(job, mod) opt_level = 2 tm = GPUCompiler.llvm_machine(job.config.target) @@ -294,162 +293,15 @@ function compile(job) # This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version # it is probably safer to reparse a string using the right llvm module api, so we will do that. - println(string(modstr)) mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMStrToMLIR(modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule) - @show mmod - - # check if we'll need the device runtime - undefined_fs = filter(collect(CUDA.LLVM.functions(meta.ir))) do f - CUDA.LLVM.isdeclaration(f) && !CUDA.LLVM.isintrinsic(f) - end - intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail", - "__nvvm_reflect" #= TODO: should have been optimized away =#] - needs_cudadevrt = !isempty(setdiff(CUDA.LLVM.name.(undefined_fs), intrinsic_fns)) - - # prepare invocations of CUDA compiler tools - ptxas_opts = String[] - nvlink_opts = String[] - ## debug flags - if Base.JLOptions().debug_level == 1 - push!(ptxas_opts, "--generate-line-info") - elseif Base.JLOptions().debug_level >= 2 - push!(ptxas_opts, "--device-debug") - push!(nvlink_opts, "--debug") - end - ## relocatable device code - if needs_cudadevrt - push!(ptxas_opts, "--compile-only") - end - - ptx = job.config.params.ptx - cap = job.config.params.cap - arch = "sm_$(cap.major)$(cap.minor)" - - # validate use of parameter memory - argtypes = filter([CUDA.KernelState, job.source.specTypes.parameters...]) do dt - !CUDA.isghosttype(dt) && !Core.Compiler.isconstType(dt) - end - param_usage = sum(sizeof, argtypes) - param_limit = 4096 - if cap >= v"7.0" && ptx >= v"8.1" - param_limit = 32764 - end - if param_usage > param_limit - msg = """Kernel invocation uses too much parameter memory. - $(Base.format_bytes(param_usage)) exceeds the $(Base.format_bytes(param_limit)) limit imposed by sm_$(cap.major)$(cap.minor) / PTX v$(ptx.major).$(ptx.minor).""" - - try - details = "\n\nRelevant parameters:" - - source_types = job.source.specTypes.parameters - source_argnames = Base.method_argnames(job.source.def) - while length(source_argnames) < length(source_types) - # this is probably due to a trailing vararg; repeat its name - push!(source_argnames, source_argnames[end]) - end - - for (i, typ) in enumerate(source_types) - if CUDA.isghosttype(typ) || Core.Compiler.isconstType(typ) - continue - end - name = source_argnames[i] - details *= "\n [$(i-1)] $name::$typ uses $(Base.format_bytes(sizeof(typ)))" - end - details *= "\n" - - if cap >= v"7.0" && ptx < v"8.1" && param_usage < 32764 - details *= "\nNote: use a newer CUDA to support more parameters on your device.\n" - end - - msg *= details - catch err - @error "Failed to analyze kernel parameter usage; please file an issue with a reproducer." - end - error(msg) - end - - # compile to machine code - # NOTE: we use tempname since mktemp doesn't support suffixes, and mktempdir is slow - ptx_input = tempname(cleanup=false) * ".ptx" - ptxas_output = tempname(cleanup=false) * ".cubin" - write(ptx_input, asm) - - # we could use the driver's embedded JIT compiler, but that has several disadvantages: - # 1. fixes and improvements are slower to arrive, by using `ptxas` we only need to - # upgrade the toolkit to get a newer compiler; - # 2. version checking is simpler, we otherwise need to use NVML to query the driver - # version, which is hard to correlate to PTX JIT improvements; - # 3. if we want to be able to use newer (minor upgrades) of the CUDA toolkit on an - # older driver, we should use the newer compiler to ensure compatibility. - append!(ptxas_opts, [ - "--verbose", - "--gpu-name", arch, - "--output-file", ptxas_output, - ptx_input - ]) - proc, log = CUDA.run_and_collect(`$(CUDA.ptxas()) $ptxas_opts`) - log = strip(log) - if !success(proc) - reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" : - "ptxas exited with code $(proc.exitcode)" - msg = "Failed to compile PTX code ($reason)" - msg *= "\nInvocation arguments: $(join(ptxas_opts, ' '))" - if !isempty(log) - msg *= "\n" * log - end - msg *= "\nIf you think this is a bug, please file an issue and attach $(ptx_input)" - if parse(Bool, get(ENV, "BUILDKITE", "false")) - run(`buildkite-agent artifact upload $(ptx_input)`) - end - error(msg) - elseif !isempty(log) - @debug "PTX compiler log:\n" * log - end - rm(ptx_input) - - # link device libraries, if necessary - # - # this requires relocatable device code, which prevents certain optimizations and - # hurts performance. as such, we only do so when absolutely necessary. - # TODO: try LTO, `--link-time-opt --nvvmpath /opt/cuda/nvvm`. - # fails with `Ignoring -lto option because no LTO objects found` - if needs_cudadevrt - nvlink_output = tempname(cleanup=false) * ".cubin" - append!(nvlink_opts, [ - "--verbose", "--extra-warnings", - "--arch", arch, - "--library-path", dirname(libcudadevrt), - "--library", "cudadevrt", - "--output-file", nvlink_output, - ptxas_output - ]) - proc, log = run_and_collect(`$(CUDA.nvlink()) $nvlink_opts`) - log = strip(log) - if !success(proc) - reason = proc.termsignal > 0 ? "nvlink received signal $(proc.termsignal)" : - "nvlink exited with code $(proc.exitcode)" - msg = "Failed to link PTX code ($reason)" - msg *= "\nInvocation arguments: $(join(nvlink_opts, ' '))" - if !isempty(log) - msg *= "\n" * log - end - msg *= "\nIf you think this is a bug, please file an issue and attach $(ptxas_output)" - error(msg) - elseif !isempty(log) - @debug "PTX linker info log:\n" * log - end - rm(ptxas_output) - - image = read(nvlink_output) - rm(nvlink_output) - else - image = read(ptxas_output) - rm(ptxas_output) - end - - modstr, image, meta.entry + + linkRes = @ccall MLIR.API.mlir_c.LinkInModule(MLIR.IR.mmodule()::MLIR.API.MlirModule, mmod::MLIR.API.MlirModule, entryname::Cstring)::MLIR.API.MlirOperation + + entry = MLIR.IR.Operation(linkRes) + + entry end - LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, modstr, image, CUDA.LLVM.name(entry)) + LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, entry) end # link into an executable kernel @@ -467,7 +319,6 @@ end Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1, cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt} - @show args @show call_kwargs blockdim = CUDA.CuDim3(blocks) @@ -478,13 +329,11 @@ Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; c aliases = MLIR.IR.Attribute[] rarrays = TracedRArray[] for (i, a) in enumerate(args) - @show a @assert a isa CuTracedArray ta = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray push!(rarrays, ta) arg = ta.mlir_data arg = transpose_val(arg) - @show arg push!(restys, MLIR.IR.type(arg)) push!(mlir_args, arg) push!(aliases, @@ -500,11 +349,19 @@ Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; c end output_operand_aliases=MLIR.IR.Attribute(aliases) - call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute("configstr")) + + fname = Reactant.TracedUtils.get_attribute_by_name(func.entry, "sym_name") + # Force public for now while we don't have real users + MLIR.IR.rmattr!(func.entry, "sym_visibility") + + call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(fname)) # call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod)) for (i, res) in enumerate(rarrays) res.mlir_data = transpose_val(MLIR.IR.result(call, i)) end + + @show blockdim + @show threaddim #CUDA.cuLaunchKernel(f, # blockdim.x, blockdim.y, blockdim.z, # threaddim.x, threaddim.y, threaddim.z, @@ -523,12 +380,10 @@ function compiler_cache(ctx::MLIR.IR.Context) end Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} - @show "recufunction", f, tt res = Base.@lock CUDA.cufunction_lock begin # compile the function cache = compiler_cache(MLIR.IR.context()) source = CUDA.methodinstance(F, tt) - # cuda = CUDA.active_state() device = nothing # cuda.device # config = CUDA.compiler_config(device; kwargs...)::CUDA.CUDACompilerConfig @@ -543,7 +398,6 @@ Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tupl config = CUDA.CompilerConfig(CUDA.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo), CUDA.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx); kernel, name, always_inline) CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link) end - @show res res end diff --git a/test/cuda.jl b/test/cuda.jl index ae1b473f6..05d0777c5 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -5,7 +5,7 @@ using CUDA function square_kernel!(x) i = threadIdx().x x[i] *= x[i] - sync_threads() + # sync_threads() return nothing end