From 1fd7b358869490e5a2cca48c4392c4490908efc8 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 10 Jul 2023 09:00:11 +0200 Subject: [PATCH] Tosa: Support AtenRsubScalarOp --- e2e_testing/xfail_sets.py | 5 +++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 38 +++++++++++-------- .../test_suite/elementwise.py | 23 +++++++++++ test/Conversion/TorchToTosa/basic.mlir | 5 ++- 4 files changed, 53 insertions(+), 18 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 7ace41ffead3..b4b7e4d5cca7 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -248,6 +248,7 @@ # ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' "ElementwiseSubScalarFloatModule_basic", "ElementwiseSubScalarIntModule_basic", + "RsubIntStaticModule_noalpha_basic", # ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode "ElementwiseDivRoundingModeFloorModule_basic", @@ -646,6 +647,7 @@ "RsubFloatModule_noalpha_basic", "RsubIntModule_basic", "RsubIntModule_noalpha_basic", + "RsubIntStaticModule_noalpha_basic", "RsubInt0d_NumToTensor_Module_basic", "ScalarTensorDefaultDtypeModule_basic", "ScalarTensorFloat32Module_basic", @@ -970,6 +972,9 @@ "ElementwiseCeilModule_basic", "ElementwiseReciprocalModule_basic", "ElementwiseIsnanModule_basic", + "RsubIntModule_basic", + "RsubIntModule_noalpha_basic", + "RsubIntStaticModule_noalpha_basic", "TypePromotionAlphaWiderModule_basic", "Conv1dNoPaddingModule_basic", "Conv1dNoPaddingGroupModule_basic", diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index f0d9e9beb2ad..cb157af5ca68 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -263,6 +263,22 @@ class ConvertAtenAddSubOp : public OpConversionPattern { op, "Only floating-point or integer datatype legalization supported"); } + if (!rhsType) { + if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(), + rhs, outElemTy, {}))) { + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA operation"); + } + rhsType = rhs.getType().dyn_cast(); + } + + // aten.rsub(lhs, rhs, alpha) computes rhs - lhs * alpha + if constexpr(std::is_same::value) { + std::swap(lhs, rhs); + std::swap(lhsType, rhsType); + } + Type rhsAlphaMulElemType; if (outElemTy.isa()) { rhsAlphaMulElemType = outElemTy; @@ -271,25 +287,14 @@ class ConvertAtenAddSubOp : public OpConversionPattern { rhsAlphaMulElemType = rewriter.getIntegerType(32); } - // if right is scalar, rhgType==None, which need to be manually cast to - // TensorType else right is tensor, rhsType==tensor - Value rhsAsTensor; - if (!rhsType) { - if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(), - rhsAsTensor, rhsAlphaMulElemType, {}))) - return rewriter.notifyMatchFailure( - op, "Currently only scalar constants are supported for " - "conversion in TOSA operation"); - } else if (rhsType.getElementType() != rhsAlphaMulElemType) { + if (rhsType.getElementType() != rhsAlphaMulElemType) { // right is tensor, rhsType == tensor // right must be cast to same type as the alpha, so MulOp success + rhsType = RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType); rhs = rewriter.create( op->getLoc(), - RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType), rhs); - // reinitialize right value type to tensor - rhsType = rhs.getType().dyn_cast(); + rhsType, rhs); } - auto rhsTensor = rhsType ? rhs : rhsAsTensor; // Handle scalar value alpha. // It should be either f32/i32 @@ -305,8 +310,8 @@ class ConvertAtenAddSubOp : public OpConversionPattern { auto mulAlphaOp = tosa::createMulOpAndCast( rewriter, op, - rhsType ? rhsType : RankedTensorType::get({}, rhsAlphaMulElemType), - rhsTensor, alphaTensor, /*shift=*/0); + rhsType, + rhs, alphaTensor, /*shift=*/0); if (outElemTy.isInteger(64)) { // Tosa doesn't support 64-bit elementwise addition and subtraction. @@ -5759,6 +5764,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp) INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenRsubScalarOp, tosa::SubOp) #undef INSERT_BINARY_ADDSUB_PATTERN #define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \ diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 2a68d4ba5883..723a87d1eec6 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -844,6 +844,29 @@ def forward(self, x): def RsubIntModule_noalpha_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, high=100)) + +# ============================================================================== + + +class RsubIntStaticModule_noalpha(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return torch.rsub(x, 2.) + + +@register_test_case(module_factory=lambda: RsubIntStaticModule_noalpha()) +def RsubIntStaticModule_noalpha_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, high=100)) + + # ============================================================================== diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index df8c148902b9..2705f453bdf5 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1015,9 +1015,10 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torc // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64> // CHECK: %[[VAL_2:.*]] = torch.constant.int 1 // CHECK: %[[VAL_3:.*]] = torch.constant.int 256 -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<256> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<256> : tensor}> : () -> tensor +// CHECK: %[[VAL_4_CAST:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor) -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_4]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_4_CAST]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32> // CHECK: %[[VAL_8:.*]] = "tosa.add"(%[[VAL_7]], %[[VAL_6]]) : (tensor<1x1x128x128xi32>, tensor) -> tensor<1x1x128x128xi32> // CHECK: %[[VAL_9:.*]] = "tosa.cast"(%[[VAL_8]]) : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64>