diff --git a/src/Transform/ONNX/ConstProp.cpp b/src/Transform/ONNX/ConstProp.cpp index df7d388fa9..da45846348 100644 --- a/src/Transform/ONNX/ConstProp.cpp +++ b/src/Transform/ONNX/ConstProp.cpp @@ -1023,6 +1023,42 @@ class SplitOfConst : public OpRewritePattern { } }; +class IfOfConst : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(ONNXIfOp ifOp) const override { + if (!isDenseONNXConstant(ifOp.getCond())) + return failure(); + return success(); + } + + void rewrite(ONNXIfOp ifOp, PatternRewriter &rewriter) const override { + Value cond = ifOp.getCond(); + ElementsAttr condElements = getConstValueElements(cond); + auto splitValues = condElements.getValues(); + Region *region; + if (splitValues[0] == 0) { + region = &ifOp.getElseBranch(); + } else { + region = &ifOp.getThenBranch(); + } + + assert( + region->hasOneBlock() && "Then/Else region should have only one block"); + + Operation *yieldOp = region->front().getTerminator(); + ValueRange yields = yieldOp->getOperands(); + SmallVector outputs(yields.begin(), yields.end()); + Block *newBlock = + rewriter.splitBlock(®ion->front(), region->front().begin()); + + rewriter.eraseOp(yieldOp); + rewriter.inlineBlockBefore(newBlock, ifOp); + rewriter.replaceOp(ifOp, outputs); + } +}; + //===----------------------------------------------------------------------===// // Code to manage the pass. //===----------------------------------------------------------------------===// @@ -1058,6 +1094,7 @@ void onnx_mlir::getConstPropONNXToONNXPatterns(RewritePatternSet &patterns) { populateWithGenerated(patterns); if (isNotDisabled("SplitOfConst")) patterns.insert(patterns.getContext()); + patterns.insert(patterns.getContext()); } void onnx_mlir::configureConstPropONNXToONNXPass(int expansionBound, diff --git a/test/mlir/onnx/onnx_constprop.mlir b/test/mlir/onnx/onnx_constprop.mlir index 7c68518db8..19517c08c6 100644 --- a/test/mlir/onnx/onnx_constprop.mlir +++ b/test/mlir/onnx/onnx_constprop.mlir @@ -1916,3 +1916,47 @@ func.func @test_sum_3_inputs() -> tensor<2x2xf32> { // CHECK: onnx.Return [[VAR_0_]] : tensor<2x2xf32> // CHECK: } } + +// ----- + +func.func @test_if_true(%arg0 : tensor<*xf16>, %arg1 : tensor<1xi64>, %arg2 : tensor<*xf16>) -> tensor { + %487 = onnx.Constant dense : tensor<1xi1> + %488 = "onnx.If"(%487) ({ + %6277 = onnx.Constant dense<1> : tensor<1xi64> + %6278 = "onnx.Squeeze"(%arg0, %arg1) : (tensor<*xf16>, tensor<1xi64>) -> tensor + onnx.Yield %6278 : tensor + }, { + %6277 = "onnx.Identity"(%arg2) : (tensor<*xf16>) -> tensor + onnx.Yield %6277 : tensor + }) : (tensor<1xi1>) -> tensor<*xf16> + %490 = "onnx.Shape"(%488) { start = 0 : si64} : (tensor<*xf16>) -> tensor + onnx.Return %490 : tensor +} +// CHECK-LABEL: func.func @test_if_true +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf16>, [[PARAM_1_:%.+]]: tensor<1xi64>, [[PARAM_2_:%.+]]: tensor<*xf16>) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.Squeeze"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<*xf16>, tensor<1xi64>) -> tensor +// CHECK: [[VAR_1_:%.+]] = "onnx.Shape"([[VAR_0_]]) {start = 0 : si64} : (tensor) -> tensor +// CHECK: onnx.Return [[VAR_1_]] : tensor +// CHECK: } + +// ----- + +func.func @test_if_false(%arg0 : tensor<*xf16>, %arg1 : tensor<1xi64>, %arg2 : tensor<*xf16>) -> tensor { + %487 = onnx.Constant dense : tensor<1xi1> + %488 = "onnx.If"(%487) ({ + %6277 = onnx.Constant dense<1> : tensor<1xi64> + %6278 = "onnx.Squeeze"(%arg0, %arg1) : (tensor<*xf16>, tensor<1xi64>) -> tensor + onnx.Yield %6278 : tensor + }, { + %6277 = "onnx.Identity"(%arg2) : (tensor<*xf16>) -> tensor + onnx.Yield %6277 : tensor + }) : (tensor<1xi1>) -> tensor<*xf16> + %490 = "onnx.Shape"(%488) { start = 0 : si64} : (tensor<*xf16>) -> tensor + onnx.Return %490 : tensor +} +// CHECK-LABEL: func.func @test_if_false +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf16>, [[PARAM_1_:%.+]]: tensor<1xi64>, [[PARAM_2_:%.+]]: tensor<*xf16>) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.Identity"([[PARAM_2_]]) : (tensor<*xf16>) -> tensor +// CHECK: [[VAR_1_:%.+]] = "onnx.Shape"([[VAR_0_]]) {start = 0 : si64} : (tensor) -> tensor +// CHECK: onnx.Return [[VAR_1_]] : tensor +// CHECK: }