From f1cce24c2f0b204c7b9c27ecc524a796d9d23b0c Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Fri, 31 May 2024 02:03:31 +0000 Subject: [PATCH] [Linalg] Fix some segfault --- lib/Conversion/TorchToLinalg/Pooling.cpp | 34 +++++++++++++----------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index ddbe52edca52..a6e198a137a9 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -568,10 +568,13 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { Location loc = op->getLoc(); const TypeConverter *typeConverter = this->getTypeConverter(); Value self = adaptor.getSelf(); - - int64_t input_height = - cast(self.getType()).getShape()[-2]; - int64_t input_width = cast(self.getType()).getShape()[-1]; + auto selfType = cast(self.getType()); + + unsigned numDims = selfType.getRank(); + int64_t wDim = numDims - 1; + int64_t hDim = numDims - 2; + int64_t input_height = selfType.getShape()[hDim]; + int64_t input_width = selfType.getShape()[wDim]; Type inputElementType = cast(self.getType()).getElementType(); Type resultType = typeConverter->convertType(op.getType()); @@ -685,23 +688,24 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { // divisor = (ih1 - ih0) * (iw1 - iw0); // } // } + if (countIncludePad) { + divisor = convertScalarToDtype(b, loc, poolSize, + resultElementType); + } else { + Value ih1_ih0 = + b.create(loc, ih1Clamped, ih0Clamped); + Value iw1_iw0 = + b.create(loc, iw1Clamped, iw0Clamped); + divisor = b.create(loc, ih1_ih0, iw1_iw0); + } + // AtenAvgPool2dOp has an optional divisor_override attribute while AtenAvgPool1dOp does not. if constexpr (std::is_same()) { if (!isa( op.getDivisorOverride().getType())) divisor = op.getDivisorOverride(); - } else { - if (countIncludePad) { - divisor = convertScalarToDtype(b, loc, poolSize, - resultElementType); - } else { - Value ih1_ih0 = - b.create(loc, ih1Clamped, ih0Clamped); - Value iw1_iw0 = - b.create(loc, iw1Clamped, iw0Clamped); - divisor = b.create(loc, ih1_ih0, iw1_iw0); - } } + divisor = convertScalarToDtype(b, loc, divisor, resultElementType); Value avg; if (isa(resultElementType)) avg = b.create(loc, args[0], divisor);