Skip to content

Commit

Permalink
[Codegen] Add pass for unrolling annotated for loops (iree-org#18641)
Browse files Browse the repository at this point in the history
This allows annotating for loops formed in earlier compilation stages to
be unrolled later on. The case this is used for today is for unrolling
loops from tiling producers of matmul operands (typically a copy).

If a loop has a dynamic trip count, the attribute will be dropped
silently (unrolling is best effort).
  • Loading branch information
qedawkins authored Oct 1, 2024
1 parent 9c39a29 commit 451ef71
Show file tree
Hide file tree
Showing 12 changed files with 224 additions and 3 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ iree_compiler_cc_library(
"TileDispatchUsingInterface.cpp",
"TileSizeSelection.cpp",
"TypePropagationPass.cpp",
"UnrollAnnotatedLoops.cpp",
"UserConfig.cpp",
"VectorizeMemrefCopy.cpp",
],
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ iree_cc_library(
"TileDispatchUsingInterface.cpp"
"TileSizeSelection.cpp"
"TypePropagationPass.cpp"
"UnrollAnnotatedLoops.cpp"
"UserConfig.cpp"
"VectorizeMemrefCopy.cpp"
DEPS
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,11 @@ def TypePropagationPass :
let summary = "Propogate the type of tensor to avoid load/stores of illegal bit widths";
}

def UnrollAnnotatedLoopsPass :
InterfacePass<"iree-codegen-unroll-annotated-loops", "mlir::FunctionOpInterface"> {
let summary = "Unrolls all scf.for loops marked with `unroll_loop`";
}

def VectorizeMemrefCopyPass :
Pass<"iree-codegen-vectorize-memref-copy", ""> {
let summary = "Vectorizes memref copy operations.";
Expand Down
79 changes: 79 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/UnrollAnnotatedLoops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_UNROLLANNOTATEDLOOPSPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {

/// Returns the trip count of `forOp` if its' low bound, high bound and step are
/// constants, or optional otherwise. Trip count is computed as
/// ceilDiv(highBound - lowBound, step).
static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value()) {
return std::nullopt;
}

// Constant loop bounds computation.
if (lbCstOp < 0 || ubCstOp < 0 || stepCstOp <= 0) {
return std::nullopt;
}
return llvm::divideCeil(*ubCstOp - *lbCstOp, *stepCstOp);
}

struct UnrollAnnotatedLoopsPass final
: impl::UnrollAnnotatedLoopsPassBase<UnrollAnnotatedLoopsPass> {
void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();

// Get the list of operations to unroll in post-order so that the inner
// most loops get unrolled before the outer most loops.
// (This is the default but set explicitly here because it's required).
SmallVector<scf::ForOp> unrollTargets;
funcOp.walk<WalkOrder::PostOrder>([&](scf::ForOp forOp) {
if (getLoopUnrollMarker(forOp)) {
unrollTargets.push_back(forOp);
}
});

for (scf::ForOp forOp : unrollTargets) {
removeLoopUnrollMarker(forOp);

std::optional<int64_t> maybeTripCount = getConstantTripCount(forOp);
if (maybeTripCount.value_or(0) <= 0) {
continue;
}

(void)loopUnrollByFactor(forOp, *maybeTripCount);
}

// Cleanup unrolled loops.
{
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
scf::ForOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
funcOp->emitError("Failed to apply post unroll cleanup");
return signalPassFailure();
}
}
}
};

} // namespace
} // namespace mlir::iree_compiler
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ iree_lit_test_suite(
"transpose_canonicalization.mlir",
"type_propagation.mlir",
"type_propagation_packing.mlir",
"unroll_annotated_loops.mlir",
"vectorize_memref_copy.mlir",
"vectorize_tensor_pad.mlir",
"vector_layout_analysis.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ iree_lit_test_suite(
"transpose_canonicalization.mlir"
"type_propagation.mlir"
"type_propagation_packing.mlir"
"unroll_annotated_loops.mlir"
"vector_layout_analysis.mlir"
"vectorize_memref_copy.mlir"
"vectorize_tensor_pad.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-unroll-annotated-loops))" \
// RUN: --allow-unregistered-dialect | FileCheck %s

func.func @basic_unroll() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
scf.for %i = %c0 to %c3 step %c1 {
"unregistered.loop_body"(%i) : (index) -> ()
} {unroll_loop}
return
}

