Skip to content

Commit

Permalink
[Relax] Support input_axis_separator to allow 2D to 1D conversion (#…
Browse files Browse the repository at this point in the history
…17115)

* [Relax] Support input axis_separator to allow 2D to 1D conversion

Introduce input_axis_separator in relax.transform_layout op to allow conversion of 2D buffers to 1D buffers.
The conversion from 2D->1D is handled while lowering of transform_layout operator.
Also introducing support for input_axis_separator in AlterOpImpl pass.

* Fix LINT errors

* Fix review comments
  • Loading branch information
abhikran-quic authored Jul 1, 2024
1 parent 4a5e22e commit ab7c1a9
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 27 deletions.
8 changes: 8 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
* first input axis that is part of a new flattened axis.
*/
Optional<Array<IntImm>> axis_separators;
/*!
* axis_separators for input buffers.
* Needed to identify if the input buffer to layout_transform
* contains axis separator.
*/
Optional<Array<IntImm>> input_axis_separators;

TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") {
TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply.");
Expand All @@ -74,6 +80,8 @@ struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
"padding. If not specified, the compiler is free to choose any value.");
TVM_ATTR_FIELD(axis_separators)
.describe("The separators between input axes when generating flat output axes");
TVM_ATTR_FIELD(input_axis_separators)
.describe("The separators between axes to regenerate output");
}
}; // struct LayoutTransformAttrs

Expand Down
4 changes: 3 additions & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -559,11 +559,13 @@ TVM_DLL Pass DecomposeOpsForTraining(Optional<String> func_name);
* \param op_buffer_transforms Map from kOperatorName attr to layout transformations on each of the
* PrimFunc i/o buffers.
* \param axis_separators Map from kOperatorName attr to axis_separators of each buffer_transforms
* \param input_axis_separators Map from kOperatorName attr to axis_separator for input buffer
* \return The Pass.
*/
TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<tir::IndexMap>>& op_buffer_transforms,
const Map<String, Array<Array<IntImm>>>& axis_separators);
const Map<String, Array<Array<IntImm>>>& axis_separators,
const Map<String, Array<Array<IntImm>>>& input_axis_separators);

