diff --git a/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp b/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp index d9667ddb9099..0a0386dc35ba 100644 --- a/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp @@ -22,6 +22,8 @@ namespace iree_compiler { namespace IREE { namespace Flow { +static const char kSplitKAttr[] = "iree_flow_split_k"; + // TODO(thomasraoux): Move to attributes. static llvm::cl::opt splitReductionRatio( "iree-flow-split-matmul-reduction", llvm::cl::desc("split ratio"), @@ -75,16 +77,18 @@ struct SplitReductionPass : public SplitReductionBase { } void runOnOperation() override { - if (splitReductionRatio <= 1) return; - RewritePatternSet patterns(&getContext()); patterns.add( &getContext(), - [&](linalg::LinalgOp op) { + [](linalg::LinalgOp op) { + int64_t ratio = splitReductionRatio; + if (auto attr = op->getAttrOfType(kSplitKAttr)) + ratio = attr.getInt(); + if (ratio <= 1) return std::make_pair(int64_t(0), 0); // For matmul make the new parallel dimension first so that it looks // like a batch_matmul and can follow the same codegen. if (isa(op)) - return std::make_pair(int64_t(splitReductionRatio), 0); + return std::make_pair(int64_t(ratio), 0); // Currently disable spliting reduction for non-matmul op. This will // get enabled after once tests are ready. return std::make_pair(int64_t(0), 0);