Skip to content

Commit

Permalink
Remove Level Constraint for Constant Propagation (#2537)
Browse files Browse the repository at this point in the history
* Remove level constraint
---------

Co-authored-by: Megan Hampton <[email protected]>
  • Loading branch information
hamptonm1 and MegoHam21 authored Sep 28, 2023
1 parent 06a10c0 commit 1a15f58
Show file tree
Hide file tree
Showing 14 changed files with 33 additions and 35 deletions.
13 changes: 6 additions & 7 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ std::string reportHeapBefore; // onnx-mlir only
std::string reportHeapAfter; // onnx-mlir only
std::string modelTag; // onnx-mlir only
bool enableConvOptPass; // onnx-mlir only
bool enableConstantProp; // onnx-mlir only
bool disableConstantProp; // onnx-mlir only
std::vector<std::string> extraLibPaths; // onnx-mlir only
std::vector<std::string> extraLibs; // onnx-mlir only
ProfileIRs profileIR; // onnx-mlir only
Expand Down Expand Up @@ -126,8 +126,7 @@ static llvm::cl::opt<OptLevel, true> OptimizationLevelOpt(
llvm::cl::values(clEnumVal(O0, "Optimization level 0 (default):"),
clEnumVal(O1, "Optimization level 1"),
clEnumVal(O2, "Optimization level 2"),
clEnumVal(O3,
"Optimization level 3, constant propagation and SIMD is enabled")),
clEnumVal(O3, "Optimization level 3, SIMD is enabled")),
llvm::cl::location(OptimizationLevel), llvm::cl::init(O0),
llvm::cl::cat(OnnxMlirCommonOptions));

