From 28a4bb6e3dfae549a437db5d36c6f89dc2c5d505 Mon Sep 17 00:00:00 2001 From: abhikran Date: Tue, 25 Jun 2024 12:31:45 +0530 Subject: [PATCH] Fix LINT errors --- python/tvm/relax/op/manipulate.py | 4 +++- src/relax/transform/alter_op_impl.cc | 14 ++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index a81d0621da05..da0a09cc7b51 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -162,7 +162,9 @@ def layout_transform( if input_axis_separators is None: input_axis_separators = [] - return _ffi_api.layout_transform(x, index_map, pad_value, axis_separators, input_axis_separators) # type: ignore + return _ffi_api.layout_transform( + x, index_map, pad_value, axis_separators, input_axis_separators + ) def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr: diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index d6d0dc81d669..308a55411e9c 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -145,7 +145,7 @@ class AlterOpImplMutator : public ExprMutator { auto call_tir_inputs_tuple = GetRef(call->args[1].as()); Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms, axis_separators, - replacement_func, buffer_pad_values, input_axis_separators); + input_axis_separators); ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is expected to be 1"; StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0], buffer_transforms); @@ -252,11 +252,10 @@ class AlterOpImplMutator : public ExprMutator { index_map.NonSurjectiveInverse(initial_ranges, &analyzer); if (tir::is_zero(padding_predicate)) { - return TransformLayout(expr, inverse_index_map, axis_separator, - input_axis_separator); + return TransformLayout(expr, inverse_index_map, axis_separator, input_axis_separator); } else { - auto padded_expr = builder_->Normalize(TransformLayout( - expr, inverse_index_map, axis_separator, input_axis_separator)); + auto padded_expr = builder_->Normalize( + TransformLayout(expr, inverse_index_map, axis_separator, input_axis_separator)); const auto& tensor_sinfo = Downcast(padded_expr->struct_info_); GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape, tensor_sinfo->dtype); @@ -304,8 +303,8 @@ class AlterOpImplMutator : public ExprMutator { input_axis_separator = input_axis_separators_value[index]; } auto transform = transforms[index++]; - updated_inputs.push_back(TransformLayout(input, transform, axis_separator, - input_axis_separator)); + updated_inputs.push_back( + TransformLayout(input, transform, axis_separator, input_axis_separator)); } return Tuple(updated_inputs); } @@ -355,7 +354,6 @@ class AlterOpImplMutator : public ExprMutator { Expr TransformOutputs(const Expr& expr, const Array& buffer_transforms, const StructInfo& old_struct_info, const Optional>>& axis_separators, - const Optional>& buffer_pad_values, const Optional>>& input_axis_separators) { if (buffer_transforms.empty()) return expr;