Skip to content

Commit

Permalink
Fix breakage in unit tests
Browse files Browse the repository at this point in the history
One unit test that had been relying on invalid shape propagation.
Another unit test that required constructed an ill-formed output to
test against.
  • Loading branch information
Lunderberg committed Jul 16, 2024
1 parent 234ddde commit 7f62f70
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 25 deletions.
53 changes: 29 additions & 24 deletions src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -578,32 +578,37 @@ Function::Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct
body_sinfo = GetStructInfo(body);
}

if (ret_struct_info.defined()) {
// allow body to override ret if body is more fine-grained.
if (body_sinfo.defined()) {
if (IsBaseOf(ret_struct_info.value(), body_sinfo.value())) {
ret_struct_info = body_sinfo;
}
}
} else {
CHECK(body_sinfo.defined())
<< "Function do not have a return signature and body is not normalized";
ret_struct_info = body_sinfo;
CHECK(body_sinfo.defined() || ret_struct_info.defined())
<< "Function must be constructed with either "
<< "an explicit struct info for the return type, "
<< "or a normalized body with struct info.";

// Use the body's struct info if there is no explicit return type,
// or if the body may provide a more granular return type.
bool use_body_struct_info =
!ret_struct_info.defined() ||
(body_sinfo && ret_struct_info && IsBaseOf(ret_struct_info.value(), body_sinfo.value()));

if (use_body_struct_info) {
// MatchCast nodes within the body may introduce new symbolic
// variables. These are in-scope for the function body, but not
// for the function's return type. When hoisting the body's type
// to the function return type, symbolic variables may only be
// used if they were defined by the function's parameters.
auto f_shape_var_map = [&] {
auto tir_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo)));
std::unordered_set<tir::Var> lookup(tir_vars.begin(), tir_vars.end());
return [lookup = std::move(lookup)](const tir::Var& var) -> Optional<PrimExpr> {
if (lookup.count(var)) {
return var;
} else {
return NullOpt;
}
};
}();
ret_struct_info = EraseToWellDefined(body_sinfo.value(), f_shape_var_map);
}

auto f_shape_var_map = [&] {
auto tir_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo)));
std::unordered_set<tir::Var> lookup(tir_vars.begin(), tir_vars.end());
return [lookup = std::move(lookup)](const tir::Var& var) -> Optional<PrimExpr> {
if (lookup.count(var)) {
return var;
} else {
return NullOpt;
}
};
}();
ret_struct_info = EraseToWellDefined(ret_struct_info.value(), f_shape_var_map);

FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value(), is_pure);

// set the fields
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def reshape(
T_reshape[v_ax0] = rxplaceholder[v_ax0 % T.int64(3)]

@R.function
def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor((3,), dtype="int64"):
def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor(ndim=1, dtype="int64"):
x_1 = T.int64()
gv: R.Shape([3]) = R.call_pure_packed("vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape([3]),))
y: R.Shape([x_1]) = R.match_cast(gv, R.Shape([x_1]))
Expand Down

0 comments on commit 7f62f70

Please sign in to comment.