forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Codegen] Add pass for unrolling annotated for loops (iree-org#18641)
This allows annotating for loops formed in earlier compilation stages to be unrolled later on. The case this is used for today is for unrolling loops from tiling producers of matmul operands (typically a copy). If a loop has a dynamic trip count, the attribute will be dropped silently (unrolling is best effort).
- Loading branch information
Showing
12 changed files
with
224 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
79 changes: 79 additions & 0 deletions
79
compiler/src/iree/compiler/Codegen/Common/UnrollAnnotatedLoops.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Codegen/Common/Passes.h" | ||
#include "iree/compiler/Codegen/Utils/MarkerUtils.h" | ||
#include "mlir/Dialect/SCF/Utils/Utils.h" | ||
#include "mlir/IR/Visitors.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
#define GEN_PASS_DEF_UNROLLANNOTATEDLOOPSPASS | ||
#include "iree/compiler/Codegen/Common/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
/// Returns the trip count of `forOp` if its' low bound, high bound and step are | ||
/// constants, or optional otherwise. Trip count is computed as | ||
/// ceilDiv(highBound - lowBound, step). | ||
static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) { | ||
std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound()); | ||
std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound()); | ||
std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep()); | ||
if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value()) { | ||
return std::nullopt; | ||
} | ||
|
||
// Constant loop bounds computation. | ||
if (lbCstOp < 0 || ubCstOp < 0 || stepCstOp <= 0) { | ||
return std::nullopt; | ||
} | ||
return llvm::divideCeil(*ubCstOp - *lbCstOp, *stepCstOp); | ||
} | ||
|
||
struct UnrollAnnotatedLoopsPass final | ||
: impl::UnrollAnnotatedLoopsPassBase<UnrollAnnotatedLoopsPass> { | ||
void runOnOperation() override { | ||
FunctionOpInterface funcOp = getOperation(); | ||
|
||
// Get the list of operations to unroll in post-order so that the inner | ||
// most loops get unrolled before the outer most loops. | ||
// (This is the default but set explicitly here because it's required). | ||
SmallVector<scf::ForOp> unrollTargets; | ||
funcOp.walk<WalkOrder::PostOrder>([&](scf::ForOp forOp) { | ||
if (getLoopUnrollMarker(forOp)) { | ||
unrollTargets.push_back(forOp); | ||
} | ||
}); | ||
|
||
for (scf::ForOp forOp : unrollTargets) { | ||
removeLoopUnrollMarker(forOp); | ||
|
||
std::optional<int64_t> maybeTripCount = getConstantTripCount(forOp); | ||
if (maybeTripCount.value_or(0) <= 0) { | ||
continue; | ||
} | ||
|
||
(void)loopUnrollByFactor(forOp, *maybeTripCount); | ||
} | ||
|
||
// Cleanup unrolled loops. | ||
{ | ||
MLIRContext *context = &getContext(); | ||
RewritePatternSet patterns(context); | ||
scf::ForOp::getCanonicalizationPatterns(patterns, context); | ||
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { | ||
funcOp->emitError("Failed to apply post unroll cleanup"); | ||
return signalPassFailure(); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace | ||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
107 changes: 107 additions & 0 deletions
107
compiler/src/iree/compiler/Codegen/Common/test/unroll_annotated_loops.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-unroll-annotated-loops))" \ | ||
// RUN: --allow-unregistered-dialect | FileCheck %s | ||
|
||
func.func @basic_unroll() { | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c3 = arith.constant 3 : index | ||
scf.for %i = %c0 to %c3 step %c1 { | ||
"unregistered.loop_body"(%i) : (index) -> () | ||
} {unroll_loop} | ||
return | ||
} | ||
|
||
// CHECK-LABEL: func.func @basic_unroll | ||
// CHECK: "unregistered.loop_body"(%c0) | ||
// CHECK: "unregistered.loop_body"(%c1) | ||
// CHECK: "unregistered.loop_body"(%c2) | ||
|
||
// ----- | ||
|
||
func.func @no_annotation_no_unroll() { | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c3 = arith.constant 3 : index | ||
scf.for %i = %c0 to %c3 step %c1 { | ||
"unregistered.loop_body"(%i) : (index) -> () | ||
} | ||
return | ||
} | ||
|
||
// CHECK-LABEL: func.func @no_annotation_no_unroll | ||
// CHECK: scf.for | ||
// CHECK: "unregistered.loop_body" | ||
|
||
// ----- | ||
|
||
func.func @no_unroll_dynamic_trip(%x: index) { | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
scf.for %i = %c0 to %x step %c1 { | ||
"unregistered.loop_body"(%i) : (index) -> () | ||
} {unroll_loop} | ||
return | ||
} | ||
|
||
// CHECK-LABEL: func.func @no_unroll_dynamic_trip | ||
// CHECK: scf.for | ||
// CHECK: "unregistered.loop_body" | ||
// CHECK-NOT: unroll_loop | ||
|
||
// ----- | ||
|
||
func.func @unroll_non_normalized() { | ||
%c5 = arith.constant 5 : index | ||
%c10 = arith.constant 10 : index | ||
%c2 = arith.constant 2 : index | ||
scf.for %i = %c5 to %c10 step %c2 { | ||
"unregistered.loop_body"(%i) : (index) -> () | ||
} {unroll_loop} | ||
return | ||
} | ||
|
||
// CHECK-LABEL: func.func @unroll_non_normalized | ||
// CHECK: "unregistered.loop_body"(%c5) | ||
// CHECK: "unregistered.loop_body"(%c7) | ||
// CHECK: "unregistered.loop_body"(%c9) | ||
|
||
// ----- | ||
|
||
func.func @unroll_iter_arg() -> i32 { | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c3 = arith.constant 3 : index | ||
%init = arith.constant 1 : i32 | ||
%0 = scf.for %i = %c0 to %c3 step %c1 iter_args(%it = %init) -> i32 { | ||
%1 = "unregistered.loop_body"(%it) : (i32) -> (i32) | ||
scf.yield %1 : i32 | ||
} {unroll_loop} | ||
return %0 : i32 | ||
} | ||
|
||
// CHECK-LABEL: func.func @unroll_iter_arg | ||
// CHECK: %[[INIT:.+]] = arith.constant 1 : i32 | ||
// CHECK: %[[IT0:.+]] = "unregistered.loop_body"(%[[INIT]]) | ||
// CHECK: %[[IT1:.+]] = "unregistered.loop_body"(%[[IT0]]) | ||
// CHECK: %[[IT2:.+]] = "unregistered.loop_body"(%[[IT1]]) | ||
// CHECK: return %[[IT2]] | ||
|
||
// ----- | ||
|
||
func.func @nested_unroll() { | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c2 = arith.constant 2 : index | ||
scf.for %i = %c0 to %c2 step %c1 { | ||
scf.for %j = %c0 to %c2 step %c1 { | ||
"unregistered.loop_body"(%i, %j) : (index, index) -> () | ||
} {unroll_loop} | ||
} {unroll_loop} | ||
return | ||
} | ||
|
||
// CHECK-LABEL: func.func @nested_unroll | ||
// CHECK: "unregistered.loop_body"(%c0, %c0) | ||
// CHECK: "unregistered.loop_body"(%c0, %c1) | ||
// CHECK: "unregistered.loop_body"(%c1, %c0) | ||
// CHECK: "unregistered.loop_body"(%c1, %c1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters