Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Handle binary operations between Tensor and PrimValue #16827

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 59 additions & 41 deletions python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name,too-many-locals

"""Utility functions for Relax"""

import functools
import inspect
import itertools

from typing import Tuple as typing_Tuple
from typing import Any, Callable, List, Dict, Optional, TypeVar

import tvm
from .. import tir
from ..tir import PrimExpr
from ..runtime import String, convert_to_object
Expand Down Expand Up @@ -302,9 +308,22 @@ def gen_call_tir_inputs(
out_sinfo, and tir_vars.
"""

def _convert_te_arg(
te_args: Any, tir_var_map: Dict[tir.Var, tir.PrimExpr]
) -> typing_Tuple[Any, List[te_Tensor]]:
tir_var_map: Dict[tir.Var, tir.PrimExpr] = {}

call_tir_args = []
# extra list of tir expression arguments
# that are not covered by Tensor
extra_tir_args_list = []

def _copy_undefined_var(expr: tir.PrimExpr):
def _visit_expr(e: tir.PrimExpr):
if isinstance(e, tir.Var) and e not in tir_var_map:
new_var = tir.Var(e.name, e.dtype)
tir_var_map[e] = new_var

tir.stmt_functor.post_order_visit(expr, _visit_expr)

def _convert_te_arg(te_args: Any) -> Any:
"""Helper function used to convert Relax expressions to TE tensor.

In the common case, the type of te_args is a Relax expression and is converted
Expand Down Expand Up @@ -335,18 +354,6 @@ def _convert_te_arg(
A tuple of the converted te_args, and a list of te tensors for each converted
Relax expression
"""
te_args_list = []
# extra list of tir expression arguments
# that are not covered by Tensor
extra_tir_args_list = []

def _copy_undefined_var(expr: tir.PrimExpr):
def _visit_expr(e: tir.PrimExpr):
if isinstance(e, tir.Var) and e not in tir_var_map:
new_var = tir.Var(e.name, e.dtype)
tir_var_map[e] = new_var

tir.stmt_functor.post_order_visit(expr, _visit_expr)

n_tensor = 0

Expand All @@ -363,18 +370,21 @@ def _convert_te_arg_helper(arg):
name = chr(ord("A") + n_tensor) if n_tensor < 26 else f"input{n_tensor}"
arg = te_tensor(arg, tir_var_map, name)
n_tensor += 1
te_args_list.append(arg)
call_tir_args.append(arg)
return arg
if isinstance(arg.struct_info, ShapeStructInfo):
assert isinstance(
arg, ShapeExpr
), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr"
return [_convert_te_arg_helper(val) for val in arg.values]
if (
isinstance(arg.struct_info, PrimStructInfo)
and arg.struct_info.value is not None
):
return _convert_te_arg_helper(arg.struct_info.value)
if isinstance(arg.struct_info, PrimStructInfo):
if arg.struct_info.value is None:
name = arg.name_hint if isinstance(arg, tvm.relax.Var) else "prim_arg"
call_tir_args.append(arg)
return tir.Var(name, arg.struct_info.dtype)
else:
return _convert_te_arg_helper(arg.struct_info.value)

elif isinstance(arg, (list, Array)):
return [_convert_te_arg_helper(x) for x in arg]
elif isinstance(arg, tuple):
Expand All @@ -395,28 +405,36 @@ def _convert_te_arg_helper(arg):
raise TypeError("not supported type in emit_te: {}".format(type(arg)))

new_arg = _convert_te_arg_helper(te_args)
return new_arg, te_args_list, extra_tir_args_list
return new_arg

def _get_unbound_tir_vars(
args: List[te_Tensor], extra_tir_args: List[PrimExpr]
) -> List[tir.Var]:
"""get unbound TIR vars (i.e TIR vars used in the shape but is not
itself a dimension of a shape)"""

bound_vars = set()
used_vars = set()

def _populate_bound_vars(expr):
if isinstance(expr, te_Tensor):
for dim in expr.shape:
_populate_bound_vars(dim)
elif isinstance(expr, tir.Var):
bound_vars.add(expr)

def _populate_used_vars(expr):
if isinstance(expr, tir.Var):
used_vars.add(expr)
if isinstance(expr, te_Tensor):
for dim in expr.shape:
_populate_used_vars(dim)
elif isinstance(expr, tir.PrimExpr):
used_vars.update(tir.analysis.undefined_vars(expr))

for val in extra_tir_args:
tir.stmt_functor.post_order_visit(val, _populate_used_vars)
for arg in itertools.chain(args, extra_tir_args):
_populate_used_vars(arg)

for x in args:
for s in x.shape:
tir.stmt_functor.post_order_visit(s, _populate_used_vars)
if isinstance(s, tir.Var):
bound_vars.add(s)
for arg in args:
_populate_bound_vars(arg)

diff = used_vars - bound_vars
return list(diff)
Expand Down Expand Up @@ -448,29 +466,29 @@ def _shape_with_old_tir_var(

primfunc_attrs = kwargs.pop("primfunc_attrs", None)

tir_var_map: Dict[tir.Var, tir.PrimExpr] = {}
new_args, te_arg_list, tir_arg_list = _convert_te_arg(args, tir_var_map)
new_kwargs, te_kwarg_list, tir_kwarg_list = _convert_te_arg(kwargs, tir_var_map)

te_args = te_arg_list + te_kwarg_list
te_args = _convert_te_arg(args)
te_kwargs = _convert_te_arg(kwargs)

te_out = func(*new_args, **new_kwargs)
te_out = func(*te_args, **te_kwargs)
assert isinstance(te_out, te_Tensor) or (
isinstance(te_out, (tuple, list, Array)) and all(isinstance(t, te_Tensor) for t in te_out)
), "only support te.tensor or tuple/list/Array of te.tensor as function output"

outs = [te_out] if isinstance(te_out, te_Tensor) else list(te_out)
unbound_tir_vars = _get_unbound_tir_vars(te_args + outs, tir_arg_list + tir_kwarg_list)
unbound_tir_vars = _get_unbound_tir_vars(
[*call_tir_args, *outs],
extra_tir_args_list,
)

inputs = [*te_args] + outs + unbound_tir_vars
tir_func = create_prim_func(inputs, "int64")
prim_func_args = [*call_tir_args, *outs, *unbound_tir_vars]
tir_func = create_prim_func(prim_func_args, "int64")

if primfunc_attrs:
tir_func = tir_func.with_attrs(primfunc_attrs)

tir_func = tir_func.without_attr("global_symbol")

call_tir_args = [x.op.value for x in te_args]
call_tir_args = [arg.op.value if isinstance(arg, te_Tensor) else arg for arg in call_tir_args]

# Invert the TIR variable mapping, to convert the output shape back
# with old set of variables.
Expand Down
81 changes: 60 additions & 21 deletions src/relax/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,52 +239,91 @@ InferLayoutOutput InferLayoutUnaryEwise(const Call& call,
const Map<String, Array<String>>& desired_layouts,
const VarLayoutMap& var_layout_map);

/*!
* \brief Get the element dtype from StructInfo
*
* \param sinfo The StructInfo to expect
* \return The inferred element dtype.
* \throw Throw exception if the StructInfo doesn't have an element type.
*/
inline DataType GetElementDType(const StructInfo& sinfo) {
if (const auto* prim = sinfo.as<PrimStructInfoNode>()) {
return prim->dtype;
} else if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
return tensor->dtype;
} else if (sinfo.as<ObjectStructInfoNode>()) {
return DataType::Void();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this necessarily be expected behavior? An Object could be anything, including things that dtype does not make sense for at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I went back and forth on it. There isn't currently a standard for whether FInferStructInfo should raise an error when the arguments are provably invalid, or if it should raise an error when the arguments are not provably valid. On the one hand, StructInfoLCA returns ObjectStructInfo as the common base class of TensorStructInfo and PrimStructInfo, so an ObjectStructInfo could contain a valid instance of either. On the other hand, the current struct inference requires that the input be validated as TensorStructInfo.

Overall, I'm not sure which is the better behavior. For now, I'm updating this PR to explicitly require either TensorStructInfo or PrimStructInfo, and to raise an exception for ObjectStructInfo, since allowing ObjectStructInfo would be an independent change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I'm in favor of asking for a MatchCast if we can't draw a conclusion. Down the line, inserting MatchCasts via normalization rules would be a good policy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. Currently, FInferStructInfo is called prior to FNormalize, so inference could be inspecting an expression that hasn't yet been normalized. This was useful for providing FNormalize for R.Prim (if PrimStructInfo contains a known value, in-line that value), but I'm wondering if we should re-visit that.

} else {
LOG(FATAL) << "TypeError: "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally our error message would ask for TensorStructInfo. In this particular case, would this error message be less informative than before? Given this is a global change across all binary ops, would be good to cross confirm the usages here and make error more informative.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, and this no longer tells the user which operation it was. Updated.

<< "Cannot determine element type of " << sinfo;
}
}

/*!
* \brief Infer the output datatype for binary arithmetic operators.
* \param call The context Call to the operator.
* \param ctx The error reporting context.
* \param x1_sinfo The struct info of the first operand
* \param x2_sinfo The struct info of the second operand
* \param lhs_sinfo The struct info of the first operand
* \param rhs_sinfo The struct info of the second operand
* \return The inferred output dtype.
* \throw Throw exception if the dtype of two input TensorStructInfo don’t match
*/
inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& ctx,
const TensorStructInfo& x1_sinfo,
const TensorStructInfo& x2_sinfo) {
if (x1_sinfo->IsUnknownDtype() || x2_sinfo->IsUnknownDtype()) {
const StructInfo& lhs_sinfo,
const StructInfo& rhs_sinfo) {
auto lhs_dtype = GetElementDType(lhs_sinfo);
auto rhs_dtype = GetElementDType(rhs_sinfo);
if (lhs_dtype.is_void() || rhs_dtype.is_void()) {
return DataType::Void();
} else if (x1_sinfo->dtype != x2_sinfo->dtype) {
} else if (lhs_dtype != rhs_dtype) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Data types " << x1_sinfo->dtype << " and " << x2_sinfo->dtype
<< " must be equal for binary operators");
<< "TypeErorr: "
<< "Binary operators must have the same datatype for both operands. "
<< "However, " << call << " uses datatype " << lhs_dtype
<< " on the LHS (StructInfo of " << lhs_sinfo << "), and datatype "
<< rhs_dtype << " on the RHS (StructInfo of " << rhs_sinfo << ").");
}
return x1_sinfo->dtype;
return lhs_dtype;
}

/*!
* \brief Infer the output virtual device for binary arithmetic operators.
* \param call The context Call to the operator.
* \param ctx The error reporting context.
* \param x1_sinfo The struct info of the first operand
* \param x2_sinfo The struct info of the second operand
* \param lhs_sinfo The struct info of the first operand
* \param rhs_sinfo The struct info of the second operand
* \return The inferred output vdevice.
* \throw Throw exception if the vdevice of two input TensorStructInfo don’t match
*/
inline Optional<VDevice> InferBinaryArithOpOutVDevice(const Call& call, const BlockBuilder& ctx,
const TensorStructInfo& x1_sinfo,
const TensorStructInfo& x2_sinfo) {
if (!x1_sinfo->vdevice.defined() || !x1_sinfo->vdevice.value()->target.defined()) {
return x2_sinfo->vdevice;
const StructInfo& lhs_sinfo,
const StructInfo& rhs_sinfo) {
auto get_vdevice = [&](const StructInfo& sinfo) -> Optional<VDevice> {
if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
return tensor->vdevice;
} else {
return NullOpt;
}
};

auto lhs_vdevice = get_vdevice(lhs_sinfo);
auto rhs_vdevice = get_vdevice(rhs_sinfo);

if (!lhs_vdevice.defined() || !lhs_vdevice.value()->target.defined()) {
return rhs_vdevice;
}
if (!x2_sinfo->vdevice.defined() || !x2_sinfo->vdevice.value()->target.defined()) {
return x1_sinfo->vdevice;
if (!rhs_vdevice.defined() || !rhs_vdevice.value()->target.defined()) {
return lhs_vdevice;
}
if (x1_sinfo->vdevice.value() != x2_sinfo->vdevice.value()) {
if (lhs_vdevice.value() != rhs_vdevice.value()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "VDevice " << x1_sinfo->vdevice.value() << " and "
<< x2_sinfo->vdevice.value() << " must be equal for binary operators");
<< "TypeErorr: "
<< "Binary operators with Tensor arguments "
<< "must have the same VDevice for both operands. "
<< "However, " << call << " has a LHS on VDevice " << lhs_vdevice
<< " and a RHS on VDevice " << rhs_vdevice);
}
return x1_sinfo->vdevice;
return lhs_vdevice;
}

/*!
Expand Down
Loading
Loading