Skip to content

Commit

Permalink
[Relax] Allow composition of DFPattern replacements (#16732)
Browse files Browse the repository at this point in the history
[Relax] Allow composition of DFPattern replacements

The `rewrite_call` function accepts a `DFPattern`, and a function to
rewrite expressions matching that pattern.  Often, the rewriting
function will perform additional validation that cannot be expressed
within the `DFPattern` itself.  If this additional validation fails,
the rewriter function will return the matched expression unmodified.

Prior to this commit, an `OrPattern` that matches on the first branch,
but whose rewriter function does not apply a modification, would
prevent the second branch from being checked.  This commit updates the
`ExprPatternRewriter` to check both branches of a `OrPattern`, if the
rewriter function of the first branch does not modify the result.
  • Loading branch information
Lunderberg authored Mar 27, 2024
1 parent 726a141 commit 86b5a13
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 5 deletions.
44 changes: 39 additions & 5 deletions src/relax/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1158,17 +1158,51 @@ class ExprPatternRewriter : ExprMutator {
Expr VisitExpr(const Expr& expr) override {
auto node = ExprMutator::VisitExpr(expr);

if (auto matches_opt = ExtractMatchedExpr(pattern_, node, bindings_)) {
Expr rewritten_expr = rewriter_func_(node, matches_opt.value());
if (!rewritten_expr.same_as(node)) {
return builder_->Normalize(rewritten_expr);
}
std::vector<DFPattern> matches_top_level;
if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) {
return builder_->Normalize(rewritten.value());
}

return node;
}

private:
Optional<Expr> TryRewrite(const Expr& expr, const DFPattern& pattern,
std::vector<DFPattern>* matches_top_level) {
ICHECK(matches_top_level);

// Special handling if the user-supplied pattern is a `OrPattern`.
// While the `ExtractMatchedExpr` can handle matching the
// `OrPattern`, it will return on the first match, even if the
// `rewriter_func_` doesn't apply a replacement. Unpacking the
// `OrPattern` here allows the match to be resumed if
// `rewriter_func_` returns the original function unmodified.
// This is only valid for a top-level match.
if (auto or_pattern = pattern.as<OrPatternNode>()) {
matches_top_level->push_back(pattern);
Optional<Expr> output = TryRewrite(expr, or_pattern->left, matches_top_level);
if (!output.defined()) {
output = TryRewrite(expr, or_pattern->right, matches_top_level);
}
matches_top_level->pop_back();
return output;
}

if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) {
auto matches = opt_matches.value();
for (const auto& pat : *matches_top_level) {
matches.Set(pat, expr);
}

Expr rewritten_expr = rewriter_func_(expr, matches);
if (!rewritten_expr.same_as(expr)) {
return builder_->Normalize(rewritten_expr);
}
}

return NullOpt;
}

/*! \brief The pattern for rewriting call nodes */
DFPattern pattern_;
/*!
Expand Down
63 changes: 63 additions & 0 deletions tests/python/relax/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,5 +1889,68 @@ def expected():
tvm.ir.assert_structural_equal(expected, after)


def test_backtrack_if_rewriter_returns_no_op():
"""Rewriter participates in the pattern matching
Sometimes, the pattern-matching syntax is insufficient to check if
a replacement may be performed. In this case, the `rewriter`
function may perform additional validation. If this validation
fails, the `rewriter` function can return the original expression,
and no replacement is performed.
In addition, when the `rewriter` returns the original expression,
the pattern match should backtrack to determine if another branch
of the match may have produced a replacement.
This functionality allows pattern replacements to be composed.
"""

pat_match_no_rewrite = is_op("relax.add")(wildcard(), wildcard())

pat_arg = wildcard()
pat_zeros = is_op("relax.zeros")(wildcard())
pat_add = is_op("relax.add")(pat_arg, pat_zeros)

# OR conditions are checked in the order that they occur. Because
# `pat_match_no_rewrite` is a superset of `pat_add`, it will
# always match first.
pat = pat_match_no_rewrite | pat_add

def rewriter(expr, matches):
if pat_match_no_rewrite in matches:
# This branch simulates a rewrite whose precondition has
# failed. If the pattern-matching treats this as a
# successful match with no replacemen required, then no
# rewrite would be performed. On the other hand, if the
# pattern-matching treats this as an unsuccessful match,
# then it can backtrack and attempt `pat_add` instead.
return expr
elif pat_add in matches:
return matches[pat_arg]
else:
raise RuntimeError("Pattern matched, but neither branch matched")

@R.function(private=True)
def before():
with R.dataflow():
A = R.ones([64, 128], "int32")
B = R.zeros([64, 128], "int32")
C = R.add(A, B)

R.output(C)
return C

@R.function(private=True)
def expected():
with R.dataflow():
C = R.ones([64, 128], "int32")

R.output(C)
return C

after = rewrite_call(pat, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 86b5a13

Please sign in to comment.