Skip to content

Commit

Permalink
replaced -onnx-hybrid-pass-disable-class by -onnx-hybrid-transform su…
Browse files Browse the repository at this point in the history
…boptions

Signed-off-by: Soren Lassen <[email protected]>
  • Loading branch information
sorenlassen committed Sep 19, 2023
1 parent 1dbe45a commit 256e3fb
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 50 deletions.
11 changes: 0 additions & 11 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ InstrumentStages instrumentStage; // common for both
int onnxConstPropExpansionBound; // common for both
std::vector<std::string> onnxConstPropDisablePatterns; // common for both
bool enableONNXHybridPass; // common for both
std::vector<std::string> onnxHybridPassDisableClasses; // common for both
std::vector<std::string> functionsToDecompose; // common for both
EmissionTargetType emissionTarget; // onnx-mlir only
bool invokeOnnxVersionConverter; // onnx-mlir only
Expand Down Expand Up @@ -178,16 +177,6 @@ static llvm::cl::opt<bool, true> enableONNXHybridPassOpt("onnx-hybrid-pass",
llvm::cl::location(enableONNXHybridPass), llvm::cl::init(true),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::list<std::string, std::vector<std::string>>
onnxHybridPassDisableClassesOpt("onnx-hybrid-pass-disable-class",
llvm::cl::desc("Class of hybrid pass patterns to disable.\n"
"Can be shape-inference, canonicalization, "
"constant-propagation, decomposition.\n"
"Repeat the flag to disable multiple classes."),
llvm::cl::value_desc("class of hybrid pass patterns to disable"),
llvm::cl::location(onnxHybridPassDisableClasses),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::list<std::string, std::vector<std::string>>
functionsToDecomposeOpt("functions-to-decompose",
llvm::cl::desc("Specify ONNX functions to decompose"),
Expand Down
1 change: 0 additions & 1 deletion src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ extern InstrumentStages instrumentStage; // common for both
extern int onnxConstPropExpansionBound; // common for both
extern std::vector<std::string> onnxConstPropDisablePatterns; // common for both
extern bool enableONNXHybridPass; // common for both
extern std::vector<std::string> onnxHybridPassDisableClasses; // common for both
extern std::vector<std::string> functionsToDecompose; // common for both
extern EmissionTargetType emissionTarget; // onnx-mlir only
extern bool invokeOnnxVersionConverter; // onnx-mlir only
Expand Down
1 change: 0 additions & 1 deletion src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ void configurePasses() {
VectorMachineSupport::setGlobalVectorMachineSupport(march, mcpu, "");
configureConstPropONNXToONNXPass(
onnxConstPropExpansionBound, onnxConstPropDisablePatterns);
configureONNXHybridTransformPass(onnxHybridPassDisableClasses);
configureOnnxToKrnlLoweringPass(optReport == OptReport::Parallel,
enableParallel, optReport == OptReport::Simd, !disableSimdOption);
}
Expand Down
4 changes: 0 additions & 4 deletions src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ std::unique_ptr<mlir::Pass> createSimplifyShapeRelatedOpsPass();
/// Pass for replacing ONNXReturnOp with func::ReturnOp.
std::unique_ptr<mlir::Pass> createStandardFuncReturnPass();

// To configure ConstPropONNXToONNXPass at program start.
void configureONNXHybridTransformPass(
llvm::ArrayRef<std::string> disabledClasses);

/// Pass that combines multiple ONNX dialect transformations,
/// including shape inference.
std::unique_ptr<mlir::Pass> createONNXHybridTransformPass();
Expand Down
57 changes: 26 additions & 31 deletions src/Transform/ONNX/ONNXHybridTransformPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,6 @@ using namespace onnx_mlir;

namespace {

// Populated by configureONNXHybridTransformPass().
struct ONNXHybridTransformPassConfiguration {
static StringSet<> disabledClasses;
};

StringSet<> ONNXHybridTransformPassConfiguration::disabledClasses;

constexpr char kShapeInference[] = "shape-inference";
constexpr char kCanonicalization[] = "canonicalization";
constexpr char kConstantPropagation[] = "constant-propagation";
bool isClass(StringRef s) {
static StringSet<> all{
kShapeInference, kCanonicalization, kConstantPropagation};
return all.contains(s);
}

// The pass combines patterns for shape inference and other ONNX-to-ONNX
// transforms.
//
Expand All @@ -61,27 +45,47 @@ struct ONNXHybridTransformPass
: public PassWrapper<ONNXHybridTransformPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXHybridTransformPass)

Option<bool> shapeInference{*this, "shape-inference",
llvm::cl::desc("Enable shape inference in hybrid transform"),
llvm::cl::init(true)};

Option<bool> canonicalization{*this, "canonicalization",
llvm::cl::desc("Enable canonicalization in hybrid transform"),
llvm::cl::init(true)};

Option<bool> constantPropagation{*this, "constant-propagation",
llvm::cl::desc("Enable constant propagation in hybrid transform"),
llvm::cl::init(true)};

FrozenRewritePatternSet patterns;

ONNXHybridTransformPass() = default;

ONNXHybridTransformPass(const ONNXHybridTransformPass &pass)
: patterns(pass.patterns) {
shapeInference = pass.shapeInference;
canonicalization = pass.canonicalization;
constantPropagation = pass.constantPropagation;
}

StringRef getArgument() const override { return "onnx-hybrid-transform"; }

LogicalResult initialize(MLIRContext *context) override {
RewritePatternSet cumulativePatterns(context);

if (!ONNXHybridTransformPassConfiguration::disabledClasses.contains(
kShapeInference)) {
if (shapeInference) {
getShapeInferencePatterns(cumulativePatterns);
}

if (!ONNXHybridTransformPassConfiguration::disabledClasses.contains(
kCanonicalization)) {
if (canonicalization) {
// canonicalization (copied from mlir/lib/Transforms/Canonicalizer.cpp)
for (auto *dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(cumulativePatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(cumulativePatterns, context);
}

if (!ONNXHybridTransformPassConfiguration::disabledClasses.contains(
kConstantPropagation)) {
if (constantPropagation) {
getConstPropONNXToONNXPatterns(cumulativePatterns);
}

Expand All @@ -103,19 +107,10 @@ struct ONNXHybridTransformPass

inferFunctionReturnShapes(f);
}

FrozenRewritePatternSet patterns;
};

} // namespace

void onnx_mlir::configureONNXHybridTransformPass(
ArrayRef<std::string> disabledClasses) {
assert(llvm::all_of(disabledClasses, isClass));
ONNXHybridTransformPassConfiguration::disabledClasses.insert(
disabledClasses.begin(), disabledClasses.end());
}

std::unique_ptr<mlir::Pass> onnx_mlir::createONNXHybridTransformPass() {
return std::make_unique<ONNXHybridTransformPass>();
}
4 changes: 2 additions & 2 deletions test/mlir/onnx/onnx_hybrid_transform.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: onnx-mlir-opt -onnx-hybrid-transform -onnx-hybrid-pass-disable-class=constant-propagation %s -split-input-file | FileCheck %s
// RUN: onnx-mlir-opt -onnx-hybrid-transform %s -split-input-file | FileCheck --check-prefix=CONSTPROP %s
// RUN: onnx-mlir-opt -onnx-hybrid-transform=constant-propagation=false %s | FileCheck %s
// RUN: onnx-mlir-opt -onnx-hybrid-transform %s | FileCheck --check-prefix=CONSTPROP %s

// Illustrates the back and forth between shape inference and the
// BinaryOpBroadcastAxisPattern canonicalization pattern:
Expand Down

0 comments on commit 256e3fb

Please sign in to comment.