Skip to content

Commit

Permalink
[Relax] Avoid wrapping TupleStructInfo into a Tuple for R.call_tir (a…
Browse files Browse the repository at this point in the history
…pache#17243)

* [Relax] Avoid wrapping TupleStructInfo into a Tuple for R.call_tir

Prior to this commit, the different `R.call_tir*` variations would
wrap the arguments into an in-line `relax.Tuple`, if it is not
already a `relax.Tuple`.  While this allows a tensor to be passed into
these functions as a single argument (`R.call_tir(func, arg, ...)`
instead of `R.call_tir(func, [arg], ...)`), the wrapped Relax variable
may already refer to a tuple.

This use of a variable to refer to an argument tuple rather than an
in-line argument tuple is not allowed by Relax.  (See discussion on
apache#15916 for details.)  However, by
wrapping a variable `args: R.Tuple(R.Tensor, R.Tensor, ...)` into a
tuple-of-tuples, the error occurs after the expression has already
been generated, and refers to an expression `R.Tuple(R.Tuple(R.Tensor,
R.Tensor, ...))` that doesn't appear anywhere in the user's input.
This can make debugging difficult (see
apache#17239 for an example).

This commit updates the argument-handling in `R.call_tir` to only
generate an in-line `relax.Tuple` if the arguments do not already have
`relax.TupleStructInfo`.  If the argument was provided as a Relax
variable bound to a tuple of arguments, it will still produce an
error.  However, that error will occur much earlier, and will
explicitly state that the argument must be a `relax.Tuple` instead of
a `relax.Var`.

* lint fixes
  • Loading branch information
Lunderberg authored Aug 26, 2024
1 parent d5d5ebb commit c4acc79
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 9 deletions.
37 changes: 28 additions & 9 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# pylint: disable=redefined-builtin
"""The base Relax operators."""

from typing import Dict, Union, List, Tuple, Optional, Callable


Expand All @@ -25,7 +26,6 @@

from . import _ffi_api
from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var
from ..expr import Tuple as RxTuple
from ..struct_info import StructInfo, TensorStructInfo
from ...ir import PrimExpr
from ..utils import args_converter
Expand Down Expand Up @@ -67,6 +67,29 @@ def null_value() -> Call:
return _ffi_api.null_value() # type: ignore


def _wrap_inline_arg_tuple(args) -> Expr:
"""Helper function to wrap argument tuple
Normalize the arguments provided the functions that accept a tuple
of arguments, and require the tuple of arguments to be written
in-line. If the arguments provided are a single relax expression,
and are not a reference to a relax tuple, then wrap them into an
in-line relax Tuple.
"""
if (
isinstance(args, Expr)
and not isinstance(args, tvm.relax.Tuple)
and (
args.struct_info_ is None
or not isinstance(args.struct_info_, tvm.relax.TupleStructInfo)
)
):
return tvm.relax.Tuple([args])
else:
return args


@args_converter.auto
def call_tir(
gvar: GlobalVar,
Expand Down Expand Up @@ -98,8 +121,7 @@ def call_tir(
ret: Call
A call node for the call_tir operator.
"""
if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))
args = _wrap_inline_arg_tuple(args)

if not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]
Expand Down Expand Up @@ -153,8 +175,7 @@ def call_tir_with_grad(
ret: Call
A call node for the call_tir_with_grad operator.
"""
if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))
args = _wrap_inline_arg_tuple(args)

if not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]
Expand Down Expand Up @@ -221,8 +242,7 @@ def call_tir_inplace(
ret: Call
A call node for the call_tir operator.
"""
if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))
args = _wrap_inline_arg_tuple(args)

if not isinstance(inplace_indices, list):
inplace_indices = [inplace_indices]
Expand Down Expand Up @@ -276,8 +296,7 @@ def call_dps_packed(
if isinstance(func, str):
func = ExternFunc(func)

if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))
args = _wrap_inline_arg_tuple(args)

if not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]
Expand Down
36 changes: 36 additions & 0 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,42 @@ def main(
_check(Module)


def test_call_tir_inplace_with_tuple_var_raises_error():

with pytest.raises(tvm.error.DiagnosticError):

@tvm.script.ir_module
class Module:
@R.function
def main(x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")):
cls = Module
args = (x, y)
res = R.call_tir_inplace(
cls.copy,
# The `args` tuple must be an in-line tuple, not a
# reference to a tuple. This error should be
# caught and raised during parsing.
args,
inplace_indices=[0, -1],
out_sinfo=[R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")],
)
return res

@T.prim_func
def copy(
A: T.Buffer((2, 3), "int32"),
B: T.Buffer((2, 3), "int32"),
out1: T.Buffer((2, 3), "int32"),
):
# copies the contents of B into A and out1
T.func_attr({"tir.noalias": True})
for iters in T.grid(T.int64(2), T.int64(3)):
with T.block("T_zeros"):
i, j = T.axis.remap("SS", iters)
A[i, j] = B[i, j]
out1[i, j] = B[i, j]


def test_local_function():
@R.function
def main(
Expand Down

0 comments on commit c4acc79

Please sign in to comment.