Skip to content

Commit

Permalink
reuse dialect registration/loading (#2391)
Browse files Browse the repository at this point in the history
* CompilerDialects
* remove obsolete Accelerator::getOrLoadDialects()
* rely on registerDialects() to initialize accelerators
* remove non-functional accelerator support from onnx-mlir-reduce
* cleaned up CompilerUtils.hpp includes
* document loadDialects()
* clean up onnx-mlir-opt exit codes

Signed-off-by: Soren Lassen <[email protected]>
  • Loading branch information
sorenlassen authored Jul 26, 2023
1 parent d579ce1 commit 760e2c9
Show file tree
Hide file tree
Showing 22 changed files with 164 additions and 148 deletions.
3 changes: 0 additions & 3 deletions docs/AddCustomAccelerators.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ We provide a base class [onnx_mlir::accel::Accelerator](../src/Accelerators/Acce
// Hooks for onnx-mlir driver
//===--------------------------------------------------------------------===//

/// Load the MLIR dialects necessary to generate code for an accelerator.
virtual void getOrLoadDialects(mlir::MLIRContext &context) const = 0;

/// Add the transformations necessary to support the accelerator.
virtual void addPasses(mlir::OwningOpRef<mlir::ModuleOp> &module,
mlir::PassManager &pm,
Expand Down
3 changes: 0 additions & 3 deletions src/Accelerators/Accelerator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,6 @@ class Accelerator {
// Hooks for onnx-mlir driver
//===--------------------------------------------------------------------===//

/// Load the MLIR dialects necessary to generate code for an accelerator.
virtual void getOrLoadDialects(mlir::MLIRContext &context) const = 0;

/// Add the transformations necessary to support the accelerator.
virtual void addPasses(mlir::OwningOpRef<mlir::ModuleOp> &module,
mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget,
Expand Down
6 changes: 0 additions & 6 deletions src/Accelerators/NNPA/NNPAAccelerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ NNPAAccelerator::~NNPAAccelerator() { delete instance; }

uint64_t NNPAAccelerator::getVersionNumber() const { return ZDNN_VERNUM; }

void NNPAAccelerator::getOrLoadDialects(mlir::MLIRContext &context) const {
LLVM_DEBUG(llvm::dbgs() << "Loading dialects for NNPA accelerator\n");
context.getOrLoadDialect<zhigh::ZHighDialect>();
context.getOrLoadDialect<zlow::ZLowDialect>();
}

void NNPAAccelerator::addPasses(mlir::OwningOpRef<mlir::ModuleOp> &module,
mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget,
std::string outputNameNoExt) const {
Expand Down
1 change: 0 additions & 1 deletion src/Accelerators/NNPA/NNPAAccelerator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class NNPAAccelerator final : public Accelerator {
//===--------------------------------------------------------------------===//
// Hooks for onnx-mlir-opt driver
//===--------------------------------------------------------------------===//
virtual void getOrLoadDialects(mlir::MLIRContext &context) const final;
virtual void addPasses(mlir::OwningOpRef<mlir::ModuleOp> &module,
mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget,
std::string outputNameNoExt) const final;
Expand Down
15 changes: 14 additions & 1 deletion src/Compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ add_onnx_mlir_library(OMCompilerOptions
OMAccelerator
)


add_onnx_mlir_library(OMCompilerDialects
CompilerDialects.cpp

LINK_LIBS PUBLIC
OMAccelerator
OMInitAccelerators
OMKrnlOps
OMONNXOps
MLIRIR
)

add_onnx_mlir_library(OMCompilerPasses
CompilerPasses.cpp
DisposableGarbageCollector.cpp
Expand Down Expand Up @@ -150,9 +162,9 @@ add_onnx_mlir_library(OMCompilerUtils

LINK_LIBS PUBLIC
${OMLibs}
OMCompilerDialects
OMCompilerPasses
OMAccelerator
OMInitAccelerators
OMVersion
MLIRIR

Expand Down Expand Up @@ -190,6 +202,7 @@ add_onnx_mlir_library(OMCompiler
EXCLUDE_FROM_OM_LIBS

LINK_LIBS PRIVATE
OMCompilerDialects
OMCompilerUtils
)

Expand Down
48 changes: 48 additions & 0 deletions src/Compiler/CompilerDialects.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===------------------------ CompilerDialects.cpp ------------------------===//

#include "CompilerDialects.hpp"

#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/ONNX/ONNXDialect.hpp"

#include "mlir/InitAllDialects.h"

using namespace mlir;

namespace onnx_mlir {

DialectRegistry registerDialects(ArrayRef<accel::Accelerator::Kind> accels) {
DialectRegistry registry;

// Note that we cannot consult command line options because they have not yet
// been parsed when registerDialects() is called.

registry.insert<arith::ArithDialect>();
registry.insert<linalg::LinalgDialect>();
registry.insert<affine::AffineDialect>();
registry.insert<LLVM::LLVMDialect>();
registry.insert<scf::SCFDialect>();
registry.insert<func::FuncDialect>();
registry.insert<vector::VectorDialect>();
registry.insert<shape::ShapeDialect>();
registry.insert<math::MathDialect>();
registry.insert<memref::MemRefDialect>();
registry.insert<ONNXDialect>();
registry.insert<KrnlDialect>();
registry.insert<cf::ControlFlowDialect>();

// Initialize accelerator(s) if required.
accel::initAccelerators(accels);

// Register dialects for accelerators.
for (auto *accel : accel::Accelerator::getAccelerators())
accel->registerDialects(registry);

return registry;
}

} // namespace onnx_mlir
21 changes: 21 additions & 0 deletions src/Compiler/CompilerDialects.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===------------------------ CompilerDialects.hpp ------------------------===//

#pragma once

#include "src/Accelerators/Accelerator.hpp"

#include "mlir/IR/DialectRegistry.h"
#include "llvm/ADT/ArrayRef.h"

namespace onnx_mlir {

// Adds the mlir and onnx-mlir dialects needed to compile end to end.
// Initializes accelerator(s) if required.
mlir::DialectRegistry registerDialects(
llvm::ArrayRef<accel::Accelerator::Kind> accels);

} // namespace onnx_mlir
47 changes: 14 additions & 33 deletions src/Compiler/CompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,47 @@
//
//===----------------------------------------------------------------------===//

#include "CompilerUtils.hpp"
#include "ExternalUtil.hpp"

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/Program.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Target/TargetMachine.h"

#include "ExternalUtil.hpp"

#include "src/Accelerators/Accelerator.hpp"
#include "src/Builder/FrontendDialectTransformer.hpp"
#include "src/Compiler/CompilerDialects.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp"
#include "src/Compiler/HeapReporter.hpp"
#include "src/Dialect/ONNX/ONNXDialect.hpp"
#include "src/Version/Version.hpp"

#include <fstream>
#include <regex>

#define DEBUG_TYPE "compiler_utils"

using namespace mlir;
using namespace onnx_mlir;

const std::string OnnxMlirEnvOptionName = "ONNX_MLIR_FLAGS";

namespace onnx_mlir {

// Return the vendor name if specified during make processing or the default.
std::string getVendorName() {
#if defined(ONNX_MLIR_VENDOR)
return ONNX_MLIR_VENDOR;
#else
return "ONNX-MLIR";
#endif
}

std::optional<std::string> getEnvVar(std::string name) {
if (const char *envVerbose = std::getenv(name.c_str()))
return std::string(envVerbose);
Expand Down Expand Up @@ -641,18 +636,9 @@ static int compileModuleToJniJar(
return genJniJar(module, modelSharedLibPath, modelJniJarPath);
}

void registerDialects(mlir::MLIRContext &context) {
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::affine::AffineDialect>();
context.getOrLoadDialect<mlir::vector::VectorDialect>();
context.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
context.getOrLoadDialect<mlir::scf::SCFDialect>();
context.getOrLoadDialect<mlir::func::FuncDialect>();
context.getOrLoadDialect<mlir::shape::ShapeDialect>();
context.getOrLoadDialect<mlir::math::MathDialect>();
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
context.getOrLoadDialect<mlir::ONNXDialect>();
context.getOrLoadDialect<mlir::KrnlDialect>();
void loadDialects(mlir::MLIRContext &context) {
context.appendDialectRegistry(registerDialects(maccel));
context.loadAllAvailableDialects();
}

namespace {
Expand Down Expand Up @@ -951,10 +937,6 @@ static int emitOutput(mlir::OwningOpRef<ModuleOp> &module,
int compileModule(mlir::OwningOpRef<ModuleOp> &module,
mlir::MLIRContext &context, std::string outputNameNoExt,
EmissionTargetType emissionTarget) {
// Initialize accelerator(s) if required.
if (!maccel.empty())
onnx_mlir::accel::initAccelerators(maccel);

int rc = setupModule(module, context, outputNameNoExt);
if (rc != CompilerSuccess)
return rc;
Expand All @@ -968,7 +950,6 @@ int compileModule(mlir::OwningOpRef<ModuleOp> &module,
bool hasAccel = false;
for (auto *accel : onnx_mlir::accel::Accelerator::getAccelerators()) {
hasAccel = true;
accel->getOrLoadDialects(context);
accel->addPasses(module, pm, emissionTarget, outputNameNoExt);
}
if (!hasAccel)
Expand Down
42 changes: 9 additions & 33 deletions src/Compiler/CompilerUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,19 @@

#pragma once

#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/Support/FileUtilities.h"

#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Transforms/Passes.h"
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Builder/FrontendDialectTransformer.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Compiler/CompilerPasses.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Pass/Passes.hpp"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OwningOpRef.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/Program.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Target/TargetMachine.h"

#include "src/Accelerators/Accelerator.hpp"
#include "src/Version/Version.hpp"
#include <optional>
#include <string>
#include <vector>

namespace onnx_mlir {

std::string getVendorName();

std::optional<std::string> getEnvVar(std::string name);

struct Command {
Expand All @@ -71,7 +45,9 @@ struct Command {
int exec(std::string wdir = "") const;
};

void registerDialects(mlir::MLIRContext &context);
// Registers and loads the mlir and onnx-mlir dialects needed to compile
// end to end. Initializes accelerator(s) if required.
void loadDialects(mlir::MLIRContext &context);

// Get Tool path, see comments in CompilerUtils.cpp for more details.
std::string getToolPath(
Expand Down
8 changes: 5 additions & 3 deletions src/Compiler/OnnxMlirCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

#include "include/OnnxMlirCompiler.h"
#include "ExternalUtil.hpp"
#include "src/Compiler/CompilerDialects.hpp"
#include "src/Compiler/CompilerUtils.hpp"
#include "llvm/Support/FileSystem.h"

using namespace mlir;
using namespace onnx_mlir;
Expand Down Expand Up @@ -140,9 +142,9 @@ ONNX_MLIR_EXPORT int64_t omCompileFromArray(const void *inputBuffer,
if (errorMessage)
*errorMessage = nullptr;

mlir::OwningOpRef<mlir::ModuleOp> module;
mlir::MLIRContext context;
registerDialects(context);
OwningOpRef<ModuleOp> module;
MLIRContext context;
loadDialects(context);

std::string internalErrorMessage;
int rc = processInputArray(
Expand Down
5 changes: 3 additions & 2 deletions src/Tools/onnx-mlir-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ add_onnx_mlir_executable(onnx-mlir-opt

LINK_LIBS PRIVATE
${OMLibs}
OMCompilerDialects
OMCompilerOptions
OMCompilerUtils
OMCompilerPasses
OMAccelerator
OMInitAccelerators
OMVersion
MLIRAffineTransforms
MLIRLinalgTransforms
MLIRMemRefTransforms
Expand Down
Loading

0 comments on commit 760e2c9

Please sign in to comment.