Skip to content

Commit

Permalink
[NFC] Cleanups to flow op folders. (iree-org#18974)
Browse files Browse the repository at this point in the history
  • Loading branch information
benvanik authored Oct 31, 2024
1 parent dc43032 commit 57fb10f
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 66 deletions.
126 changes: 65 additions & 61 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ namespace mlir::iree_compiler::IREE::Flow {
// Folding utilities
//===----------------------------------------------------------------------===//

namespace {

// Erases an op if it has no uses.
// This is to support ops that are "pure" but can't be marked as such because
// the MLIR CSE pass would deduplicate them.
Expand Down Expand Up @@ -170,6 +172,8 @@ static SmallVector<Value> refreshDimsOnTypeChange(Operation *op, Type oldType,
return newDims;
}

} // namespace

//===----------------------------------------------------------------------===//
// flow.dispatch.workgroups
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -365,6 +369,8 @@ void DispatchWorkgroupsOp::getCanonicalizationPatterns(
// flow.dispatch.workload.ordinal
//===----------------------------------------------------------------------===//

namespace {

// Bubble up the ordinal ops so that all uses go through this operation.
struct BubbleUpOrdinalOp : public OpRewritePattern<DispatchWorkloadOrdinalOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -399,6 +405,8 @@ struct BubbleUpOrdinalOp : public OpRewritePattern<DispatchWorkloadOrdinalOp> {
}
};

} // namespace

/// Fold away following sequence of `flow.dispatch.workload.ordinal`.
///
/// ```mlir
Expand Down Expand Up @@ -863,25 +871,6 @@ OpFoldResult TensorReshapeOp::fold(FoldAdaptor operands) {
return {};
}

//===----------------------------------------------------------------------===//
// flow.tensor.bitcast
//===----------------------------------------------------------------------===//

OpFoldResult TensorBitCastOp::fold(FoldAdaptor operands) {
auto sourceType = llvm::cast<ShapedType>(getSource().getType());
auto resultType = llvm::cast<ShapedType>(getResult().getType());
if (sourceType.getElementType() != resultType.getElementType()) {
// Element type mismatch, this is a bitcast.
return {};
}
if (compareShapesEqual(sourceType, getSourceDims(), resultType,
getResultDims())) {
// Shapes match and this is a no-op so just fold to the source.
return getSource();
}
return {};
}

namespace {

// Flatten a chain of reshapes or bitcasts (reshape/bitcast feeding into
Expand Down Expand Up @@ -930,48 +919,6 @@ struct FlattenTensorCastLikeChain : public OpRewritePattern<CastOpTy> {
}
};

// Replace `flow.tensor.splat`-`flow.tensor.load` op-pairs by the input
// primitive value for the splat op.
struct FoldSplatLoadIntoPrimitive : public OpRewritePattern<TensorLoadOp> {
using OpRewritePattern<TensorLoadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TensorLoadOp loadOp,
PatternRewriter &rewriter) const override {
auto sourceOp =
dyn_cast_or_null<TensorSplatOp>(loadOp.getSource().getDefiningOp());

if (!sourceOp)
return failure();

rewriter.replaceOp(loadOp, sourceOp.getValue());
return success();
}
};

struct FoldSplatReshapeIntoSplat : public OpRewritePattern<TensorSplatOp> {
using OpRewritePattern<TensorSplatOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TensorSplatOp splatOp,
PatternRewriter &rewriter) const override {
if (!splatOp.getResult().hasOneUse())
return failure();

auto reshapeOp = dyn_cast_or_null<TensorReshapeOp>(
splatOp.getResult().use_begin()->getOwner());
if (!reshapeOp)
return failure();

PatternRewriter::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(reshapeOp);
rewriter.replaceOpWithNewOp<TensorSplatOp>(
reshapeOp, reshapeOp.getResult().getType(), splatOp.getValue(),
reshapeOp.getResultDims());
rewriter.eraseOp(splatOp);

return success();
}
};

struct ResolveShapedRank : public OpRewritePattern<tensor::RankOp> {
using OpRewritePattern<tensor::RankOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::RankOp op,
Expand Down Expand Up @@ -1032,6 +979,25 @@ void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.insert<ResolveShapedDim>(context);
}

