Skip to content

Commit

Permalink
[Relax][Analysis] Handle recursive functions in CollectVarUsage (#17224)
Browse files Browse the repository at this point in the history
* [Relax][Analysis] Handle recursive functions in CollectVarUsage

Prior to this commit, the `relax::analysis::CollectVarUsage` utility
treated a local function definition as in-scope after visiting the
body of the local function.  As a result, recursive calls from a local
function were incorrectly identified as calls to an undefined
variable.

This commit updates the `CollectVarUsage` to treat a local function
definition as in-scope when inspecting the function body.  This change
is similar to the change made for structural equality in
#16756.

* lint fixes
  • Loading branch information
Lunderberg authored Aug 22, 2024
1 parent 32063b0 commit ed9aa56
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 2 deletions.
21 changes: 19 additions & 2 deletions src/relax/analysis/udchain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class UDChain : relax::ExprVisitor {

private:
Map<Var, Expr> bound_values;
std::unordered_set<Var> forward_declarations;
std::unordered_map<Var, support::OrderedSet<Var>> usage_map;
support::OrderedSet<Var> outputs;

Expand All @@ -71,9 +72,20 @@ class UDChain : relax::ExprVisitor {
cur_user_ = cache;
}

void VisitBinding_(const VarBindingNode* binding, const FunctionNode* func) override {
// A local Relax function may be recursively defined. References to
// `binding->var` that appear within `func` are valid.
DefineVar(binding->var);
forward_declarations.insert(binding->var);
ExprVisitor::VisitBinding_(binding, func);
}

void VisitVarDef(const Var& var) override {
CHECK(!usage_map.count(var)) << "Variable " << var << " was used before its definition";
usage_map[var] = {};
if (forward_declarations.count(var)) {
forward_declarations.erase(var);
} else {
DefineVar(var);
}
}
void VisitExpr_(const VarNode* op) override {
auto var = GetRef<Var>(op);
Expand All @@ -89,6 +101,11 @@ class UDChain : relax::ExprVisitor {
cur_user_ = nullptr;
ExprVisitor::VisitExpr_(op);
}

void DefineVar(const Var& var) {
CHECK(!usage_map.count(var)) << "Variable " << var << " was used before its definition";
usage_map[var] = {};
}
};

std::pair<runtime::Map<Var, runtime::Array<Var>>, runtime::Array<Var>> FunctionUseDef(
Expand Down
81 changes: 81 additions & 0 deletions tests/python/relax/test_transform_dead_code_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,5 +658,86 @@ def subsubroutine(A: R.Tensor) -> R.Tensor:
tvm.ir.assert_structural_equal(Expected, After)


def test_recursively_defined_lambda():
"""DCE may be applied to recursively-defined functions
While most expressions may only contain references to
previously-defined variables, local Relax function definitions may
contain references to themselves.
This is a regression test. In previous implementations, the
recursive use of `while_loop` resulted in an error, as
`while_loop` was not considered in-scope by the `CollectVarUsage`
utility until after the body of `while_loop` had been visited.
"""

@I.ir_module
class Before:
@R.function
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor:
@R.function
def while_loop(
i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32")
) -> R.Tensor((2, 3), "float32"):
cond = R.call_pure_packed(
"test.vm.less", i, R.const(10), sinfo_args=R.Tensor((), dtype="bool")
)
c = R.const(1, dtype="int32")
if cond:
new_i = R.add(i, c)
new_s = R.add(s, x)
r = while_loop(new_i, new_s)
else:
r = s
return r

gv = while_loop(R.const(0), x)
return gv

Expected = Before

verify(Before, Expected)


def test_recursively_defined_closure():
"""DCE may be applied to recursively-defined closures
This test is identical to `test_recursively_defined_lambda`,
except that the threshold for recursion is defined in an enclosed
variable outside of the recursive function.
"""

@I.ir_module
class Before:
@R.function
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor:
threshold = R.const(10)

@R.function
def while_loop(
i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32")
) -> R.Tensor((2, 3), "float32"):
cond = R.call_pure_packed(
"test.vm.less", i, threshold, sinfo_args=R.Tensor((), dtype="bool")
)
c = R.const(1, dtype="int32")
if cond:
new_i = R.add(i, c)
new_s = R.add(s, x)
r = while_loop(new_i, new_s)
else:
r = s
return r

gv = while_loop(R.const(0), x)
return gv

Expected = Before

verify(Before, Expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit ed9aa56

Please sign in to comment.