Skip to content

Commit

Permalink
Run cuda module init
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 23, 2024
1 parent f1fc93d commit 5744380
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/enzyme_ad/jax/Passes/LowerKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ llvm::StringMap<void *> kernels;
llvm::sys::SmartRWMutex<true> kernel_mutex;
std::unique_ptr<llvm::orc::LLJIT> JIT = nullptr;

void *CompileHostModule(std::string &key, mlir::ModuleOp modOp) {
void *CompileHostModule(std::string &key, mlir::ModuleOp modOp, bool run_init) {
llvm::errs() << " compiling host module: " << modOp << "\n";
if (!JIT) {
auto tJIT =
Expand Down Expand Up @@ -249,6 +249,17 @@ void *CompileHostModule(std::string &key, mlir::ModuleOp modOp) {
auto ptr = (void *)EntrySym->getValue();

kernels[key] = ptr;

auto NVSym = JIT->lookup(LibA.get(), "nv_func_init");
if (!NVSym) {
llvm::errs() << " lookupError " << NVSym.takeError() << "\n";
return nullptr;
}

auto nvptr = (void *)NVSym->getValue();

((void (*)())(nvptr))();

return ptr;
}

Expand Down Expand Up @@ -317,7 +328,7 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
int indexBitWidth, std::string cubinChip,
std::string cubinFeatures, size_t cuLaunchKernelPtr,
size_t cuModuleLoadDataPtr, size_t cuModuleGetFunctionPtr,
bool compileLaunch) {
bool compileLaunch, bool run_init) {

llvm::errs() << " Compiling kernel: " << gridx << "," << gridy << "," << gridz
<< "," << blockx << "," << blocky << "," << blockz << "\n";
Expand Down Expand Up @@ -648,7 +659,7 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
if (!compileLaunch)
return nullptr;

ptr = CompileHostModule(ss.str(), submod);
ptr = CompileHostModule(ss.str(), submod, run_init);

submod.erase();
}
Expand Down Expand Up @@ -747,7 +758,7 @@ struct LowerKernelPass : public LowerKernelPassBase<LowerKernelPass> {
data[5], data[6], data[7], toolkitPath.getValue(), linkFilesArray,
indexBitWidth.getValue(), cubinChip.getValue(),
cubinFeatures.getValue(), cuLaunchKernelPtr, cuModuleLoadDataPtr,
cuModuleGetFunctionPtr, compileLaunch);
cuModuleGetFunctionPtr, compileLaunch, run_init);

std::string backendinfo((char *)&data, sizeof(void *));

Expand Down
7 changes: 7 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ def LowerKernelPass : Pass<"lower-kernel"> {
/*default=*/"0",
/*description=*/"cuModuleGetFunctionPtr"
>,
Option<
/*C++ variable name=*/"run_init",
/*CLI argument=*/"run_init",
/*type=*/"bool",
/*default=*/"false",
/*description=*/"Run initialization of cuda module"
>,
];
}

Expand Down

0 comments on commit 5744380

Please sign in to comment.