Skip to content

Commit

Permalink
[mlir][SCF] Use Affine ops for indexing math. (#108450)
Browse files Browse the repository at this point in the history
For index type of induction variable, the indexing math is better
represented using affine ops such as `affine.delinearize_index`.

This also further demonstrates that some of these `affine` ops might
need to move to a different dialect. For one these ops only support
`IndexType` when they should be able to work with any integer type.

This change also includes some canonicalization patterns for
`affine.delinearize_index` operation to
1) Drop unit `basis` values
2) Remove the `delinearize_index` op when the `linear_index` is a loop
induction variable of a normalized loop and the `basis` is of size 1 and
is also the upper bound of the normalized loop.

---------

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar authored Sep 28, 2024
1 parent d33fa70 commit cca3217
Show file tree
Hide file tree
Showing 11 changed files with 416 additions and 207 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
];

let hasVerifier = 1;
let hasCanonicalizer = 1;
}

#endif // AFFINE_OPS
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Affine/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> {
let summary = "Coalesce nested loops with independent bounds into a single "
"loop";
let constructor = "mlir::affine::createLoopCoalescingPass()";
let dependentDialects = ["arith::ArithDialect"];
let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"];
}

def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
def TestSCFParallelLoopCollapsing : Pass<"test-scf-parallel-loop-collapsing"> {
let summary = "Test parallel loops collapsing transformation";
let constructor = "mlir::createTestSCFParallelLoopCollapsingPass()";
let dependentDialects = ["affine::AffineDialect"];
let description = [{
This pass is purely for testing the scf::collapseParallelLoops
transformation. The transformation does not have opinions on how a
Expand Down
127 changes: 127 additions & 0 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4537,6 +4537,133 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
return success();
}

namespace {

// Drops delinearization indices that correspond to unit-extent basis
struct DropUnitExtentBasis
: public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
PatternRewriter &rewriter) const override {
SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
std::optional<Value> zero = std::nullopt;
Location loc = delinearizeOp->getLoc();
auto getZero = [&]() -> Value {
if (!zero)
zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
return zero.value();
};

// Replace all indices corresponding to unit-extent basis with 0.
// Remaining basis can be used to get a new `affine.delinearize_index` op.
SmallVector<Value> newOperands;
for (auto [index, basis] : llvm::enumerate(delinearizeOp.getBasis())) {
if (matchPattern(basis, m_One()))
replacements[index] = getZero();
else
newOperands.push_back(basis);
}

if (newOperands.size() == delinearizeOp.getBasis().size())
return failure();

if (!newOperands.empty()) {
auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
loc, delinearizeOp.getLinearIndex(), newOperands);
int newIndex = 0;
// Map back the new delinearized indices to the values they replace.
for (auto &replacement : replacements) {
if (replacement)
continue;
replacement = newDelinearizeOp->getResult(newIndex++);
}
}

rewriter.replaceOp(delinearizeOp, replacements);
return success();
}
};

/// Drop delinearization pattern related to loops in the following way
///
/// ```
/// <loop>(%iv) = (%c0) to (%ub) step (%c1) {
/// %0 = affine.delinearize_index %iv into (%ub) : index
/// <some_use>(%0)
/// }
/// ```
///
/// can be canonicalized to
///
/// ```
/// <loop>(%iv) = (%c0) to (%ub) step (%c1) {
/// <some_use>(%iv)
/// }
/// ```
struct DropDelinearizeOfSingleLoop
: public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
PatternRewriter &rewriter) const override {
auto basis = delinearizeOp.getBasis();
if (basis.size() != 1)
return failure();

// Check that the `linear_index` is an induction variable.
auto inductionVar = cast<BlockArgument>(delinearizeOp.getLinearIndex());
if (!inductionVar)
return failure();

// Check that the parent is a `LoopLikeOpInterface`.
auto loopLikeOp = cast<LoopLikeOpInterface>(
inductionVar.getParentRegion()->getParentOp());
if (!loopLikeOp)
return failure();

// Check that loop is unit-rank and that the `linear_index` is the induction
// variable.
auto inductionVars = loopLikeOp.getLoopInductionVars();
if (!inductionVars || inductionVars->size() != 1 ||
inductionVars->front() != inductionVar) {
return rewriter.notifyMatchFailure(
delinearizeOp, "`linear_index` is not loop induction variable");
}

// Check that the upper-bound is the basis.
auto upperBounds = loopLikeOp.getLoopUpperBounds();
if (!upperBounds || upperBounds->size() != 1 ||
upperBounds->front() != getAsOpFoldResult(basis.front())) {
return rewriter.notifyMatchFailure(delinearizeOp,
"`basis` is not upper bound");
}

// Check that the lower bound is zero.
auto lowerBounds = loopLikeOp.getLoopLowerBounds();
if (!lowerBounds || lowerBounds->size() != 1 ||
!isZeroIndex(lowerBounds->front())) {
return rewriter.notifyMatchFailure(delinearizeOp,
"loop lower bound is not zero");
}

// Check that the step is one.
auto steps = loopLikeOp.getLoopSteps();
if (!steps || steps->size() != 1 || !isConstantIntValue(steps->front(), 1))
return rewriter.notifyMatchFailure(delinearizeOp, "loop step is not one");

rewriter.replaceOp(delinearizeOp, inductionVar);
return success();
}
};

} // namespace

