Skip to content

Commit

Permalink
Tosa: Support AtenRsubScalarOp
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Jul 10, 2023
1 parent a5670e8 commit 1fd7b35
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 18 deletions.
5 changes: 5 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@
# ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
"ElementwiseSubScalarFloatModule_basic",
"ElementwiseSubScalarIntModule_basic",
"RsubIntStaticModule_noalpha_basic",

# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
"ElementwiseDivRoundingModeFloorModule_basic",
Expand Down Expand Up @@ -646,6 +647,7 @@
"RsubFloatModule_noalpha_basic",
"RsubIntModule_basic",
"RsubIntModule_noalpha_basic",
"RsubIntStaticModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
"ScalarTensorDefaultDtypeModule_basic",
"ScalarTensorFloat32Module_basic",
Expand Down Expand Up @@ -970,6 +972,9 @@
"ElementwiseCeilModule_basic",
"ElementwiseReciprocalModule_basic",
"ElementwiseIsnanModule_basic",
"RsubIntModule_basic",
"RsubIntModule_noalpha_basic",
"RsubIntStaticModule_noalpha_basic",
"TypePromotionAlphaWiderModule_basic",
"Conv1dNoPaddingModule_basic",
"Conv1dNoPaddingGroupModule_basic",
Expand Down
38 changes: 22 additions & 16 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,22 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
op, "Only floating-point or integer datatype legalization supported");
}

if (!rhsType) {
if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(),
rhs, outElemTy, {}))) {
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA operation");
}
rhsType = rhs.getType().dyn_cast<TensorType>();
}

// aten.rsub(lhs, rhs, alpha) computes rhs - lhs * alpha
if constexpr(std::is_same<AtenOpT, AtenRsubScalarOp>::value) {
std::swap(lhs, rhs);
std::swap(lhsType, rhsType);
}

Type rhsAlphaMulElemType;
if (outElemTy.isa<mlir::FloatType>()) {
rhsAlphaMulElemType = outElemTy;
Expand All @@ -271,25 +287,14 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
rhsAlphaMulElemType = rewriter.getIntegerType(32);
}

// if right is scalar, rhgType==None, which need to be manually cast to
// TensorType else right is tensor, rhsType==tensor<i32/i64/f32>
Value rhsAsTensor;
if (!rhsType) {
if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(),
rhsAsTensor, rhsAlphaMulElemType, {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA operation");
} else if (rhsType.getElementType() != rhsAlphaMulElemType) {
if (rhsType.getElementType() != rhsAlphaMulElemType) {
// right is tensor, rhsType == tensor<i32/i64/f32>
// right must be cast to same type as the alpha, so MulOp success
rhsType = RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType);
rhs = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType), rhs);
// reinitialize right value type to tensor<i32/f32>
rhsType = rhs.getType().dyn_cast<TensorType>();
rhsType, rhs);
}
auto rhsTensor = rhsType ? rhs : rhsAsTensor;

// Handle scalar value alpha.
// It should be either f32/i32
Expand All @@ -305,8 +310,8 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {

auto mulAlphaOp = tosa::createMulOpAndCast(
rewriter, op,
rhsType ? rhsType : RankedTensorType::get({}, rhsAlphaMulElemType),
rhsTensor, alphaTensor, /*shift=*/0);
rhsType,
rhs, alphaTensor, /*shift=*/0);

if (outElemTy.isInteger(64)) {
// Tosa doesn't support 64-bit elementwise addition and subtraction.
Expand Down Expand Up @@ -5759,6 +5764,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp)
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp)
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp)
INSERT_BINARY_ADDSUB_PATTERN(AtenRsubScalarOp, tosa::SubOp)
#undef INSERT_BINARY_ADDSUB_PATTERN

#define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \
Expand Down
23 changes: 23 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,29 @@ def forward(self, x):
def RsubIntModule_noalpha_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, high=100))


# ==============================================================================


class RsubIntStaticModule_noalpha(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.rsub(x, 2.)


@register_test_case(module_factory=lambda: RsubIntStaticModule_noalpha())
def RsubIntStaticModule_noalpha_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, high=100))


# ==============================================================================


Expand Down
5 changes: 3 additions & 2 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1015,9 +1015,10 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torc
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.int 256
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<256> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<256> : tensor<i64>}> : () -> tensor<i64>
// CHECK: %[[VAL_4_CAST:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor<i64>) -> tensor<i32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_4]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_4_CAST]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[VAL_7:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32>
// CHECK: %[[VAL_8:.*]] = "tosa.add"(%[[VAL_7]], %[[VAL_6]]) : (tensor<1x1x128x128xi32>, tensor<i32>) -> tensor<1x1x128x128xi32>
// CHECK: %[[VAL_9:.*]] = "tosa.cast"(%[[VAL_8]]) : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64>
Expand Down

0 comments on commit 1fd7b35

Please sign in to comment.