diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 1abdda647364..0ec2d8cafa64 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -144,6 +144,7 @@ iree_compiler_cc_library( "TileDispatchUsingInterface.cpp", "TileSizeSelection.cpp", "TypePropagationPass.cpp", + "UnrollAnnotatedLoops.cpp", "UserConfig.cpp", "VectorizeMemrefCopy.cpp", ], diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 7f54a2980d52..428bb49353c2 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -135,6 +135,7 @@ iree_cc_library( "TileDispatchUsingInterface.cpp" "TileSizeSelection.cpp" "TypePropagationPass.cpp" + "UnrollAnnotatedLoops.cpp" "UserConfig.cpp" "VectorizeMemrefCopy.cpp" DEPS diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 385a7eb89cc7..a9a208beade9 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -544,6 +544,11 @@ def TypePropagationPass : let summary = "Propogate the type of tensor to avoid load/stores of illegal bit widths"; } +def UnrollAnnotatedLoopsPass : + InterfacePass<"iree-codegen-unroll-annotated-loops", "mlir::FunctionOpInterface"> { + let summary = "Unrolls all scf.for loops marked with `unroll_loop`"; +} + def VectorizeMemrefCopyPass : Pass<"iree-codegen-vectorize-memref-copy", ""> { let summary = "Vectorizes memref copy operations."; diff --git a/compiler/src/iree/compiler/Codegen/Common/UnrollAnnotatedLoops.cpp b/compiler/src/iree/compiler/Codegen/Common/UnrollAnnotatedLoops.cpp new file mode 100644 index 000000000000..2ce3b2f3c40f --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/UnrollAnnotatedLoops.cpp @@ -0,0 +1,79 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Utils/MarkerUtils.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_UNROLLANNOTATEDLOOPSPASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" + +namespace { + +/// Returns the trip count of `forOp` if its' low bound, high bound and step are +/// constants, or optional otherwise. Trip count is computed as +/// ceilDiv(highBound - lowBound, step). +static std::optional getConstantTripCount(scf::ForOp forOp) { + std::optional lbCstOp = getConstantIntValue(forOp.getLowerBound()); + std::optional ubCstOp = getConstantIntValue(forOp.getUpperBound()); + std::optional stepCstOp = getConstantIntValue(forOp.getStep()); + if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value()) { + return std::nullopt; + } + + // Constant loop bounds computation. + if (lbCstOp < 0 || ubCstOp < 0 || stepCstOp <= 0) { + return std::nullopt; + } + return llvm::divideCeil(*ubCstOp - *lbCstOp, *stepCstOp); +} + +struct UnrollAnnotatedLoopsPass final + : impl::UnrollAnnotatedLoopsPassBase { + void runOnOperation() override { + FunctionOpInterface funcOp = getOperation(); + + // Get the list of operations to unroll in post-order so that the inner + // most loops get unrolled before the outer most loops. + // (This is the default but set explicitly here because it's required). + SmallVector unrollTargets; + funcOp.walk([&](scf::ForOp forOp) { + if (getLoopUnrollMarker(forOp)) { + unrollTargets.push_back(forOp); + } + }); + + for (scf::ForOp forOp : unrollTargets) { + removeLoopUnrollMarker(forOp); + + std::optional maybeTripCount = getConstantTripCount(forOp); + if (maybeTripCount.value_or(0) <= 0) { + continue; + } + + (void)loopUnrollByFactor(forOp, *maybeTripCount); + } + + // Cleanup unrolled loops. + { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + scf::ForOp::getCanonicalizationPatterns(patterns, context); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + funcOp->emitError("Failed to apply post unroll cleanup"); + return signalPassFailure(); + } + } + } +}; + +} // namespace +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel index eba43b07c821..f0ce080ad3f0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel @@ -76,6 +76,7 @@ iree_lit_test_suite( "transpose_canonicalization.mlir", "type_propagation.mlir", "type_propagation_packing.mlir", + "unroll_annotated_loops.mlir", "vectorize_memref_copy.mlir", "vectorize_tensor_pad.mlir", "vector_layout_analysis.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt index 751fcf190981..ba2b67bc7746 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt @@ -72,6 +72,7 @@ iree_lit_test_suite( "transpose_canonicalization.mlir" "type_propagation.mlir" "type_propagation_packing.mlir" + "unroll_annotated_loops.mlir" "vector_layout_analysis.mlir" "vectorize_memref_copy.mlir" "vectorize_tensor_pad.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/test/unroll_annotated_loops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/unroll_annotated_loops.mlir new file mode 100644 index 000000000000..7659a8dc1274 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/unroll_annotated_loops.mlir @@ -0,0 +1,107 @@ +// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-unroll-annotated-loops))" \ +// RUN: --allow-unregistered-dialect | FileCheck %s + +func.func @basic_unroll() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + scf.for %i = %c0 to %c3 step %c1 { + "unregistered.loop_body"(%i) : (index) -> () + } {unroll_loop} + return +} + +// CHECK-LABEL: func.func @basic_unroll +// CHECK: "unregistered.loop_body"(%c0) +// CHECK: "unregistered.loop_body"(%c1) +// CHECK: "unregistered.loop_body"(%c2) + +// ----- + +func.func @no_annotation_no_unroll() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + scf.for %i = %c0 to %c3 step %c1 { + "unregistered.loop_body"(%i) : (index) -> () + } + return +} + +// CHECK-LABEL: func.func @no_annotation_no_unroll +// CHECK: scf.for +// CHECK: "unregistered.loop_body" + +// ----- + +func.func @no_unroll_dynamic_trip(%x: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.for %i = %c0 to %x step %c1 { + "unregistered.loop_body"(%i) : (index) -> () + } {unroll_loop} + return +} + +// CHECK-LABEL: func.func @no_unroll_dynamic_trip +// CHECK: scf.for +// CHECK: "unregistered.loop_body" +// CHECK-NOT: unroll_loop + +// ----- + +func.func @unroll_non_normalized() { + %c5 = arith.constant 5 : index + %c10 = arith.constant 10 : index + %c2 = arith.constant 2 : index + scf.for %i = %c5 to %c10 step %c2 { + "unregistered.loop_body"(%i) : (index) -> () + } {unroll_loop} + return +} + +// CHECK-LABEL: func.func @unroll_non_normalized +// CHECK: "unregistered.loop_body"(%c5) +// CHECK: "unregistered.loop_body"(%c7) +// CHECK: "unregistered.loop_body"(%c9) + +// ----- + +func.func @unroll_iter_arg() -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %init = arith.constant 1 : i32 + %0 = scf.for %i = %c0 to %c3 step %c1 iter_args(%it = %init) -> i32 { + %1 = "unregistered.loop_body"(%it) : (i32) -> (i32) + scf.yield %1 : i32 + } {unroll_loop} + return %0 : i32 +} + +// CHECK-LABEL: func.func @unroll_iter_arg +// CHECK: %[[INIT:.+]] = arith.constant 1 : i32 +// CHECK: %[[IT0:.+]] = "unregistered.loop_body"(%[[INIT]]) +// CHECK: %[[IT1:.+]] = "unregistered.loop_body"(%[[IT0]]) +// CHECK: %[[IT2:.+]] = "unregistered.loop_body"(%[[IT1]]) +// CHECK: return %[[IT2]] + +// ----- + +func.func @nested_unroll() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + scf.for %i = %c0 to %c2 step %c1 { + scf.for %j = %c0 to %c2 step %c1 { + "unregistered.loop_body"(%i, %j) : (index, index) -> () + } {unroll_loop} + } {unroll_loop} + return +} + +// CHECK-LABEL: func.func @nested_unroll +// CHECK: "unregistered.loop_body"(%c0, %c0) +// CHECK: "unregistered.loop_body"(%c0, %c1) +// CHECK: "unregistered.loop_body"(%c1, %c0) +// CHECK: "unregistered.loop_body"(%c1, %c1) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir index f512b11ce1ab..02f3d10cff4a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_forall.mlir @@ -59,6 +59,7 @@ module attributes { transform.with_named_sequence } { // CHECK: %[[COPY:.+]] = linalg.copy ins(%[[INSLICE0]] : tensor<2x128xf32>) outs(%[[INSLICE1]] : tensor<2x128xf32>) -> tensor<2x128xf32> // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[INID0]], %[[IDS]]#1] [2, 128] [1, 1] // CHECK: scf.yield %[[INSERT]] +// CHECK: } {unroll_loop} // CHECK: iree_gpu.yield %[[LOOP]] // CHECK: } : tensor<128x128xf32> @@ -125,6 +126,7 @@ module attributes { transform.with_named_sequence } { // CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region ins(%[[ALLOC]] : tensor<128x128xf32>) // CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %{{.*}}) // CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]] +// CHECK: unroll_loop // CHECK: } : tensor<128x128xf32> // CHECK: } {mapping = [#gpu.warp, #gpu.warp]} @@ -182,6 +184,7 @@ module attributes { transform.with_named_sequence } { // CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>): // CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %[[INTERMEDIATE]]) // CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]] +// CHECK: unroll_loop // CHECK: iree_gpu.yield %[[LOOP]] // CHECK: } : tensor<128x128xf32> // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[BARRIER]] {{\[}}[0, 1], [2]{{\]}} output_shape [2, 64, 128] @@ -253,6 +256,7 @@ module attributes { transform.with_named_sequence } { // CHECK: %[[COPY:.+]] = linalg.copy // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[IDX]], %[[IDS]]#1] [2, 128] // CHECK: scf.yield %[[INSERT]] +// CHECK: unroll_loop // CHECK: } : tensor<128x128xf32> // CHECK: } {mapping = [#iree_gpu.lane_id<1>, #iree_gpu.lane_id<0>]} @@ -309,6 +313,7 @@ module attributes { transform.with_named_sequence } { // CHECK: scf.for %[[I:.+]] = %[[LINEARID]] to %c32{{.*}} step %c64{{.*}} // CHECK: %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%c32) : index // CHECK: scf.yield +// CHECK: unroll_loop // CHECK: } {mapping = [#gpu.thread, #gpu.thread]} // ----- @@ -364,4 +369,5 @@ module attributes { transform.with_named_sequence } { // CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %[[LINEARID]] to %[[PRODCOUNT]] step %c64{{.*}} // CHECK: %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%[[Z]], %[[Y]], %[[X]]) : index // CHECK: scf.yield +// CHECK: unroll_loop // CHECK: } {mapping = [#gpu.thread, #gpu.thread]} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index ed1c10e533bb..161546f24c42 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -10,6 +10,7 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "iree/compiler/Codegen/Utils/MarkerUtils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" @@ -234,6 +235,7 @@ LogicalResult fuseForallIntoConsumer(RewriterBase &rewriter, getValueOrCreateConstantIndexOp(rewriter, loc, consumerWorkerCount); auto newProducer = rewriter.create( loc, lb, ub, step, barrierOp.getBody()->getArgument(0)); + setLoopUnrollMarker(newProducer); Block *loopBody = newProducer.getBody(); // Get the replacement IDs for the producer loop. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index d657985f798f..a2320c95a07b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -431,6 +431,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, // Step 9. Remaining post-bufferization optimizations/lowerings. funcPassManager.addPass(IREE::GPU::createLowerIREEGPUOpsPass()); + funcPassManager.addPass(createUnrollAnnotatedLoopsPass()); funcPassManager.addPass(createLoopInvariantCodeMotionPass()); if (pipelineOptions.enableReduceSharedMemoryBankConflicts) { GPUReduceBankConflictsPassOptions options = {}; diff --git a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp index b7f3040a4294..9c2931c20ea7 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp @@ -142,4 +142,15 @@ void setMarker(Operation *op, StringRef marker) { StringAttr::get(op->getContext(), marker)); } +constexpr StringLiteral kUnrollLoopName = "unroll_loop"; +void setLoopUnrollMarker(Operation *op) { + op->setAttr(kUnrollLoopName, UnitAttr::get(op->getContext())); +} + +Attribute getLoopUnrollMarker(Operation *op) { + return op->getAttr(kUnrollLoopName); +} + +void removeLoopUnrollMarker(Operation *op) { op->removeAttr(kUnrollLoopName); } + } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h index 40dc0012840e..f2c9a3fa80f6 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h @@ -4,10 +4,9 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -//===- MarkerUtils.h - Methods for manipulating markers on Linalg op ------===// +//===- MarkerUtils.h - Methods for manipulating transformation markers ----===// // -// Method that set markers on Linalg operations that determine which processor -// heirarchy to use for partitioning +// Method that set markers on various operations that affect later transforms. // //===----------------------------------------------------------------------===// @@ -121,6 +120,13 @@ bool hasMarker(Operation *, ArrayRef markers = {}); /// Sets a given marker on an operation. void setMarker(Operation *, StringRef); +/// Markers for other operations. + +// Getter/setter for marking a loop for unrolling. +void setLoopUnrollMarker(Operation *op); +Attribute getLoopUnrollMarker(Operation *op); +void removeLoopUnrollMarker(Operation *op); + } // namespace mlir::iree_compiler #endif // IREE_COMPILER_CODEGEN_CODEGENUTILS_MARKERUTILS_H_