From 76c07eeb4a4a9530dbd6413749695b3c383e1ab9 Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Fri, 31 May 2024 02:03:31 +0000 Subject: [PATCH] [Linalg] Fix some segfault --- lib/Conversion/TorchToLinalg/Pooling.cpp | 73 +++++++++++-------- .../torch_mlir_e2e_test/test_suite/pooling.py | 28 ++++--- 2 files changed, 57 insertions(+), 44 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index ba33ebb74b71..041e05650bac 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -566,10 +566,13 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { Location loc = op->getLoc(); const TypeConverter *typeConverter = this->getTypeConverter(); Value self = adaptor.getSelf(); + auto selfType = cast(self.getType()); - int64_t input_height = - cast(self.getType()).getShape()[-2]; - int64_t input_width = cast(self.getType()).getShape()[-1]; + const int64_t selfRank = selfType.getRank(); + int64_t wDim = toPositiveDim(-1, selfRank); + int64_t hDim = toPositiveDim(-2, selfRank); + int64_t inputHeight = selfType.getShape()[hDim]; + int64_t inputWidth = selfType.getShape()[wDim]; Type inputElementType = cast(self.getType()).getElementType(); Type resultType = typeConverter->convertType(op.getType()); @@ -620,6 +623,12 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { /*indexingMaps=*/indexingMapsAvg, /*iteratorTypes=*/iteratorTypesAvg, [&](OpBuilder &b, Location loc, ValueRange args) { + // The algorithm for computing the divisor with + // count_include_pad is manily based on pytorch + // implementation. The following code is comment + // with pytorch code. + // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 + Value divisor; Value indexOh = b.create(loc, /*value=*/dimH); Value oh = castIndexToInt64(b, loc, indexOh); @@ -632,34 +641,33 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { loc, rewriter.getI64IntegerAttr(strideInts[0])); Value padH = rewriter.create( loc, rewriter.getI64IntegerAttr(paddingInts[0])); - Value oh_dH = b.create(loc, oh, dH); - Value ih0 = b.create(loc, oh_dH, padH); + Value ohDH = b.create(loc, oh, dH); + Value ih0 = b.create(loc, ohDH, padH); // int64_t iw0 = ow * dW - padW; Value dW = rewriter.create( loc, rewriter.getI64IntegerAttr(strideInts[1])); Value padW = rewriter.create( loc, rewriter.getI64IntegerAttr(paddingInts[1])); - Value ow_dW = b.create(loc, ow, dW); - Value iw0 = b.create(loc, ow_dW, padW); + Value owDW = b.create(loc, ow, dW); + Value iw0 = b.create(loc, owDW, padW); // int64_t ih1 = std::min(ih0 + kH, input_height + padH); Value ih = rewriter.create( - loc, rewriter.getI64IntegerAttr(input_height)); - Value ih0_kH = + loc, rewriter.getI64IntegerAttr(inputHeight)); + Value ih0KH = b.create(loc, ih0, kernelSizeIntValues[0]); - Value ih_padH = b.create(loc, ih, padH); - Value ih1 = b.create(loc, ih0_kH, ih_padH); + Value ihPadH = b.create(loc, ih, padH); + Value ih1 = b.create(loc, ih0KH, ihPadH); // int64_t iw1 = std::min(iw0 + kW, input_width + padW); Value iw = rewriter.create( - loc, rewriter.getI64IntegerAttr(input_width)); - Value iw0_kW = + loc, rewriter.getI64IntegerAttr(inputWidth)); + Value iw0KW = b.create(loc, iw0, kernelSizeIntValues[1]); - Value iw_padW = b.create(loc, iw, padW); - Value iw1 = b.create(loc, iw0_kW, iw_padW); + Value iwPadW = b.create(loc, iw, padW); + Value iw1 = b.create(loc, iw0KW, iwPadW); // int64_t pool_size = (ih1 - ih0) * (iw1 - iw0); - Value ih1_ih0 = b.create(loc, ih1, ih0); - Value iw1_iw0 = b.create(loc, iw1, iw0); - Value poolSize = - b.create(loc, ih1_ih0, iw1_iw0); + Value ih1Ih0 = b.create(loc, ih1, ih0); + Value iw1Iw0 = b.create(loc, iw1, iw0); + Value poolSize = b.create(loc, ih1Ih0, iw1Iw0); // ih0 = std::max(ih0, 0); Value cstZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); @@ -672,8 +680,6 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { Value ih1Clamped = b.create(loc, ih1, ih); // iw1 = std::min(iw1, input_width); Value iw1Clamped = b.create(loc, iw1, iw); - - Value divisor; // if (divisor_override.has_value()) { // divisor = divisor_override.value(); // } else { @@ -683,23 +689,26 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { // divisor = (ih1 - ih0) * (iw1 - iw0); // } // } + if (countIncludePad) { + divisor = convertScalarToDtype(b, loc, poolSize, + resultElementType); + } else { + Value ih1_ih0 = + b.create(loc, ih1Clamped, ih0Clamped); + Value iw1_iw0 = + b.create(loc, iw1Clamped, iw0Clamped); + divisor = b.create(loc, ih1_ih0, iw1_iw0); + } + // AtenAvgPool2/3dOp has an optional divisor_override + // attribute while AtenAvgPool1dOp does not. if constexpr (std::is_same()) { if (!isa( op.getDivisorOverride().getType())) divisor = op.getDivisorOverride(); - } else { - if (countIncludePad) { - divisor = convertScalarToDtype(b, loc, poolSize, - resultElementType); - } else { - Value ih1_ih0 = - b.create(loc, ih1Clamped, ih0Clamped); - Value iw1_iw0 = - b.create(loc, iw1Clamped, iw0Clamped); - divisor = b.create(loc, ih1_ih0, iw1_iw0); - } } + divisor = + convertScalarToDtype(b, loc, divisor, resultElementType); Value avg; if (isa(resultElementType)) avg = b.create(loc, args[0], divisor); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 904beedfba1a..c0782573d685 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1020,18 +1020,22 @@ class AvgPool2dFloatStaticModule(torch.nn.Module): def __init__(self): super().__init__() - self.ap2d = torch.nn.AvgPool2d(kernel_size=[3, 3], - stride=[1, 1], - padding=[1, 1], - ceil_mode=False, - count_include_pad=False, - divisor_override=None) - - @export - @annotate_args([ - None, - ([32, 384, 25, 25], torch.float32, True), - ]) + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=False, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([32, 384, 25, 25], torch.float32, True), + ] + ) def forward(self, x): return self.ap2d(x)