From 191a9e1b96d8076d892213098611b25f85f6f69c Mon Sep 17 00:00:00 2001 From: hamptonm1 <79232909+hamptonm1@users.noreply.github.com> Date: Tue, 19 Sep 2023 11:55:06 -0400 Subject: [PATCH] Create Constant Propagation flag (#2498) * constant propagation flag --------- Signed-off-by: Megan Hampton Co-authored-by: Megan Hampton --- src/Compiler/CompilerOptions.cpp | 16 +++++++--- src/Compiler/CompilerOptions.hpp | 3 +- src/Compiler/CompilerPasses.cpp | 31 ++++++++++--------- src/Pass/Passes.hpp | 5 +-- src/Transform/ONNX/ConstProp.cpp | 29 ++++++++++++----- test/backend/inference_backend.py | 4 +-- .../conversion/rewrite-onnx-for-zhigh.mlir | 2 +- test/mlir/onnx/onnx_constprop.mlir | 2 +- .../onnx/onnx_constprop_expansion_bound.mlir | 2 +- .../onnx_constprop_no_shape_inference.mlir | 2 +- .../onnx/onnx_simplify_shape_related_ops.mlir | 2 +- 11 files changed, 62 insertions(+), 36 deletions(-) diff --git a/src/Compiler/CompilerOptions.cpp b/src/Compiler/CompilerOptions.cpp index 21a248af8d..ec36ad2f2b 100644 --- a/src/Compiler/CompilerOptions.cpp +++ b/src/Compiler/CompilerOptions.cpp @@ -4,7 +4,7 @@ //===------------------------ CompilerOptions.cpp -------------------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022, 2023 The IBM Research Authors. // // ============================================================================= // @@ -69,6 +69,7 @@ std::string reportHeapBefore; // onnx-mlir only std::string reportHeapAfter; // onnx-mlir only std::string modelTag; // onnx-mlir only bool enableConvOptPass; // onnx-mlir only +bool enableConstantProp; // onnx-mlir only std::vector extraLibPaths; // onnx-mlir only std::vector extraLibs; // onnx-mlir only ProfileIRs profileIR; // onnx-mlir only @@ -123,9 +124,10 @@ static llvm::cl::list OptimizationLevelOpt( llvm::cl::desc("Levels:"), llvm::cl::values(clEnumVal(O0, "Optimization level 0 (default):"), - clEnumVal(O1, "Optimization level 1,"), - clEnumVal(O2, "Optimization level 2,"), - clEnumVal(O3, "Optimization level 3.")), + clEnumVal(O1, "Optimization level 1"), + clEnumVal(O2, "Optimization level 2"), + clEnumVal(O3, + "Optimization level 3, constant propagation and SIMD is enabled")), llvm::cl::location(OptimizationLevel), llvm::cl::init(O0), llvm::cl::cat(OnnxMlirCommonOptions)); @@ -460,6 +462,12 @@ static llvm::cl::opt enableConvOptPassOpt("enable-conv-opt-pass", llvm::cl::location(enableConvOptPass), llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions)); +static llvm::cl::opt enableConstantPropOpt("enable-constant-prop", + llvm::cl::desc("Enable Constant Propagation (default is false)\n" + "Set to 'true' to enable Constant Propagation at Level O3."), + llvm::cl::location(enableConstantProp), llvm::cl::init(false), + llvm::cl::cat(OnnxMlirCommonOptions)); + static llvm::cl::list> extraLibPathsOpt( "L", llvm::cl::desc("Specify extra directories for libraries when compiling" diff --git a/src/Compiler/CompilerOptions.hpp b/src/Compiler/CompilerOptions.hpp index 5a26885717..431d0b3bd0 100644 --- a/src/Compiler/CompilerOptions.hpp +++ b/src/Compiler/CompilerOptions.hpp @@ -4,7 +4,7 @@ //===------------------------ CompilerOptions.hpp -------------------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022, 2023 The IBM Research Authors. // // ============================================================================= // @@ -111,6 +111,7 @@ extern std::string reportHeapBefore; // onnx-mlir only extern std::string reportHeapAfter; // onnx-mlir only extern std::string modelTag; // onnx-mlir only extern bool enableConvOptPass; // onnx-mlir only +extern bool enableConstantProp; // onnx-mlir only extern std::vector extraLibPaths; // onnx-mlir only extern std::vector extraLibs; // onnx-mlir only extern ProfileIRs profileIR; // onnx-mlir only diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index f873302cc5..49c5b4719a 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -45,8 +45,9 @@ namespace onnx_mlir { void configurePasses() { // Set global vector machine support. VectorMachineSupport::setGlobalVectorMachineSupport(march, mcpu, ""); - configureConstPropONNXToONNXPass( - onnxConstPropExpansionBound, onnxConstPropDisablePatterns); + configureConstPropONNXToONNXPass(onnxConstPropExpansionBound, + onnxConstPropDisablePatterns, + OptimizationLevel >= 3 || enableConstantProp); configureOnnxToKrnlLoweringPass(optReport == OptReport::Parallel, enableParallel, optReport == OptReport::Simd, !disableSimdOption); } @@ -85,7 +86,6 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) { // There are more opportunities for const propagation once all tensors have // inferred shapes. pm.addNestedPass(onnx_mlir::createConstPropONNXToONNXPass()); - if (onnxOpTransformThreshold > 0) { // Dynamic iterate in ONNXOpTransformPass pm.addPass(onnx_mlir::createONNXOpTransformPass(onnxOpTransformThreshold, @@ -104,11 +104,12 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) { // Simplify shape-related ops. pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass()); - // One more call to ONNX shape inference/canonicalization/... to update shape - // if possible. + // One more call to ONNX shape inference/canonicalization/... to update + // shape if possible. if (enableONNXHybridPass) { - // For starters only illustrating the new hybrid pass by replacing 3 passes - // here. The plan is to replace most of the passes in addONNXToMLIRPasses. + // For starters only illustrating the new hybrid pass by replacing 3 + // passes here. The plan is to replace most of the passes in + // addONNXToMLIRPasses. pm.addNestedPass(onnx_mlir::createONNXHybridTransformPass()); } else { pm.addNestedPass(onnx_mlir::createShapeInferencePass()); @@ -131,10 +132,10 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) { if (profileIR == onnx_mlir::ProfileIRs::Onnx) { instrumentStage = onnx_mlir::InstrumentStages::Onnx; instrumentOps = "onnx.*"; - // Enable the first three bits for InstrumentBeforOp, InstrumentAfterOp and - // InstrumentReportTime. - // Disable the last bit for InstrumentReportMemory because of its big - // overhead. Users can optionally enable the last bit by using + // Enable the first three bits for InstrumentBeforOp, InstrumentAfterOp + // and InstrumentReportTime. Disable the last bit for + // InstrumentReportMemory because of its big overhead. Users can + // optionally enable the last bit by using // --InstrumentReportMemory option. instrumentActions |= (1 << 3) - 1; } @@ -168,8 +169,8 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, } } - // Print Signatures of each op at runtime if enabled. Should not run signature - // and instrument passes at the same time. + // Print Signatures of each op at runtime if enabled. Should not run + // signature and instrument passes at the same time. if (enableInstrumentONNXSignature) pm.addNestedPass( onnx_mlir::createInstrumentONNXSignaturePass()); @@ -211,8 +212,8 @@ void addKrnlToLLVMPasses( // Use MLIR buffer deallocation pass to emit buffer deallocs. // Currently this has to be done *after* lowering the affine dialect because - // operations in that dialect do not conform to the requirements explained in - // https://mlir.llvm.org/docs/BufferDeallocationInternals. + // operations in that dialect do not conform to the requirements explained + // in https://mlir.llvm.org/docs/BufferDeallocationInternals. pm.addNestedPass( mlir::bufferization::createBufferDeallocationPass()); diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 2597843ab4..7e312fc857 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -45,8 +45,9 @@ std::unique_ptr createConvOptONNXToONNXPass( std::unique_ptr createShapeInferencePass(); // To configure ConstPropONNXToONNXPass at program start. -void configureConstPropONNXToONNXPass( - int expansionBound, llvm::ArrayRef disabledPatterns = {}); +void configureConstPropONNXToONNXPass(int expansionBound, + llvm::ArrayRef disabledPatterns = {}, + bool constantPropIsEnabled = false); std::unique_ptr createConstPropONNXToONNXPass(); diff --git a/src/Transform/ONNX/ConstProp.cpp b/src/Transform/ONNX/ConstProp.cpp index 9a8078d24d..805efeb1e0 100644 --- a/src/Transform/ONNX/ConstProp.cpp +++ b/src/Transform/ONNX/ConstProp.cpp @@ -4,7 +4,7 @@ //===----------- ONNXConstProp.cpp - ONNX High Level Rewriting ------------===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2023 The IBM Research Authors. // // ============================================================================= // @@ -63,10 +63,12 @@ namespace { struct ConstPropONNXToONNXPassConfiguration { static int expansionBound; static StringSet<> disabledPatterns; + static bool constantPropIsEnabled; }; int ConstPropONNXToONNXPassConfiguration::expansionBound = -1; // -1 == no bound StringSet<> ConstPropONNXToONNXPassConfiguration::disabledPatterns; +bool ConstPropONNXToONNXPassConfiguration::constantPropIsEnabled = false; // Precondition: result has ranked tensor type with static shape and int or // float element type. @@ -86,6 +88,14 @@ bool satisfiesExpansionBound(Value result) { getSizeInBytes(resultType); } +// We want to enable Constant Propagation only for Level O3 or when a user +// manually specifies the "enable-constant-prop" flag. +bool isConsantPropagationEnabled() { + bool enable = (/*enableConstantProp*/ ConstPropONNXToONNXPassConfiguration:: + constantPropIsEnabled); + return enable; +} + bool isNotDisabled(StringRef name) { bool ok = !ConstPropONNXToONNXPassConfiguration::disabledPatterns.contains(name); @@ -1025,7 +1035,6 @@ struct ConstPropONNXToONNXPass return "ConstProp ONNX operations into composition of " "other ONNX operations."; } - void runOnOperation() final; }; @@ -1034,20 +1043,26 @@ void ConstPropONNXToONNXPass::runOnOperation() { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - populateWithGenerated(patterns); - if (isNotDisabled("SplitOfConst")) - patterns.insert(context); + if (isConsantPropagationEnabled()) { + populateWithGenerated(patterns); + if (isNotDisabled("SplitOfConst")) { + patterns.insert(context); + } + } + if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns)))) signalPassFailure(); } } // end anonymous namespace. -void onnx_mlir::configureConstPropONNXToONNXPass( - int expansionBound, ArrayRef disabledPatterns) { +void onnx_mlir::configureConstPropONNXToONNXPass(int expansionBound, + ArrayRef disabledPatterns, bool constantPropIsEnabled) { ConstPropONNXToONNXPassConfiguration::expansionBound = expansionBound; ConstPropONNXToONNXPassConfiguration::disabledPatterns.insert( disabledPatterns.begin(), disabledPatterns.end()); + ConstPropONNXToONNXPassConfiguration::constantPropIsEnabled = + constantPropIsEnabled; } /*! diff --git a/test/backend/inference_backend.py b/test/backend/inference_backend.py index a727e316b2..13f8e20cca 100644 --- a/test/backend/inference_backend.py +++ b/test/backend/inference_backend.py @@ -774,10 +774,10 @@ def get_test_models(): # ==LIM== do_not_keep_dim not supported. #"test_reduce_log_sum_desc_axes_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, #"test_reduce_log_sum_asc_axes_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, - "test_reduce_log_sum_default_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + #"test_reduce_log_sum_default_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_reduce_log_sum_negative_axes_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # Name changed in v13 - "test_reduce_log_sum_default_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + # "test_reduce_log_sum_default_expanded_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # ==OP== ReduceL1 # ==MIN== 13 diff --git a/test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh.mlir b/test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh.mlir index 432ed312bd..5d5dd15caf 100644 --- a/test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh.mlir +++ b/test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh.mlir @@ -1,5 +1,5 @@ // RUN: onnx-mlir-opt --maccel=NNPA --shape-inference --rewrite-onnx-for-zhigh %s -split-input-file | FileCheck %s -// RUN: onnx-mlir-opt --maccel=NNPA --rewrite-onnx-for-zhigh --shape-inference --canonicalize --constprop-onnx --shape-inference %s --split-input-file | FileCheck --check-prefix=CONSTPROP %s +// RUN: onnx-mlir-opt --maccel=NNPA --rewrite-onnx-for-zhigh --shape-inference --canonicalize --constprop-onnx --enable-constant-prop=true --shape-inference %s --split-input-file | FileCheck --check-prefix=CONSTPROP %s // ----- diff --git a/test/mlir/onnx/onnx_constprop.mlir b/test/mlir/onnx/onnx_constprop.mlir index 7c68518db8..b47c625f1e 100644 --- a/test/mlir/onnx/onnx_constprop.mlir +++ b/test/mlir/onnx/onnx_constprop.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --shape-inference --constprop-onnx %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --shape-inference --constprop-onnx --enable-constant-prop=true %s -split-input-file | FileCheck %s //===----------------------------------------------------------------------===// // Common tests. Use ONNXAddOp as example. diff --git a/test/mlir/onnx/onnx_constprop_expansion_bound.mlir b/test/mlir/onnx/onnx_constprop_expansion_bound.mlir index 0429243b71..257aa905ac 100644 --- a/test/mlir/onnx/onnx_constprop_expansion_bound.mlir +++ b/test/mlir/onnx/onnx_constprop_expansion_bound.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --constprop-onnx --onnx-const-prop-expansion-bound=2 %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --constprop-onnx --enable-constant-prop=true --onnx-const-prop-expansion-bound=2 %s -split-input-file | FileCheck %s //===----------------------------------------------------------------------===// // Constant propagate ONNXAddOp only if expansion bound satisfied diff --git a/test/mlir/onnx/onnx_constprop_no_shape_inference.mlir b/test/mlir/onnx/onnx_constprop_no_shape_inference.mlir index 9ef7544648..cc7ef91ffb 100644 --- a/test/mlir/onnx/onnx_constprop_no_shape_inference.mlir +++ b/test/mlir/onnx/onnx_constprop_no_shape_inference.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --decompose-onnx --constprop-onnx %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --decompose-onnx --constprop-onnx --enable-constant-prop=true %s -split-input-file | FileCheck %s //===----------------------------------------------------------------------===// /// Split tests diff --git a/test/mlir/onnx/onnx_simplify_shape_related_ops.mlir b/test/mlir/onnx/onnx_simplify_shape_related_ops.mlir index f5f0ede4d7..27ea734ab2 100644 --- a/test/mlir/onnx/onnx_simplify_shape_related_ops.mlir +++ b/test/mlir/onnx/onnx_simplify_shape_related_ops.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --simplify-shape-related-ops-onnx %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --enable-constant-prop=true --simplify-shape-related-ops-onnx %s -split-input-file | FileCheck %s func.func @test_shape_to_dim(%arg0: tensor) -> (tensor<2xi64>) { %0 = "onnx.Shape"(%arg0) : (tensor) -> tensor<2xi64>