From cc28d566ff02904bde4855f3e2c8a124ecb6f4d6 Mon Sep 17 00:00:00 2001 From: Wu Yuan Date: Mon, 20 May 2024 15:49:24 +0800 Subject: [PATCH] [Stablehlo] Support AtenTrilOp (#3359) 1. lower aten.tril to stablehlo composed by iota, select and so forth 2. add related e2e test cases --- lib/Conversion/TorchToStablehlo/Basic.cpp | 73 +++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 6 +- .../test_suite/elementwise.py | 69 ++++++++++++++++++ test/Conversion/TorchToStablehlo/basic.mlir | 22 ++++++ 4 files changed, 167 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 377795d843d9..792de89b8a53 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -2052,6 +2052,77 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenTrilOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + Location loc = op.getLoc(); + + Value self = adaptor.getSelf(); + + auto selfTy = self.getType().cast(); + if (!selfTy.hasStaticShape()) { + return op->emitError("dynamic shaped input is not supported"); + } + + ArrayRef selfShape = selfTy.getShape(); + int64_t selfRank = selfTy.getRank(); + auto iotaElementTy = mlir::IntegerType::get(op.getContext(), 64); + auto iotaTy = RankedTensorType::get( + {selfShape[selfRank - 2], selfShape[selfRank - 1]}, iotaElementTy); + Value colIdxTensor = + rewriter.create(loc, iotaTy, 1).getResult(); + Value rowIdxTensor = + rewriter.create(loc, iotaTy, 0).getResult(); + + Value diagonal = adaptor.getDiagonal(); + Value diagonalTensor = + rewriter.create(loc, diagonal).getResult(); + + auto bcastDimensions = rewriter.getDenseI64ArrayAttr({1}); + Value shiftedRowIdxTensor = rewriter.create( + loc, rowIdxTensor, diagonalTensor, bcastDimensions); + + auto cmpDirectionAttr = stablehlo::ComparisonDirectionAttr::get( + rewriter.getContext(), stablehlo::ComparisonDirection::LE); + auto cmpTypeAttr = stablehlo::ComparisonTypeAttr::get( + rewriter.getContext(), stablehlo::ComparisonType::SIGNED); + auto cmpTy = iotaTy.clone(rewriter.getI1Type()); + Value cmpRes = rewriter.create( + loc, cmpTy, colIdxTensor, shiftedRowIdxTensor, cmpDirectionAttr, + cmpTypeAttr); + + auto resTy = + getTypeConverter()->convertType(op.getType()).cast(); + + auto bcastTy = resTy.clone(rewriter.getI1Type()); + auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1}); + Value bcastedCmpRes = rewriter.create( + loc, bcastTy, cmpRes, bcastAttr); + + auto resElemTy = resTy.getElementType(); + Value zeroTensor; + if (resElemTy.isa()) { + auto constAttr = SplatElementsAttr::get( + resTy, llvm::APFloat::getZero( + resElemTy.cast().getFloatSemantics(), false)); + zeroTensor = rewriter.create(loc, resTy, constAttr); + } else if (resElemTy.isa()) { + auto constAttr = SplatElementsAttr::get( + resTy, + llvm::APInt::getZero(resElemTy.cast().getWidth())); + zeroTensor = rewriter.create(loc, resTy, constAttr); + } else { + return op.emitError("element type is not float or integer"); + } + + rewriter.replaceOpWithNewOp( + op.getOperation(), resTy, bcastedCmpRes, self, zeroTensor); + + return success(); +} + void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { @@ -2218,6 +2289,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenFmodTensorOp); INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp); INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp); + + INSERT_ATENOP_PATTERN(AtenTrilOp); #undef INSERT_ATENOP_PATTERN #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index d5b682a22ec1..9d7cf7beb795 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -524,9 +524,6 @@ "AtenSubFloatModule_basic", "AtenTopKModule_basic", "AtenTopKSmallestModule_basic", - "AtenTrilModule_basic", - "AtenTrilWithNegDiagonalModule_basic", - "AtenTrilWithPosDiagonalModule_basic", "Aten_EmbeddingBagExample_basic", "AvgPool2dDivisorOverrideModule_basic", "BernoulliTensorModule_basic", @@ -867,6 +864,9 @@ "AtenRoundIntModule_basic", "AtenSubFloatModule_basic", "AtenToDeviceModule_basic", + "AtenTrilStaticModule_basic", + "AtenTrilWithNegDiagonalStaticModule_basic", + "AtenTrilWithPosDiagonalStaticModule_basic", "Aten_CastFloatModule_basic", "Aten_CastLongModule_basic", "AvgPool1dStaticModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index f5e3c9fc4b9b..a7f27df555ba 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -5338,6 +5338,29 @@ def AtenTrilModule_basic(module, tu: TestUtils): # ============================================================================== +class AtenTrilStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([8, 8], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.tril(x) + + +@register_test_case(module_factory=lambda: AtenTrilStaticModule()) +def AtenTrilStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(8, 8)) + + +# ============================================================================== + + class AtenTrilWithPosDiagonalModule(torch.nn.Module): def __init__(self): super().__init__() @@ -5361,6 +5384,29 @@ def AtenTrilWithPosDiagonalModule_basic(module, tu: TestUtils): # ============================================================================== +class AtenTrilWithPosDiagonalStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([9, 4, 3], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.tril(x, diagonal=2) + + +@register_test_case(module_factory=lambda: AtenTrilWithPosDiagonalStaticModule()) +def AtenTrilWithPosDiagonalStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(9, 4, 3)) + + +# ============================================================================== + + class AtenTrilWithNegDiagonalModule(torch.nn.Module): def __init__(self): super().__init__() @@ -5384,6 +5430,29 @@ def AtenTrilWithNegDiagonalModule_basic(module, tu: TestUtils): # ============================================================================== +class AtenTrilWithNegDiagonalStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 1, 5, 9], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.tril(x, diagonal=-4) + + +@register_test_case(module_factory=lambda: AtenTrilWithNegDiagonalStaticModule()) +def AtenTrilWithNegDiagonalStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 5, 9)) + + +# ============================================================================== + + class AtenRoundFloatModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToStablehlo/basic.mlir b/test/Conversion/TorchToStablehlo/basic.mlir index 30f8716ebdf0..5dd685fedf30 100644 --- a/test/Conversion/TorchToStablehlo/basic.mlir +++ b/test/Conversion/TorchToStablehlo/basic.mlir @@ -319,3 +319,25 @@ func.func @torch.aten.bitwise_right_shift.Tensor(%arg0: !torch.vtensor<[3,4],si6 %0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si64>, !torch.vtensor<[3,4],si64> -> !torch.vtensor<[3,4],si64> return %0 : !torch.vtensor<[3,4],si64> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.tril( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[2,3,5],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.int) -> !torch.vtensor<[2,3,5],f32> +// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[2,3,5],f32> -> tensor<2x3x5xf32> +// CHECK: %[[VAL_1:.*]] = torch_c.to_i64 %[[ARG_1]] +// CHECK: %[[VAL_2:.*]] = stablehlo.iota dim = 1 : tensor<3x5xi64> +// CHECK: %[[VAL_3:.*]] = stablehlo.iota dim = 0 : tensor<3x5xi64> +// CHECK: %[[VAL_4:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xi64> +// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_3]], %[[VAL_4]] {broadcast_dimensions = array} : (tensor<3x5xi64>, tensor<1xi64>) -> tensor<3x5xi64> +// CHECK: %[[VAL_6:.*]] = stablehlo.compare LE, %[[VAL_2]], %[[VAL_5]], SIGNED : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1> +// CHECK: %[[VAL_7:.*]] = stablehlo.broadcast_in_dim %[[VAL_6]], dims = [1, 2] : (tensor<3x5xi1>) -> tensor<2x3x5xi1> +// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<2x3x5xf32> +// CHECK: %[[VAL_9:.*]] = stablehlo.select %[[VAL_7]], %[[VAL_0]], %[[VAL_8]] : tensor<2x3x5xi1>, tensor<2x3x5xf32> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<2x3x5xf32> -> !torch.vtensor<[2,3,5],f32> +// CHECK: return %[[VAL_10:.*]] : !torch.vtensor<[2,3,5],f32> +func.func @torch.aten.tril(%arg0: !torch.vtensor<[2,3,5],f32>, %arg1: !torch.int) -> !torch.vtensor<[2,3,5],f32> { + %0 = torch.aten.tril %arg0, %arg1:!torch.vtensor<[2,3,5],f32>, !torch.int -> !torch.vtensor<[2,3,5],f32> + return %0 : !torch.vtensor<[2,3,5],f32> +}