Skip to content

Commit

Permalink
[Linalg] Fix some segfault
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis committed May 31, 2024
1 parent 7bac2bd commit f1cce24
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,10 +568,13 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
Location loc = op->getLoc();
const TypeConverter *typeConverter = this->getTypeConverter();
Value self = adaptor.getSelf();

int64_t input_height =
cast<RankedTensorType>(self.getType()).getShape()[-2];
int64_t input_width = cast<RankedTensorType>(self.getType()).getShape()[-1];
auto selfType = cast<RankedTensorType>(self.getType());

unsigned numDims = selfType.getRank();
int64_t wDim = numDims - 1;
int64_t hDim = numDims - 2;
int64_t input_height = selfType.getShape()[hDim];
int64_t input_width = selfType.getShape()[wDim];
Type inputElementType =
cast<RankedTensorType>(self.getType()).getElementType();
Type resultType = typeConverter->convertType(op.getType());
Expand Down Expand Up @@ -685,23 +688,24 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
// divisor = (ih1 - ih0) * (iw1 - iw0);
// }
// }
if (countIncludePad) {
divisor = convertScalarToDtype(b, loc, poolSize,
resultElementType);
} else {
Value ih1_ih0 =
b.create<arith::SubIOp>(loc, ih1Clamped, ih0Clamped);
Value iw1_iw0 =
b.create<arith::SubIOp>(loc, iw1Clamped, iw0Clamped);
divisor = b.create<arith::MulIOp>(loc, ih1_ih0, iw1_iw0);
}
// AtenAvgPool2dOp has an optional divisor_override attribute while AtenAvgPool1dOp does not.
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
if (!isa<Torch::NoneType>(
op.getDivisorOverride().getType()))
divisor = op.getDivisorOverride();
} else {
if (countIncludePad) {
divisor = convertScalarToDtype(b, loc, poolSize,
resultElementType);
} else {
Value ih1_ih0 =
b.create<arith::SubIOp>(loc, ih1Clamped, ih0Clamped);
Value iw1_iw0 =
b.create<arith::SubIOp>(loc, iw1Clamped, iw0Clamped);
divisor = b.create<arith::MulIOp>(loc, ih1_ih0, iw1_iw0);
}
}

divisor = convertScalarToDtype(b, loc, divisor, resultElementType);
Value avg;
if (isa<mlir::IntegerType>(resultElementType))
avg = b.create<arith::DivSIOp>(loc, args[0], divisor);
Expand Down

0 comments on commit f1cce24

Please sign in to comment.