Skip to content

Commit

Permalink
[MLIR][Torch] Add TorchToLinalg lowering for AtenAvgPool3dOp (llvm#3030)
Browse files Browse the repository at this point in the history
This commit also fixes the average pool op' test failing for
OnnxToLinalg lowering.

Signed-Off By: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 authored and josel-amd committed Jun 7, 2024
1 parent 40b292c commit 7bfc292
Show file tree
Hide file tree
Showing 8 changed files with 380 additions and 63 deletions.
4 changes: 4 additions & 0 deletions include/torch-mlir/Conversion/TorchToLinalg/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ getBackendTypeForScalarType(MLIRContext *context,

bool isUnsignedTorchType(Type type);

LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter,
Location loc, SmallVector<int64_t> dimensions,
Value input, Value &result);

} // namespace torch_to_linalg
} // namespace torch
} // namespace mlir
55 changes: 7 additions & 48 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1457,56 +1457,15 @@ class ConvertAtenPermuteOp : public OpConversionPattern<AtenPermuteOp> {
return rewriter.notifyMatchFailure(op, "all dimensions must be constant");

Value inVector = adaptor.getSelf();
auto inType = inVector.getType().cast<RankedTensorType>();
int64_t inputRank = inType.getRank();
auto outType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type elementType = inType.getElementType();

// Check if the dimensions are a valid constants.
int64_t numDimensions = dimensions.size();
if (inputRank != numDimensions)
Value result;
if (failed(torch_to_linalg::permuteTensor(op, rewriter, op->getLoc(),
dimensions, inVector, result)))
return rewriter.notifyMatchFailure(
op, "size of `dims` must be equal to the rank of the input");
for (unsigned i = 0; i < numDimensions; i++) {
if (dimensions[i] < 0)
dimensions[i] = toPositiveDim(dimensions[i], inputRank);
if (!isValidDim(dimensions[i], inputRank))
return rewriter.notifyMatchFailure(op, "dimension out of range");
}

Location loc = op.getLoc();

SmallVector<Value> outputDims;
for (unsigned i = 0; i < inputRank; i++)
outputDims.push_back(getDimOp(rewriter, loc, inVector, dimensions[i]));
op, "failed to perform permutation of tensor");

Value outVector = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outputDims), elementType);
SmallVector<AffineExpr> idExprs;
SmallVector<AffineExpr> swapExprs;
for (unsigned i = 0; i < inputRank; i++)
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
for (unsigned i = 0; i < inputRank; i++)
swapExprs.push_back(idExprs[dimensions[i]]);

AffineMap inputMap =
AffineMap::get(inputRank, /*symbolCount=*/0, idExprs, op->getContext());
AffineMap outputMap = AffineMap::get(inputRank, /*symbolCount=*/0,
swapExprs, op->getContext());
SmallVector<AffineMap> indexingMaps{inputMap, outputMap};
SmallVector<utils::IteratorType> iteratorTypes(
inputRank, utils::IteratorType::parallel);
auto transpose = rewriter
.create<linalg::GenericOp>(
loc, outVector.getType(), inVector, outVector,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, transpose);
auto outType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, result);
return success();
}
};
Expand Down
65 changes: 50 additions & 15 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,42 @@ static LogicalResult createPoolingOp(
Value windowTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(shape), elementType);

result = rewriter
.create<OpTy>(loc, outTensorInitialized.getType(),
ValueRange{paddedInput, windowTensor},
outTensorInitialized, stridesAttr, dilationAttr)
.getResult(0);
Value permutedInput = paddedInput, permutedOutput = outTensorInitialized;
if (dimensionality == 3) {
// Permute input and output tensor as follows:
// (n,c,d,h,w) -> (n,d,h,w,c)
SmallVector<int64_t> dimensions = {0, 2, 3, 4, 1};
if (failed(torch_to_linalg::permuteTensor(op, rewriter, op->getLoc(),
dimensions, paddedInput,
permutedInput)))
return rewriter.notifyMatchFailure(
op, "failed to perform permutation of tensor");

if (failed(torch_to_linalg::permuteTensor(op, rewriter, op->getLoc(),
dimensions, outTensorInitialized,
permutedOutput)))
return rewriter.notifyMatchFailure(
op, "failed to perform permutation of tensor");
}

Value poolingResult =
rewriter
.create<OpTy>(loc, permutedOutput.getType(),
ValueRange{permutedInput, windowTensor}, permutedOutput,
stridesAttr, dilationAttr)
.getResult(0);

