Skip to content

Commit

Permalink
small refactoring for mlir backend (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippGrulich authored Oct 8, 2024
1 parent 7adef5c commit 1ff4204
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 212 deletions.
6 changes: 0 additions & 6 deletions nautilus/src/nautilus/compiler/backends/mlir/JITCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ JITCompiler::jitCompileModule(::mlir::OwningOpRef<::mlir::ModuleOp>& mlirModule,
LLVMInitializeNativeTarget();
LLVMInitializeNativeAsmPrinter();

//(void) dumpHelper;
// if (compilerOptions.isDumpToConsole() || compilerOptions.isDumpToFile()) {
// dumpLLVMIR(mlirModule.get(), compilerOptions, dumpHelper);
// }

// Create MLIR execution engine (wrapper around LLVM ExecutionEngine).
::mlir::ExecutionEngineOptions options;
options.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive;
Expand All @@ -46,7 +41,6 @@ JITCompiler::jitCompileModule(::mlir::OwningOpRef<::mlir::ModuleOp>& mlirModule,
// We register all external functions (symbols) that we do not inline.
const auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
auto symbolMap = llvm::orc::SymbolMap();

for (int i = 0; i < (int) jitProxyFunctionSymbols.size(); ++i) {
auto address = jitProxyFunctionTargetAddresses.at(i);
symbolMap[interner(jitProxyFunctionSymbols.at(i))] = {llvm::orc::ExecutorAddr::fromPtr(address),
Expand Down
16 changes: 3 additions & 13 deletions nautilus/src/nautilus/compiler/backends/mlir/LLVMIROptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ int getOptimizationLevel(const engine::Options& options) {
return options.getOptionOrDefault("mlir.optimizationLevel", 3);
}

LLVMIROptimizer::LLVMIROptimizer() = default;
LLVMIROptimizer::~LLVMIROptimizer() = default;

std::function<llvm::Error(llvm::Module*)> LLVMIROptimizer::getLLVMOptimizerPipeline(const engine::Options& options,
const DumpHandler& handler) {
// Return LLVM optimizer pipeline.
Expand All @@ -24,8 +27,6 @@ std::function<llvm::Error(llvm::Module*)> LLVMIROptimizer::getLLVMOptimizerPipel
constexpr int SIZE_LEVEL = 0;
// Create A target-specific target machine for the host
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
// NES_ASSERT2_FMT(tmBuilderOrError, "Failed to create a
// JITTargetMachineBuilder for the host");
auto targetMachine = tmBuilderOrError->createTargetMachine();
llvm::TargetMachine* targetMachinePtr = targetMachine->get();
targetMachinePtr->setOptLevel(llvm::CodeGenOptLevel::Aggressive);
Expand All @@ -40,17 +41,6 @@ std::function<llvm::Error(llvm::Module*)> LLVMIROptimizer::getLLVMOptimizerPipel
~0, llvm::Attribute::get(llvmIRModule->getContext(), "tune-cpu", targetMachinePtr->getTargetCPU()));
llvm::SMDiagnostic Err;

// Load LLVM IR module from proxy inlining input path (We assert that it
// exists in CompilationOptions). if (options.isProxyInlining()) {
// auto proxyFunctionsIR =
// llvm::parseIRFile(options.getProxyInliningInputPath(), Err,
// llvmIRModule->getContext());
// Link the module with our generated LLVM IR module and optimize the linked
// LLVM IR module (inlining happens during optimization).
// llvm::Linker::linkModules(*llvmIRModule, std::move(proxyFunctionsIR),
// llvm::Linker::Flags::OverrideFromSrc);
// }

auto optPipeline =
::mlir::makeOptimizingTransformer(getOptimizationLevel(options), SIZE_LEVEL, targetMachinePtr);
auto optimizedModule = optPipeline(llvmIRModule);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,10 @@ ::mlir::Type MLIRLoweringProvider::getMLIRType(Type type) {
case Type::ptr:
return mlir::LLVM::LLVMPointerType::get(context);
}

throw NotImplementedException("No matching type for stamp ");
}

std::vector<mlir::Type> MLIRLoweringProvider::getMLIRType(std::vector<ir::Operation*> types) {
std::vector<mlir::Type> MLIRLoweringProvider::getMLIRType(const std::vector<ir::Operation*>& types) {
std::vector<mlir::Type> resultTypes;
for (auto& type : types) {
resultTypes.push_back(getMLIRType(type->getStamp()));
Expand All @@ -73,7 +72,6 @@ mlir::Value MLIRLoweringProvider::getConstBool(const std::string& location, bool
builder->getIntegerAttr(builder->getIndexType(), value));
}

// Todo Issue #3004: Currently, we are simply adding 'Query_1' as the
// FileLineLoc name. Moreover,
// the provided 'name' often is not meaningful either.
mlir::Location MLIRLoweringProvider::getNameLoc(const std::string& name) {
Expand Down Expand Up @@ -184,8 +182,9 @@ mlir::arith::CmpIPredicate convertToBooleanMLIRComparison(ir::CompareOperation::
}

mlir::FlatSymbolRefAttr MLIRLoweringProvider::insertExternalFunction(const std::string& name, void* functionPtr,
mlir::Type resultType,
std::vector<mlir::Type> argTypes, bool varArgs) {
const mlir::Type& resultType,
const std::vector<mlir::Type>& argTypes,
bool varArgs) {
// Create function arg & result types (currently only int for result).
mlir::LLVM::LLVMFunctionType llvmFnType = mlir::LLVM::LLVMFunctionType::get(resultType, argTypes, varArgs);

Expand Down Expand Up @@ -244,7 +243,6 @@ void MLIRLoweringProvider::generateMLIR(const ir::BasicBlock* basicBlock, ValueF
void MLIRLoweringProvider::generateMLIR(const std::unique_ptr<ir::Operation>& operation, ValueFrame& frame) {
switch (operation->getOperationType()) {
case ir::Operation::OperationType::FunctionOp:
// generateMLIR(as<ir::FunctionOperation>(operation), frame);
break;
case ir::Operation::OperationType::ConstIntOp:
generateMLIR(as<ir::ConstIntOperation>(operation), frame);
Expand Down Expand Up @@ -340,24 +338,14 @@ void MLIRLoweringProvider::generateMLIR(ir::OrOperation* orOperation, ValueFrame
auto leftInput = frame.getValue(orOperation->getLeftInput()->getIdentifier());
auto rightInput = frame.getValue(orOperation->getRightInput()->getIdentifier());
auto mlirOrOp = builder->create<mlir::LLVM::OrOp>(getNameLoc("binOpResult"), leftInput, rightInput);
frame.setValue(orOperation->

getIdentifier(),
mlirOrOp

);
frame.setValue(orOperation->getIdentifier(), mlirOrOp);
}

void MLIRLoweringProvider::generateMLIR(ir::AndOperation* andOperation, ValueFrame& frame) {
auto leftInput = frame.getValue(andOperation->getLeftInput()->getIdentifier());
auto rightInput = frame.getValue(andOperation->getRightInput()->getIdentifier());
auto mlirAndOp = builder->create<mlir::LLVM::AndOp>(getNameLoc("binOpResult"), leftInput, rightInput);
frame.setValue(andOperation->

getIdentifier(),
mlirAndOp

);
frame.setValue(andOperation->getIdentifier(), mlirAndOp);
}

void MLIRLoweringProvider::generateMLIR(const ir::FunctionOperation& functionOp, ValueFrame& frame) {
Expand All @@ -378,30 +366,16 @@ void MLIRLoweringProvider::generateMLIR(const ir::FunctionOperation& functionOp,
} else if (isSignedInteger(functionOp.getStamp())) {
mlirFunction.setResultAttr(0, "llvm.signext", mlir::UnitAttr::get(context));
}
// mlirFunction.setArgAttr(0, "llvm.signext", mlir::UnitAttr::get(context));

mlirFunction.

addEntryBlock();
mlirFunction.addEntryBlock();

// Set InsertPoint to beginning of the execute function.
builder->setInsertionPointToStart(&mlirFunction
.

getBody()

.

front()

);
builder->setInsertionPointToStart(&mlirFunction.getBody().front());

// Store references to function args in the valueMap map.
auto valueMapIterator = mlirFunction.args_begin();
for (int i = 0; i < (int) functionOp.getFunctionBasicBlock().getArguments().size(); ++i) {
frame.setValue(functionOp.getFunctionBasicBlock().getArguments().at(i)->getIdentifier(), valueMapIterator[i]

);
frame.setValue(functionOp.getFunctionBasicBlock().getArguments().at(i)->getIdentifier(), valueMapIterator[i]);
}

// Generate MLIR for operations in function body (BasicBlock).
Expand All @@ -411,30 +385,19 @@ void MLIRLoweringProvider::generateMLIR(const ir::FunctionOperation& functionOp,
}

void MLIRLoweringProvider::generateMLIR(ir::LoadOperation* loadOp, ValueFrame& frame) {

auto address = frame.getValue(loadOp->getAddress()->getIdentifier());

// auto bitcast = builder->create<mlir::LLVM::BitcastOp>(getNameLoc("Bitcasted
// address"),
// mlir::LLVM::LLVMPointerType::get(context),
// address);
auto mlirLoadOp =
builder->create<mlir::LLVM::LoadOp>(getNameLoc("loadedValue"), getMLIRType(loadOp->getStamp()), address);
frame.setValue(loadOp->getIdentifier(), mlirLoadOp);
}

void MLIRLoweringProvider::generateMLIR(ir::ConstIntOperation* constIntOp, ValueFrame& frame) {
if (!frame.contains(constIntOp->getIdentifier())) {
frame.setValue(constIntOp->getIdentifier(),
getConstInt("ConstantOp", constIntOp->getStamp(), constIntOp->getValue()));
} else {
frame.setValue(constIntOp->getIdentifier(),
getConstInt("ConstantOp", constIntOp->getStamp(), constIntOp->getValue()));
}
frame.setValue(constIntOp->getIdentifier(),
getConstInt("ConstantOp", constIntOp->getStamp(), constIntOp->getValue()));
}

void MLIRLoweringProvider::generateMLIR(ir::ConstPtrOperation* constPtr, ValueFrame& frame) {
int64_t val = (int64_t) constPtr->getValue();
auto val = (int64_t) constPtr->getValue();
auto constInt = builder->create<mlir::arith::ConstantOp>(getNameLoc("location"), builder->getI64Type(),
builder->getIntegerAttr(builder->getI64Type(), val));
auto elementAddress = builder->create<mlir::LLVM::IntToPtrOp>(getNameLoc("fieldAccess"),
Expand Down Expand Up @@ -463,7 +426,6 @@ void MLIRLoweringProvider::generateMLIR(ir::AddOperation* addOp, ValueFrame& fra
getNameLoc("fieldAccess"), mlir::LLVM::LLVMPointerType::get(context), builder->getI8Type(), leftInput,
mlir::ArrayRef<mlir::Value>({rightInput}));
frame.setValue(addOp->getIdentifier(), elementAddress);

} else if (isFloat(addOp->getStamp())) {
auto mlirAddOp = builder->create<mlir::LLVM::FAddOp>(getNameLoc("binOpResult"), leftInput.getType(), leftInput,
rightInput, mlir::LLVM::FastmathFlags::fast);
Expand All @@ -490,7 +452,6 @@ void MLIRLoweringProvider::generateMLIR(ir::SubOperation* subIntOp, ValueFrame&
getNameLoc("fieldAccess"), mlir::LLVM::LLVMPointerType::get(context), builder->getI8Type(), leftInput,
mlir::ArrayRef<mlir::Value>({rightInput}));
frame.setValue(subIntOp->getIdentifier(), elementAddress);

} else if (isFloat(subIntOp->getStamp())) {
auto mlirSubOp = builder->create<mlir::LLVM::FSubOp>(
getNameLoc("binOpResult"), leftInput, rightInput,
Expand Down Expand Up @@ -605,20 +566,6 @@ void MLIRLoweringProvider::generateMLIR(ir::CompareOperation* compareOp, ValueFr
if ((isInteger(leftStamp) && isFloat(rightStamp)) || ((isInteger(rightStamp) && isFloat(leftStamp)))) {
// Avoid comparing integer to float
throw NotImplementedException("Type missmatch: cannot compare");
} else if (compareOp->getComparator() == ir::CompareOperation::EQ &&
compareOp->getLeftInput()->getStamp() == Type::ptr &&
isInteger(compareOp->getRightInput()->getStamp())) {
// add null check
throw NotImplementedException("Null check is not implemented");
// auto null =
// builder->create<mlir::LLVM::NullOp>(getNameLoc("null"),
// mlir::LLVM::LLVMPointerType::get(context));
// auto cmpOp =
// builder->create<mlir::LLVM::ICmpOp>(getNameLoc("comparison"),
// mlir::LLVM::ICmpPredicate::eq,
// frame.getValue(compareOp->getLeftInput()->getIdentifier()),
// null);
// frame.setValue(compareOp->getIdentifier(), cmpOp);
} else if (isInteger(leftStamp) && isInteger(rightStamp)) {
// handle integer
auto cmpOp = builder->create<mlir::arith::CmpIOp>(
Expand Down Expand Up @@ -814,7 +761,6 @@ void MLIRLoweringProvider::generateMLIR(ir::BinaryCompOperation* binaryCompOpera
nautilus::compiler::mlir::MLIRLoweringProvider::ValueFrame& frame) {
auto leftInput = frame.getValue(binaryCompOperation->getLeftInput()->getIdentifier());
auto rightInput = frame.getValue(binaryCompOperation->getRightInput()->getIdentifier());

mlir::Value op;
switch (binaryCompOperation->getType()) {
case ir::BinaryCompOperation::BAND:
Expand All @@ -834,7 +780,6 @@ void MLIRLoweringProvider::generateMLIR(ir::ShiftOperation* shiftOperation,
nautilus::compiler::mlir::MLIRLoweringProvider::ValueFrame& frame) {
auto leftInput = frame.getValue(shiftOperation->getLeftInput()->getIdentifier());
auto rightInput = frame.getValue(shiftOperation->getRightInput()->getIdentifier());

mlir::Value op;
switch (shiftOperation->getType()) {
case ir::ShiftOperation::LS:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class MLIRLoweringProvider {
* @param varArgs: Include variable arguments.
* @return FlatSymbolRefAttr: Reference to function used in CallOps.
*/
::mlir::FlatSymbolRefAttr insertExternalFunction(const std::string& name, void* functionPtr, ::mlir::Type resultType, std::vector<::mlir::Type> argTypes, bool varArgs);
::mlir::FlatSymbolRefAttr insertExternalFunction(const std::string& name, void* functionPtr, const ::mlir::Type& resultType, const std::vector<::mlir::Type>& argTypes, bool varArgs);

/**
* @brief Generates a Name(d)Loc(ation) that is attached to the operation.
Expand All @@ -167,7 +167,7 @@ class MLIRLoweringProvider {
* @param types: Vector of basic types.
* @return mlir::Type: Vector of MLIR types.
*/
std::vector<::mlir::Type> getMLIRType(std::vector<ir::Operation*> types);
std::vector<::mlir::Type> getMLIRType(const std::vector<ir::Operation*>& types);

/**
* @brief Get a constant MLIR Integer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ namespace nautilus::compiler::mlir {
class MLIRPassManager {
public:
enum class OptimizationPass : uint8_t { Inline };

MLIRPassManager(); // Disable default constructor
~MLIRPassManager(); // Disable default destructor

static int lowerAndOptimizeMLIRModule(::mlir::OwningOpRef<::mlir::ModuleOp>& module, const std::vector<OptimizationPass>& optimizationPasses);
};
} // namespace nautilus::compiler::mlir
70 changes: 0 additions & 70 deletions nautilus/src/nautilus/compiler/backends/mlir/MLIRUtility.cpp
Original file line number Diff line number Diff line change
@@ -1,70 +0,0 @@


#include "nautilus/compiler/backends/mlir/MLIRUtility.hpp>
#include "nautilus/compiler/backends/mlir/JITCompiler.hpp>
#include "nautilus/compiler/backends/mlir/LLVMIROptimizer.hpp>
#include "nautilus/compiler/backends/mlir/MLIRLoweringProvider.hpp>
#include "nautilus/compiler/backends/mlir/MLIRPassManager.hpp>
#include <Util/Logger/Logger.hpp>
#include <mlir/AsmParser/AsmParser.h>
#include <mlir/Parser/Parser.h>
namespace nautilus::compiler::mlir {
void MLIRUtility::writeMLIRModuleToFile(mlir::OwningOpRef<mlir::ModuleOp>& mlirModule, std::string mlirFilepath) {
std::string mlirString;
llvm::raw_string_ostream llvmStringStream(mlirString);
auto* basicError = new std::error_code();
llvm::raw_fd_ostream fileStream(mlirFilepath, *basicError);
auto* opPrintFlags = new mlir::OpPrintingFlags();
mlirModule->print(llvmStringStream, *opPrintFlags);
if (!mlirFilepath.empty()) {
fileStream.write(mlirString.c_str(), mlirString.length());
}
NES_DEBUG(mlirString.c_str());
}
int MLIRUtility::loadAndExecuteModuleFromString(const std::string& mlirString, const std::string& moduleString) {
mlir::MLIRContext context;
mlir::ParserConfig config(&context);
auto mlirModule = mlir::parseSourceString<mlir::ModuleOp>(mlirString, config);
// Take the MLIR module from the MLIRLoweringProvider and apply lowering and
// optimization passes.
if (!MLIR::MLIRPassManager::lowerAndOptimizeMLIRModule(mlirModule, {}, {})) {
NES_FATAL_ERROR("Could not lower and optimize MLIR");
}
// Lower MLIR module to LLVM IR and create LLVM IR optimization pipeline.
auto optPipeline = MLIR::LLVMIROptimizer::getLLVMOptimizerPipeline(/*inlining*/ false);
// JIT compile LLVM IR module and return engine that provides access compiled
// execute function.
auto engine = MLIR::JITCompiler::jitCompileModule(mlirModule, optPipeline, {}, {});
if (!engine->invoke(moduleString)) {
return -1;
} else
return 0;
}
std::unique_ptr<mlir::ExecutionEngine>
MLIRUtility::compileNESIRToMachineCode(std::shared_ptr<NES::Nautilus::IR::IRGraph> ir) {
mlir::MLIRContext context;
auto loweringProvider = std::make_unique<MLIR::MLIRLoweringProvider>(context);
auto module = loweringProvider->generateModuleFromIR(ir);
// Take the MLIR module from the MLIRLoweringProvider and apply lowering and
// optimization passes.
if (MLIR::MLIRPassManager::lowerAndOptimizeMLIRModule(module, {}, {})) {
NES_FATAL_ERROR("Could not lower and optimize MLIR");
}
// Lower MLIR module to LLVM IR and create LLVM IR optimization pipeline.
auto optPipeline = MLIR::LLVMIROptimizer::getLLVMOptimizerPipeline(/*inlining*/ false);
// JIT compile LLVM IR module and return engine that provides access compiled
// execute function.
return MLIR::JITCompiler::jitCompileModule(module, optPipeline, loweringProvider->getJitProxyFunctionSymbols(),
loweringProvider->getJitProxyTargetAddresses());
}
} // namespace nautilus::compiler::mlir
44 changes: 0 additions & 44 deletions nautilus/src/nautilus/compiler/backends/mlir/MLIRUtility.hpp

This file was deleted.

Loading

0 comments on commit 1ff4204

Please sign in to comment.