From 687950e6010b5e632c8e449c6230b8cddd7010f5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sun, 31 Mar 2024 20:01:19 -0500 Subject: [PATCH 1/9] [Relax] Handle binary operations between Tensor and PrimValue Prior to this commit, binary operations were only defined between two tensors. This commit allows binary operations to apply between a tensor and a `relax::PrimValue`. When inferring the output `StructInfo`, binary operations with a `PrimValue` produce the same output as using a 0-d tensor. When legalizing operations containing a `PrimValue`, they are lowered to primitive TIR arguments. --- python/tvm/relax/utils.py | 97 ++-- src/relax/op/op_common.h | 81 ++- src/relax/op/tensor/binary.cc | 103 +++- src/script/printer/relax/tir.cc | 3 - src/te/operation/create_primfunc.cc | 11 +- tests/python/relax/test_op_binary.py | 110 +++- .../test_transform_legalize_ops_binary.py | 534 +++++++++++++++++- 7 files changed, 820 insertions(+), 119 deletions(-) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index a58b65477cee..0489adcde178 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -14,13 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + # pylint: disable=invalid-name,too-many-locals + """Utility functions for Relax""" + import functools import inspect +import itertools + from typing import Tuple as typing_Tuple from typing import Any, Callable, List, Dict, Optional, TypeVar +import tvm from .. import tir from ..tir import PrimExpr from ..runtime import String, convert_to_object @@ -302,9 +308,22 @@ def gen_call_tir_inputs( out_sinfo, and tir_vars. """ - def _convert_te_arg( - te_args: Any, tir_var_map: Dict[tir.Var, tir.PrimExpr] - ) -> typing_Tuple[Any, List[te_Tensor]]: + tir_var_map: Dict[tir.Var, tir.PrimExpr] = {} + + call_tir_args = [] + # extra list of tir expression arguments + # that are not covered by Tensor + extra_tir_args_list = [] + + def _copy_undefined_var(expr: tir.PrimExpr): + def _visit_expr(e: tir.PrimExpr): + if isinstance(e, tir.Var) and e not in tir_var_map: + new_var = tir.Var(e.name, e.dtype) + tir_var_map[e] = new_var + + tir.stmt_functor.post_order_visit(expr, _visit_expr) + + def _convert_te_arg(te_args: Any) -> Any: """Helper function used to convert Relax expressions to TE tensor. In the common case, the type of te_args is a Relax expression and is converted @@ -335,18 +354,6 @@ def _convert_te_arg( A tuple of the converted te_args, and a list of te tensors for each converted Relax expression """ - te_args_list = [] - # extra list of tir expression arguments - # that are not covered by Tensor - extra_tir_args_list = [] - - def _copy_undefined_var(expr: tir.PrimExpr): - def _visit_expr(e: tir.PrimExpr): - if isinstance(e, tir.Var) and e not in tir_var_map: - new_var = tir.Var(e.name, e.dtype) - tir_var_map[e] = new_var - - tir.stmt_functor.post_order_visit(expr, _visit_expr) n_tensor = 0 @@ -363,18 +370,23 @@ def _convert_te_arg_helper(arg): name = chr(ord("A") + n_tensor) if n_tensor < 26 else f"input{n_tensor}" arg = te_tensor(arg, tir_var_map, name) n_tensor += 1 - te_args_list.append(arg) + call_tir_args.append(arg) return arg if isinstance(arg.struct_info, ShapeStructInfo): assert isinstance( arg, ShapeExpr ), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr" return [_convert_te_arg_helper(val) for val in arg.values] - if ( - isinstance(arg.struct_info, PrimStructInfo) - and arg.struct_info.value is not None - ): - return _convert_te_arg_helper(arg.struct_info.value) + if isinstance(arg.struct_info, PrimStructInfo): + if arg.struct_info.value is None: + name = arg.name_hint if isinstance(arg, tvm.relax.Var) else "prim_arg" + call_tir_args.append(arg) + return tir.Var(name, arg.struct_info.dtype) + # call_tir_args.append(tir.Var(name, arg.struct_info.dtype)) + # return arg + else: + return _convert_te_arg_helper(arg.struct_info.value) + elif isinstance(arg, (list, Array)): return [_convert_te_arg_helper(x) for x in arg] elif isinstance(arg, tuple): @@ -388,35 +400,43 @@ def _convert_te_arg_helper(arg): elif isinstance(arg, tir.PrimExpr): _copy_undefined_var(arg) new_arg = tir.stmt_functor.substitute(arg, tir_var_map) - extra_tir_args_list.append(new_arg) + extra_tir_args_list.append(arg) return new_arg elif isinstance(arg, (int, float, str, Type, Attrs)) or arg is None: return arg raise TypeError("not supported type in emit_te: {}".format(type(arg))) new_arg = _convert_te_arg_helper(te_args) - return new_arg, te_args_list, extra_tir_args_list + return new_arg def _get_unbound_tir_vars( args: List[te_Tensor], extra_tir_args: List[PrimExpr] ) -> List[tir.Var]: """get unbound TIR vars (i.e TIR vars used in the shape but is not itself a dimension of a shape)""" + bound_vars = set() used_vars = set() + def _populate_bound_vars(expr): + if isinstance(expr, te_Tensor): + for dim in expr.shape: + _populate_bound_vars(dim) + elif isinstance(expr, tir.Var): + bound_vars.add(expr) + def _populate_used_vars(expr): - if isinstance(expr, tir.Var): - used_vars.add(expr) + if isinstance(expr, te_Tensor): + for dim in expr.shape: + _populate_used_vars(dim) + elif isinstance(expr, tir.PrimExpr): + used_vars.update(tir.analysis.undefined_vars(expr)) - for val in extra_tir_args: - tir.stmt_functor.post_order_visit(val, _populate_used_vars) + for arg in itertools.chain(args, extra_tir_args): + _populate_used_vars(arg) - for x in args: - for s in x.shape: - tir.stmt_functor.post_order_visit(s, _populate_used_vars) - if isinstance(s, tir.Var): - bound_vars.add(s) + for arg in args: + _populate_bound_vars(arg) diff = used_vars - bound_vars return list(diff) @@ -448,19 +468,16 @@ def _shape_with_old_tir_var( primfunc_attrs = kwargs.pop("primfunc_attrs", None) - tir_var_map: Dict[tir.Var, tir.PrimExpr] = {} - new_args, te_arg_list, tir_arg_list = _convert_te_arg(args, tir_var_map) - new_kwargs, te_kwarg_list, tir_kwarg_list = _convert_te_arg(kwargs, tir_var_map) - - te_args = te_arg_list + te_kwarg_list + te_args = _convert_te_arg(args) + te_kwargs = _convert_te_arg(kwargs) - te_out = func(*new_args, **new_kwargs) + te_out = func(*te_args, **te_kwargs) assert isinstance(te_out, te_Tensor) or ( isinstance(te_out, (tuple, list, Array)) and all(isinstance(t, te_Tensor) for t in te_out) ), "only support te.tensor or tuple/list/Array of te.tensor as function output" outs = [te_out] if isinstance(te_out, te_Tensor) else list(te_out) - unbound_tir_vars = _get_unbound_tir_vars(te_args + outs, tir_arg_list + tir_kwarg_list) + unbound_tir_vars = _get_unbound_tir_vars([*te_args, *outs], extra_tir_args_list) inputs = [*te_args] + outs + unbound_tir_vars tir_func = create_prim_func(inputs, "int64") @@ -470,7 +487,7 @@ def _shape_with_old_tir_var( tir_func = tir_func.without_attr("global_symbol") - call_tir_args = [x.op.value for x in te_args] + call_tir_args = [arg.op.value if isinstance(arg, te_Tensor) else arg for arg in call_tir_args] # Invert the TIR variable mapping, to convert the output shape back # with old set of variables. diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index f5eed7af0698..354bf773a9ac 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -239,52 +239,91 @@ InferLayoutOutput InferLayoutUnaryEwise(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map); +/*! + * \brief Get the element dtype from StructInfo + * + * \param sinfo The StructInfo to expect + * \return The inferred element dtype. + * \throw Throw exception if the StructInfo doesn't have an element type. + */ +inline DataType GetElementDType(const StructInfo& sinfo) { + if (const auto* prim = sinfo.as()) { + return prim->dtype; + } else if (const auto* tensor = sinfo.as()) { + return tensor->dtype; + } else if (sinfo.as()) { + return DataType::Void(); + } else { + LOG(FATAL) << "TypeError: " + << "Cannot determine element type of " << sinfo; + } +} + /*! * \brief Infer the output datatype for binary arithmetic operators. * \param call The context Call to the operator. * \param ctx The error reporting context. - * \param x1_sinfo The struct info of the first operand - * \param x2_sinfo The struct info of the second operand + * \param lhs_sinfo The struct info of the first operand + * \param rhs_sinfo The struct info of the second operand * \return The inferred output dtype. * \throw Throw exception if the dtype of two input TensorStructInfo don’t match */ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& ctx, - const TensorStructInfo& x1_sinfo, - const TensorStructInfo& x2_sinfo) { - if (x1_sinfo->IsUnknownDtype() || x2_sinfo->IsUnknownDtype()) { + const StructInfo& lhs_sinfo, + const StructInfo& rhs_sinfo) { + auto lhs_dtype = GetElementDType(lhs_sinfo); + auto rhs_dtype = GetElementDType(rhs_sinfo); + if (lhs_dtype.is_void() || rhs_dtype.is_void()) { return DataType::Void(); - } else if (x1_sinfo->dtype != x2_sinfo->dtype) { + } else if (lhs_dtype != rhs_dtype) { ctx->ReportFatal(Diagnostic::Error(call) - << "Data types " << x1_sinfo->dtype << " and " << x2_sinfo->dtype - << " must be equal for binary operators"); + << "TypeErorr: " + << "Binary operators must have the same datatype for both operands. " + << "However, " << call << " uses datatype " << lhs_dtype + << " on the LHS (StructInfo of " << lhs_sinfo << "), and datatype " + << rhs_dtype << " on the RHS (StructInfo of " << rhs_sinfo << ")."); } - return x1_sinfo->dtype; + return lhs_dtype; } /*! * \brief Infer the output virtual device for binary arithmetic operators. * \param call The context Call to the operator. * \param ctx The error reporting context. - * \param x1_sinfo The struct info of the first operand - * \param x2_sinfo The struct info of the second operand + * \param lhs_sinfo The struct info of the first operand + * \param rhs_sinfo The struct info of the second operand * \return The inferred output vdevice. * \throw Throw exception if the vdevice of two input TensorStructInfo don’t match */ inline Optional InferBinaryArithOpOutVDevice(const Call& call, const BlockBuilder& ctx, - const TensorStructInfo& x1_sinfo, - const TensorStructInfo& x2_sinfo) { - if (!x1_sinfo->vdevice.defined() || !x1_sinfo->vdevice.value()->target.defined()) { - return x2_sinfo->vdevice; + const StructInfo& lhs_sinfo, + const StructInfo& rhs_sinfo) { + auto get_vdevice = [&](const StructInfo& sinfo) -> Optional { + if (const auto* tensor = sinfo.as()) { + return tensor->vdevice; + } else { + return NullOpt; + } + }; + + auto lhs_vdevice = get_vdevice(lhs_sinfo); + auto rhs_vdevice = get_vdevice(rhs_sinfo); + + if (!lhs_vdevice.defined() || !lhs_vdevice.value()->target.defined()) { + return rhs_vdevice; } - if (!x2_sinfo->vdevice.defined() || !x2_sinfo->vdevice.value()->target.defined()) { - return x1_sinfo->vdevice; + if (!rhs_vdevice.defined() || !rhs_vdevice.value()->target.defined()) { + return lhs_vdevice; } - if (x1_sinfo->vdevice.value() != x2_sinfo->vdevice.value()) { + if (lhs_vdevice.value() != rhs_vdevice.value()) { ctx->ReportFatal(Diagnostic::Error(call) - << "VDevice " << x1_sinfo->vdevice.value() << " and " - << x2_sinfo->vdevice.value() << " must be equal for binary operators"); + << "TypeErorr: " + << "Binary operators with Tensor arguments " + << "must have the same VDevice for both operands. " + << "However, " << call << " has a LHS on VDevice " << lhs_vdevice + << " and a RHS on VDevice " << rhs_vdevice); } - return x1_sinfo->vdevice; + return lhs_vdevice; } /*! diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index f1427156e0da..1c167367a826 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -32,43 +32,94 @@ namespace relax { template StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo x1_sinfo = input_sinfo[0]; - TensorStructInfo x2_sinfo = input_sinfo[1]; + Op op = Downcast(call->op); + size_t n_input = op->arguments.size(); + if (call->args.size() != n_input) { + ctx->ReportFatal(Diagnostic::Error(call) + << call->op << " op should have " << n_input << " arguments"); + } + + auto lhs_sinfo = GetStructInfo(call->args[0]); + auto rhs_sinfo = GetStructInfo(call->args[1]); // DateType - DataType output_dtype = f_compute_out_dtype(call, ctx, x1_sinfo, x2_sinfo); + DataType output_dtype = f_compute_out_dtype(call, ctx, lhs_sinfo, rhs_sinfo); + + if (lhs_sinfo.as() && rhs_sinfo.as()) { + return PrimStructInfo(output_dtype); + } else if (lhs_sinfo.as() && rhs_sinfo.as()) { + return ObjectStructInfo(); + } // VDevice - Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, x1_sinfo, x2_sinfo); + Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, lhs_sinfo, rhs_sinfo); + + auto get_ndim = [&](const StructInfo& sinfo) -> int { + if (sinfo.as()) { + return 1; + } else if (const auto* tensor = sinfo.as()) { + return tensor->ndim; + } else { + return kUnknownNDim; + } + }; // ndims - int output_ndim; - if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) { - output_ndim = kUnknownNDim; - } else { - output_ndim = std::max(x1_sinfo->ndim, x2_sinfo->ndim); - } + int output_ndim = [&]() { + int lhs_ndim = get_ndim(lhs_sinfo); + int rhs_ndim = get_ndim(rhs_sinfo); + if (lhs_ndim == kUnknownNDim || rhs_ndim == kUnknownNDim) { + return kUnknownNDim; + } else { + return std::max(lhs_ndim, rhs_ndim); + } + }(); - const auto* x1_shape = x1_sinfo->shape.as(); - const auto* x2_shape = x2_sinfo->shape.as(); - // Shapes and ndims - if (x1_shape && x2_shape) { - // If all inputs have shapes, directly infer shapes - Optional> output_shape = - InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values); - if (!output_shape.defined()) { - return TensorStructInfo(output_dtype, /*ndim=*/output_ndim, vdevice); + // Shapes + auto get_shape = [](const StructInfo& sinfo) -> Optional> { + if (sinfo.as()) { + return Array{IntImm(DataType::Int(64), 1)}; + } else if (const auto* tensor = sinfo.as()) { + return tensor->GetShape(); } else { + return NullOpt; + } + }; + + // If both inputs have a known shape, directly infer the shape of + // the output. + auto lhs_shape = get_shape(lhs_sinfo); + auto rhs_shape = get_shape(rhs_sinfo); + if (lhs_shape && rhs_shape) { + Optional> output_shape = + InferBinaryBroadcastShape(call, ctx, lhs_shape.value(), rhs_shape.value()); + if (output_shape.defined()) { ICHECK_EQ(static_cast(output_shape.value().size()), output_ndim); return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdevice); } - } else if (x1_sinfo->shape.defined() && x1_sinfo->shape.same_as(x2_sinfo->shape)) { - return TensorStructInfo(x1_sinfo->shape.value(), output_dtype, vdevice); - } else { - return TensorStructInfo(output_dtype, /*ndim=*/output_ndim, vdevice); } + + auto get_shape_expr = [](const StructInfo& sinfo) -> Optional { + if (const auto* tensor = sinfo.as()) { + return tensor->shape; + } else { + return NullOpt; + } + }; + + // If the input shape is unknown, but both inputs have the same + // `ShapeStructInfo`variable for their shape, then propagate that + // variable to the output. + auto lhs_shape_expr = get_shape_expr(lhs_sinfo); + auto rhs_shape_expr = get_shape_expr(rhs_sinfo); + if (lhs_shape_expr.defined() && lhs_shape_expr.same_as(rhs_shape_expr)) { + return TensorStructInfo(lhs_shape_expr.value(), output_dtype, vdevice); + } + + // If neither of those cases holds, then fall back to an unknown + // shape with `output_ndim` dimensionality. + return TensorStructInfo(output_dtype, output_ndim, vdevice); } StructInfo InferStructInfoBroadcastArith(const Call& call, const BlockBuilder& ctx) { @@ -78,8 +129,8 @@ StructInfo InferStructInfoBroadcastArith(const Call& call, const BlockBuilder& c StructInfo InferStructInfoBroadcastCMP(const Call& call, const BlockBuilder& ctx) { return InferStructInfoBroadcast( call, ctx, - [](const Call& call, const BlockBuilder& ctx, const TensorStructInfo& x1_sinfo, - const TensorStructInfo& x2_sinfo) { return DataType::Bool(); }); + [](const Call& call, const BlockBuilder& ctx, const StructInfo& lhs_sinfo, + const StructInfo& rhs_sinfo) { return DataType::Bool(); }); } InferLayoutOutput InferLayoutBinaryEwise(const Call& call, diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index 7c7752cfe65d..9546ce536523 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -41,9 +41,6 @@ RelaxFrameNode* GetRelaxFrame(IRDocsifier d) { } Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) { - ICHECK(n->dtype.is_int() && n->dtype.is_scalar()) << "TypeError: Relax only uses " - "scalar integer TIR variables, but gets: " - << n; if (!d->IsVarDefined(n)) { RelaxFrameNode* f = GetRelaxFrame(d); // There should be at least one Relax frame diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 0dc8b3870104..4581d536b9c2 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -581,17 +581,16 @@ PrimFunc GenerateAndCompletePrimFunc(const Array& arg_tir_var_list, const Array& root_stmts, CreateFuncInfo* info) { Array parameters; Map buffer_map; - for (const ObjectRef& x : arg_tir_var_list) { - if (auto n = x.as()) { - te::Tensor tensor = GetRef(n); + for (const ObjectRef& arg : arg_tir_var_list) { + if (auto opt_tensor = arg.as()) { + te::Tensor tensor = opt_tensor.value(); Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle())); parameters.push_back(arg); auto it = info->tensor2buffers.find(tensor); ICHECK(it != info->tensor2buffers.end()); buffer_map.Set(arg, it->second); - } else if (auto n = x.as()) { - tir::Var var = GetRef(n); - parameters.push_back(var); + } else if (auto var = arg.as()) { + parameters.push_back(var.value()); } } PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters), diff --git a/tests/python/relax/test_op_binary.py b/tests/python/relax/test_op_binary.py index a0ec08f0aba1..bac91a89942b 100644 --- a/tests/python/relax/test_op_binary.py +++ b/tests/python/relax/test_op_binary.py @@ -59,15 +59,15 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) -(binary_arith_op,) = tvm.testing.parameters( - (relax.op.add,), - (relax.op.divide,), - (relax.op.floor_divide,), - (relax.op.multiply,), - (relax.op.power,), - (relax.op.subtract,), - (relax.op.maximum,), - (relax.op.minimum,), +(binary_arith_op, tir_arith_op) = tvm.testing.parameters( + (relax.op.add, tir.Add), + (relax.op.divide, tir.Div), + (relax.op.floor_divide, tir.FloorDiv), + (relax.op.multiply, tir.Mul), + (relax.op.power, tir.pow), + (relax.op.subtract, tir.Sub), + (relax.op.maximum, tir.Max), + (relax.op.minimum, tir.Min), ) @@ -115,13 +115,47 @@ def test_binary_arith_infer_struct_info(binary_arith_op: Callable): ) -(binary_cmp_op,) = tvm.testing.parameters( - (relax.op.equal,), - (relax.op.greater,), - (relax.op.greater_equal,), - (relax.op.less,), - (relax.op.less_equal,), - (relax.op.not_equal,), +def test_infer_struct_info_binary_arith_prim_value_with_tensor(binary_arith_op: Callable): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Prim("float32")) + + _check_inference(bb, binary_arith_op(x, y), relax.TensorStructInfo((2, 3), "float32")) + + +def test_infer_struct_info_binary_arith_prim_value_with_prim_value(binary_arith_op: Callable): + bb = relax.BlockBuilder() + + x = relax.Var("x", R.Prim("float32")) + y = relax.Var("y", R.Prim("float32")) + + _check_inference(bb, binary_arith_op(x, y), relax.PrimStructInfo("float32")) + + +@pytest.mark.xfail(reason="Not yet implemented") +def test_infer_struct_info_binary_arith_known_prim_value_with_prim_value( + binary_arith_op: Callable, tir_arith_op +): + bb = relax.BlockBuilder() + + tir_x = tir.Var("tir_x", "float32") + tir_y = tir.Var("tir_y", "float32") + + x = relax.Var("x", R.Prim(value=tir_x)) + y = relax.Var("y", R.Prim(value=tir_y)) + + _check_inference(bb, binary_arith_op(x, y), relax.PrimStructInfo(value=tir_x + tir_y)) + _check_inference(bb, binary_arith_op(y, x), relax.PrimStructInfo(value=tir_y + tir_x)) + + +(binary_cmp_op, tir_cmp_op) = tvm.testing.parameters( + (relax.op.equal, tir.EQ), + (relax.op.greater, tir.GT), + (relax.op.greater_equal, tir.GE), + (relax.op.less, tir.LT), + (relax.op.less_equal, tir.LE), + (relax.op.not_equal, tir.NE), ) @@ -141,6 +175,38 @@ def test_binary_cmp_infer_struct_info(binary_cmp_op: Callable): _check_inference(bb, binary_cmp_op(x, y2), relax.TensorStructInfo((2, 3), "bool", vdev0)) +def test_infer_struct_info_binary_cmp_prim_value_to_tensor(binary_cmp_op: Callable): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Prim("float32")) + _check_inference(bb, binary_cmp_op(x, y), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(y, x), relax.TensorStructInfo((2, 3), "bool")) + + +def test_infer_struct_info_binary_cmp_prim_value_to_prim_value(binary_cmp_op: Callable): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Prim("float32")) + y = relax.Var("y", R.Prim("float32")) + _check_inference(bb, binary_cmp_op(x, y), relax.PrimStructInfo("bool")) + _check_inference(bb, binary_cmp_op(y, x), relax.PrimStructInfo("bool")) + + +@pytest.mark.xfail(reason="Not yet implemented") +def test_infer_struct_info_binary_cmp_known_prim_value_to_prim_value( + binary_cmp_op: Callable, tir_cmp_op +): + bb = relax.BlockBuilder() + + tir_x = tir.Var("tir_x", "float32") + tir_y = tir.Var("tir_y", "float32") + + x = relax.Var("x", R.Prim(value=tir_x)) + y = relax.Var("y", R.Prim(value=tir_y)) + + _check_inference(bb, binary_cmp_op(x, y), relax.PrimStructInfo(value=tir_cmp_op(tir_x, tir_y))) + _check_inference(bb, binary_cmp_op(y, x), relax.PrimStructInfo(value=tir_cmp_op(tir_y, tir_x))) + + def test_binary_infer_struct_info_shape_symbolic(binary_arith_op: Callable): bb = relax.BlockBuilder() m = tir.Var("m", "int64") @@ -184,10 +250,10 @@ def test_binary_infer_struct_info_shape_var(binary_arith_op: Callable): y4 = relax.Var("y", relax.TensorStructInfo(s4, "float32")) _check_inference(bb, binary_arith_op(x, y0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, binary_arith_op(x, y1), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, binary_arith_op(x, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) - _check_inference(bb, binary_arith_op(x, y3), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, binary_arith_op(x, y4), relax.TensorStructInfo(dtype="float32")) + # _check_inference(bb, binary_arith_op(x, y1), relax.TensorStructInfo(dtype="float32", ndim=2)) + # _check_inference(bb, binary_arith_op(x, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) + # _check_inference(bb, binary_arith_op(x, y3), relax.TensorStructInfo(dtype="float32", ndim=2)) + # _check_inference(bb, binary_arith_op(x, y4), relax.TensorStructInfo(dtype="float32")) def test_binary_arith_infer_struct_info_more_input_dtype(binary_arith_op: Callable): @@ -245,9 +311,9 @@ def test_binary_infer_struct_info_wrong_input_type(binary_arith_op: Callable): x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) y = relax.Var("y", R.Tensor((2, 3), "float32")) - with pytest.raises(TVMError): + with pytest.raises(TypeError): bb.normalize(binary_arith_op(x0, y)) - with pytest.raises(TVMError): + with pytest.raises(TypeError): bb.normalize(binary_arith_op(x1, y)) diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py b/tests/python/relax/test_transform_legalize_ops_binary.py index d71a248b2512..7b9405782433 100644 --- a/tests/python/relax/test_transform_legalize_ops_binary.py +++ b/tests/python/relax/test_transform_legalize_ops_binary.py @@ -17,7 +17,7 @@ import tvm from tvm.relax.transform import LegalizeOps -from tvm.script import relax as R, tir as T +from tvm.script import ir as I, relax as R, tir as T import tvm.testing @@ -164,6 +164,44 @@ def add(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_add: T tvm.ir.assert_structural_equal(mod, Expected) +def test_add_primvalue(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + gv = R.add(x, y) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + cls = Expected + gv = R.call_tir(cls.add, (x, y), R.Tensor([64, 32, 16], dtype="float32")) + return gv + + @T.prim_func(private=True) + def add( + lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + rhs: T.float32, + output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i, j, k in T.grid(*lhs.shape): + with T.block("T_add"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + output[vi, vj, vk] = lhs[vi, vj, vk] + rhs + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + def test_divide(): # fmt: off @tvm.script.ir_module @@ -303,6 +341,44 @@ def divide(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_div tvm.ir.assert_structural_equal(mod, Expected) +def test_divide_primvalue(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + gv = R.divide(x, y) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + cls = Expected + gv = R.call_tir(cls.divide, (x, y), R.Tensor([64, 32, 16], dtype="float32")) + return gv + + @T.prim_func(private=True) + def divide( + lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + rhs: T.float32, + output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i, j, k in T.grid(*lhs.shape): + with T.block("T_add"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + output[vi, vj, vk] = lhs[vi, vj, vk] / rhs + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + def test_floor_divide(): # fmt: off @tvm.script.ir_module @@ -442,6 +518,44 @@ def floor_divide(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var tvm.ir.assert_structural_equal(mod, Expected) +def test_floordiv_primvalue(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + gv = R.floor_divide(x, y) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + cls = Expected + gv = R.call_tir(cls.floor_divide, (x, y), R.Tensor([64, 32, 16], dtype="float32")) + return gv + + @T.prim_func(private=True) + def floor_divide( + lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + rhs: T.float32, + output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i, j, k in T.grid(*lhs.shape): + with T.block("T_floordiv"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + output[vi, vj, vk] = T.floor(lhs[vi, vj, vk] / rhs) + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + def test_multiply(): # fmt: off @tvm.script.ir_module @@ -519,6 +633,44 @@ def multiply(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_m tvm.ir.assert_structural_equal(mod, Expected) +def test_multiply_primvalue(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + gv = R.multiply(x, y) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + cls = Expected + gv = R.call_tir(cls.multiply, (x, y), R.Tensor([64, 32, 16], dtype="float32")) + return gv + + @T.prim_func(private=True) + def multiply( + lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + rhs: T.float32, + output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i, j, k in T.grid(*lhs.shape): + with T.block("T_add"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + output[vi, vj, vk] = lhs[vi, vj, vk] * rhs + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + def test_power(): # fmt: off @tvm.script.ir_module @@ -599,6 +751,44 @@ def main(x: R.Tensor((1, "c", "d"), dtype="float32"), y: R.Tensor(("a", "b", "c" tvm.ir.assert_structural_equal(mod, Expected) +def test_power_primvalue(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + gv = R.power(x, y) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + cls = Expected + gv = R.call_tir(cls.power, (x, y), R.Tensor([64, 32, 16], dtype="float32")) + return gv + + @T.prim_func(private=True) + def power( + lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + rhs: T.float32, + output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i, j, k in T.grid(*lhs.shape): + with T.block("T_power"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + output[vi, vj, vk] = T.pow(lhs[vi, vj, vk], rhs) + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + def test_subtract(): # fmt: off @tvm.script.ir_module @@ -676,6 +866,44 @@ def subtract(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_s tvm.ir.assert_structural_equal(mod, Expected) +def test_subtract_primvalue(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + gv = R.subtract(x, y) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + cls = Expected + gv = R.call_tir(cls.subtract, (x, y), R.Tensor([64, 32, 16], dtype="float32")) + return gv + + @T.prim_func(private=True) + def subtract( + lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + rhs: T.float32, + output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i, j, k in T.grid(*lhs.shape): + with T.block("T_add"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + output[vi, vj, vk] = lhs[vi, vj, vk] - rhs + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + ##################### Binary comparison ##################### @@ -818,6 +1046,44 @@ def equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_equa tvm.ir.assert_structural_equal(mod, Expected) +def test_equal_primvalue(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + gv = R.equal(x, y) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + cls = Expected + gv = R.call_tir(cls.equal, (x, y), R.Tensor([64, 32, 16], dtype="bool")) + return gv + + @T.prim_func(private=True) + def equal( + lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + rhs: T.float32, + output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"), + ): + T.func_attr({"tir.noalias": True}) + for i, j, k in T.grid(*lhs.shape): + with T.block("T_add"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + output[vi, vj, vk] = lhs[vi, vj, vk] == rhs + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + def test_greater(): # fmt: off @tvm.script.ir_module @@ -957,6 +1223,44 @@ def greater(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_gr tvm.ir.assert_structural_equal(mod, Expected) +def test_greater_primvalue(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + gv = R.greater(x, y) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + cls = Expected + gv = R.call_tir(cls.greater, (x, y), R.Tensor([64, 32, 16], dtype="bool")) + return gv + + @T.prim_func(private=True) + def greater( + lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + rhs: T.float32, + output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"), + ): + T.func_attr({"tir.noalias": True}) + for i, j, k in T.grid(*lhs.shape): + with T.block("T_add"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + output[vi, vj, vk] = rhs < lhs[vi, vj, vk] + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + def test_greater_equal(): # fmt: off @tvm.script.ir_module @@ -1034,6 +1338,44 @@ def greater_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, va tvm.ir.assert_structural_equal(mod, Expected) +def test_greater_equal_primvalue(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + gv = R.greater_equal(x, y) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + cls = Expected + gv = R.call_tir(cls.greater_equal, (x, y), R.Tensor([64, 32, 16], dtype="bool")) + return gv + + @T.prim_func(private=True) + def greater_equal( + lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + rhs: T.float32, + output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"), + ): + T.func_attr({"tir.noalias": True}) + for i, j, k in T.grid(*lhs.shape): + with T.block("T_add"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + output[vi, vj, vk] = rhs <= lhs[vi, vj, vk] + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + def test_less(): # fmt: off @tvm.script.ir_module @@ -1111,6 +1453,44 @@ def less(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_less: tvm.ir.assert_structural_equal(mod, Expected) +def test_less_primvalue(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + gv = R.less(x, y) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + cls = Expected + gv = R.call_tir(cls.less, (x, y), R.Tensor([64, 32, 16], dtype="bool")) + return gv + + @T.prim_func(private=True) + def less( + lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + rhs: T.float32, + output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"), + ): + T.func_attr({"tir.noalias": True}) + for i, j, k in T.grid(*lhs.shape): + with T.block("T_add"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + output[vi, vj, vk] = lhs[vi, vj, vk] < rhs + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + def test_less_equal(): # fmt: off @tvm.script.ir_module @@ -1250,6 +1630,44 @@ def less_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T tvm.ir.assert_structural_equal(mod, Expected) +def test_less_equal_primvalue(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + gv = R.less_equal(x, y) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + cls = Expected + gv = R.call_tir(cls.less_equal, (x, y), R.Tensor([64, 32, 16], dtype="bool")) + return gv + + @T.prim_func(private=True) + def less_equal( + lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + rhs: T.float32, + output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"), + ): + T.func_attr({"tir.noalias": True}) + for i, j, k in T.grid(*lhs.shape): + with T.block("T_add"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + output[vi, vj, vk] = lhs[vi, vj, vk] <= rhs + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + def test_not_equal(): # fmt: off @tvm.script.ir_module @@ -1327,6 +1745,44 @@ def not_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_ tvm.ir.assert_structural_equal(mod, Expected) +def test_not_equal_primvalue(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + gv = R.not_equal(x, y) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + cls = Expected + gv = R.call_tir(cls.not_equal, (x, y), R.Tensor([64, 32, 16], dtype="bool")) + return gv + + @T.prim_func(private=True) + def not_equal( + lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + rhs: T.float32, + output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"), + ): + T.func_attr({"tir.noalias": True}) + for i, j, k in T.grid(*lhs.shape): + with T.block("T_add"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + output[vi, vj, vk] = lhs[vi, vj, vk] != rhs + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + def test_maximum(): # fmt: off @tvm.script.ir_module @@ -1467,6 +1923,44 @@ def maximum(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_ma tvm.ir.assert_structural_equal(mod, Expected) +def test_max_primvalue(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + gv = R.maximum(x, y) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + cls = Expected + gv = R.call_tir(cls.maximum, (x, y), R.Tensor([64, 32, 16], dtype="float32")) + return gv + + @T.prim_func(private=True) + def maximum( + lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + rhs: T.float32, + output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i, j, k in T.grid(*lhs.shape): + with T.block("T_add"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + output[vi, vj, vk] = T.max(lhs[vi, vj, vk], rhs) + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + def test_minimum(): # fmt: off @tvm.script.ir_module @@ -1607,5 +2101,43 @@ def minimum(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_mi tvm.ir.assert_structural_equal(mod, Expected) +def test_min_primvalue(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + gv = R.minimum(x, y) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + cls = Expected + gv = R.call_tir(cls.minimum, (x, y), R.Tensor([64, 32, 16], dtype="float32")) + return gv + + @T.prim_func(private=True) + def minimum( + lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + rhs: T.float32, + output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i, j, k in T.grid(*lhs.shape): + with T.block("T_add"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + output[vi, vj, vk] = T.min(lhs[vi, vj, vk], rhs) + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main() From 4f487d5635e9944d3d18583cd48fcac584102162 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 1 Apr 2024 12:51:24 -0500 Subject: [PATCH 2/9] Fix unit tests --- python/tvm/relax/utils.py | 13 +++++++------ src/te/operation/create_primfunc.cc | 4 +++- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 0489adcde178..1c6f3d2f425c 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -382,8 +382,6 @@ def _convert_te_arg_helper(arg): name = arg.name_hint if isinstance(arg, tvm.relax.Var) else "prim_arg" call_tir_args.append(arg) return tir.Var(name, arg.struct_info.dtype) - # call_tir_args.append(tir.Var(name, arg.struct_info.dtype)) - # return arg else: return _convert_te_arg_helper(arg.struct_info.value) @@ -400,7 +398,7 @@ def _convert_te_arg_helper(arg): elif isinstance(arg, tir.PrimExpr): _copy_undefined_var(arg) new_arg = tir.stmt_functor.substitute(arg, tir_var_map) - extra_tir_args_list.append(arg) + extra_tir_args_list.append(new_arg) return new_arg elif isinstance(arg, (int, float, str, Type, Attrs)) or arg is None: return arg @@ -477,10 +475,13 @@ def _shape_with_old_tir_var( ), "only support te.tensor or tuple/list/Array of te.tensor as function output" outs = [te_out] if isinstance(te_out, te_Tensor) else list(te_out) - unbound_tir_vars = _get_unbound_tir_vars([*te_args, *outs], extra_tir_args_list) + unbound_tir_vars = _get_unbound_tir_vars( + [*call_tir_args, *outs], + extra_tir_args_list, + ) - inputs = [*te_args] + outs + unbound_tir_vars - tir_func = create_prim_func(inputs, "int64") + prim_func_args = [*call_tir_args, *outs, *unbound_tir_vars] + tir_func = create_prim_func(prim_func_args, "int64") if primfunc_attrs: tir_func = tir_func.with_attrs(primfunc_attrs) diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 4581d536b9c2..03de68e32624 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -488,7 +488,9 @@ void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, Arraynum_outputs(), 1); const te::Tensor& tensor = op.output(0); // Check op is in op list - ICHECK(info->IsArg(tensor)); + ICHECK(info->IsArg(tensor)) << "The operation " << op << " produces tensor " << tensor + << ", but this tensor does not appear as a function argument. " + << "The function accepts arguments " << info->arg_list; // Declare a buffer for any argument tensors without a pre-existing // buffer declaration recorded in the tensor2buffer binds map if (info->tensor2buffers.count(tensor) == 0) { From 4f3f510284fdbbe1f3c9bf72f67282d342423832 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 2 Apr 2024 08:22:15 -0500 Subject: [PATCH 3/9] Restore ICHECK for scalar TIR variable --- src/script/printer/relax/tir.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index 9546ce536523..1a9c5d0546ec 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -41,6 +41,10 @@ RelaxFrameNode* GetRelaxFrame(IRDocsifier d) { } Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) { + ICHECK(n->dtype.is_scalar()) << "TypeError: " + << "Relax only uses scalar TIR variables," + << "but received TIR variable " << n << " with dtype " << n->dtype; + if (!d->IsVarDefined(n)) { RelaxFrameNode* f = GetRelaxFrame(d); // There should be at least one Relax frame From 7a6ff0ad4e652b0e0c75487b4d6e3d2066eb4884 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 2 Apr 2024 08:56:00 -0500 Subject: [PATCH 4/9] Fix a few more unit tests --- python/tvm/relax/utils.py | 50 ++++++++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 1c6f3d2f425c..48beeed8da67 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -22,6 +22,7 @@ import functools import inspect import itertools +import string from typing import Tuple as typing_Tuple from typing import Any, Callable, List, Dict, Optional, TypeVar @@ -311,6 +312,7 @@ def gen_call_tir_inputs( tir_var_map: Dict[tir.Var, tir.PrimExpr] = {} call_tir_args = [] + create_primfunc_args = [] # extra list of tir expression arguments # that are not covered by Tensor extra_tir_args_list = [] @@ -355,10 +357,7 @@ def _convert_te_arg(te_args: Any) -> Any: Relax expression """ - n_tensor = 0 - def _convert_te_arg_helper(arg): - nonlocal n_tensor if isinstance(arg, Expr): # type: ignore if isinstance(arg.struct_info, TensorStructInfo): assert isinstance( @@ -367,21 +366,43 @@ def _convert_te_arg_helper(arg): for shape_value in arg.struct_info.shape.values: _copy_undefined_var(shape_value) - name = chr(ord("A") + n_tensor) if n_tensor < 26 else f"input{n_tensor}" - arg = te_tensor(arg, tir_var_map, name) - n_tensor += 1 + n_args = len(create_primfunc_args) + if isinstance(arg, tvm.relax.Var): + name = arg.name_hint + elif n_args < len(string.ascii_uppercase): + name = string.ascii_uppercase[n_args] + else: + name = f"tensor_input_{n_args}" + + te_arg = te_tensor(arg, tir_var_map, name) + call_tir_args.append(arg) - return arg + create_primfunc_args.append(te_arg) + + return te_arg + if isinstance(arg.struct_info, ShapeStructInfo): assert isinstance( arg, ShapeExpr ), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr" return [_convert_te_arg_helper(val) for val in arg.values] + if isinstance(arg.struct_info, PrimStructInfo): if arg.struct_info.value is None: - name = arg.name_hint if isinstance(arg, tvm.relax.Var) else "prim_arg" + n_args = len(create_primfunc_args) + if isinstance(arg, tvm.relax.Var): + name = arg.name_hint + elif n_args < len(string.ascii_lowercase): + name = string.ascii_lowercase[n_args] + else: + name = f"scalar_input_{n_args}" + + tir_param = tir.Var(name, arg.struct_info.dtype) + call_tir_args.append(arg) - return tir.Var(name, arg.struct_info.dtype) + create_primfunc_args.append(tir_param) + + return tir_param else: return _convert_te_arg_helper(arg.struct_info.value) @@ -475,21 +496,16 @@ def _shape_with_old_tir_var( ), "only support te.tensor or tuple/list/Array of te.tensor as function output" outs = [te_out] if isinstance(te_out, te_Tensor) else list(te_out) - unbound_tir_vars = _get_unbound_tir_vars( - [*call_tir_args, *outs], - extra_tir_args_list, - ) + unbound_tir_vars = _get_unbound_tir_vars([*create_primfunc_args, *outs], extra_tir_args_list) - prim_func_args = [*call_tir_args, *outs, *unbound_tir_vars] - tir_func = create_prim_func(prim_func_args, "int64") + inputs = [*create_primfunc_args] + outs + unbound_tir_vars + tir_func = create_prim_func(inputs, "int64") if primfunc_attrs: tir_func = tir_func.with_attrs(primfunc_attrs) tir_func = tir_func.without_attr("global_symbol") - call_tir_args = [arg.op.value if isinstance(arg, te_Tensor) else arg for arg in call_tir_args] - # Invert the TIR variable mapping, to convert the output shape back # with old set of variables. tir_var_inverse_map = {v: k for k, v in tir_var_map.items()} From efe53233a9f2cbc2f262104be42c0051f126c955 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 2 Apr 2024 09:02:09 -0500 Subject: [PATCH 5/9] Remove handling of ObjectStructInfo --- src/relax/op/op_common.h | 4 ++-- src/relax/op/tensor/binary.cc | 13 +++++++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 354bf773a9ac..39352f37713a 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -251,10 +251,10 @@ inline DataType GetElementDType(const StructInfo& sinfo) { return prim->dtype; } else if (const auto* tensor = sinfo.as()) { return tensor->dtype; - } else if (sinfo.as()) { - return DataType::Void(); } else { LOG(FATAL) << "TypeError: " + << "Only PrimStructInfo and TensorStructInfo " + << "have an associated data type. " << "Cannot determine element type of " << sinfo; } } diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 1c167367a826..afc0fb73031b 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -42,13 +42,22 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, auto lhs_sinfo = GetStructInfo(call->args[0]); auto rhs_sinfo = GetStructInfo(call->args[1]); + CHECK(lhs_sinfo.as() || lhs_sinfo.as()) + << "TypeError: " + << "Arguments to binary operators must be either R.Tensor or R.Prim types, " + << "but expression " << call << " has LHS " << call->args[0] << ", which has StructInfo " + << lhs_sinfo; + CHECK(rhs_sinfo.as() || rhs_sinfo.as()) + << "TypeError: " + << "Arguments to binary operators must be either R.Tensor or R.Prim types, " + << "but expression " << call << " has RHS " << call->args[1] << ", which has StructInfo " + << rhs_sinfo; + // DateType DataType output_dtype = f_compute_out_dtype(call, ctx, lhs_sinfo, rhs_sinfo); if (lhs_sinfo.as() && rhs_sinfo.as()) { return PrimStructInfo(output_dtype); - } else if (lhs_sinfo.as() && rhs_sinfo.as()) { - return ObjectStructInfo(); } // VDevice From 5295220a2405350de31c03e515e4b446d18234d8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 3 Apr 2024 21:51:22 -0500 Subject: [PATCH 6/9] Undo commenting-out of test cases --- tests/python/relax/test_op_binary.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_op_binary.py b/tests/python/relax/test_op_binary.py index bac91a89942b..952185615713 100644 --- a/tests/python/relax/test_op_binary.py +++ b/tests/python/relax/test_op_binary.py @@ -250,10 +250,10 @@ def test_binary_infer_struct_info_shape_var(binary_arith_op: Callable): y4 = relax.Var("y", relax.TensorStructInfo(s4, "float32")) _check_inference(bb, binary_arith_op(x, y0), relax.TensorStructInfo(s0, "float32")) - # _check_inference(bb, binary_arith_op(x, y1), relax.TensorStructInfo(dtype="float32", ndim=2)) - # _check_inference(bb, binary_arith_op(x, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) - # _check_inference(bb, binary_arith_op(x, y3), relax.TensorStructInfo(dtype="float32", ndim=2)) - # _check_inference(bb, binary_arith_op(x, y4), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, binary_arith_op(x, y1), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x, y3), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x, y4), relax.TensorStructInfo(dtype="float32")) def test_binary_arith_infer_struct_info_more_input_dtype(binary_arith_op: Callable): From b5e2608a22a679c2d2a49770f56afb01ccf370b8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 5 Apr 2024 13:36:58 -0500 Subject: [PATCH 7/9] Update for improved error messages --- src/relax/op/op_common.h | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 39352f37713a..5e19edb47c45 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -31,6 +31,7 @@ #include #include +#include #include #include #include @@ -246,12 +247,13 @@ InferLayoutOutput InferLayoutUnaryEwise(const Call& call, * \return The inferred element dtype. * \throw Throw exception if the StructInfo doesn't have an element type. */ -inline DataType GetElementDType(const StructInfo& sinfo) { +inline std::optional GetElementDType(const StructInfo& sinfo) { if (const auto* prim = sinfo.as()) { return prim->dtype; } else if (const auto* tensor = sinfo.as()) { return tensor->dtype; } else { + return std::nullopt; LOG(FATAL) << "TypeError: " << "Only PrimStructInfo and TensorStructInfo " << "have an associated data type. " @@ -271,13 +273,33 @@ inline DataType GetElementDType(const StructInfo& sinfo) { inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& ctx, const StructInfo& lhs_sinfo, const StructInfo& rhs_sinfo) { - auto lhs_dtype = GetElementDType(lhs_sinfo); - auto rhs_dtype = GetElementDType(rhs_sinfo); + auto opt_lhs_dtype = GetElementDType(lhs_sinfo); + if (!opt_lhs_dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "TypeError: " + << "Binary operators must have the same datatype for both operands. " + << "However, " << call << " has argument " << call->args[0] + << " on the LHS, with struct info " << lhs_sinfo << ". This is of type " + << lhs_sinfo->GetTypeKey() << ", which does not have a datatype."); + } + auto lhs_dtype = opt_lhs_dtype.value(); + + auto opt_rhs_dtype = GetElementDType(rhs_sinfo); + if (!opt_rhs_dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "TypeError: " + << "Binary operators must have the same datatype for both operands. " + << "However, " << call << " has argument " << call->args[1] + << " on the RHS, with struct info " << rhs_sinfo << ". This is of type " + << rhs_sinfo->GetTypeKey() << ", which does not have a datatype."); + } + auto rhs_dtype = opt_rhs_dtype.value(); + if (lhs_dtype.is_void() || rhs_dtype.is_void()) { return DataType::Void(); } else if (lhs_dtype != rhs_dtype) { ctx->ReportFatal(Diagnostic::Error(call) - << "TypeErorr: " + << "TypeError: " << "Binary operators must have the same datatype for both operands. " << "However, " << call << " uses datatype " << lhs_dtype << " on the LHS (StructInfo of " << lhs_sinfo << "), and datatype " From 16e468ec91a88130d0ea7e1cf1d9fca192840f41 Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Tue, 9 Apr 2024 15:56:13 +0000 Subject: [PATCH 8/9] Fix failing unit tests --- tests/python/relax/test_op_binary.py | 4 ++-- tests/python/relax/test_op_nn_convolution.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/relax/test_op_binary.py b/tests/python/relax/test_op_binary.py index 952185615713..85842f1578df 100644 --- a/tests/python/relax/test_op_binary.py +++ b/tests/python/relax/test_op_binary.py @@ -282,7 +282,7 @@ def test_binary_arith_infer_struct_info_dtype_mismatch(binary_arith_op: Callable bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3), "float32")) y = relax.Var("y", R.Tensor((2, 3), "int32")) - with pytest.raises(TVMError): + with pytest.raises(TypeError): bb.normalize(binary_arith_op(x, y)) @@ -290,7 +290,7 @@ def test_binary_arith_infer_struct_info_vdevice_mismatch(binary_arith_op: Callab bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3), "float32", VDevice("llvm"))) y = relax.Var("y", R.Tensor((2, 3), "int32", VDevice("cuda"))) - with pytest.raises(TVMError): + with pytest.raises(TypeError): bb.normalize(binary_arith_op(x, y)) diff --git a/tests/python/relax/test_op_nn_convolution.py b/tests/python/relax/test_op_nn_convolution.py index 55e35ee2031b..588dc9b1b19c 100644 --- a/tests/python/relax/test_op_nn_convolution.py +++ b/tests/python/relax/test_op_nn_convolution.py @@ -386,7 +386,7 @@ def test_conv1d_dtype_mismatch(): x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) w = relax.Var("w", R.Tensor((4, 3, 3), "int8")) - with pytest.raises(TVMError): + with pytest.raises(TypeError): bb.normalize(relax.op.nn.conv1d(x, w)) @@ -744,7 +744,7 @@ def test_conv1d_transpose_dtype_mismatch(): x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) w = relax.Var("w", R.Tensor((3, 4, 3), "int8")) - with pytest.raises(TVMError): + with pytest.raises(TypeError): bb.normalize(relax.op.nn.conv1d_transpose(x, w)) @@ -1141,7 +1141,7 @@ def test_conv2d_dtype_mismatch(): x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) w = relax.Var("w", R.Tensor((4, 3, 3, 3), "int8")) - with pytest.raises(TVMError): + with pytest.raises(TypeError): bb.normalize(relax.op.nn.conv2d(x, w)) @@ -1533,7 +1533,7 @@ def test_conv2d_transpose_dtype_mismatch(): x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) w = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8")) - with pytest.raises(TVMError): + with pytest.raises(TypeError): bb.normalize(relax.op.nn.conv2d_transpose(x, w)) From f097c788f3a1a2a1707902070e2ef10c70609181 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 17 Apr 2024 19:18:11 -0500 Subject: [PATCH 9/9] Fix unit test --- tests/python/relax/test_op_search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_op_search.py b/tests/python/relax/test_op_search.py index 21f022d9eb79..e67ef442f962 100644 --- a/tests/python/relax/test_op_search.py +++ b/tests/python/relax/test_op_search.py @@ -262,9 +262,9 @@ def test_where_infer_struct_info_dtype_mismatch(): x1 = relax.Var("x", R.Tensor((2, 3), "int8")) y1 = relax.Var("y", R.Tensor((2, 3), "float32")) - with pytest.raises(TVMError): + with pytest.raises(TypeError): bb.normalize(relax.op.where(cond, x0, y0)) - with pytest.raises(TVMError): + with pytest.raises(TypeError): bb.normalize(relax.op.where(cond, x1, y1))