diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp index 66acf8cb670a..f26f3223f3ce 100644 --- a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp @@ -46,6 +46,84 @@ class TransposeUnitDimToShapeCast } }; +// TODO: Move this upstream +// Hoists a vector.bitcast op to the output of the enclosing scf.if +// +// This transforms IR like: +// %0 = scf.if %1 -> (vector<16xi8>) { +// %2 = memref.load %4[%c0] : memref> +// %3 = vector.bitcast %2 : vector<4xi32> to vector<16xi8> +// scf.yield %3 : vector<16xi8> +// } else { +// scf.yield %cst : vector<16xi8> +// } +// Into: +// %0 = scf.if %1 -> (vector<4xi32>) { +// %2 = memref.load %4[%c0] : memref> +// scf.yield %2 : vector<4xi32> +// } else { +// %3 = vector.bitcast %cst : vector<16xi8> to vector<4xi32> +// scf.yield %0 : vector<4xi32> +// } +// %3 = vector.bitcast %0 : vector<4xi32> to vector<16xi8> +struct BubbleUpBitCastOfScfIf : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::IfOp ifOp, + PatternRewriter &rewriter) const override { + // Bail on more than one result for now. + scf::YieldOp thenYield = ifOp.thenYield(); + if (!thenYield || thenYield.getNumOperands() != 1) + return failure(); + auto bitcastOp = thenYield.getOperand(0).getDefiningOp(); + // Bail out if no bitcast on the if then statement. + if (!bitcastOp) + return failure(); + + VectorType castSrcType = bitcastOp.getSourceVectorType(); + VectorType castDstType = bitcastOp.getResultVectorType(); + assert(castSrcType.getRank() == castDstType.getRank()); + // Skip 0-D vector. + if (castSrcType.getRank() == 0) + return failure(); + + int64_t castSrcLastDim = castSrcType.getShape().back(); + int64_t castDstLastDim = castDstType.getShape().back(); + // Require casting to more elements; + if (castSrcLastDim > castDstLastDim) + return failure(); + + Location loc = ifOp.getLoc(); + + auto bitcastedIfOp = + rewriter.create(loc, castSrcType, ifOp.getCondition()); + bitcastedIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + bitcastedIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + + scf::YieldOp newThenYield = bitcastedIfOp.thenYield(); + auto newBitcastOp = + newThenYield.getOperand(0).getDefiningOp(); + + newThenYield.setOperand(0, newBitcastOp.getSource()); + + auto newBitcast = rewriter.create( + loc, castDstType, bitcastedIfOp.getResult(0)); + + scf::YieldOp elseYield = bitcastedIfOp.elseYield(); + if (elseYield) { + OpBuilder::InsertionGuard elseGuard(rewriter); + rewriter.setInsertionPoint(elseYield); + + Value yieldSrc = elseYield.getOperand(0); + auto elseBitcast = + rewriter.create(loc, castSrcType, yieldSrc); + elseYield.setOperand(0, elseBitcast); + } + rewriter.replaceOp(ifOp, newBitcast); + return success(); + } +}; + static void loopInvariantCodeMotion(func::FuncOp funcOp) { // Walk through all loops in a function in innermost-loop-first order. This // way, we first LICM from the inner loop, and place the ops in @@ -89,6 +167,7 @@ struct OptimizeVectorTransferPass { RewritePatternSet patterns(&getContext()); vector::populateBubbleVectorBitCastOpPatterns(patterns); + patterns.add(&getContext()); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); }