diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index b6bd3b8b6a36..3e37456f3086 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -373,16 +373,19 @@ static FailureOr> createTMTensorTopkOp( } namespace { -class ConvertAtenScatterSrcOp : public OpConversionPattern { +template +class ConvertAtenScatterOp : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult - matchAndRewrite(AtenScatterSrcOp op, OpAdaptor adaptor, + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); - const TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = + OpConversionPattern::getTypeConverter(); Value self = adaptor.getSelf(); Value index = adaptor.getIndex(); Value src = adaptor.getSrc(); @@ -410,7 +413,19 @@ class ConvertAtenScatterSrcOp : public OpConversionPattern { /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value updatesElement, Value inputElement) { - b.create(loc, updatesElement); + if (isa(op)) { + b.create(loc, updatesElement); + } else if (isa(op)) { + if (isa(selfType.getElementType())) { + Value add = + b.create(loc, inputElement, updatesElement); + b.create(loc, add); + } else if (isa(selfType.getElementType())) { + Value add = + b.create(loc, inputElement, updatesElement); + b.create(loc, add); + } + } }); auto resultType = cast( @@ -2169,7 +2184,11 @@ class ConvertTorchToTMTensor context); target.addIllegalOp(); - patterns.add(typeConverter, context); + patterns.add>(typeConverter, + context); + target.addIllegalOp(); + patterns.add>(typeConverter, + context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index bc8f252e6dfc..65f9f16e0425 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9787,6 +9787,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.scatter.value\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.float) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.scatter_add\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.list) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11567,6 +11570,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scatter_add\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.masked_scatter\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 504c7ca9d6f7..f9576c984c73 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2682,6 +2682,7 @@ "ScatterReduceIntMaxModuleIncludeSelf", "ScatterReduceIntMinModuleIncludeSelf", "ScatterValueFloatModule_basic", + "ScatterAddStaticModule_basic", # Failure - onnx_lowering: onnx.ScatterND "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 37db50050b43..553398905700 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1810,6 +1810,9 @@ def aten〇scatter〇src〡shape(self: List[int], dim: int, index: List[int], sr def aten〇scatter〇value〡shape(self: List[int], dim: int, index: List[int], value: float) -> List[int]: return self +def aten〇scatter_add〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]: + return self + def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]: return upstream_shape_functions.index_select(self, dim, index) @@ -3115,6 +3118,12 @@ def aten〇scatter〇value〡dtype(self_rank_dtype: Tuple[int, int], dim: int, i self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function( + [Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES]) +def aten〇scatter_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( [Invocation(TensorOfShape(3, dtype=dtype), TensorOfShape(3, dtype=torch.bool), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES]) def aten〇masked_scatter〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], source_rank_dtype: Tuple[int, int]) -> int: diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py index ba44dc076904..ee85855e4aa8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -1020,6 +1020,31 @@ def ScatterValueIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ScatterAddStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([10, 8, 6], torch.float32, True), + ([2, 4, 3], torch.int64, True), + ([5, 8, 6], torch.float32, True), + ] + ) + def forward(self, input, index, src): + return torch.ops.aten.scatter_add(input, 0, index, src) + + +@register_test_case(module_factory=lambda: ScatterAddStaticModule()) +def ScatterAddStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + + +# ============================================================================== + + class ScatterReduceFloatModule(torch.nn.Module): include_self: bool reduce_type: str