diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index b7b3f411ed41..820acd235d8c 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -591,7 +591,22 @@ def wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Sequence[Tensor]]: The computed result. """ if not isinstance(expr, rx.DataflowVar): - expr = BlockBuilder.current().emit(expr, name) + block_builder = BlockBuilder.current() + if block_builder is None: + # Normalize to make sure we have valid StructInfo, but + # wait until we are actually building the function to + # flatten nested expressions. + # + # TODO(Lunderberg): Make this easier to call. Infering + # struct info for a nested expression should be doable in + # a free function, without requiring an active + # BlockBuilder and an active FunctionFrame. + builder = BlockBuilder() + with builder.function("dummy_scope", params=[]): + expr = builder.normalize(expr) + builder.emit_func_output([]) + else: + expr = BlockBuilder.current().emit(expr, name) if isinstance(expr.struct_info_, TensorStructInfo): return Tensor(_expr=expr) if isinstance(expr.struct_info_, TupleStructInfo): diff --git a/python/tvm/relax/frontend/nn/exporter.py b/python/tvm/relax/frontend/nn/exporter.py index 1a7dcd6a648b..525d689f4995 100644 --- a/python/tvm/relax/frontend/nn/exporter.py +++ b/python/tvm/relax/frontend/nn/exporter.py @@ -111,7 +111,8 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]: return result # pylint: enable=protected-access - params = None + + params = _params() effects = _effects() ext_mods = self.extern_mods with self: @@ -121,7 +122,6 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]: outputs = _emit_effect_init(self.builder, effects) self.builder.emit_func_output(outputs, params=[]) for method_name, method_spec in zip(spec.method_names, spec.method_specs): - params = _params() # Re-initialize so symbolic shapes not shared across methods len_args = len(method_spec.arg_specs) len_effects = { "packed": 1, @@ -135,9 +135,18 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]: with self.builder.dataflow(): outputs, inputs = _emit_method(self.builder, method_spec, params, effects) self.builder.emit_func_output(outputs, inputs) + + # TODO(Lunderberg): Make a `ir.transform.ConvertSSA`, + # similar to the existing `tir.transform.ConvertSSA`, + # that converts an entire module to SSA, including TIR + # variable definitions used in either TIR or Relax. + mod = self.builder.get() + mod[method_name] = rx.utils.copy_with_new_vars(mod[method_name]) + mod = self.builder.finalize() assert rx.analysis.well_formed(mod) + mod = rx.transform.CanonicalizeBindings()(mod) return mod, params, ext_mods @@ -161,8 +170,6 @@ def _emit_method( # pylint: disable=too-many-locals,too-many-branches,too-many- effects: typing.Optional[typing.List[typing.Tuple[str, core.Effect]]], ): # pylint: disable=protected-access - # symbolic shape's name mapping to its tir.Var for reuse - str2var_params: typing.Dict[str, tir.Var] = {} def _unwrap_ret(expr: typing.Any) -> typing.Any: if isinstance(expr, (core.Tensor, core.Object)): @@ -176,35 +183,26 @@ def _unwrap_ret(expr: typing.Any) -> typing.Any: def _convert_input(arg): if isinstance(arg, tir.Var): return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg])) - if isinstance(arg, (core.Tensor, core.Object)): + elif isinstance(arg, (core.Tensor, core.Object)): return arg._expr # pylint: disable=protected-access - if isinstance(arg, _spec.Tuple): + elif isinstance(arg, _spec.Tuple): return rx.Var( arg.name, struct_info=TupleStructInfo( [_convert_input(arg_i).struct_info for arg_i in arg.elements] ), ) - raise TypeError(f"Unsupported input type: {type(arg)}") + elif isinstance(arg, rx.Expr): + return arg + else: + raise TypeError(f"Unsupported input type: {type(arg)}") def _params(mode: str) -> typing.List[rx.Var]: inputs: typing.List[rx.Var] = [] - def _get_var(shape_var: tir.Var) -> tir.Var: - name = shape_var.name - if name in str2var_params: - return str2var_params[name] - var = tir.Var(name, "int64") - str2var_params[name] = var - return var - for name, param in params: - # Make sure the a symbolic shape is not re-registered (same as _method_spec_to_inputs) - # e.g. we do not see `vocab_size` for `lm_head` and `vocab_size_1` for `embed_tokens` - new_shape = [_get_var(x) if isinstance(x, tir.Var) else x for x in param.shape] - var = core.Tensor.placeholder(new_shape, param.dtype, name)._expr - inputs.append(var) - param._expr = var + inputs.append(param._expr) + if mode == "none": return [] if mode == "plain": diff --git a/tests/python/relax/test_frontend_nn_exporter.py b/tests/python/relax/test_frontend_nn_exporter.py new file mode 100644 index 000000000000..de8900238bb6 --- /dev/null +++ b/tests/python/relax/test_frontend_nn_exporter.py @@ -0,0 +1,443 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import tvm +import tvm.testing + +from tvm import relax, tir +from tvm.ir import assert_structural_equal +from tvm.relax.frontend import nn +from tvm.script import ir as I, relax as R, tir as T + + +def test_simple(): + """A module may be exported from nn.Module to Relax""" + + slm_mod = nn.modules.ReLU() + exported_mod, _ = slm_mod.export_tvm( + spec={"forward": {"x": nn.spec.Tensor((3, 3), "float32")}}, + debug=False, + ) + + @I.ir_module + class Expected: + @R.function + def forward(x: R.Tensor([3, 3], dtype="float32")): + R.func_attr({"num_input": 1}) + with R.dataflow(): + relu = R.nn.relu(x) + R.output(relu) + return relu + + assert_structural_equal(exported_mod, Expected) + + +def test_custom_module(): + """A module may be exported from nn.Module to Relax""" + + class Before(nn.Module): + def forward(self, x: R.Tensor): + return nn.op.relu(x) + + slm_mod = Before() + exported_mod, _ = slm_mod.export_tvm( + spec={"forward": {"x": nn.spec.Tensor((3, 3), "float32")}}, + debug=False, + ) + + @I.ir_module + class Expected: + @R.function + def forward(x: R.Tensor([3, 3], dtype="float32")): + R.func_attr({"num_input": 1}) + with R.dataflow(): + relu = R.nn.relu(x) + R.output(relu) + return relu + + assert_structural_equal(exported_mod, Expected) + + +def test_debug_effect(): + """Passing debug=True provides an argument for IO effect""" + + slm_mod = nn.modules.ReLU() + exported_mod, _ = slm_mod.export_tvm( + spec={"forward": {"x": nn.spec.Tensor((3, 3), "float32")}}, + debug=True, + ) + + @I.ir_module + class Expected: + @R.function + def forward( + x: R.Tensor([3, 3], dtype="float32"), + _io: R.Object, + ): + R.func_attr({"num_input": 2}) + with R.dataflow(): + relu = R.nn.relu(x) + output = relu, (_io,) + R.output(output) + return output + + @R.function + def _initialize_effect(): + with R.dataflow(): + _io = R.null_value() + output = (_io,) + R.output(output) + return output + + assert_structural_equal(exported_mod, Expected) + + +def test_dynamic_shape(): + """An argument may have a dynamic shape""" + + slm_mod = nn.modules.ReLU() + exported_mod, _ = slm_mod.export_tvm( + spec={"forward": {"x": nn.spec.Tensor([tir.Var("batch_size", "int64"), 8], "float32")}}, + debug=False, + ) + + @I.ir_module + class Expected: + @R.function + def forward(x: R.Tensor(["batch_size", 8], dtype="float32")): + R.func_attr({"num_input": 1}) + with R.dataflow(): + relu = R.nn.relu(x) + R.output(relu) + return relu + + assert_structural_equal(exported_mod, Expected) + + +def test_dynamic_shape_in_multiple_functions(): + """A dynamic shape may be used in multiple functions""" + + class Before(nn.Module): + def forward_relu(self, x: nn.Tensor): + return nn.relu(x) + + def forward_silu(self, x: nn.Tensor): + return nn.silu(x) + + slm_mod = Before() + exported_mod, _ = slm_mod.export_tvm( + spec={ + "forward_relu": {"x": nn.spec.Tensor((tir.Var("batch_size", "int64"), 8), "float32")}, + "forward_silu": {"x": nn.spec.Tensor((tir.Var("batch_size", "int64"), 8), "float32")}, + }, + debug=False, + ) + + @I.ir_module + class Expected: + @R.function + def forward_relu(x: R.Tensor(["batch_size", 8], dtype="float32")): + R.func_attr({"num_input": 1}) + with R.dataflow(): + relu = R.nn.relu(x) + R.output(relu) + return relu + + @R.function + def forward_silu(x: R.Tensor(["batch_size", 8], dtype="float32")): + R.func_attr({"num_input": 1}) + with R.dataflow(): + silu = R.nn.silu(x) + R.output(silu) + return silu + + assert_structural_equal(exported_mod, Expected) + + +def test_export_nested_module(): + """nn.Module instances may contain other nn.Module + + When exporting to a Relax IRModule, all `nn.Parameter` instances + within the `nn.Module` become Relax function parameters. + """ + + class LlamaMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int): + super().__init__() + self.gate_proj = nn.Linear( + in_features=hidden_size, + out_features=intermediate_size, + dtype="float16", + bias=False, + ) + self.up_proj = nn.Linear( + in_features=hidden_size, + out_features=intermediate_size, + dtype="float16", + bias=False, + ) + self.down_proj = nn.Linear( + intermediate_size, + hidden_size, + dtype="float16", + bias=False, + ) + + def forward(self, x: nn.Tensor): + gate = self.gate_proj(x) + up = self.up_proj(x) + return self.down_proj(nn.op.silu(gate) * up) + + hidden_size = 4096 + intermediate_size = 11008 + slm_mod = LlamaMLP(hidden_size=hidden_size, intermediate_size=intermediate_size) + exported_mod, _ = slm_mod.export_tvm( + spec={ + "forward": { + "x": nn.spec.Tensor((tir.Var("batch_size", "int64"), hidden_size), "float16") + }, + }, + debug=False, + ) + + @I.ir_module + class Expected: + @R.function + def forward( + x: R.Tensor(["batch_size", hidden_size], "float16"), + gate_proj_weights: R.Tensor([intermediate_size, hidden_size], "float16"), + up_proj_weights: R.Tensor([intermediate_size, hidden_size], "float16"), + down_proj_weights: R.Tensor([hidden_size, intermediate_size], "float16"), + ): + R.func_attr({"num_input": 1}) + batch_size = T.int64() + with R.dataflow(): + gate: R.Tensor([batch_size, intermediate_size]) = R.matmul( + x, R.permute_dims(gate_proj_weights) + ) + up: R.Tensor([batch_size, intermediate_size]) = R.matmul( + x, R.permute_dims(up_proj_weights) + ) + down: R.Tensor([batch_size, hidden_size]) = R.matmul( + R.nn.silu(gate) * up, R.permute_dims(down_proj_weights) + ) + R.output(down) + return down + + assert_structural_equal(exported_mod, Expected) + + +def test_generate_parameters(): + """Weights may be expressions in terms of other parameters + + Optimizations often require preprocessing of the model weights. + + 1. Declare the `nn.Module` members that contain the original model + weights. These are used to define the parameter names when + reading from a Pytorch or Safetensors file. + + 2. Declare the `nn.Module` members, with the `weight` field + in terms of the un-optimized weights. These `nn.Module` + do not generate any parameters in the Relax function. + + 3. Define the `forward` function in terms of the `nn.Module` + members for the updated weight tensors. + + The exported Relax function accepts the original model parameters, + computes the pre-processed weights, and then performs computations + using the pre-processed weights. + + In this example, the `LiftTransformParams` transform is applied + immediately, splitting the Relax function into a pre-processing + step and an execution step. In practice, this transform would be + applied much later in an optimization pipeline, to allow optimized + compute kernels to be recognized. For example, in some cases + `R.matmul(x, R.permute_dims(weight))` may be computed more + efficiently than `R.matmul(x, weight_transpose)`. For this + reason, we do *not* apply `LiftTransformParams` as part of the + export from `nn.Module` to Relax. + + """ + + class LlamaMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int): + super().__init__() + # The nn.Linear for the original parameters are present in + # the model definition, and are still found when + # collecting a function's parameters. + self.gate_proj = nn.Linear( + in_features=hidden_size, + out_features=intermediate_size, + dtype="float16", + bias=False, + ) + self.up_proj = nn.Linear( + in_features=hidden_size, + out_features=intermediate_size, + dtype="float16", + bias=False, + ) + self.down_proj = nn.Linear( + intermediate_size, + hidden_size, + dtype="float16", + bias=False, + ) + + # At runtime, we'd like to have a single concatenated + # tensor containing both the gate and up projection + # weights. We also want to use it in the `forward` + # function as if it owned its own weights. + self.gate_up_proj = nn.Linear( + in_features=hidden_size, + out_features=intermediate_size, + dtype="float16", + bias=False, + ) + + # The weight tensor of `gate_up_proj` can be overwritten + # in terms of the original `gate_proj` and `up_proj` + # tensors. + self.gate_up_proj.weight = nn.op.concat( + [self.gate_proj.weight, self.up_proj.weight], dim=0, name="gate_up_proj_weights" + ) + + def forward(self, x: nn.Tensor): + # Even though the `gate_up_proj` weights are defined as an + # expression rather than a `nn.Parameter`, the `forward` + # function does not require any special handling for it. + concat_gate_up = self.gate_up_proj(x) + gate, up = nn.op.split(concat_gate_up, 2, axis=-1) + return self.down_proj(nn.op.silu(gate) * up) + + hidden_size = 4096 + intermediate_size = 11008 + slm_mod = LlamaMLP(hidden_size=hidden_size, intermediate_size=intermediate_size) + exported_mod, _ = slm_mod.export_tvm( + spec={ + "forward": { + "x": nn.spec.Tensor((tir.Var("batch_size", "int64"), hidden_size), "float16") + }, + }, + debug=False, + ) + + @I.ir_module + class Expected: + @R.function + def forward( + x: R.Tensor(["batch_size", hidden_size], "float16"), + # The function's parameters are defined by the + # `nn.Parameter` instances, and still reference the + # original `gate_proj` and `up_proj` weights. This + # maintains compatibility with named model weights in a + # Pytorch or Safetensors file. + gate_proj_weights: R.Tensor([intermediate_size, hidden_size], "float16"), + up_proj_weights: R.Tensor([intermediate_size, hidden_size], "float16"), + down_proj_weights: R.Tensor([hidden_size, intermediate_size], "float16"), + ): + R.func_attr({"num_input": 1}) + batch_size = T.int64() + with R.dataflow(): + # At this stage of compilation, the concatenation is + # written within the body of the function. This will + # later be extracted into a pre-processing step using + # `relax.transform.LiftTransformParams`. + gate_up_proj_weights: R.Tensor( + [intermediate_size * 2, hidden_size], "float16" + ) = R.concat([gate_proj_weights, up_proj_weights], axis=0) + gate_up: R.Tensor([batch_size, intermediate_size * 2], "float16") = R.matmul( + x, R.permute_dims(gate_up_proj_weights) + ) + gate_up_split = R.split(gate_up, 2, axis=-1) + gate = gate_up_split[0] + up = gate_up_split[1] + down: R.Tensor([batch_size, hidden_size], "float16") = R.matmul( + R.nn.silu(gate) * up, R.permute_dims(down_proj_weights) + ) + R.output(down) + return down + + assert_structural_equal(exported_mod, Expected) + + @I.ir_module + class ExpectedAfterLift: + @R.function + def forward( + x: R.Tensor(["batch_size", hidden_size], "float16"), + # After `relax.transform.LiftTransformParams`, the + # `gate_proj` and `up_proj` weights have been concatenated + # together. + gate_up_proj_weights_transpose: R.Tensor( + [hidden_size, intermediate_size * 2], "float16" + ), + down_proj_weights_transpose: R.Tensor([intermediate_size, hidden_size], "float16"), + ): + R.func_attr({"num_input": 1}) + batch_size = T.int64() + with R.dataflow(): + gate_up: R.Tensor([batch_size, intermediate_size * 2], "float16") = R.matmul( + x, gate_up_proj_weights_transpose + ) + gate_up_split = R.split(gate_up, 2, axis=-1) + gate = gate_up_split[0] + up = gate_up_split[1] + down: R.Tensor([batch_size, hidden_size], "float16") = R.matmul( + R.nn.silu(gate) * up, down_proj_weights_transpose + ) + R.output(down) + return down + + @R.function + def transform_params( + model_params: R.Tuple( + R.Tensor([intermediate_size, hidden_size], "float16"), + R.Tensor([intermediate_size, hidden_size], "float16"), + R.Tensor([hidden_size, intermediate_size], "float16"), + ) + ): + R.func_attr({"num_input": 0}) + with R.dataflow(): + gate_proj_weights: R.Tensor( + [intermediate_size, hidden_size], "float16" + ) = model_params[0] + up_proj_weights: R.Tensor( + [intermediate_size, hidden_size], "float16" + ) = model_params[1] + gate_up_proj_weights: R.Tensor( + [intermediate_size * 2, hidden_size], "float16" + ) = R.concat([gate_proj_weights, up_proj_weights], axis=0) + gate_up_proj_weights_transpose: R.Tensor( + [hidden_size, intermediate_size * 2], "float16" + ) = R.permute_dims(gate_up_proj_weights) + down_proj_weights: R.Tensor( + [hidden_size, intermediate_size], "float16" + ) = model_params[2] + down_proj_weights_transpose: R.Tensor( + [intermediate_size, hidden_size], "float16" + ) = R.permute_dims(down_proj_weights) + output = (gate_up_proj_weights_transpose, down_proj_weights_transpose) + R.output(output) + return output + + lifted_mod = relax.transform.LiftTransformParams(shared_transform=True)(exported_mod) + assert_structural_equal(lifted_mod, ExpectedAfterLift) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_frontend_nn_extern_module.py b/tests/python/relax/test_frontend_nn_extern_module.py index 6eaf1fbfc805..6ca774242274 100644 --- a/tests/python/relax/test_frontend_nn_extern_module.py +++ b/tests/python/relax/test_frontend_nn_extern_module.py @@ -94,9 +94,8 @@ def scalar_add( ext_scalar_add = R.call_dps_packed( "ext_scalar_add", (a, b), out_sinfo=R.Tensor((), dtype="float32") ) - gv: R.Tensor((), dtype="float32") = ext_scalar_add - R.output(gv) - return gv + R.output(ext_scalar_add) + return ext_scalar_add @R.function def test_sym( @@ -110,9 +109,8 @@ def test_sym( ext_test_sym = R.call_dps_packed( "ext_test_sym", (a, b), out_sinfo=R.Tensor((x, y, z, 9), dtype="float32") ) - gv1: R.Tensor((x, y, z, 9), dtype="float32") = ext_test_sym - R.output(gv1) - return gv1 + R.output(ext_test_sym) + return ext_test_sym tvm.ir.assert_structural_equal(ExpectedModule, mod) diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index 5ddc10505591..45128749e23d 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -493,8 +493,7 @@ def _initialize_effect() -> R.Tuple(R.Object, R.Object): R.prim_value(0), sinfo_args=[R.Object()], ) - lv1 = _io, cache - gv = lv1 + gv = _io, cache R.output(gv) return gv diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 7d78e47c945b..68f86bba50e8 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -538,8 +538,7 @@ def add_one(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), T_add: T.Buffer( def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv = (_io,) R.output(gv) return gv @@ -611,8 +610,7 @@ def llama_fused_rope(var_qkv: T.handle, offset: T.int64, var_q: T.handle, var_k: def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv = (_io,) R.output(gv) return gv @@ -699,8 +697,7 @@ def inplace_take( def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv = (_io,) R.output(gv) return gv @@ -717,13 +714,12 @@ def test( R.func_attr({"num_input": 4}) cls = Expected with R.dataflow(): - lv1 = R.call_tir( + gv1 = R.call_tir( cls.inplace_take, (embedding_table, input_ids, embedding_dst), out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype), tir_vars=R.shape([offset_1]), ) - gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1 R.output(gv1) return gv1 @@ -772,8 +768,7 @@ def test(A: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="fl R.func_attr({"num_input": 1}) cls = Expected with R.dataflow(): - lv = R.call_tir(cls.tir_func, (A,), out_sinfo=R.Tensor((16, 16), dtype="float32")) - gv: R.Tensor((16, 16), dtype="float32") = lv + gv = R.call_tir(cls.tir_func, (A,), out_sinfo=R.Tensor((16, 16), dtype="float32")) R.output(gv) return gv @@ -800,8 +795,7 @@ class Expected: def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv = (_io,) R.output(gv) return gv @@ -888,8 +882,7 @@ def get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle): def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv = (_io,) R.output(gv) return gv @@ -1015,8 +1008,7 @@ def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv: R.Tuple(R.Object) = (_io,) R.output(gv) return gv @@ -1130,8 +1122,7 @@ def get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.h def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv: R.Tuple(R.Object) = (_io,) R.output(gv) return gv diff --git a/tests/python/relax/test_frontend_nn_packing.py b/tests/python/relax/test_frontend_nn_packing.py index 56b614a807b8..c2cc22c17d40 100644 --- a/tests/python/relax/test_frontend_nn_packing.py +++ b/tests/python/relax/test_frontend_nn_packing.py @@ -59,8 +59,7 @@ def forward( matmul = R.matmul(x, matmul_1_weight) matmul_2_weight = R.permute_dims(linear_2_weight) matmul1 = R.matmul(x, matmul_2_weight) - add = R.add(matmul, matmul1) - gv = add + gv = R.add(matmul, matmul1) R.output(gv) return gv diff --git a/tests/python/relax/test_frontend_nn_subroutines.py b/tests/python/relax/test_frontend_nn_subroutines.py index 6bbf57aeadde..32ae967916a8 100644 --- a/tests/python/relax/test_frontend_nn_subroutines.py +++ b/tests/python/relax/test_frontend_nn_subroutines.py @@ -61,8 +61,7 @@ def forward( def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv = (_io,) R.output(gv) return gv @@ -75,9 +74,8 @@ def layer( with R.dataflow(): state = R.matmul(state, weights) state = Expected.activation(state) - dataflow_output = state - R.output(dataflow_output) - return dataflow_output + R.output(state) + return state @R.function(private=True) def activation( @@ -85,9 +83,8 @@ def activation( ) -> R.Tensor(("batch_size", 32), dtype="float32"): with R.dataflow(): state = R.nn.silu(state) - dataflow_output = state - R.output(dataflow_output) - return dataflow_output + R.output(state) + return state mod = Layer(64, 32) batch_size = tvm.tir.Var("batch_size", "int64")