From 4fb146a587191df4b609b86f8ffb61f68c78ed02 Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Fri, 22 Sep 2023 00:08:18 -0700 Subject: [PATCH] add decomposition to hybrid pass (#2496) Signed-off-by: Soren Lassen --- src/Transform/ONNX/Decompose.cpp | 120 ++++++++++-------- src/Transform/ONNX/Decompose.hpp | 15 +++ src/Transform/ONNX/Decompose.td | 9 +- .../ONNX/ONNXHybridTransformPass.cpp | 19 ++- test/mlir/onnx/onnx_hybrid_transform.mlir | 103 ++++++++++++++- 5 files changed, 199 insertions(+), 67 deletions(-) create mode 100644 src/Transform/ONNX/Decompose.hpp diff --git a/src/Transform/ONNX/Decompose.cpp b/src/Transform/ONNX/Decompose.cpp index 413674c845..a4fabb833e 100644 --- a/src/Transform/ONNX/Decompose.cpp +++ b/src/Transform/ONNX/Decompose.cpp @@ -20,6 +20,9 @@ // //===----------------------------------------------------------------------===// +#include "src/Transform/ONNX/Decompose.hpp" +#include "src/Pass/Passes.hpp" + #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -29,7 +32,6 @@ #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" -#include "src/Pass/Passes.hpp" #include "src/Support/TypeUtilities.hpp" #include "src/Transform/ONNX/DecomposeEinsum.hpp" @@ -326,6 +328,17 @@ bool hasStaticSpatialDims(Value v) { return !llvm::any_of(Ds, ShapedType::isDynamic); } +bool shouldDecomposeConvTransposeOp(Value convTransposeResult) { +#ifdef ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE + ONNXConvTransposeOp op = + cast(convTransposeResult.getDefiningOp()); + return hasStaticSpatialDims(op.getX()) && hasStaticSpatialDims(op.getW()); +#else + // Disable the ONNXConvTransposeOp decomposition patterns. + return false; +#endif +} + // Split on the specified axis. The length of each output is one. ValueRange emitSplitAxisOutputLength1( PatternRewriter &rewriter, Location loc, Value input, int64_t axis) { @@ -445,6 +458,9 @@ Value insertAdditionalPadsConvTranspose(PatternRewriter &rewriter, Location loc, namespace { /// Include the patterns defined in the Declarative Rewrite framework. #include "src/Transform/ONNX/ONNXDecompose.inc" + +#ifdef ONNX_MLIR_ENABLE_MHLO + RankedTensorType createResultType( Type outputType, int64_t axisValue, bool keepDims) { RankedTensorType outputShapeType = outputType.dyn_cast(); @@ -465,26 +481,18 @@ RankedTensorType createResultType( return resultType; } -struct SoftmaxPattern : public ConversionPattern { - SoftmaxPattern(MLIRContext *context) - : ConversionPattern(ONNXSoftmaxOp::getOperationName(), 1, context) {} - LogicalResult matchAndRewrite(Operation *op0, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - // Variables for capturing values and attributes used while creating ops. - IntegerAttr axis; +struct SoftmaxPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite( + ONNXSoftmaxOp softmaxOp, PatternRewriter &rewriter) const final { // Match - ONNXSoftmaxOp softmaxOp = ::llvm::dyn_cast(op0); Value input = softmaxOp.getInput(); Type inputType = input.getType(); - axis = op0->getAttrOfType("axis"); - if (!axis) - axis = rewriter.getIntegerAttr( - rewriter.getIntegerType(64, /*isSigned=*/true), -1); - int64_t axisValue = axis.getSInt(); + int64_t axisValue = softmaxOp.getAxis(); // Rewrite - Location odsLoc = op0->getLoc(); + Location odsLoc = softmaxOp.getLoc(); onnx_mlir::MultiDialectBuilder create( rewriter, odsLoc); @@ -506,16 +514,16 @@ struct SoftmaxPattern : public ConversionPattern { /*axis=*/axisOp, keepDimsAttr, noopWithEmptyAxes); Value divValue = rewriter.create(odsLoc, inputType, expValue, sumValue); - rewriter.replaceOp(op0, divValue); + rewriter.replaceOp(softmaxOp, divValue); return success(); } }; -#ifdef ONNX_MLIR_ENABLE_MHLO void populateDecomposingONNXBeforeMhloPatterns( RewritePatternSet &patterns, MLIRContext *ctx) { patterns.add(ctx); } + #endif // Special Op fusion for the following pattern: @@ -528,11 +536,11 @@ void populateDecomposingONNXBeforeMhloPatterns( // Helper function: is the ConcatOp matched to the fusion pattern? static bool isConcatFuseMatched( - Operation *op, ONNXShapeOp &shapeOp, ONNXTransposeOp &transposeOp) { - shapeOp = NULL; - transposeOp = NULL; + ONNXConcatOp concatOp, ONNXShapeOp &shapeOp, ONNXTransposeOp &transposeOp) { + shapeOp = nullptr; + transposeOp = nullptr; bool failed = false; - for (Operation *user : op->getUsers()) { + for (Operation *user : concatOp->getUsers()) { if (isa(user) && !shapeOp) shapeOp = cast(user); else if (isa(user) && !transposeOp) @@ -543,18 +551,15 @@ static bool isConcatFuseMatched( return (shapeOp && transposeOp && !failed); } -struct ConcatFusePattern : public ConversionPattern { - ConcatFusePattern(MLIRContext *context) - : ConversionPattern(ONNXConcatOp::getOperationName(), 4, context) {} - LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - - ONNXConcatOp concatOp = ::llvm::dyn_cast(op); +struct ConcatFusePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite( + ONNXConcatOp concatOp, PatternRewriter &rewriter) const final { // Match - ONNXShapeOp shapeOp = NULL; - ONNXTransposeOp transposeOp = NULL; - if (!isConcatFuseMatched(op, shapeOp, transposeOp)) + ONNXShapeOp shapeOp; + ONNXTransposeOp transposeOp; + if (!isConcatFuseMatched(concatOp, shapeOp, transposeOp)) return failure(); // Rewrite @@ -562,12 +567,13 @@ struct ConcatFusePattern : public ConversionPattern { outputTypes.emplace_back(shapeOp.getResult().getType()); outputTypes.emplace_back(transposeOp.getResult().getType()); - auto fusedV = rewriter.create(op->getLoc(), - outputTypes, operands, concatOp.getAxisAttr(), shapeOp.getEndAttr(), - shapeOp.getStartAttr(), transposeOp.getPermAttr()); + auto fusedV = rewriter.create(concatOp.getLoc(), + outputTypes, concatOp->getOperands(), concatOp.getAxisAttr(), + shapeOp.getEndAttr(), shapeOp.getStartAttr(), + transposeOp.getPermAttr()); rewriter.replaceOp(shapeOp.getOperation(), fusedV.getResults()[0]); rewriter.replaceOp(transposeOp.getOperation(), fusedV.getResults()[1]); - rewriter.eraseOp(op); + rewriter.eraseOp(concatOp); return success(); } }; @@ -594,12 +600,12 @@ struct ConcatFusePattern : public ConversionPattern { // transA = 0 : si64, transB = 1 : si64} : // (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> // ``` -struct CustomOpFuseMatMulPattern : public OpConversionPattern { - CustomOpFuseMatMulPattern(MLIRContext *context) - : OpConversionPattern(context) {} - LogicalResult matchAndRewrite(ONNXCustomOp customOp, - ONNXCustomOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + +struct CustomOpFuseMatMulPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXCustomOp customOp, PatternRewriter &rewriter) const final { using namespace onnx_mlir; Location loc = customOp.getLoc(); @@ -806,8 +812,8 @@ void DecomposeONNXToONNXPass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addDynamicallyLegalOp([](ONNXConcatOp op) { - ONNXShapeOp shapeOp = NULL; - ONNXTransposeOp transposeOp = NULL; + ONNXShapeOp shapeOp; + ONNXTransposeOp transposeOp; return !isConcatFuseMatched(op, shapeOp, transposeOp); }); // Decompose CustomOp FusedMatMul introduced by onnxruntime: @@ -828,8 +834,7 @@ void DecomposeONNXToONNXPass::runOnOperation() { #endif target.addDynamicallyLegalOp( [](ONNXConvTransposeOp op) { - return !(onnx_mlir::hasStaticSpatialDims(op.getX()) && - onnx_mlir::hasStaticSpatialDims(op.getW())); + return !onnx_mlir::shouldDecomposeConvTransposeOp(op); }); #ifdef ONNX_MLIR_ENABLE_MHLO } @@ -837,13 +842,7 @@ void DecomposeONNXToONNXPass::runOnOperation() { #endif RewritePatternSet patterns(context); - populateWithGenerated(patterns); - patterns.insert(&getContext()); - patterns.insert(&getContext()); - // Decompose CustomOp FusedMatMul introduced by onnxruntime: - // https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul - patterns.insert(&getContext()); - + onnx_mlir::getDecomposeONNXToONNXPatterns(patterns); #ifdef ONNX_MLIR_ENABLE_MHLO if (this->target == "mhlo") { populateDecomposingONNXBeforeMhloPatterns(patterns, context); @@ -857,14 +856,23 @@ void DecomposeONNXToONNXPass::runOnOperation() { } // namespace -namespace onnx_mlir { +void onnx_mlir::getDecomposeONNXToONNXPatterns( + mlir::RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + populateWithGenerated(patterns); + patterns.insert(context); + patterns.insert(context); + // Decompose CustomOp FusedMatMul introduced by onnxruntime: + // https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul + patterns.insert(context); + + // TODO: consider whether to include SoftmaxPattern here +} /*! * Create a DecomposeONNX pass. */ -std::unique_ptr createDecomposeONNXToONNXPass( +std::unique_ptr onnx_mlir::createDecomposeONNXToONNXPass( const std::string &target) { return std::make_unique(target); } - -} // namespace onnx_mlir diff --git a/src/Transform/ONNX/Decompose.hpp b/src/Transform/ONNX/Decompose.hpp new file mode 100644 index 0000000000..d7edb8256e --- /dev/null +++ b/src/Transform/ONNX/Decompose.hpp @@ -0,0 +1,15 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "mlir/IR/PatternMatch.h" + +namespace onnx_mlir { + +// Exports the DecomposeONNXToONNXPass patterns. They are all plain rewrite +// patterns that can be used with any PatternRewriter, not conversion patterns. +void getDecomposeONNXToONNXPatterns(mlir::RewritePatternSet &patterns); + +} // namespace onnx_mlir diff --git a/src/Transform/ONNX/Decompose.td b/src/Transform/ONNX/Decompose.td index abd9049992..7cfd1b6cda 100644 --- a/src/Transform/ONNX/Decompose.td +++ b/src/Transform/ONNX/Decompose.td @@ -474,6 +474,11 @@ def HasUnitStrides: Constraint< "has unit strides" >; +def ShouldDecomposeConvTransposeOp: Constraint< + CPred<"onnx_mlir::shouldDecomposeConvTransposeOp($_self)">, + "X and W have static spatial dims and ConvTransposeOp decomposition is enabled" +>; + def HasStaticSpatialDims: Constraint< CPred<"onnx_mlir::hasStaticSpatialDims($_self)">, "has static spatial dims" @@ -496,7 +501,7 @@ def ConvTransposeOpPattern1: Pattern< (GetNullStringAttr), $dilation, $group, $kernel_shape, $new_pads, $strides), (insertAdditionalPadsConvTranspose $conv_res, $res) ], - [(HasUnitStrides:$strides), (HasStaticSpatialDims:$x), (HasStaticSpatialDims:$w)], [], + [(ShouldDecomposeConvTransposeOp:$res), (HasUnitStrides:$strides)], [], (addBenefit 1) >; @@ -513,7 +518,7 @@ def ConvTransposeOpPattern2: Pattern< (insertAdditionalPadsConvTranspose $conv_res, $res) ], - [(HasStaticSpatialDims:$x), (HasStaticSpatialDims:$w)], [], + [(ShouldDecomposeConvTransposeOp:$res)], [], (addBenefit 0) >; diff --git a/src/Transform/ONNX/ONNXHybridTransformPass.cpp b/src/Transform/ONNX/ONNXHybridTransformPass.cpp index 7201f81f6d..d57d5a26e0 100644 --- a/src/Transform/ONNX/ONNXHybridTransformPass.cpp +++ b/src/Transform/ONNX/ONNXHybridTransformPass.cpp @@ -5,9 +5,11 @@ //===------------------ ONNXHybridTransformPass.cpp -----------------------===// // // Hybrid ONNX transformation pass that combines conversion patterns for -// shape inference and canonicalization and constant propagation. +// shape inference, canonicalization, constant propagation, and decomposition. // -// TODO: add decomposition +// Note that the decomposition patterns are applied "best effort" with a greedy +// rewrite, not a partial conversion with "legalization" to ensure that every +// decomposable op is decomposed. // //===----------------------------------------------------------------------===// @@ -21,6 +23,7 @@ #include "src/Interface/ShapeInferenceOpInterface.hpp" #include "src/Pass/Passes.hpp" #include "src/Transform/ONNX/ConstProp.hpp" +#include "src/Transform/ONNX/Decompose.hpp" #include "src/Transform/ONNX/ShapeInference.hpp" using namespace mlir; @@ -60,15 +63,17 @@ struct ONNXHybridTransformPass llvm::cl::desc("Enable constant propagation in hybrid transform"), llvm::cl::init(true)}; + Option decomposition{*this, "decomposition", + llvm::cl::desc("Enable decomposition in hybrid transform"), + llvm::cl::init(true)}; + FrozenRewritePatternSet patterns; ONNXHybridTransformPass() = default; ONNXHybridTransformPass(const ONNXHybridTransformPass &pass) : patterns(pass.patterns) { - shapeInference = pass.shapeInference; - canonicalization = pass.canonicalization; - constantPropagation = pass.constantPropagation; + copyOptionValuesFrom(&pass); } StringRef getArgument() const override { return "onnx-hybrid-transform"; } @@ -92,7 +97,9 @@ struct ONNXHybridTransformPass getConstPropONNXToONNXPatterns(cumulativePatterns); } - // TODO: decomposition + if (decomposition) { + getDecomposeONNXToONNXPatterns(cumulativePatterns); + } patterns = FrozenRewritePatternSet(std::move(cumulativePatterns)); return success(); diff --git a/test/mlir/onnx/onnx_hybrid_transform.mlir b/test/mlir/onnx/onnx_hybrid_transform.mlir index d809b7db0b..7e5ca83ae5 100644 --- a/test/mlir/onnx/onnx_hybrid_transform.mlir +++ b/test/mlir/onnx/onnx_hybrid_transform.mlir @@ -1,5 +1,6 @@ -// RUN: onnx-mlir-opt --enable-constant-prop=true -onnx-hybrid-transform=constant-propagation=false %s | FileCheck %s -// RUN: onnx-mlir-opt --enable-constant-prop=true -onnx-hybrid-transform %s | FileCheck --check-prefix=CONSTPROP %s +// RUN: onnx-mlir-opt --enable-constant-prop=true -onnx-hybrid-transform="constant-propagation=false decomposition=false" %s | FileCheck %s +// RUN: onnx-mlir-opt --enable-constant-prop=true -onnx-hybrid-transform=constant-propagation=false %s | FileCheck --check-prefix=DECOMPOSE %s +// RUN: onnx-mlir-opt --enable-constant-prop=true -onnx-hybrid-transform=decomposition=false %s | FileCheck --check-prefix=CONSTPROP %s // Illustrates the back and forth between shape inference and the // BinaryOpBroadcastAxisPattern canonicalization pattern: @@ -98,7 +99,6 @@ func.func @test_inception_v2_6_snippet(%arg0: tensor<1x3x224x224xf32>, %arg1: te %529 = "onnx.Relu"(%528) : (tensor<*xf32>) -> tensor<*xf32> return %529 : tensor<*xf32> } - // CHECK-LABEL: func.func @test_inception_v2_6_snippet // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x224x224xf32>, [[PARAM_1_:%.+]]: tensor<64x3x7x7xf32>) -> tensor<1x64x28x28xf32> { // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1, 2]> : tensor<2xi64> @@ -196,6 +196,103 @@ func.func @test_inception_v2_6_snippet(%arg0: tensor<1x3x224x224xf32>, %arg1: te // CHECK: return [[VAR_87_]] : tensor<1x64x28x28xf32> // CHECK: } +// DECOMPOSE-LABEL: func.func @test_inception_v2_6_snippet +// DECOMPOSE-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x224x224xf32>, [[PARAM_1_:%.+]]: tensor<64x3x7x7xf32>) -> tensor<1x64x28x28xf32> { +// DECOMPOSE-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1, 2]> : tensor<2xi64> +// DECOMPOSE-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<9.99999974E-6> : tensor<1xf32> +// DECOMPOSE-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[1, 2, 3]> : tensor<3xi64> +// DECOMPOSE-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<1.000000e-01> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<2.000000e-01> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<3.000000e-01> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_6_:%.+]] = onnx.Constant dense<4.000000e-01> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_7_:%.+]] = onnx.Constant dense<5.000000e-01> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_8_:%.+]] = onnx.Constant dense<6.000000e-01> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_9_:%.+]] = onnx.Constant dense<0.699999988> : tensor<64x64x1x1xf32> +// DECOMPOSE-DAG: [[VAR_10_:%.+]] = onnx.Constant dense<8.000000e-01> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_11_:%.+]] = onnx.Constant dense<0.899999976> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_12_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_13_:%.+]] = onnx.Constant dense<1.100000e+00> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_14_:%.+]] = onnx.Constant dense<1.200000e+00> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_15_:%.+]] = onnx.Constant dense<1.300000e+00> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_16_:%.+]] = onnx.Constant dense<1.400000e+00> : tensor<192x64x3x3xf32> +// DECOMPOSE-DAG: [[VAR_17_:%.+]] = onnx.Constant dense<1.500000e+00> : tensor<192xf32> +// DECOMPOSE-DAG: [[VAR_18_:%.+]] = onnx.Constant dense<1.600000e+00> : tensor<192xf32> +// DECOMPOSE-DAG: [[VAR_19_:%.+]] = onnx.Constant dense<1.700000e+00> : tensor<192xf32> +// DECOMPOSE-DAG: [[VAR_20_:%.+]] = onnx.Constant dense<1.800000e+00> : tensor<192xf32> +// DECOMPOSE-DAG: [[VAR_21_:%.+]] = onnx.Constant dense<1.900000e+00> : tensor<192xf32> +// DECOMPOSE-DAG: [[VAR_22_:%.+]] = onnx.Constant dense<2.000000e+00> : tensor<192xf32> +// DECOMPOSE-DAG: [[VAR_23_:%.+]] = onnx.Constant dense<4.200000e+00> : tensor<64x192x1x1xf32> +// DECOMPOSE-DAG: [[VAR_24_:%.+]] = onnx.Constant dense<4.300000e+00> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_25_:%.+]] = onnx.Constant dense<4.400000e+00> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_26_:%.+]] = onnx.Constant dense<4.500000e+00> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_27_:%.+]] = onnx.Constant dense<4.600000e+00> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_28_:%.+]] = onnx.Constant dense<4.700000e+00> : tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_29_:%.+]] = onnx.Constant dense<4.800000e+00> : tensor<64xf32> +// DECOMPOSE: [[VAR_30_:%.+]] = "onnx.Add"([[VAR_6_]], [[VAR_1_]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> +// DECOMPOSE: [[VAR_31_:%.+]] = "onnx.Sqrt"([[VAR_30_]]) : (tensor<64xf32>) -> tensor<64xf32> +// DECOMPOSE: [[VAR_32_:%.+]] = "onnx.Div"([[VAR_3_]], [[VAR_3_]]1) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> +// DECOMPOSE: [[VAR_33_:%.+]] = "onnx.Unsqueeze"([[VAR_32_]], [[VAR_2_]]) : (tensor<64xf32>, tensor<3xi64>) -> tensor<64x1x1x1xf32> +// DECOMPOSE-DAG: [[VAR_34_:%.+]] = "onnx.Mul"([[PARAM_1_]], [[VAR_33_]]) : (tensor<64x3x7x7xf32>, tensor<64x1x1x1xf32>) -> tensor<64x3x7x7xf32> +// DECOMPOSE-DAG: [[VAR_35_:%.+]] = "onnx.Neg"([[VAR_5_]]) : (tensor<64xf32>) -> tensor<64xf32> +// DECOMPOSE: [[VAR_36_:%.+]] = "onnx.Mul"([[VAR_32_]], [[VAR_35_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> +// DECOMPOSE: [[VAR_37_:%.+]] = "onnx.Add"([[VAR_36_]], [[VAR_4_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_38_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_34_]], [[VAR_37_]]) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]} : (tensor<1x3x224x224xf32>, tensor<64x3x7x7xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32> +// DECOMPOSE-DAG: [[VAR_39_:%.+]] = "onnx.Unsqueeze"([[VAR_7_]], [[VAR_0_]]) : (tensor<64xf32>, tensor<2xi64>) -> tensor<64x1x1xf32> +// DECOMPOSE-NOT: separator of consecutive DAGs +// DECOMPOSE-DAG: [[VAR_40_:%.+]] = "onnx.Mul"([[VAR_38_]], [[VAR_39_]]) : (tensor<1x64x112x112xf32>, tensor<64x1x1xf32>) -> tensor<1x64x112x112xf32> +// DECOMPOSE-DAG: [[VAR_41_:%.+]] = "onnx.Unsqueeze"([[VAR_8_]], [[VAR_0_]]) : (tensor<64xf32>, tensor<2xi64>) -> tensor<64x1x1xf32> +// DECOMPOSE: [[VAR_42_:%.+]] = "onnx.Add"([[VAR_40_]], [[VAR_41_]]) : (tensor<1x64x112x112xf32>, tensor<64x1x1xf32>) -> tensor<1x64x112x112xf32> +// DECOMPOSE: [[VAR_43_:%.+]] = "onnx.Relu"([[VAR_42_]]) : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32> +// DECOMPOSE-DAG: [[VAR_44_:%.+]] = "onnx.MaxPoolSingleOut"([[VAR_43_]]) {auto_pad = "NOTSET", ceil_mode = 0 : si64, kernel_shape = [3, 3], pads = [0, 0, 1, 1], storage_order = 0 : si64, strides = [2, 2]} : (tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32> +// DECOMPOSE-DAG: [[VAR_45_:%.+]] = "onnx.Add"([[VAR_13_]], [[VAR_1_]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> +// DECOMPOSE: [[VAR_46_:%.+]] = "onnx.Sqrt"([[VAR_45_]]) : (tensor<64xf32>) -> tensor<64xf32> +// DECOMPOSE: [[VAR_47_:%.+]] = "onnx.Div"([[VAR_10_]], [[VAR_46_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> +// DECOMPOSE: [[VAR_48_:%.+]] = "onnx.Unsqueeze"([[VAR_47_]], [[VAR_2_]]) : (tensor<64xf32>, tensor<3xi64>) -> tensor<64x1x1x1xf32> +// DECOMPOSE-DAG: [[VAR_49_:%.+]] = "onnx.Mul"([[VAR_48_]], [[VAR_9_]]) : (tensor<64x1x1x1xf32>, tensor<64x64x1x1xf32>) -> tensor<64x64x1x1xf32> +// DECOMPOSE-DAG: [[VAR_50_:%.+]] = "onnx.Neg"([[VAR_12_]]) : (tensor<64xf32>) -> tensor<64xf32> +// DECOMPOSE: [[VAR_51_:%.+]] = "onnx.Mul"([[VAR_47_]], [[VAR_50_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> +// DECOMPOSE: [[VAR_52_:%.+]] = "onnx.Add"([[VAR_51_]], [[VAR_11_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_53_:%.+]] = "onnx.Conv"([[VAR_44_]], [[VAR_49_]], [[VAR_52_]]) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x64x56x56xf32>, tensor<64x64x1x1xf32>, tensor<64xf32>) -> tensor<1x64x56x56xf32> +// DECOMPOSE-DAG: [[VAR_54_:%.+]] = "onnx.Unsqueeze"([[VAR_14_]], [[VAR_0_]]) : (tensor<64xf32>, tensor<2xi64>) -> tensor<64x1x1xf32> +// DECOMPOSE-NOT: separator of consecutive DAGs +// DECOMPOSE-DAG: [[VAR_55_:%.+]] = "onnx.Mul"([[VAR_53_]], [[VAR_54_]]) : (tensor<1x64x56x56xf32>, tensor<64x1x1xf32>) -> tensor<1x64x56x56xf32> +// DECOMPOSE-DAG: [[VAR_56_:%.+]] = "onnx.Unsqueeze"([[VAR_15_]], [[VAR_0_]]) : (tensor<64xf32>, tensor<2xi64>) -> tensor<64x1x1xf32> +// DECOMPOSE: [[VAR_57_:%.+]] = "onnx.Add"([[VAR_55_]], [[VAR_56_]]) : (tensor<1x64x56x56xf32>, tensor<64x1x1xf32>) -> tensor<1x64x56x56xf32> +// DECOMPOSE-DAG: [[VAR_58_:%.+]] = "onnx.Relu"([[VAR_57_]]) : (tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32> +// DECOMPOSE-DAG: [[VAR_59_:%.+]] = "onnx.Add"([[VAR_20_]], [[VAR_1_]]) : (tensor<192xf32>, tensor<1xf32>) -> tensor<192xf32> +// DECOMPOSE: [[VAR_60_:%.+]] = "onnx.Sqrt"([[VAR_59_]]) : (tensor<192xf32>) -> tensor<192xf32> +// DECOMPOSE: [[VAR_61_:%.+]] = "onnx.Div"([[VAR_17_]], [[VAR_60_]]) : (tensor<192xf32>, tensor<192xf32>) -> tensor<192xf32> +// DECOMPOSE: [[VAR_62_:%.+]] = "onnx.Unsqueeze"([[VAR_61_]], [[VAR_2_]]) : (tensor<192xf32>, tensor<3xi64>) -> tensor<192x1x1x1xf32> +// DECOMPOSE-DAG: [[VAR_63_:%.+]] = "onnx.Mul"([[VAR_62_]], [[VAR_16_]]) : (tensor<192x1x1x1xf32>, tensor<192x64x3x3xf32>) -> tensor<192x64x3x3xf32> +// DECOMPOSE-DAG: [[VAR_64_:%.+]] = "onnx.Neg"([[VAR_19_]]) : (tensor<192xf32>) -> tensor<192xf32> +// DECOMPOSE: [[VAR_65_:%.+]] = "onnx.Mul"([[VAR_61_]], [[VAR_64_]]) : (tensor<192xf32>, tensor<192xf32>) -> tensor<192xf32> +// DECOMPOSE: [[VAR_66_:%.+]] = "onnx.Add"([[VAR_65_]], [[VAR_18_]]) : (tensor<192xf32>, tensor<192xf32>) -> tensor<192xf32> +// DECOMPOSE-DAG: [[VAR_67_:%.+]] = "onnx.Conv"([[VAR_58_]], [[VAR_63_]], [[VAR_66_]]) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]} : (tensor<1x64x56x56xf32>, tensor<192x64x3x3xf32>, tensor<192xf32>) -> tensor<1x192x56x56xf32> +// DECOMPOSE-DAG: [[VAR_68_:%.+]] = "onnx.Unsqueeze"([[VAR_21_]], [[VAR_0_]]) : (tensor<192xf32>, tensor<2xi64>) -> tensor<192x1x1xf32> +// DECOMPOSE-NOT: separator of consecutive DAGs +// DECOMPOSE-DAG: [[VAR_69_:%.+]] = "onnx.Mul"([[VAR_67_]], [[VAR_68_]]) : (tensor<1x192x56x56xf32>, tensor<192x1x1xf32>) -> tensor<1x192x56x56xf32> +// DECOMPOSE-DAG: [[VAR_70_:%.+]] = "onnx.Unsqueeze"([[VAR_22_]], [[VAR_0_]]) : (tensor<192xf32>, tensor<2xi64>) -> tensor<192x1x1xf32> +// DECOMPOSE: [[VAR_71_:%.+]] = "onnx.Add"([[VAR_69_]], [[VAR_70_]]) : (tensor<1x192x56x56xf32>, tensor<192x1x1xf32>) -> tensor<1x192x56x56xf32> +// DECOMPOSE: [[VAR_72_:%.+]] = "onnx.Relu"([[VAR_71_]]) : (tensor<1x192x56x56xf32>) -> tensor<1x192x56x56xf32> +// DECOMPOSE-DAG: [[VAR_73_:%.+]] = "onnx.MaxPoolSingleOut"([[VAR_72_]]) {auto_pad = "NOTSET", ceil_mode = 0 : si64, kernel_shape = [3, 3], pads = [0, 0, 1, 1], storage_order = 0 : si64, strides = [2, 2]} : (tensor<1x192x56x56xf32>) -> tensor<1x192x28x28xf32> +// DECOMPOSE-DAG: [[VAR_74_:%.+]] = "onnx.Add"([[VAR_27_]], [[VAR_1_]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> +// DECOMPOSE: [[VAR_75_:%.+]] = "onnx.Sqrt"([[VAR_74_]]) : (tensor<64xf32>) -> tensor<64xf32> +// DECOMPOSE: [[VAR_76_:%.+]] = "onnx.Div"([[VAR_24_]], [[VAR_75_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> +// DECOMPOSE: [[VAR_77_:%.+]] = "onnx.Unsqueeze"([[VAR_76_]], [[VAR_2_]]) : (tensor<64xf32>, tensor<3xi64>) -> tensor<64x1x1x1xf32> +// DECOMPOSE-DAG: [[VAR_78_:%.+]] = "onnx.Mul"([[VAR_77_]], [[VAR_23_]]) : (tensor<64x1x1x1xf32>, tensor<64x192x1x1xf32>) -> tensor<64x192x1x1xf32> +// DECOMPOSE-DAG: [[VAR_79_:%.+]] = "onnx.Neg"([[VAR_26_]]) : (tensor<64xf32>) -> tensor<64xf32> +// DECOMPOSE: [[VAR_80_:%.+]] = "onnx.Mul"([[VAR_76_]], [[VAR_79_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> +// DECOMPOSE: [[VAR_81_:%.+]] = "onnx.Add"([[VAR_80_]], [[VAR_25_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> +// DECOMPOSE-DAG: [[VAR_82_:%.+]] = "onnx.Conv"([[VAR_73_]], [[VAR_78_]], [[VAR_81_]]) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x192x28x28xf32>, tensor<64x192x1x1xf32>, tensor<64xf32>) -> tensor<1x64x28x28xf32> +// DECOMPOSE-DAG: [[VAR_83_:%.+]] = "onnx.Unsqueeze"([[VAR_28_]], [[VAR_0_]]) : (tensor<64xf32>, tensor<2xi64>) -> tensor<64x1x1xf32> +// DECOMPOSE-NOT: separator of consecutive DAGs +// DECOMPOSE-DAG: [[VAR_84_:%.+]] = "onnx.Mul"([[VAR_82_]], [[VAR_83_]]) : (tensor<1x64x28x28xf32>, tensor<64x1x1xf32>) -> tensor<1x64x28x28xf32> +// DECOMPOSE-DAG: [[VAR_85_:%.+]] = "onnx.Unsqueeze"([[VAR_29_]], [[VAR_0_]]) : (tensor<64xf32>, tensor<2xi64>) -> tensor<64x1x1xf32> +// DECOMPOSE: [[VAR_86_:%.+]] = "onnx.Add"([[VAR_84_]], [[VAR_85_]]) : (tensor<1x64x28x28xf32>, tensor<64x1x1xf32>) -> tensor<1x64x28x28xf32> +// DECOMPOSE: [[VAR_87_:%.+]] = "onnx.Relu"([[VAR_86_]]) : (tensor<1x64x28x28xf32>) -> tensor<1x64x28x28xf32> +// DECOMPOSE: return [[VAR_87_]] : tensor<1x64x28x28xf32> +// DECOMPOSE: } + // CONSTPROP-LABEL: func.func @test_inception_v2_6_snippet // CONSTPROP-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x224x224xf32>, [[PARAM_1_:%.+]]: tensor<64x3x7x7xf32>) -> tensor<1x64x28x28xf32> { // CONSTPROP-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<4.800000e+00> : tensor<64x1x1xf32>