Skip to content

Commit

Permalink
Create Constant Propagation flag (#2498)
Browse files Browse the repository at this point in the history
* constant propagation flag

---------

Signed-off-by: Megan Hampton <[email protected]>
Co-authored-by: Megan Hampton <[email protected]>
  • Loading branch information
hamptonm1 and MegoHam21 authored Sep 19, 2023
1 parent 425a01f commit 191a9e1
Show file tree
Hide file tree
Showing 11 changed files with 62 additions and 36 deletions.
16 changes: 12 additions & 4 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===------------------------ CompilerOptions.cpp -------------------------===//
//
// Copyright 2022 The IBM Research Authors.
// Copyright 2022, 2023 The IBM Research Authors.
//
// =============================================================================
//
Expand Down Expand Up @@ -69,6 +69,7 @@ std::string reportHeapBefore; // onnx-mlir only
std::string reportHeapAfter; // onnx-mlir only
std::string modelTag; // onnx-mlir only
bool enableConvOptPass; // onnx-mlir only
bool enableConstantProp; // onnx-mlir only
std::vector<std::string> extraLibPaths; // onnx-mlir only
std::vector<std::string> extraLibs; // onnx-mlir only
ProfileIRs profileIR; // onnx-mlir only
Expand Down Expand Up @@ -123,9 +124,10 @@ static llvm::cl::list<accel::Accelerator::Kind,
static llvm::cl::opt<OptLevel, true> OptimizationLevelOpt(
llvm::cl::desc("Levels:"),
llvm::cl::values(clEnumVal(O0, "Optimization level 0 (default):"),
clEnumVal(O1, "Optimization level 1,"),
clEnumVal(O2, "Optimization level 2,"),
clEnumVal(O3, "Optimization level 3.")),
clEnumVal(O1, "Optimization level 1"),
clEnumVal(O2, "Optimization level 2"),
clEnumVal(O3,
"Optimization level 3, constant propagation and SIMD is enabled")),
llvm::cl::location(OptimizationLevel), llvm::cl::init(O0),
llvm::cl::cat(OnnxMlirCommonOptions));

Expand Down Expand Up @@ -460,6 +462,12 @@ static llvm::cl::opt<bool, true> enableConvOptPassOpt("enable-conv-opt-pass",
llvm::cl::location(enableConvOptPass), llvm::cl::init(true),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<bool, true> enableConstantPropOpt("enable-constant-prop",
llvm::cl::desc("Enable Constant Propagation (default is false)\n"
"Set to 'true' to enable Constant Propagation at Level O3."),
llvm::cl::location(enableConstantProp), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::list<std::string, std::vector<std::string>> extraLibPathsOpt(
"L",
llvm::cl::desc("Specify extra directories for libraries when compiling"
Expand Down
3 changes: 2 additions & 1 deletion src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===------------------------ CompilerOptions.hpp -------------------------===//
//
// Copyright 2022 The IBM Research Authors.
// Copyright 2022, 2023 The IBM Research Authors.
//
// =============================================================================
//
Expand Down Expand Up @@ -111,6 +111,7 @@ extern std::string reportHeapBefore; // onnx-mlir only
extern std::string reportHeapAfter; // onnx-mlir only
extern std::string modelTag; // onnx-mlir only
extern bool enableConvOptPass; // onnx-mlir only
extern bool enableConstantProp; // onnx-mlir only
extern std::vector<std::string> extraLibPaths; // onnx-mlir only
extern std::vector<std::string> extraLibs; // onnx-mlir only
extern ProfileIRs profileIR; // onnx-mlir only
Expand Down
31 changes: 16 additions & 15 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ namespace onnx_mlir {
void configurePasses() {
// Set global vector machine support.
VectorMachineSupport::setGlobalVectorMachineSupport(march, mcpu, "");
configureConstPropONNXToONNXPass(
onnxConstPropExpansionBound, onnxConstPropDisablePatterns);
configureConstPropONNXToONNXPass(onnxConstPropExpansionBound,
onnxConstPropDisablePatterns,
OptimizationLevel >= 3 || enableConstantProp);
configureOnnxToKrnlLoweringPass(optReport == OptReport::Parallel,
enableParallel, optReport == OptReport::Simd, !disableSimdOption);
}
Expand Down Expand Up @@ -85,7 +86,6 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) {
// There are more opportunities for const propagation once all tensors have
// inferred shapes.
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConstPropONNXToONNXPass());

if (onnxOpTransformThreshold > 0) {
// Dynamic iterate in ONNXOpTransformPass
pm.addPass(onnx_mlir::createONNXOpTransformPass(onnxOpTransformThreshold,
Expand All @@ -104,11 +104,12 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) {
// Simplify shape-related ops.
pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass());

// One more call to ONNX shape inference/canonicalization/... to update shape
// if possible.
// One more call to ONNX shape inference/canonicalization/... to update
// shape if possible.
if (enableONNXHybridPass) {
// For starters only illustrating the new hybrid pass by replacing 3 passes
// here. The plan is to replace most of the passes in addONNXToMLIRPasses.
// For starters only illustrating the new hybrid pass by replacing 3
// passes here. The plan is to replace most of the passes in
// addONNXToMLIRPasses.
pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass());
} else {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
Expand All @@ -131,10 +132,10 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) {
if (profileIR == onnx_mlir::ProfileIRs::Onnx) {
instrumentStage = onnx_mlir::InstrumentStages::Onnx;
instrumentOps = "onnx.*";
// Enable the first three bits for InstrumentBeforOp, InstrumentAfterOp and
// InstrumentReportTime.
// Disable the last bit for InstrumentReportMemory because of its big
// overhead. Users can optionally enable the last bit by using
// Enable the first three bits for InstrumentBeforOp, InstrumentAfterOp
// and InstrumentReportTime. Disable the last bit for
// InstrumentReportMemory because of its big overhead. Users can
// optionally enable the last bit by using
// --InstrumentReportMemory option.
instrumentActions |= (1 << 3) - 1;
}
Expand Down Expand Up @@ -168,8 +169,8 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
}
}

// Print Signatures of each op at runtime if enabled. Should not run signature
// and instrument passes at the same time.
// Print Signatures of each op at runtime if enabled. Should not run
// signature and instrument passes at the same time.
if (enableInstrumentONNXSignature)
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createInstrumentONNXSignaturePass());
Expand Down Expand Up @@ -211,8 +212,8 @@ void addKrnlToLLVMPasses(

// Use MLIR buffer deallocation pass to emit buffer deallocs.
// Currently this has to be done *after* lowering the affine dialect because
// operations in that dialect do not conform to the requirements explained in
// https://mlir.llvm.org/docs/BufferDeallocationInternals.
// operations in that dialect do not conform to the requirements explained
// in https://mlir.llvm.org/docs/BufferDeallocationInternals.
pm.addNestedPass<func::FuncOp>(
mlir::bufferization::createBufferDeallocationPass());

Expand Down
5 changes: 3 additions & 2 deletions src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ std::unique_ptr<mlir::Pass> createConvOptONNXToONNXPass(
std::unique_ptr<mlir::Pass> createShapeInferencePass();

// To configure ConstPropONNXToONNXPass at program start.
void configureConstPropONNXToONNXPass(
int expansionBound, llvm::ArrayRef<std::string> disabledPatterns = {});
void configureConstPropONNXToONNXPass(int expansionBound,
llvm::ArrayRef<std::string> disabledPatterns = {},
bool constantPropIsEnabled = false);

std::unique_ptr<mlir::Pass> createConstPropONNXToONNXPass();

Expand Down
29 changes: 22 additions & 7 deletions src/Transform/ONNX/ConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===----------- ONNXConstProp.cpp - ONNX High Level Rewriting ------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
// Copyright 2019-2023 The IBM Research Authors.
//
// =============================================================================
//
Expand Down Expand Up @@ -63,10 +63,12 @@ namespace {
struct ConstPropONNXToONNXPassConfiguration {
static int expansionBound;
static StringSet<> disabledPatterns;
static bool constantPropIsEnabled;
};

int ConstPropONNXToONNXPassConfiguration::expansionBound = -1; // -1 == no bound
StringSet<> ConstPropONNXToONNXPassConfiguration::disabledPatterns;
bool ConstPropONNXToONNXPassConfiguration::constantPropIsEnabled = false;

// Precondition: result has ranked tensor type with static shape and int or
// float element type.
Expand All @@ -86,6 +88,14 @@ bool satisfiesExpansionBound(Value result) {
getSizeInBytes(resultType);
}

// We want to enable Constant Propagation only for Level O3 or when a user
// manually specifies the "enable-constant-prop" flag.
bool isConsantPropagationEnabled() {
bool enable = (/*enableConstantProp*/ ConstPropONNXToONNXPassConfiguration::
constantPropIsEnabled);
return enable;
}

bool isNotDisabled(StringRef name) {
bool ok =
!ConstPropONNXToONNXPassConfiguration::disabledPatterns.contains(name);
Expand Down Expand Up @@ -1025,7 +1035,6 @@ struct ConstPropONNXToONNXPass
return "ConstProp ONNX operations into composition of "
"other ONNX operations.";
}

void runOnOperation() final;
};

Expand All @@ -1034,20 +1043,26 @@ void ConstPropONNXToONNXPass::runOnOperation() {
MLIRContext *context = &getContext();

RewritePatternSet patterns(context);
populateWithGenerated(patterns);
if (isNotDisabled("SplitOfConst"))
patterns.insert<SplitOfConst>(context);
if (isConsantPropagationEnabled()) {
populateWithGenerated(patterns);
if (isNotDisabled("SplitOfConst")) {
patterns.insert<SplitOfConst>(context);
}
}

if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns))))
signalPassFailure();
}

} // end anonymous namespace.

