Skip to content

Commit

Permalink
Add attribute control for splitK
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasRaoux authored and powderluv committed Apr 7, 2022
1 parent 9f09785 commit 683e17f
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions iree/compiler/Dialect/Flow/Transforms/SplitReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ namespace iree_compiler {
namespace IREE {
namespace Flow {

static const char kSplitKAttr[] = "iree_flow_split_k";

// TODO(thomasraoux): Move to attributes.
static llvm::cl::opt<int64_t> splitReductionRatio(
"iree-flow-split-matmul-reduction", llvm::cl::desc("split ratio"),
Expand Down Expand Up @@ -75,16 +77,18 @@ struct SplitReductionPass : public SplitReductionBase<SplitReductionPass> {
}

void runOnOperation() override {
if (splitReductionRatio <= 1) return;

RewritePatternSet patterns(&getContext());
patterns.add<LinalgSplitReduction>(
&getContext(),
[&](linalg::LinalgOp op) {
[](linalg::LinalgOp op) {
int64_t ratio = splitReductionRatio;
if (auto attr = op->getAttrOfType<IntegerAttr>(kSplitKAttr))
ratio = attr.getInt();
if (ratio <= 1) return std::make_pair(int64_t(0), 0);
// For matmul make the new parallel dimension first so that it looks
// like a batch_matmul and can follow the same codegen.
if (isa<linalg::MatmulOp>(op))
return std::make_pair(int64_t(splitReductionRatio), 0);
return std::make_pair(int64_t(ratio), 0);
// Currently disable spliting reduction for non-matmul op. This will
// get enabled after once tests are ready.
return std::make_pair(int64_t(0), 0);
Expand Down

0 comments on commit 683e17f

Please sign in to comment.