//===----------------------------------------------------------------------===//
// flow.tensor.bitcast
//===----------------------------------------------------------------------===//

OpFoldResult TensorBitCastOp::fold(FoldAdaptor operands) {
auto sourceType = llvm::cast<ShapedType>(getSource().getType());
auto resultType = llvm::cast<ShapedType>(getResult().getType());
if (sourceType.getElementType() != resultType.getElementType()) {
// Element type mismatch, this is a bitcast.
return {};
}
if (compareShapesEqual(sourceType, getSourceDims(), resultType,
getResultDims())) {
// Shapes match and this is a no-op so just fold to the source.
return getSource();
}
return {};
}

void TensorBitCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<ReplaceOpIfTensorOperandZeroElements<TensorBitCastOp, 0>>(
Expand Down Expand Up @@ -1060,6 +1026,25 @@ OpFoldResult TensorLoadOp::fold(FoldAdaptor operands) {
return {};
}

namespace {

// Replace `flow.tensor.splat`-`flow.tensor.load` op-pairs by the input
// primitive value for the splat op.
struct FoldSplatLoadIntoPrimitive : public OpRewritePattern<TensorLoadOp> {
using OpRewritePattern<TensorLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorLoadOp loadOp,
PatternRewriter &rewriter) const override {
auto sourceOp =
dyn_cast_or_null<TensorSplatOp>(loadOp.getSource().getDefiningOp());
if (!sourceOp)
return failure();
rewriter.replaceOp(loadOp, sourceOp.getValue());
return success();
}
};

} // namespace

void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<FoldSplatLoadIntoPrimitive>(context);
Expand Down Expand Up @@ -1116,6 +1101,25 @@ void TensorEmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
// flow.tensor.splat
//===----------------------------------------------------------------------===//

namespace {

struct FoldSplatReshapeIntoSplat : public OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
auto splatOp = dyn_cast_if_present<TensorSplatOp>(
reshapeOp.getSource().getDefiningOp());
if (!splatOp)
return failure();
rewriter.replaceOpWithNewOp<TensorSplatOp>(
reshapeOp, reshapeOp.getResult().getType(), splatOp.getValue(),
reshapeOp.getResultDims());
return success();
}
};

} // namespace

void TensorSplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO(benvanik): canonicalize splat+slice to smaller splat.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,14 +411,14 @@ util.func public @cloneDynamicZeroElements(%arg0: tensor<0x?xf32>, %dim: index)

// CHECK-LABEL: @ElideRedundantTransfer
// CHECK-SAME: (%[[OPERAND:.+]]: tensor<4x?xf32>, %[[DIM:.+]]: index)
util.func public @ElideRedundantTransfer(%arg0: tensor<4x?xf32>, %dim: index) -> tensor<4x?xi32> {
// CHECK: %[[TRANSFER:.+]] = flow.tensor.transfer %arg0
%transfer = flow.tensor.transfer %arg0 : tensor<4x?xf32>{%dim} to "target"
util.func public @ElideRedundantTransfer(%operand: tensor<4x?xf32>, %dim: index) -> tensor<4x?xi32> {
// CHECK: %[[TRANSFER:.+]] = flow.tensor.transfer %[[OPERAND]]
%transfer = flow.tensor.transfer %operand : tensor<4x?xf32>{%dim} to "target"
// CHECK: %[[BITCAST:.+]] = flow.tensor.bitcast %[[TRANSFER]]
%bitcast = flow.tensor.bitcast %transfer : tensor<4x?xf32>{%dim} -> tensor<4x?xi32>{%dim}
// CHECK-NOT: flow.transfer
// CHECK-NOT: flow.tensor.transfer
%redundant = flow.tensor.transfer %bitcast : tensor<4x?xi32>{%dim} to "target"
// CHECK-NEXT: %[[BITCAST]]
// CHECK-NEXT: util.return %[[BITCAST]]
util.return %redundant : tensor<4x?xi32>
}

Expand Down

0 comments on commit 57fb10f

Please sign in to comment.