Skip to content

Commit

Permalink
[Torch] add support for aten.scatter_add (llvm#3534)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu authored Jul 12, 2024
1 parent 0fb8b01 commit 5e4f00a
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 6 deletions.
31 changes: 25 additions & 6 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,16 +373,19 @@ static FailureOr<SmallVector<Value>> createTMTensorTopkOp(
}

namespace {
class ConvertAtenScatterSrcOp : public OpConversionPattern<AtenScatterSrcOp> {
template <typename AtenOpT>
class ConvertAtenScatterOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern::OpConversionPattern;
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenScatterSrcOp op, OpAdaptor adaptor,
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
const TypeConverter *typeConverter = getTypeConverter();
const TypeConverter *typeConverter =
OpConversionPattern<AtenOpT>::getTypeConverter();
Value self = adaptor.getSelf();
Value index = adaptor.getIndex();
Value src = adaptor.getSrc();
Expand Down Expand Up @@ -410,7 +413,19 @@ class ConvertAtenScatterSrcOp : public OpConversionPattern<AtenScatterSrcOp> {
/*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value updatesElement,
Value inputElement) {
b.create<TMTensor::YieldOp>(loc, updatesElement);
if (isa<AtenScatterSrcOp>(op)) {
b.create<TMTensor::YieldOp>(loc, updatesElement);
} else if (isa<AtenScatterAddOp>(op)) {
if (isa<mlir::IntegerType>(selfType.getElementType())) {
Value add =
b.create<arith::AddIOp>(loc, inputElement, updatesElement);
b.create<TMTensor::YieldOp>(loc, add);
} else if (isa<mlir::FloatType>(selfType.getElementType())) {
Value add =
b.create<arith::AddFOp>(loc, inputElement, updatesElement);
b.create<TMTensor::YieldOp>(loc, add);
}
}
});

auto resultType = cast<RankedTensorType>(
Expand Down Expand Up @@ -2169,7 +2184,11 @@ class ConvertTorchToTMTensor
context);

target.addIllegalOp<AtenScatterSrcOp>();
patterns.add<ConvertAtenScatterSrcOp>(typeConverter, context);
patterns.add<ConvertAtenScatterOp<AtenScatterSrcOp>>(typeConverter,
context);
target.addIllegalOp<AtenScatterAddOp>();
patterns.add<ConvertAtenScatterOp<AtenScatterAddOp>>(typeConverter,
context);
target.addIllegalOp<AtenKthvalueOp>();
patterns.add<ConvertAtenKthvalueOp>(typeConverter, context);

Expand Down
7 changes: 7 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9787,6 +9787,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.scatter.value\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.float) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.scatter_add\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -11567,6 +11570,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scatter_add\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.masked_scatter\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2682,6 +2682,7 @@
"ScatterReduceIntMaxModuleIncludeSelf",
"ScatterReduceIntMinModuleIncludeSelf",
"ScatterValueFloatModule_basic",
"ScatterAddStaticModule_basic",
# Failure - onnx_lowering: onnx.ScatterND
"IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,9 @@ def aten〇scatter〇src〡shape(self: List[int], dim: int, index: List[int], sr
def aten〇scatter〇value〡shape(self: List[int], dim: int, index: List[int], value: float) -> List[int]:
return self

def aten〇scatter_add〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]:
return self

def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]:
return upstream_shape_functions.index_select(self, dim, index)

Expand Down Expand Up @@ -3115,6 +3118,12 @@ def aten〇scatter〇value〡dtype(self_rank_dtype: Tuple[int, int], dim: int, i
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(
[Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
def aten〇scatter_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(
[Invocation(TensorOfShape(3, dtype=dtype), TensorOfShape(3, dtype=torch.bool), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
def aten〇masked_scatter〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], source_rank_dtype: Tuple[int, int]) -> int:
Expand Down
25 changes: 25 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,31 @@ def ScatterValueIntModule_basic(module, tu: TestUtils):
# ==============================================================================


class ScatterAddStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([10, 8, 6], torch.float32, True),
([2, 4, 3], torch.int64, True),
([5, 8, 6], torch.float32, True),
]
)
def forward(self, input, index, src):
return torch.ops.aten.scatter_add(input, 0, index, src)


@register_test_case(module_factory=lambda: ScatterAddStaticModule())
def ScatterAddStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6))


# ==============================================================================


class ScatterReduceFloatModule(torch.nn.Module):
include_self: bool
reduce_type: str
Expand Down

0 comments on commit 5e4f00a

Please sign in to comment.