Skip to content

Commit

Permalink
Add RestructureNonConstantAxes pass to address reduce op tests failin…
Browse files Browse the repository at this point in the history
…g on non constant axes (llvm#3600)
  • Loading branch information
renxida authored Aug 26, 2024
1 parent 638ef14 commit eb7bf78
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 0 deletions.
6 changes: 6 additions & 0 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ StringRef getAbstractInterpLibrary();

static const char kTorchOpPrefix[] = R"(torch.)";

void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns,
MLIRContext *context);

std::unique_ptr<OperationPass<func::FuncOp>>
createRestructureNonConstantAxesPass();

} // namespace Torch

/// Registers all Torch transformation passes.
Expand Down
20 changes: 20 additions & 0 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -431,4 +431,24 @@ def VerifyBackendContractNoDecompositions
}];
}

def RestructureNonConstantAxes
: Pass<"torch-restructure-non-constant-axes", "func::FuncOp"> {
let summary = "Ensure that every Reduction.cpp op has a constant reduction axis.";
let constructor = [{
mlir::torch::Torch::createRestructureNonConstantAxesPass()
}];
let description = [{
This pass ensures that every Reduction.cpp op has a constant reduction axis.

It does so using reshapes. For example, a <1,2,3,4,5> tensor will be reshaped to a <?,?,?> tensor
and reduced on axis 1 to produce a <?,1,?> tensor. The resulting tensor will be reshaped back to the original shape.

Then when the axis is supplied at runtime (say axis = -2), the shapes will be computed as so:
<?,?,?> becomes <6,4,5>
which gets reduced to <6,1,5>
and rehsaped back to the original reduction op's output shape,
<1,2,3,1,5>
}];
}

#endif // TORCHMLIR_TORCH_PASSES
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_mlir_library(TorchMLIRTorchPasses
ReifyShapeCalculations.cpp
ReifyDtypeCalculations.cpp
ReifyAbstractInterpCalculationsUtils.cpp
RestructureNonConstantAxes.cpp
ScalarizeShapes.cpp
AbstractInterpLibrary.cpp
SimplifyShapeCalculations.cpp
Expand Down
277 changes: 277 additions & 0 deletions lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
//===- RestructureNonConstantAxes.cpp --------------------------------*-
// C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"

#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "torch-lower-to-backend-contract"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {

template <typename SrcOp>
class ConstantifyDimArgument : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;

bool isDimConstant(SrcOp op) const {
SmallVector<int64_t> dimList;
int64_t dim;
return matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList)) ||
matchPattern(op.getDim(), m_TorchConstantInt(&dim));
}

/*
This function renders the reduction dim constant by reshaping the input tensor
such that the dim argument is the middle dimension.
For example, if the input tensor has shape [3,4,5,6,7] and the dim argument is
-2, the input tensor is reshaped to [3,4,5,6,7] -> [12,5,42], the reduction
operation is applied, and the result is reshaped back to [3,4,1,6,7].
Since we don't know the dim argument at compile time, we need to compute the
arguments to the reshape op at runtime. We do this by computing the new shape
of the tensor by multiplying the shapes of the tensor before and after the dim
argument, and then reshaping the tensor to this new shape.
*/
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();

Value self = op.getSelf();
Value dim = op.getDim();

if (isDimConstant(op)) {
return rewriter.notifyMatchFailure(op,
"dim argument is already constant");
}

if (isa<Torch::NoneType>(dim.getType())) {
return rewriter.notifyMatchFailure(
op, "RestructureNonConstantAxes does not support None dim");
}

// when keepdim is not constant, check the ranks of the input and output
// tensors
ValueTensorType selfTy =
llvm::cast<ValueTensorType>(op.getSelf().getType());
ValueTensorType resultTy =
llvm::cast<ValueTensorType>(op.getResult().getType());
if (selfTy.hasSizes() && resultTy.hasSizes() &&
selfTy.getSizes().size() != resultTy.getSizes().size()) {
return rewriter.notifyMatchFailure(
op,
"RestructureNonConstantAxes does not yet support keepdim=false, but "
"the input and output tensors have different ranks");
}

Type intType = rewriter.getType<Torch::IntType>();
Type boolType = rewriter.getType<Torch::BoolType>();
auto createInt = [&](int value) {
return rewriter.create<Torch::ConstantIntOp>(
loc, intType,
rewriter.getIntegerAttr(rewriter.getIntegerType(64), value));
};
Value zero = createInt(0);
Value one = createInt(1);

// handle when dim is a single element list
bool oldDimIsList = isa<Torch::ListType>(dim.getType());
if (oldDimIsList) {
Value len = rewriter.create<Torch::AtenLenTOp>(loc, intType, dim);
Value dimListIsLengthOne =
rewriter.create<Torch::AtenEqIntOp>(loc, boolType, len, one);
rewriter.create<Torch::RuntimeAssertOp>(
loc, dimListIsLengthOne,
rewriter.getStringAttr("RestructureNonConstantAxes does not support "
"dim lists with more than one element"));
dim = rewriter.create<Torch::Aten__Getitem__TOp>(loc, intType, dim, zero);
}

// Normalize negative dim
Value rank = rewriter.create<Torch::AtenDimOp>(loc, intType, self);
Value isNegative = rewriter.create<Torch::AtenLtIntOp>(loc, dim, zero);
Value rankOffset = rewriter.create<Torch::AtenMulIntOp>(
loc, intType,
rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isNegative), rank);
dim = rewriter.create<Torch::AtenAddIntOp>(loc, intType, dim, rankOffset);

