Skip to content

Commit

Permalink
[MLIR][SCF] Loop pipelining fails on failed predication (no assert) (l…
Browse files Browse the repository at this point in the history
…lvm#107442)

The SCFLoopPipelining allows predication on peeled or loop ops. When the
predicationFn returns a nullptr this signifies the op type is
unsupported and the pipeliner fails except in `emitPrologue` where it
asserts.

This patch fixes handling in the prologue to gracefully fail.
  • Loading branch information
sjw36 authored Sep 5, 2024
1 parent d219c63 commit 1892666
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ struct LoopPipelinerInternal {
bool initializeLoopInfo(ForOp op, const PipeliningOption &options);
/// Emits the prologue, this creates `maxStage - 1` part which will contain
/// operations from stages [0; i], where i is the part index.
void emitPrologue(RewriterBase &rewriter);
LogicalResult emitPrologue(RewriterBase &rewriter);
/// Gather liverange information for Values that are used in a different stage
/// than its definition.
llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
Expand Down Expand Up @@ -263,7 +263,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
return clone;
}

void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
// Initialize the iteration argument to the loop initial values.
for (auto [arg, operand] :
llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
Expand Down Expand Up @@ -311,7 +311,8 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
if (predicates[predicateIdx]) {
OpBuilder::InsertionGuard insertGuard(rewriter);
newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
assert(newOp && "failed to predicate op.");
if (newOp == nullptr)
return failure();
}
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
Expand Down Expand Up @@ -339,6 +340,7 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
}
}
}
return success();
}

llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
Expand Down Expand Up @@ -772,7 +774,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
*modifiedIR = true;

// 1. Emit prologue.
pipeliner.emitPrologue(rewriter);
if (failed(pipeliner.emitPrologue(rewriter)))
return failure();

// 2. Track values used across stages. When a value cross stages it will
// need to be passed as loop iteration arguments.
Expand Down

0 comments on commit 1892666

Please sign in to comment.