From 5e4f00acb13f3f849a05e5ac28ee39307a5fdbff Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Fri, 12 Jul 2024 09:15:42 +0800 Subject: [PATCH 1/4] [Torch] add support for aten.scatter_add (#3534) --- .../TorchToTMTensor/TorchToTMTensor.cpp | 31 +++++++++++++++---- .../Transforms/AbstractInterpLibrary.cpp | 7 +++++ projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/abstract_interp_lib_gen.py | 9 ++++++ .../torch_mlir_e2e_test/test_suite/scatter.py | 25 +++++++++++++++ 5 files changed, 67 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index b6bd3b8b6a36..3e37456f3086 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -373,16 +373,19 @@ static FailureOr> createTMTensorTopkOp( } namespace { -class ConvertAtenScatterSrcOp : public OpConversionPattern { +template +class ConvertAtenScatterOp : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult - matchAndRewrite(AtenScatterSrcOp op, OpAdaptor adaptor, + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); - const TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = + OpConversionPattern::getTypeConverter(); Value self = adaptor.getSelf(); Value index = adaptor.getIndex(); Value src = adaptor.getSrc(); @@ -410,7 +413,19 @@ class ConvertAtenScatterSrcOp : public OpConversionPattern { /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value updatesElement, Value inputElement) { - b.create(loc, updatesElement); + if (isa(op)) { + b.create(loc, updatesElement); + } else if (isa(op)) { + if (isa(selfType.getElementType())) { + Value add = + b.create(loc, inputElement, updatesElement); + b.create(loc, add); + } else if (isa(selfType.getElementType())) { + Value add = + b.create(loc, inputElement, updatesElement); + b.create(loc, add); + } + } }); auto resultType = cast( @@ -2169,7 +2184,11 @@ class ConvertTorchToTMTensor context); target.addIllegalOp(); - patterns.add(typeConverter, context); + patterns.add>(typeConverter, + context); + target.addIllegalOp(); + patterns.add>(typeConverter, + context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index bc8f252e6dfc..65f9f16e0425 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9787,6 +9787,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.scatter.value\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.float) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.scatter_add\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.list) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11567,6 +11570,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scatter_add\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.masked_scatter\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 504c7ca9d6f7..f9576c984c73 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2682,6 +2682,7 @@ "ScatterReduceIntMaxModuleIncludeSelf", "ScatterReduceIntMinModuleIncludeSelf", "ScatterValueFloatModule_basic", + "ScatterAddStaticModule_basic", # Failure - onnx_lowering: onnx.ScatterND "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 37db50050b43..553398905700 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1810,6 +1810,9 @@ def aten〇scatter〇src〡shape(self: List[int], dim: int, index: List[int], sr def aten〇scatter〇value〡shape(self: List[int], dim: int, index: List[int], value: float) -> List[int]: return self +def aten〇scatter_add〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]: + return self + def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]: return upstream_shape_functions.index_select(self, dim, index) @@ -3115,6 +3118,12 @@ def aten〇scatter〇value〡dtype(self_rank_dtype: Tuple[int, int], dim: int, i self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function( + [Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES]) +def aten〇scatter_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( [Invocation(TensorOfShape(3, dtype=dtype), TensorOfShape(3, dtype=torch.bool), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES]) def aten〇masked_scatter〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], source_rank_dtype: Tuple[int, int]) -> int: diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py index ba44dc076904..ee85855e4aa8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -1020,6 +1020,31 @@ def ScatterValueIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ScatterAddStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([10, 8, 6], torch.float32, True), + ([2, 4, 3], torch.int64, True), + ([5, 8, 6], torch.float32, True), + ] + ) + def forward(self, input, index, src): + return torch.ops.aten.scatter_add(input, 0, index, src) + + +@register_test_case(module_factory=lambda: ScatterAddStaticModule()) +def ScatterAddStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + + +# ============================================================================== + + class ScatterReduceFloatModule(torch.nn.Module): include_self: bool reduce_type: str From cdbcf519f7fad36f2426371da2aa4853918e93a5 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Sun, 14 Jul 2024 10:33:47 -0700 Subject: [PATCH 2/4] [NFC] Expose both raw Torch dialect and Torch dialect in backend form with Dynamo/FX (#3541) This is a non-functional change. It merely allows intercepting the Torch dialect during TorchDynamo export at two stages: 1. `OutputType.RAW`: This gets us the torch dialect as-imported from the FX graph 2. `OutputType.TORCH`: This gets us the torch dialect after the raw torch goes through DecomposeComplexOps and ReduceOpVariants. Prior to this, there was no way of accessing the Torch dialect in backend compliant form (right after running the `torchdynamo-export-to-torch-backend-pipeline`) because both [here](https://sourcegraph.com/github.com/llvm/torch-mlir@5e4f00acb13f3f849a05e5ac28ee39307a5fdbff/-/blob/python/torch_mlir/fx.py?L33) and [here](https://sourcegraph.com/github.com/llvm/torch-mlir@5e4f00acb13f3f849a05e5ac28ee39307a5fdbff/-/blob/python/torch_mlir/compiler_utils.py?L138) the same `OutputType.TORCH` were used, meaning the 2nd condition would never be reached. Since the default behavior is unchanged, this is an NFC. --- python/torch_mlir/compiler_utils.py | 15 ++++++++------- python/torch_mlir/fx.py | 6 +++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/python/torch_mlir/compiler_utils.py b/python/torch_mlir/compiler_utils.py index c1315abd47f9..cb2799f85d51 100644 --- a/python/torch_mlir/compiler_utils.py +++ b/python/torch_mlir/compiler_utils.py @@ -82,12 +82,12 @@ def run_pipeline_with_repro_report( class OutputType(Enum): - # Output torch dialect. When converting from FX, this will be immediately - # after the import from FX to MLIR. When converting from torchscript, - # this will come after some cleanup passes which attempt to de-alias, - # decompose and infer shapes. These should be roughly the same level of - # abstraction since those steps are done within PyTorch itself - # when coming directly from Dynamo/FX. + # Output torch dialect in backend form. When converting from TorchDynamo, + # this comes after some decomposition and reduce op variants passes are + # applied to the raw torch dialect. When converting from TorchScript, this + # comes after some cleanup passes which attempt to de-alias, decompose and infer shapes. + # These should be roughly the same level of abstraction since those + # steps are done within PyTorch itself when coming directly from Dynamo/FX. TORCH = "torch" # The output type contains a mix of `linalg`-on-tensors ops, `scf`, and @@ -104,7 +104,8 @@ class OutputType(Enum): # as taking the `TORCH` output type and lowering it to StableHLO. STABLEHLO = "stablehlo" - # Raw output of the JIT IR importer. This is not expected to be useful + # Raw output of the JIT IR importer in the TorchScript frontend or that of + # the FX IR importer in the TorchDynamo frontend. This is not expected to be useful # for end-users, but can be convenient for development or reporting bugs. RAW = "raw" diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 5cd7d2d6e1f1..0d9ad77d2ff7 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -30,7 +30,7 @@ def _module_lowering( extra_library_file_name=None, ): - if output_type == OutputType.TORCH: + if output_type == OutputType.RAW: if verbose: print(torch_mod) return torch_mod @@ -50,7 +50,7 @@ def _module_lowering( def export_and_import( f: Union[nn.Module, ExportedProgram], *args, - output_type: Union[str, OutputType] = OutputType.TORCH, + output_type: Union[str, OutputType] = OutputType.RAW, fx_importer: Optional[FxImporter] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, experimental_support_mutation: bool = False, @@ -99,7 +99,7 @@ def export_and_import( def stateless_fx_import( gm: torch.fx.GraphModule, - output_type: Union[str, OutputType] = OutputType.TORCH, + output_type: Union[str, OutputType] = OutputType.RAW, fx_importer: Optional[FxImporter] = None, hooks: Optional[FxImporterHooks] = None, model_name: str = "main", From 7411ff2f69b5d74a283ac31c6684e9d7670013fd Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Sun, 14 Jul 2024 11:52:03 -0700 Subject: [PATCH 3/4] [Symbolic Shapes] Test coverage for unbacked symint from data dependent ops (#3542) We do have support for translating unbacked symbolic_ints that arise from data-dependent ops like `aten.nonzero`. This PR adds the python lit test coverage for the same. --- .../fx_importer/symbolic_shape_expr_test.py | 41 +++++++++++++++++-- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/test/python/fx_importer/symbolic_shape_expr_test.py b/test/python/fx_importer/symbolic_shape_expr_test.py index 3215e0f8213d..d86e98725499 100644 --- a/test/python/fx_importer/symbolic_shape_expr_test.py +++ b/test/python/fx_importer/symbolic_shape_expr_test.py @@ -84,7 +84,7 @@ def forward(self, x, y, z): # CHECK-LABEL: test_symbolic_dim_differ_by_one # CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} { # CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int -# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) +# FIXME: This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) # CHECK-DISABLED: %[[S1:.+]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int # CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> # CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32> @@ -262,7 +262,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @run # CHECK-LABEL: test_shape_div # CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,7],f32>) -> !torch.vtensor<[?,5],f32> { -# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) +# FIXME: This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) # CHECK-DISABLED: %[[S0:.+]] = torch.symbolic_int "5*s1" {min_val = 0, max_val = 5000} : !torch.int # CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = 2, max_val = 1000} : !torch.int # CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S1]]], affine_map<()[s0] -> (s0 * 5, 7)> : !torch.vtensor<[?,7],f32> @@ -433,7 +433,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @run # CHECK-LABEL: test_gather_elements # CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.+]]: !torch.vtensor<[2,3],si64>) -> !torch.vtensor<[2,3],f32> { -# CHECK: %[[S0]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int # CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> # CHECK: %[[GATHER:.+]] = torch.aten.gather %[[ARG0]], {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.bool -> !torch.vtensor<[2,3],f32> # CHECK: return %[[GATHER]] : !torch.vtensor<[2,3],f32> @@ -461,3 +461,38 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: import_symbolic_shape_expressions=True, ) print(m) + + +@run +# CHECK-LABEL: test_nonzero +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[?,2],si64> { +# FIXME: There's a bug in the torch 2.3 stable release which creates redundant symbolic_int ops for the nonzero +# output which is fixed in the 2.4 nightlies. Once we move to a 2.4 stable release, this check may be re-enabled +# CHECK-DISABLED: %[[U0:.+]] = torch.symbolic_int "u0" {min_val = 0, max_val = 9223372036854775806} : !torch.int +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 10} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: %[[NZERO:.+]] = torch.aten.nonzero %[[ARG0]] : !torch.vtensor<[?,3],f32> -> !torch.vtensor<[?,2],si64> +# CHECK-DISABLED: torch.bind_symbolic_shape %[[NZERO]], [%[[U0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],si64> +# CHECK: return %[[NZERO]] : !torch.vtensor<[?,2],si64> +def test_nonzero(): + class Nonzero(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nonzero(x) + + # Sample inputs + x = torch.randn(4, 3) + + # Dynamic dim constraints + batch = Dim("batch", min=3, max=10) + dynamic_shapes = {"x": {0: batch}} + + m = fx.export_and_import( + Nonzero(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) From fe9db781209d3ebd5e77b49436ac80ad80b4796a Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Sun, 14 Jul 2024 14:54:23 -0400 Subject: [PATCH 4/4] Allow custom ops to return an array of tensors (#3531) This PR adds support to `fx_importer.py` for handling custom ops that return an array of tensors. As long as the length of the array is consistent across runs (determined statically), then this patch will work. This does not require that the number of tensors returned is determined by the op's definition. CC @sjain-stanford --- python/torch_mlir/extras/fx_importer.py | 82 ++++++++++++++++------- test/python/fx_importer/custom_op_test.py | 47 +++++++++++++ 2 files changed, 106 insertions(+), 23 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index cb86406c55fd..c95df2504d03 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1099,6 +1099,10 @@ def value_info_to_type( return self.get_vtensor_type( val.size(), val.dtype, sparsity=sparsity, mutable=mutable ) + elif isinstance(val, list) and all( + isinstance(x, TorchFakeTensor) for x in val + ): + return IrType.parse("!torch.list", context=self._c) # Note that None is a valid scalar here, so it is important that this # is always checked as the last fallback. @@ -1227,6 +1231,7 @@ class GraphNodeImporter: "_v", "_symbol_to_value", "_multi_result_nodes", + "_unpack_list_values", "fx_importer", ] @@ -1251,6 +1256,10 @@ def __init__( # Statically multi-result nodes which we have de-tupled are noted here. # They will have their getitem calls short-circuited. self._multi_result_nodes: Set[torch_fx.Node] = set() + # If a OP returns a list, then it needs to be unpacked entirely using + # prim.ListUnpack. Cache the result of these nodes so that it only + # unpacks once instead of every time that getitem is used + self._unpack_list_values: Dict[torch_fx.Node, Tuple[Value]] = {} def bind_node_value( self, @@ -1420,29 +1429,7 @@ def import_nodes( elif op == "call_function": target = node.target if target == operator.getitem: - # Special case handling of getitem for when it is resolving - # against a function call that we know has returned multiple - # results. We short-circuit this case because we have modeled - # function calls to natively return multiple results vs tupling. - getitem_ref, getitem_index = node.args - if getitem_ref in self._multi_result_nodes: - try: - self.bind_node_value( - node, - self.resolve_node_value(getitem_ref, getitem_index), - ) - except IndexError: - raise RuntimeError( - f"getitem de-aliasing failed. This likely " - f"indicates a programmer error that usually " - f"would have happened at runtime. Please " - f"notify developers if this case happens " - f"(at {loc})." - ) - else: - raise NotImplementedError( - f"General getitem access to non-multi-result ops" - ) + self._import_getitem(loc, node) elif target in SYMBOLIC_TORCH_OPS or ( is_symbolic(node.meta.get("val")) and is_builtin_function_or_method(target) @@ -2007,6 +1994,51 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value: with loc: return cvt(arg, self, self._cc) + def _import_getitem(self, loc: Location, node: torch.fx.Node): + ref_node, index = node.args + if ref_node in self._multi_result_nodes: + # Special case handling of getitem for when it is resolving + # against a function call that we know has returned multiple + # results. We short-circuit this case because we have modeled + # function calls to natively return multiple results vs tupling. + try: + self.bind_node_value( + node, + self.resolve_node_value(ref_node, index), + ) + except IndexError: + raise RuntimeError( + f"getitem de-aliasing failed. This likely " + f"indicates a programmer error that usually " + f"would have happened at runtime. Please " + f"notify developers if this case happens " + f"(at {loc})." + ) + else: + # handle nodes that return a torch.list<...> at the MLIR level + # NOTE: the length of the list must be knowable at compile time. + if ref_node not in self._unpack_list_values: + node_result = self.resolve_node_value(ref_node, 0) + if str(node_result.type) in TORCH_LIST_TYPES: + result_types = [ + self._cc.value_info_to_type(v) for v in ref_node.meta["val"] + ] + operation = Operation.create( + "torch.prim.ListUnpack", + results=result_types, + operands=[node_result], + loc=loc, + ) + self._unpack_list_values[ref_node] = tuple(operation.results) + + try: + self.bind_node_value(node, self._unpack_list_values[ref_node][index]) + except IndexError: + raise RuntimeError( + f"getitem failed. " + f"getitem only supports lists of known length. (at {loc})" + ) + def _unpack_node_result_types( self, node: torch.fx.Node, schema: FunctionSchema ) -> List[IrType]: @@ -2337,6 +2369,10 @@ def _ref_finalizer(self, ref_id: int): "vtensor": "!torch.list>", } +TORCH_LIST_TYPES = set(PY_TYPE_TO_TORCH_LIST_TYPE.values()) | set( + PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE.values() +) + SCALAR_TYPE_TO_TORCH_MLIR_TYPE = { torch.SymInt: "!torch.int", torch.SymFloat: "!torch.float", diff --git a/test/python/fx_importer/custom_op_test.py b/test/python/fx_importer/custom_op_test.py index dbbc5ba057af..d9ce7d6096a5 100644 --- a/test/python/fx_importer/custom_op_test.py +++ b/test/python/fx_importer/custom_op_test.py @@ -84,3 +84,50 @@ def forward(self, x, y, z): import_symbolic_shape_expressions=True, ) print(m) + + +@run +# CHECK-LABEL: test_custom_op_array_output +# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,3],f32>) +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = 10} : !torch.int +# CHECK: %[[int:.+]] = torch.constant.int 4 +# CHECK: %[[V0:.+]] = torch.operator "torch.my_custom_library.array_output_op"(%[[int]], %[[ARG0]]) : (!torch.int, !torch.vtensor<[?,3],f32>) -> !torch.list +# CHECK: %[[V1:.+]]:4 = torch.prim.ListUnpack %[[V0]] : !torch.list -> !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#0, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#1, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#2, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#3, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: return %[[V1]]#0, %[[V1]]#1, %[[V1]]#2, %[[V1]]#3 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32> +def test_custom_op_array_output(): + m = Library("my_custom_library", "DEF") + m.define("array_output_op(int num_outs, Tensor a) -> Tensor[]") + + @impl(m, "array_output_op", "CompositeExplicitAutograd") + def custom_op(num_outs, a): + return [a] * num_outs + + @impl_abstract("my_custom_library::array_output_op") + def custom_op_meta(num_outs, a): + result = custom_op(num_outs, a) + return [torch.empty_like(t) for t in result] + + class ArrayOutputCustomOp(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a): + return torch.ops.my_custom_library.array_output_op(4, a) + + dim = Dim("n", max=10) + dynamic_shapes = { + "a": {0: dim}, + } + + a = torch.rand(2, 3) + m = fx.export_and_import( + ArrayOutputCustomOp(), + a, + import_symbolic_shape_expressions=True, + dynamic_shapes=dynamic_shapes, + ) + print(m)