// CHECK-LABEL: func.func @basic_unroll
// CHECK: "unregistered.loop_body"(%c0)
// CHECK: "unregistered.loop_body"(%c1)
// CHECK: "unregistered.loop_body"(%c2)

// -----

func.func @no_annotation_no_unroll() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
scf.for %i = %c0 to %c3 step %c1 {
"unregistered.loop_body"(%i) : (index) -> ()
}
return
}

// CHECK-LABEL: func.func @no_annotation_no_unroll
// CHECK: scf.for
// CHECK: "unregistered.loop_body"

// -----

func.func @no_unroll_dynamic_trip(%x: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %i = %c0 to %x step %c1 {
"unregistered.loop_body"(%i) : (index) -> ()
} {unroll_loop}
return
}

// CHECK-LABEL: func.func @no_unroll_dynamic_trip
// CHECK: scf.for
// CHECK: "unregistered.loop_body"
// CHECK-NOT: unroll_loop

// -----

func.func @unroll_non_normalized() {
%c5 = arith.constant 5 : index
%c10 = arith.constant 10 : index
%c2 = arith.constant 2 : index
scf.for %i = %c5 to %c10 step %c2 {
"unregistered.loop_body"(%i) : (index) -> ()
} {unroll_loop}
return
}

// CHECK-LABEL: func.func @unroll_non_normalized
// CHECK: "unregistered.loop_body"(%c5)
// CHECK: "unregistered.loop_body"(%c7)
// CHECK: "unregistered.loop_body"(%c9)

// -----

func.func @unroll_iter_arg() -> i32 {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%init = arith.constant 1 : i32
%0 = scf.for %i = %c0 to %c3 step %c1 iter_args(%it = %init) -> i32 {
%1 = "unregistered.loop_body"(%it) : (i32) -> (i32)
scf.yield %1 : i32
} {unroll_loop}
return %0 : i32
}

// CHECK-LABEL: func.func @unroll_iter_arg
// CHECK: %[[INIT:.+]] = arith.constant 1 : i32
// CHECK: %[[IT0:.+]] = "unregistered.loop_body"(%[[INIT]])
// CHECK: %[[IT1:.+]] = "unregistered.loop_body"(%[[IT0]])
// CHECK: %[[IT2:.+]] = "unregistered.loop_body"(%[[IT1]])
// CHECK: return %[[IT2]]

// -----

func.func @nested_unroll() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
scf.for %i = %c0 to %c2 step %c1 {
scf.for %j = %c0 to %c2 step %c1 {
"unregistered.loop_body"(%i, %j) : (index, index) -> ()
} {unroll_loop}
} {unroll_loop}
return
}

// CHECK-LABEL: func.func @nested_unroll
// CHECK: "unregistered.loop_body"(%c0, %c0)
// CHECK: "unregistered.loop_body"(%c0, %c1)
// CHECK: "unregistered.loop_body"(%c1, %c0)
// CHECK: "unregistered.loop_body"(%c1, %c1)
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ module attributes { transform.with_named_sequence } {
// CHECK: %[[COPY:.+]] = linalg.copy ins(%[[INSLICE0]] : tensor<2x128xf32>) outs(%[[INSLICE1]] : tensor<2x128xf32>) -> tensor<2x128xf32>
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1]
// CHECK: scf.yield %[[INSERT]]
// CHECK: } {unroll_loop}
// CHECK: iree_gpu.yield %[[LOOP]]
// CHECK: } : tensor<128x128xf32>

Expand Down Expand Up @@ -125,6 +126,7 @@ module attributes { transform.with_named_sequence } {
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region ins(%[[ALLOC]] : tensor<128x128xf32>)
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %{{.*}})
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]]
// CHECK: unroll_loop
// CHECK: } : tensor<128x128xf32>
// CHECK: } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}