Expand Down Expand Up @@ -462,10 +461,10 @@ static llvm::cl::opt<bool, true> enableConvOptPassOpt("enable-conv-opt-pass",
llvm::cl::location(enableConvOptPass), llvm::cl::init(true),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<bool, true> enableConstantPropOpt("enable-constant-prop",
llvm::cl::desc("Enable Constant Propagation (default is false)\n"
"Set to 'true' to enable Constant Propagation at Level O3."),
llvm::cl::location(enableConstantProp), llvm::cl::init(false),
static llvm::cl::opt<bool, true> disableConstantPropOpt("disable-constant-prop",
llvm::cl::desc("Disable Constant Propagation (default is false)\n"
"Set to 'true' to disable Constant Propagation."),
llvm::cl::location(disableConstantProp), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::list<std::string, std::vector<std::string>> extraLibPathsOpt(
Expand Down
2 changes: 1 addition & 1 deletion src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ extern std::string reportHeapBefore; // onnx-mlir only
extern std::string reportHeapAfter; // onnx-mlir only
extern std::string modelTag; // onnx-mlir only
extern bool enableConvOptPass; // onnx-mlir only
extern bool enableConstantProp; // onnx-mlir only
extern bool disableConstantProp; // onnx-mlir only
extern std::vector<std::string> extraLibPaths; // onnx-mlir only
extern std::vector<std::string> extraLibs; // onnx-mlir only
extern ProfileIRs profileIR; // onnx-mlir only
Expand Down
3 changes: 1 addition & 2 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ void configurePasses() {
// Set global vector machine support.
VectorMachineSupport::setGlobalVectorMachineSupport(march, mcpu, "");
configureConstPropONNXToONNXPass(onnxConstPropExpansionBound,
onnxConstPropDisablePatterns,
OptimizationLevel >= 3 || enableConstantProp);
onnxConstPropDisablePatterns, disableConstantProp);
configureOnnxToKrnlLoweringPass(optReport == OptReport::Parallel,
enableParallel, optReport == OptReport::Simd, !disableSimdOption);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ std::unique_ptr<mlir::Pass> createShapeInferencePass();

// To configure ConstPropONNXToONNXPass at program start.
void configureConstPropONNXToONNXPass(int expansionBound,
llvm::ArrayRef<std::string> disabledPatterns, bool constantPropIsEnabled);
llvm::ArrayRef<std::string> disabledPatterns, bool constantPropIsDisabled);

std::unique_ptr<mlir::Pass> createConstPropONNXToONNXPass();

Expand Down
24 changes: 12 additions & 12 deletions src/Transform/ONNX/ConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ namespace {
struct ConstPropONNXToONNXPassConfiguration {
static int expansionBound;
static StringSet<> disabledPatterns;
static bool constantPropIsEnabled;
static bool constantPropIsDisabled;
};

int ConstPropONNXToONNXPassConfiguration::expansionBound = -1; // -1 == no bound
StringSet<> ConstPropONNXToONNXPassConfiguration::disabledPatterns = {};
bool ConstPropONNXToONNXPassConfiguration::constantPropIsEnabled = false;
bool ConstPropONNXToONNXPassConfiguration::constantPropIsDisabled = false;

// Precondition: result has ranked tensor type with static shape and int or
// float element type.
Expand All @@ -90,12 +90,12 @@ bool satisfiesExpansionBound(Value result) {
getSizeInBytes(resultType);
}

// We want to enable Constant Propagation only for Level O3 or when a user
// manually specifies the "enable-constant-prop" flag.
bool isConstantPropagationEnabled() {
bool enable = (/*enableConstantProp*/ ConstPropONNXToONNXPassConfiguration::
constantPropIsEnabled);
return enable;
// We want to disable Constant Propagation when a user
// manually specifies the "disable-constant-prop" flag.
bool isConstantPropagationDisabled() {
bool disable = (/*disableConstantProp*/ ConstPropONNXToONNXPassConfiguration::
constantPropIsDisabled);
return disable;
}

bool isNotDisabled(StringRef name) {
Expand Down Expand Up @@ -1053,20 +1053,20 @@ void ConstPropONNXToONNXPass::runOnOperation() {
} // end anonymous namespace.

void onnx_mlir::getConstPropONNXToONNXPatterns(RewritePatternSet &patterns) {
if (!isConstantPropagationEnabled())
if (isConstantPropagationDisabled())
return;
populateWithGenerated(patterns);
if (isNotDisabled("SplitOfConst"))
patterns.insert<SplitOfConst>(patterns.getContext());
}

void onnx_mlir::configureConstPropONNXToONNXPass(int expansionBound,
ArrayRef<std::string> disabledPatterns, bool constantPropIsEnabled) {
ArrayRef<std::string> disabledPatterns, bool constantPropIsDisabled) {
ConstPropONNXToONNXPassConfiguration::expansionBound = expansionBound;
ConstPropONNXToONNXPassConfiguration::disabledPatterns.insert(
disabledPatterns.begin(), disabledPatterns.end());
ConstPropONNXToONNXPassConfiguration::constantPropIsEnabled =
constantPropIsEnabled;
ConstPropONNXToONNXPassConfiguration::constantPropIsDisabled =
constantPropIsDisabled;
}

/*!
Expand Down
4 changes: 2 additions & 2 deletions test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,10 +774,10 @@ def get_test_models():
# ==LIM== do_not_keep_dim not supported.
#"test_reduce_log_sum_desc_axes_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
#"test_reduce_log_sum_asc_axes_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
#"test_reduce_log_sum_default_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_reduce_log_sum_default_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_reduce_log_sum_negative_axes_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
# Name changed in v13
# "test_reduce_log_sum_default_expanded_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_reduce_log_sum_default_expanded_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},

# ==OP== ReduceL1
# ==MIN== 13
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: onnx-mlir --EmitONNXIR --maccel=NNPA --printIR %s | FileCheck %s
// RUN: onnx-mlir --EmitONNXIR --maccel=NNPA --disable-constant-prop=true --printIR %s | FileCheck %s

module attributes {llvm.data_layout = "E-m:e-i1:8:16-i8:8:16-i64:64-f128:64-v128:64-a:8:16-n32:64", llvm.target_triple = "s390x-ibm-linux", "onnx-mlir.symbol-postfix" = "model"} {
func.func @mnist(%arg0: tensor<1x1x28x28xf32>) -> tensor<1x10xf32> attributes {input_names = ["Input3"], output_names = ["Plus214_Output_0"]} {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: onnx-mlir --EmitZHighIR --maccel=NNPA --printIR %s | FileCheck %s
// RUN: onnx-mlir --EmitZHighIR --maccel=NNPA --disable-constant-prop=true --printIR %s | FileCheck %s

// Note that, we intentionally add `device=cpu` into onnx.Gemm to force it run on CPU.
module {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: onnx-mlir-opt --maccel=NNPA --shape-inference --rewrite-onnx-for-zhigh %s -split-input-file | FileCheck %s
// RUN: onnx-mlir-opt --maccel=NNPA --rewrite-onnx-for-zhigh --shape-inference --canonicalize --constprop-onnx --enable-constant-prop=true --shape-inference %s --split-input-file | FileCheck --check-prefix=CONSTPROP %s
// RUN: onnx-mlir-opt --maccel=NNPA --rewrite-onnx-for-zhigh --shape-inference --canonicalize --constprop-onnx --shape-inference %s --split-input-file | FileCheck --check-prefix=CONSTPROP %s

// -----

Expand Down
2 changes: 1 addition & 1 deletion test/mlir/onnx/onnx_constprop.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: onnx-mlir-opt --shape-inference --constprop-onnx --enable-constant-prop=true %s -split-input-file | FileCheck %s
// RUN: onnx-mlir-opt --shape-inference --constprop-onnx %s -split-input-file | FileCheck %s

//===----------------------------------------------------------------------===//
// Common tests. Use ONNXAddOp as example.
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/onnx/onnx_constprop_expansion_bound.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: onnx-mlir-opt --constprop-onnx --enable-constant-prop=true --onnx-const-prop-expansion-bound=2 %s -split-input-file | FileCheck %s
// RUN: onnx-mlir-opt --constprop-onnx --onnx-const-prop-expansion-bound=2 %s -split-input-file | FileCheck %s

//===----------------------------------------------------------------------===//
// Constant propagate ONNXAddOp only if expansion bound satisfied
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/onnx/onnx_constprop_no_shape_inference.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: onnx-mlir-opt --decompose-onnx --constprop-onnx --enable-constant-prop=true %s -split-input-file | FileCheck %s
// RUN: onnx-mlir-opt --decompose-onnx --constprop-onnx %s -split-input-file | FileCheck %s

//===----------------------------------------------------------------------===//
/// Split tests
Expand Down
6 changes: 3 additions & 3 deletions test/mlir/onnx/onnx_hybrid_transform.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: onnx-mlir-opt --enable-constant-prop=true -onnx-hybrid-transform="constant-propagation=false decomposition=false" %s | FileCheck %s
// RUN: onnx-mlir-opt --enable-constant-prop=true -onnx-hybrid-transform=constant-propagation=false %s | FileCheck --check-prefix=DECOMPOSE %s
// RUN: onnx-mlir-opt --enable-constant-prop=true -onnx-hybrid-transform=decomposition=false %s | FileCheck --check-prefix=CONSTPROP %s
// RUN: onnx-mlir-opt -onnx-hybrid-transform="constant-propagation=false decomposition=false" %s | FileCheck %s
// RUN: onnx-mlir-opt -onnx-hybrid-transform=constant-propagation=false %s | FileCheck --check-prefix=DECOMPOSE %s
// RUN: onnx-mlir-opt -onnx-hybrid-transform=decomposition=false %s | FileCheck --check-prefix=CONSTPROP %s

// Illustrates the back and forth between shape inference and the
// BinaryOpBroadcastAxisPattern canonicalization pattern:
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/onnx/onnx_simplify_shape_related_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: onnx-mlir-opt --enable-constant-prop=true --simplify-shape-related-ops-onnx %s -split-input-file | FileCheck %s
// RUN: onnx-mlir-opt --simplify-shape-related-ops-onnx %s -split-input-file | FileCheck %s

func.func @test_shape_to_dim(%arg0: tensor<?x256xi64>) -> (tensor<2xi64>) {
%0 = "onnx.Shape"(%arg0) : (tensor<?x256xi64>) -> tensor<2xi64>
Expand Down

0 comments on commit 1a15f58

Please sign in to comment.