Skip to content

Commit

Permalink
Add pattern for bubbling vector.bitcast through an enclosing scf.if
Browse files Browse the repository at this point in the history
<32 bit width types are handled on the SPIR-V side by introducing
bitcasts to and from i32 and bubbling them to the center of the kernel
hoping to cancel. This adds a pattern for a bitcast on the result of an
scf.if, which comes from the way that padding is handled (transfer_read
in the `then` branch, else yield a splat constant).
  • Loading branch information
qedawkins committed Jul 20, 2023
1 parent 44a733b commit f8976fd
Showing 1 changed file with 79 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,84 @@ class TransposeUnitDimToShapeCast
}
};

// TODO: Move this upstream
// Hoists a vector.bitcast op to the output of the enclosing scf.if
//
// This transforms IR like:
// %0 = scf.if %1 -> (vector<16xi8>) {
// %2 = memref.load %4[%c0] : memref<?xvector<4xi32>>
// %3 = vector.bitcast %2 : vector<4xi32> to vector<16xi8>
// scf.yield %3 : vector<16xi8>
// } else {
// scf.yield %cst : vector<16xi8>
// }
// Into:
// %0 = scf.if %1 -> (vector<4xi32>) {
// %2 = memref.load %4[%c0] : memref<?xvector<4xi32>>
// scf.yield %2 : vector<4xi32>
// } else {
// %3 = vector.bitcast %cst : vector<16xi8> to vector<4xi32>
// scf.yield %0 : vector<4xi32>
// }
// %3 = vector.bitcast %0 : vector<4xi32> to vector<16xi8>
struct BubbleUpBitCastOfScfIf : public OpRewritePattern<scf::IfOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(scf::IfOp ifOp,
PatternRewriter &rewriter) const override {
// Bail on more than one result for now.
scf::YieldOp thenYield = ifOp.thenYield();
if (!thenYield || thenYield.getNumOperands() != 1)
return failure();
auto bitcastOp = thenYield.getOperand(0).getDefiningOp<vector::BitCastOp>();
// Bail out if no bitcast on the if then statement.
if (!bitcastOp)
return failure();

VectorType castSrcType = bitcastOp.getSourceVectorType();
VectorType castDstType = bitcastOp.getResultVectorType();
assert(castSrcType.getRank() == castDstType.getRank());
// Skip 0-D vector.
if (castSrcType.getRank() == 0)
return failure();

int64_t castSrcLastDim = castSrcType.getShape().back();
int64_t castDstLastDim = castDstType.getShape().back();
// Require casting to more elements;
if (castSrcLastDim > castDstLastDim)
return failure();

Location loc = ifOp.getLoc();

auto bitcastedIfOp =
rewriter.create<scf::IfOp>(loc, castSrcType, ifOp.getCondition());
bitcastedIfOp.getThenRegion().takeBody(ifOp.getThenRegion());
bitcastedIfOp.getElseRegion().takeBody(ifOp.getElseRegion());

scf::YieldOp newThenYield = bitcastedIfOp.thenYield();
auto newBitcastOp =
newThenYield.getOperand(0).getDefiningOp<vector::BitCastOp>();

newThenYield.setOperand(0, newBitcastOp.getSource());

auto newBitcast = rewriter.create<vector::BitCastOp>(
loc, castDstType, bitcastedIfOp.getResult(0));

scf::YieldOp elseYield = bitcastedIfOp.elseYield();
if (elseYield) {
OpBuilder::InsertionGuard elseGuard(rewriter);
rewriter.setInsertionPoint(elseYield);

Value yieldSrc = elseYield.getOperand(0);
auto elseBitcast =
rewriter.create<vector::BitCastOp>(loc, castSrcType, yieldSrc);
elseYield.setOperand(0, elseBitcast);
}
rewriter.replaceOp(ifOp, newBitcast);
return success();
}
};

static void loopInvariantCodeMotion(func::FuncOp funcOp) {
// Walk through all loops in a function in innermost-loop-first order. This
// way, we first LICM from the inner loop, and place the ops in
Expand Down Expand Up @@ -89,6 +167,7 @@ struct OptimizeVectorTransferPass
{
RewritePatternSet patterns(&getContext());
vector::populateBubbleVectorBitCastOpPatterns(patterns);
patterns.add<BubbleUpBitCastOfScfIf>(&getContext());
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
Expand Down

0 comments on commit f8976fd

Please sign in to comment.