Skip to content

Commit

Permalink
[Linalg] Add countIncludePad support for averagepool
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis committed May 2, 2024
1 parent eb3b62f commit 9250bff
Showing 1 changed file with 74 additions and 18 deletions.
92 changes: 74 additions & 18 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,8 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
const TypeConverter *typeConverter = this->getTypeConverter();
Value self = adaptor.getSelf();

int64_t input_height = cast<RankedTensorType>(self.getType()).getShape()[-2];
int64_t input_width = cast<RankedTensorType>(self.getType()).getShape()[-1];
Type inputElementType =
cast<RankedTensorType>(self.getType()).getElementType();
Type resultType = typeConverter->convertType(op.getType());
Expand All @@ -572,13 +574,6 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
return rewriter.notifyMatchFailure(
op, "count_include_pad must be a constant");

// If the padding is zero then there is no padding to include.
if (!countIncludePad &&
!llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) {
return rewriter.notifyMatchFailure(
op, "unimplemented: count_include_pad is expected to be true");
}

// `sumPool` contains the result of sumpool operation over the input.
Value sumPool, paddedInput;
SmallVector<Value, Dim + 2> outTensorShape;
Expand All @@ -588,17 +583,12 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType),
outTensorShape, paddedInput, sumPool)))
return rewriter.notifyMatchFailure(op, "unable to compute sumpool");
Value divisor;
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
Value kHtimeskW = rewriter.create<arith::MulIOp>(
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
divisor = isa<Torch::NoneType>(op.getDivisorOverride().getType())
? kHtimeskW
: adaptor.getDivisorOverride();
} else {
divisor = kernelSizeIntValues[0];
}
divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType);

RankedTensorType sumPoolType = cast<RankedTensorType>(sumPool.getType());
// get rank of input (same as rank of output)
const int64_t rank = sumPoolType.getRank();
int dimH = toPositiveDim(-2, rank);
int dimW = toPositiveDim(-1, rank);

Value outputTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outTensorShape), resultElementType);
Expand All @@ -613,6 +603,72 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
/*indexingMaps=*/indexingMapsAvg,
/*iteratorTypes=*/iteratorTypesAvg,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value indexOh = b.create<linalg::IndexOp>(loc, /*value=*/dimH);
Value oh = castIndexToInt64(b, loc, indexOh);
Value indexOw = b.create<linalg::IndexOp>(loc, /*value=*/dimW);
Value ow = castIndexToInt64(b, loc, indexOw);

// int64_t ih0 = oh * dH - padH;
Value ih0 = b.create<arith::SubIOp>(
loc, b.create<arith::MulIOp>(loc, oh,
rewriter.getI64IntegerAttr(strideInts[0])),
rewriter.getI64IntegerAttr(paddingInts[0]));
// int64_t iw0 = ow * dW - padW;
Value iw0 = b.create<arith::SubIOp>(
loc, b.create<arith::MulIOp>(loc, ow,
rewriter.getI64IntegerAttr(strideInts[1])),
rewriter.getI64IntegerAttr(paddingInts[1]));
// int64_t ih1 = std::min(ih0 + kH, input_height + padH);
Value ih1 = b.create<arith::MinIOp>(
loc, b.create<arith::AddIOp>(loc, ih0,
rewriter.getI64IntegerAttr(kernelSizeIntValues[0])),
b.create<arith::AddIOp>(loc, rewriter.getI64IntegerAttr(input_height),
rewriter.getI64IntegerAttr(paddingInts[0])));
// int64_t iw1 = std::min(iw0 + kW, input_width + padW);
Value iw1 = b.create<arith::MinIOp>(
loc, b.create<arith::AddIOp>(loc, iw0,
rewriter.getI64IntegerAttr(kernelSizeIntValues[1])),
b.create<arith::AddIOp>(loc, rewriter.getI64IntegerAttr(input_width),
rewriter.getI64IntegerAttr(paddingInts[1])));
// int64_t pool_size = (ih1 - ih0) * (iw1 - iw0);
Value poolSize = b.create<arith::MulIOp>(
loc, b.create<arith::SubIOp>(loc, ih1, ih0),
b.create<arith::SubIOp>(loc, iw1, iw0));
// ih0 = std::max(ih0, 0);
Value ih0Clamped = b.create<arith::MaxIOp>(
loc, ih0, rewriter.getI64IntegerAttr(0));
// iw0 = std::max(iw0, 0);
Value iw0Clamped = b.create<arith::MaxIOp>(
loc, iw0, rewriter.getI64IntegerAttr(0));
// ih1 = std::min(ih1, input_height);
Value ih1Clamped = b.create<arith::MinIOp>(
loc, ih1, getDimOp(rewriter, loc, input_height, 2));
// iw1 = std::min(iw1, input_width);
Value iw1Clamped = b.create<arith::MinIOp>(
loc, iw1, getDimOp(rewriter, loc, input_width, 3));

// Value divisor;
// if (divisor_override.has_value()) {
// divisor = divisor_override.value();
// } else {
// if(count_include_pad) {
// divisor = pool_size;
// } else {
// divisor = (ih1 - ih0) * (iw1 - iw0);
// }
// }
if(!isa<Torch::NoneType>(op.getDivisorOverride().getType())){
divisor = convertScalarToDtype(b, loc, op.getDivisorOverride(), resultElementType);
} else {
if(countIncludePad){
divisor = convertScalarToDtype(b, loc, poolSize, resultElementType);
} else {
divisor = b.create<arith::MulIOp>(
loc, b.create<arith::SubIOp>(loc, ih1Clamped, ih0Clamped),
b.create<arith::SubIOp>(loc, iw1Clamped, iw0Clamped));
}
}

Value avg;
if (isa<mlir::IntegerType>(resultElementType))
avg = b.create<arith::DivSIOp>(loc, args[0], divisor);
Expand Down

0 comments on commit 9250bff

Please sign in to comment.