From ac85338491efd59be5aa6532bf9e80738c798290 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Fri, 26 Apr 2024 15:47:44 +0800 Subject: [PATCH] [Stablehlo] Support AtenPowScalarOp, AtenTanOp, AtenAsinhOp, AtenAcoshOp, AtenAtanhOp, Atan2Op (#3233) --- lib/Conversion/TorchToStablehlo/Basic.cpp | 43 ++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 20 ++++-- .../test_suite/elementwise.py | 70 +++++++++++++++++++ 3 files changed, 126 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 2334d0180865..5cc6b0928898 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -960,6 +960,43 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenPowScalarOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenPowScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value lhs = adaptor.getSelf(); + auto lhsType = dyn_cast(lhs.getType()); + Value rhs = adaptor.getExponent(); + auto rhsType = dyn_cast(rhs.getType()); + + if (!rhsType) + return op.emitError("only Tensor types supported in StableHLO"); + + auto outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); + + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + + if (!lhsType) { + lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy); + } + DenseI64ArrayAttr bcastDimensions; + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); + auto loc = op.getLoc(); + Value result = rewriter.create(loc, outType, lhs, rhs, + bcastDimensions); + + rewriter.replaceOp(op, result); + return success(); +} + // PrimNumToTensorScalarOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -2050,11 +2087,15 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanhOp, stablehlo::TanhOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinOp, stablehlo::SineOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCosOp, stablehlo::CosineOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanOp, chlo::TanOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinOp, chlo::AsinOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinhOp, chlo::SinhOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcosOp, chlo::AcosOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCoshOp, chlo::CoshOp); INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanOp, chlo::AtanOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinhOp, chlo::AsinhOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcoshOp, chlo::AcoshOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanhOp, chlo::AtanhOp); #undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ @@ -2137,6 +2178,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenTensorIntOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp); INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); + INSERT_ATENOP_PATTERN(AtenPowScalarOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(AtenScalarImplicitOp); INSERT_ATENOP_PATTERN(AtenContiguousOp); @@ -2181,5 +2223,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseAndTensorOp, chlo::BroadcastAndOp); INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseOrTensorOp, chlo::BroadcastOrOp); INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseXorTensorOp, chlo::BroadcastXorOp); + INSERT_BINARY_BROADCAST_PATTERN(AtenAtan2Op, chlo::BroadcastAtan2Op); #undef INSERT_BINARY_BROADCAST_PATTERN } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7e6a2883bd6a..55a005e681dd 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -622,16 +622,10 @@ "DiagonalModule_with_offset", "DivFloatModule_basic", "DivIntModule_basic", - "ElementwiseAcoshIntModule_basic", - "ElementwiseAcoshModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", - "ElementwiseAsinhIntModule_basic", - "ElementwiseAsinhModule_basic", "ElementwiseAtan2FloatIntModule_basic", "ElementwiseAtan2TensorFloatModule_basic", "ElementwiseAtan2TensorIntModule_basic", - "ElementwiseAtanhIntModule_basic", - "ElementwiseAtanhModule_basic", "ElementwiseBitwiseLeftShiftInt32Module_basic", "ElementwiseBitwiseLeftShiftInt64Module_basic", "ElementwiseBitwiseLeftShiftInt8Module_basic", @@ -643,7 +637,6 @@ "ElementwiseErfIntModule_basic", "ElementwiseLogitModule_basic", "ElementwiseMulTensorComplexModule_basic", - "ElementwisePowScalarModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", @@ -992,6 +985,15 @@ "DropoutEvalIntModule_basic", "ElementwiseAbsFloatModule_basic", "ElementwiseAbsIntModule_basic", + "ElementwiseAcoshIntModule_basic", + "ElementwiseAcoshModule_basic", + "ElementwiseAsinhIntModule_basic", + "ElementwiseAsinhModule_basic", + "ElementwiseAtanhIntModule_basic", + "ElementwiseAtanhModule_basic", + "ElementwiseAtan2TensorFloatStaticModule_basic", + "ElementwiseAtan2TensorIntStaticModule_basic", + "ElementwiseAtan2FloatIntStaticModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "ElementwiseAtenIsinfOpModule_basic", @@ -1057,6 +1059,7 @@ "ElementwiseNegModule_basic", "ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseAndScalarStaticShapeModule_basic", + "ElementwisePowScalarModule_basic", "ElementwisePowTensorBroadcastStaticModule_basic", "ElementwisePowTensorStaticModule_basic", "ElementwisePreluStaticModule_basic", @@ -1069,6 +1072,8 @@ "ElementwiseSigmoidModule_basic", "ElementwiseSinModule_basic", "ElementwiseSqrtModule_basic", + "ElementwiseTanIntModule_basic", + "ElementwiseTanModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToI8Module_basic", "ElementwiseToDtypeIdentityModule_basic", @@ -2180,6 +2185,7 @@ "AvgPool2dDivisorOverrideModule_basic", "BroadcastDynamicDimModule_basic", "ElementwiseAtan2TensorIntModule_basic", + "ElementwiseAtan2TensorIntStaticModule_basic", "ElementwiseAtenFloorDivideScalarNegativeModule_basic", "ElementwiseAtenFloorDivideTensorNegativeModule_basic", "ElementwiseLog10IntModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 3aa8f10ff9dd..cbd2868b71d6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1792,6 +1792,28 @@ def ElementwiseAtan2TensorFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAtan2TensorFloatStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 5, 6], torch.float32, True), + ([4, 5, 6], torch.float32, True), + ]) + def forward(self, a, b): + return torch.atan2(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseAtan2TensorFloatStaticModule()) +def ElementwiseAtan2TensorFloatStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 5, 6), tu.rand(4, 5, 6)) + + +# ============================================================================== + class ElementwiseAtan2TensorIntModule(torch.nn.Module): def __init__(self): @@ -1816,6 +1838,30 @@ def ElementwiseAtan2TensorIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAtan2TensorIntStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 5, 6], torch.int32, True), + ([4, 5, 6], torch.int64, True), + ]) + def forward(self, a, b): + return torch.atan2(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntStaticModule()) +def ElementwiseAtan2TensorIntStaticModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(4, 5, 6, low=1, high=10).type(torch.int32), tu.randint(4, 5, 6, low=1, high=10)) + + +# ============================================================================== + + class ElementwiseAtan2FloatIntModule(torch.nn.Module): def __init__(self): @@ -1840,6 +1886,30 @@ def ElementwiseAtan2FloatIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAtan2FloatIntStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 5, 6], torch.int32, True), + ([4, 5, 6], torch.float64, True), + ]) + def forward(self, a, b): + return torch.atan2(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntStaticModule()) +def ElementwiseAtan2FloatIntStaticModule_basic(module, tu: TestUtils): + module.forward(tu.randint(4, 5, 6, low=1, high=10).to(torch.int32), + tu.rand(4, 5, 6).double()) + + +# ============================================================================== + + class ElementwiseLogModule(torch.nn.Module): def __init__(self):