From 9f64748f97fa543a2b6b227cd26f570622cd26f1 Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Mon, 29 Apr 2024 10:09:00 +0800 Subject: [PATCH 01/23] [FxImporter] Synchronize the collection of symbolic torch ops (#3236) --- python/torch_mlir/extras/fx_importer.py | 16 ++++------------ python/torch_mlir/fx.py | 4 ++-- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index c1eec37aab00..9acf4ad03a77 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -236,12 +236,6 @@ # set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP if _IS_TORCH_2_1_OR_EARLIER: - SYMBOLIC_TORCH_OPS = { - torch.ops.aten.sym_size, - torch.ops.aten.sym_stride, - torch.ops.aten.sym_numel, - } - SYMBOLIC_OP_TO_TORCH_OP = { (torch.ops.aten.sym_size, 1): torch.ops.aten.size.default, (torch.ops.aten.sym_size, 2): torch.ops.aten.size.int, @@ -249,13 +243,9 @@ (torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int, (torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default, } -else: - SYMBOLIC_TORCH_OPS = { - torch.ops.aten.sym_size.int, - torch.ops.aten.sym_stride.int, - torch.ops.aten.sym_numel.default, - } + SYMBOLIC_TORCH_OPS = {key[0] for key in SYMBOLIC_OP_TO_TORCH_OP} +else: SYMBOLIC_OP_TO_TORCH_OP = { torch.ops.aten.sym_size.default: torch.ops.aten.size.default, torch.ops.aten.sym_size.int: torch.ops.aten.size.int, @@ -264,6 +254,8 @@ torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default, } + SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP} + @dataclass(frozen=True) class SparsityMeta: diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 0879dbe31218..651ccae673a6 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from typing import Optional, Union, Dict, Tuple, Any +from typing import Optional, Union, Dict, Tuple, Any, Callable import warnings @@ -25,7 +25,7 @@ def export_and_import( dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, experimental_support_mutation: bool = False, hooks: Optional[FxImporterHooks] = None, - decomposition_table: Optional[list] = None, + decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None, func_name: str = "main", enable_graph_printing: bool = False, **kwargs, From aed2cf3351ab2ffc8e9ccf1cc7e1f4a498071b13 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Mon, 29 Apr 2024 10:51:17 +0800 Subject: [PATCH 02/23] [Torch] emit aten.__contains__.str_list and add folder (#3249) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 ++++++++++++ .../torch-mlir/Dialect/Torch/IR/TorchOps.h | 31 ++++++++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 24 +++++++++++ .../build_tools/torch_ods_gen.py | 1 + test/Dialect/Torch/canonicalize.mlir | 40 +++++++++++++++---- 5 files changed, 113 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4a234307d00a..8ebd7b162fa7 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -13626,6 +13626,31 @@ def Torch_AtenWarnOp : Torch_Op<"aten.warn", [ }]; } +def Torch_Aten__Contains__StrListOp : Torch_Op<"aten.__contains__.str_list", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::__contains__.str_list : (str[], str) -> (bool)`"; + let arguments = (ins + AnyTorchListOfTorchStringType:$l, + Torch_StringType:$item + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten__Contains__StrListOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten__Contains__StrListOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index 4508518bf297..f49fef0721c2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -239,6 +239,37 @@ m_TorchListOfConstantBools(SmallVectorImpl &bind_values) { return detail::torch_list_of_constant_bools_op_binder(bind_values); } +namespace detail { +/// Matches the constant strs stored in a `torch.ListConstruct`. +struct torch_list_of_constant_strs_op_binder { + SmallVectorImpl &bind_values; + + /// Creates a matcher instance that binds the value to bvs if match succeeds. + torch_list_of_constant_strs_op_binder(SmallVectorImpl &bvs) + : bind_values(bvs) {} + + bool match(Operation *op) { + auto listConstruct = dyn_cast(op); + if (!listConstruct) + return false; + for (Value value : listConstruct.getElements()) { + std::string str; + if (matchPattern(value, m_TorchConstantStr(str))) + bind_values.push_back(str); + else + return false; + } + return true; + } +}; +} // namespace detail + +/// Matches the constant strs stored in a `torch.prim.ListConstruct`. +inline detail::torch_list_of_constant_strs_op_binder +m_TorchListOfConstantStrs(SmallVectorImpl &bind_values) { + return detail::torch_list_of_constant_strs_op_binder(bind_values); +} + namespace detail { /// Matches the expected tensor and dim from `torch.aten.size.int`. struct torch_tensor_size_int_op_binder { diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 33079e35fda1..376e7dd2e584 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2385,6 +2385,30 @@ OpFoldResult AtenNeStrOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// Aten__Contains__StrListOp +//===----------------------------------------------------------------------===// + +OpFoldResult Aten__Contains__StrListOp::fold(FoldAdaptor adaptor) { + StringAttr item = dyn_cast(adaptor.getItem()); + if (!item) + return nullptr; + + if (auto listConstruct = getL().getDefiningOp()) { + if (isListPotentiallyMutated(listConstruct)) + return nullptr; + } + llvm::SmallVector strs; + if (matchPattern(getL(), m_TorchListOfConstantStrs(strs))) { + for (const auto &str : strs) { + if (item.getValue().str() == str) + return getI1IntegerAttr(getContext(), true); + } + return getI1IntegerAttr(getContext(), false); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenLtIntOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 6e449c277456..ca867723cc66 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -974,6 +974,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::format : (...) -> (str)") emit("aten::join : (str, str[]) -> (str)") emit("aten::warn : (str, int) -> ()") + emit("aten::__contains__.str_list : (str[], str) -> (bool)", has_folder=True) # Type conversion ops. emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 7fd4e9832394..a317e4011b3e 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -504,8 +504,8 @@ func.func @torch.aten.eq.str$different_value() -> !torch.bool { // CHECK-LABEL: func.func @torch.aten.eq.str$same_operand( // CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool { -// CHECK-NEXT: %[[F:.*]] = torch.constant.bool true -// CHECK-NEXT: return %[[F]] : !torch.bool +// CHECK-NEXT: %[[TRUE:.*]] = torch.constant.bool true +// CHECK-NEXT: return %[[TRUE]] : !torch.bool func.func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool { %0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool return %0 : !torch.bool @@ -522,8 +522,8 @@ func.func @torch.aten.eq.str$same_value() -> !torch.bool { } // CHECK-LABEL: func.func @torch.aten.ne.str$different_value() -> !torch.bool { -// CHECK: %[[FALSE:.*]] = torch.constant.bool true -// CHECK: return %[[FALSE]] : !torch.bool +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool func.func @torch.aten.ne.str$different_value() -> !torch.bool { %str4 = torch.constant.str "4" %str5 = torch.constant.str "5" @@ -533,16 +533,16 @@ func.func @torch.aten.ne.str$different_value() -> !torch.bool { // CHECK-LABEL: func.func @torch.aten.ne.str$same_operand( // CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool { -// CHECK-NEXT: %[[F:.*]] = torch.constant.bool false -// CHECK-NEXT: return %[[F]] : !torch.bool +// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-NEXT: return %[[FALSE]] : !torch.bool func.func @torch.aten.ne.str$same_operand(%arg0: !torch.str) -> !torch.bool { %0 = torch.aten.ne.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool return %0 : !torch.bool } // CHECK-LABEL: func.func @torch.aten.ne.str$same_value() -> !torch.bool { -// CHECK: %[[TRUE:.*]] = torch.constant.bool false -// CHECK: return %[[TRUE]] : !torch.bool +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool func.func @torch.aten.ne.str$same_value() -> !torch.bool { %str4 = torch.constant.str "4" %str4_0 = torch.constant.str "4" @@ -568,6 +568,30 @@ func.func @torch.aten.len.str$empty() -> !torch.int { return %2 : !torch.int } +// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$false() -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func.func @torch.aten.__contains__.str_list$false() -> !torch.bool { + %str = torch.constant.str "c" + %str_0 = torch.constant.str "b" + %str_1 = torch.constant.str "a" + %1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list + %2 = torch.aten.__contains__.str_list %1, %str : !torch.list, !torch.str -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$true() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func.func @torch.aten.__contains__.str_list$true() -> !torch.bool { + %str = torch.constant.str "aa" + %str_0 = torch.constant.str "aa" + %str_1 = torch.constant.str "ccc" + %1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list + %2 = torch.aten.__contains__.str_list %1, %str : !torch.list, !torch.str -> !torch.bool + return %2 : !torch.bool +} + // CHECK-LABEL: func.func @torch.aten.__not__ // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: return %[[TRUE]] : !torch.bool From b2185195e8fecb3568d53a97a502fc77a22a6daf Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Mon, 29 Apr 2024 11:06:01 +0800 Subject: [PATCH 03/23] [NFC] Update black version (#3256) * Update black version to support 3.11/3.12 * Reformat code --- .pre-commit-config.yaml | 2 +- build_tools/scrape_releases.py | 1 + .../torchscript_stablehlo_backend_tinybert.py | 1 + .../python/torch_mlir/_dynamo_fx_importer.py | 6 ++-- .../build_tools/torch_ods_gen.py | 6 ++-- .../configs/onnx_backend.py | 4 +-- .../linalg_on_tensors_backends/refbackend.py | 10 +++--- .../test_suite/elementwise.py | 4 +-- .../test_suite/slice_like.py | 1 + .../jit_ir/ivalue_import/debug-module-name.py | 1 + .../object-identity-torch-bug.py | 1 + .../jit_ir/ivalue_import/quantization.py | 1 + .../importer/jit_ir/node_import/debug-info.py | 1 + .../importer/jit_ir/node_import/elif.py | 1 + .../jit_ir/node_import/function-derefine.py | 1 + .../python/importer/jit_ir/node_import/if.py | 1 + .../importer/jit_ir/node_import/loop.py | 1 + .../importer/jit_ir/node_import/prim.py | 1 + .../importer/jit_ir/node_import/tuple.py | 1 + .../importer/jit_ir/node_import/types-bool.py | 1 + .../importer/jit_ir/node_import/types-none.py | 1 + .../importer/jit_ir/node_import/utils.py | 1 + python/torch_mlir/extras/fx_importer.py | 3 +- python/torch_mlir/extras/onnx_importer.py | 31 ++++++++++--------- 24 files changed, 49 insertions(+), 33 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 72329026f716..f2938e28e8c7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - id: check-yaml - id: check-added-large-files - repo: https://github.com/psf/black - rev: 22.10.0 + rev: 24.4.2 hooks: - id: black diff --git a/build_tools/scrape_releases.py b/build_tools/scrape_releases.py index 88f19d92bf7c..77aa41c155c6 100644 --- a/build_tools/scrape_releases.py +++ b/build_tools/scrape_releases.py @@ -2,6 +2,7 @@ See https://github.com/llvm/torch-mlir/issues/1374 """ + import argparse import json diff --git a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py index af2af2de3173..840ec519d5c8 100644 --- a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py +++ b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py @@ -3,6 +3,7 @@ from transformers import BertForMaskedLM + # Wrap the bert model to avoid multiple returns problem class BertTinyWrapper(torch.nn.Module): def __init__(self) -> None: diff --git a/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py b/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py index fcea14dc1df2..81908d80164d 100644 --- a/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py +++ b/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py @@ -257,9 +257,9 @@ def __init__(self, g: torch.fx.Graph, func_name: str): # FakeTensor's in case of a tuple return with multiple elements. self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {} self._module = ir.Module.create(ir.Location.unknown()) - self._module.operation.attributes[ - "torch.debug_module_name" - ] = ir.StringAttr.get(func_name) + self._module.operation.attributes["torch.debug_module_name"] = ( + ir.StringAttr.get(func_name) + ) function_type = _extract_function_type_from_graph(g) func = func_dialect.FuncOp( func_name, diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ca867723cc66..eea8d31a95a4 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -285,9 +285,9 @@ def emit_with_mutating_variants(key, **kwargs): (ns, unqual + "_", overload if not is_functional_op else "") ), emitter_td, - traits=["IsTrailingUnderscoreInplaceVariant"] - if not is_functional_op - else [], + traits=( + ["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else [] + ), ) # ========================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py index 6fa845ab377e..7f630074e756 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -46,7 +46,7 @@ def convert_onnx(model, inputs): examples = [] input_names = [] dynamic_tensors = {} - for (index, arg) in enumerate(inputs): + for index, arg in enumerate(inputs): shape = map(lambda d: d if d >= 0 else 1, arg.shape) shape = tuple(shape) examples.append(torch.zeros(size=shape, dtype=arg.dtype)) @@ -55,7 +55,7 @@ def convert_onnx(model, inputs): input_names.append(input_name) dynamic_dims = {} - for (dimindex, dim) in enumerate(arg.shape): + for dimindex, dim in enumerate(arg.shape): if dim < 0: dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex) diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index a1611a1e5d2e..1e958a4d9451 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -101,10 +101,12 @@ def __init__(self, module): def consume_return_funcs(*args): self.result = tuple( [ - arg - if type in elemental_type_to_ctype - else unranked_memref_to_numpy( - arg, memref_type_to_np_dtype[type] + ( + arg + if type in elemental_type_to_ctype + else unranked_memref_to_numpy( + arg, memref_type_to_np_dtype[type] + ) ) for arg, type in zip(args, ret_types) ] diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index d034e6d1f426..8e287584295b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -803,9 +803,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: QuantizedReluInt32()) def QuantizedReluInt32_basic(module, tu: TestUtils): - module.forward( - tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32) - ) + module.forward(tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32)) # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py index 07f064de72ca..be2a80d84427 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -342,6 +342,7 @@ def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils): # ============================================================================== + # For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1). # For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index). class SliceScatterModule(torch.nn.Module): diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py index bd21c4e8bdb3..5af1a6b895c9 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py @@ -11,6 +11,7 @@ mb = ModuleBuilder() + # CHECK: module attributes {torch.debug_module_name = "TestModule"} class TestModule(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py index 4c323ec01e41..4c325308b702 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py @@ -18,6 +18,7 @@ # `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so # naively duplicating a Tensor retains the identity of the TensorImpl. + # CHECK-LABEL: torch.class_type @__torch__.TestModule { class TestModule(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py index e33985fac5cf..df6f1736c79f 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py @@ -12,6 +12,7 @@ mb = ModuleBuilder() + # CHECK-LABEL: torch.class_type @__torch__.TestModule { class TestModule(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py b/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py index 1bc258a42179..7e8df49a0efc 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py @@ -9,6 +9,7 @@ mb = ModuleBuilder() + # CHECK-LABEL: func.func @__torch__.add3 # Note that line-level debug information for parts unannotated in the Torch # graph are ascribed to the first op that carries source information. Presently diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/elif.py b/projects/pt1/test/python/importer/jit_ir/node_import/elif.py index 5ee16e3916b0..f3ee0a557eb7 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/elif.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/elif.py @@ -9,6 +9,7 @@ mb = ModuleBuilder() + # CHECK-LABEL: @__torch__.f @mb.import_function @torch.jit.script diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py b/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py index 2acde08caf5f..f9505b91fda3 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py @@ -11,6 +11,7 @@ mb = ModuleBuilder() + # CHECK-LABEL: func.func @__torch__.optional_return( # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional { # CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/if.py b/projects/pt1/test/python/importer/jit_ir/node_import/if.py index 86390f707d12..02cb8d9f0033 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/if.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/if.py @@ -13,6 +13,7 @@ # else branch and making all defined values optional, so no special handling # is needed. + # CHECK-LABEL: @__torch__.prim_If( # CHECK-SAME: %[[B:.*]]: !torch.bool, # CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int { diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/loop.py b/projects/pt1/test/python/importer/jit_ir/node_import/loop.py index d432cd6ee35c..b28d63bb0927 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/loop.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/loop.py @@ -11,6 +11,7 @@ mb = ModuleBuilder() + # CHECK-LABEL: func.func @__torch__.prim_Loop_forlike( # CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float { # CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/prim.py b/projects/pt1/test/python/importer/jit_ir/node_import/prim.py index 66959257e9b3..759292b6d35b 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/prim.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/prim.py @@ -15,6 +15,7 @@ mb = ModuleBuilder() + # CHECK-LABEL: func.func @__torch__.prim_NumToTensor( # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor { # CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py b/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py index a1f06c3902ad..b6a313cd4345 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py @@ -13,6 +13,7 @@ mb = ModuleBuilder() NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])]) + # CHECK-LABEL: func.func @__torch__.tuple( # CHECK-SAME: %[[T0:.*]]: !torch.tensor, # CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py b/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py index 0a27692fc8f3..7cd4c3c16148 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py @@ -9,6 +9,7 @@ mb = ModuleBuilder() + # CHECK: @__torch__.returns_bool @mb.import_function @torch.jit.script diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py b/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py index 16a3359da1bb..b0358467ca63 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py @@ -9,6 +9,7 @@ mb = ModuleBuilder() + # CHECK: @__torch__.returns_none @mb.import_function @torch.jit.script diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/utils.py b/projects/pt1/test/python/importer/jit_ir/node_import/utils.py index 613ccb6a8502..b06c38fdf285 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/utils.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/utils.py @@ -9,6 +9,7 @@ # RUN: %PYTHON %s + # Import TorchScript IR string as ScriptFunction. def create_script_function(func_name, ts_ir_str, **kwargs): cu = CompilationUnit() diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 9acf4ad03a77..24bda3f5b2c6 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1849,8 +1849,7 @@ def _emit_operation( # Opaque value to indicate something is empty. Used in cases where 'None' # may have a different meaning. -class EmptyType: - ... +class EmptyType: ... Empty = EmptyType() diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index f1064f976504..8d0e4cf5a8e1 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -156,8 +156,7 @@ def find_type_proto_for_name(self, name: str) -> onnx.TypeProto: return "" -class OnnxImportError(Exception): - ... +class OnnxImportError(Exception): ... class NodeImporter: @@ -235,22 +234,22 @@ def _populate_graph_attrs(self, container_op: Operation): else: default_opset_version = opset_import.version if default_opset_version: - container_op.attributes[ - "torch.onnx_meta.opset_version" - ] = IntegerAttr.get(i64_type, default_opset_version) + container_op.attributes["torch.onnx_meta.opset_version"] = ( + IntegerAttr.get(i64_type, default_opset_version) + ) if opset_versions: - container_op.attributes[ - "torch.onnx_meta.opset_versions" - ] = DictAttr.get(opset_versions) + container_op.attributes["torch.onnx_meta.opset_versions"] = ( + DictAttr.get(opset_versions) + ) container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get( IntegerType.get_signed(64), m.ir_version ) container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get( m.producer_name ) - container_op.attributes[ - "torch.onnx_meta.producer_version" - ] = StringAttr.get(m.producer_version) + container_op.attributes["torch.onnx_meta.producer_version"] = ( + StringAttr.get(m.producer_version) + ) def import_all(self, func=True): """Imports all nodes topologically.""" @@ -658,9 +657,11 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: RankedTensorType.get(shape, IntegerType.get_signed(64)), IntegerAttr.get( IntegerType.get_signed(64), - int.from_bytes(tp.raw_data, "little", signed=True) - if tp.HasField("raw_data") - else tp.int64_data[0], + ( + int.from_bytes(tp.raw_data, "little", signed=True) + if tp.HasField("raw_data") + else tp.int64_data[0] + ), ), ), # TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB @@ -703,7 +704,7 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: ), onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get( np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False - ) + ), # Intentionally unsupported: STRING } From b1e22414794db1b25d938b6f8f7dca6376a50990 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 29 Apr 2024 09:30:01 +0530 Subject: [PATCH 04/23] [ONNX] Fix Onnx.Selu lowering and canonicalizer for IntImplicit op (#3221) Signed-Off By: Vivek Khandelwal --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 35 ++++++++++++++++--- lib/Dialect/Torch/IR/TorchOps.cpp | 19 +++++++--- projects/pt1/e2e_testing/xfail_sets.py | 3 -- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 16 ++++++--- 4 files changed, 56 insertions(+), 17 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 586b8d4ff053..edb36aee967b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -847,15 +847,21 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( patterns.onOp( "Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // y = gamma * (alpha * e^x - alpha) for x <= 0, y = gamma * x for x > 0 Torch::ValueTensorType resultType; float alpha, gamma; Value operand; + // Refer https://onnx.ai/onnx/operators/onnx__Selu.html for the default + // alpha and gamma values. if (binder.tensorOperand(operand) || - binder.f32FloatAttr(alpha, "alpha") || - binder.f32FloatAttr(gamma, "gamma") || + binder.f32FloatAttr(alpha, "alpha", 1.67326) || + binder.f32FloatAttr(gamma, "gamma", 1.0507) || binder.tensorResultType(resultType)) return failure(); + Torch::ValueTensorType inputType = + operand.getType().cast(); + Value vAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), alpha)); @@ -864,12 +870,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), gamma)); - Value vInputScale = rewriter.create( + Value cstOne = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, vAlpha, vScale, vInputScale); + Value cstNone = rewriter.create(binder.getLoc()); + Value zeroTensor = rewriter.create( + binder.getLoc(), resultType, operand, cstNone, cstNone, cstNone, + cstNone, cstNone); + Value exp = rewriter.create(binder.getLoc(), + resultType, operand); + Value expMulAlpha = rewriter.create( + binder.getLoc(), resultType, exp, vAlpha); + Value expMulAlphaSubAlpha = rewriter.create( + binder.getLoc(), resultType, expMulAlpha, vAlpha, cstOne); + Value neg = rewriter.create( + binder.getLoc(), resultType, expMulAlphaSubAlpha, vScale); + Value pos = rewriter.create( + binder.getLoc(), resultType, operand, vScale); + Type compareType = inputType.getWithSizesAndDtype( + inputType.getOptionalSizes(), rewriter.getI1Type()); + Value xLessThanZero = rewriter.create( + binder.getLoc(), compareType, operand, zeroTensor); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, xLessThanZero, neg, pos); return success(); }); patterns.onOp("ReduceL1", 1, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 376e7dd2e584..29911961dc06 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -140,7 +140,7 @@ static Value getScalarIntValue(Value input, Location loc, return nullptr; Type inputDtype = inputTensorType.getOptionalDtype(); - if (!inputDtype || !inputDtype.isInteger(64)) + if (!inputDtype || !(inputDtype.isInteger(64) || inputDtype.isInteger(1))) return nullptr; std::optional inputRank = getTensorRank(input); @@ -148,10 +148,19 @@ static Value getScalarIntValue(Value input, Location loc, return nullptr; if (auto valueTensorLiteralOp = input.getDefiningOp()) { - auto val = cast(valueTensorLiteralOp.getValue()) - .getSplatValue(); - return rewriter.create( - loc, rewriter.getI64IntegerAttr(val)); + if (inputDtype.isInteger(64)) { + auto val = valueTensorLiteralOp.getValue() + .cast() + .getSplatValue(); + return rewriter.create( + loc, rewriter.getI64IntegerAttr(val)); + } else { + auto val = valueTensorLiteralOp.getValue() + .cast() + .getSplatValue(); + return rewriter.create( + loc, rewriter.getI64IntegerAttr(val)); + } } else if (auto primNumToTensorScalarOp = input.getDefiningOp()) { return primNumToTensorScalarOp.getA(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 276cc47c1cc6..e45839617a08 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2124,7 +2124,6 @@ "ElementwiseAtenFloorDivideTensorNegativeModule_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog2IntModule_basic", - "ElementwiseSeluModule_basic", "FlipModuleStaticShape_basic", "FlipNegativeIndexModule_basic", "HardsigmoidModule_basic", @@ -2637,8 +2636,6 @@ "CopyWithDifferentDTypesModule_basic", "CosineSimilarityStaticBroadcastModule_basic", "CumsumInputDtypeInt32Module_basic", - "DropoutTrainModule_basic", - "DropoutTrainStaticShapeModule_basic", "ElementwiseAcosIntModule_basic", "ElementwiseAsinIntModule_basic", "ElementwiseAtanTensorIntModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 9c0ab351297f..5fe9c79d3089 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -582,10 +582,18 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to // CHECK-LABEL: func.func @test_selu func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} { - // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1 - // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2 - // CHECK-DAG: %[[F3:.+]] = torch.constant.float 3 - // CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]] + // CHECK: %[[F2:.+]] = torch.constant.float 2.000000e+00 + // CHECK: %[[F3:.+]] = torch.constant.float 3.000000e+00 + // CHECK: %[[F1:.+]] = torch.constant.float 1.000000e+00 + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[ZEROS:.+]] = torch.aten.zeros_like %arg0, %none, %none, %none, %none, %none : !torch.vtensor<[3,4,5],f32>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[EXP:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[MUL:.+]] = torch.aten.mul.Scalar %[[EXP]], %[[F2]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[MUL]], %[[F2]], %[[F1]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[MUL_1:.+]] = torch.aten.mul.Scalar %[[SUB]], %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[MUL_2:.+]] = torch.aten.mul.Scalar %arg0, %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[LT:.+]] = torch.aten.lt.Tensor %arg0, %[[ZEROS]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> + // CHECK: torch.aten.where.self %[[LT]], %[[MUL_1]], %[[MUL_2]] : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } From 0a5ff68d9d57c9c3948b6d60c1edb32da9fe3670 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Mon, 29 Apr 2024 17:40:30 +0800 Subject: [PATCH 05/23] [stablehlo] Support PrimsCollapseOp and PrimsSplitDimOp in stablehlo (#3230) --- .../TorchToStablehlo/StablehloLegalizeUtils.h | 11 ++ .../StablehloLegalizeUtils.cpp | 131 ++++++++++++++++++ lib/Conversion/TorchToStablehlo/ViewLike.cpp | 61 ++++---- projects/pt1/e2e_testing/xfail_sets.py | 8 +- 4 files changed, 181 insertions(+), 30 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h index 6e14b324b656..734ba81ea07a 100644 --- a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h @@ -69,6 +69,17 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, Value tensor, ArrayRef inputUnsqzDims, size_t dimSizeIndexBits); +// Get a tensor that collapse the specified dimensions of the input tensor +FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, int64_t collapseStartDim, + int64_t collapseEndDim, + size_t dimSizeIndexBits); + +// Get a tensor that splits the specified dimensions of the input tensor +FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, int64_t splitDim, + int64_t outerLength, size_t dimSizeIndexBits); + Value getConstantOfShape(PatternRewriter &rewriter, Location loc, const APFloat &constant, Value shape, TensorType outType); diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index 40ec715cd62e..c4d629d4f5bb 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -9,6 +9,7 @@ #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" @@ -306,6 +307,136 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, .getResult(); } +FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, int64_t collapseStartDim, + int64_t collapseEndDim, + size_t dimSizeIndexBits) { + + auto dimSizesInfo = + getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits); + + if (failed(dimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + + auto dimSizes = *dimSizesInfo; + int64_t rank = dimSizes.size(); + + collapseStartDim = toPositiveDim(collapseStartDim, rank); + collapseEndDim = toPositiveDim(collapseEndDim, rank); + + int64_t newRank = rank - (collapseEndDim - collapseStartDim + 1); + + auto loc = op->getLoc(); + auto rankTy = dyn_cast(tensor.getType()); + auto oldShape = rankTy.getShape(); + Type intType = rewriter.getIntegerType(dimSizeIndexBits); + + std::vector newDimSizes; + std::vector newShape; + newDimSizes.reserve(newRank); + newShape.reserve(newRank); + + Value collapseDimSize = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + int64_t collapseShape = 1; + + for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) { + if (k < 0 || k >= rank) { + return rewriter.notifyMatchFailure( + op, "collapse dimensions must be within the rank of the tensor"); + } + if (collapseShape == ShapedType::kDynamic || + oldShape[k] == ShapedType::kDynamic) { + collapseShape = ShapedType::kDynamic; + } else { + collapseShape *= oldShape[k]; + } + collapseDimSize = + rewriter.create(loc, collapseDimSize, dimSizes[k]); + } + + for (int64_t k = 0; k < collapseStartDim; ++k) { + newDimSizes.push_back(dimSizes[k]); + newShape.push_back(oldShape[k]); + } + newDimSizes.push_back(collapseDimSize); + newShape.push_back(collapseShape); + for (int64_t k = collapseEndDim + 1; k < rank; ++k) { + newDimSizes.push_back(dimSizes[k]); + newShape.push_back(oldShape[k]); + } + + auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); + auto shape = rewriter.create(loc, newDimSizes); + return rewriter.create(loc, outTy, tensor, shape) + .getResult(); +} + +// TODO: support splitDim & outerLength to be Value +FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, int64_t splitDim, + int64_t outerLength, size_t dimSizeIndexBits) { + auto dimSizesInfo = + getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits); + + if (failed(dimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + + auto dimSizes = *dimSizesInfo; + int64_t rank = dimSizes.size(); + splitDim = toPositiveDim(splitDim, rank); + + auto loc = op->getLoc(); + auto rankTy = dyn_cast(tensor.getType()); + auto oldShape = rankTy.getShape(); + Type intType = rewriter.getIntegerType(dimSizeIndexBits); + + if (splitDim < 0 || splitDim >= rank) { + return rewriter.notifyMatchFailure( + op, "split dimensions must be within the rank of the tensor"); + } + + int64_t newRank = rank + 1; + auto outerLengthValue = rewriter.create( + loc, rewriter.getIntegerAttr(intType, outerLength)); + + auto innerLengthValue = rewriter.create( + loc, dimSizes[splitDim], outerLengthValue); + + int64_t originShape = oldShape[splitDim]; + int64_t outerShape = outerLength; + int64_t innerShape = originShape == ShapedType::kDynamic + ? ShapedType::kDynamic + : originShape / outerLength; + + std::vector newDimSizes; + std::vector newShape; + + newDimSizes.reserve(newRank); + newShape.reserve(newRank); + + for (int64_t k = 0; k < splitDim; ++k) { + newDimSizes.push_back(dimSizes[k]); + newShape.push_back(oldShape[k]); + } + newDimSizes.push_back(outerLengthValue); + newShape.push_back(outerShape); + newDimSizes.push_back(innerLengthValue); + newShape.push_back(innerShape); + + for (int64_t k = splitDim + 1; k < rank; ++k) { + newDimSizes.push_back(dimSizes[k]); + newShape.push_back(oldShape[k]); + } + + auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); + auto shape = rewriter.create(loc, newDimSizes); + return rewriter.create(loc, outTy, tensor, shape) + .getResult(); +} + Value getConstantOfShape(PatternRewriter &rewriter, Location loc, const APFloat &constant, Value shape, TensorType outType) { diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index e43105ea1b2b..04952d84343a 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -414,34 +414,44 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "only constant end is currently supported"); - start = toPositiveDim(start, rank); - end = toPositiveDim(end, rank); - SmallVector dims; - dims.reserve(rank); - for (int r = 0; r < start; ++r) - dims.push_back(r); - int64_t collapsedDimSize = 1; - for (int r = start; r <= end; ++r) { - if (selfType.getShape()[r] == ShapedType::kDynamic) - return rewriter.notifyMatchFailure( - op, "the size of the dimension being collapsed is can't be unknown"); - collapsedDimSize *= selfType.getShape()[r]; + auto collapseTensorInfo = hlo::collapseTensor( + rewriter, op, adaptor.getA(), start, end, options.dimSizeIndexBits); + if (failed(collapseTensorInfo)) + return rewriter.notifyMatchFailure(op, "failed to create collapsed tensor"); + + rewriter.replaceOp(op, *collapseTensorInfo); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + PrimsSplitDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto selfType = adaptor.getA().getType().dyn_cast(); + if (!selfType) { + return op.emitError("only tensor types are currently supported"); } - dims.push_back(collapsedDimSize); - for (int r = end + 1; r < rank; ++r) - dims.push_back(r); - auto newDimSizesInfo = hlo::getDimSizesOfTensor( - rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits); - if (failed(newDimSizesInfo)) + auto rank = selfType.getRank(); + if (rank == 0) return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); - auto newDimSizes = *newDimSizesInfo; - auto stablehloShape = - rewriter.create(op.getLoc(), newDimSizes); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), adaptor.getA(), - stablehloShape); + op, "the rank of tensor must be greater than 0"); + + int64_t dim, outerLength; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "only constant dim is currently supported"); + if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength))) + return rewriter.notifyMatchFailure( + op, "only constant outerLength is currently supported"); + + auto splitTensorInfo = hlo::splitTensor( + rewriter, op, adaptor.getA(), dim, outerLength, options.dimSizeIndexBits); + + if (failed(splitTensorInfo)) + return rewriter.notifyMatchFailure(op, "failed to create split tensor"); + + rewriter.replaceOp(op, *splitTensorInfo); return success(); } @@ -458,6 +468,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenSqueezeDimOp); INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); INSERT_ATENOP_PATTERN(PrimsCollapseOp); + INSERT_ATENOP_PATTERN(PrimsSplitDimOp); #undef INSERT_ATENOP_PATTERN #define INSERT_VIEW_OP_PATTERN(AtenOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e45839617a08..10c24b657128 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -678,11 +678,6 @@ "NumToTensorIntModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", - "PixelShuffleModuleFullDynamic_basic", - "PixelShuffleModuleSpatiallyDynamic_basic", - "PixelShuffleModuleSpatiallyStatic_basic", - "PixelShuffleModuleStaticRank3Int64_basic", - "PixelShuffleModuleStaticRank4Float32_basic", "PowIntFloatModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", @@ -1157,6 +1152,8 @@ "Permute0RankModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", + "PixelShuffleModuleStaticRank3Int64_basic", + "PixelShuffleModuleStaticRank4Float32_basic", "PowIntFloatModule_basic", "PrimListUnpackNumMismatchModule_basic", "PrimMaxIntModule_basic", @@ -1240,6 +1237,7 @@ "SliceWholeTensorModule_basic", "SortIntListReverse_basic", "SortIntList_basic", + "SplitDimStaticModule_basic", "SplitTensorGetItem_Module_basic", "SplitTensorLastSmallerModule_basic", "SplitTensorListUnpackModule_basic", From 2176176fefd696d929b9d61b5587a419fae8386d Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Mon, 29 Apr 2024 09:21:12 -0700 Subject: [PATCH 06/23] [FX] Add broadcast test with dynamic dim (#3123) This scenario was uncovered in a downstream test that failed with a previous snapshot of torch-mlir. See https://github.com/cruise-automation/mlir-tcp/actions/runs/8605480116/job/23581829102?pr=65. ``` File "/home/runner/.cache/bazel/_bazel_runner/ce288f117ee4ca92dc028a6a28476a3d/sandbox/processwrapper-sandbox/2380/execroot/mlir-tcp/bazel-out/k8-opt-exec-2B5CBBC6/bin/test/AotCompile/broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic_torch_exporter.runfiles/pip_deps_torch_mlir/site-packages/torch_mlir/extras/fx_importer.py", line 969, in value_info_to_type raise NotImplementedError( NotImplementedError: Could not deduce type from value info: tensor_meta=None, val=s1, sparsity=None ``` It seems to have resolved on current HEAD. Adding this test to ensure coverage in the future. --- test/python/fx_importer/basic_test.py | 29 ++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 08ef9fdc9cd3..fde318630077 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -105,6 +105,33 @@ def forward(self, x): print(m) +@run +# CHECK-LABEL: test_broadcast_with_dynamic_shapes +# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32> +def test_broadcast_with_dynamic_shapes(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.broadcast_to(x, (y.shape[0], -1)) + + # Sample inputs + x = torch.randn(1, 2) + y = torch.randn(10) + + dim_0 = Dim("dim_0") + dynamic_shapes = { + "x": {}, + "y": {0: dim_0}, + } + + m = fx.export_and_import( + Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net" + ) + print(m) + + @make_boxed_compiler def fx_import_aot_autograd_backend( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor] @@ -117,7 +144,7 @@ def fx_import_aot_autograd_backend( @run # CHECK-LABEL: test_stateless_fx_import -# CHECK: func.func @basic_forward__6_inference_0(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +# CHECK: func.func @[[basic:[a-zA-Z0-9_]+]](%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> # CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> # CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32> def test_stateless_fx_import(): From 087fea0608dac3995b74e5c22ae7950287fe7a73 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 29 Apr 2024 21:54:04 +0530 Subject: [PATCH 07/23] build: manually update PyTorch version (#3257) Set PyTorch and TorchVision version to nightly release 2024-04-28. Signed-Off By: Vivek Khandelwal --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index c2f8c830c0ee..400586976392 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -0a3e5f5badd8a0cb7fac97f5ec9d48c304e5c0b7 +34ade3521ca41f20af3469bba276c2b0499c3892 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 256104030b0a..7cd8d44e5425 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.4.0.dev20240422 +torch==2.4.0.dev20240428 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index a530cc800fcc..148f66152b88 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.19.0.dev20240422 +torchvision==0.19.0.dev20240428 From db6721084a2b3f41216e9cc7e0ea9263c33f196e Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 29 Apr 2024 12:01:40 -0700 Subject: [PATCH 08/23] Integrate LLVM at llvm/llvm-project@593f6fdcb4bb3ff81ba4e6f89d7b16540c4b9eaf (#3260) --- externals/llvm-project | 2 +- .../Dialect/TMTensor/IR/TMTensorInterfaces.h | 4 ++-- lib/Dialect/TMTensor/IR/TMTensorOps.cpp | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/externals/llvm-project b/externals/llvm-project index a952c123880e..593f6fdcb4bb 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit a952c123880eb1168f1021b116485e27170d48ca +Subproject commit 593f6fdcb4bb3ff81ba4e6f89d7b16540c4b9eaf diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h index 159bcea7899e..50045438f584 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h @@ -30,8 +30,6 @@ namespace detail { LogicalResult verifyTMTensorOpInterface(Operation *op); } -#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export - /// Include the generated interface declarations. #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterfaces.h.inc" // IWYU pragma: export @@ -39,4 +37,6 @@ LogicalResult verifyTMTensorOpInterface(Operation *op); } // namespace torch } // namespace mlir +#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export + #endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_ diff --git a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index be07ca276dd6..218ecad3388f 100644 --- a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -936,7 +936,7 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern { // If no operand comes from a tensor::CastOp and can be folded then fail. bool hasTensorCastOperand = llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) { - if (opOperand->get().isa()) + if (isa(opOperand->get())) return false; auto castOp = opOperand->get().getDefiningOp(); return castOp && canFoldIntoConsumerOp(castOp); From 122cf22cc2b1d2006607dc18e8d2309a94172321 Mon Sep 17 00:00:00 2001 From: "Jae Hoon (Antonio) Kim" <17433012+antoniojkim@users.noreply.github.com> Date: Mon, 29 Apr 2024 15:02:12 -0400 Subject: [PATCH 09/23] Re-enable LTC Build (#3261) The LTC Build was disabled in https://github.com/llvm/torch-mlir/pull/3210 due to a regression in the packaging of the torch nightly wheels (https://github.com/pytorch/pytorch/issues/124941) which is now resolved. So, re-enabling LTC build in this PR --- build_tools/ci/build_posix.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/build_tools/ci/build_posix.sh b/build_tools/ci/build_posix.sh index bacb736ba1f2..fec5e252e8d7 100755 --- a/build_tools/ci/build_posix.sh +++ b/build_tools/ci/build_posix.sh @@ -50,6 +50,7 @@ cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \ -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$repo_root" \ -DLLVM_TARGETS_TO_BUILD=host \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DTORCH_MLIR_ENABLE_LTC=ON echo "::endgroup::" echo "::group::Build" From b64c22cfc12c110f9e77857530d014978b2577b8 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Tue, 30 Apr 2024 00:44:41 -0700 Subject: [PATCH 10/23] Fix onnx sinh lowering (#3253) iree tests `test_sinh` and `test_sinh_example` passed --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 35 +++++++++++++------ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 10 ++++-- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index edb36aee967b..197d9c536b9f 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1449,18 +1449,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); - patterns.onOp("Sinh", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); + patterns.onOp( + "Sinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + // 1/2 * (exp(x) – exp(-x)) + Value x = rewriter.create(binder.getLoc(), resultType, + operand); + Value neg = rewriter.create(binder.getLoc(), + resultType, operand); + Value y = + rewriter.create(binder.getLoc(), resultType, neg); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value z = rewriter.create( + binder.getLoc(), resultType, x, y, cstOne); + Value cstTwo = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, z, cstTwo); + return success(); + }); // split with fixed-size parts // Arguments: diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 5fe9c79d3089..2748a640a6b3 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1265,9 +1265,15 @@ func.func @test_reduce_prod_keepdims_random(%arg0: !torch.vtensor<[3,2,2],f32>, // ----- -// CHECK-LABEL: func.func @test_sinh +// CHECK-LABEL: func.func @test_sinh_example func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} { - // CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[SUB:.+]] = torch.aten.sub.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[C2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.div.Scalar %[[SUB]], %[[C2]] : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } From aa471f1d9612eb3a3b47a041aaef565944398dd2 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Tue, 30 Apr 2024 00:49:29 -0700 Subject: [PATCH 11/23] Fix onnx cosh lowering (#3254) iree tests `test_cosh` and `test_cosh_example` passed --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 36 +++++++++++++------ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 16 +++++++-- 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 96f4e55fb12d..401c83991f7b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1348,17 +1348,31 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp("Cosh", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + patterns.onOp( + "Cosh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + // 1/2 * (exp(x) + exp(-x)) + Value x = rewriter.create(binder.getLoc(), resultType, + operand); + Value neg = rewriter.create(binder.getLoc(), + resultType, operand); + Value y = + rewriter.create(binder.getLoc(), resultType, neg); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value z = rewriter.create( + binder.getLoc(), resultType, x, y, cstOne); + Value cstTwo = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, z, cstTwo); + return success(); + }); patterns.onOp( "CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Location loc = binder.getLoc(); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index f53e55a1679b..0719512f0d5b 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -665,7 +665,13 @@ func.func @test_cos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5 // CHECK-LABEL: @test_cosh_example func.func @test_cosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[C2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.div.Scalar %[[ADD]], %[[C2]] : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } @@ -674,7 +680,13 @@ func.func @test_cosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[ // CHECK-LABEL: @test_cosh func.func @test_cosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[C2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.div.Scalar %[[ADD]], %[[C2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } From fb499192dfe60476c72838e84b8d5b42dfcd6072 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Tue, 30 Apr 2024 00:49:44 -0700 Subject: [PATCH 12/23] Fix onnx acosh lowering (#3262) iree tests `test_acosh` and `test_acosh_example` passed --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 34 +++++++++++++------ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 14 ++++++-- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 401c83991f7b..f5b05327c2cb 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -242,17 +242,29 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp("Acosh", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + patterns.onOp( + "Acosh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + // log(x + sqrt(x**2 - 1)) + Value square = rewriter.create( + binder.getLoc(), resultType, operand); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value sub = rewriter.create( + binder.getLoc(), resultType, square, cstOne, cstOne); + Value sqrt = rewriter.create(binder.getLoc(), + resultType, sub); + Value add = rewriter.create( + binder.getLoc(), resultType, operand, sqrt, cstOne); + rewriter.replaceOpWithNewOp(binder.op, resultType, + add); + return success(); + }); patterns.onOp("BatchNormalization", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 0719512f0d5b..967c35f130d8 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -695,7 +695,12 @@ func.func @test_cosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK-LABEL: @test_acosh_example func.func @test_acosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[SUB]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.log %[[ADD]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } @@ -704,7 +709,12 @@ func.func @test_acosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor< // CHECK-LABEL: @test_acosh func.func @test_acosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[SUB]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.log %[[ADD]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } From bf04b53b072aa90ea72723b0189c418cbdc4857f Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Tue, 30 Apr 2024 00:49:57 -0700 Subject: [PATCH 13/23] Fix onnx asinh lowering (#3263) iree tests `test_asinh` and `test_asinh_example` passed --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 34 +++++++++++++------ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 14 ++++++-- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index f5b05327c2cb..7b44e8510bbb 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -198,17 +198,29 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp("Asinh", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + patterns.onOp( + "Asinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + // log(x + sqrt(x**2 + 1)) + Value square = rewriter.create( + binder.getLoc(), resultType, operand); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value add0 = rewriter.create( + binder.getLoc(), resultType, square, cstOne, cstOne); + Value sqrt = rewriter.create(binder.getLoc(), + resultType, add0); + Value add1 = rewriter.create( + binder.getLoc(), resultType, operand, sqrt, cstOne); + rewriter.replaceOpWithNewOp(binder.op, resultType, + add1); + return success(); + }); patterns.onOp("Atan", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 967c35f130d8..aca59b8aec1a 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -741,7 +741,12 @@ func.func @test_asin(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK-LABEL: @test_asinh_example func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[ADD]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[ADD_0:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.log %[[ADD_0]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } @@ -750,7 +755,12 @@ func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor< // CHECK-LABEL: @test_asinh func.func @test_asinh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[ADD]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[ADD_0:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.log %[[ADD_0]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } From fbbad2d81e7cad20b2590fbd2087889a207e2eb6 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Tue, 30 Apr 2024 00:50:08 -0700 Subject: [PATCH 14/23] Fix onnx atanh lowering (#3264) iree tests `test_atanh` and `test_atanh_example` passed --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 38 +++++++++++++------ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 9 ++++- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 7b44e8510bbb..716ea3d6e202 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -232,17 +232,33 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp("Atanh", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + patterns.onOp( + "Atanh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + // 1/2 * log((1 + x) / (1 - x)) + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value add = rewriter.create( + binder.getLoc(), resultType, operand, cstOne, cstOne); + Value neg = rewriter.create(binder.getLoc(), + resultType, operand); + Value sub = rewriter.create( + binder.getLoc(), resultType, neg, cstOne, cstOne); + Value div = rewriter.create( + binder.getLoc(), resultType, add, sub); + Value log = + rewriter.create(binder.getLoc(), resultType, div); + Value cstTwo = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, log, cstTwo); + return success(); + }); patterns.onOp("Acos", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index aca59b8aec1a..eb2cde696ef1 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -201,7 +201,14 @@ func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK-LABEL: @test_atanh func.func @test_atanh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.atanh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg0, %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[NEG:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SUB:.*]] = torch.aten.add.Scalar %[[NEG]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[DIV:.*]] = torch.aten.div.Tensor %[[ADD]], %[[SUB]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[LOG:.*]] = torch.aten.log %[[DIV]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: torch.aten.div.Scalar %[[LOG]], %[[C2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Atanh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } From fb8aed09076bc5073808dfe7057268b3b80543d7 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Tue, 30 Apr 2024 00:55:25 -0700 Subject: [PATCH 15/23] [Release Builds] Use `-no-build-isolation` to decouple from `pyproject.toml` (#3266) Fixes https://github.com/llvm/torch-mlir/issues/3258 In addition disabling the LTC builds since they are already covered in CI (build_posix.sh) and I am not aware of a consumer of this flow in the binary releases of torch-mlir (the main dependency there is from source). --- build_tools/python_deploy/build_linux_packages.sh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 4feccdd64029..625020836797 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -432,6 +432,8 @@ function clean_build() { } function build_torch_mlir() { + # Disable LTC build for releases + export TORCH_MLIR_ENABLE_LTC=0 local torch_version="$1" case $torch_version in nightly) @@ -440,7 +442,7 @@ function build_torch_mlir() { --extra-index-url https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html CMAKE_GENERATOR=Ninja \ TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ - python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir \ + python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir \ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html \ -r /main_checkout/torch-mlir/whl-requirements.txt ;; @@ -450,7 +452,7 @@ function build_torch_mlir() { python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt CMAKE_GENERATOR=Ninja \ TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ - python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir + python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir ;; *) echo "Unrecognized torch version '$torch_version'" @@ -474,7 +476,7 @@ function build_torch_mlir_core() { TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \ TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \ - python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir + python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir } function clean_wheels() { From f32ada993d393581ae1e70ac6b47dbdd4a70dca1 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Wed, 1 May 2024 00:06:13 +0800 Subject: [PATCH 16/23] [Stablehlo] Improve the lowering of pool op in stablehlo (#3259) 1. Handle case stride == None 2. add avgpool3d maxpool1d maxpool3d lowering --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 28 ++ lib/Conversion/TorchToStablehlo/Pooling.cpp | 279 +++++++++++------- .../Transforms/AbstractInterpLibrary.cpp | 14 +- projects/pt1/e2e_testing/xfail_sets.py | 3 + .../build_tools/abstract_interp_lib_gen.py | 13 +- .../build_tools/torch_ods_gen.py | 1 + 6 files changed, 216 insertions(+), 122 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 8ebd7b162fa7..cb08ffd533b9 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6637,6 +6637,34 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [ }]; } +def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$ceil_mode + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxPool1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenMaxPool1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 132410a2a358..9219b4af355f 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -36,7 +36,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, auto constType = RankedTensorType::get({}, elementTy); // Avg pooling if (isa(op)) { + AtenAvgPool3dOp, AtenCumsumOp>(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( @@ -54,7 +54,8 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, } // Max pooling - if (isa(op)) { + if (isa(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, @@ -75,101 +76,6 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, return nullptr; } -// AtenMaxPool2dOp -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenMaxPool2dOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = cast(input.getType()); - auto inputElemTy = inputTy.getElementType(); - - auto inputRank = inputTy.getRank(); - auto outTy = - cast(getTypeConverter()->convertType(op.getType())); - - if (inputRank <= 2) { - return op.emitError( - "max_pooling2d only supports inputs with rank higher than 2"); - } - SmallVector padding, kernelSize, stride, dilation; - bool ceilMode = false; - - if (!(matchPattern(op.getKernelSize(), - m_TorchListOfConstantInts(kernelSize)))) { - return rewriter.notifyMatchFailure( - op, "non-const int kernel size unsupported!"); - } - if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { - return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); - } - if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { - return rewriter.notifyMatchFailure(op, - "non-const int padding unsupported!"); - } - if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) { - return rewriter.notifyMatchFailure(op, - "non-const int dilation unsupported!"); - } - if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { - return rewriter.notifyMatchFailure(op, - "non-const bool ceil_mode unsupported!"); - } - - // prepend 1 to kernelSize, stride, dilation until they are of same rank as - // input - SmallVector stablehloStride(inputRank, 1); - SmallVector stablehloDilation(inputRank, 1); - SmallVector stablehloKernelSize(inputRank, 1); - SmallVector stablehloPadding(inputRank * 2, 0); - std::copy(dilation.begin(), dilation.end(), - stablehloDilation.begin() + inputRank - 2); - std::copy(stride.begin(), stride.end(), - stablehloStride.begin() + inputRank - 2); - std::copy(kernelSize.begin(), kernelSize.end(), - stablehloKernelSize.begin() + inputRank - 2); - - Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - - stablehloPadding[stablehloPadding.size() - 4] = padding[0]; - stablehloPadding[stablehloPadding.size() - 3] = padding[0]; - stablehloPadding[stablehloPadding.size() - 2] = padding[1]; - stablehloPadding[stablehloPadding.size() - 1] = padding[1]; - - auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); - auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); - DenseI64ArrayAttr baseDilations; - auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); - DenseIntElementsAttr pad = DenseIntElementsAttr::get( - RankedTensorType::get( - {static_cast(inputRank), static_cast(2)}, - rewriter.getI64Type()), - stablehloPadding); - auto reduceWindowOp = rewriter.create( - op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, - baseDilations, windowDilations, pad); - - Block &block = reduceWindowOp.getBody().emplaceBlock(); - - auto blockArgumentTy = RankedTensorType::get({}, inputElemTy); - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArg = block.args_begin(); - auto secondArg = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value result = - rewriter.create(op->getLoc(), *firstArg, *secondArg); - rewriter.create(op->getLoc(), result); - } - - rewriter.replaceOp(op, reduceWindowOp.getResults()); - return success(); -} - // AtenMaxPool2dWithIndicesOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -356,6 +262,129 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +namespace { +template +class ConvertAtenMaxPoolOp : public ConvertAtenOp { +public: + using ConvertAtenOp::ConvertAtenOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getSelf(); + auto inputTy = cast(input.getType()); + auto inputElemTy = inputTy.getElementType(); + auto inputRank = inputTy.getRank(); + auto outTy = cast( + ConvertAtenOp::getTypeConverter()->convertType(op.getType())); + + if (inputRank <= Dim) { + return op.emitError( + "max_pooling1d/2d only supports inputs with rank higher than 1/2"); + } + SmallVector padding, kernelSize, stride, dilation; + bool ceilMode = false; + + if (!(matchPattern(op.getKernelSize(), + m_TorchListOfConstantInts(kernelSize)))) { + return rewriter.notifyMatchFailure( + op, "non-const int kernel size unsupported!"); + } + if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { + return rewriter.notifyMatchFailure(op, + "non-const int stride unsupported!"); + } + if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); + } + if (!(matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilation)))) { + return rewriter.notifyMatchFailure(op, + "non-const int dilation unsupported!"); + } + if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { + return rewriter.notifyMatchFailure( + op, "non-const bool ceil_mode unsupported!"); + } + + if (stride.empty()) { + stride = kernelSize; + } + + // prepend 1 to kernelSize, stride, dilation until they are of same rank + // as input + SmallVector stablehloStride(inputRank, 1); + SmallVector stablehloDilation(inputRank, 1); + SmallVector stablehloKernelSize(inputRank, 1); + SmallVector stablehloPadding(inputRank * 2, 0); + std::copy(dilation.begin(), dilation.end(), + stablehloDilation.begin() + inputRank - Dim); + std::copy(stride.begin(), stride.end(), + stablehloStride.begin() + inputRank - Dim); + std::copy(kernelSize.begin(), kernelSize.end(), + stablehloKernelSize.begin() + inputRank - Dim); + + Value initVal = + createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + + if (Dim == 1) { + stablehloPadding[stablehloPadding.size() - 2] = padding[0]; + stablehloPadding[stablehloPadding.size() - 1] = padding[0]; + } else if (Dim == 2) { + stablehloPadding[stablehloPadding.size() - 4] = padding[0]; + stablehloPadding[stablehloPadding.size() - 3] = padding[0]; + stablehloPadding[stablehloPadding.size() - 2] = padding[1]; + stablehloPadding[stablehloPadding.size() - 1] = padding[1]; + } else if (Dim == 3) { + stablehloPadding[stablehloPadding.size() - 6] = padding[0]; + stablehloPadding[stablehloPadding.size() - 5] = padding[0]; + stablehloPadding[stablehloPadding.size() - 4] = padding[1]; + stablehloPadding[stablehloPadding.size() - 3] = padding[1]; + stablehloPadding[stablehloPadding.size() - 2] = padding[2]; + stablehloPadding[stablehloPadding.size() - 1] = padding[2]; + } else { + assert(false && "Unsupported pooling dimension"); + } + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + DenseI64ArrayAttr baseDilations; + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); + + DenseIntElementsAttr pad = DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast(inputRank), static_cast(2)}, + rewriter.getI64Type()), + stablehloPadding); + + auto reduceWindowOp = rewriter.create( + op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, + baseDilations, windowDilations, pad); + + Block &block = reduceWindowOp.getBody().emplaceBlock(); + + // Add bb argument + auto blockArgumentType = RankedTensorType::get({}, inputElemTy); + block.addArgument(blockArgumentType, op->getLoc()); + block.addArgument(blockArgumentType, op->getLoc()); + auto *firstArg = block.args_begin(); + auto secondArg = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + + Value result = rewriter.create(op->getLoc(), *firstArg, + *secondArg); + rewriter.create(op->getLoc(), result); + } + + rewriter.replaceOp(op, reduceWindowOp.getResults()); + return success(); + } +}; +} // namespace + namespace { template class ConvertAtenAvgPoolOp : public ConvertAtenOp { @@ -375,8 +404,8 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { auto outShape = outTy.getShape(); if (inputRank <= Dim) { - return op.emitError( - "avg_pooling1d/2d only supports inputs with rank higher than 1/2"); + return op.emitError("avg_pooling1d/2d/3d only supports inputs with rank " + "higher than 1/2/3"); } SmallVector padding, kernelSize, stride; bool ceilMode = false; @@ -405,6 +434,10 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { op, "non-const bool count_include_pad unsupported!"); } + if (stride.empty()) { + stride = kernelSize; + } + if constexpr (std::is_same()) { if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) return rewriter.notifyMatchFailure( @@ -425,11 +458,20 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { if (Dim == 1) { stablehloPadding[stablehloPadding.size() - 2] = padding[0]; stablehloPadding[stablehloPadding.size() - 1] = padding[0]; - } else { + } else if (Dim == 2) { stablehloPadding[stablehloPadding.size() - 4] = padding[0]; stablehloPadding[stablehloPadding.size() - 3] = padding[0]; stablehloPadding[stablehloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1]; + } else if (Dim == 3) { + stablehloPadding[stablehloPadding.size() - 6] = padding[0]; + stablehloPadding[stablehloPadding.size() - 5] = padding[0]; + stablehloPadding[stablehloPadding.size() - 4] = padding[1]; + stablehloPadding[stablehloPadding.size() - 3] = padding[1]; + stablehloPadding[stablehloPadding.size() - 2] = padding[2]; + stablehloPadding[stablehloPadding.size() - 1] = padding[2]; + } else { + assert(false && "Unsupported pooling dimension"); } Value initVal = @@ -474,10 +516,17 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { divisor = hlo::getConstTensor(rewriter, op, {kernelSize[0]}, {}) .value(); - } else { + } else if (Dim == 2) { divisor = hlo::getConstTensor( rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) .value(); + } else if (Dim == 3) { + divisor = hlo::getConstTensor( + rewriter, op, + {kernelSize[0] * kernelSize[1] * kernelSize[2]}, {}) + .value(); + } else { + assert(false && "Unsupported pooling dimension"); } divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); DenseI64ArrayAttr bcastDimensions; @@ -611,22 +660,28 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); - target.addIllegalOp(); - patterns.add>(typeConverter, context, options); - target.addIllegalOp(); - patterns.add>(typeConverter, context, options); - target.addIllegalOp(); - patterns.add>(typeConverter, context, options); - target.addIllegalOp(); - patterns.add>(typeConverter, - context, options); - target.addIllegalOp(); - patterns.add>(typeConverter, context, options); +#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp); + INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp); +#undef INSERT_ATEN_POOLING_PATTERN + +#define INSERT_ATEN_MAXPOOL_PATTERN(AtenOp, Dim) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, \ + options) + INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool1dOp, 1); + INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool2dOp, 2); + INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool3dOp, 3); +#undef INSERT_ATEN_MAXPOOL_PATTERN + #define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, \ options) INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1); INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2); + INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool3dOp, 3); #undef INSERT_ATEN_AVGPOOL_PATTERN } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 553a8dc74413..d9ac7a6d0c55 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7845,19 +7845,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %arg2 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list {\n" -" %0 = call @__torch__.avg_pool1d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool) -> !torch.list\n" +" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @__torch__.avg_pool1d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list {\n" +" func.func @__torch__.pool1d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool) -> !torch.list {\n" " %int-1 = torch.constant.int -1\n" " %int-2 = torch.constant.int -2\n" " %int-3 = torch.constant.int -3\n" " %str = torch.constant.str \"AssertionError: \"\n" -" %str_0 = torch.constant.str \"AssertionError: avg_pool1d: padding must be a single int\"\n" -" %str_1 = torch.constant.str \"AssertionError: avg_pool1d: stride must either be omitted, or a single int\"\n" +" %str_0 = torch.constant.str \"AssertionError: pool1d: padding must be a single int\"\n" +" %str_1 = torch.constant.str \"AssertionError: pool1d: stride must either be omitted, or a single int\"\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" -" %str_2 = torch.constant.str \"AssertionError: avg_pool1d: kernel_size must be a single int\"\n" +" %str_2 = torch.constant.str \"AssertionError: pool1d: kernel_size must be a single int\"\n" " %int1 = torch.constant.int 1\n" " %int0 = torch.constant.int 0\n" " %int2 = torch.constant.int 2\n" @@ -7940,6 +7940,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %23 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.max_pool1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 10c24b657128..8ffe8d1c30c7 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1075,6 +1075,9 @@ "Matmul_vecmat", "MatmulStaticBroadcast_basic", "MaxPool2dStaticModule_basic", + "MaxPool2dEmptyStrideStaticModule_basic", + "MaxPool3dStaticModule_basic", + "MaxPool3dEmptyStrideStaticModule_basic", "MeanDimAllReduceModule_basic", "MeanDimEmptyDimModule_basic", "MeanDimNoneDimModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index da486fe4602c..eb60620561bf 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -961,14 +961,14 @@ def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd # TODO: This should be upstreamed. # See https://github.com/pytorch/pytorch/pull/76889 for an example. -def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool): - assert len(kernel_size) == 1, "avg_pool1d: kernel_size must be a single int" +def pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool): + assert len(kernel_size) == 1, "pool1d: kernel_size must be a single int" kL = kernel_size[0] - assert len(stride) == 0 or len(stride) == 1, "avg_pool1d: stride must either be omitted, or a single int" + assert len(stride) == 0 or len(stride) == 1, "pool1d: stride must either be omitted, or a single int" dL = kL if len(stride) == 0 else stride[0] - assert len(padding) == 1, "avg_pool1d: padding must be a single int" + assert len(padding) == 1, "pool1d: padding must be a single int" padL = padding[0] dilationL = 1 @@ -1004,7 +1004,10 @@ def adaptive_avg_pool1d(self: List[int], out: List[int]): return shape def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]: - return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad) + return pool1d(self, kernel_size, stride, padding, ceil_mode) + +def aten〇max_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> List[int]: + return pool1d(self, kernel_size, stride, padding, ceil_mode) def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]: return adaptive_avg_pool1d(self, output_size) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index eea8d31a95a4..e0329c8df54f 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -591,6 +591,7 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)" ) + emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit( "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" From 05f8b69bf66f7727fc4870e51efa74c2f276b624 Mon Sep 17 00:00:00 2001 From: Vinayak Dev <104419489+vinayakdsci@users.noreply.github.com> Date: Tue, 30 Apr 2024 21:51:27 +0530 Subject: [PATCH 17/23] [MLIR][TORCH] Add OnnxToTorch support for BlackmanWindow function (#3181) Implements OnnxToTorch lowering for the BlackmanWindow Function. --- .../Conversion/TorchOnnxToTorch/Utils.h | 7 + .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 122 ++++++++++++++++++ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 8 -- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 78 +++++++++++ 4 files changed, 207 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index d4ace352a9bd..919146c6a1c7 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -38,6 +38,13 @@ Value createConstantIntList(OpBinder binder, Type getQTorchTypeFromTorchIntType(Type ty); +template +Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, + Value &ofItem) { + return rewriter.create(binder.getLoc(), + rewriter.getType(), ofItem); +} + LogicalResult OnnxLstmExpander(OpBinder binder, ConversionPatternRewriter &rewriter); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 716ea3d6e202..bd5c57fac3ba 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -2240,4 +2240,126 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone); return success(); }); + patterns.onOp( + "BlackmanWindow", 17, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value size; + Torch::ValueTensorType resultType; + int64_t periodic, output_datatype; + if (binder.tensorOperand(size) || + binder.s64IntegerAttr(output_datatype, "output_datatype", 1) || + binder.s64IntegerAttr(periodic, "periodic", 1) || + binder.tensorResultType(resultType)) { + return failure(); + } + double isPeriodicFp = static_cast(periodic); + Value a0 = rewriter.create( + binder.getLoc(), + rewriter.getFloatAttr(rewriter.getF64Type(), 0.42)); + Value a1 = rewriter.create( + binder.getLoc(), + rewriter.getFloatAttr(rewriter.getF64Type(), -0.5)); + Value a2 = rewriter.create( + binder.getLoc(), + rewriter.getFloatAttr(rewriter.getF64Type(), 0.08)); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(0.0)); + Value one = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(1.0)); + Value two = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(2.0)); + + constexpr double pi = llvm::numbers::pi; + Value tau = rewriter.create( + binder.getLoc(), + rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi)); + + Value noneVal = rewriter.create(binder.getLoc()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value float32Type = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(/*float32Type*/ 6)); + + // Create an f32 ValueTensorType with thse same size as size, the + // operand + auto shapeOfOperand = size.getType() + .dyn_cast() + .getOptionalSizes(); + auto f32ResultType = rewriter.getType( + shapeOfOperand, rewriter.getF32Type()); + Value periodicSizeFloat = rewriter.create( + binder.getLoc(), f32ResultType, size, float32Type, cstFalse, + cstFalse, noneVal); + Value symmetricSizeFloat = rewriter.create( + binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat, + one, one); + + Value isPeriodic = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(isPeriodicFp)); + Value isSymmetricFloat = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(1.0 - isPeriodicFp)); + + Value periodicComponent = rewriter.create( + binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat, + isPeriodic); + Value symmetricComponent = rewriter.create( + binder.getLoc(), symmetricSizeFloat.getType(), symmetricSizeFloat, + isSymmetricFloat); + Value sizeFloat = rewriter.create( + binder.getLoc(), symmetricComponent.getType(), symmetricComponent, + periodicComponent, one); + + // Here, size can be used in the place of periodicSizeFloat, as the + // latter is just a float representation of the former. + Value scalarLimit = getItemOp(binder, rewriter, size); + + Value rangeArr = rewriter.create( + binder.getLoc(), resultType, zero, scalarLimit, one, noneVal, + noneVal, noneVal, noneVal); + + Value rangeTimesTau = rewriter.create( + binder.getLoc(), resultType, rangeArr, tau); + Value rangeAngular = rewriter.create( + binder.getLoc(), resultType, rangeTimesTau, sizeFloat); + Value twoRangeAngular = rewriter.create( + binder.getLoc(), resultType, rangeAngular, two); + + Value cosRangeAngular = rewriter.create( + binder.getLoc(), resultType, rangeAngular); + Value cosTwoRangeAngular = rewriter.create( + binder.getLoc(), resultType, twoRangeAngular); + + Value a1Component = rewriter.create( + binder.getLoc(), resultType, cosRangeAngular, a1); + Value a2Component = rewriter.create( + binder.getLoc(), resultType, cosTwoRangeAngular, a2); + + // AtenSubScalarOp actually requires a tensor operand as the LHS, that + // is, operand #1. Therefore, to avoid errors, the onnx implementation + // has been modified. a1 has been changed to negative half, and the + // AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add + // operation is commutative. + Value subA1Component = rewriter.create( + binder.getLoc(), resultType, a1Component, a0, one); + Value result = rewriter.create( + binder.getLoc(), resultType, subA1Component, a2Component, one); + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(output_datatype); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value outputDtype = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + dtypeIntTorch.value())); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, result, outputDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/noneVal); + return success(); + }); } diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 197d9c536b9f..5f9da3faa18e 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -31,15 +31,7 @@ using namespace mlir::torch::onnx_c; // thing here, so we simplify. // utilities -// Templatized function to get an item op of a type namespace { -template -Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, - Value &ofItem) { - return rewriter.create(binder.getLoc(), - rewriter.getType(), ofItem); -} - // In case the ReduceSum Op was not the first operation performed on the data, // we provide the original operand through storeResult, which will be modified // if the result will be passed onto another operation, and will be used for diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index eb2cde696ef1..a068acbf2941 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -2035,3 +2035,81 @@ func.func @test_eyelike_dynamic(%arg0: !torch.vtensor<[3,?],f32>) -> !torch.vten %0 = torch.operator "onnx.EyeLike"(%arg0) {torch.onnx.k = -1 : si64} : (!torch.vtensor<[3,?],f32>) -> !torch.vtensor<[3,?],f32> return %0 : !torch.vtensor<[3,?],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_blackmanwindow_symmetric +func.func @test_blackmanwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01 + // CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01 + // CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02 + // CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00 + // CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK: return %[[CAST]] : !torch.vtensor<[10],f32> + %0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_blackmanwindow +func.func @test_blackmanwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01 + // CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01 + // CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02 + // CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00 + // CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK: return %[[CAST]] : !torch.vtensor<[10],f32> + %0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} From 9442c6685698b111c936c2d0c2e173b5e56b88d7 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 30 Apr 2024 09:21:39 -0700 Subject: [PATCH 18/23] [torch-mlir][sparse] add a few missing passes to the ref pipeline (#3265) For some sparse programs (and I am sure other not-seen corner cases for dense), some passes were missing in the reference pipeline, eventually resulting in e.g. a unresolved unrealized cast issue. This PR adds some very obvious missing passes to avoid this situation. --- .../linalg_on_tensors_backends/refbackend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 1e958a4d9451..08e8ff64d08e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -180,6 +180,7 @@ def invoke(*args): "func.func(tm-tensor-to-loops)", "func.func(refback-munge-memref-copy)", "func.func(convert-linalg-to-loops)", + "func.func(expand-realloc)", "func.func(lower-affine)", "convert-scf-to-cf", "func.func(refback-expand-ops-for-llvm)", @@ -193,6 +194,7 @@ def invoke(*args): "convert-bufferization-to-memref", "finalize-memref-to-llvm", "func.func(convert-arith-to-llvm)", + "convert-vector-to-llvm", "convert-func-to-llvm", "convert-cf-to-llvm", "convert-complex-to-llvm", From 72349f7522195645d1af7b468bea15b64a37b105 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Tue, 30 Apr 2024 11:23:09 -0500 Subject: [PATCH 19/23] [TorchToLinalg] Adds Quantization Support for ConvTranspose (#3240) I spent a little while debugging numerics issues with some tests similar to the ones in quantized_models.py, only to find that pytorch's quantized conv transpose is catastrophically inaccurate. I'll upstream the issue and only leave the tests here which are of the form quantize -> dequantize -> op. --- lib/Conversion/TorchToLinalg/Linear.cpp | 59 +++++++++++-------- projects/pt1/e2e_testing/xfail_sets.py | 5 ++ .../torch_mlir_e2e_test/test_suite/conv.py | 53 +++++++++++++++++ 3 files changed, 92 insertions(+), 25 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 3f4e6ed66354..c49646e2f1c0 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -43,7 +43,8 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg, if (!isUnsignedType) return; int64_t minSI = -(1 << (numBits - 1)); - Value minSIValue = rewriter.create(loc, minSI, 32); + Value minSIValue = rewriter.create( + loc, minSI, zp.getType().cast().getWidth()); zp = rewriter.create(loc, zp, minSIValue); minSIValue = rewriter.create(loc, minSI, numBits); arg = torch_to_linalg::createElementwiseLinalgGeneric( @@ -797,6 +798,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { auto resultTy = cast(op.getType()); Value inputZp, weightZp; + bool inputUnsigned = false; + bool weightUnsigned = false; if (auto make = op.getInput() .getDefiningOp()) { input = make.getSelf(); @@ -806,6 +809,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { inputZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(inputZp.getType()), inputZp); + auto torchDtype = cast(make.getType()).getDtype(); + inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); } if (auto make = op.getWeight() @@ -818,6 +823,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { weightZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(weightZp.getType()), weightZp); + auto torchDtype = cast(make.getType()).getDtype(); + weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); } if (static_cast(inputZp) != static_cast(weightZp)) { @@ -916,15 +923,35 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { SmallVector strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); + // convert any uint8 quantization to int8 quantization + if (auto integerType = dyn_cast(inputDTy)) { + int64_t width = integerType.getWidth(); + signShift(rewriter, loc, input, inputZp, inputUnsigned, width); + } + if (auto integerType = dyn_cast(weightDTy)) { + int64_t width = integerType.getWidth(); + signShift(rewriter, loc, weight, weightZp, weightUnsigned, width); + } // Pad the input tensor according to padding. SmallVector outDims{inBatch, weightBatch}; Value paddedInput; - if (transposed) { - if (!isa(inputDTy) || !isa(weightDTy) || - !isa(resultDTy)) - return rewriter.notifyMatchFailure( - op, "transpose does not support non-fp type yet"); + Value pad = inputZp; + if (!pad) { + if (isa(inputDTy)) + pad = rewriter.create( + op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0)); + if (isa(inputDTy)) + pad = rewriter.create( + op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0)); + } + if (pad.getType() != inputDTy) { + if (isa(inputDTy)) + pad = rewriter.create(op.getLoc(), inputDTy, pad); + if (isa(inputDTy)) + pad = rewriter.create(op.getLoc(), inputDTy, pad); + } + if (transposed) { Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); Value c1 = @@ -994,7 +1021,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // Allocate padded input tensor Value initTensor = - createZeroInitTensor(rewriter, loc, outerSizes, inputDTy); + createInitTensor(rewriter, loc, outerSizes, inputDTy, pad); // Insert input into allocated tensor SmallVector strideIndexValues{c1, c1}; @@ -1017,24 +1044,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { strideInts.clear(); strideInts.append(numSpatialDims, 1); } else { - Value pad = inputZp; - if (!pad) { - if (isa(inputDTy)) - pad = rewriter.create( - op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0)); - if (isa(inputDTy)) - pad = rewriter.create( - op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0)); - } - - if (pad.getType() != inputDTy) { - if (isa(inputDTy)) - pad = rewriter.create(op.getLoc(), inputDTy, pad); - - if (isa(inputDTy)) - pad = rewriter.create(op.getLoc(), inputDTy, pad); - } - // Pad input paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor( op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 8ffe8d1c30c7..fcb7e053a0db 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -272,6 +272,7 @@ "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", "Conv2dQInt8Module_basic", + "ConvTranspose2DQInt8_basic", # Dynamo not supporting conv_tbc "ConvTbcModule_basic", "FloatImplicitModule_basic", @@ -372,6 +373,7 @@ "Conv2dQInt8Module_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "ConvTbcModule_basic", + "ConvTranspose2DQInt8_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", @@ -544,6 +546,7 @@ "ContainsIntList_True", "Conv2dQInt8Module_basic", "ConvTbcModule_basic", + "ConvTranspose2DQInt8_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", @@ -2100,6 +2103,7 @@ "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", "Conv2dQInt8Module_basic", + "ConvTranspose2DQInt8_basic", } ONNX_XFAIL_SET = { @@ -2254,6 +2258,7 @@ "Conv2dWithPaddingModule_basic", "Conv3dModule_basic", "ConvTbcModule_basic", + "ConvTranspose2DQInt8_basic", "Conv_Transpose2dModule_basic", "Convolution2DModule_basic", "Convolution2DStridedModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 9600b090032e..e99525c32d88 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1046,3 +1046,56 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils): weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8) bias = torch.rand(3) module.forward(inputVec, weight, bias) + + +N = 10 +Cin = 5 +Cout = 7 +Hin = 10 +Win = 8 +Hker = 3 +Wker = 2 + + +class ConvTranspose2DQInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.int8, True), + ([-1, -1, -1, -1], torch.int8, True), + ([-1], torch.float, True), + ] + ) + def forward(self, input, weight, bias): + qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25) + qinput = torch.dequantize(qinput) + qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50) + qweight = torch.dequantize(qweight) + qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32) + qbias = torch.dequantize(qbias) + qz = torch.ops.aten.convolution( + qinput, + qweight, + bias=qbias, + stride=[2, 1], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1, + ) + return qz + + +@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module()) +def ConvTranspose2DQInt8_basic(module, tu: TestUtils): + module.forward( + tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8), + tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8), + torch.rand(Cout), + ) From 315dc6c3e377b74e8981776237cff2e733667811 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Tue, 30 Apr 2024 13:41:03 -0400 Subject: [PATCH 20/23] [torch] `aten.eye` should use dynamic dims when no static dims are available (#3202) Co-authored-by: Xida Ren --- .../Torch/Transforms/DecomposeComplexOps.cpp | 47 +++++++++---------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 677ccc4f241b..cc21f2155e46 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1059,44 +1059,44 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenEyeMOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - int64_t n; - - if (!matchPattern(op.getN(), m_TorchConstantInt(&n))) - return rewriter.notifyMatchFailure(op, - "unimplemented: n must be constant"); - int64_t m; - if (!matchPattern(op.getM(), m_TorchConstantInt(&m))) - return rewriter.notifyMatchFailure(op, - "unimplemented: m must be constant"); - Value none = rewriter.create(loc); - auto outType = dyn_cast(op.getType()); + auto outType = op.getType().dyn_cast(); if (!outType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); if (!outType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } - if (n < 0) { - return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0"); - } - if (m < 0) { - return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0"); - } - + Value none = rewriter.create(loc); auto context = op.getContext(); auto int64Dtype = getDtypeIntValueForType( rewriter, loc, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); - auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type); + + int64_t n = kUnknownSize; + int64_t m = kUnknownSize; + // prioritize getting shape from output shape + if (outType.hasSizes() && outType.getSizes().size() == 2) { + n = outType.getSizes().front(); + m = outType.getSizes().back(); + } + // if output shape is not available, try to get shape from input + if (n == kUnknownSize) + matchPattern(op.getN(), m_TorchConstantInt(&n)); + if (m == kUnknownSize) + matchPattern(op.getM(), m_TorchConstantInt(&m)); + + // prepare two unsqueezed ranges that are equal on and only on the diagonal + auto rangeNSize = llvm::SmallVector({n}); + Type rangeNType = outType.getWithSizesAndDtype(rangeNSize, si64Type); Value rangeN = rewriter.create( - loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none, + loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/op.getDevice(), /*pin_memory=*/none); - auto arangeType1 = - outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type); + auto rangeMSize = llvm::SmallVector({m}); + Type rangeMType = outType.getWithSizesAndDtype(rangeMSize, si64Type); Value rangeM = rewriter.create( - loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none, + loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); Value constMinusOne = rewriter.create( @@ -1109,7 +1109,6 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { } Value unsqzRangeN = *unsqzTensorInfo; - // compare unsqueezed input with boundaries auto eqType = ValueTensorType::get( context, cast(op.getType()).getSizes(), IntegerType::get(context, 1)); From 33eef15e428f848e3848d1038ed71faab893a686 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Tue, 30 Apr 2024 14:36:40 -0400 Subject: [PATCH 21/23] Support onnx.If (#2825) This is probably a decent PR for learning about blocks and regions. If you're here to learn about that, consider also looking at lib/Conversion/TorchToSCF/TorchToSCF.cpp While this doesn't include an e2e test, it is tested downstream in https://github.com/nod-ai/SHARK-TestSuite/blob/main/e2eshark/onnx/operators/If/model.py --------- Co-authored-by: Xida Ren --- .../Conversion/TorchOnnxToTorch/Patterns.h | 25 +++++++++ .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 54 +++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 12 ++--- .../test_suite/diagonal.py | 31 +++++++++++ python/torch_mlir/extras/onnx_importer.py | 8 ++- test/Conversion/TorchOnnxToTorch/ops/if.mlir | 20 +++++++ 6 files changed, 141 insertions(+), 9 deletions(-) create mode 100644 test/Conversion/TorchOnnxToTorch/ops/if.mlir diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index d3260500cfa8..3230cc8b46a0 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -97,6 +97,31 @@ struct OpBinder { return success(); } + ParseResult tensorResultTypes(llvm::SmallVector &typeList) { + for (auto result : op->getResults()) { + auto t = toValidTensorType(result.getType()); + if (!t) + return failure(); + typeList.push_back(t); + } + return success(); + } + + // The importer imports Onnx.GraphProto attributes as regions attached to the + // op. + ParseResult getRegionAtIndex(mlir::Region *®ion, int64_t idx) { + if (idx >= op->getNumRegions()) + return failure(); + + region = &op->getRegion(idx); + + if (region == nullptr) { + return failure(); + } + + return success(); + } + ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, int64_t idx) { if (idx >= op->getNumResults()) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 7a150794cb4b..1f1e2e5d7f0c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -158,6 +158,60 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( alignCorners); return success(); }); + patterns.onOp( + "If", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value conditionTensor; + if (binder.tensorOperand(conditionTensor)) { + return rewriter.notifyMatchFailure(binder.op, + "condition bind failure"); + } + + auto conditionType = + conditionTensor.getType().cast(); + if (!conditionType || conditionType.getSizes().size() != 1) + return rewriter.notifyMatchFailure( + binder.op, "condition must have one single element per " + "https://onnx.ai/onnx/operators/onnx__If.html"); + auto conditionInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + conditionTensor); + auto conditionBool = rewriter.create( + binder.getLoc(), rewriter.getType(), conditionInt); + + llvm::SmallVector resultTypes; + if (binder.tensorResultTypes(resultTypes)) { + return rewriter.notifyMatchFailure(binder.op, + "result type bind failure"); + } + + Region *thenRegion, *elseRegion; + if (binder.getRegionAtIndex(elseRegion, 0) || + binder.getRegionAtIndex(thenRegion, 1)) { + return rewriter.notifyMatchFailure(binder.op, "region bind failure"); + } + + auto primIfOp = rewriter.create( + binder.getLoc(), TypeRange(resultTypes), conditionBool); + + auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) { + rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin()); + }; + inlineIfCase(*thenRegion, primIfOp.getThenRegion()); + inlineIfCase(*elseRegion, primIfOp.getElseRegion()); + + auto replaceTerminator = [&](Region ®ion) { + PatternRewriter::InsertionGuard guard(rewriter); + Operation *terminator = region.front().getTerminator(); + rewriter.setInsertionPoint(terminator); + rewriter.replaceOpWithNewOp( + terminator, terminator->getOperands()); + }; + replaceTerminator(primIfOp.getThenRegion()); + replaceTerminator(primIfOp.getElseRegion()); + + rewriter.replaceOp(binder.op, primIfOp.getResults()); + return success(); + }); patterns.onOp("Less", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index fcb7e053a0db..25d8fa9be5a2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2562,16 +2562,12 @@ "_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic", "_SoftmaxModule_basic", + # Failure - onnx_import # Failure - onnx_lowering: onnx.AveragePool "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", - # Failure - onnx_lowering: onnx.If - "DiagonalModule_basic", - "DiagonalModule_nonsquare", - "DiagonalModule_transposed", - "DiagonalModule_with_dims", - "DiagonalModule_with_dims_and_offset", - "DiagonalModule_with_negative_dims", - "DiagonalModule_with_offset", + # these diagonal modules are currently failing due to dynamic shape. + # We are currently testing aten.diagonal using DiagonalWithStaticShapeModule instead. + # when the issue is fixed, please remove DiagonalWithStaticShapeModule as well as the xfails here. "TileBigDimsSizeModule_basic", "TileSmallDimsSizeModule_basic", # Failure - onnx_lowering: onnx.MaxPool diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py index 6371f9a8d7a7..3bd3796dad8e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py @@ -39,6 +39,37 @@ def DiagonalModule_nonsquare(module, tu: TestUtils): # ============================================================================== +class DiagonalWithStaticShapeModule(torch.nn.Module): + """ + Diagonal with static shape. The other diagonal modules are failing in onnx + because DecomoposeAtenEyeMOp requires constants n, m, which are only constant + when the shape is static. + + Please remove this module and associated test once the issue is fixed. + """ + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 9], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.diagonal(a) + + +@register_test_case(module_factory=lambda: DiagonalWithStaticShapeModule()) +def DiagonalWithStaticShapeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 9)) + + +# ============================================================================== + + class DiagonalTransposedModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index 8d0e4cf5a8e1..e0d3529d942e 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -347,8 +347,14 @@ def import_attributes(self, onnx_attrs: List[onnx.AttributeProto]): continue elif handler is False: # Active error. + # try matching attribute type ID to name for a more descriptive error message + try: + attr_type_name = onnx.AttributeProto.AttributeType.Name(attr_type) + except ValueError: + attr_type_name = "UNKNOWN" raise OnnxImportError( - f"ONNX importer does not support generic node attribute type {attr_type}. " + f"ONNX importer does not support generic node attribute type {attr_type_name} " + f"with ID {attr_type}. " f"This likely means that this is a special node which requires specific " f"handling in the importer: {onnx_attr}" ) diff --git a/test/Conversion/TorchOnnxToTorch/ops/if.mlir b/test/Conversion/TorchOnnxToTorch/ops/if.mlir new file mode 100644 index 000000000000..1d95a3f5fc3a --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/ops/if.mlir @@ -0,0 +1,20 @@ +// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s + +// CHECK-LABEL: func.func @test_ifop_basic +// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[1],f32>) +// CHECK-DAG: %[[SUB:.*]] = torch.aten.sub.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32> +// CHECK-DAG: torch.prim.If.yield %[[SUB]] : !torch.vtensor<[1],f32> +// CHECK-DAG: } else { +// CHECK-DAG: %[[ADD:.*]] = torch.aten.add.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32> +// CHECK-DAG: torch.prim.If.yield %[[ADD]] : !torch.vtensor<[1],f32> +func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + %0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[1],f32> { + %1 = torch.operator "onnx.Add"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> + torch.operator_terminator %1 : !torch.vtensor<[1],f32> + }, { + %1 = torch.operator "onnx.Sub"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> + torch.operator_terminator %1 : !torch.vtensor<[1],f32> + } + return %0 : !torch.vtensor<[1],f32> +} From 0a2d21b108602d2b11c208ca1a713a72f483f6c1 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Tue, 30 Apr 2024 17:48:01 -0400 Subject: [PATCH 22/23] Add `.yamllint` and disable some annoying recurring warnings on every pr (#3224) --- .yamllint.yml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 .yamllint.yml diff --git a/.yamllint.yml b/.yamllint.yml new file mode 100644 index 000000000000..ec40711eb2bf --- /dev/null +++ b/.yamllint.yml @@ -0,0 +1,22 @@ +--- + +extends: default + +rules: + # These do not appear to be conventional in GitHub actions. + document-end: + present: false + document-start: + present: false + # GitHub actions use "on" for triggers. + truthy: disable + # We have lots of long strings and command lines. + line-length: disable + comments: + # Formatters may do this (e.g. Prettier does) and it seems like the most + # trivial thing to get a failing check for. + min-spaces-from-content: 1 + # This is not a useful check, especially when disabling entire blocks. + comments-indentation: disable + +ignore: /third_party/* From 8c48135a426b84fa412b031fc92e12826ff60b31 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Wed, 1 May 2024 12:06:53 +0530 Subject: [PATCH 23/23] [linalg] Fix bug for conversion of complex dtype (#3269) The conversion of complex type wasn't supported or checked; the support and required tests were added. Fixes: https://github.com/iree-org/iree/issues/17226#issuecomment-2087779158 --- lib/Conversion/Utils/Utils.cpp | 21 ++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 2 ++ .../test_suite/elementwise.py | 28 +++++++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index bae25cc7ac60..e014fbeaa9d4 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -10,6 +10,7 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -349,6 +350,26 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return b.create(loc, dtype, scalar); } + if (auto dtypeComplex = dyn_cast(dtype)) { + if (auto scalarComplex = dyn_cast(scalarType)) { + auto dtypeElemType = dtypeComplex.getElementType(); + + // Extract the real and imaginary parts of the scalar. + // Cast them to the target element type, and create a new complex + // value with the target complex type. + Value realVal = b.create(loc, scalar); + Value imgVal = b.create(loc, scalar); + + realVal = convertScalarToDtype(b, loc, realVal, dtypeElemType); + imgVal = convertScalarToDtype(b, loc, imgVal, dtypeElemType); + + return b.create(loc, dtypeComplex, realVal, imgVal); + } + mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype " + << scalarType << "(scalar type) -> " << dtype + << "(dtype)"; + } + llvm_unreachable("convertScalarToDtype should handle all the types"); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 25d8fa9be5a2..33f1ed702273 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -575,6 +575,7 @@ "ElementwiseErfIntModule_basic", "ElementwiseLogitModule_basic", "ElementwiseMulTensorComplexModule_basic", + "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", @@ -2314,6 +2315,7 @@ "ElementwiseExpm1Module_basic", "ElementwiseFmodTensor_Int_basic", "ElementwiseMulTensorComplexModule_basic", + "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseOrTensorModule_basic", "ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseQuantizePerTensorModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 8e287584295b..a26fd9809f13 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1839,6 +1839,34 @@ def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils): # ============================================================================== +# torch.complex32 is not supported by the refbackend. +class ElementwiseMulTensorComplexDiffModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.complex64, True), + ([-1], torch.complex128, True), + ] + ) + def forward(self, a, b): + return torch.mul(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseMulTensorComplexDiffModule()) +def ElementwiseMulTensorComplexDiffModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(4, high=10).type(torch.complex64), + tu.randint(4, high=10).type(torch.complex128), + ) + + +# ============================================================================== + + class ElementwiseMishModule(torch.nn.Module): def __init__(self): super().__init__()