Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: kernels #314

Merged
merged 23 commits into from
Dec 17, 2024
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Scratch = "6c6a2e73-6563-6170-7368-637461726353"
[weakdeps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
Expand All @@ -31,6 +32,7 @@ path = "lib/ReactantCore"
[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantCUDAExt = "CUDA"
ReactantNNlibExt = "NNlib"
ReactantStatisticsExt = "Statistics"
ReactantYaoBlocksExt = "YaoBlocks"
Expand Down Expand Up @@ -58,4 +60,5 @@ julia = "1.10"
[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
84 changes: 84 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,10 @@ extern "C" void RegisterDialects(MlirContext cctx) {
context.loadDialect<mlir::stablehlo::StablehloDialect>();
context.loadDialect<mlir::chlo::ChloDialect>();
}

#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
mlir::DialectRegistry &registry = *unwrap(creg);

Expand Down Expand Up @@ -513,6 +517,11 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
mlir::affine::registerAffinePasses();
mlir::registerReconcileUnrealizedCasts();

mlir::registerLLVMDialectImport(registry);
mlir::registerNVVMDialectImport(registry);

mlir::LLVM::registerInlinerInterface(registry);

/*
registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
Expand Down Expand Up @@ -540,6 +549,81 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
mlir::enzyme::registerEnzymeJaxTransformExtension(registry);
}


/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
/// suffix in `lastUsedID`.
static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName,
unsigned &lastUsedID,
mlir::ModuleOp source,
mlir::ModuleOp target) {
using namespace llvm;
using namespace mlir;
SmallString<64> newSymName(oldSymName);
newSymName.push_back('_');
while (true) {
auto possible = newSymName + Twine(++lastUsedID);
if (!SymbolTable::lookupSymbolIn(source, possible.str()) && !SymbolTable::lookupSymbolIn(target, possible.str())) {
return StringAttr::get(target.getContext(), possible);
}
}
}


/// Checks if a symbol with the same name as `op` already exists in `source`.
/// If so, renames `op` and updates all its references in `target`.
static mlir::LogicalResult
updateSymbolAndAllUses(mlir::SymbolOpInterface op, mlir::ModuleOp source, mlir::ModuleOp target,
unsigned &lastUsedID) {
using namespace llvm;
using namespace mlir;

auto opName = op.getName().str();

if (!SymbolTable::lookupSymbolIn(target, opName)) {
return success();
}

StringAttr newSymName =
renameSymbol(opName, lastUsedID, source, target);

if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, source)))
return op.emitError("unable to update all symbol uses for ")
<< opName << " to " << newSymName;

SymbolTable::setSymbolName(op, newSymName);
return success();
}

extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, const char* entryfn) {
auto prevMod = cast<ModuleOp>(*unwrap(prevModC));
auto newMod = cast<ModuleOp>(*unwrap(newModC));

Operation* entryFn = nullptr;

unsigned lastUsedID = 0;

for (auto &op : *newMod.getBody()) {
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;

StringRef oldSymName = symbolOp.getName();

if (oldSymName == entryfn) {
entryFn = &op;
}

if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod,
lastUsedID))) {
assert(0 && "failed to update all uses");
}
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
}
prevMod.getBody()->getOperations().splice(prevMod.getBody()->getOperations().end(),
newMod.getBody()->getOperations());
return wrap(entryFn);
}

#pragma region xla::ifrt

#pragma region xla::ifrt::Value
Expand Down
5 changes: 5 additions & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ cc_library(
"-Wl,-exported_symbol,_BufferToHost",
"-Wl,-exported_symbol,_FreeClient",
"-Wl,-exported_symbol,_ClientCompile",
"-Wl,-exported_symbol,_LinkInModule",
"-Wl,-exported_symbol,_FreeFuture",
"-Wl,-exported_symbol,_FutureIsReady",
"-Wl,-exported_symbol,_FutureAwait",
Expand Down Expand Up @@ -451,6 +452,10 @@ cc_library(
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:Transforms",

"@llvm-project//mlir:LLVMIRToLLVMTranslation",
"@llvm-project//mlir:LLVMIRToNVVMTranslation",
"@llvm-project//mlir:LLVMIRTransforms",

"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:AArch64AsmParser",
Expand Down
Loading
Loading