result = poolingResult;
if (dimensionality == 3) {
// Permute output tensor as follows:
// (n,d,h,w,c) -> (n,c,d,h,w)
SmallVector<int64_t> dimensions = {0, 4, 1, 2, 3};
if (failed(torch_to_linalg::permuteTensor(
op, rewriter, op->getLoc(), dimensions, poolingResult, result)))
return rewriter.notifyMatchFailure(
op, "failed to perform permutation of tensor");
}

return success();
}

Expand Down Expand Up @@ -588,16 +619,17 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType),
outTensorShape, paddedInput, sumPool)))
return rewriter.notifyMatchFailure(op, "unable to compute sumpool");
Value divisor;
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
Value kHtimeskW = rewriter.create<arith::MulIOp>(
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
// }

Value divisor = kernelSizeIntValues[0];
for (uint32_t i = 1; i < kernelSizeIntValues.size(); i++) {
divisor =
op.getDivisorOverride().getType().template isa<Torch::NoneType>()
? kHtimeskW
: adaptor.getDivisorOverride();
} else {
divisor = kernelSizeIntValues[0];
rewriter.create<arith::MulIOp>(loc, divisor, kernelSizeIntValues[i]);
}
if constexpr (!std::is_same<OpTy, AtenAvgPool1dOp>()) {
divisor = isa<Torch::NoneType>(op.getDivisorOverride().getType())
? divisor
: adaptor.getDivisorOverride();
}
divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType);

Expand Down Expand Up @@ -1098,13 +1130,16 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(

target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp>();
target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp, AtenAvgPool3dOp>();
patterns
.add<ConvertAtenAvgPoolOp<AtenAvgPool1dOp, linalg::PoolingNcwSumOp, 1>>(
typeConverter, context);
patterns
.add<ConvertAtenAvgPoolOp<AtenAvgPool2dOp, linalg::PoolingNchwSumOp, 2>>(
typeConverter, context);
patterns
.add<ConvertAtenAvgPoolOp<AtenAvgPool3dOp, linalg::PoolingNdhwcSumOp, 3>>(
typeConverter, context);
target.addIllegalOp<AtenAdaptiveAvgPool1dOp, AtenAdaptiveAvgPool2dOp,
AtenAdaptiveAvgPool3dOp, Aten_AdaptiveAvgPool3dOp>();
patterns.add<ConvertAtenAdaptivePoolOp<AtenAdaptiveAvgPool1dOp>>(
Expand Down
52 changes: 52 additions & 0 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,3 +576,55 @@ bool torch_to_linalg::isUnsignedTorchType(Type type) {
llvm_unreachable("Unknown type checked for signedness");
return false;
}

LogicalResult torch_to_linalg::permuteTensor(Operation *op,
PatternRewriter &rewriter,
Location loc,
SmallVector<int64_t> dimensions,
Value input, Value &result) {
auto inType = cast<RankedTensorType>(input.getType());
int64_t inputRank = inType.getRank();
Type elementType = inType.getElementType();

// Check if the dimensions are a valid constants.
int64_t numDimensions = dimensions.size();
if (inputRank != numDimensions)
return rewriter.notifyMatchFailure(
op, "size of `dims` must be equal to the rank of the input");
for (uint32_t i = 0; i < numDimensions; i++) {
if (dimensions[i] < 0)
dimensions[i] = toPositiveDim(dimensions[i], inputRank);
if (!isValidDim(dimensions[i], inputRank))
return rewriter.notifyMatchFailure(op, "dimension out of range");
}

SmallVector<Value> outputDims;
for (uint32_t i = 0; i < inputRank; i++)
outputDims.push_back(getDimOp(rewriter, loc, input, dimensions[i]));

Value outVector = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outputDims), elementType);
SmallVector<AffineExpr> idExprs;
SmallVector<AffineExpr> swapExprs;
for (uint32_t i = 0; i < inputRank; i++)
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
for (uint32_t i = 0; i < inputRank; i++)
swapExprs.push_back(idExprs[dimensions[i]]);

AffineMap inputMap =
AffineMap::get(inputRank, /*symbolCount=*/0, idExprs, op->getContext());
AffineMap outputMap =
AffineMap::get(inputRank, /*symbolCount=*/0, swapExprs, op->getContext());
SmallVector<AffineMap> indexingMaps{inputMap, outputMap};
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
utils::IteratorType::parallel);
result = rewriter
.create<linalg::GenericOp>(
loc, outVector.getType(), input, outVector, indexingMaps,
iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
return success();
}
Loading

0 comments on commit 7bfc292

Please sign in to comment.