From 256e3fb7be1ba2c67defbe333247251fd043d408 Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Mon, 18 Sep 2023 21:23:10 -0700 Subject: [PATCH] replaced -onnx-hybrid-pass-disable-class by -onnx-hybrid-transform suboptions Signed-off-by: Soren Lassen --- src/Compiler/CompilerOptions.cpp | 11 ---- src/Compiler/CompilerOptions.hpp | 1 - src/Compiler/CompilerPasses.cpp | 1 - src/Pass/Passes.hpp | 4 -- .../ONNX/ONNXHybridTransformPass.cpp | 57 +++++++++---------- test/mlir/onnx/onnx_hybrid_transform.mlir | 4 +- 6 files changed, 28 insertions(+), 50 deletions(-) diff --git a/src/Compiler/CompilerOptions.cpp b/src/Compiler/CompilerOptions.cpp index f6f07a4a82..48ef99b4ad 100644 --- a/src/Compiler/CompilerOptions.cpp +++ b/src/Compiler/CompilerOptions.cpp @@ -35,7 +35,6 @@ InstrumentStages instrumentStage; // common for both int onnxConstPropExpansionBound; // common for both std::vector onnxConstPropDisablePatterns; // common for both bool enableONNXHybridPass; // common for both -std::vector onnxHybridPassDisableClasses; // common for both std::vector functionsToDecompose; // common for both EmissionTargetType emissionTarget; // onnx-mlir only bool invokeOnnxVersionConverter; // onnx-mlir only @@ -178,16 +177,6 @@ static llvm::cl::opt enableONNXHybridPassOpt("onnx-hybrid-pass", llvm::cl::location(enableONNXHybridPass), llvm::cl::init(true), llvm::cl::cat(OnnxMlirCommonOptions)); -static llvm::cl::list> - onnxHybridPassDisableClassesOpt("onnx-hybrid-pass-disable-class", - llvm::cl::desc("Class of hybrid pass patterns to disable.\n" - "Can be shape-inference, canonicalization, " - "constant-propagation, decomposition.\n" - "Repeat the flag to disable multiple classes."), - llvm::cl::value_desc("class of hybrid pass patterns to disable"), - llvm::cl::location(onnxHybridPassDisableClasses), - llvm::cl::cat(OnnxMlirCommonOptions)); - static llvm::cl::list> functionsToDecomposeOpt("functions-to-decompose", llvm::cl::desc("Specify ONNX functions to decompose"), diff --git a/src/Compiler/CompilerOptions.hpp b/src/Compiler/CompilerOptions.hpp index 004102d44a..5a26885717 100644 --- a/src/Compiler/CompilerOptions.hpp +++ b/src/Compiler/CompilerOptions.hpp @@ -77,7 +77,6 @@ extern InstrumentStages instrumentStage; // common for both extern int onnxConstPropExpansionBound; // common for both extern std::vector onnxConstPropDisablePatterns; // common for both extern bool enableONNXHybridPass; // common for both -extern std::vector onnxHybridPassDisableClasses; // common for both extern std::vector functionsToDecompose; // common for both extern EmissionTargetType emissionTarget; // onnx-mlir only extern bool invokeOnnxVersionConverter; // onnx-mlir only diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index c577b22e93..f873302cc5 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -47,7 +47,6 @@ void configurePasses() { VectorMachineSupport::setGlobalVectorMachineSupport(march, mcpu, ""); configureConstPropONNXToONNXPass( onnxConstPropExpansionBound, onnxConstPropDisablePatterns); - configureONNXHybridTransformPass(onnxHybridPassDisableClasses); configureOnnxToKrnlLoweringPass(optReport == OptReport::Parallel, enableParallel, optReport == OptReport::Simd, !disableSimdOption); } diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index b22dc7039f..591c6174bd 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -65,10 +65,6 @@ std::unique_ptr createSimplifyShapeRelatedOpsPass(); /// Pass for replacing ONNXReturnOp with func::ReturnOp. std::unique_ptr createStandardFuncReturnPass(); -// To configure ConstPropONNXToONNXPass at program start. -void configureONNXHybridTransformPass( - llvm::ArrayRef disabledClasses); - /// Pass that combines multiple ONNX dialect transformations, /// including shape inference. std::unique_ptr createONNXHybridTransformPass(); diff --git a/src/Transform/ONNX/ONNXHybridTransformPass.cpp b/src/Transform/ONNX/ONNXHybridTransformPass.cpp index 680cfd17e6..0172d41718 100644 --- a/src/Transform/ONNX/ONNXHybridTransformPass.cpp +++ b/src/Transform/ONNX/ONNXHybridTransformPass.cpp @@ -28,22 +28,6 @@ using namespace onnx_mlir; namespace { -// Populated by configureONNXHybridTransformPass(). -struct ONNXHybridTransformPassConfiguration { - static StringSet<> disabledClasses; -}; - -StringSet<> ONNXHybridTransformPassConfiguration::disabledClasses; - -constexpr char kShapeInference[] = "shape-inference"; -constexpr char kCanonicalization[] = "canonicalization"; -constexpr char kConstantPropagation[] = "constant-propagation"; -bool isClass(StringRef s) { - static StringSet<> all{ - kShapeInference, kCanonicalization, kConstantPropagation}; - return all.contains(s); -} - // The pass combines patterns for shape inference and other ONNX-to-ONNX // transforms. // @@ -61,18 +45,39 @@ struct ONNXHybridTransformPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXHybridTransformPass) + Option shapeInference{*this, "shape-inference", + llvm::cl::desc("Enable shape inference in hybrid transform"), + llvm::cl::init(true)}; + + Option canonicalization{*this, "canonicalization", + llvm::cl::desc("Enable canonicalization in hybrid transform"), + llvm::cl::init(true)}; + + Option constantPropagation{*this, "constant-propagation", + llvm::cl::desc("Enable constant propagation in hybrid transform"), + llvm::cl::init(true)}; + + FrozenRewritePatternSet patterns; + + ONNXHybridTransformPass() = default; + + ONNXHybridTransformPass(const ONNXHybridTransformPass &pass) + : patterns(pass.patterns) { + shapeInference = pass.shapeInference; + canonicalization = pass.canonicalization; + constantPropagation = pass.constantPropagation; + } + StringRef getArgument() const override { return "onnx-hybrid-transform"; } LogicalResult initialize(MLIRContext *context) override { RewritePatternSet cumulativePatterns(context); - if (!ONNXHybridTransformPassConfiguration::disabledClasses.contains( - kShapeInference)) { + if (shapeInference) { getShapeInferencePatterns(cumulativePatterns); } - if (!ONNXHybridTransformPassConfiguration::disabledClasses.contains( - kCanonicalization)) { + if (canonicalization) { // canonicalization (copied from mlir/lib/Transforms/Canonicalizer.cpp) for (auto *dialect : context->getLoadedDialects()) dialect->getCanonicalizationPatterns(cumulativePatterns); @@ -80,8 +85,7 @@ struct ONNXHybridTransformPass op.getCanonicalizationPatterns(cumulativePatterns, context); } - if (!ONNXHybridTransformPassConfiguration::disabledClasses.contains( - kConstantPropagation)) { + if (constantPropagation) { getConstPropONNXToONNXPatterns(cumulativePatterns); } @@ -103,19 +107,10 @@ struct ONNXHybridTransformPass inferFunctionReturnShapes(f); } - - FrozenRewritePatternSet patterns; }; } // namespace -void onnx_mlir::configureONNXHybridTransformPass( - ArrayRef disabledClasses) { - assert(llvm::all_of(disabledClasses, isClass)); - ONNXHybridTransformPassConfiguration::disabledClasses.insert( - disabledClasses.begin(), disabledClasses.end()); -} - std::unique_ptr onnx_mlir::createONNXHybridTransformPass() { return std::make_unique(); } diff --git a/test/mlir/onnx/onnx_hybrid_transform.mlir b/test/mlir/onnx/onnx_hybrid_transform.mlir index 65449ee5cd..a1ad2b538f 100644 --- a/test/mlir/onnx/onnx_hybrid_transform.mlir +++ b/test/mlir/onnx/onnx_hybrid_transform.mlir @@ -1,5 +1,5 @@ -// RUN: onnx-mlir-opt -onnx-hybrid-transform -onnx-hybrid-pass-disable-class=constant-propagation %s -split-input-file | FileCheck %s -// RUN: onnx-mlir-opt -onnx-hybrid-transform %s -split-input-file | FileCheck --check-prefix=CONSTPROP %s +// RUN: onnx-mlir-opt -onnx-hybrid-transform=constant-propagation=false %s | FileCheck %s +// RUN: onnx-mlir-opt -onnx-hybrid-transform %s | FileCheck --check-prefix=CONSTPROP %s // Illustrates the back and forth between shape inference and the // BinaryOpBroadcastAxisPattern canonicalization pattern: