Skip to content

Commit

Permalink
Call RewriteONNXToONNX from the driver
Browse files Browse the repository at this point in the history
Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld committed Sep 11, 2023
1 parent d432ee5 commit e55713b
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
7 changes: 1 addition & 6 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) {
pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
}
// Rewrite ONNX operators.
pm.addPass(onnx_mlir::createRewriteONNXToONNXPass());
// Convolution Optimization for CPU: enable when there are no accelerators.
if (targetCPU && enableConvOptPass) {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvOptONNXToONNXPass(
Expand All @@ -97,16 +95,14 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) {
// Statically add extra passes
for (int i = 0; i < repeatOnnxTransform; i++) {
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass());
pm.addPass(onnx_mlir::createRewriteONNXToONNXPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createConstPropONNXToONNXPass());
}
}

// Simplify shape-related ops.
pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass());

// One more call to ONNX shape inference/canonicalization/... to update shape
// if possible.
if (enableONNXHybridPass) {
Expand All @@ -116,7 +112,6 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) {
} else {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(onnx_mlir::createRewriteONNXToONNXPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
}

Expand Down
1 change: 1 addition & 0 deletions src/Transform/ONNX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ add_onnx_mlir_library(OMOpTransform
OMONNXOps
MLIRPass
OMONNXRewrite
OMONNXSimplifyShapeRelatedOps
OMShapeInferencePass
MLIRTransforms
)
Expand Down
3 changes: 3 additions & 0 deletions src/Transform/ONNX/ONNXOpTransformPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ void ONNXOpTransformPass::runOnOperation() {
}
dynamicPM.addNestedPass<func::FuncOp>(
onnx_mlir::createConstPropONNXToONNXPass());
// Simplify shape-related ops.
dynamicPM.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass());
// Rewrite ONNX operators.
dynamicPM.addPass(onnx_mlir::createRewriteONNXToONNXPass());
if (failed(runPipeline(dynamicPM, module)))
return signalPassFailure();
Expand Down

0 comments on commit e55713b

Please sign in to comment.