void onnx_mlir::configureConstPropONNXToONNXPass(
int expansionBound, ArrayRef<std::string> disabledPatterns) {
void onnx_mlir::configureConstPropONNXToONNXPass(int expansionBound,
ArrayRef<std::string> disabledPatterns, bool constantPropIsEnabled) {
ConstPropONNXToONNXPassConfiguration::expansionBound = expansionBound;
ConstPropONNXToONNXPassConfiguration::disabledPatterns.insert(
disabledPatterns.begin(), disabledPatterns.end());
ConstPropONNXToONNXPassConfiguration::constantPropIsEnabled =
constantPropIsEnabled;
}

/*!
Expand Down
4 changes: 2 additions & 2 deletions test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,10 +774,10 @@ def get_test_models():
# ==LIM== do_not_keep_dim not supported.
#"test_reduce_log_sum_desc_axes_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
#"test_reduce_log_sum_asc_axes_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_reduce_log_sum_default_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
#"test_reduce_log_sum_default_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_reduce_log_sum_negative_axes_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
# Name changed in v13
"test_reduce_log_sum_default_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
# "test_reduce_log_sum_default_expanded_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},

# ==OP== ReduceL1
# ==MIN== 13
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: onnx-mlir-opt --maccel=NNPA --shape-inference --rewrite-onnx-for-zhigh %s -split-input-file | FileCheck %s
// RUN: onnx-mlir-opt --maccel=NNPA --rewrite-onnx-for-zhigh --shape-inference --canonicalize --constprop-onnx --shape-inference %s --split-input-file | FileCheck --check-prefix=CONSTPROP %s
// RUN: onnx-mlir-opt --maccel=NNPA --rewrite-onnx-for-zhigh --shape-inference --canonicalize --constprop-onnx --enable-constant-prop=true --shape-inference %s --split-input-file | FileCheck --check-prefix=CONSTPROP %s

// -----

Expand Down
2 changes: 1 addition & 1 deletion test/mlir/onnx/onnx_constprop.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: onnx-mlir-opt --shape-inference --constprop-onnx %s -split-input-file | FileCheck %s
// RUN: onnx-mlir-opt --shape-inference --constprop-onnx --enable-constant-prop=true %s -split-input-file | FileCheck %s

//===----------------------------------------------------------------------===//
// Common tests. Use ONNXAddOp as example.
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/onnx/onnx_constprop_expansion_bound.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: onnx-mlir-opt --constprop-onnx --onnx-const-prop-expansion-bound=2 %s -split-input-file | FileCheck %s
// RUN: onnx-mlir-opt --constprop-onnx --enable-constant-prop=true --onnx-const-prop-expansion-bound=2 %s -split-input-file | FileCheck %s

//===----------------------------------------------------------------------===//
// Constant propagate ONNXAddOp only if expansion bound satisfied
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/onnx/onnx_constprop_no_shape_inference.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: onnx-mlir-opt --decompose-onnx --constprop-onnx %s -split-input-file | FileCheck %s
// RUN: onnx-mlir-opt --decompose-onnx --constprop-onnx --enable-constant-prop=true %s -split-input-file | FileCheck %s

//===----------------------------------------------------------------------===//
/// Split tests
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/onnx/onnx_simplify_shape_related_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: onnx-mlir-opt --simplify-shape-related-ops-onnx %s -split-input-file | FileCheck %s
// RUN: onnx-mlir-opt --enable-constant-prop=true --simplify-shape-related-ops-onnx %s -split-input-file | FileCheck %s

func.func @test_shape_to_dim(%arg0: tensor<?x256xi64>) -> (tensor<2xi64>) {
%0 = "onnx.Shape"(%arg0) : (tensor<?x256xi64>) -> tensor<2xi64>
Expand Down

0 comments on commit 191a9e1

Please sign in to comment.