/*!
* \brief Layout conversion pass.
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def layout_transform(
index_map: Union[Callable, IndexMap],
pad_value: Optional[Union[int, float, PrimValue]] = None,
axis_separators: Optional[Union[int, IndexMap.AXIS_SEPARATOR]] = None,
input_axis_separators: Optional[Union[int, IndexMap.AXIS_SEPARATOR]] = None,
):
"""Modifies the layout of a tensor.
Expand Down Expand Up @@ -158,7 +159,12 @@ def layout_transform(
if axis_separators is None:
axis_separators = []

return _ffi_api.layout_transform(x, index_map, pad_value, axis_separators) # type: ignore
if input_axis_separators is None:
input_axis_separators = []

return _ffi_api.layout_transform(
x, index_map, pad_value, axis_separators, input_axis_separators
)


def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr:
Expand Down
13 changes: 10 additions & 3 deletions python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ def te_layout_transform(data, name):
name=name,
)

def set_axis_sep(axis_sep: list, sch: tir.schedule, buffer_type: str):
sch.set_axis_separator(primfunc_name, (buffer_type, 0), axis_separators=axis_sep)

index_map: tvm.tir.IndexMap = call.attrs.index_map
pad_value = call.attrs.pad_value
if pad_value is not None:
Expand All @@ -192,8 +195,10 @@ def te_layout_transform(data, name):
pad_value = float(0.0)

axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR = call.attrs.axis_separators
input_axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR = call.attrs.input_axis_separators

# Convert to list from array
axis_separators = list(map(lambda x: x.value, axis_separators))
axis_separators = [int(sep) for sep in axis_separators]
primfunc_name = "te_layout_transform"
_, padding_predicate = index_map.non_surjective_inverse(call.args[0].struct_info.shape)
if not isinstance(padding_predicate, tvm.tir.expr.IntImm):
Expand All @@ -206,8 +211,10 @@ def te_layout_transform(data, name):
# Create TIR schedule to apply layout changes with axis separators
sch = tir.Schedule(tir_func)
sch.transform_layout(primfunc_name, ("write", 0), index_map, pad_value)
if len(axis_separators) != 0:
sch.set_axis_separator(primfunc_name, ("write", 0), axis_separators=axis_separators)
set_axis_sep(axis_separators, sch, "write")
if input_axis_separators is not None:
input_axis_separators = [int(sep) for sep in input_axis_separators]
set_axis_sep(input_axis_separators, sch, "read")
gvar = bb.add_func(sch.mod["main"], primfunc_name)
output_shape = index_map.map_shape(list(call_args[0].struct_info.shape))
output_dtype = call_args[0].struct_info.dtype
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np # type: ignore

import tvm.ir
from tvm.ir.container import Array
from tvm.relax import Expr, Var, StructInfo
from tvm.relax.dpl import DFPattern
from tvm.runtime import NDArray, Object
Expand Down Expand Up @@ -1280,6 +1281,7 @@ def AlterOpImpl(
op_impl_map: Dict[str, PrimFunc],
op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]],
op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]],
op_buffer_input_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]],
):
"""Replace all PrimFunc's which have matching 'operator_name' attribute, with replacement
PrimFunc that could possibly have different layouts on i/o buffers. The layout
Expand All @@ -1295,6 +1297,8 @@ def AlterOpImpl(
op_kind to layout transformation map for each of the buffers
op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]]
op_kind to axis_separator for each index_map
op_buffer_input_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]]
op_kind to axis_separator for input index_map
Returns
-------
Expand All @@ -1303,13 +1307,19 @@ def AlterOpImpl(
for operator_name, transform_list in op_buffer_transforms.items():
l = []
for transform in transform_list:
# Extract the index_map
if isinstance(transform, Callable):
transform = IndexMap.from_func_with_separators(transform)[0]
elif isinstance(transform, (Array, tuple)) and isinstance(transform[0], IndexMap):
transform = transform[0]
l.append(transform)
op_buffer_transforms[operator_name] = l

return _ffi_api.AlterOpImpl(
op_impl_map, op_buffer_transforms, op_buffer_axis_separators
op_impl_map,
op_buffer_transforms,
op_buffer_axis_separators,
op_buffer_input_axis_separators,
) # type: ignore


Expand Down
4 changes: 3 additions & 1 deletion src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,13 @@ TVM_REGISTER_OP("relax.flatten")
TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);

Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value,
Optional<Array<IntImm>> axis_separators) {
Optional<Array<IntImm>> axis_separators,
Optional<Array<IntImm>> input_axis_separators) {
ObjectPtr<LayoutTransformAttrs> attrs = make_object<LayoutTransformAttrs>();
attrs->index_map = std::move(index_map);
attrs->pad_value = std::move(pad_value);
attrs->axis_separators = std::move(axis_separators);
attrs->input_axis_separators = std::move(input_axis_separators);

static const Op& op = Op::Get("relax.layout_transform");
return Call(op, {std::move(x)}, Attrs{attrs}, {});
Expand Down
4 changes: 3 additions & 1 deletion src/relax/op/tensor/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ Expr flatten(Expr x);
* not specified, any value can be used.
* \param axis_separators Array of values to differentiate between input axes
* when generating flattened output axes.
* \param input axis_separators Array of values for input buffer.
* \return The transformed result.
*/
Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value,
Optional<Array<IntImm>> axis_separators);
Optional<Array<IntImm>> axis_separators,
Optional<Array<IntImm>> input_axis_separators = NullOpt);