Expand Down Expand Up @@ -182,6 +184,7 @@ module attributes { transform.with_named_sequence } {
// CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>):
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %[[INTERMEDIATE]])
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]]
// CHECK: unroll_loop
// CHECK: iree_gpu.yield %[[LOOP]]
// CHECK: } : tensor<128x128xf32>
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[BARRIER]] {{\[}}[0, 1], [2]{{\]}} output_shape [2, 64, 128]
Expand Down Expand Up @@ -253,6 +256,7 @@ module attributes { transform.with_named_sequence } {
// CHECK: %[[COPY:.+]] = linalg.copy
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[IDX]], %[[IDS]]#1] [2, 128]
// CHECK: scf.yield %[[INSERT]]
// CHECK: unroll_loop
// CHECK: } : tensor<128x128xf32>

// CHECK: } {mapping = [#iree_gpu.lane_id<1>, #iree_gpu.lane_id<0>]}
Expand Down Expand Up @@ -309,6 +313,7 @@ module attributes { transform.with_named_sequence } {
// CHECK: scf.for %[[I:.+]] = %[[LINEARID]] to %c32{{.*}} step %c64{{.*}}
// CHECK: %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%c32) : index
// CHECK: scf.yield
// CHECK: unroll_loop
// CHECK: } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}

// -----
Expand Down Expand Up @@ -364,4 +369,5 @@ module attributes { transform.with_named_sequence } {
// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %[[LINEARID]] to %[[PRODCOUNT]] step %c64{{.*}}
// CHECK: %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%[[Z]], %[[Y]], %[[X]]) : index
// CHECK: scf.yield
// CHECK: unroll_loop
// CHECK: } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -234,6 +235,7 @@ LogicalResult fuseForallIntoConsumer(RewriterBase &rewriter,
getValueOrCreateConstantIndexOp(rewriter, loc, consumerWorkerCount);
auto newProducer = rewriter.create<scf::ForOp>(
loc, lb, ub, step, barrierOp.getBody()->getArgument(0));
setLoopUnrollMarker(newProducer);
Block *loopBody = newProducer.getBody();

// Get the replacement IDs for the producer loop.
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,

// Step 9. Remaining post-bufferization optimizations/lowerings.
funcPassManager.addPass(IREE::GPU::createLowerIREEGPUOpsPass());
funcPassManager.addPass(createUnrollAnnotatedLoopsPass());
funcPassManager.addPass(createLoopInvariantCodeMotionPass());
if (pipelineOptions.enableReduceSharedMemoryBankConflicts) {
GPUReduceBankConflictsPassOptions options = {};
Expand Down
11 changes: 11 additions & 0 deletions compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,15 @@ void setMarker(Operation *op, StringRef marker) {
StringAttr::get(op->getContext(), marker));
}

constexpr StringLiteral kUnrollLoopName = "unroll_loop";
void setLoopUnrollMarker(Operation *op) {
op->setAttr(kUnrollLoopName, UnitAttr::get(op->getContext()));
}

Attribute getLoopUnrollMarker(Operation *op) {
return op->getAttr(kUnrollLoopName);
}

void removeLoopUnrollMarker(Operation *op) { op->removeAttr(kUnrollLoopName); }

} // namespace mlir::iree_compiler
12 changes: 9 additions & 3 deletions compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

//===- MarkerUtils.h - Methods for manipulating markers on Linalg op ------===//
//===- MarkerUtils.h - Methods for manipulating transformation markers ----===//
//
// Method that set markers on Linalg operations that determine which processor
// heirarchy to use for partitioning
// Method that set markers on various operations that affect later transforms.
//
//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -121,6 +120,13 @@ bool hasMarker(Operation *, ArrayRef<StringRef> markers = {});
/// Sets a given marker on an operation.
void setMarker(Operation *, StringRef);

/// Markers for other operations.

// Getter/setter for marking a loop for unrolling.
void setLoopUnrollMarker(Operation *op);
Attribute getLoopUnrollMarker(Operation *op);
void removeLoopUnrollMarker(Operation *op);

} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_CODEGEN_CODEGENUTILS_MARKERUTILS_H_

0 comments on commit 451ef71

Please sign in to comment.