Skip to content

Commit

Permalink
Fix LINT errors
Browse files Browse the repository at this point in the history
  • Loading branch information
abhikran-quic committed Jun 26, 2024
1 parent 78eb4e6 commit 28a4bb6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
4 changes: 3 additions & 1 deletion python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def layout_transform(
if input_axis_separators is None:
input_axis_separators = []

return _ffi_api.layout_transform(x, index_map, pad_value, axis_separators, input_axis_separators) # type: ignore
return _ffi_api.layout_transform(
x, index_map, pad_value, axis_separators, input_axis_separators
)


def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr:
Expand Down
14 changes: 6 additions & 8 deletions src/relax/transform/alter_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class AlterOpImplMutator : public ExprMutator {

auto call_tir_inputs_tuple = GetRef<Tuple>(call->args[1].as<TupleNode>());
Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms, axis_separators,
replacement_func, buffer_pad_values, input_axis_separators);
input_axis_separators);

ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is expected to be 1";
StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0], buffer_transforms);
Expand Down Expand Up @@ -252,11 +252,10 @@ class AlterOpImplMutator : public ExprMutator {
index_map.NonSurjectiveInverse(initial_ranges, &analyzer);

if (tir::is_zero(padding_predicate)) {
return TransformLayout(expr, inverse_index_map, axis_separator,
input_axis_separator);
return TransformLayout(expr, inverse_index_map, axis_separator, input_axis_separator);
} else {
auto padded_expr = builder_->Normalize(TransformLayout(
expr, inverse_index_map, axis_separator, input_axis_separator));
auto padded_expr = builder_->Normalize(
TransformLayout(expr, inverse_index_map, axis_separator, input_axis_separator));
const auto& tensor_sinfo = Downcast<TensorStructInfo>(padded_expr->struct_info_);

GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape, tensor_sinfo->dtype);
Expand Down Expand Up @@ -304,8 +303,8 @@ class AlterOpImplMutator : public ExprMutator {
input_axis_separator = input_axis_separators_value[index];
}
auto transform = transforms[index++];
updated_inputs.push_back(TransformLayout(input, transform, axis_separator,
input_axis_separator));
updated_inputs.push_back(
TransformLayout(input, transform, axis_separator, input_axis_separator));
}
return Tuple(updated_inputs);
}
Expand Down Expand Up @@ -355,7 +354,6 @@ class AlterOpImplMutator : public ExprMutator {
Expr TransformOutputs(const Expr& expr, const Array<IndexMap>& buffer_transforms,
const StructInfo& old_struct_info,
const Optional<Array<Array<IntImm>>>& axis_separators,
const Optional<Array<IntImm>>& buffer_pad_values,
const Optional<Array<Array<IntImm>>>& input_axis_separators) {
if (buffer_transforms.empty()) return expr;

Expand Down

0 comments on commit 28a4bb6

Please sign in to comment.