diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 78a0e29cea67..102a73b58eb6 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -373,16 +373,19 @@ static FailureOr> createTMTensorTopkOp( } namespace { -class ConvertAtenScatterSrcOp : public OpConversionPattern { +template +class ConvertAtenScatterOp : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult - matchAndRewrite(AtenScatterSrcOp op, OpAdaptor adaptor, + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); - const TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = + OpConversionPattern::getTypeConverter(); Value self = adaptor.getSelf(); Value index = adaptor.getIndex(); Value src = adaptor.getSrc(); @@ -410,7 +413,19 @@ class ConvertAtenScatterSrcOp : public OpConversionPattern { /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value updatesElement, Value inputElement) { - b.create(loc, updatesElement); + if (isa(op)) { + b.create(loc, updatesElement); + } else if (isa(op)) { + if (isa(selfType.getElementType())) { + Value add = + b.create(loc, inputElement, updatesElement); + b.create(loc, add); + } else if (isa(selfType.getElementType())) { + Value add = + b.create(loc, inputElement, updatesElement); + b.create(loc, add); + } + } }); auto resultType = cast( @@ -2172,7 +2187,11 @@ class ConvertTorchToTMTensor context); target.addIllegalOp(); - patterns.add(typeConverter, context); + patterns.add>(typeConverter, + context); + target.addIllegalOp(); + patterns.add>(typeConverter, + context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 55d4ef7b9594..a016d0e30a09 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9808,6 +9808,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.scatter.value\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.float) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.scatter_add\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.list) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11605,6 +11608,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scatter_add\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.masked_scatter\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index bb92d8a77845..c8c70ccdf251 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2874,6 +2874,7 @@ "ScatterReduceIntMaxModuleIncludeSelf", "ScatterReduceIntMinModuleIncludeSelf", "ScatterValueFloatModule_basic", + "ScatterAddStaticModule_basic", # Failure - onnx_lowering: onnx.ScatterND "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6934753a9eae..2401e8837fb4 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1817,6 +1817,9 @@ def aten〇scatter〇src〡shape(self: List[int], dim: int, index: List[int], sr def aten〇scatter〇value〡shape(self: List[int], dim: int, index: List[int], value: float) -> List[int]: return self +def aten〇scatter_add〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]: + return self + def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]: return upstream_shape_functions.index_select(self, dim, index) @@ -3140,6 +3143,12 @@ def aten〇scatter〇value〡dtype(self_rank_dtype: Tuple[int, int], dim: int, i self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function( + [Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES]) +def aten〇scatter_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( [Invocation(TensorOfShape(3, dtype=dtype), TensorOfShape(3, dtype=torch.bool), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES]) def aten〇masked_scatter〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], source_rank_dtype: Tuple[int, int]) -> int: diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py index 7fd674c9f20d..efe927134b4e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -1053,6 +1053,31 @@ def ScatterValueIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ScatterAddStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([10, 8, 6], torch.float32, True), + ([2, 4, 3], torch.int64, True), + ([5, 8, 6], torch.float32, True), + ] + ) + def forward(self, input, index, src): + return torch.ops.aten.scatter_add(input, 0, index, src) + + +@register_test_case(module_factory=lambda: ScatterAddStaticModule()) +def ScatterAddStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + + +# ============================================================================== + + class ScatterReduceFloatModule(torch.nn.Module): include_self: bool reduce_type: str diff --git a/python/torch_mlir/compiler_utils.py b/python/torch_mlir/compiler_utils.py index f9071b9239e2..a68b27d43226 100644 --- a/python/torch_mlir/compiler_utils.py +++ b/python/torch_mlir/compiler_utils.py @@ -86,12 +86,12 @@ def run_pipeline_with_repro_report( class OutputType(Enum): - # Output torch dialect. When converting from FX, this will be immediately - # after the import from FX to MLIR. When converting from torchscript, - # this will come after some cleanup passes which attempt to de-alias, - # decompose and infer shapes. These should be roughly the same level of - # abstraction since those steps are done within PyTorch itself - # when coming directly from Dynamo/FX. + # Output torch dialect in backend form. When converting from TorchDynamo, + # this comes after some decomposition and reduce op variants passes are + # applied to the raw torch dialect. When converting from TorchScript, this + # comes after some cleanup passes which attempt to de-alias, decompose and infer shapes. + # These should be roughly the same level of abstraction since those + # steps are done within PyTorch itself when coming directly from Dynamo/FX. TORCH = "torch" # The output type contains a mix of `linalg`-on-tensors ops, `scf`, and @@ -108,7 +108,8 @@ class OutputType(Enum): # as taking the `TORCH` output type and lowering it to StableHLO. STABLEHLO = "stablehlo" - # Raw output of the JIT IR importer. This is not expected to be useful + # Raw output of the JIT IR importer in the TorchScript frontend or that of + # the FX IR importer in the TorchDynamo frontend. This is not expected to be useful # for end-users, but can be convenient for development or reporting bugs. RAW = "raw" diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index cb86406c55fd..c95df2504d03 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1099,6 +1099,10 @@ def value_info_to_type( return self.get_vtensor_type( val.size(), val.dtype, sparsity=sparsity, mutable=mutable ) + elif isinstance(val, list) and all( + isinstance(x, TorchFakeTensor) for x in val + ): + return IrType.parse("!torch.list", context=self._c) # Note that None is a valid scalar here, so it is important that this # is always checked as the last fallback. @@ -1227,6 +1231,7 @@ class GraphNodeImporter: "_v", "_symbol_to_value", "_multi_result_nodes", + "_unpack_list_values", "fx_importer", ] @@ -1251,6 +1256,10 @@ def __init__( # Statically multi-result nodes which we have de-tupled are noted here. # They will have their getitem calls short-circuited. self._multi_result_nodes: Set[torch_fx.Node] = set() + # If a OP returns a list, then it needs to be unpacked entirely using + # prim.ListUnpack. Cache the result of these nodes so that it only + # unpacks once instead of every time that getitem is used + self._unpack_list_values: Dict[torch_fx.Node, Tuple[Value]] = {} def bind_node_value( self, @@ -1420,29 +1429,7 @@ def import_nodes( elif op == "call_function": target = node.target if target == operator.getitem: - # Special case handling of getitem for when it is resolving - # against a function call that we know has returned multiple - # results. We short-circuit this case because we have modeled - # function calls to natively return multiple results vs tupling. - getitem_ref, getitem_index = node.args - if getitem_ref in self._multi_result_nodes: - try: - self.bind_node_value( - node, - self.resolve_node_value(getitem_ref, getitem_index), - ) - except IndexError: - raise RuntimeError( - f"getitem de-aliasing failed. This likely " - f"indicates a programmer error that usually " - f"would have happened at runtime. Please " - f"notify developers if this case happens " - f"(at {loc})." - ) - else: - raise NotImplementedError( - f"General getitem access to non-multi-result ops" - ) + self._import_getitem(loc, node) elif target in SYMBOLIC_TORCH_OPS or ( is_symbolic(node.meta.get("val")) and is_builtin_function_or_method(target) @@ -2007,6 +1994,51 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value: with loc: return cvt(arg, self, self._cc) + def _import_getitem(self, loc: Location, node: torch.fx.Node): + ref_node, index = node.args + if ref_node in self._multi_result_nodes: + # Special case handling of getitem for when it is resolving + # against a function call that we know has returned multiple + # results. We short-circuit this case because we have modeled + # function calls to natively return multiple results vs tupling. + try: + self.bind_node_value( + node, + self.resolve_node_value(ref_node, index), + ) + except IndexError: + raise RuntimeError( + f"getitem de-aliasing failed. This likely " + f"indicates a programmer error that usually " + f"would have happened at runtime. Please " + f"notify developers if this case happens " + f"(at {loc})." + ) + else: + # handle nodes that return a torch.list<...> at the MLIR level + # NOTE: the length of the list must be knowable at compile time. + if ref_node not in self._unpack_list_values: + node_result = self.resolve_node_value(ref_node, 0) + if str(node_result.type) in TORCH_LIST_TYPES: + result_types = [ + self._cc.value_info_to_type(v) for v in ref_node.meta["val"] + ] + operation = Operation.create( + "torch.prim.ListUnpack", + results=result_types, + operands=[node_result], + loc=loc, + ) + self._unpack_list_values[ref_node] = tuple(operation.results) + + try: + self.bind_node_value(node, self._unpack_list_values[ref_node][index]) + except IndexError: + raise RuntimeError( + f"getitem failed. " + f"getitem only supports lists of known length. (at {loc})" + ) + def _unpack_node_result_types( self, node: torch.fx.Node, schema: FunctionSchema ) -> List[IrType]: @@ -2337,6 +2369,10 @@ def _ref_finalizer(self, ref_id: int): "vtensor": "!torch.list>", } +TORCH_LIST_TYPES = set(PY_TYPE_TO_TORCH_LIST_TYPE.values()) | set( + PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE.values() +) + SCALAR_TYPE_TO_TORCH_MLIR_TYPE = { torch.SymInt: "!torch.int", torch.SymFloat: "!torch.float", diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 5cd7d2d6e1f1..0d9ad77d2ff7 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -30,7 +30,7 @@ def _module_lowering( extra_library_file_name=None, ): - if output_type == OutputType.TORCH: + if output_type == OutputType.RAW: if verbose: print(torch_mod) return torch_mod @@ -50,7 +50,7 @@ def _module_lowering( def export_and_import( f: Union[nn.Module, ExportedProgram], *args, - output_type: Union[str, OutputType] = OutputType.TORCH, + output_type: Union[str, OutputType] = OutputType.RAW, fx_importer: Optional[FxImporter] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, experimental_support_mutation: bool = False, @@ -99,7 +99,7 @@ def export_and_import( def stateless_fx_import( gm: torch.fx.GraphModule, - output_type: Union[str, OutputType] = OutputType.TORCH, + output_type: Union[str, OutputType] = OutputType.RAW, fx_importer: Optional[FxImporter] = None, hooks: Optional[FxImporterHooks] = None, model_name: str = "main", diff --git a/test/python/fx_importer/custom_op_test.py b/test/python/fx_importer/custom_op_test.py index dbbc5ba057af..d9ce7d6096a5 100644 --- a/test/python/fx_importer/custom_op_test.py +++ b/test/python/fx_importer/custom_op_test.py @@ -84,3 +84,50 @@ def forward(self, x, y, z): import_symbolic_shape_expressions=True, ) print(m) + + +@run +# CHECK-LABEL: test_custom_op_array_output +# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,3],f32>) +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = 10} : !torch.int +# CHECK: %[[int:.+]] = torch.constant.int 4 +# CHECK: %[[V0:.+]] = torch.operator "torch.my_custom_library.array_output_op"(%[[int]], %[[ARG0]]) : (!torch.int, !torch.vtensor<[?,3],f32>) -> !torch.list +# CHECK: %[[V1:.+]]:4 = torch.prim.ListUnpack %[[V0]] : !torch.list -> !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#0, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#1, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#2, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#3, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: return %[[V1]]#0, %[[V1]]#1, %[[V1]]#2, %[[V1]]#3 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32> +def test_custom_op_array_output(): + m = Library("my_custom_library", "DEF") + m.define("array_output_op(int num_outs, Tensor a) -> Tensor[]") + + @impl(m, "array_output_op", "CompositeExplicitAutograd") + def custom_op(num_outs, a): + return [a] * num_outs + + @impl_abstract("my_custom_library::array_output_op") + def custom_op_meta(num_outs, a): + result = custom_op(num_outs, a) + return [torch.empty_like(t) for t in result] + + class ArrayOutputCustomOp(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a): + return torch.ops.my_custom_library.array_output_op(4, a) + + dim = Dim("n", max=10) + dynamic_shapes = { + "a": {0: dim}, + } + + a = torch.rand(2, 3) + m = fx.export_and_import( + ArrayOutputCustomOp(), + a, + import_symbolic_shape_expressions=True, + dynamic_shapes=dynamic_shapes, + ) + print(m) diff --git a/test/python/fx_importer/symbolic_shape_expr_test.py b/test/python/fx_importer/symbolic_shape_expr_test.py index 3215e0f8213d..d86e98725499 100644 --- a/test/python/fx_importer/symbolic_shape_expr_test.py +++ b/test/python/fx_importer/symbolic_shape_expr_test.py @@ -84,7 +84,7 @@ def forward(self, x, y, z): # CHECK-LABEL: test_symbolic_dim_differ_by_one # CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} { # CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int -# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) +# FIXME: This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) # CHECK-DISABLED: %[[S1:.+]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int # CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> # CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32> @@ -262,7 +262,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @run # CHECK-LABEL: test_shape_div # CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,7],f32>) -> !torch.vtensor<[?,5],f32> { -# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) +# FIXME: This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) # CHECK-DISABLED: %[[S0:.+]] = torch.symbolic_int "5*s1" {min_val = 0, max_val = 5000} : !torch.int # CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = 2, max_val = 1000} : !torch.int # CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S1]]], affine_map<()[s0] -> (s0 * 5, 7)> : !torch.vtensor<[?,7],f32> @@ -433,7 +433,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @run # CHECK-LABEL: test_gather_elements # CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.+]]: !torch.vtensor<[2,3],si64>) -> !torch.vtensor<[2,3],f32> { -# CHECK: %[[S0]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int # CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> # CHECK: %[[GATHER:.+]] = torch.aten.gather %[[ARG0]], {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.bool -> !torch.vtensor<[2,3],f32> # CHECK: return %[[GATHER]] : !torch.vtensor<[2,3],f32> @@ -461,3 +461,38 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: import_symbolic_shape_expressions=True, ) print(m) + + +@run +# CHECK-LABEL: test_nonzero +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[?,2],si64> { +# FIXME: There's a bug in the torch 2.3 stable release which creates redundant symbolic_int ops for the nonzero +# output which is fixed in the 2.4 nightlies. Once we move to a 2.4 stable release, this check may be re-enabled +# CHECK-DISABLED: %[[U0:.+]] = torch.symbolic_int "u0" {min_val = 0, max_val = 9223372036854775806} : !torch.int +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 10} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: %[[NZERO:.+]] = torch.aten.nonzero %[[ARG0]] : !torch.vtensor<[?,3],f32> -> !torch.vtensor<[?,2],si64> +# CHECK-DISABLED: torch.bind_symbolic_shape %[[NZERO]], [%[[U0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],si64> +# CHECK: return %[[NZERO]] : !torch.vtensor<[?,2],si64> +def test_nonzero(): + class Nonzero(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nonzero(x) + + # Sample inputs + x = torch.randn(4, 3) + + # Dynamic dim constraints + batch = Dim("batch", min=3, max=10) + dynamic_shapes = {"x": {0: batch}} + + m = fx.export_and_import( + Nonzero(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m)