/*!
* \brief Permutes the dimensions of an array.
Expand Down
68 changes: 50 additions & 18 deletions src/relax/transform/alter_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,14 @@ class AlterOpImplMutator : public ExprMutator {
public:
AlterOpImplMutator(const IRModule& mod, const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<IndexMap>>& op_buffer_transforms_,
const Map<String, Array<Array<IntImm>>>& axis_separators_)
const Map<String, Array<Array<IntImm>>>& axis_separators_,
const Map<String, Array<Array<IntImm>>>& input_axis_separators_)
: ExprMutator(mod),
mod_(mod),
op_impl_map_(op_impl_map),
op_buffer_transforms__(op_buffer_transforms_),
op_buffer_axis_separators__(axis_separators_) {}
op_buffer_axis_separators__(axis_separators_),
op_buffer_input_axis_separators__(input_axis_separators_) {}

IRModule Run() {
for (const auto& gv : mod_->GetGlobalVars()) {
Expand Down Expand Up @@ -127,9 +129,12 @@ class AlterOpImplMutator : public ExprMutator {

Array<IndexMap> buffer_transforms;
Optional<Array<Array<IntImm>>> axis_separators;
Optional<Array<Array<IntImm>>> input_axis_separators;
if (op_buffer_transforms__.count(op_kind)) buffer_transforms = op_buffer_transforms__[op_kind];
if (op_buffer_axis_separators__.count(op_kind))
axis_separators = op_buffer_axis_separators__[op_kind];
if (op_buffer_input_axis_separators__.count(op_kind))
input_axis_separators = op_buffer_input_axis_separators__[op_kind];

ICHECK(buffer_transforms.empty() || buffer_transforms.size() == replacement_func->params.size())
<< "Either the i/o buffers do not require any transformations or transformations for each "
Expand All @@ -140,15 +145,17 @@ class AlterOpImplMutator : public ExprMutator {
GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func, op_kind);

auto call_tir_inputs_tuple = GetRef<Tuple>(call->args[1].as<TupleNode>());
Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms, axis_separators);
Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms, axis_separators,
input_axis_separators);

ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is expected to be 1";
StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0], buffer_transforms);
auto updated_call = builder_->Normalize(
Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs, {updated_ret_sinfo}));

// Now transform each of the outputs to previous layout.
return TransformOutputs(updated_call, buffer_transforms, call->sinfo_args[0], axis_separators);
return TransformOutputs(updated_call, buffer_transforms, call->sinfo_args[0], axis_separators,
input_axis_separators);
}

Array<TensorStructInfo> GetTensorStructInfoPerOutput(const StructInfo& output_sinfo) {
Expand All @@ -175,7 +182,8 @@ class AlterOpImplMutator : public ExprMutator {
}

Expr TransformLayout(const Expr& expr, const IndexMap& index_map,
const Array<IntImm>& axis_separators) {
const Array<IntImm>& axis_separators,
const Array<IntImm>& input_axis_separators) {
if (IsScalarConstant(expr) || index_map.get() == nullptr) {
return expr;
}
Expand All @@ -185,6 +193,7 @@ class AlterOpImplMutator : public ExprMutator {
// so would confuse the structural equality check.
attrs->index_map = std::move(DeepCopyIndexMap(index_map));
attrs->axis_separators = std::move(axis_separators);
attrs->input_axis_separators = std::move(input_axis_separators);
return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {});
}

Expand Down Expand Up @@ -232,7 +241,8 @@ class AlterOpImplMutator : public ExprMutator {

Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map,
const TensorStructInfo& old_tensor_sinfo,
const Array<IntImm>& axis_separator) {
const Array<IntImm>& axis_separator,
const Array<IntImm>& input_axis_separator) {
if (IsScalarConstant(expr) || index_map.get() == nullptr) {
return expr;
}
Expand All @@ -243,10 +253,10 @@ class AlterOpImplMutator : public ExprMutator {
index_map.NonSurjectiveInverse(initial_ranges, &analyzer);

if (tir::is_zero(padding_predicate)) {
return TransformLayout(expr, inverse_index_map, axis_separator);
return TransformLayout(expr, inverse_index_map, axis_separator, input_axis_separator);
} else {
auto padded_expr =
builder_->Normalize(TransformLayout(expr, inverse_index_map, axis_separator));
auto padded_expr = builder_->Normalize(
TransformLayout(expr, inverse_index_map, axis_separator, input_axis_separator));
const auto& tensor_sinfo = Downcast<TensorStructInfo>(padded_expr->struct_info_);

GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape, tensor_sinfo->dtype);
Expand Down Expand Up @@ -277,19 +287,26 @@ class AlterOpImplMutator : public ExprMutator {
* \brief Updates call inputs with layout transformed inputs
*/
Tuple UpdateInputs(const Tuple& inputs, const Array<IndexMap>& transforms,
const Optional<Array<Array<IntImm>>>& axis_separators) {
const Optional<Array<Array<IntImm>>>& axis_separators,
const Optional<Array<Array<IntImm>>>& input_axis_separators) {
if (transforms.empty()) return inputs;

Array<Expr> updated_inputs;
int index = 0;
for (const auto& input : inputs->fields) {
Array<IntImm> axis_separator;
Array<IntImm> input_axis_separator;
if (axis_separators.defined()) {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_separator = axis_separators_value[index];
}
if (input_axis_separators.defined()) {
Array<Array<IntImm>> input_axis_separators_value = input_axis_separators.value();
input_axis_separator = input_axis_separators_value[index];
}
auto transform = transforms[index++];
updated_inputs.push_back(TransformLayout(input, transform, axis_separator));
updated_inputs.push_back(
TransformLayout(input, transform, axis_separator, input_axis_separator));
}
return Tuple(updated_inputs);
}
Expand Down Expand Up @@ -338,12 +355,13 @@ class AlterOpImplMutator : public ExprMutator {

Expr TransformOutputs(const Expr& expr, const Array<IndexMap>& buffer_transforms,
const StructInfo& old_struct_info,
const Optional<Array<Array<IntImm>>>& axis_separators) {
const Optional<Array<Array<IntImm>>>& axis_separators,
const Optional<Array<Array<IntImm>>>& input_axis_separators) {
if (buffer_transforms.empty()) return expr;

Array<TensorStructInfo> old_output_sinfo = GetTensorStructInfoPerOutput(old_struct_info);

Array<IntImm> axis_sep;
Array<IntImm> axis_sep, input_axis_sep;
size_t num_outputs = old_output_sinfo.size();
if (num_outputs == 0) return expr;

Expand All @@ -355,7 +373,12 @@ class AlterOpImplMutator : public ExprMutator {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_sep = axis_separators_value[first_output_index];
}
return TransformLayoutInverse(expr, output_map, old_output_sinfo[0], axis_sep);
if (input_axis_separators.defined()) {
Array<Array<IntImm>> input_axis_separators_value = input_axis_separators.value();
input_axis_sep = input_axis_separators_value[first_output_index];
}
return TransformLayoutInverse(expr, output_map, old_output_sinfo[0], axis_sep,
input_axis_sep);
}

// In case of more than one output, we would have to get each item of the output tuple,
Expand All @@ -367,9 +390,13 @@ class AlterOpImplMutator : public ExprMutator {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_sep = axis_separators_value[i + first_output_index];
}
if (input_axis_separators.defined()) {
Array<Array<IntImm>> input_axis_separators_value = input_axis_separators.value();
input_axis_sep = input_axis_separators_value[i + first_output_index];
}
auto output = builder_->Normalize(TupleGetItem(expr, static_cast<int>(i)));
transformed_outputs.push_back(
TransformLayoutInverse(output, output_map, old_output_sinfo[i], axis_sep));
transformed_outputs.push_back(TransformLayoutInverse(output, output_map, old_output_sinfo[i],
axis_sep, input_axis_sep));
}
return Tuple(transformed_outputs);
}
Expand All @@ -387,6 +414,8 @@ class AlterOpImplMutator : public ExprMutator {
const Map<String, Array<IndexMap>>& op_buffer_transforms__;
/*! \brief Map from kOperatorName attribute to the axis separatos on i/o buffers */
const Map<String, Array<Array<IntImm>>>& op_buffer_axis_separators__;
/*! \brief Map from kOperatorName attribute to the input axis separatos */
const Map<String, Array<Array<IntImm>>>& op_buffer_input_axis_separators__;

const Op& call_tir_op_ = Op::Get("relax.call_tir");
const Op& layout_transform_op_ = Op::Get("relax.layout_transform");
Expand All @@ -396,10 +425,13 @@ namespace transform {

Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<IndexMap>>& op_buffer_transforms_,
const Map<String, Array<Array<IntImm>>>& axis_separators_) {
const Map<String, Array<Array<IntImm>>>& axis_separators_,
const Map<String, Array<Array<IntImm>>>& input_axis_separators_) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
PassContext pc) {
return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_, axis_separators_).Run();
return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_, axis_separators_,
input_axis_separators_)
.Run();
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
Expand Down
Loading

0 comments on commit ab7c1a9

Please sign in to comment.