Skip to content

Commit

Permalink
[mlir][linalg] unfold projected permutation. (llvm#114704)
Browse files Browse the repository at this point in the history
Patterns to decompose the input operand(s) of a linalg.generic that has
a projected permutation`  affine-map -- i.e. effectively a folded `transpose`, 
`broadcast`, or a mixture of two --  into explicit transpose and broadcast.
This is useful for instance when trying to recognize named ops.
email: [email protected]
  • Loading branch information
javedabsar1 authored Nov 9, 2024
1 parent c236dbc commit 0ac4821
Show file tree
Hide file tree
Showing 5 changed files with 326 additions and 0 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1820,6 +1820,10 @@ void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
/// linalg.fill(%cst, tensor.extract_slice(%init)).
void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns);

/// Add patterns to make explicit broadcasts and transforms in the
/// input operands of a genericOp.
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns);

/// Patterns to apply `splitReduction` below.
void populateSplitReductionPattern(
RewritePatternSet &patterns,
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
TilingInterfaceImpl.cpp
Transforms.cpp
TransposeConv2D.cpp
DecomposeGenericByUnfoldingPermutation.cpp
Vectorization.cpp
WinogradConv2D.cpp

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
//===- DecomposeGenericByUnfoldingPermutation.cpp -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include <map>
#include <optional>
#include <utility>

using namespace mlir;
using namespace mlir::linalg;

namespace {

/// This pattern decomposes the input operand(s) of a linalg.generic that has
/// a `transpose`, `broadcast`, or a mixture of two, into explicit transpose
/// and broadcast. Having them folded into the linalg.generic is a good
/// optimization but sometimes we may want to unwrap, i.e., `unfold` them as
/// explicit transpose and broadcast. This rewrite pattern helps do it for
/// each input operand. This is useful for instance when trying to recognize
/// named ops.
///
/// The transpose, broadcast, or mixture of both, are expressed in the affine
/// map of the operand. Technically it is essentially `projected permutation`.
///
/// Example
///
/// ```mlir
///
/// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
/// #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
/// ...
/// %res = linalg.generic
/// { indexing_maps = [#projection, #identity, #identity],
/// iterator_types = ["parallel", "parallel", "parallel",
/// "parallel", "parallel"]}
/// ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>)
/// outs(%z : tensor<5x9x7x8x10xf32>) {
/// ^bb0(%in: f32, %in_1: f32, %out: f32):
/// %div = arith.divf %in, %in_1 : f32
/// linalg.yield %div : f32
/// } -> tensor<5x9x7x8x10xf32>
/// ```
///
/// In the above IR operand `%x` map is a projected-permutation. This can be
/// unfolded as:
///
/// ```mlir
/// ...
/// %x_trans = linalg.transpose
/// ins(%x : tensor<7x8x9xf32>)
/// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1]
/// ...
/// %x_trans_bc = linalg.broadcast
/// ins(%x_trans : tensor<9x7x8xf32>)
/// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
/// %2 = linalg.div
/// ins(%x_trans_bc, %y :
/// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>)
/// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
///
/// Note that linalg.generic has been 'specialized' to linalg.div.
///
/// To unfold it, it is more optimal to transpose first and then do the
/// broadcast. However, if transpose is done first, the permutation map needs
/// to be expressed in terms of reduced dimension as broadcast hasn't happened
/// yet. Also, the broadcast dimensions in a linalg.generic come from other
/// operands (those not broadcasted along that particular dimension). We work
/// this out by computing the convex-polyhedron shape of the linalg.generic
/// iteration space from shapes of all the operands, both inputs and outputs.
///
struct DecomposeProjectedPermutation : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;

LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override;
};

/// For the given `map`, determine what dimensions are transposed and what
/// dimensions are broadcasted.
/// Returns :
/// transpose-permutation, broadcast-dimensions` (empty if not needed)
///
std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
computeTransposeBroadcast(AffineMap &map) {
assert(map.isProjectedPermutation(false) && "not a projection");

// As the map is a projection it likely operates on a smaller set of
// dimensions as far as the transpose is concerned (rest are broadcast).
int64_t minorSize = map.getNumResults();

SmallVector<int64_t> minorResult;
for (int64_t i = 0; i < minorSize; ++i) {
auto expr = cast<AffineDimExpr>(map.getResults()[i]);
minorResult.push_back(expr.getPosition());
}

// If dims are not monotonically increasing then transpose is present.
SmallVector<int64_t> sortedResMap(minorResult);
std::sort(sortedResMap.begin(), sortedResMap.end());
bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(),
sortedResMap.begin(), sortedResMap.end());

