Skip to content

Commit

Permalink
[Relax] Normalize use of void-type variable to inline R.tuple() (#16658)
Browse files Browse the repository at this point in the history
* [Relax] Normalize use of void-type variable to inline R.tuple()

This is a follow-up commit to
#16641.  While parsing of relax
expressions without a variable binding could be implemented at that
point (e.g. `R.assert_op(condition)` instead of `dummy_var =
R.assert_op(condition)`), the corresponding printing changes could
not.  This was because a variable that satisfies
`relax::HasVoidStructInfo(var)` could still be used later in the
function, and removing its binding would result in use of an undefined
variable.

This commit normalizes use of void-type variables to an in-line
`R.tuple()`.  This simplifies the relax function, and also allows the
binding of void-type variables to be hidden.

* Fix breakage in unit tests
  • Loading branch information
Lunderberg authored Mar 13, 2024
1 parent dffdc3e commit 8023a98
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 33 deletions.
9 changes: 8 additions & 1 deletion src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,14 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&
return GetRef<Var>(var);
}

Expr VisitExpr_(const VarNode* var) final { return VisitVar_<Var>(var); }
Expr VisitExpr_(const VarNode* var_ptr) final {
auto var = VisitVar_<Var>(var_ptr);
if (HasVoidStructInfo(var)) {
return VisitExpr(Tuple(Array<Expr>{}));
} else {
return var;
}
}

Expr VisitExpr_(const DataflowVarNode* var) final { return VisitVar_<DataflowVar>(var); }

Expand Down
11 changes: 6 additions & 5 deletions src/script/ir_builder/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_

#include <tvm/relax/struct_info_functor.h>
#include <tvm/relax/utils.h>
#include <tvm/script/ir_builder/relax/frame.h>
#include <tvm/script/ir_builder/relax/ir.h>

Expand Down Expand Up @@ -109,12 +110,12 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String
GetStructInfo(last_binding->var));
tvm::relax::Expr body;

if (const auto* var_binding = last_binding.as<tvm::relax::VarBindingNode>();
var_binding && var_binding->value->IsInstance<tvm::relax::VarNode>()) {
const auto* var_binding = last_binding.as<tvm::relax::VarBindingNode>();

if (var_binding && tvm::relax::IsLeafOrTuple(var_binding->value)) {
body = var_binding->value;
} else if (const auto* var_binding = last_binding.as<tvm::relax::VarBindingNode>()) {
last_block_bindings.push_back(last_binding =
tvm::relax::VarBinding(new_var, var_binding->value));
} else if (var_binding) {
last_block_bindings.push_back(tvm::relax::VarBinding(new_var, var_binding->value));
body = new_var;
} else if (const auto* match_cast = last_binding.as<tvm::relax::MatchCastNode>()) {
last_block_bindings.push_back(
Expand Down
22 changes: 4 additions & 18 deletions src/script/printer/relax/binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Doc ret = d->AsDoc(n->value, n_p->Attr("value"));
d->cfg->binding_names.pop_back();
return ret;

// Uncommenting this section hides the variable binding
// when the StructInfo is void. For example, printing
// `R.assert_op(expr)` instead of `_ = R.assert_op(expr)`.
// However, Relax represents void values as an empty
// tuple, and a void-type variable may still be used later
// in the function. Hiding bindings of these void-type
// variables would result in use of an undefined variable.
//
// TODO(Lunderberg): Inline void-type variable to use
// `R.tuple()` during normalization. This will avoid the
// cases that trigger the undefined variables, and allow
// this syntax sugar to be enabled.
//
// } else if (d->cfg->syntax_sugar && relax::HasVoidStructInfo(n->value) &&
// relax::HasVoidStructInfo(n->var)) {
// ExprDoc rhs = d->AsDoc<ExprDoc>(n->value, n_p->Attr("value"));
// return ExprStmtDoc(rhs);
} else if (d->cfg->syntax_sugar && relax::HasVoidStructInfo(n->value) &&
relax::HasVoidStructInfo(n->var)) {
ExprDoc rhs = d->AsDoc<ExprDoc>(n->value, n_p->Attr("value"));
return ExprStmtDoc(rhs);
} else {
ExprDoc rhs = d->AsDoc<ExprDoc>(n->value, n_p->Attr("value"));
Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value);
Expand Down
10 changes: 6 additions & 4 deletions tests/python/relax/test_transform_lift_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,10 @@ def main_transform_params(params: R.Tuple) -> R.Tuple:
R.func_attr({"num_input": 0})
with R.dataflow():
gv: R.Tuple = R.tuple()
R.output(gv)
return gv
R.output()
# All instance of the empty tuple are normalized to be
# in-line.
return R.tuple()

@R.function
def main(shape: R.Shape(["n"])) -> R.Shape(["n"]):
Expand Down Expand Up @@ -612,8 +614,8 @@ def main_transform_params(params: R.Tuple) -> R.Tuple:
R.func_attr({"num_input": 0})
with R.dataflow():
gv: R.Tuple = R.tuple()
R.output(gv)
return gv
R.output()
return R.tuple()

@R.function
def main(shape: R.Shape(["n"])) -> R.Shape(["n"]):
Expand Down
29 changes: 29 additions & 0 deletions tests/python/relax/test_transform_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,5 +552,34 @@ def test_nesting_non_dataflow_in_dataflow_error():
# should fail due to a normal binding block being inside a dataflowblock


def test_remove_usage_of_void_type_variables():
"""All empty tuples should be constructed in-line
For readability, TVMScript hides the variable binding if the
variable has a void type. For example, `R.assert_op(condition)`
instead of `void_var: R.Tuple([]) = R.assert_op(condition)`.
However, Relax follows standard convention of functional
languages, and uses an empty tuple to represent void. Since an
empty tuple may be legally used later in the function, the
`void_var` may require a binding.
This is avoided by normalizing all usage of a void-type
variable with an in-line `R.tuple()`.
"""
x = relax.Var("x", R.Tuple([]))
bindings = [
relax.VarBinding(x, R.assert_op(R.const(True, "bool"))),
]
seq = relax.SeqExpr([relax.BindingBlock(bindings)], x)
before = relax.Function([], seq, ret_struct_info=R.Tuple([]))

after = relax.transform.Normalize()(tvm.IRModule({"main": before}))["main"]

@R.function(private=True)
def expected():
x = R.assert_op(R.const(True, "bool"))
return R.tuple()


if __name__ == "__main__":
tvm.testing.main()
5 changes: 0 additions & 5 deletions tests/python/relax/test_tvmscript_printer_relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
# under the License.
# pylint: disable=missing-docstring

import pytest

import tvm
import tvm.testing
from tvm import IRModule, relax, tir
Expand Down Expand Up @@ -636,7 +634,6 @@ def foo(x: R.Tensor((128,), dtype="float32")) -> R.Tensor((128,), dtype="float32
)


@pytest.mark.xfail(reason="Eliding void variable bindings currently disabled")
def test_assert_op():
@I.ir_module
class AssertOpMod:
Expand All @@ -661,7 +658,6 @@ def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
)


@pytest.mark.xfail(reason="Eliding void variable bindings currently disabled")
def test_print():
@I.ir_module
class PrintMod:
Expand Down Expand Up @@ -710,7 +706,6 @@ def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
)


@pytest.mark.xfail(reason="Eliding void variable bindings currently disabled")
def test_directly_construct_private_funcs():
# public
@R.function
Expand Down

0 comments on commit 8023a98

Please sign in to comment.