Skip to content

Commit

Permalink
[Stablehlo] Support AtenTrilOp (llvm#3359)
Browse files Browse the repository at this point in the history
1. lower aten.tril to stablehlo composed by iota, select and so forth
2. add related e2e test cases
  • Loading branch information
william0021224 authored May 20, 2024
1 parent 8814d0a commit cc28d56
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 3 deletions.
73 changes: 73 additions & 0 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,77 @@ LogicalResult ConvertAtenOp<AtenBitwiseRightShiftTensorOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
AtenTrilOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

Location loc = op.getLoc();

Value self = adaptor.getSelf();

auto selfTy = self.getType().cast<RankedTensorType>();
if (!selfTy.hasStaticShape()) {
return op->emitError("dynamic shaped input is not supported");
}

ArrayRef<int64_t> selfShape = selfTy.getShape();
int64_t selfRank = selfTy.getRank();
auto iotaElementTy = mlir::IntegerType::get(op.getContext(), 64);
auto iotaTy = RankedTensorType::get(
{selfShape[selfRank - 2], selfShape[selfRank - 1]}, iotaElementTy);
Value colIdxTensor =
rewriter.create<stablehlo::IotaOp>(loc, iotaTy, 1).getResult();
Value rowIdxTensor =
rewriter.create<stablehlo::IotaOp>(loc, iotaTy, 0).getResult();

Value diagonal = adaptor.getDiagonal();
Value diagonalTensor =
rewriter.create<tensor::FromElementsOp>(loc, diagonal).getResult();

auto bcastDimensions = rewriter.getDenseI64ArrayAttr({1});
Value shiftedRowIdxTensor = rewriter.create<chlo::BroadcastAddOp>(
loc, rowIdxTensor, diagonalTensor, bcastDimensions);

auto cmpDirectionAttr = stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::LE);
auto cmpTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
auto cmpTy = iotaTy.clone(rewriter.getI1Type());
Value cmpRes = rewriter.create<stablehlo::CompareOp>(
loc, cmpTy, colIdxTensor, shiftedRowIdxTensor, cmpDirectionAttr,
cmpTypeAttr);

auto resTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();

auto bcastTy = resTy.clone(rewriter.getI1Type());
auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1});
Value bcastedCmpRes = rewriter.create<stablehlo::BroadcastInDimOp>(
loc, bcastTy, cmpRes, bcastAttr);

auto resElemTy = resTy.getElementType();
Value zeroTensor;
if (resElemTy.isa<mlir::FloatType>()) {
auto constAttr = SplatElementsAttr::get(
resTy, llvm::APFloat::getZero(
resElemTy.cast<FloatType>().getFloatSemantics(), false));
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
} else if (resElemTy.isa<mlir::IntegerType>()) {
auto constAttr = SplatElementsAttr::get(
resTy,
llvm::APInt::getZero(resElemTy.cast<mlir::IntegerType>().getWidth()));
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
} else {
return op.emitError("element type is not float or integer");
}

rewriter.replaceOpWithNewOp<stablehlo::SelectOp>(
op.getOperation(), resTy, bcastedCmpRes, self, zeroTensor);

return success();
}

void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) {
Expand Down Expand Up @@ -2218,6 +2289,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenFmodTensorOp);
INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp);
INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp);

INSERT_ATENOP_PATTERN(AtenTrilOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
Expand Down
6 changes: 3 additions & 3 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,6 @@
"AtenSubFloatModule_basic",
"AtenTopKModule_basic",
"AtenTopKSmallestModule_basic",
"AtenTrilModule_basic",
"AtenTrilWithNegDiagonalModule_basic",
"AtenTrilWithPosDiagonalModule_basic",
"Aten_EmbeddingBagExample_basic",
"AvgPool2dDivisorOverrideModule_basic",
"BernoulliTensorModule_basic",
Expand Down Expand Up @@ -867,6 +864,9 @@
"AtenRoundIntModule_basic",
"AtenSubFloatModule_basic",
"AtenToDeviceModule_basic",
"AtenTrilStaticModule_basic",
"AtenTrilWithNegDiagonalStaticModule_basic",
"AtenTrilWithPosDiagonalStaticModule_basic",
"Aten_CastFloatModule_basic",
"Aten_CastLongModule_basic",
"AvgPool1dStaticModule_basic",
Expand Down
69 changes: 69 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5338,6 +5338,29 @@ def AtenTrilModule_basic(module, tu: TestUtils):
# ==============================================================================


class AtenTrilStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([8, 8], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.tril(x)


@register_test_case(module_factory=lambda: AtenTrilStaticModule())
def AtenTrilStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 8))


