Skip to content

Commit

Permalink
[NFC] Update black version (#3256)
Browse files Browse the repository at this point in the history
* Update black version to support 3.11/3.12
* Reformat code
  • Loading branch information
penguin-wwy authored Apr 29, 2024
1 parent aed2cf3 commit b218519
Show file tree
Hide file tree
Showing 24 changed files with 49 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 22.10.0
rev: 24.4.2
hooks:
- id: black

Expand Down
1 change: 1 addition & 0 deletions build_tools/scrape_releases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
See https://github.com/llvm/torch-mlir/issues/1374
"""

import argparse
import json

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from transformers import BertForMaskedLM


# Wrap the bert model to avoid multiple returns problem
class BertTinyWrapper(torch.nn.Module):
def __init__(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions projects/pt1/python/torch_mlir/_dynamo_fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ def __init__(self, g: torch.fx.Graph, func_name: str):
# FakeTensor's in case of a tuple return with multiple elements.
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
self._module = ir.Module.create(ir.Location.unknown())
self._module.operation.attributes[
"torch.debug_module_name"
] = ir.StringAttr.get(func_name)
self._module.operation.attributes["torch.debug_module_name"] = (
ir.StringAttr.get(func_name)
)
function_type = _extract_function_type_from_graph(g)
func = func_dialect.FuncOp(
func_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@ def emit_with_mutating_variants(key, **kwargs):
(ns, unqual + "_", overload if not is_functional_op else "")
),
emitter_td,
traits=["IsTrailingUnderscoreInplaceVariant"]
if not is_functional_op
else [],
traits=(
["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else []
),
)

# ==========================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def convert_onnx(model, inputs):
examples = []
input_names = []
dynamic_tensors = {}
for (index, arg) in enumerate(inputs):
for index, arg in enumerate(inputs):
shape = map(lambda d: d if d >= 0 else 1, arg.shape)
shape = tuple(shape)
examples.append(torch.zeros(size=shape, dtype=arg.dtype))
Expand All @@ -55,7 +55,7 @@ def convert_onnx(model, inputs):
input_names.append(input_name)

dynamic_dims = {}
for (dimindex, dim) in enumerate(arg.shape):
for dimindex, dim in enumerate(arg.shape):
if dim < 0:
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,12 @@ def __init__(self, module):
def consume_return_funcs(*args):
self.result = tuple(
[
arg
if type in elemental_type_to_ctype
else unranked_memref_to_numpy(
arg, memref_type_to_np_dtype[type]
(
arg
if type in elemental_type_to_ctype
else unranked_memref_to_numpy(
arg, memref_type_to_np_dtype[type]
)
)
for arg, type in zip(args, ret_types)
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -803,9 +803,7 @@ def forward(self, x):

@register_test_case(module_factory=lambda: QuantizedReluInt32())
def QuantizedReluInt32_basic(module, tu: TestUtils):
module.forward(
tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32)
)
module.forward(tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32))


# ==============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils):

# ==============================================================================


# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
class SliceScatterModule(torch.nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

mb = ModuleBuilder()


# CHECK: module attributes {torch.debug_module_name = "TestModule"}
class TestModule(torch.nn.Module):
def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so
# naively duplicating a Tensor retains the identity of the TensorImpl.


# CHECK-LABEL: torch.class_type @__torch__.TestModule {
class TestModule(torch.nn.Module):
def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

mb = ModuleBuilder()


# CHECK-LABEL: torch.class_type @__torch__.TestModule {
class TestModule(torch.nn.Module):
def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

mb = ModuleBuilder()


# CHECK-LABEL: func.func @__torch__.add3
# Note that line-level debug information for parts unannotated in the Torch
# graph are ascribed to the first op that carries source information. Presently
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

mb = ModuleBuilder()


# CHECK-LABEL: @__torch__.f
@mb.import_function
@torch.jit.script
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

mb = ModuleBuilder()


# CHECK-LABEL: func.func @__torch__.optional_return(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/test/python/importer/jit_ir/node_import/if.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# else branch and making all defined values optional, so no special handling
# is needed.


# CHECK-LABEL: @__torch__.prim_If(
# CHECK-SAME: %[[B:.*]]: !torch.bool,
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

mb = ModuleBuilder()


# CHECK-LABEL: func.func @__torch__.prim_Loop_forlike(
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float {
# CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

mb = ModuleBuilder()


# CHECK-LABEL: func.func @__torch__.prim_NumToTensor(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
mb = ModuleBuilder()
NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])])


# CHECK-LABEL: func.func @__torch__.tuple(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

mb = ModuleBuilder()


# CHECK: @__torch__.returns_bool
@mb.import_function
@torch.jit.script
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

mb = ModuleBuilder()


# CHECK: @__torch__.returns_none
@mb.import_function
@torch.jit.script
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# RUN: %PYTHON %s


# Import TorchScript IR string as ScriptFunction.
def create_script_function(func_name, ts_ir_str, **kwargs):
cu = CompilationUnit()
Expand Down
3 changes: 1 addition & 2 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1849,8 +1849,7 @@ def _emit_operation(

# Opaque value to indicate something is empty. Used in cases where 'None'
# may have a different meaning.
class EmptyType:
...
class EmptyType: ...


Empty = EmptyType()
Expand Down
31 changes: 16 additions & 15 deletions python/torch_mlir/extras/onnx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,7 @@ def find_type_proto_for_name(self, name: str) -> onnx.TypeProto:
return ""


class OnnxImportError(Exception):
...
class OnnxImportError(Exception): ...


class NodeImporter:
Expand Down Expand Up @@ -235,22 +234,22 @@ def _populate_graph_attrs(self, container_op: Operation):
else:
default_opset_version = opset_import.version
if default_opset_version:
container_op.attributes[
"torch.onnx_meta.opset_version"
] = IntegerAttr.get(i64_type, default_opset_version)
container_op.attributes["torch.onnx_meta.opset_version"] = (
IntegerAttr.get(i64_type, default_opset_version)
)
if opset_versions:
container_op.attributes[
"torch.onnx_meta.opset_versions"
] = DictAttr.get(opset_versions)
container_op.attributes["torch.onnx_meta.opset_versions"] = (
DictAttr.get(opset_versions)
)
container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get(
IntegerType.get_signed(64), m.ir_version
)
container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get(
m.producer_name
)
container_op.attributes[
"torch.onnx_meta.producer_version"
] = StringAttr.get(m.producer_version)
container_op.attributes["torch.onnx_meta.producer_version"] = (
StringAttr.get(m.producer_version)
)

def import_all(self, func=True):
"""Imports all nodes topologically."""
Expand Down Expand Up @@ -658,9 +657,11 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
RankedTensorType.get(shape, IntegerType.get_signed(64)),
IntegerAttr.get(
IntegerType.get_signed(64),
int.from_bytes(tp.raw_data, "little", signed=True)
if tp.HasField("raw_data")
else tp.int64_data[0],
(
int.from_bytes(tp.raw_data, "little", signed=True)
if tp.HasField("raw_data")
else tp.int64_data[0]
),
),
),
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
Expand Down Expand Up @@ -703,7 +704,7 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
),
onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get(
np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False
)
),
# Intentionally unsupported: STRING
}

Expand Down

0 comments on commit b218519

Please sign in to comment.