diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 531971d3db5d..db70ef6a9cec 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -1158,17 +1158,51 @@ class ExprPatternRewriter : ExprMutator { Expr VisitExpr(const Expr& expr) override { auto node = ExprMutator::VisitExpr(expr); - if (auto matches_opt = ExtractMatchedExpr(pattern_, node, bindings_)) { - Expr rewritten_expr = rewriter_func_(node, matches_opt.value()); - if (!rewritten_expr.same_as(node)) { - return builder_->Normalize(rewritten_expr); - } + std::vector matches_top_level; + if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) { + return builder_->Normalize(rewritten.value()); } return node; } private: + Optional TryRewrite(const Expr& expr, const DFPattern& pattern, + std::vector* matches_top_level) { + ICHECK(matches_top_level); + + // Special handling if the user-supplied pattern is a `OrPattern`. + // While the `ExtractMatchedExpr` can handle matching the + // `OrPattern`, it will return on the first match, even if the + // `rewriter_func_` doesn't apply a replacement. Unpacking the + // `OrPattern` here allows the match to be resumed if + // `rewriter_func_` returns the original function unmodified. + // This is only valid for a top-level match. + if (auto or_pattern = pattern.as()) { + matches_top_level->push_back(pattern); + Optional output = TryRewrite(expr, or_pattern->left, matches_top_level); + if (!output.defined()) { + output = TryRewrite(expr, or_pattern->right, matches_top_level); + } + matches_top_level->pop_back(); + return output; + } + + if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) { + auto matches = opt_matches.value(); + for (const auto& pat : *matches_top_level) { + matches.Set(pat, expr); + } + + Expr rewritten_expr = rewriter_func_(expr, matches); + if (!rewritten_expr.same_as(expr)) { + return builder_->Normalize(rewritten_expr); + } + } + + return NullOpt; + } + /*! \brief The pattern for rewriting call nodes */ DFPattern pattern_; /*! diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 583e2a8d0822..81cd8da7fe71 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -1889,5 +1889,68 @@ def expected(): tvm.ir.assert_structural_equal(expected, after) +def test_backtrack_if_rewriter_returns_no_op(): + """Rewriter participates in the pattern matching + + Sometimes, the pattern-matching syntax is insufficient to check if + a replacement may be performed. In this case, the `rewriter` + function may perform additional validation. If this validation + fails, the `rewriter` function can return the original expression, + and no replacement is performed. + + In addition, when the `rewriter` returns the original expression, + the pattern match should backtrack to determine if another branch + of the match may have produced a replacement. + + This functionality allows pattern replacements to be composed. + """ + + pat_match_no_rewrite = is_op("relax.add")(wildcard(), wildcard()) + + pat_arg = wildcard() + pat_zeros = is_op("relax.zeros")(wildcard()) + pat_add = is_op("relax.add")(pat_arg, pat_zeros) + + # OR conditions are checked in the order that they occur. Because + # `pat_match_no_rewrite` is a superset of `pat_add`, it will + # always match first. + pat = pat_match_no_rewrite | pat_add + + def rewriter(expr, matches): + if pat_match_no_rewrite in matches: + # This branch simulates a rewrite whose precondition has + # failed. If the pattern-matching treats this as a + # successful match with no replacemen required, then no + # rewrite would be performed. On the other hand, if the + # pattern-matching treats this as an unsuccessful match, + # then it can backtrack and attempt `pat_add` instead. + return expr + elif pat_add in matches: + return matches[pat_arg] + else: + raise RuntimeError("Pattern matched, but neither branch matched") + + @R.function(private=True) + def before(): + with R.dataflow(): + A = R.ones([64, 128], "int32") + B = R.zeros([64, 128], "int32") + C = R.add(A, B) + + R.output(C) + return C + + @R.function(private=True) + def expected(): + with R.dataflow(): + C = R.ones([64, 128], "int32") + + R.output(C) + return C + + after = rewrite_call(pat, rewriter, before) + tvm.ir.assert_structural_equal(expected, after) + + if __name__ == "__main__": tvm.testing.main()