Skip to content

Commit

Permalink
[Linalg] Fix some segfault
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis committed Jun 4, 2024
1 parent e371ef8 commit 89ee8b7
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 42 deletions.
73 changes: 41 additions & 32 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,10 +566,13 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
Location loc = op->getLoc();
const TypeConverter *typeConverter = this->getTypeConverter();
Value self = adaptor.getSelf();
auto selfType = cast<RankedTensorType>(self.getType());

int64_t input_height =
cast<RankedTensorType>(self.getType()).getShape()[-2];
int64_t input_width = cast<RankedTensorType>(self.getType()).getShape()[-1];
const int64_t selfRank = selfType.getRank();
int64_t wDim = toPositiveDim(-1, selfRank);
int64_t hDim = toPositiveDim(-2, selfRank);
int64_t inputHeight = selfType.getShape()[hDim];
int64_t inputWidth = selfType.getShape()[wDim];
Type inputElementType =
cast<RankedTensorType>(self.getType()).getElementType();
Type resultType = typeConverter->convertType(op.getType());
Expand Down Expand Up @@ -620,6 +623,12 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
/*indexingMaps=*/indexingMapsAvg,
/*iteratorTypes=*/iteratorTypesAvg,
[&](OpBuilder &b, Location loc, ValueRange args) {
// The algorithm for computing the divisor with
// count_include_pad is manily based on pytorch
// implementation. The following code is comment
// with pytorch code.
// https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78
Value divisor;
Value indexOh =
b.create<linalg::IndexOp>(loc, /*value=*/dimH);
Value oh = castIndexToInt64(b, loc, indexOh);
Expand All @@ -632,34 +641,33 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
loc, rewriter.getI64IntegerAttr(strideInts[0]));
Value padH = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(paddingInts[0]));
Value oh_dH = b.create<arith::MulIOp>(loc, oh, dH);
Value ih0 = b.create<arith::SubIOp>(loc, oh_dH, padH);
Value ohDH = b.create<arith::MulIOp>(loc, oh, dH);
Value ih0 = b.create<arith::SubIOp>(loc, ohDH, padH);
// int64_t iw0 = ow * dW - padW;
Value dW = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(strideInts[1]));
Value padW = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(paddingInts[1]));
Value ow_dW = b.create<arith::MulIOp>(loc, ow, dW);
Value iw0 = b.create<arith::SubIOp>(loc, ow_dW, padW);
Value owDW = b.create<arith::MulIOp>(loc, ow, dW);
Value iw0 = b.create<arith::SubIOp>(loc, owDW, padW);
// int64_t ih1 = std::min(ih0 + kH, input_height + padH);
Value ih = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(input_height));
Value ih0_kH =
loc, rewriter.getI64IntegerAttr(inputHeight));
Value ih0KH =
b.create<arith::AddIOp>(loc, ih0, kernelSizeIntValues[0]);
Value ih_padH = b.create<arith::AddIOp>(loc, ih, padH);
Value ih1 = b.create<arith::MinSIOp>(loc, ih0_kH, ih_padH);
Value ihPadH = b.create<arith::AddIOp>(loc, ih, padH);
Value ih1 = b.create<arith::MinSIOp>(loc, ih0KH, ihPadH);
// int64_t iw1 = std::min(iw0 + kW, input_width + padW);
Value iw = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(input_width));
Value iw0_kW =
loc, rewriter.getI64IntegerAttr(inputWidth));
Value iw0KW =
b.create<arith::AddIOp>(loc, iw0, kernelSizeIntValues[1]);
Value iw_padW = b.create<arith::AddIOp>(loc, iw, padW);
Value iw1 = b.create<arith::MinSIOp>(loc, iw0_kW, iw_padW);
Value iwPadW = b.create<arith::AddIOp>(loc, iw, padW);
Value iw1 = b.create<arith::MinSIOp>(loc, iw0KW, iwPadW);
// int64_t pool_size = (ih1 - ih0) * (iw1 - iw0);
Value ih1_ih0 = b.create<arith::SubIOp>(loc, ih1, ih0);
Value iw1_iw0 = b.create<arith::SubIOp>(loc, iw1, iw0);
Value poolSize =
b.create<arith::MulIOp>(loc, ih1_ih0, iw1_iw0);
Value ih1Ih0 = b.create<arith::SubIOp>(loc, ih1, ih0);
Value iw1Iw0 = b.create<arith::SubIOp>(loc, iw1, iw0);
Value poolSize = b.create<arith::MulIOp>(loc, ih1Ih0, iw1Iw0);
// ih0 = std::max(ih0, 0);
Value cstZero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(0));
Expand All @@ -672,8 +680,6 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
Value ih1Clamped = b.create<arith::MinSIOp>(loc, ih1, ih);
// iw1 = std::min(iw1, input_width);
Value iw1Clamped = b.create<arith::MinSIOp>(loc, iw1, iw);

Value divisor;
// if (divisor_override.has_value()) {
// divisor = divisor_override.value();
// } else {
Expand All @@ -683,23 +689,26 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
// divisor = (ih1 - ih0) * (iw1 - iw0);
// }
// }
if (countIncludePad) {
divisor = convertScalarToDtype(b, loc, poolSize,
resultElementType);
} else {
Value ih1_ih0 =
b.create<arith::SubIOp>(loc, ih1Clamped, ih0Clamped);
Value iw1_iw0 =
b.create<arith::SubIOp>(loc, iw1Clamped, iw0Clamped);
divisor = b.create<arith::MulIOp>(loc, ih1_ih0, iw1_iw0);
}
// AtenAvgPool2/3dOp has an optional divisor_override
// attribute while AtenAvgPool1dOp does not.
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
if (!isa<Torch::NoneType>(
op.getDivisorOverride().getType()))
divisor = op.getDivisorOverride();
} else {
if (countIncludePad) {
divisor = convertScalarToDtype(b, loc, poolSize,
resultElementType);
} else {
Value ih1_ih0 =
b.create<arith::SubIOp>(loc, ih1Clamped, ih0Clamped);
Value iw1_iw0 =
b.create<arith::SubIOp>(loc, iw1Clamped, iw0Clamped);
divisor = b.create<arith::MulIOp>(loc, ih1_ih0, iw1_iw0);
}
}

divisor =
convertScalarToDtype(b, loc, divisor, resultElementType);
Value avg;
if (isa<mlir::IntegerType>(resultElementType))
avg = b.create<arith::DivSIOp>(loc, args[0], divisor);
Expand Down
26 changes: 16 additions & 10 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,22 +1016,27 @@ def forward(self, x):
def AvgPool2dStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2, 10, 20, low=-1))


class AvgPool2dFloatStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(kernel_size=[3, 3],
stride=[1, 1],
padding=[1, 1],
ceil_mode=False,
count_include_pad=False,
divisor_override=None)
self.ap2d = torch.nn.AvgPool2d(
kernel_size=[3, 3],
stride=[1, 1],
padding=[1, 1],
ceil_mode=False,
count_include_pad=False,
divisor_override=None,
)

@export
@annotate_args([
None,
([32, 384, 25, 25], torch.float32, True),
])
@annotate_args(
[
None,
([32, 384, 25, 25], torch.float32, True),
]
)
def forward(self, x):
return self.ap2d(x)

Expand All @@ -1040,6 +1045,7 @@ def forward(self, x):
def AvgPool2dFloatStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(32, 384, 25, 25, low=-1))


class AvgPool2dDivisorOverrideModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 89ee8b7

Please sign in to comment.