Skip to content

Commit

Permalink
[Relax] Provide well-formed output in transform.LazyGetInput (#16841)
Browse files Browse the repository at this point in the history
Prior to this commit, symbolic variables inferred from the parameters
were retained in the output function's `ret_struct_info`.  This is
ill-formed, as the parameters from which these symbolic variables are
inferred are no longer part of the function signature.

This commit updates `LazyGetInput` to use `EraseToWellDefined` to
remove any symbolic variables from `ret_struct_info` that cannot be
inferred from the remaining arguments.
  • Loading branch information
Lunderberg authored Apr 4, 2024
1 parent 61249b4 commit 6f74762
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/relax/transform/lazy_transform_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,22 @@ class LazyInputMutator : public ExprMutator {
Array<Var> new_params(func->params.begin(), func->params.begin() + num_input_params);
new_params.push_back(fget_param);

auto array_externally_visible_vars =
DefinableTIRVarsInStructInfo(TupleStructInfo(new_params.Map(GetStructInfo)));
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> externally_visible_vars(
array_externally_visible_vars.begin(), array_externally_visible_vars.end());
StructInfo new_ret_struct_info =
EraseToWellDefined(func->ret_struct_info, [&](const tir::Var& var) -> Optional<PrimExpr> {
if (externally_visible_vars.count(var)) {
return var;
} else {
return NullOpt;
}
});

auto node = GetRef<Function>(func);
node.CopyOnWrite()->params = new_params;
node.CopyOnWrite()->ret_struct_info = new_ret_struct_info;
node = WithAttr(node, attr::kNumInput, Integer(num_input_params + 1));

plan_ = FunctionPlan{std::move(param_lookup), fget_param};
Expand Down
34 changes: 34 additions & 0 deletions tests/python/relax/test_transform_lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,40 @@ def transform_params(
tvm.ir.assert_structural_equal(After, Expected)


def test_get_item_callback_dynamic_shape():
@I.ir_module
class Before:
@R.function
def transform_params(
A: R.Tensor(["m", "n"], "float32"), B: R.Tensor(["m", "n"], "float32")
) -> R.Tuple(R.Tensor(["m", "n"], "float32"), R.Tensor(["m", "n"], "float32")):
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
return (D, B)

@I.ir_module
class Expected:
@R.function
def transform_params(
fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object)
) -> R.Tuple(R.Tensor(ndim=2, dtype="float32"), R.Tensor(ndim=2, dtype="float32")):
R.func_attr({"num_input": 1})
m = T.int64()
n = T.int64()

A = fget_param(R.prim_value(0), R.str("A"))
A = R.match_cast(A, R.Tensor([m, n], "float32"))
C = R.multiply(A, R.const(2, "float32"))

B = fget_param(R.prim_value(1), R.str("B"))
B = R.match_cast(B, R.Tensor([m, n], "float32"))
D = R.add(C, B)
return (D, B)

After = relax.transform.LazyGetInput()(Before)
tvm.ir.assert_structural_equal(After, Expected)


def test_set_output_callback():
"""fset_output is called for each element of the output tuple
Expand Down

0 comments on commit 6f74762

Please sign in to comment.