void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.insert<DropDelinearizeOfSingleLoop, DropUnitExtentBasis>(context);
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "mlir/Dialect/SCF/Transforms/Passes.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Transforms/RegionUtils.h"
Expand Down
78 changes: 77 additions & 1 deletion mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -671,9 +672,26 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
return success();
}

Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc,
OpFoldResult lb, OpFoldResult ub,
OpFoldResult step) {
Range normalizedLoopBounds;
normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
AffineExpr s0, s1, s2;
bindSymbols(rewriter.getContext(), s0, s1, s2);
AffineExpr e = (s1 - s0).ceilDiv(s2);
normalizedLoopBounds.size =
affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step});
return normalizedLoopBounds;
}

Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
OpFoldResult lb, OpFoldResult ub,
OpFoldResult step) {
if (getType(lb).isIndex()) {
return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
}
// For non-index types, generate `arith` instructions
// Check if the loop is already known to have a constant zero lower bound or
// a constant one step.
Expand Down Expand Up @@ -714,9 +732,38 @@ Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
return {newLowerBound, newUpperBound, newStep};
}

static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
Location loc,
Value normalizedIv,
OpFoldResult origLb,
OpFoldResult origStep) {
AffineExpr d0, s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
bindDims(rewriter.getContext(), d0);
AffineExpr e = d0 * s1 + s0;
OpFoldResult denormalizedIv = affine::makeComposedFoldedAffineApply(
rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
Value denormalizedIvVal =
getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv);
SmallPtrSet<Operation *, 1> preservedUses;
// If an `affine.apply` operation is generated for denormalization, the use
// of `origLb` in those ops must not be replaced. These arent not generated
// when `origLb == 0` and `origStep == 1`.
if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
preservedUses.insert(preservedUse);
}
}
rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses);
}

void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
Value normalizedIv, OpFoldResult origLb,
OpFoldResult origStep) {
if (getType(origLb).isIndex()) {
return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
origLb, origStep);
}
Value denormalizedIv;
SmallPtrSet<Operation *, 2> preserve;
bool isStepOne = isConstantIntValue(origStep, 1);
Expand All @@ -739,10 +786,29 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
}

static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc,
ArrayRef<OpFoldResult> values) {
assert(!values.empty() && "unexecpted empty array");
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
AffineExpr mul = s0 * s1;
OpFoldResult products = rewriter.getIndexAttr(1);
for (auto v : values) {
products = affine::makeComposedFoldedAffineApply(
rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
}
return products;
}

/// Helper function to multiply a sequence of values.
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
ArrayRef<Value> values) {
assert(!values.empty() && "unexpected empty list");
if (getType(values.front()).isIndex()) {
SmallVector<OpFoldResult> ofrs = getAsOpFoldResult(values);
OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
return getValueOrCreateConstantIndexOp(rewriter, loc, product);
}
std::optional<Value> productOf;
for (auto v : values) {
auto vOne = getConstantIntValue(v);
Expand All @@ -757,7 +823,7 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
if (!productOf) {
productOf = rewriter
.create<arith::ConstantOp>(
loc, rewriter.getOneAttr(values.front().getType()))
loc, rewriter.getOneAttr(getType(values.front())))
.getResult();
}
return productOf.value();
Expand All @@ -774,6 +840,16 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
Value linearizedIv, ArrayRef<Value> ubs) {

if (linearizedIv.getType().isIndex()) {
Operation *delinearizedOp =
rewriter.create<affine::AffineDelinearizeIndexOp>(loc, linearizedIv,
ubs);
auto resultVals = llvm::map_to_vector(
delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
}

SmallVector<Value> delinearizedIvs(ubs.size());
SmallPtrSet<Operation *, 2> preservedUsers;

Expand Down
48 changes: 48 additions & 0 deletions mlir/test/Dialect/Affine/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1466,3 +1466,51 @@ func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
}
return
}

// -----

func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) ->
(index, index, index, index, index, index) {
%c1 = arith.constant 1 : index
%0:6 = affine.delinearize_index %arg0 into (%c1, %arg1, %c1, %c1, %arg2, %c1)
: index, index, index, index, index, index
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : index, index, index, index, index, index
}
// CHECK-LABEL: func @drop_unit_basis_in_delinearize(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], %[[ARG2]])
// CHECK: return %[[C0]], %[[DELINEARIZE]]#0, %[[C0]], %[[C0]], %[[DELINEARIZE]]#1, %[[C0]]

// -----

func.func @drop_all_unit_bases(%arg0 : index) -> (index, index) {
%c1 = arith.constant 1 : index
%0:2 = affine.delinearize_index %arg0 into (%c1, %c1) : index, index
return %0#0, %0#1 : index, index
}
// CHECK-LABEL: func @drop_all_unit_bases(
// CHECK-SAME: %[[ARG0:.+]]: index)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-NOT: affine.delinearize_index
// CHECK: return %[[C0]], %[[C0]]

// -----

func.func @drop_single_loop_delinearize(%arg0 : index, %arg1 : index) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%2 = scf.for %iv = %c0 to %arg1 step %c1 iter_args(%arg2 = %c0) -> index {
%0 = affine.delinearize_index %iv into (%arg1) : index
%1 = "some_use"(%arg2, %0) : (index, index) -> (index)
scf.yield %1 : index
}
return %2 : index
}
// CHECK-LABEL: func @drop_single_loop_delinearize(
// CHECK-SAME: %[[ARG0:.+]]: index)
// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] =
// CHECK-NOT: affine.delinearize_index
// CHECK: "some_use"(%{{.+}}, %[[IV]])
Loading

0 comments on commit cca3217

Please sign in to comment.