Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX to Zhigh guided by cost model #2507

Closed
7 changes: 7 additions & 0 deletions src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,11 @@ llvm::cl::opt<bool> nnpaEnableZHighToOnnx("enable-zhigh-to-onnx",
"level. Default is true."),
llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions));

llvm::cl::opt<bool> nnpaEnableZHighCostModel("enable-zhigh-cost-model",
llvm::cl::desc(
"Enabling a performance cost model to estimate the benefit of "
"migrating an eligible onnx operation to a ZHigh operation. Default is "
"false."),
llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions));

} // namespace onnx_mlir
1 change: 1 addition & 0 deletions src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ extern llvm::cl::opt<onnx_mlir::NNPAEmissionTargetType> nnpaEmissionTarget;
extern llvm::cl::list<std::string> execNodesOnCpu;
extern llvm::cl::opt<bool> nnpaClipToDLFloatRange;
extern llvm::cl::opt<bool> nnpaEnableZHighToOnnx;
extern llvm::cl::opt<bool> nnpaEnableZHighCostModel;
extern llvm::cl::opt<bool> profileZHighIR;

} // namespace onnx_mlir
11 changes: 6 additions & 5 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ using namespace onnx_mlir;

namespace onnx_mlir {

void addONNXToZHighPasses(
mlir::PassManager &pm, ArrayRef<std::string> execNodesOnCpu) {
void addONNXToZHighPasses(mlir::PassManager &pm,
ArrayRef<std::string> execNodesOnCpu, bool useCostModel) {
for (unsigned i = 0; i < 3; i++) {
// Repeat this process so that shape-related ops such as Shape, Expand,
// Gather generated during RewriteONNXForZHigh will become constants.
pm.addPass(onnx_mlir::createRewriteONNXForZHighPass(execNodesOnCpu));
pm.addPass(onnx_mlir::createRewriteONNXForZHighPass(
execNodesOnCpu, false /*useCostModel*/));
// Simplify shape-related ops, including ShapeOp-to-DimOp replacement,
// constant propagation, shape inference and canonicalize.
pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass());
Expand All @@ -75,7 +76,7 @@ void addONNXToZHighPasses(
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions));

pm.addPass(onnx_mlir::createONNXToZHighPass(execNodesOnCpu));
pm.addPass(onnx_mlir::createONNXToZHighPass(execNodesOnCpu, useCostModel));
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
// There are more opportunities for const propagation once all zhigh ops were
// generated.
Expand Down Expand Up @@ -155,7 +156,7 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,

if (emissionTarget >= EmitMLIR) {
// Lower zAIU-compatible ONNX ops to ZHigh dialect where possible.
addONNXToZHighPasses(pm, execNodesOnCpu);
addONNXToZHighPasses(pm, execNodesOnCpu, nnpaEnableZHighCostModel);

if (nnpaEmissionTarget >= EmitZHighIR)
emissionTarget = EmitMLIR;
Expand Down
2 changes: 2 additions & 0 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_onnx_mlir_library(OMONNXToZHigh
ONNXLegalityCheck.cpp
ONNXToZHigh.cpp
ONNXToZHighCommon.cpp
ZHighPerfModel.cpp

DEPENDS
OMONNXONNXToZHighIncGen
Expand All @@ -25,6 +26,7 @@ add_onnx_mlir_library(OMRewriteONNXForZHigh
ONNXLegalityCheck.cpp
RewriteONNXForZHigh.cpp
ONNXToZHighCommon.cpp
ZHighPerfModel.cpp

DEPENDS
OMONNXRewriteONNXForZHighIncGen
Expand Down
Loading
Loading