Skip to content

Commit

Permalink
[Relax][Bugfix] Provide the full Expr to pattern-match rewriter (#16828)
Browse files Browse the repository at this point in the history
* [Relax][Bugfix] Provide the full Expr to pattern-match rewriter

This resolves a bug that was introduced in
#16732.  If a rewriter function
returned a no-op, and the pattern-match continued, then the `matches`
provided to the rewriter function in subsequent calls would contain
a variable to which the matched expression was bound, not the matched
expression itself.  (e.g. For a match of `C = R.add(A,B)`, passing `C`
to the rewriter instead of `R.add(A,B)`.)

This bug was caused by incorrect re-wrapping of `OrPattern` in
`ExprPatternRewriter`.  Prior to
#16732, all pattern-match results
were populated by `ExtractMatchExpr`, and contained the result after
applying `TryGetValOfVar`.  When re-wrapping the result of an
`OrPattern`, #16732 populated the
additional matches with the result before applying `TryGetValOfVar`.
This commit fixes the bug by applying `TryGetValOfVar`.

* Update with PR link of bugfix
  • Loading branch information
Lunderberg authored Apr 1, 2024
1 parent 3f615dc commit 00395ae
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/relax/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1190,8 +1190,17 @@ class ExprPatternRewriter : ExprMutator {

if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) {
auto matches = opt_matches.value();
for (const auto& pat : *matches_top_level) {
matches.Set(pat, expr);

// Append any additional matches that from the unwrapped
// `OrPattern`. When matching against `pat = pat_lhs |
// pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and
// `pat_rhs` separately. The top-level `pat` is never seen by
// `ExtractMatchedExpr`, and must be re-added afterward.
if (matches_top_level->size()) {
auto matched_expr = TryGetValOfVar(expr, bindings_);
for (const auto& pat : *matches_top_level) {
matches.Set(pat, matched_expr);
}
}

Expr rewritten_expr = rewriter_func_(expr, matches);
Expand Down
33 changes: 33 additions & 0 deletions tests/python/relax/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1952,5 +1952,38 @@ def expected():
tvm.ir.assert_structural_equal(expected, after)


def test_backtrack_for_no_op_rewriter_does_not_match_on_var():
"""The matches should always contain the bound value
This is a regression test. In versions from
https://github.com/apache/tvm/pull/16732 to
https://github.com/apache/tvm/pull/16828, the `rewrite_call`
function could erroneously call the rewriter with `expr` and
`matches[pat]` set to a variable (`C`) instead of the value to
which it is bound (`R.add(A,B)`).
"""
pat_a = is_op("relax.add")(wildcard(), wildcard())
pat_b = is_op("relax.add")(wildcard(), wildcard())
pat = pat_a | pat_b

def rewriter(expr, matches):
assert isinstance(matches[pat], rx.Call)
return expr

@R.function(private=True)
def before():
with R.dataflow():
A = R.ones([64, 128], "int32")
B = R.zeros([64, 128], "int32")
C = R.add(A, B)

R.output(C)
return C

expected = before
after = rewrite_call(pat, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 00395ae

Please sign in to comment.