Skip to content

Commit

Permalink
constant fold for IfOp (#2544)
Browse files Browse the repository at this point in the history
* code

Signed-off-by: chentong319 <[email protected]>

* test

Signed-off-by: chentong319 <[email protected]>

---------

Signed-off-by: chentong319 <[email protected]>
  • Loading branch information
chentong319 authored Oct 2, 2023
1 parent ded4d47 commit 554acae
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
37 changes: 37 additions & 0 deletions src/Transform/ONNX/ConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,42 @@ class SplitOfConst : public OpRewritePattern<ONNXSplitOp> {
}
};

class IfOfConst : public OpRewritePattern<ONNXIfOp> {
public:
using OpRewritePattern<ONNXIfOp>::OpRewritePattern;

LogicalResult match(ONNXIfOp ifOp) const override {
if (!isDenseONNXConstant(ifOp.getCond()))
return failure();
return success();
}

void rewrite(ONNXIfOp ifOp, PatternRewriter &rewriter) const override {
Value cond = ifOp.getCond();
ElementsAttr condElements = getConstValueElements(cond);
auto splitValues = condElements.getValues<bool>();
Region *region;
if (splitValues[0] == 0) {
region = &ifOp.getElseBranch();
} else {
region = &ifOp.getThenBranch();
}

assert(
region->hasOneBlock() && "Then/Else region should have only one block");

Operation *yieldOp = region->front().getTerminator();
ValueRange yields = yieldOp->getOperands();
SmallVector<Value, 4> outputs(yields.begin(), yields.end());
Block *newBlock =
rewriter.splitBlock(&region->front(), region->front().begin());

rewriter.eraseOp(yieldOp);
rewriter.inlineBlockBefore(newBlock, ifOp);
rewriter.replaceOp(ifOp, outputs);
}
};

//===----------------------------------------------------------------------===//
// Code to manage the pass.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1058,6 +1094,7 @@ void onnx_mlir::getConstPropONNXToONNXPatterns(RewritePatternSet &patterns) {
populateWithGenerated(patterns);
if (isNotDisabled("SplitOfConst"))
patterns.insert<SplitOfConst>(patterns.getContext());
patterns.insert<IfOfConst>(patterns.getContext());
}

void onnx_mlir::configureConstPropONNXToONNXPass(int expansionBound,
Expand Down
44 changes: 44 additions & 0 deletions test/mlir/onnx/onnx_constprop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1916,3 +1916,47 @@ func.func @test_sum_3_inputs() -> tensor<2x2xf32> {
// CHECK: onnx.Return [[VAR_0_]] : tensor<2x2xf32>
// CHECK: }
}

// -----

func.func @test_if_true(%arg0 : tensor<*xf16>, %arg1 : tensor<1xi64>, %arg2 : tensor<*xf16>) -> tensor<?xi64> {
%487 = onnx.Constant dense<true> : tensor<1xi1>
%488 = "onnx.If"(%487) ({
%6277 = onnx.Constant dense<1> : tensor<1xi64>
%6278 = "onnx.Squeeze"(%arg0, %arg1) : (tensor<*xf16>, tensor<1xi64>) -> tensor<?x?x?xf16>
onnx.Yield %6278 : tensor<?x?x?xf16>
}, {
%6277 = "onnx.Identity"(%arg2) : (tensor<*xf16>) -> tensor<?x?x?x?xf16>
onnx.Yield %6277 : tensor<?x?x?x?xf16>
}) : (tensor<1xi1>) -> tensor<*xf16>
%490 = "onnx.Shape"(%488) { start = 0 : si64} : (tensor<*xf16>) -> tensor<?xi64>
onnx.Return %490 : tensor<?xi64>
}
// CHECK-LABEL: func.func @test_if_true
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf16>, [[PARAM_1_:%.+]]: tensor<1xi64>, [[PARAM_2_:%.+]]: tensor<*xf16>) -> tensor<?xi64> {
// CHECK: [[VAR_0_:%.+]] = "onnx.Squeeze"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<*xf16>, tensor<1xi64>) -> tensor<?x?x?xf16>
// CHECK: [[VAR_1_:%.+]] = "onnx.Shape"([[VAR_0_]]) {start = 0 : si64} : (tensor<?x?x?xf16>) -> tensor<?xi64>
// CHECK: onnx.Return [[VAR_1_]] : tensor<?xi64>
// CHECK: }

// -----

func.func @test_if_false(%arg0 : tensor<*xf16>, %arg1 : tensor<1xi64>, %arg2 : tensor<*xf16>) -> tensor<?xi64> {
%487 = onnx.Constant dense<false> : tensor<1xi1>
%488 = "onnx.If"(%487) ({
%6277 = onnx.Constant dense<1> : tensor<1xi64>
%6278 = "onnx.Squeeze"(%arg0, %arg1) : (tensor<*xf16>, tensor<1xi64>) -> tensor<?x?x?xf16>
onnx.Yield %6278 : tensor<?x?x?xf16>
}, {
%6277 = "onnx.Identity"(%arg2) : (tensor<*xf16>) -> tensor<?x?x?x?xf16>
onnx.Yield %6277 : tensor<?x?x?x?xf16>
}) : (tensor<1xi1>) -> tensor<*xf16>
%490 = "onnx.Shape"(%488) { start = 0 : si64} : (tensor<*xf16>) -> tensor<?xi64>
onnx.Return %490 : tensor<?xi64>
}
// CHECK-LABEL: func.func @test_if_false
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf16>, [[PARAM_1_:%.+]]: tensor<1xi64>, [[PARAM_2_:%.+]]: tensor<*xf16>) -> tensor<?xi64> {
// CHECK: [[VAR_0_:%.+]] = "onnx.Identity"([[PARAM_2_]]) : (tensor<*xf16>) -> tensor<?x?x?x?xf16>
// CHECK: [[VAR_1_:%.+]] = "onnx.Shape"([[VAR_0_]]) {start = 0 : si64} : (tensor<?x?x?x?xf16>) -> tensor<?xi64>
// CHECK: onnx.Return [[VAR_1_]] : tensor<?xi64>
// CHECK: }

0 comments on commit 554acae

Please sign in to comment.