# ==============================================================================


class AtenTrilWithPosDiagonalModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -5361,6 +5384,29 @@ def AtenTrilWithPosDiagonalModule_basic(module, tu: TestUtils):
# ==============================================================================


class AtenTrilWithPosDiagonalStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([9, 4, 3], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.tril(x, diagonal=2)


@register_test_case(module_factory=lambda: AtenTrilWithPosDiagonalStaticModule())
def AtenTrilWithPosDiagonalStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(9, 4, 3))


# ==============================================================================


class AtenTrilWithNegDiagonalModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -5384,6 +5430,29 @@ def AtenTrilWithNegDiagonalModule_basic(module, tu: TestUtils):
# ==============================================================================


class AtenTrilWithNegDiagonalStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([3, 1, 5, 9], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.tril(x, diagonal=-4)


@register_test_case(module_factory=lambda: AtenTrilWithNegDiagonalStaticModule())
def AtenTrilWithNegDiagonalStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 5, 9))


# ==============================================================================


class AtenRoundFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
22 changes: 22 additions & 0 deletions test/Conversion/TorchToStablehlo/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,25 @@ func.func @torch.aten.bitwise_right_shift.Tensor(%arg0: !torch.vtensor<[3,4],si6
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si64>, !torch.vtensor<[3,4],si64> -> !torch.vtensor<[3,4],si64>
return %0 : !torch.vtensor<[3,4],si64>
}

// -----

// CHECK-LABEL: func.func @torch.aten.tril(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[2,3,5],f32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.int) -> !torch.vtensor<[2,3,5],f32>
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[2,3,5],f32> -> tensor<2x3x5xf32>
// CHECK: %[[VAL_1:.*]] = torch_c.to_i64 %[[ARG_1]]
// CHECK: %[[VAL_2:.*]] = stablehlo.iota dim = 1 : tensor<3x5xi64>
// CHECK: %[[VAL_3:.*]] = stablehlo.iota dim = 0 : tensor<3x5xi64>
// CHECK: %[[VAL_4:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xi64>
// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_3]], %[[VAL_4]] {broadcast_dimensions = array<i64: 1>} : (tensor<3x5xi64>, tensor<1xi64>) -> tensor<3x5xi64>
// CHECK: %[[VAL_6:.*]] = stablehlo.compare LE, %[[VAL_2]], %[[VAL_5]], SIGNED : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1>
// CHECK: %[[VAL_7:.*]] = stablehlo.broadcast_in_dim %[[VAL_6]], dims = [1, 2] : (tensor<3x5xi1>) -> tensor<2x3x5xi1>
// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<2x3x5xf32>
// CHECK: %[[VAL_9:.*]] = stablehlo.select %[[VAL_7]], %[[VAL_0]], %[[VAL_8]] : tensor<2x3x5xi1>, tensor<2x3x5xf32>
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<2x3x5xf32> -> !torch.vtensor<[2,3,5],f32>
// CHECK: return %[[VAL_10:.*]] : !torch.vtensor<[2,3,5],f32>
func.func @torch.aten.tril(%arg0: !torch.vtensor<[2,3,5],f32>, %arg1: !torch.int) -> !torch.vtensor<[2,3,5],f32> {
%0 = torch.aten.tril %arg0, %arg1:!torch.vtensor<[2,3,5],f32>, !torch.int -> !torch.vtensor<[2,3,5],f32>
return %0 : !torch.vtensor<[2,3,5],f32>
}

0 comments on commit cc28d56

Please sign in to comment.