From 5744380f6c1bb3ee0c078aec67ad96a022b056ba Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 22 Dec 2024 23:35:27 -0500 Subject: [PATCH] Run cuda module init --- src/enzyme_ad/jax/Passes/LowerKernel.cpp | 19 +++++++++++++++---- src/enzyme_ad/jax/Passes/Passes.td | 7 +++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/LowerKernel.cpp b/src/enzyme_ad/jax/Passes/LowerKernel.cpp index dc34abb9a..7406085cf 100644 --- a/src/enzyme_ad/jax/Passes/LowerKernel.cpp +++ b/src/enzyme_ad/jax/Passes/LowerKernel.cpp @@ -175,7 +175,7 @@ llvm::StringMap kernels; llvm::sys::SmartRWMutex kernel_mutex; std::unique_ptr JIT = nullptr; -void *CompileHostModule(std::string &key, mlir::ModuleOp modOp) { +void *CompileHostModule(std::string &key, mlir::ModuleOp modOp, bool run_init) { llvm::errs() << " compiling host module: " << modOp << "\n"; if (!JIT) { auto tJIT = @@ -249,6 +249,17 @@ void *CompileHostModule(std::string &key, mlir::ModuleOp modOp) { auto ptr = (void *)EntrySym->getValue(); kernels[key] = ptr; + + auto NVSym = JIT->lookup(LibA.get(), "nv_func_init"); + if (!NVSym) { + llvm::errs() << " lookupError " << NVSym.takeError() << "\n"; + return nullptr; + } + + auto nvptr = (void *)NVSym->getValue(); + + ((void (*)())(nvptr))(); + return ptr; } @@ -317,7 +328,7 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, int indexBitWidth, std::string cubinChip, std::string cubinFeatures, size_t cuLaunchKernelPtr, size_t cuModuleLoadDataPtr, size_t cuModuleGetFunctionPtr, - bool compileLaunch) { + bool compileLaunch, bool run_init) { llvm::errs() << " Compiling kernel: " << gridx << "," << gridy << "," << gridz << "," << blockx << "," << blocky << "," << blockz << "\n"; @@ -648,7 +659,7 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, if (!compileLaunch) return nullptr; - ptr = CompileHostModule(ss.str(), submod); + ptr = CompileHostModule(ss.str(), submod, run_init); submod.erase(); } @@ -747,7 +758,7 @@ struct LowerKernelPass : public LowerKernelPassBase { data[5], data[6], data[7], toolkitPath.getValue(), linkFilesArray, indexBitWidth.getValue(), cubinChip.getValue(), cubinFeatures.getValue(), cuLaunchKernelPtr, cuModuleLoadDataPtr, - cuModuleGetFunctionPtr, compileLaunch); + cuModuleGetFunctionPtr, compileLaunch, run_init); std::string backendinfo((char *)&data, sizeof(void *)); diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index bdbb26f10..21ecd8cb1 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -207,6 +207,13 @@ def LowerKernelPass : Pass<"lower-kernel"> { /*default=*/"0", /*description=*/"cuModuleGetFunctionPtr" >, + Option< + /*C++ variable name=*/"run_init", + /*CLI argument=*/"run_init", + /*type=*/"bool", + /*default=*/"false", + /*description=*/"Run initialization of cuda module" + >, ]; }