// Walk the sorted map result to determine which dimensions are broadcasted.
SmallVector<int64_t> broadcast;
for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) {
if (j < minorSize && sortedResMap[j] == i) {
j++;
continue;
}
broadcast.push_back(i);
}

SmallVector<int64_t> permutation;
if (hasTranspose) {
// Consider an operand `x : tensor<7x8x9>` of a genericOp that has
// affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
// `x`s access is both transposed and broadcast. But when specifying
// the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
// specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
// refering to d3, d4. Therefore, re-base the transpose dimensions so
// that they start from d0.
permutation.resize(minorSize);
std::map<int64_t, int64_t> minorMap;
for (int64_t i = 0; i < minorSize; ++i)
minorMap.insert({sortedResMap[i], i});

// Re-map the dimensions.
SmallVector<int64_t> remappedResult(minorSize);
for (int64_t i = 0; i < minorSize; ++i)
remappedResult[i] = minorMap[minorResult[i]];

/// Calculate the permutation for the transpose.
for (unsigned i = 0; i < minorSize; ++i) {
permutation[remappedResult[i]] = i;
}
}
return {permutation, broadcast};
}

LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
GenericOp op, PatternRewriter &rewriter) const {
if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() ||
op.isSingleYieldOp() || !op.isAllParallelLoops())
return failure();

// If the map of an operand is not a `projected permutation` then
// it cannot be decomposed to mere transpose and broadcast.
// The requirement that all maps be `projected permutation` may be
// over-restrictive but since we need to determine shape of the
// iteration space as well, reject if any map violates assumption.
for (auto &opOperand : op->getOpOperands()) {
auto map = op.getMatchingIndexingMap(&opOperand);
if (!map.isProjectedPermutation(false))
return failure();
}

// Decomposing linalg.generic involves creating `tensor.empty`
// which can have dynamic shapes but then we would have to work
// out which operand can supply that runtime-value (tensor.dim).
// Leaving it as a future TODO.
if (llvm::any_of(op->getOpOperands(), [](OpOperand &oper) {
auto opType = cast<RankedTensorType>(oper.get().getType());
return ShapedType::isDynamicShape(opType.getShape());
}))
return failure();

auto outputShape = op.getStaticLoopRanges();

auto loc = op.getLoc();
bool isChanged = false;
SmallVector<Value> newInitValues = op.getDpsInputs();
SmallVector<AffineMap> newMap = op.getIndexingMapsArray();

// Walk over each input operand and unfold if it is transposed, broadcast
// or mix of two via operand's affine-map.
for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) {
auto &map = newMap[i];
auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType());
auto elType = inputRTType.getElementType();

/// Nothing to do if map is already an identity.
if (map.isIdentity())
continue;

auto [permutation, broadcastedDims] = computeTransposeBroadcast(map);

// Does it need transpose?
if (!permutation.empty()) {
/// linalg.transpose permutes the dimensions of input using
/// rule: dim(result, i) = dim(input, permutation[i])
SmallVector<int64_t> transposedShape(map.getNumResults());
for (int64_t i = 0; i < map.getNumResults(); ++i)
transposedShape[i] = inputRTType.getShape()[permutation[i]];

Value emptyTensor =
rewriter.create<tensor::EmptyOp>(loc, transposedShape, elType);

auto transposeOp = rewriter.create<TransposeOp>(loc, newInitValues[i],
emptyTensor, permutation);
newInitValues[i] = transposeOp->getResult(0);
isChanged = true;
}

// Does it require broadcast?
if (!broadcastedDims.empty()) {
assert(broadcastedDims.size() && "should have non size broadcast");
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, outputShape, inputRTType.getElementType());

auto broadcastOp = rewriter.create<linalg::BroadcastOp>(
loc, newInitValues[i], emptyTensor, broadcastedDims);

newInitValues[i] = broadcastOp->getResult(0);
isChanged = true;
}
newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
}

