Skip to content

Commit

Permalink
host and device IR
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 15, 2024
1 parent 0c61f5d commit d0e5195
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 168 deletions.
84 changes: 84 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,10 @@ extern "C" void RegisterDialects(MlirContext cctx) {
context.loadDialect<mlir::stablehlo::StablehloDialect>();
context.loadDialect<mlir::chlo::ChloDialect>();
}

#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
mlir::DialectRegistry &registry = *unwrap(creg);

Expand Down Expand Up @@ -513,6 +517,11 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
mlir::affine::registerAffinePasses();
mlir::registerReconcileUnrealizedCasts();

mlir::registerLLVMDialectImport(registry);
mlir::registerNVVMDialectImport(registry);

mlir::LLVM::registerInlinerInterface(registry);

/*
registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
Expand Down Expand Up @@ -540,6 +549,81 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
mlir::enzyme::registerEnzymeJaxTransformExtension(registry);
}


/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
/// suffix in `lastUsedID`.
static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName,
unsigned &lastUsedID,
mlir::ModuleOp source,
mlir::ModuleOp target) {
using namespace llvm;
using namespace mlir;
SmallString<64> newSymName(oldSymName);
newSymName.push_back('_');
while (true) {
auto possible = newSymName + Twine(++lastUsedID);
if (!SymbolTable::lookupSymbolIn(source, possible.str()) && !SymbolTable::lookupSymbolIn(target, possible.str())) {
return StringAttr::get(target.getContext(), possible);
}
}
}


/// Checks if a symbol with the same name as `op` already exists in `source`.
/// If so, renames `op` and updates all its references in `target`.
static mlir::LogicalResult
updateSymbolAndAllUses(mlir::SymbolOpInterface op, mlir::ModuleOp source, mlir::ModuleOp target,
unsigned &lastUsedID) {
using namespace llvm;
using namespace mlir;

auto opName = op.getName().str();

if (!SymbolTable::lookupSymbolIn(target, opName)) {
return success();
}

StringAttr newSymName =
renameSymbol(opName, lastUsedID, source, target);

if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, source)))
return op.emitError("unable to update all symbol uses for ")
<< opName << " to " << newSymName;

SymbolTable::setSymbolName(op, newSymName);
return success();
}

extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, const char* entryfn) {
auto prevMod = cast<ModuleOp>(*unwrap(prevModC));
auto newMod = cast<ModuleOp>(*unwrap(newModC));

Operation* entryFn = nullptr;

unsigned lastUsedID = 0;

for (auto &op : *newMod.getBody()) {
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;

StringRef oldSymName = symbolOp.getName();

if (oldSymName == entryfn) {
entryFn = &op;
}

if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod,
lastUsedID))) {
assert(0 && "failed to update all uses");
}
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
}
prevMod.getBody()->getOperations().splice(prevMod.getBody()->getOperations().end(),
newMod.getBody()->getOperations());
return wrap(entryFn);
}

#pragma region xla::ifrt

#pragma region xla::ifrt::Value
Expand Down
5 changes: 5 additions & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ cc_library(
"-Wl,-exported_symbol,_BufferToHost",
"-Wl,-exported_symbol,_FreeClient",
"-Wl,-exported_symbol,_ClientCompile",
"-Wl,-exported_symbol,_LinkInModule",
"-Wl,-exported_symbol,_FreeFuture",
"-Wl,-exported_symbol,_FutureIsReady",
"-Wl,-exported_symbol,_FutureAwait",
Expand Down Expand Up @@ -451,6 +452,10 @@ cc_library(
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:Transforms",

"@llvm-project//mlir:LLVMIRToLLVMTranslation",
"@llvm-project//mlir:LLVMIRToNVVMTranslation",
"@llvm-project//mlir:LLVMIRTransforms",

"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:AArch64AsmParser",
Expand Down
188 changes: 21 additions & 167 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,17 +200,14 @@ end

function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
res = CuTracedArray{T,N,CUDA.AS.Global, size(xs)}(Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs)))
@show res, xs
return res
end

const _kernel_instances = Dict{Any, Any}()

struct LLVMFunc{F,tt}
f::Union{F, Nothing}
mod::String
image
entry::String
entry::MLIR.IR.Operation
end


Expand Down Expand Up @@ -249,11 +246,13 @@ CleanupKernelStatePass() = LLVM.NewPMModulePass("CleanupKernelStatePass", noop_p

# compile to executable machine code
function compile(job)

# lower to PTX
# TODO: on 1.9, this actually creates a context. cache those.
modstr, image, entry = GPUCompiler.JuliaContext() do ctx
entry = GPUCompiler.JuliaContext() do ctx
mod, meta = GPUCompiler.compile(:llvm, job; optimize=false, cleanup=false, validate=false)

entryname = LLVM.name(meta.entry)

GPUCompiler.optimize_module!(job, mod)
opt_level = 2
tm = GPUCompiler.llvm_machine(job.config.target)
Expand Down Expand Up @@ -294,162 +293,15 @@ function compile(job)
# This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version
# it is probably safer to reparse a string using the right llvm module api, so we will do that.

println(string(modstr))
mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMStrToMLIR(modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule)
@show mmod

# check if we'll need the device runtime
undefined_fs = filter(collect(CUDA.LLVM.functions(meta.ir))) do f
CUDA.LLVM.isdeclaration(f) && !CUDA.LLVM.isintrinsic(f)
end
intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail",
"__nvvm_reflect" #= TODO: should have been optimized away =#]
needs_cudadevrt = !isempty(setdiff(CUDA.LLVM.name.(undefined_fs), intrinsic_fns))

# prepare invocations of CUDA compiler tools
ptxas_opts = String[]
nvlink_opts = String[]
## debug flags
if Base.JLOptions().debug_level == 1
push!(ptxas_opts, "--generate-line-info")
elseif Base.JLOptions().debug_level >= 2
push!(ptxas_opts, "--device-debug")
push!(nvlink_opts, "--debug")
end
## relocatable device code
if needs_cudadevrt
push!(ptxas_opts, "--compile-only")
end

ptx = job.config.params.ptx
cap = job.config.params.cap
arch = "sm_$(cap.major)$(cap.minor)"

# validate use of parameter memory
argtypes = filter([CUDA.KernelState, job.source.specTypes.parameters...]) do dt
!CUDA.isghosttype(dt) && !Core.Compiler.isconstType(dt)
end
param_usage = sum(sizeof, argtypes)
param_limit = 4096
if cap >= v"7.0" && ptx >= v"8.1"
param_limit = 32764
end
if param_usage > param_limit
msg = """Kernel invocation uses too much parameter memory.
$(Base.format_bytes(param_usage)) exceeds the $(Base.format_bytes(param_limit)) limit imposed by sm_$(cap.major)$(cap.minor) / PTX v$(ptx.major).$(ptx.minor)."""

try
details = "\n\nRelevant parameters:"

source_types = job.source.specTypes.parameters
source_argnames = Base.method_argnames(job.source.def)
while length(source_argnames) < length(source_types)
# this is probably due to a trailing vararg; repeat its name
push!(source_argnames, source_argnames[end])
end

for (i, typ) in enumerate(source_types)
if CUDA.isghosttype(typ) || Core.Compiler.isconstType(typ)
continue
end
name = source_argnames[i]
details *= "\n [$(i-1)] $name::$typ uses $(Base.format_bytes(sizeof(typ)))"
end
details *= "\n"

if cap >= v"7.0" && ptx < v"8.1" && param_usage < 32764
details *= "\nNote: use a newer CUDA to support more parameters on your device.\n"
end

msg *= details
catch err
@error "Failed to analyze kernel parameter usage; please file an issue with a reproducer."
end
error(msg)
end

# compile to machine code
# NOTE: we use tempname since mktemp doesn't support suffixes, and mktempdir is slow
ptx_input = tempname(cleanup=false) * ".ptx"
ptxas_output = tempname(cleanup=false) * ".cubin"
write(ptx_input, asm)

# we could use the driver's embedded JIT compiler, but that has several disadvantages:
# 1. fixes and improvements are slower to arrive, by using `ptxas` we only need to
# upgrade the toolkit to get a newer compiler;
# 2. version checking is simpler, we otherwise need to use NVML to query the driver
# version, which is hard to correlate to PTX JIT improvements;
# 3. if we want to be able to use newer (minor upgrades) of the CUDA toolkit on an
# older driver, we should use the newer compiler to ensure compatibility.
append!(ptxas_opts, [
"--verbose",
"--gpu-name", arch,
"--output-file", ptxas_output,
ptx_input
])
proc, log = CUDA.run_and_collect(`$(CUDA.ptxas()) $ptxas_opts`)
log = strip(log)
if !success(proc)
reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" :
"ptxas exited with code $(proc.exitcode)"
msg = "Failed to compile PTX code ($reason)"
msg *= "\nInvocation arguments: $(join(ptxas_opts, ' '))"
if !isempty(log)
msg *= "\n" * log
end
msg *= "\nIf you think this is a bug, please file an issue and attach $(ptx_input)"
if parse(Bool, get(ENV, "BUILDKITE", "false"))
run(`buildkite-agent artifact upload $(ptx_input)`)
end
error(msg)
elseif !isempty(log)
@debug "PTX compiler log:\n" * log
end
rm(ptx_input)

# link device libraries, if necessary
#
# this requires relocatable device code, which prevents certain optimizations and
# hurts performance. as such, we only do so when absolutely necessary.
# TODO: try LTO, `--link-time-opt --nvvmpath /opt/cuda/nvvm`.
# fails with `Ignoring -lto option because no LTO objects found`
if needs_cudadevrt
nvlink_output = tempname(cleanup=false) * ".cubin"
append!(nvlink_opts, [
"--verbose", "--extra-warnings",
"--arch", arch,
"--library-path", dirname(libcudadevrt),
"--library", "cudadevrt",
"--output-file", nvlink_output,
ptxas_output
])
proc, log = run_and_collect(`$(CUDA.nvlink()) $nvlink_opts`)
log = strip(log)
if !success(proc)
reason = proc.termsignal > 0 ? "nvlink received signal $(proc.termsignal)" :
"nvlink exited with code $(proc.exitcode)"
msg = "Failed to link PTX code ($reason)"
msg *= "\nInvocation arguments: $(join(nvlink_opts, ' '))"
if !isempty(log)
msg *= "\n" * log
end
msg *= "\nIf you think this is a bug, please file an issue and attach $(ptxas_output)"
error(msg)
elseif !isempty(log)
@debug "PTX linker info log:\n" * log
end
rm(ptxas_output)

image = read(nvlink_output)
rm(nvlink_output)
else
image = read(ptxas_output)
rm(ptxas_output)
end

modstr, image, meta.entry

linkRes = @ccall MLIR.API.mlir_c.LinkInModule(MLIR.IR.mmodule()::MLIR.API.MlirModule, mmod::MLIR.API.MlirModule, entryname::Cstring)::MLIR.API.MlirOperation

entry = MLIR.IR.Operation(linkRes)

entry
end
LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, modstr, image, CUDA.LLVM.name(entry))
LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, entry)
end

# link into an executable kernel
Expand All @@ -467,7 +319,6 @@ end

Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1,
cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt}
@show args
@show call_kwargs

blockdim = CUDA.CuDim3(blocks)
Expand All @@ -478,13 +329,11 @@ Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; c
aliases = MLIR.IR.Attribute[]
rarrays = TracedRArray[]
for (i, a) in enumerate(args)
@show a
@assert a isa CuTracedArray
ta = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray
push!(rarrays, ta)
arg = ta.mlir_data
arg = transpose_val(arg)
@show arg
push!(restys, MLIR.IR.type(arg))
push!(mlir_args, arg)
push!(aliases,
Expand All @@ -500,11 +349,19 @@ Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; c
end

output_operand_aliases=MLIR.IR.Attribute(aliases)
call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute("configstr"))

fname = Reactant.TracedUtils.get_attribute_by_name(func.entry, "sym_name")
# Force public for now while we don't have real users
MLIR.IR.rmattr!(func.entry, "sym_visibility")

call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(fname))
# call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod))
for (i, res) in enumerate(rarrays)
res.mlir_data = transpose_val(MLIR.IR.result(call, i))
end

@show blockdim
@show threaddim
#CUDA.cuLaunchKernel(f,
# blockdim.x, blockdim.y, blockdim.z,
# threaddim.x, threaddim.y, threaddim.z,
Expand All @@ -523,12 +380,10 @@ function compiler_cache(ctx::MLIR.IR.Context)
end

Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
@show "recufunction", f, tt
res = Base.@lock CUDA.cufunction_lock begin
# compile the function
cache = compiler_cache(MLIR.IR.context())
source = CUDA.methodinstance(F, tt)

# cuda = CUDA.active_state()
device = nothing # cuda.device
# config = CUDA.compiler_config(device; kwargs...)::CUDA.CUDACompilerConfig
Expand All @@ -543,7 +398,6 @@ Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tupl
config = CUDA.CompilerConfig(CUDA.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo), CUDA.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx); kernel, name, always_inline)
CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link)
end
@show res
res
end

Expand Down
Loading

0 comments on commit d0e5195

Please sign in to comment.