auto createConditionalMult = [&](Value self, Value multiplier,
Value condition) {
// compute:
// result = codition ? (self * multiplier) : self
// via
// result = self * (1 + (multiplier - 1) * condition)
// which translates to:

// result = multiplier - 1
Value result = rewriter.create<Torch::AtenSubIntOp>(
loc, intType, multiplier, createInt(1));
// result = result * condition
result =
rewriter.create<Torch::AtenMulIntOp>(loc, intType, result, condition);
// result = result + 1
result = rewriter.create<Torch::AtenAddIntOp>(loc, intType, result,
createInt(1));
// result = self * result
result = rewriter.create<Torch::AtenMulIntOp>(loc, intType, self, result);
return result;
};

// new shape = [beforeDim, dimSize, afterDim]
Value beforeProd = createInt(1);
Value afterProd = createInt(1);
Value dimSize = createInt(1);

for (size_t i = 0; i < selfTy.getSizes().size(); ++i) {
Value idx = createInt(i);
Value size =
rewriter.create<Torch::AtenSizeIntOp>(loc, intType, self, idx);

Value isBeforeDim =
rewriter.create<Torch::AtenLtIntOp>(loc, boolType, idx, dim);
isBeforeDim =
rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isBeforeDim);
Value isAfterDim =
rewriter.create<Torch::AtenGtIntOp>(loc, boolType, idx, dim);
isAfterDim =
rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isAfterDim);

Value isEqualToDim =
rewriter.create<Torch::AtenEqIntOp>(loc, boolType, idx, dim);
isEqualToDim =
rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isEqualToDim);
dimSize = createConditionalMult(dimSize, size, isEqualToDim);

beforeProd = createConditionalMult(beforeProd, size, isBeforeDim);
afterProd = createConditionalMult(afterProd, size, isAfterDim);
}

Value newShape = rewriter.create<Torch::PrimListConstructOp>(
loc, rewriter.getType<Torch::ListType>(intType),
ValueRange{beforeProd, dimSize, afterProd});

// Reshape input
auto newSelfTy = selfTy.getWithSizesAndDtype(
SmallVector<int64_t>{Torch::kUnknownSize, Torch::kUnknownSize,
Torch::kUnknownSize},
selfTy.getDtype());
Value reshapedSelf =
rewriter.create<Torch::AtenViewOp>(loc, newSelfTy, self, newShape);

// construct new operange range where self is replaced with reshapedSelf
// tensor, and dim is replaced with 1
Value newDim;
if (oldDimIsList) {
newDim = rewriter.create<Torch::PrimListConstructOp>(
loc, rewriter.getType<Torch::ListType>(intType), ValueRange{one});
} else {
newDim = one;
}
ValueRange oldOperands = op->getOperands();
SmallVector<Value> newOperandsVect;
for (size_t i = 0; i < oldOperands.size(); ++i) {
if (oldOperands[i] == op.getSelf()) {
newOperandsVect.push_back(reshapedSelf);
} else if (oldOperands[i] == op.getDim()) {
newOperandsVect.push_back(newDim);
} else {
newOperandsVect.push_back(oldOperands[i]);
}
}
ValueRange newOperands = ValueRange(newOperandsVect);

// construct new reduction op result type
ValueTensorType newResultTy =
cast<ValueTensorType>(resultTy.getWithSizesAndDtype(
SmallVector<int64_t>{Torch::kUnknownSize, 1, Torch::kUnknownSize},
resultTy.getDtype()));

Value newReductionOp =
rewriter.create<SrcOp>(loc, newResultTy, newOperands, op->getAttrs());

// Reshape the result back to original shape
ValueTensorType oldResultTy =
cast<ValueTensorType>(op.getResult().getType());
SmallVector<Value> shapeValues;
for (auto dim : oldResultTy.getSizes()) {
shapeValues.push_back(createInt(dim));
}
Value originalShape = rewriter.create<Torch::PrimListConstructOp>(
loc, rewriter.getType<Torch::ListType>(intType), shapeValues);
Value result = rewriter.create<Torch::AtenViewOp>(
loc, op->getResult(0).getType(), newReductionOp, originalShape);

rewriter.replaceOp(op, result);
return success();
};
};

template <typename... OpTypes>
void addConstantifyDimArgumentPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
// simple variadic template to sugar up adding the patterns
(patterns.add<ConstantifyDimArgument<OpTypes>>(context), ...);
}

void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns,
MLIRContext *context) {
// these are the reduction ops with a dim argument

addConstantifyDimArgumentPatterns<
// not supported because they have multiple results
// AtenMaxDimOp,
// AtenMinDimOp,
AtenSumDimIntListOp, AtenAllDimOp, AtenLinalgVectorNormOp,
AtenFrobeniusNormDimOp>(patterns, context);
}

class RestructureNonConstantAxesPass
: public RestructureNonConstantAxesBase<RestructureNonConstantAxesPass> {
public:
RestructureNonConstantAxesPass() = default;

void runOnOperation() override {
MLIRContext *context = &getContext();

RewritePatternSet patterns(context);

populateRestructureNonConstantAxesPattern(patterns, context);

// TODO: Debug visitation order to make this more efficient.
// A single linear scan should suffice.
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoLimit;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
return signalPassFailure();
}
}
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createRestructureNonConstantAxesPass() {
return std::make_unique<RestructureNonConstantAxesPass>();
}
4 changes: 4 additions & 0 deletions lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ void mlir::torch::registerTorchConversionPasses() {

void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
OpPassManager &pm) {
// Fix non constant dims passed to reduction ops
pm.addNestedPass<func::FuncOp>(
torch::Torch::createRestructureNonConstantAxesPass());

// We want to fuse quantized operations together before lowering to linalg.
pm.addNestedPass<func::FuncOp>(Torch::createFuseQuantizedOpsPass());
pm.addNestedPass<func::FuncOp>(Torch::createScalarizeShapesPass());
Expand Down

0 comments on commit eb7bf78

Please sign in to comment.