diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index b9d0b9f53bb7..ef4265d73b4b 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -66,6 +66,12 @@ struct LayoutTransformAttrs : public tvm::AttrsNode { * first input axis that is part of a new flattened axis. */ Optional> axis_separators; + /*! + * axis_separators for input buffers. + * Needed to identify if the input buffer to layout_transform + * contains axis separator. + */ + Optional> input_axis_separators; TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") { TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply."); @@ -74,6 +80,8 @@ struct LayoutTransformAttrs : public tvm::AttrsNode { "padding. If not specified, the compiler is free to choose any value."); TVM_ATTR_FIELD(axis_separators) .describe("The separators between input axes when generating flat output axes"); + TVM_ATTR_FIELD(input_axis_separators) + .describe("The separators between axes to regenerate output"); } }; // struct LayoutTransformAttrs diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index d8f36e478669..5a7b85ac1376 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -559,11 +559,13 @@ TVM_DLL Pass DecomposeOpsForTraining(Optional func_name); * \param op_buffer_transforms Map from kOperatorName attr to layout transformations on each of the * PrimFunc i/o buffers. * \param axis_separators Map from kOperatorName attr to axis_separators of each buffer_transforms + * \param input_axis_separators Map from kOperatorName attr to axis_separator for input buffer * \return The Pass. */ TVM_DLL Pass AlterOpImpl(const Map& op_impl_map, const Map>& op_buffer_transforms, - const Map>>& axis_separators); + const Map>>& axis_separators, + const Map>>& input_axis_separators); /*! * \brief Layout conversion pass. diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 9bd99020e998..da0a09cc7b51 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -115,6 +115,7 @@ def layout_transform( index_map: Union[Callable, IndexMap], pad_value: Optional[Union[int, float, PrimValue]] = None, axis_separators: Optional[Union[int, IndexMap.AXIS_SEPARATOR]] = None, + input_axis_separators: Optional[Union[int, IndexMap.AXIS_SEPARATOR]] = None, ): """Modifies the layout of a tensor. @@ -158,7 +159,12 @@ def layout_transform( if axis_separators is None: axis_separators = [] - return _ffi_api.layout_transform(x, index_map, pad_value, axis_separators) # type: ignore + if input_axis_separators is None: + input_axis_separators = [] + + return _ffi_api.layout_transform( + x, index_map, pad_value, axis_separators, input_axis_separators + ) def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr: diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index e56240dc0d12..4d30b97f6467 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -181,6 +181,9 @@ def te_layout_transform(data, name): name=name, ) + def set_axis_sep(axis_sep: list, sch: tir.schedule, buffer_type: str): + sch.set_axis_separator(primfunc_name, (buffer_type, 0), axis_separators=axis_sep) + index_map: tvm.tir.IndexMap = call.attrs.index_map pad_value = call.attrs.pad_value if pad_value is not None: @@ -192,8 +195,10 @@ def te_layout_transform(data, name): pad_value = float(0.0) axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR = call.attrs.axis_separators + input_axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR = call.attrs.input_axis_separators + # Convert to list from array - axis_separators = list(map(lambda x: x.value, axis_separators)) + axis_separators = [int(sep) for sep in axis_separators] primfunc_name = "te_layout_transform" _, padding_predicate = index_map.non_surjective_inverse(call.args[0].struct_info.shape) if not isinstance(padding_predicate, tvm.tir.expr.IntImm): @@ -206,8 +211,10 @@ def te_layout_transform(data, name): # Create TIR schedule to apply layout changes with axis separators sch = tir.Schedule(tir_func) sch.transform_layout(primfunc_name, ("write", 0), index_map, pad_value) - if len(axis_separators) != 0: - sch.set_axis_separator(primfunc_name, ("write", 0), axis_separators=axis_separators) + set_axis_sep(axis_separators, sch, "write") + if input_axis_separators is not None: + input_axis_separators = [int(sep) for sep in input_axis_separators] + set_axis_sep(input_axis_separators, sch, "read") gvar = bb.add_func(sch.mod["main"], primfunc_name) output_shape = index_map.map_shape(list(call_args[0].struct_info.shape)) output_dtype = call_args[0].struct_info.dtype diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 38e7994eb97f..3528b4429e6f 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -24,6 +24,7 @@ import numpy as np # type: ignore import tvm.ir +from tvm.ir.container import Array from tvm.relax import Expr, Var, StructInfo from tvm.relax.dpl import DFPattern from tvm.runtime import NDArray, Object @@ -1280,6 +1281,7 @@ def AlterOpImpl( op_impl_map: Dict[str, PrimFunc], op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]], op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]], + op_buffer_input_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]], ): """Replace all PrimFunc's which have matching 'operator_name' attribute, with replacement PrimFunc that could possibly have different layouts on i/o buffers. The layout @@ -1295,6 +1297,8 @@ def AlterOpImpl( op_kind to layout transformation map for each of the buffers op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]] op_kind to axis_separator for each index_map + op_buffer_input_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]] + op_kind to axis_separator for input index_map Returns ------- @@ -1303,13 +1307,19 @@ def AlterOpImpl( for operator_name, transform_list in op_buffer_transforms.items(): l = [] for transform in transform_list: + # Extract the index_map if isinstance(transform, Callable): transform = IndexMap.from_func_with_separators(transform)[0] + elif isinstance(transform, (Array, tuple)) and isinstance(transform[0], IndexMap): + transform = transform[0] l.append(transform) op_buffer_transforms[operator_name] = l return _ffi_api.AlterOpImpl( - op_impl_map, op_buffer_transforms, op_buffer_axis_separators + op_impl_map, + op_buffer_transforms, + op_buffer_axis_separators, + op_buffer_input_axis_separators, ) # type: ignore diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index ad2a812c8254..07c90756bf90 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -472,11 +472,13 @@ TVM_REGISTER_OP("relax.flatten") TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value, - Optional> axis_separators) { + Optional> axis_separators, + Optional> input_axis_separators) { ObjectPtr attrs = make_object(); attrs->index_map = std::move(index_map); attrs->pad_value = std::move(pad_value); attrs->axis_separators = std::move(axis_separators); + attrs->input_axis_separators = std::move(input_axis_separators); static const Op& op = Op::Get("relax.layout_transform"); return Call(op, {std::move(x)}, Attrs{attrs}, {}); diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index b19e3b85070d..32aa10776894 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -67,10 +67,12 @@ Expr flatten(Expr x); * not specified, any value can be used. * \param axis_separators Array of values to differentiate between input axes * when generating flattened output axes. + * \param input axis_separators Array of values for input buffer. * \return The transformed result. */ Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value, - Optional> axis_separators); + Optional> axis_separators, + Optional> input_axis_separators = NullOpt); /*! * \brief Permutes the dimensions of an array. diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index 2cb226d56e27..aaf643f8011d 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -81,12 +81,14 @@ class AlterOpImplMutator : public ExprMutator { public: AlterOpImplMutator(const IRModule& mod, const Map& op_impl_map, const Map>& op_buffer_transforms_, - const Map>>& axis_separators_) + const Map>>& axis_separators_, + const Map>>& input_axis_separators_) : ExprMutator(mod), mod_(mod), op_impl_map_(op_impl_map), op_buffer_transforms__(op_buffer_transforms_), - op_buffer_axis_separators__(axis_separators_) {} + op_buffer_axis_separators__(axis_separators_), + op_buffer_input_axis_separators__(input_axis_separators_) {} IRModule Run() { for (const auto& gv : mod_->GetGlobalVars()) { @@ -127,9 +129,12 @@ class AlterOpImplMutator : public ExprMutator { Array buffer_transforms; Optional>> axis_separators; + Optional>> input_axis_separators; if (op_buffer_transforms__.count(op_kind)) buffer_transforms = op_buffer_transforms__[op_kind]; if (op_buffer_axis_separators__.count(op_kind)) axis_separators = op_buffer_axis_separators__[op_kind]; + if (op_buffer_input_axis_separators__.count(op_kind)) + input_axis_separators = op_buffer_input_axis_separators__[op_kind]; ICHECK(buffer_transforms.empty() || buffer_transforms.size() == replacement_func->params.size()) << "Either the i/o buffers do not require any transformations or transformations for each " @@ -140,7 +145,8 @@ class AlterOpImplMutator : public ExprMutator { GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func, op_kind); auto call_tir_inputs_tuple = GetRef(call->args[1].as()); - Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms, axis_separators); + Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms, axis_separators, + input_axis_separators); ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is expected to be 1"; StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0], buffer_transforms); @@ -148,7 +154,8 @@ class AlterOpImplMutator : public ExprMutator { Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs, {updated_ret_sinfo})); // Now transform each of the outputs to previous layout. - return TransformOutputs(updated_call, buffer_transforms, call->sinfo_args[0], axis_separators); + return TransformOutputs(updated_call, buffer_transforms, call->sinfo_args[0], axis_separators, + input_axis_separators); } Array GetTensorStructInfoPerOutput(const StructInfo& output_sinfo) { @@ -175,7 +182,8 @@ class AlterOpImplMutator : public ExprMutator { } Expr TransformLayout(const Expr& expr, const IndexMap& index_map, - const Array& axis_separators) { + const Array& axis_separators, + const Array& input_axis_separators) { if (IsScalarConstant(expr) || index_map.get() == nullptr) { return expr; } @@ -185,6 +193,7 @@ class AlterOpImplMutator : public ExprMutator { // so would confuse the structural equality check. attrs->index_map = std::move(DeepCopyIndexMap(index_map)); attrs->axis_separators = std::move(axis_separators); + attrs->input_axis_separators = std::move(input_axis_separators); return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {}); } @@ -232,7 +241,8 @@ class AlterOpImplMutator : public ExprMutator { Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map, const TensorStructInfo& old_tensor_sinfo, - const Array& axis_separator) { + const Array& axis_separator, + const Array& input_axis_separator) { if (IsScalarConstant(expr) || index_map.get() == nullptr) { return expr; } @@ -243,10 +253,10 @@ class AlterOpImplMutator : public ExprMutator { index_map.NonSurjectiveInverse(initial_ranges, &analyzer); if (tir::is_zero(padding_predicate)) { - return TransformLayout(expr, inverse_index_map, axis_separator); + return TransformLayout(expr, inverse_index_map, axis_separator, input_axis_separator); } else { - auto padded_expr = - builder_->Normalize(TransformLayout(expr, inverse_index_map, axis_separator)); + auto padded_expr = builder_->Normalize( + TransformLayout(expr, inverse_index_map, axis_separator, input_axis_separator)); const auto& tensor_sinfo = Downcast(padded_expr->struct_info_); GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape, tensor_sinfo->dtype); @@ -277,19 +287,26 @@ class AlterOpImplMutator : public ExprMutator { * \brief Updates call inputs with layout transformed inputs */ Tuple UpdateInputs(const Tuple& inputs, const Array& transforms, - const Optional>>& axis_separators) { + const Optional>>& axis_separators, + const Optional>>& input_axis_separators) { if (transforms.empty()) return inputs; Array updated_inputs; int index = 0; for (const auto& input : inputs->fields) { Array axis_separator; + Array input_axis_separator; if (axis_separators.defined()) { Array> axis_separators_value = axis_separators.value(); axis_separator = axis_separators_value[index]; } + if (input_axis_separators.defined()) { + Array> input_axis_separators_value = input_axis_separators.value(); + input_axis_separator = input_axis_separators_value[index]; + } auto transform = transforms[index++]; - updated_inputs.push_back(TransformLayout(input, transform, axis_separator)); + updated_inputs.push_back( + TransformLayout(input, transform, axis_separator, input_axis_separator)); } return Tuple(updated_inputs); } @@ -338,12 +355,13 @@ class AlterOpImplMutator : public ExprMutator { Expr TransformOutputs(const Expr& expr, const Array& buffer_transforms, const StructInfo& old_struct_info, - const Optional>>& axis_separators) { + const Optional>>& axis_separators, + const Optional>>& input_axis_separators) { if (buffer_transforms.empty()) return expr; Array old_output_sinfo = GetTensorStructInfoPerOutput(old_struct_info); - Array axis_sep; + Array axis_sep, input_axis_sep; size_t num_outputs = old_output_sinfo.size(); if (num_outputs == 0) return expr; @@ -355,7 +373,12 @@ class AlterOpImplMutator : public ExprMutator { Array> axis_separators_value = axis_separators.value(); axis_sep = axis_separators_value[first_output_index]; } - return TransformLayoutInverse(expr, output_map, old_output_sinfo[0], axis_sep); + if (input_axis_separators.defined()) { + Array> input_axis_separators_value = input_axis_separators.value(); + input_axis_sep = input_axis_separators_value[first_output_index]; + } + return TransformLayoutInverse(expr, output_map, old_output_sinfo[0], axis_sep, + input_axis_sep); } // In case of more than one output, we would have to get each item of the output tuple, @@ -367,9 +390,13 @@ class AlterOpImplMutator : public ExprMutator { Array> axis_separators_value = axis_separators.value(); axis_sep = axis_separators_value[i + first_output_index]; } + if (input_axis_separators.defined()) { + Array> input_axis_separators_value = input_axis_separators.value(); + input_axis_sep = input_axis_separators_value[i + first_output_index]; + } auto output = builder_->Normalize(TupleGetItem(expr, static_cast(i))); - transformed_outputs.push_back( - TransformLayoutInverse(output, output_map, old_output_sinfo[i], axis_sep)); + transformed_outputs.push_back(TransformLayoutInverse(output, output_map, old_output_sinfo[i], + axis_sep, input_axis_sep)); } return Tuple(transformed_outputs); } @@ -387,6 +414,8 @@ class AlterOpImplMutator : public ExprMutator { const Map>& op_buffer_transforms__; /*! \brief Map from kOperatorName attribute to the axis separatos on i/o buffers */ const Map>>& op_buffer_axis_separators__; + /*! \brief Map from kOperatorName attribute to the input axis separatos */ + const Map>>& op_buffer_input_axis_separators__; const Op& call_tir_op_ = Op::Get("relax.call_tir"); const Op& layout_transform_op_ = Op::Get("relax.layout_transform"); @@ -396,10 +425,13 @@ namespace transform { Pass AlterOpImpl(const Map& op_impl_map, const Map>& op_buffer_transforms_, - const Map>>& axis_separators_) { + const Map>>& axis_separators_, + const Map>>& input_axis_separators_) { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext pc) { - return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_, axis_separators_).Run(); + return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_, axis_separators_, + input_axis_separators_) + .Run(); }; return CreateModulePass(/*pass_function=*/pass_func, // /*opt_level=*/0, // diff --git a/tests/python/relax/test_transform_alter_op_impl.py b/tests/python/relax/test_transform_alter_op_impl.py index f2bad31f2116..f1824eba6baa 100644 --- a/tests/python/relax/test_transform_alter_op_impl.py +++ b/tests/python/relax/test_transform_alter_op_impl.py @@ -26,12 +26,19 @@ def _check( - before, expected, operator_name, replacement_primfunc, layout_changes, axis_separator=None + before, + expected, + operator_name, + replacement_primfunc, + layout_changes, + axis_separator=None, + input_axis_separator=None, ): after = relax.transform.AlterOpImpl( {operator_name: replacement_primfunc}, {operator_name: layout_changes}, {operator_name: axis_separator}, + {operator_name: input_axis_separator}, )(before) after = relax.transform.DeadCodeElimination()(after) tvm.ir.assert_structural_equal(after, expected) @@ -572,5 +579,81 @@ def reshape_new( ) +def test_input_axis_separator(): + # fmt: off + @I.ir_module + class Before: + @T.prim_func(private=True) + def some_op(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), "float32"), output0: T.Buffer((16,), "float32"), output1: T.Buffer((16,), "float32")): + T.func_attr({"operator_name": "relax.some_op"}) + for ax0 in range(16): + with T.block("T_add"): + v_ax0 = T.axis.spatial(16, ax0) + T.reads(arg0[v_ax0], arg1[v_ax0]) + T.writes(output0[v_ax0], output1[v_ax0]) + output0[v_ax0] = arg0[v_ax0] + arg1[v_ax0] + output1[v_ax0] = arg0[v_ax0] - arg1[v_ax0] + + @R.function + def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")): + with R.dataflow(): + gv = R.call_tir(Before.some_op, (x, y), out_sinfo=[R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")]) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def relax_some_op_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4), "float32")): + T.func_attr({"operator_name": "relax.some_op"}) + for ax0, ax1 in T.grid(4, 4): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1] + output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1] + + @R.function + def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=lambda i: (i // 4, i % 4), pad_value=None, axis_separators=[1]) + lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, index_map=lambda i: (i // 4, i % 4), pad_value=None, axis_separators=[1]) + lv2 = R.call_tir(Expected.relax_some_op_replacement, (lv, lv1), out_sinfo=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4), dtype="float32")]) + lv3: R.Tensor((4, 4), dtype="float32") = lv2[0] + lv4: R.Tensor((16,), dtype="float32") = R.layout_transform(lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None, axis_separators=[], input_axis_separators=[1]) + lv5: R.Tensor((4, 4), dtype="float32") = lv2[1] + lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None, axis_separators=[], input_axis_separators=[1]) + gv: R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")) = (lv4, lv6) + R.output(gv) + return gv + + @T.prim_func(private=True) + def some_op_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4), "float32")): + for ax0, ax1 in T.grid(4, 4): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1] + output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1] + # fmt: on + + index_map_axis_sep = IndexMap.from_func_with_separators( + lambda i: (i // 4, IndexMap.AXIS_SEPARATOR, i % 4) + ) + + _check( + Before, + Expected, + operator_name="relax.some_op", + replacement_primfunc=some_op_2d, + layout_changes=[ + index_map_axis_sep, + index_map_axis_sep, + index_map_axis_sep, + index_map_axis_sep, + ], + axis_separator=[index_map_axis_sep[1], index_map_axis_sep[1], [], []], + input_axis_separator=[[], [], index_map_axis_sep[1], index_map_axis_sep[1]], + ) + + if __name__ == "__main__": tvm.testing.main()