if (isChanged) {
SmallVector<Value> operands = op->getOperands();
ValueRange operandsRef(operands);

auto newOp = rewriter.create<linalg::GenericOp>(
/*location=*/op.getLoc(),
/*resultTensorTypes=*/op->getResultTypes(),
/*inputs=*/newInitValues,
/*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
/*indexingMaps=*/newMap,
/*iteratorTypes=*/op.getIteratorTypesArray());

newOp.getRegion().takeBody(op->getRegion(0));
rewriter.replaceOp(op, newOp->getResults());
}
return success();
}

} // namespace

void mlir::linalg::populateDecomposeProjectedPermutationPatterns(
RewritePatternSet &patterns) {
patterns.insert<DecomposeProjectedPermutation>(patterns.getContext());
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ struct LinalgSpecializeGenericOpsPass
void LinalgSpecializeGenericOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateLinalgGenericOpsSpecializationPatterns(patterns);
populateDecomposeProjectedPermutationPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s

#projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
#identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>

func.func @transpose_and_broadcast(%x : tensor<7x8x9xf32>, %y: tensor<5x9x7x8x10xf32>, %z : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> {
%res = linalg.generic
{ indexing_maps = [#projection, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) outs(%z : tensor<5x9x7x8x10xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%div = arith.divf %in, %in_1 : f32
linalg.yield %div : f32
} -> tensor<5x9x7x8x10xf32>
return %res : tensor<5x9x7x8x10xf32>
}

// CHECK-LABEL: transpose_and_broadcast
// CHECK-SAME: %[[X:.+]]: tensor<7x8x9xf32>, %[[Y:.+]]: tensor<5x9x7x8x10xf32>, %[[Z:.+]]: tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> {
// CHECK: %[[E0:.+]] = tensor.empty() : tensor<9x7x8xf32>
// CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<7x8x9xf32>) outs(%[[E0]] : tensor<9x7x8xf32>) permutation = [2, 0, 1]
// CHECK: %[[E1:.+]] = tensor.empty() : tensor<5x9x7x8x10xf32>
// CHECK: %[[X_trans_bc:.+]] = linalg.broadcast ins(%[[X_trans]] : tensor<9x7x8xf32>) outs(%[[E1]] : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
// CHECK: {{.*}} = linalg.div ins(%[[X_trans_bc]], %[[Y]] : tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) outs(%[[Z]] : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
// CHECK-NOT: linalg.generic

// -----

#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#transposed = affine_map<(d0, d1, d2) -> (d2, d0, d1)>

func.func @transpose_only(%x : tensor<32x2x16xf32>, %y: tensor<2x16x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
%res = linalg.generic
{ indexing_maps = [#transposed, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel"]}
ins(%x, %y : tensor<32x2x16xf32>, tensor<2x16x32xf32>)
outs(%z : tensor<2x16x32xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%div = arith.divf %in, %in_1 : f32
linalg.yield %div : f32
} -> tensor<2x16x32xf32>
return %res : tensor<2x16x32xf32>
}

// CHECK-LABEL: transpose_only
// CHECK-SAME: %[[X:.+]]: tensor<32x2x16xf32>, %[[Y:.+]]: tensor<2x16x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32>
// CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<32x2x16xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) permutation = [1, 2, 0]
// CHECK: {{.*}} = linalg.div ins(%[[X_trans]], %[[Y]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%[[Z]] : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
// CHECK-NOT: linalg.generic

// -----

#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#broadcast = affine_map<(d0, d1, d2) -> (d0, d2)>
func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y: tensor<2x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
%res = linalg.generic
{ indexing_maps = [#identity, #broadcast, #identity], iterator_types = ["parallel", "parallel", "parallel"]}
ins(%x, %y : tensor<2x16x32xf32>, tensor<2x32xf32>)
outs(%z : tensor<2x16x32xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%div = arith.divf %in, %in_1 : f32
linalg.yield %div : f32
} -> tensor<2x16x32xf32>
return %res : tensor<2x16x32xf32>
}

// CHECK-LABEL: broadcast_only
// CHECK-SAME: %[[X:.+]]: tensor<2x16x32xf32>, %[[Y:.+]]: tensor<2x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32>
// CHECK: %[[X_bc:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1]
// CHECK: {{.*}} = linalg.div ins(%[[X]], %[[X_bc]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
// CHECK-NOT: linalg.generic

0 comments on commit 0ac4821

Please sign in to comment.