From 9ee061d2ec366e955f2a348225b4051a31f8f244 Mon Sep 17 00:00:00 2001 From: rohan-tan-bhowmik <46410002+rohan-tan-bhowmik@users.noreply.github.com> Date: Sat, 21 Sep 2024 23:07:18 -0700 Subject: [PATCH] [LinalgExt] Masked Attention Implementation (#18525) Enables float/boolean mask as parameters and created linalg generic ops to apply masking. This image (https://imgur.com/a/1MePgcy) elaborates on the main files changed and how they enable masked attention: - Blue boxes represent changed .cpp and .td files to enable/pass/decompose the mask - Yellow boxes represent the different op classes - Red boxes represent test mlir files pertaining to certain .cpp/.td implementations or ops For quick reference, AggregateOpInterfaceImpl.cpp contains the bulk of the actual mask decomposition (QK += mask) And for clarification, TileAttention.cpp only holds the convertToOnlineAttentionOp and getTileAttentionIndexingMaps functions; TilingInterfaceImpl.cpp contains the main tiling capabilities in the form of AttentionOp::getTiledImplementation and OnlineAttentionOp::getTiledImplementation. Updated version of https://github.com/iree-org/iree/pull/18461. This version was created to include scale affine map and enable fused attention (incorporated https://github.com/IanWood1/iree/tree/raikonen/sdpa_mask). - To that end, many modifications in tests are for adding the scale affine map (without much functionality change) - For tiling and decomposition tests, most functionality tests are included in "tiling.mlir" and "decompose_online_attention.mlir". On the other hand, the "tile_attention.mlir and "decompose_attention.mlir" are old paths intended to be be retired and deprecate soon. Hence, no major tests were added it there. Test directory for numerical verification: https://github.com/rohan-tan-bhowmik/iree-masked-attention-test --------- Signed-off-by: Stanley Winata Co-authored-by: Stanley Winata Co-authored-by: Ian Wood --- .../ConvertTMTensorToLinalgExt.cpp | 34 +- .../Torch/InputConversion/test/attention.mlir | 18 +- .../Common/GPU/test/gpu_tensor_alloc.mlir | 3 +- .../test/select_x86_64_lowering_strategy.mlir | 1 + .../test/ROCDL/config_vector_distribute.mlir | 2 + .../pipeline_vector_distribute_gfx940.mlir | 5 +- .../Codegen/LLVMGPU/test/attention.mlir | 1 + .../Codegen/LLVMGPU/test/attention_mfma.mlir | 1 + .../LinalgExt/IR/LinalgExtInterfaces.td | 6 +- .../Dialect/LinalgExt/IR/LinalgExtOps.cpp | 102 +++++- .../Dialect/LinalgExt/IR/LinalgExtOps.td | 84 ++++- .../Dialect/LinalgExt/IR/test/invalid.mlir | 27 +- .../Dialect/LinalgExt/IR/test/roundtrip.mlir | 16 +- .../Transforms/AggregatedOpInterfaceImpl.cpp | 62 +++- .../LinalgExt/Transforms/ReshapeFusion.cpp | 14 +- .../LinalgExt/Transforms/TileAttention.cpp | 9 +- .../Transforms/TilingInterfaceImpl.cpp | 40 ++- .../test/convert_to_online_attention.mlir | 5 +- .../Transforms/test/decompose_attention.mlir | 4 + .../test/decompose_online_attention.mlir | 70 +++- .../Transforms/test/tile_attention.mlir | 4 + .../LinalgExt/Transforms/test/tiling.mlir | 307 +++++++++++++++++- .../Dialect/LinalgExt/Utils/IndexingUtils.cpp | 5 +- .../test/attention_fuse_by_expansion.mlir | 248 +++++++++++++- ...clone_producers_into_dispatch_regions.mlir | 97 +++++- .../test/dispatch_linalg_ext_fusion.mlir | 48 ++- .../DispatchCreation/test/fold_transpose.mlir | 51 ++- .../test/fold_attention_with_transpose.mlir | 67 ++-- .../attention/generate_e2e_attention_tests.py | 1 + tests/e2e/linalg_ext_ops/attention.mlir | 3 + 30 files changed, 1211 insertions(+), 124 deletions(-) diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp index d775c1ccad5b..b27193d25053 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp @@ -73,21 +73,22 @@ struct ScatterOpConversion }; } // namespace -static SmallVector -getStandardAttentionIndexingMaps(MLIRContext *ctx) { +static SmallVector getStandardAttentionIndexingMaps(MLIRContext *ctx, + bool hasMask) { AffineExpr m, n, k1, k2; bindDims(ctx, m, n, k1, k2); - AffineMap qMap = - AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k1}, ctx); - AffineMap kMap = - AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx); - AffineMap vMap = - AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx); - AffineMap rMap = - AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx); - - return {qMap, kMap, vMap, rMap}; + auto qMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k1}, ctx); + auto kMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx); + auto vMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx); + auto sMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, ctx); + auto rMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx); + if (hasMask) { + // Add mask map only if it exists + auto mMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k2}, ctx); + return {qMap, kMap, vMap, sMap, mMap, rMap}; + } + return {qMap, kMap, vMap, sMap, rMap}; } struct AttentionOpConversion @@ -100,6 +101,7 @@ struct AttentionOpConversion Value query = op.getQuery(); Value key = op.getKey(); Value value = op.getValue(); + std::optional optionalMask = op.getAttnMask(); ShapedType outputType = op.getOutputType(); @@ -147,10 +149,14 @@ struct AttentionOpConversion loc, targetType, rewriter.getFloatAttr(targetType, dk)); // Add batches to standard attention indexing maps. - SmallVector indexingMaps = getStandardAttentionIndexingMaps(ctx); + SmallVector indexingMaps = + getStandardAttentionIndexingMaps(ctx, optionalMask.has_value()); + int64_t numBatches = op.getQueryType().getRank() - 2; for (AffineMap &map : indexingMaps) { map = map.shiftDims(numBatches); + if (map.getNumResults() == 0) + continue; for (int batch : llvm::seq(numBatches)) { map = map.insertResult(rewriter.getAffineDimExpr(batch), batch); } @@ -158,7 +164,7 @@ struct AttentionOpConversion auto attention = rewriter.create( loc, result.getType(), query, key, value, scale, result, - rewriter.getAffineMapArrayAttr(indexingMaps)); + rewriter.getAffineMapArrayAttr(indexingMaps), optionalMask); rewriter.replaceOp(op, attention.getResult(0)); return success(); diff --git a/compiler/plugins/input/Torch/InputConversion/test/attention.mlir b/compiler/plugins/input/Torch/InputConversion/test/attention.mlir index 06e85e753a55..0de566e823ee 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/attention.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/attention.mlir @@ -8,14 +8,15 @@ func.func @attention(%arg0: tensor<5x2x3x4xf32>, %arg1: tensor<5x2x3x4xf32>, %ar // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)> // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)> // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)> +// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> ()> // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> // CHECK-LABEL: func.func @attention( // CHECK-SAME: %[[ARG0:.*]]: tensor<5x2x3x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>, -// CHECK: %arg3: tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> { +// CHECK-SAME: %[[ARG3:.*]]: tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> { // CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32 // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5x2x3x4xf32> -// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> // CHECK: return %[[ATTN]] : tensor<5x2x3x4xf32> // ----- @@ -27,14 +28,15 @@ func.func @attention(%arg0: tensor<5x2x8x4xf32>, %arg1: tensor<5x2x3x4xf32>, %ar // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)> // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)> // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)> +// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> ()> // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> // CHECK-LABEL: func.func @attention( // CHECK-SAME: %[[ARG0:.*]]: tensor<5x2x8x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>, -// CHECK: %arg3: tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> { +// CHECK-SAME: %[[ARG3:.*]]: tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> { // CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32 // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5x2x8x4xf32> -// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x8x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x8x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> // CHECK: return %[[ATTN]] : tensor<5x2x8x4xf32> // ----- @@ -46,14 +48,15 @@ func.func @attention(%arg0: tensor<1x3x4xf32>, %arg1: tensor<1x3x4xf32>, %arg2: // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)> // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)> +// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> // CHECK-LABEL: func.func @attention( // CHECK-SAME: %[[ARG0:.*]]: tensor<1x3x4xf32>, %[[ARG1:.*]]: tensor<1x3x4xf32>, %[[ARG2:.*]]: tensor<1x3x4xf32>, -// CHECK: %arg3: tensor<1x3x4xf32>) -> tensor<1x3x4xf32> { +// CHECK: %[[ARG3:.*]]: tensor<1x3x4xf32>) -> tensor<1x3x4xf32> { // CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32 // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x3x4xf32> -// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32> +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32> // CHECK: return %[[ATTN]] : tensor<1x3x4xf32> // ----- @@ -65,6 +68,7 @@ func.func @attention_dyn(%arg0: tensor, %arg1: tensor, %ar // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)> // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)> +// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> // CHECK-LABEL: func.func @attention_dyn( @@ -76,5 +80,5 @@ func.func @attention_dyn(%arg0: tensor, %arg1: tensor, %ar // CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]] // CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]] // CHECK-DAG: %[[EMPTY:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor -// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor, tensor, tensor, f32) outs(%[[EMPTY]] : tensor) -> tensor +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor, tensor, tensor, f32) outs(%[[EMPTY]] : tensor) -> tensor // CHECK: return %[[ATTN]] : tensor diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir index 7598316bfc18..4010ae855aa7 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tensor_alloc.mlir @@ -257,6 +257,7 @@ func.func @conv() attributes {translation_info = #iree_codegen.translation_info< #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> #mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> #mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> @@ -272,7 +273,7 @@ func.func @online_attention(%query: tensor<192x1024x64xf16>, %scale = arith.constant 1.0 : f16 %out:3 = iree_linalg_ext.online_attention - { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR], + { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR], lowering_config = #config } ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir index e67e7afb8e97..6533f238e0d7 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir @@ -1747,6 +1747,7 @@ func.func @attention() attributes {hal.executable.target = #executable_target_em %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%4, %5, %6, %scale : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir index 5ebd5c3fe18b..227ae1e689d1 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir @@ -369,6 +369,7 @@ func.func @attention_20x4096x64x4096x64() { %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : tensor<20x4096x64xf16> -> !flow.dispatch.tensor> @@ -407,6 +408,7 @@ func.func @attention_large_head_dim_shared_mem() { %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d1, d2, d3, d4) -> (d1, d2)>, affine_map<(d1, d2, d3, d4) -> (d3, d2)>, affine_map<(d1, d2, d3, d4) -> (d3, d4)>, + affine_map<(d1, d2, d3, d4) -> ()>, affine_map<(d1, d2, d3, d4) -> (d1, d4)>]} ins(%4, %5, %6, %cst : tensor<1024x512xf16>, tensor<128x512xf16>, tensor<128x512xf16>, f16) outs(%7 : tensor<1024x512xf16>) -> tensor<1024x512xf16> flow.dispatch.tensor.store %8, %3, offsets = [0, 0], sizes = [1024, 512], strides = [1, 1] : tensor<1024x512xf16> -> !flow.dispatch.tensor> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir index eb8f4f177396..6ebf8aba5c14 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir @@ -574,6 +574,7 @@ hal.executable private @attention_20x4096x64x4096x64 { %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> @@ -637,7 +638,7 @@ hal.executable private @attention_multiple_m_transpose { %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<24x4608x128xf16> %7 = tensor.empty() : tensor<64x4608x24x128xf16> %8 = tensor.empty() : tensor<24x64x4608x128xf16> - %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16> + %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16> %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], lowering_config = #config} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) { ^bb0(%in: f16, %out: f16): linalg.yield %in : f16 @@ -692,7 +693,7 @@ hal.executable private @attention_mfma_32x32x8 { %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<24x4608x128xf16> %7 = tensor.empty() : tensor<64x4608x24x128xf16> %8 = tensor.empty() : tensor<24x64x4608x128xf16> - %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16> + %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16> %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], lowering_config = #config} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) { ^bb0(%in: f16, %out: f16): linalg.yield %in : f16 diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir index 649cbd9b65c0..3f2c090d41f8 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir @@ -22,6 +22,7 @@ func.func @_attention_dispatch_0() { %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%4, %5, %6, %cst : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) outs(%7 : tensor<192x1024x64xf16>) -> tensor<192x1024x64xf16> flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : tensor<192x1024x64xf16> -> !flow.dispatch.tensor> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir index 13a84ec704f0..109b107dce04 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir @@ -22,6 +22,7 @@ func.func @attention_dispatch_0_attention_16x16384x128xf16() { %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%4, %5, %6, %scale : tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, f16) outs(%7 : tensor<16x16384x128xf16>) -> tensor<16x16384x128xf16> flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : tensor<16x16384x128xf16> -> !flow.dispatch.tensor> diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td index 596a925a5101..9607c9ef781a 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td @@ -149,11 +149,7 @@ def LinalgFusionInterface : OpInterface<"LinalgFusionOpInterface"> { /*methodBody=*/"", /*defaultImplementation=*/[{ assert(opOperand->getOwner() == $_op); - if(opOperand->getOperandNumber() >= $_op.getNumDpsInputs()){ - return $_op.getIndexingMapsForResults()[opOperand->getOperandNumber() - $_op.getNumDpsInputs()]; - }else { - return $_op.getIndexingMapsForOperands()[opOperand->getOperandNumber()]; - } + return getIndexingMapsArray()[opOperand->getOperandNumber()]; }] >, diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index 180559758062..e08838b79dec 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -1200,6 +1200,15 @@ LogicalResult WinogradOutputTransformOp::reifyResultShapes( // AttentionOp //===----------------------------------------------------------------------===// +void AttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState, + TypeRange results, Value query, Value key, Value value, + Value scale, ValueRange outputs, ArrayAttr indexingMaps, + std::optional mask) { + Value maskIn = mask.value_or(Value()); + build(odsBuilder, odsState, results, query, key, value, scale, maskIn, + outputs, indexingMaps); +} + LogicalResult AttentionOp::verify() { AttentionOp attnOp = *this; @@ -1212,6 +1221,9 @@ LogicalResult AttentionOp::verify() { // Check if indexing maps can represent attention. SmallVector indexingMaps = attnOp.getIndexingMapsArray(); + if (indexingMaps.size() != getOperation()->getNumOperands()) { + return attnOp->emitOpError("expected an indexing map for each operand"); + } FailureOr maybeOpInfo = AttentionOpDetail::get(indexingMaps); if (failed(maybeOpInfo)) { @@ -1246,8 +1258,8 @@ LogicalResult AttentionOp::verify() { } if (shape[pos] != valShape[i]) { return attnOp->emitError("Shape Mismatch for ") - << operandName << ". Expected: " << shape[pos] - << " Got: " << valShape[i]; + << operandName << " at position " << i + << ". Expected: " << shape[pos] << " Got: " << valShape[i]; } } return success(); @@ -1261,6 +1273,38 @@ LogicalResult AttentionOp::verify() { return failure(); } + // Additional check case if mask exists + if (auto maskMap = getMaskMap()) { + if (failed(checkShape("Mask", getMaskType()->getShape(), *maskMap))) + return failure(); + } + + int expectedSymbols = getQueryMap().getNumInputs(); + auto checkDomain = + [&attnOp, &expectedSymbols](StringRef operandName, + AffineMap indexingMap) -> LogicalResult { + if (expectedSymbols != indexingMap.getNumInputs()) { + return attnOp->emitError("Mismatched map domain for ") + << operandName << ". Expected: " << expectedSymbols + << " Got: " << indexingMap.getNumInputs(); + } + return success(); + }; + + if (failed(checkDomain("Query", getQueryMap())) || + failed(checkDomain("Key", getKeyMap())) || + failed(checkDomain("Value", getValueMap())) || + failed(checkDomain("Scale", getScaleMap())) || + failed(checkDomain("Output", getOutputMap()))) { + return failure(); + } + + // Additional check case if mask exists + if (auto maskMap = getMaskMap()) { + if (failed(checkDomain("Mask", *maskMap))) + return failure(); + } + if (isTiled) { // Tiled/Flash attention. Type maxElementType = getMaxType()->getElementType(); @@ -1324,20 +1368,29 @@ SmallVector AttentionOp::getStaticLoopRanges() { SmallVector AttentionOp::getIndexingMapsForOperands() { auto maps = getIndexingMapsArray(); - return SmallVector(maps.begin(), - maps.begin() + getNumDpsInputs() - 1); + maps.resize(getNumDpsInputs()); + return maps; } SmallVector AttentionOp::getIndexingMapsForResults() { auto maps = getIndexingMapsArray(); - return SmallVector(maps.begin() + getNumDpsInputs() - 1, - maps.end()); + return SmallVector(maps.begin() + getNumDpsInputs(), maps.end()); } //===----------------------------------------------------------------------===// // OnlineAttentionOp //===----------------------------------------------------------------------===// +void OnlineAttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState, + TypeRange results, Value query, Value key, + Value value, Value scale, Value output, Value max, + Value sum, ArrayAttr indexingMaps, + std::optional mask) { + Value maskIn = mask.value_or(Value()); + build(odsBuilder, odsState, results, query, key, value, maskIn, scale, output, + max, sum, indexingMaps); +} + LogicalResult OnlineAttentionOp::verify() { OnlineAttentionOp attnOp = *this; @@ -1389,11 +1442,46 @@ LogicalResult OnlineAttentionOp::verify() { return failure(); } + // Additional check case if mask exists + if (auto maskMap = getMaskMap()) { + if (failed(checkShape("Mask", getMask().getType().getShape(), *maskMap))) + return failure(); + } + + int expectedSymbols = getQueryMap().getNumInputs(); + auto checkDomain = + [&attnOp, &expectedSymbols](StringRef operandName, + AffineMap indexingMap) -> LogicalResult { + if (expectedSymbols != indexingMap.getNumInputs()) { + return attnOp->emitError("Mismatched map domain for ") + << operandName << ". Expected: " << expectedSymbols + << " Got: " << indexingMap.getNumInputs(); + } + return success(); + }; + + if (failed(checkDomain("Query", getQueryMap())) || + failed(checkDomain("Key", getKeyMap())) || + failed(checkDomain("Value", getValueMap())) || + failed(checkDomain("Scale", getScaleMap())) || + failed(checkDomain("Output", getOutputMap())) || + failed(checkDomain("Max", getMaxMap())) || + failed(checkDomain("Sum", getSumMap()))) { + return failure(); + } + + // Additional check case if mask exists + if (auto maskMap = getMaskMap()) { + if (failed(checkDomain("Mask", *maskMap))) + return failure(); + } + return success(); } MutableOperandRange OnlineAttentionOp::getDpsInitsMutable() { - return MutableOperandRange(*this, /*numInputs=*/4, /*numInits=*/3); + return MutableOperandRange(*this, /*numInputs=*/getMask() ? 5 : 4, + /*numInits=*/3); } LogicalResult OnlineAttentionOp::reifyResultShapes( diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index c29245269530..c642fd02fadb 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -461,7 +461,7 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", "getLoopIteratorTypes", "getResultTilePosition", "getTiledImplementation", - "generateResultTileValue"]>]> { + "generateResultTileValue"]>, AttrSizedOperandSegments]> { let summary = "Attention operator"; let description = [{ Computes the scaled dot product attention function: @@ -471,6 +471,10 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", Here Q, K, V are given tensors and scale is a scalar value specifying the scale to use. + If an additional mask argument M is included, the result of the first matmul is modified according to: + + Q @ K.T += M + For self-attention, all inputs and the result have the same shape BxNxd where B is the batch dimension, N is the sequence length and d is head dimension. Typically N >>> d. Usually, this operator also performs @@ -495,6 +499,7 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", AnyShaped:$key, AnyShaped:$value, AnyFloat:$scale, + Optional:$mask, Variadic:$outputs, AffineMapArrayAttr:$indexing_maps ); @@ -505,11 +510,22 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", let hasCustomAssemblyFormat = 1; let assemblyFormat = [{ attr-dict - `ins` `(` $query `,` $key `,` $value `,` $scale `:` type($query) `,` type($key) `,` type($value) `,` type($scale) `)` + `ins` `(` $query `,` $key `,` $value `,` $scale (`,` $mask^)? `:` type($query) `,` type($key) `,` type($value) `,` type($scale) (`,` type($mask)^ )? `)` `outs` `(` $outputs `:` type($outputs) `)` (`->` type($results)^)? }]; + let builders = [ + OpBuilder<(ins "TypeRange":$results, + "Value":$query, + "Value":$key, + "Value":$value, + "Value":$scale, + "ValueRange":$outputs, + "ArrayAttr":$indexing_maps, + CArg<"std::optional", "std::nullopt">:$mask)> + ]; + let extraClassDeclaration = [{ // Method to implement for specifying output range for // DestinationStyleOpInterface @@ -530,9 +546,18 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", AffineMap getValueMap() { return cast(getIndexingMapsArray()[2]); } - AffineMap getOutputMap() { + AffineMap getScaleMap() { return cast(getIndexingMapsArray()[3]); } + std::optional getMaskMap() { + if (getMask()) { + return cast(getIndexingMapsArray()[4]); + } + return std::nullopt; + } + AffineMap getOutputMap() { + return cast(getIndexingMapsArray()[getNumDpsInputs()]); + } int64_t getIterationDomainRank() { return getQueryMap().getNumDims(); } @@ -545,6 +570,11 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", ShapedType getValueType() { return cast(getValue().getType()); } + std::optional getMaskType() { + std::optional mask = getMask(); + if (!mask) return std::nullopt; + return cast(mask->getType()); + } FloatType getScaleType() { return cast(cast(getScale().getType())); } @@ -559,12 +589,12 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", std::optional getMaxMap() { if (getNumResults() < 2) return std::nullopt; - return cast(getIndexingMapsArray()[4]); + return cast(getIndexingMapsArray()[getNumDpsInputs() + 1]); } std::optional getSumMap() { if (getNumResults() < 3) return std::nullopt; - return cast(getIndexingMapsArray()[5]); + return cast(getIndexingMapsArray()[getNumDpsInputs() + 2]); } Value getOutput() { return getDpsInitOperand(0)->get(); @@ -599,6 +629,12 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", int64_t getValueRank() { return getValueType().getRank(); } + std::optional getMaskRank() { + std::optional maskType = getMaskType(); + if (!maskType) + return std::nullopt; + return maskType->getRank(); + } int64_t getOutputRank() { return getOutputType().getRank(); } @@ -650,8 +686,12 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention", online_attention(Q, K, V, scale, running_max, running_sum) = online_normalizer(Q @ K.T * scale, running_max, running_sum) @ V + If an additional mask argument M is included, the result of the first matmul is modified according to: + + Q @ K.T += M + The advantage of this online_normalizer is that it can be tiled along - it's reduction dimension, making the online_attention operator: + its reduction dimension, making the online_attention operator: - Tilable along softmax reduction dimension - Associative along softmax reduction dimension - Commutative along softmax associative dimension @@ -666,6 +706,7 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention", AnyShaped:$key, AnyShaped:$value, AnyFloat:$scale, + Optional:$mask, AnyShaped:$output, AnyShaped:$max, AnyShaped:$sum, @@ -677,11 +718,25 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention", let hasCustomAssemblyFormat = 1; let assemblyFormat = [{ attr-dict - `ins` `(` $query `,` $key `,` $value `,` $scale `:` type($query) `,` type($key) `,` type($value) `,` type($scale) `)` + `ins` `(` $query `,` $key `,` $value `,` $scale (`,` $mask^)? `:` type($query) `,` type($key) `,` type($value) `,` type($scale) (`,` type($mask)^ )?`)` `outs` `(` $output `,` $max `,` $sum `:` type($output) `,` type($max) `,` type($sum) `)` (`->` type($results)^)? }]; + let builders = [ + OpBuilder<(ins "TypeRange":$results, + "Value":$query, + "Value":$key, + "Value":$value, + "Value":$scale, + "Value":$output, + "Value":$max, + "Value":$sum, + "ArrayAttr":$indexing_maps, + CArg<"std::optional", "std::nullopt">:$mask)> + ]; + + let extraClassDeclaration = [{ // Method to implement for specifying output range for // DestinationStyleOpInterface @@ -698,14 +753,23 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention", AffineMap getValueMap() { return getIndexingMapsArray()[2]; } - AffineMap getOutputMap() { + AffineMap getScaleMap() { return getIndexingMapsArray()[3]; } + std::optional getMaskMap() { + if (getMask()) { + return cast(getIndexingMapsArray()[4]); + } + return std::nullopt; + } + AffineMap getOutputMap() { + return cast(getIndexingMapsArray()[getNumDpsInputs()]); + } AffineMap getMaxMap() { - return getIndexingMapsArray()[4]; + return cast(getIndexingMapsArray()[getNumDpsInputs() + 1]); } AffineMap getSumMap() { - return getIndexingMapsArray()[5]; + return cast(getIndexingMapsArray()[getNumDpsInputs() + 2]); } int64_t getIterationDomainRank() { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir index 2f8d8efe9ded..b3d5152e36c8 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir @@ -712,6 +712,7 @@ func.func @illegal_attention_inputs(%query: tensor<6x12x20x8xf32>, %key: tensor< %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%query, %key, %value, %scale : tensor<6x12x20x8xf32>, tensor<6x12x20x8xf32>, tensor<6x12x20x8xf32>, f32) outs(%0 : tensor<6x12x20x8xf32>) -> tensor<6x12x20x8xf32> return %1 : tensor<6x12x20x8xf32> @@ -728,6 +729,9 @@ func.func @illegal_flash_attention_inputs(%query: tensor<20xf32>, %key: tensor<2 %1:3 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d2, d1)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3)>, + affine_map<(d0, d1, d2, d3) -> ()>, + affine_map<(d0, d1, d2, d3) -> (d0)>, + affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d0)>]} ins(%query, %key, %value, %scale : tensor<20xf32>, tensor<20x8xf32>, tensor<20x8xf32>, f32) outs(%result, %max, %sum : tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32> return %1#0, %1#1, %1#2 : tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32> @@ -738,10 +742,11 @@ func.func @illegal_flash_attention_inputs(%query: tensor<20xf32>, %key: tensor<2 func.func @illegal_attention_inputs(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: f32) -> tensor<192x1024x64xf32> { %0 = tensor.empty() : tensor<192x1024x64xf32> %scale = arith.constant 1.0 : f32 - // expected-error @+5 {{custom op 'iree_linalg_ext.attention' invalid kind of type specified}} + // expected-error @+6 {{custom op 'iree_linalg_ext.attention' invalid kind of type specified}} %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> return %1 : tensor<192x1024x64xf32> @@ -749,12 +754,28 @@ func.func @illegal_attention_inputs(%query: tensor<192x1024x64xf32>, %key: tenso // ----- -func.func @attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> { +func.func @attention_missing_affine_map(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> { %0 = tensor.empty() : tensor<192x1024x64xf32> %scale = arith.constant 1.0 : f32 - // expected-error @below {{'iree_linalg_ext.attention' op failed to verify op's indexing maps}} + // expected-error @below {{'iree_linalg_ext.attention' op expected an indexing map for each operand}} %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> + return %1 : tensor<192x1024x64xf32> +} + +// ----- + +func.func @attention_affine_map_domain_mismatch(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> { + %0 = tensor.empty() : tensor<192x1024x64xf32> + %scale = arith.constant 1.0 : f32 + // expected-error @below {{Mismatched map domain for Scale. Expected: 5 Got: 4}} + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> return %1 : tensor<192x1024x64xf32> diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir index 77ba7aa38fd2..0f117cca1f75 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir @@ -1087,6 +1087,7 @@ func.func @attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> return %1 : tensor<192x1024x64xf32> @@ -1095,6 +1096,7 @@ func.func @attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> // CHECK: func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]: @@ -1103,7 +1105,7 @@ func.func @attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf // CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32> // CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[D1:.+]] = iree_linalg_ext.attention -// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} +// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : // CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%[[D0]] : // CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> @@ -1118,6 +1120,7 @@ func.func @cross_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x204 %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> return %1 : tensor<192x1024x64xf32> @@ -1125,6 +1128,7 @@ func.func @cross_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x204 // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> // CHECK: func.func @cross_attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]: @@ -1133,7 +1137,7 @@ func.func @cross_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x204 // CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32> // CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[D1:.+]] = iree_linalg_ext.attention -// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} +// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : // CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>, f32) outs(%[[D0]] : // CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> @@ -1150,6 +1154,7 @@ func.func @cross_attention_transposev(%query: tensor<192x1024x64xf32>, %key: ten %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> return %1 : tensor<192x1024x64xf32> @@ -1157,6 +1162,7 @@ func.func @cross_attention_transposev(%query: tensor<192x1024x64xf32>, %key: ten // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)> +// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> // CHECK: func.func @cross_attention_transposev(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]: @@ -1165,7 +1171,7 @@ func.func @cross_attention_transposev(%query: tensor<192x1024x64xf32>, %key: ten // CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32> // CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[D1:.+]] = iree_linalg_ext.attention -// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} +// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : // CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>, f32) outs(%[[D0]] : // CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> @@ -1179,6 +1185,7 @@ func.func @cross_attention_transposev_dyn(%query: tensor, %key: tenso %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor, tensor, tensor, f32) outs(%init : tensor) -> tensor return %1 : tensor @@ -1186,6 +1193,7 @@ func.func @cross_attention_transposev_dyn(%query: tensor, %key: tenso // CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> // CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> // CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)> +// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> // CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> // CHECK: func.func @cross_attention_transposev_dyn(%[[ARG0:[a-zA-Z0-9_]+]]: tensor, %[[ARG1:[a-zA-Z0-9_]+]]: @@ -1193,7 +1201,7 @@ func.func @cross_attention_transposev_dyn(%query: tensor, %key: tenso // CHECK-SAME: { // CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[D1:.+]] = iree_linalg_ext.attention -// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} +// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : // CHECK-SAME: tensor, tensor, tensor, f32) outs(%[[ARG3]] : // CHECK-SAME: tensor) -> tensor diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp index e31058961a47..2f2e046a00c7 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp @@ -100,10 +100,10 @@ static Value truncateFloat(OpBuilder &builder, Location loc, AffineMap inputMap, // used by attention's exp2 who's value is always > 0. Value mx = builder.create( loc, builder.getFloatAttr(srcTy, mxDbl)); - Value clamped = b.create(loc, mx, args[0]); + Value clamp = b.create(loc, mx, args[0]); // Convert scale to the same datatype as input. - Value trunc = convertScalarToDtype(b, loc, clamped, dstTy, + Value trunc = convertScalarToDtype(b, loc, clamp, dstTy, /*isUnsignedCast=*/false); b.create(loc, trunc); }); @@ -175,6 +175,53 @@ static Value computeMatmul(OpBuilder &builder, Location loc, AffineMap lhsMap, return genericOp.getResult(0); } +static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, + AffineMap maskMap, Value qk, Value mask) { + + SmallVector compressedMaps = + compressUnusedDims(SmallVector{qkMap, maskMap}); + qkMap = compressedMaps[0]; + maskMap = compressedMaps[1]; + + SmallVector iteratorTypes(qkMap.getNumDims(), + utils::IteratorType::parallel); + + Value zero = builder.create( + loc, builder.getFloatAttr(getElementTypeOrSelf(qk.getType()), 0.0)); + Value negInf = builder.create( + loc, builder.getFloatAttr(getElementTypeOrSelf(qk.getType()), + -std::numeric_limits::infinity())); + auto genericOp = builder.create( + loc, qk.getType(), SmallVector{mask}, qk, + SmallVector{maskMap, qkMap}, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value qkVal = args[1]; + Value maskVal = args[0]; + + // TODO: Replace bool mask condition once treated as i1 (instead of i8) + if (maskVal.getType().isInteger()) { + maskVal = + b.create(loc, builder.getI1Type(), maskVal); + maskVal = b.create(loc, maskVal, zero, negInf); + } else { + maskVal = convertScalarToDtype(b, loc, maskVal, qkVal.getType(), + /*isUnsignedCast=*/false); + // Scaling to compensate for base-2 softmax + Value log2e = b.create( + loc, b.getFloatAttr(qkVal.getType(), M_LOG2E)); + maskVal = b.create(loc, maskVal, log2e); + } + // Finally, set the returned value to the qk element plus the mask + // element (or 0/-infinity if bool mask). We opt for a AddFOp (instead + // of a SelectFOp to stay consistent with the additive definition of + // attention masking) + Value add = b.create(loc, qkVal, maskVal); + b.create(loc, add); + }); + + return genericOp.getResult(0); +} + // Compute output = exp2(output - input) static Value computeSubAndExp2(OpBuilder &builder, Location loc, AffineMap inputMap, AffineMap outputMap, @@ -240,6 +287,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { Value query = getQuery(); Value key = getKey(); Value value = getValue(); + std::optional mask = getMask(); Value oldAcc = getOutput(); Value oldMax = getMax(); Value oldSum = getSum(); @@ -265,6 +313,9 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { auto qETy = getElementTypeOrSelf(query.getType()); auto vETy = getElementTypeOrSelf(value.getType()); + AffineMap scaleMap = AffineMap::get(/*dimCount=*/getQueryMap().getNumInputs(), + /*symbolCount=*/0, getContext()); + // In the original algorithm, the scaling is done after the softmax: // softmax(Q @ K.T * scale) @ V // @@ -275,8 +326,6 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // significantly affect numerics. if (qETy.getIntOrFloatBitWidth() > 8) { AffineMap qMap = getQueryMap(); - AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(), - /*symbolCount=*/0, getContext()); query = elementwiseValueInPlace(b, loc, qMap, scaleMap, query, scale); } @@ -325,6 +374,11 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { offset); } + // S += mask + if (mask != nullptr) { + s = applyMask(b, loc, sMap, *getMaskMap(), s, mask.value()); + } + // TODO: This decomposition should be in a seperate op called // "online softmax". // ---- Online Softmax ---- diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp index ef5c21ce647e..a01e08a4556a 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp @@ -8,6 +8,7 @@ // The content of this file is adapted from linalg's ElemenwiseOpFusion.cpp and // modified to work with LinalgExt ops, specifically `LinalgExt::AttentionOp`. +#include #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -106,7 +107,7 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, SmallVector newExprs; for (AffineExpr expr : indexingMap.getResults()) { unsigned pos = cast(expr).getPosition(); - SmallVector expandedExprs = llvm::to_vector<4>( + auto expandedExprs = llvm::to_vector_of( llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { return builder.getAffineDimExpr(static_cast(v)); })); @@ -187,8 +188,7 @@ static std::optional> fuseAttentionWithReshapeByExpansion( : collapsingReshapeOp.getReassociationMaps(), expandedType.getShape(), collapsedType.getShape(), rewriter))) return std::nullopt; - - SmallVector expandedOpIndexingMaps = llvm::to_vector<4>( + auto expandedOpIndexingMaps = llvm::to_vector_of( llvm::map_range(attentionOp.getIndexingMapsArray(), [&](AffineMap m) { return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); })); @@ -254,12 +254,18 @@ static std::optional> fuseAttentionWithReshapeByExpansion( } } + Value maskOperand; + if (expandedOpOperands.size() > 4) { + maskOperand = expandedOpOperands[4]; + } + // Create a new `AttentionOp` that has the computed operands/indexing maps. TypeRange resultTypes = ValueRange(outputs).getTypes(); auto fusedOp = rewriter.create( attentionOp.getLoc(), resultTypes, expandedOpOperands[0], expandedOpOperands[1], expandedOpOperands[2], expandedOpOperands[3], - outputs, rewriter.getAffineMapArrayAttr(expandedOpIndexingMaps)); + outputs, rewriter.getAffineMapArrayAttr(expandedOpIndexingMaps), + maskOperand); // Reshape the result values to their original shape if this is a collapsing // reshape folded into its consumer. diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp index d68a2eb35e4c..8089cd0eecf6 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp @@ -161,6 +161,7 @@ getTileAttentionIndexingMaps(RewriterBase &rewriter, int64_t tiledInputRank, AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx); AffineMap vMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx); + AffineMap sMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {}, ctx); AffineMap rMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx); AffineMap maxMap = @@ -174,7 +175,7 @@ getTileAttentionIndexingMaps(RewriterBase &rewriter, int64_t tiledInputRank, vMap = AffineMap::get(vMap.getNumDims(), vMap.getNumSymbols(), vDims, ctx); } - SmallVector attentionMaps = {qMap, kMap, vMap, + SmallVector attentionMaps = {qMap, kMap, vMap, sMap, rMap, maxMap, sumMap}; // Add batches to standard attention indexing maps. int64_t numBatches = tiledInputRank - 2; @@ -417,10 +418,14 @@ void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp, SmallVector indexingMaps = attnOp.getIndexingMapsArray(); indexingMaps.push_back(maxMap); indexingMaps.push_back(sumMap); + + Value mask = attnOp.getMask() ? attnOp.getMask() : Value(); + OnlineAttentionOp onlineAttn = rewriter.create( loc, TypeRange{accFill.getType(), maxFill.getType(), sumFill.getType()}, attnOp.getQuery(), attnOp.getKey(), attnOp.getValue(), attnOp.getScale(), - accFill, maxFill, sumFill, rewriter.getAffineMapArrayAttr(indexingMaps)); + mask, accFill, maxFill, sumFill, + rewriter.getAffineMapArrayAttr(indexingMaps)); onlineAttn->setDiscardableAttrs(attnOp->getDiscardableAttrDictionary()); ops.push_back(onlineAttn); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp index afa3578928ce..016f9d097ea4 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp @@ -1887,6 +1887,16 @@ AttentionOp::getTiledImplementation(OpBuilder &builder, // Scale tiledOperands.emplace_back(scale); + // Mask + Value attnMask = getMask(); + if (attnMask) { + SmallVector maskSlice = + getPermutedSlice(*getMaskMap(), offsets, sizes); + Operation *maskSliceOp = getSlice(builder, loc, attnMask, maskSlice); + tiledOperands.emplace_back(maskSliceOp->getResult(0)); + slices.push_back(maskSliceOp); + } + // Output { Operation *outputSliceOp = getSlice(builder, loc, getOutput(), outputSlice); @@ -1909,7 +1919,7 @@ AttentionOp::getTiledImplementation(OpBuilder &builder, slices.push_back(maxSliceOp); } - std::optional sum = getMax(); + std::optional sum = getSum(); if (sum) { SmallVector sumSlice = getPermutedSlice(*getSumMap(), offsets, sizes); @@ -1923,12 +1933,13 @@ AttentionOp::getTiledImplementation(OpBuilder &builder, SmallVector resultTypes; if (hasPureTensorSemantics()) { - resultTypes.push_back(tiledOperands[4].getType()); + int64_t baseIdx = attnMask ? 5 : 4; + resultTypes.push_back(tiledOperands[baseIdx].getType()); if (max) { - resultTypes.push_back(tiledOperands[5].getType()); + resultTypes.push_back(tiledOperands[baseIdx + 1].getType()); } if (sum) { - resultTypes.push_back(tiledOperands[6].getType()); + resultTypes.push_back(tiledOperands[baseIdx + 2].getType()); } } @@ -2024,6 +2035,11 @@ OnlineAttentionOp::getTiledImplementation(OpBuilder &builder, SmallVector keySlice = getPermutedSlice(getKeyMap(), offsets, sizes); SmallVector valueSlice = getPermutedSlice(getValueMap(), offsets, sizes); + std::optional> maskSlice; + if (auto maskMap = getMaskMap()) { + maskSlice = getPermutedSlice(*maskMap, offsets, sizes); + } + SmallVector outputSlice = getPermutedSlice(getOutputMap(), offsets, sizes); SmallVector maxSlice = getPermutedSlice(getMaxMap(), offsets, sizes); @@ -2065,6 +2081,16 @@ OnlineAttentionOp::getTiledImplementation(OpBuilder &builder, tiledOperands.emplace_back(scale); + // Mask + Value attnMask = getMask(); + if (attnMask) { + SmallVector maskSlice = + getPermutedSlice(*getMaskMap(), offsets, sizes); + Operation *maskSliceOp = getSlice(builder, loc, attnMask, maskSlice); + tiledOperands.emplace_back(maskSliceOp->getResult(0)); + slices.push_back(maskSliceOp); + } + /// Output { Operation *outputSliceOp = getSlice(builder, loc, getOutput(), outputSlice); @@ -2096,9 +2122,9 @@ OnlineAttentionOp::getTiledImplementation(OpBuilder &builder, } SmallVector resultTypes; - resultTypes.push_back(tiledOperands[4].getType()); - resultTypes.push_back(tiledOperands[5].getType()); - resultTypes.push_back(tiledOperands[6].getType()); + resultTypes.push_back(tiledOperands[tiledOperands.size() - 3].getType()); + resultTypes.push_back(tiledOperands[tiledOperands.size() - 2].getType()); + resultTypes.push_back(tiledOperands[tiledOperands.size() - 1].getType()); Operation *tiledOp = mlir::clone(builder, getOperation(), resultTypes, tiledOperands); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir index 7eb4c0aa4fc1..202ab11d2bdc 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir @@ -3,14 +3,15 @@ #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)> -#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> func.func @attention(%q: tensor<2x10x4096x128xf16>, %k: tensor<2x10x4096x128xf16>, %v: tensor<2x10x4096x128xf16>) -> tensor<2x10x4096x128xf16> { %scale = arith.constant 0.125 : f16 %acc = tensor.empty() : tensor<2x10x4096x128xf16> %out = iree_linalg_ext.attention - {indexing_maps = [#map, #map1, #map2, #map3]} + {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%q, %k, %v, %scale : tensor<2x10x4096x128xf16>, tensor<2x10x4096x128xf16>, tensor<2x10x4096x128xf16>, f16) outs(%acc : tensor<2x10x4096x128xf16>) -> tensor<2x10x4096x128xf16> func.return %out : tensor<2x10x4096x128xf16> diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir index 0344eb731491..19d6a6bf38df 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir @@ -6,6 +6,7 @@ func.func @attention(%query: tensor<1x1024x64xf32>, %key: tensor<1x1024x64xf32>, %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, f32) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> return %1 : tensor<1x1024x64xf32> @@ -108,6 +109,7 @@ func.func @attention(%query: tensor, %key: tensor, %value: %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor, tensor, tensor, f32) outs(%0 : tensor) -> tensor return %1 : tensor @@ -213,6 +215,7 @@ func.func @attention_f16(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> return %1 : tensor<1x1024x64xf16> @@ -333,6 +336,7 @@ func.func @attention_transpose_v(%query: tensor<1x1024x64xf16>, %key: tensor<1x1 %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> return %1 : tensor<1x1024x64xf16> diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir index df46f9f6a275..e0aa548e16b0 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir @@ -3,6 +3,7 @@ #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> #mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> #mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> @@ -16,7 +17,7 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>, %scale = arith.constant 1.0 : f16 %out:3 = iree_linalg_ext.online_attention - { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] } + { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] } ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32> @@ -82,6 +83,7 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>, #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> #mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> #mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> @@ -95,7 +97,7 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, %scale = arith.constant 1.0 : f32 %out:3 = iree_linalg_ext.online_attention - { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] } + { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] } ins(%query, %key, %value, %scale : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f32) outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32> @@ -165,3 +167,67 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, // CHECK: arith.mulf // CHECK: arith.addf // CHECK: linalg.yield + +// ----- + +#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> +#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> +#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> +#mapM = affine_map<(batch, m, k1, k2, n) -> (batch, m, k2)> +#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> +#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> + +func.func @attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>, + %key: tensor<192x1024x64xf8E4M3FNUZ>, + %value: tensor<192x1024x64xf8E4M3FNUZ>, + %mask: tensor<192x1024x1024xf8E4M3FNUZ>, + %output: tensor<192x1024x64xf32>, + %max: tensor<192x1024xf32>, + %sum: tensor<192x1024xf32>) + -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) { + %scale = arith.constant 1.0 : f16 + + %out:3 = iree_linalg_ext.online_attention + { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapM, #mapO, #mapR, #mapR] } + ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, tensor<192x1024x64xf8E4M3FNUZ>, f16, tensor<192x1024x1024xf8E4M3FNUZ>) + outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) + -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32> + + return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32> +} +// CHECK-LABEL: @attention_f8_masked +// S = Q @ K +// CHECK: linalg.generic +// CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32 +// CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32 +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// S = S * scale +// CHECK: linalg.generic +// CHECK: arith.mulf +// S = S + mask +// CHECK: arith.addf +// newMax = max(oldMax, rowMax(S)) +// CHECK: linalg.generic +// CHECK: arith.maximumf +// CHECK: linalg.yield +// P = exp2(S - newMax) +// CHECK: linalg.generic +// CHECK: arith.subf +// CHECK: math.exp2 +// CHECK: linalg.yield +// norm = exp2(oldMax - newMax) +// CHECK: linalg.generic +// CHECK: arith.subf +// CHECK: math.exp2 +// CHECK: linalg.yield +// normSum = norm * oldSum +// CHECK: linalg.generic +// CHECK: arith.mulf +// CHECK: linalg.yield +// newSum = normSum + rowMax(P) +// CHECK: linalg.generic +// CHECK: arith.addf +// CHECK: linalg.yield diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir index be9c9da45760..51d3c517e999 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir @@ -8,6 +8,7 @@ func.func @attention(%query: tensor<1x1024x64xf32>, %key: tensor<1x1024x64xf32>, %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, f32) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> return %1 : tensor<1x1024x64xf32> @@ -61,6 +62,7 @@ func.func @attention(%query: tensor, %key: tensor, %value: %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor, tensor, tensor, f32) outs(%0 : tensor) -> tensor return %1 : tensor @@ -117,6 +119,7 @@ func.func @attention_f16(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> return %1 : tensor<1x1024x64xf16> @@ -151,6 +154,7 @@ func.func @attention_transpose_v(%query: tensor<1x1024x64xf16>, %key: tensor<1x1 %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> return %1 : tensor<1x1024x64xf16> diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir index 147d5925346e..9bfa8b4f708c 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir @@ -628,6 +628,7 @@ module attributes { transform.with_named_sequence } { transform.yield } } + // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)> // CHECK: func.func @topk_tile_tensor // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] @@ -1537,6 +1538,7 @@ func.func @attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> return %1 : tensor<192x1024x64xf32> @@ -1553,6 +1555,7 @@ module attributes { transform.with_named_sequence } { // CHECK-DAG: #[[MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> // CHECK-DAG: #[[MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> // CHECK-DAG: #[[MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK-DAG: #[[MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> // CHECK-DAG: #[[MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> // CHECK: func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]: @@ -1580,7 +1583,7 @@ module attributes { transform.with_named_sequence } { // CHECK: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG6]][%[[ARG3]], %[[ARG5]], 0] [%[[D2]], // CHECK-SAME: %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor // CHECK: %[[D5:.+]] = iree_linalg_ext.attention -// CHECK-SAME: {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_O]]]} +// CHECK-SAME: {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_S]], #[[MAP_O]]]} // CHECK-SAME: ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]], // CHECK-SAME: %[[EXTRACTED_SLICE_1]], %[[C1_F32]] : tensor, tensor, tensor, f32) // CHECK-SAME: outs(%[[EXTRACTED_SLICE_2]] : tensor) -> tensor @@ -1595,11 +1598,152 @@ module attributes { transform.with_named_sequence } { // ----- +func.func @attention_float_mask(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>, %mask: tensor<192x1024x1024xf32>) -> tensor<192x1024x64xf32> { + %0 = tensor.empty() : tensor<192x1024x64xf32> + %scale = arith.constant 1.0 : f32 + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, tensor<192x1024x1024xf32>) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> + return %1 : tensor<192x1024x64xf32> +} +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + %1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [10, 30] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (-d0 + 192, 10)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 1024, 30)> +// CHECK-DAG: #[[MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +// CHECK-DAG: #[[MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK-DAG: #[[MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> +// CHECK-DAG: #[[MAP_M:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> +// CHECK-DAG: #[[MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + +// CHECK: func.func @attention_float_mask(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]: +// CHECK-SAME: tensor<192x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: tensor<192x1024x1024xf32>) -> tensor<192x1024x64xf32> +// CHECK-SAME: { +// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index +// CHECK-DAG: %[[C1_F32:.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C192:.+]] = arith.constant 192 : index +// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index +// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32> +// CHECK: %[[D1:.+]] = scf.for %[[ARG4:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C192]] step %[[C10]] +// CHECK-SAME: iter_args(%[[ARG5:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<192x1024x64xf32>) { +// CHECK: %[[D2:.+]] = scf.for %[[ARG6:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C30]] +// CHECK-SAME: iter_args(%[[ARG7:[a-zA-Z0-9_]+]] = %[[ARG5]]) -> (tensor<192x1024x64xf32>) { +// CHECK-DAG: %[[D3:.+]] = affine.min #[[MAP]](%[[ARG4]]) +// CHECK-DAG: %[[D4:.+]] = affine.min #[[MAP1]](%[[ARG6]]) +// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]], +// CHECK-SAME: %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor +// CHECK: %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0] [%[[D3]], 1024, 64] [1, +// CHECK-SAME: 1, 1] : tensor<192x1024x64xf32> to tensor +// CHECK: %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG2]][%[[ARG4]], 0, 0] [%[[D3]], 1024, 64] [1, +// CHECK-SAME: 1, 1] : tensor<192x1024x64xf32> to tensor +// CHECK: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG3]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]], +// CHECK-SAME: %[[D4]], 1024] [1, 1, 1] : tensor<192x1024x1024xf32> to tensor +// CHECK: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG7]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]], +// CHECK-SAME: %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor +// CHECK: %[[D5:.+]] = iree_linalg_ext.attention +// CHECK-SAME: {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_S]], #[[MAP_M]], #[[MAP_O]]]} +// CHECK-SAME: ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]], +// CHECK-SAME: %[[EXTRACTED_SLICE_1]], %[[C1_F32]], %[[EXTRACTED_SLICE_2]] : tensor, tensor, tensor, f32, tensor) +// CHECK-SAME: outs(%[[EXTRACTED_SLICE_3]] : tensor) -> tensor +// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D5]] into %[[ARG7]][%[[ARG4]], %[[ARG6]], 0] +// CHECK-SAME: [%[[D3]], %[[D4]], 64] [1, 1, 1] : tensor into tensor<192x1024x64xf32> +// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<192x1024x64xf32> +// CHECK: } +// CHECK: scf.yield %[[D2]] : tensor<192x1024x64xf32> +// CHECK: } +// CHECK: return %[[D1]] : tensor<192x1024x64xf32> +// CHECK: } + +// ----- + +func.func @attention_bool_mask(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>, %mask: tensor<192x1024x1024xi1>) -> tensor<192x1024x64xf32> { + %0 = tensor.empty() : tensor<192x1024x64xf32> + %scale = arith.constant 1.0 : f32 + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, tensor<192x1024x1024xi1>) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> + return %1 : tensor<192x1024x64xf32> +} +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + %1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [10, 30] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (-d0 + 192, 10)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 1024, 30)> +// CHECK-DAG: #[[MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +// CHECK-DAG: #[[MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK-DAG: #[[MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> +// CHECK-DAG: #[[MAP_M:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> +// CHECK-DAG: #[[MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + +// CHECK: func.func @attention_bool_mask(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]: +// CHECK-SAME: tensor<192x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: tensor<192x1024x1024xi1>) -> tensor<192x1024x64xf32> +// CHECK-SAME: { +// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index +// CHECK-DAG: %[[C1_F32:.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C192:.+]] = arith.constant 192 : index +// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index +// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32> +// CHECK: %[[D1:.+]] = scf.for %[[ARG4:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C192]] step %[[C10]] +// CHECK-SAME: iter_args(%[[ARG5:[a-zA-Z0-9_]+]] = %[[D0]]) -> (tensor<192x1024x64xf32>) { +// CHECK: %[[D2:.+]] = scf.for %[[ARG6:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C30]] +// CHECK-SAME: iter_args(%[[ARG7:[a-zA-Z0-9_]+]] = %[[ARG5]]) -> (tensor<192x1024x64xf32>) { +// CHECK-DAG: %[[D3:.+]] = affine.min #[[MAP]](%[[ARG4]]) +// CHECK-DAG: %[[D4:.+]] = affine.min #[[MAP1]](%[[ARG6]]) +// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]], +// CHECK-SAME: %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor +// CHECK: %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0] [%[[D3]], 1024, 64] [1, +// CHECK-SAME: 1, 1] : tensor<192x1024x64xf32> to tensor +// CHECK: %[[EXTRACTED_SLICE_1:.+]] = tensor.extract_slice %[[ARG2]][%[[ARG4]], 0, 0] [%[[D3]], 1024, 64] [1, +// CHECK-SAME: 1, 1] : tensor<192x1024x64xf32> to tensor +// CHECK: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG3]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]], +// CHECK-SAME: %[[D4]], 1024] [1, 1, 1] : tensor<192x1024x1024xi1> to tensor +// CHECK: %[[EXTRACTED_SLICE_3:.+]] = tensor.extract_slice %[[ARG7]][%[[ARG4]], %[[ARG6]], 0] [%[[D3]], +// CHECK-SAME: %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor +// CHECK: %[[D5:.+]] = iree_linalg_ext.attention +// CHECK-SAME: {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_S]], #[[MAP_M]], #[[MAP_O]]]} +// CHECK-SAME: ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]], +// CHECK-SAME: %[[EXTRACTED_SLICE_1]], %[[C1_F32]], %[[EXTRACTED_SLICE_2]] : tensor, tensor, tensor, f32, tensor) +// CHECK-SAME: outs(%[[EXTRACTED_SLICE_3]] : tensor) -> tensor +// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D5]] into %[[ARG7]][%[[ARG4]], %[[ARG6]], 0] +// CHECK-SAME: [%[[D3]], %[[D4]], 64] [1, 1, 1] : tensor into tensor<192x1024x64xf32> +// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<192x1024x64xf32> +// CHECK: } +// CHECK: scf.yield %[[D2]] : tensor<192x1024x64xf32> +// CHECK: } +// CHECK: return %[[D1]] : tensor<192x1024x64xf32> +// CHECK: } + +// ----- + func.func @attention_memref(%query: memref<192x1024x64xf32>, %key: memref<192x1024x64xf32>, %value: memref<192x1024x64xf32>, %output: memref<192x1024x64xf32>) { %scale = arith.constant 1.0 : f32 iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : memref<192x1024x64xf32>, memref<192x1024x64xf32>, memref<192x1024x64xf32>, f32) outs(%output : memref<192x1024x64xf32>) return @@ -1616,6 +1760,7 @@ module attributes { transform.with_named_sequence } { // CHECK-DAG: #[[MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> // CHECK-DAG: #[[MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> // CHECK-DAG: #[[MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK-DAG: #[[MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> // CHECK-DAG: #[[MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> // CHECK: func.func @attention_memref(%[[ARG0:[a-zA-Z0-9_]+]]: memref<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]: @@ -1640,7 +1785,7 @@ module attributes { transform.with_named_sequence } { // CHECK: %[[SUBVIEW_2:.+]] = memref.subview %[[ARG3]][%[[ARG4]], %[[ARG5]], 0] [%[[D0]], %[[D1]], 64] [1, 1, // CHECK-SAME: 1] : memref<192x1024x64xf32> to memref> // CHECK: iree_linalg_ext.attention -// CHECK-SAME: {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_O]]]} +// CHECK-SAME: {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_S]], #[[MAP_O]]]} // CHECK-SAME: ins(%[[SUBVIEW]], %[[SUBVIEW_0]], %[[SUBVIEW_1]], %[[C1_F32]] : memref>, memref>, // CHECK-SAME: memref>, f32) outs(%[[SUBVIEW_2]] : @@ -1662,6 +1807,7 @@ func.func @attention_fusion( indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%query, %key, %value, %scale : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, f16) outs(%0 : tensor<2x10x4096x64xf16>) -> tensor<2x10x4096x64xf16> @@ -1706,6 +1852,7 @@ module attributes { transform.with_named_sequence } { #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> #mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> #mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> @@ -1723,7 +1870,7 @@ func.func @online_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x10 %sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32> %out:3 = iree_linalg_ext.online_attention - { indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR] } + { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] } ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32> @@ -1737,8 +1884,9 @@ func.func @online_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x10 // CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> // CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> // CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> -// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> -// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> +// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> +// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> +// CHECK-DAG: #[[$MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> // CHECK-LABEL: @online_attention // CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) in (48, 8, 2) // CHECK-DAG: %[[I0:.+]] = affine.apply #[[$IDXMAP0]](%[[IV0]]) @@ -1751,7 +1899,7 @@ func.func @online_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x10 // CHECK-DAG: %[[M:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32> // CHECK-DAG: %[[S:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32> // CHECK-DAG: iree_linalg_ext.online_attention -// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP4]]]} +// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP5]], #[[$MAP5]]]} // CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %{{.*}} : tensor<4x128x64xf32>, tensor<4x1024x64xf32>, tensor<4x1024x32xf32>, f32) // CHECK-SAME: outs(%[[O]], %[[M]], %[[S]] : tensor<4x128x32xf32>, tensor<4x128xf32>, tensor<4x128xf32>) // CHECK: scf.forall.in_parallel @@ -1763,3 +1911,150 @@ module attributes { transform.with_named_sequence } { transform.yield } } + +// ----- + +#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> +#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> +#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> +#mapM = affine_map<(batch, m, k1, k2, n) -> (batch, m, k2)> +#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> +#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> + +func.func @online_attention_float_mask(%query: tensor<192x1024x64xf32>, + %key: tensor<192x1024x64xf32>, + %value: tensor<192x1024x64xf32>, + %mask: tensor<192x1024x1024xf32>) + -> tensor<192x1024x64xf32> { + %scale = arith.constant 1.0 : f32 + + %output_empty = tensor.empty() : tensor<192x1024x64xf32> + %row_red_empty = tensor.empty() : tensor<192x1024xf32> + + %sum_ident = arith.constant 0.000000e+00 : f32 + %max_ident = arith.constant -3.40282347E+38 : f32 + + %output_fill = linalg.fill ins(%sum_ident : f32) outs(%output_empty : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> + %acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32> + %sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32> + + // Adjust the operation to correctly handle the mask + %out:3 = iree_linalg_ext.online_attention + { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapM, #mapO, #mapR, #mapR] } + ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, tensor<192x1024x1024xf32>) + outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) + -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32> + + return %out#0 : tensor<192x1024x64xf32> +} + +// CHECK-DAG: #[[$IDXMAP0:.+]] = affine_map<(d0) -> (d0 * 4)> +// CHECK-DAG: #[[$IDXMAP1:.+]] = affine_map<(d0) -> (d0 * 128)> +// CHECK-DAG: #[[$IDXMAP2:.+]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> +// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> +// CHECK-DAG: #[[$MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> +// CHECK-DAG: #[[$MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> +// CHECK-LABEL: @online_attention_float_mask +// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) in (48, 8, 2) +// CHECK-DAG: %[[I0:.+]] = affine.apply #[[$IDXMAP0]](%[[IV0]]) +// CHECK-DAG: %[[I1:.+]] = affine.apply #[[$IDXMAP1]](%[[IV1]]) +// CHECK-DAG: %[[I2:.+]] = affine.apply #[[$IDXMAP2]](%[[IV2]]) +// CHECK-DAG: %[[Q:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x64xf32> +// CHECK-DAG: %[[K:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, 0] [4, 1024, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x64xf32> +// CHECK-DAG: %[[V:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, %[[I2]]] [4, 1024, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x32xf32> +// CHECK-DAG: %[[MASK:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 1024] [1, 1, 1] : tensor<192x1024x1024xf32> to tensor<4x128x1024xf32> +// CHECK-DAG: %[[O:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], %[[I2]]] [4, 128, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x32xf32> +// CHECK-DAG: %[[M:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32> +// CHECK-DAG: %[[S:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32> +// CHECK-DAG: iree_linalg_ext.online_attention +// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP5]], #[[$MAP6]], #[[$MAP6]]]} +// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %{{.*}}, %[[MASK]] : tensor<4x128x64xf32>, tensor<4x1024x64xf32>, tensor<4x1024x32xf32>, f32, tensor<4x128x1024xf32>) +// CHECK-SAME: outs(%[[O]], %[[M]], %[[S]] : tensor<4x128x32xf32>, tensor<4x128xf32>, tensor<4x128xf32>) +// CHECK: scf.forall.in_parallel + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + %tiled_att, %grid = transform.structured.tile_using_forall %0 tile_sizes [4, 128, 0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> +#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> +#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> +#mapM = affine_map<(batch, m, k1, k2, n) -> (batch, m, k2)> +#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> +#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> + +func.func @online_attention_bool_mask(%query: tensor<192x1024x64xf32>, + %key: tensor<192x1024x64xf32>, + %value: tensor<192x1024x64xf32>, + %mask: tensor<192x1024x1024xi1>) + -> tensor<192x1024x64xf32> { + %scale = arith.constant 1.0 : f32 + + %output_empty = tensor.empty() : tensor<192x1024x64xf32> + %row_red_empty = tensor.empty() : tensor<192x1024xf32> + + %sum_ident = arith.constant 0.000000e+00 : f32 + %max_ident = arith.constant -3.40282347E+38 : f32 + + %output_fill = linalg.fill ins(%sum_ident : f32) outs(%output_empty : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> + %acc_fill = linalg.fill ins(%max_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32> + %sum_fill = linalg.fill ins(%sum_ident : f32) outs(%row_red_empty : tensor<192x1024xf32>) -> tensor<192x1024xf32> + + // Adjust the operation to correctly handle the mask + %out:3 = iree_linalg_ext.online_attention + { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapM, #mapO, #mapR, #mapR] } + ins(%query, %key, %value, %scale, %mask : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, tensor<192x1024x1024xi1>) + outs(%output_fill, %acc_fill, %sum_fill : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) + -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32> + + return %out#0 : tensor<192x1024x64xf32> +} + + +// CHECK-DAG: #[[$IDXMAP0:.+]] = affine_map<(d0) -> (d0 * 4)> +// CHECK-DAG: #[[$IDXMAP1:.+]] = affine_map<(d0) -> (d0 * 128)> +// CHECK-DAG: #[[$IDXMAP2:.+]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> +// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> +// CHECK-DAG: #[[$MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> +// CHECK-DAG: #[[$MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> +// CHECK-LABEL: @online_attention_bool_mask +// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) in (48, 8, 2) +// CHECK-DAG: %[[I0:.+]] = affine.apply #[[$IDXMAP0]](%[[IV0]]) +// CHECK-DAG: %[[I1:.+]] = affine.apply #[[$IDXMAP1]](%[[IV1]]) +// CHECK-DAG: %[[I2:.+]] = affine.apply #[[$IDXMAP2]](%[[IV2]]) +// CHECK-DAG: %[[Q:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x64xf32> +// CHECK-DAG: %[[K:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, 0] [4, 1024, 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x64xf32> +// CHECK-DAG: %[[V:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], 0, %[[I2]]] [4, 1024, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x1024x32xf32> +// CHECK-DAG: %[[MASK:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], 0] [4, 128, 1024] [1, 1, 1] : tensor<192x1024x1024xi1> to tensor<4x128x1024xi1> +// CHECK-DAG: %[[O:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]], %[[I2]]] [4, 128, 32] [1, 1, 1] : tensor<192x1024x64xf32> to tensor<4x128x32xf32> +// CHECK-DAG: %[[M:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32> +// CHECK-DAG: %[[S:.+]] = tensor.extract_slice %{{.*}}[%[[I0]], %[[I1]]] [4, 128] [1, 1] : tensor<192x1024xf32> to tensor<4x128xf32> +// CHECK-DAG: iree_linalg_ext.online_attention +// CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]], #[[$MAP5]], #[[$MAP6]], #[[$MAP6]]]} +// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]], %{{.*}}, %[[MASK]] : tensor<4x128x64xf32>, tensor<4x1024x64xf32>, tensor<4x1024x32xf32>, f32, tensor<4x128x1024xi1>) +// CHECK-SAME: outs(%[[O]], %[[M]], %[[S]] : tensor<4x128x32xf32>, tensor<4x128xf32>, tensor<4x128xf32>) +// CHECK: scf.forall.in_parallel + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + %tiled_att, %grid = transform.structured.tile_using_forall %0 tile_sizes [4, 128, 0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp index 428f2401beeb..90797a0b4127 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h" #include "llvm/ADT/SetOperations.h" +#include "llvm/Support/raw_ostream.h" namespace mlir::iree_compiler::IREE::LinalgExt { @@ -83,10 +84,6 @@ void AttentionOpDetail::inferFromIndexingMaps( FailureOr AttentionOpDetail::get(ArrayRef indexingMaps) { - if (indexingMaps.size() != 4 && indexingMaps.size() != 6) { - return failure(); - } - AttentionOpDetail opInfo; opInfo.inferFromIndexingMaps(indexingMaps); opInfo.maps = SmallVector(indexingMaps); diff --git a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir index 6b0203621606..e0e0cef13281 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir @@ -3,11 +3,12 @@ #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> -#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> util.func public @attention_static(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16) -> tensor<2x10x4096x64xf16> { %0 = tensor.empty() : tensor<20x4096x64xf16> - %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> + %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16> util.return %expanded : tensor<2x10x4096x64xf16> } @@ -36,11 +37,51 @@ util.func public @attention_static(%arg0: tensor<20x4096x16xf16>, %arg1: tensor< #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> -#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> +#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + +util.func public @attention_static_masked(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16, %arg4: tensor<20x4096x1024xf16>) -> tensor<2x10x4096x64xf16> { + %0 = tensor.empty() : tensor<20x4096x64xf16> + %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%arg0, %arg1, %arg2, %arg3, %arg4 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16, tensor<20x4096x1024xf16>) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> + %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16> + util.return %expanded : tensor<2x10x4096x64xf16> +} + +//CHECK-LABEL: func public @attention_static_masked( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16> +// CHECK-SAME: %[[ARG3:.+]]: f16 +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor<20x4096x1024xf16>) +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x10x4096x64xf16> +// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 16] +// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 16] +// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 64] +// CHECK-DAG: %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 1024] +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK-SAME: indexing_maps = +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> +// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]], %[[MASK]] : +// CHECK-SAME: outs(%[[EMPTY]] : +// CHECK: util.return %[[ATTENTION]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> util.func public @attention_expand_all(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16) -> tensor<2x10x2048x2x2x32xf16> { %0 = tensor.empty() : tensor<20x4096x64xf16> - %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> + %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> %expanded = tensor.expand_shape %1 [[0, 1], [2, 3], [4, 5]] output_shape [2, 10, 2048, 2, 2, 32] : tensor<20x4096x64xf16> into tensor<2x10x2048x2x2x32xf16> util.return %expanded : tensor<2x10x2048x2x2x32xf16> } @@ -66,11 +107,50 @@ util.func public @attention_expand_all(%arg0: tensor<20x4096x16xf16>, %arg1: ten // ----- +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> +#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + +util.func public @attention_expand_all_masked(%arg0: tensor<20x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16, %arg4: tensor<20x4096x1024xf16>) -> tensor<2x10x2048x2x2x32xf16> { + %0 = tensor.empty() : tensor<20x4096x64xf16> + %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%arg0, %arg1, %arg2, %arg3, %arg4: tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16, tensor<20x4096x1024xf16>) outs(%0 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> + %expanded = tensor.expand_shape %1 [[0, 1], [2, 3], [4, 5]] output_shape [2, 10, 2048, 2, 2, 32] : tensor<20x4096x64xf16> into tensor<2x10x2048x2x2x32xf16> + util.return %expanded : tensor<2x10x2048x2x2x32xf16> +} + +//CHECK-LABEL: func public @attention_expand_all_masked( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<20x4096x16xf16> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16> +// CHECK-SAME: %[[ARG3:.+]]: f16 +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor<20x4096x1024xf16>) +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x10x2048x2x2x32xf16> +// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3], [4]] output_shape [2, 10, 2048, 2, 16] : tensor<20x4096x16xf16> into tensor<2x10x2048x2x16xf16> +// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3]] output_shape [2, 10, 1024, 16] : tensor<20x1024x16xf16> into tensor<2x10x1024x16xf16> +// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [2, 10, 1024, 2, 32] : tensor<20x1024x64xf16> into tensor<2x10x1024x2x32xf16> +// CHECK-DAG: %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[\[}}0, 1], [2, 3], [4]] output_shape [2, 10, 2048, 2, 1024] : tensor<20x4096x1024xf16> into tensor<2x10x2048x2x1024xf16> +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK-SAME: indexing_maps = +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d4)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d7)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d5)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d6, d7)> +// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]], %[[MASK]] : +// CHECK-SAME: outs(%[[EMPTY]] : +// CHECK: util.return %[[ATTENTION]] + +// ----- #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> -#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> util.func public @attention_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: f16) -> tensor<2x?x?x?xf16> { %c0 = arith.constant 0 : index @@ -83,7 +163,7 @@ util.func public @attention_dynamic(%arg0: tensor, %arg1: tensor %d4 = tensor.dim %arg2, %c2 : tensor %0 = tensor.empty(%d0, %d1, %d4) : tensor - %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%arg0, %arg1, %arg2, %arg3 : tensor, tensor, tensor, f16) outs(%0 : tensor) -> tensor + %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg0, %arg1, %arg2, %arg3 : tensor, tensor, tensor, f16) outs(%0 : tensor) -> tensor %split = arith.divsi %d0, %c2 : index %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4] : tensor into tensor<2x?x?x?xf16> @@ -130,7 +210,78 @@ util.func public @attention_dynamic(%arg0: tensor, %arg1: tensor (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> -#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> +#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + +util.func public @attention_dynamic_masked(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: f16, %arg4: tensor) -> tensor<2x?x?x?xf16> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %d2 = tensor.dim %arg0, %c2 : tensor + %d3 = tensor.dim %arg1, %c1 : tensor + %d4 = tensor.dim %arg2, %c2 : tensor + %0 = tensor.empty(%d0, %d1, %d4) : tensor + %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%arg0, %arg1, %arg2, %arg3, %arg4 : tensor, tensor, tensor, f16, tensor) outs(%0 : tensor) -> tensor + %split = arith.divsi %d0, %c2 : index + %expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4] + : tensor into tensor<2x?x?x?xf16> + util.return %expanded : tensor<2x?x?x?xf16> +} + +//CHECK-LABEL: func public @attention_dynamic_masked( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG3:.+]]: f16 +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]] +// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]] +// CHECK-DAG: %[[SPLIT0:.+]] = arith.divui %[[D0]] +// CHECK-DAG: %[[VAL:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[D0]]] +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[VAL]], %[[D1]], %[[D4]]) : tensor<2x?x?x?xf16> +// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D1]], %[[D2]]] +// CHECK-DAG: %[[D5:.+]] = tensor.dim %[[ARG1]], %[[C0]] +// CHECK-DAG: %[[D6:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[D7:.+]] = tensor.dim %[[ARG1]], %[[C2]] +// CHECK-DAG: %[[SPLIT1:.+]] = arith.divui %[[D5]], %[[C2]] +// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT1]], %[[D6]], %[[D7]]] +// CHECK-DAG: %[[D8:.+]] = tensor.dim %[[ARG2]], %[[C0]] +// CHECK-DAG: %[[D9:.+]] = tensor.dim %[[ARG2]], %[[C1]] +// CHECK-DAG: %[[SPLIT2:.+]] = arith.divui %[[D8]], %[[C2]] +// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT2]], %[[D9]], %[[D4]]] +// CHECK-DAG: %[[D10:.+]] = tensor.dim %[[ARG4]], %[[C0]] +// CHECK-DAG: %[[D11:.+]] = tensor.dim %[[ARG4]], %[[C1]] +// CHECK-DAG: %[[D12:.+]] = tensor.dim %[[ARG4]], %[[C2]] +// CHECK-DAG: %[[SPLIT3:.+]] = arith.divui %[[D10]], %[[C2]] +// CHECK-DAG: %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT3]], %[[D11]], %[[D12]]] +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK-SAME: indexing_maps = +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> +// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]], %[[MASK]] : +// CHECK-SAME: outs(%[[EMPTY]] : +// CHECK: util.return %[[ATTENTION]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> util.func public @sink_through_attention(%0 : tensor<4x32x64x128xf16>, %1 : tensor<4x32x64x128xf16>, %2 : tensor<4x32x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) { %13 = tensor.empty() : tensor<4x32x64x128xf16> @@ -138,7 +289,7 @@ util.func public @sink_through_attention(%0 : tensor<4x32x64x128xf16>, %1 : tens %collapsed_13 = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> %collapsed_14 = tensor.collapse_shape %2 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> %17 = tensor.empty() : tensor<128x64x128xf16> - %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%collapsed_12, %collapsed_13, %collapsed_14, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16> + %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%collapsed_12, %collapsed_13, %collapsed_14, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16> util.return %18 : tensor<128x64x128xf16> } @@ -162,13 +313,52 @@ util.func public @sink_through_attention(%0 : tensor<4x32x64x128xf16>, %1 : tens #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> -#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> +#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + +util.func public @sink_through_attention_masked(%0 : tensor<4x32x64x128xf16>, %1 : tensor<4x32x64x128xf16>, %2 : tensor<4x32x64x128xf16>, %cst : f16, %3 : tensor<4x32x64x64xf16>) -> (tensor<128x64x128xf16>) { + %13 = tensor.empty() : tensor<4x32x64x128xf16> + %collapsed_12 = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> + %collapsed_13 = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> + %collapsed_14 = tensor.collapse_shape %2 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> + %collapsed_15 = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<4x32x64x64xf16> into tensor<128x64x64xf16> + %17 = tensor.empty() : tensor<128x64x128xf16> + %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%collapsed_12, %collapsed_13, %collapsed_14, %cst, %collapsed_15 : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16, tensor<128x64x64xf16>) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16> + util.return %18 : tensor<128x64x128xf16> +} + +// CHECK-LABEL: util.func public @sink_through_attention_masked +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG3:.+]]: f16 +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK-SAME: indexing_maps = +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]] : +// CHECK: %[[RET:.+]] = tensor.collapse_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> +// CHECK: util.return %[[RET]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> util.func public @sink_single_collapse(%0 : tensor<4x32x64x128xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16) -> (tensor<128x64x128xf16>) { %13 = tensor.empty() : tensor<4x32x64x128xf16> %collapsed_12 = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> %17 = tensor.empty() : tensor<128x64x128xf16> - %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3]} ins(%collapsed_12, %1, %2, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16> + %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%collapsed_12, %1, %2, %cst : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16> util.return %18 : tensor<128x64x128xf16> } @@ -188,3 +378,41 @@ util.func public @sink_single_collapse(%0 : tensor<4x32x64x128xf16>, %1 : tensor // CHECK-SAME: ins(%[[ARG0]], %[[EXPANDED1]], %[[EXPANDED2]], %[[ARG3]] : // CHECK: %[[RET:.+]] = tensor.collapse_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> // CHECK: util.return %[[RET]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> +#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + +util.func public @sink_single_collapse_masked(%0 : tensor<4x32x64x128xf16>, %1 : tensor<128x64x128xf16>, %2 : tensor<128x64x128xf16>, %cst : f16, %3 : tensor<128x64x64xf16>) -> (tensor<128x64x128xf16>) { + %13 = tensor.empty() : tensor<4x32x64x128xf16> + %collapsed_12 = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> + %17 = tensor.empty() : tensor<128x64x128xf16> + %18 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%collapsed_12, %1, %2, %cst, %3 : tensor<128x64x128xf16>, tensor<128x64x128xf16>, tensor<128x64x128xf16>, f16, tensor<128x64x64xf16>) outs(%17 : tensor<128x64x128xf16>) -> tensor<128x64x128xf16> + util.return %18 : tensor<128x64x128xf16> +} + +// CHECK-LABEL: util.func public @sink_single_collapse_masked +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG3:.+]]: f16 +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: +// CHECK-DAG: %[[EXPANDED1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [4, 32, 64, 128] +// CHECK-DAG: %[[EXPANDED2:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [4, 32, 64, 128] +// CHECK-DAG: %[[EXPANDED3:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [4, 32, 64, 64] +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK-SAME: indexing_maps = +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> +// CHECK-SAME: ins(%[[ARG0]], %[[EXPANDED1]], %[[EXPANDED2]], %[[ARG3]], %[[EXPANDED3]] : +// CHECK: %[[RET:.+]] = tensor.collapse_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<4x32x64x128xf16> into tensor<128x64x128xf16> +// CHECK: util.return %[[RET]] diff --git a/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir index 7f0ed697388f..edfbfcfffb93 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir @@ -388,7 +388,7 @@ util.func public @clone_gather_like(%arg0: tensor<4x1x4xi64>, %arg1: tensor<1638 } -> tensor<4x1x4x16x32x128xf16> %3 = tensor.empty() : tensor<4x1x32x1x128xf16> %4 = flow.dispatch.region -> (tensor<4x1x32x1x128xf16>) { - %5 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>]} ins(%arg5, %1, %2, %arg2 : tensor<4x1x32x1x128xf16>, tensor<4x1x4x16x32x128xf16>, tensor<4x1x4x16x32x128xf16>, f16) outs(%3 : tensor<4x1x32x1x128xf16>) -> tensor<4x1x32x1x128xf16> + %5 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>]} ins(%arg5, %1, %2, %arg2 : tensor<4x1x32x1x128xf16>, tensor<4x1x4x16x32x128xf16>, tensor<4x1x4x16x32x128xf16>, f16) outs(%3 : tensor<4x1x32x1x128xf16>) -> tensor<4x1x32x1x128xf16> flow.return %5 : tensor<4x1x32x1x128xf16> } %collapsed = tensor.collapse_shape %4 [[0, 1], [2], [3], [4]] : tensor<4x1x32x1x128xf16> into tensor<4x32x1x128xf16> @@ -402,3 +402,98 @@ util.func public @clone_gather_like(%arg0: tensor<4x1x4xi64>, %arg1: tensor<1638 // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention // CHECK: ins({{.*}}, %[[GATHER0]], %[[GATHER1]] // CHECK: flow.return %[[ATTENTION]] + +// ----- + +// Clone 'gather-like' operations +util.func public @clone_gather_like(%arg0: tensor<4x1x4xi64>, %arg1: tensor<16384x16x32x128xf16>, %arg2: f16, %arg3: i64, %arg4: tensor<4x4x16x32x128xf16>, %arg5: tensor<4x1x32x1x128xf16>) -> tensor<4x32x1x128xf16> { + %0 = tensor.empty() : tensor<4x1x4x16x32x128xf16> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x1x4xi64>) outs(%0 : tensor<4x1x4x16x32x128xf16>) { + ^bb0(%in: i64, %out: f16): + %5 = arith.index_cast %in : i64 to index + %6 = linalg.index 3 : index + %7 = linalg.index 4 : index + %8 = linalg.index 5 : index + %extracted = tensor.extract %arg1[%5, %6, %7, %8] : tensor<16384x16x32x128xf16> + linalg.yield %extracted : f16 + } -> tensor<4x1x4x16x32x128xf16> + %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x1x4xi64>) outs(%0 : tensor<4x1x4x16x32x128xf16>) { + ^bb0(%in: i64, %out: f16): + %5 = arith.addi %in, %arg3 : i64 + %6 = arith.index_cast %5 : i64 to index + %7 = linalg.index 3 : index + %8 = linalg.index 4 : index + %9 = linalg.index 5 : index + %extracted = tensor.extract %arg1[%6, %7, %8, %9] : tensor<16384x16x32x128xf16> + linalg.yield %extracted : f16 + } -> tensor<4x1x4x16x32x128xf16> + %3 = tensor.empty() : tensor<4x1x32x1x128xf16> + %4 = flow.dispatch.region -> (tensor<4x1x32x1x128xf16>) { + %5 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>]} ins(%arg5, %1, %2, %arg2 : tensor<4x1x32x1x128xf16>, tensor<4x1x4x16x32x128xf16>, tensor<4x1x4x16x32x128xf16>, f16) outs(%3 : tensor<4x1x32x1x128xf16>) -> tensor<4x1x32x1x128xf16> + flow.return %5 : tensor<4x1x32x1x128xf16> + } + %collapsed = tensor.collapse_shape %4 [[0, 1], [2], [3], [4]] : tensor<4x1x32x1x128xf16> into tensor<4x32x1x128xf16> + util.return %collapsed : tensor<4x32x1x128xf16> +} + +// CHECK-LABEL: util.func public @clone_gather_lik +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[GATHER0:.+]] = linalg.generic +// CHECK: %[[GATHER1:.+]] = linalg.generic +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK: ins({{.*}}, %[[GATHER0]], %[[GATHER1]] +// CHECK: flow.return %[[ATTENTION]] + +// ----- + +util.func public @clone_gather_like(%arg0: tensor<4x1x4xi64>, %arg1: tensor<16384x16x32x128xf16>, %arg2: f16, %arg3: i64, %arg4: tensor<4x4x16x32x128xf16>, %arg5: tensor<4x1x32x1x128xf16>, %arg6: tensor<4x1x32x128xf16>) -> tensor<4x32x1x128xf16> { + %0 = tensor.empty() : tensor<4x1x4x16x32x128xf16> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x1x4xi64>) outs(%0 : tensor<4x1x4x16x32x128xf16>) { + ^bb0(%in: i64, %out: f16): + %5 = arith.index_cast %in : i64 to index + %6 = linalg.index 3 : index + %7 = linalg.index 4 : index + %8 = linalg.index 5 : index + %extracted = tensor.extract %arg1[%5, %6, %7, %8] : tensor<16384x16x32x128xf16> + linalg.yield %extracted : f16 + } -> tensor<4x1x4x16x32x128xf16> + + %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x1x4xi64>) outs(%0 : tensor<4x1x4x16x32x128xf16>) { + ^bb0(%in: i64, %out: f16): + %5 = arith.addi %in, %arg3 : i64 + %6 = arith.index_cast %5 : i64 to index + %7 = linalg.index 3 : index + %8 = linalg.index 4 : index + %9 = linalg.index 5 : index + %extracted = tensor.extract %arg1[%6, %7, %8, %9] : tensor<16384x16x32x128xf16> + linalg.yield %extracted : f16 + } -> tensor<4x1x4x16x32x128xf16> + + %3 = tensor.empty() : tensor<4x1x32x1x128xf16> + + %4 = flow.dispatch.region -> (tensor<4x1x32x1x128xf16>) { + %5 = iree_linalg_ext.attention { + indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d2, d7)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)> + ] + } ins(%arg5, %1, %2, %arg2, %arg6 : tensor<4x1x32x1x128xf16>, tensor<4x1x4x16x32x128xf16>, tensor<4x1x4x16x32x128xf16>, f16, tensor<4x1x32x128xf16>) outs(%3 : tensor<4x1x32x1x128xf16>) -> tensor<4x1x32x1x128xf16> + + flow.return %5 : tensor<4x1x32x1x128xf16> + } + + %collapsed = tensor.collapse_shape %4 [[0, 1], [2], [3], [4]] : tensor<4x1x32x1x128xf16> into tensor<4x32x1x128xf16> + util.return %collapsed : tensor<4x32x1x128xf16> +} + +// CHECK-LABEL: util.func public @clone_gather_lik +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[GATHER0:.+]] = linalg.generic +// CHECK: %[[GATHER1:.+]] = linalg.generic +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK: ins({{.*}}, %[[GATHER0]], %[[GATHER1]] +// CHECK: flow.return %[[ATTENTION]] diff --git a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir index 993b6903e910..1d10dfcb6592 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir @@ -97,7 +97,7 @@ util.func public @attention_dispatch(%arg0: tensor, %arg1: tensor tensor - %3 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%0, %1, %2, %arg3 : tensor, tensor, tensor, f16) outs(%arg4 : tensor) -> tensor + %3 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%0, %1, %2, %arg3 : tensor, tensor, tensor, f16) outs(%arg4 : tensor) -> tensor %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3 : tensor) outs(%arg4 : tensor) { ^bb0(%in: f16, %out: f16): @@ -123,3 +123,49 @@ util.func public @attention_dispatch(%arg0: tensor, %arg1: tensor, %arg1: tensor, %arg2: tensor, %arg3: f16, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor) -> tensor { + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor) outs(%arg4 : tensor) { + ^bb0(%in: f16, %out: f16): + %5 = arith.mulf %in, %in : f16 + linalg.yield %5 : f16 + } -> tensor + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor) outs(%arg4 : tensor) { + ^bb0(%in: f16, %out: f16): + %5 = arith.mulf %in, %in : f16 + linalg.yield %5 : f16 + } -> tensor + %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg2 : tensor) outs(%arg4 : tensor) { + ^bb0(%in: f16, %out: f16): + %5 = arith.mulf %in, %in : f16 + linalg.yield %5 : f16 + } -> tensor + + %3 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%0, %1, %2, %arg3, %arg4: tensor, tensor, tensor, f16, tensor) outs(%arg4 : tensor) -> tensor + + %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3 : tensor) outs(%arg4 : tensor) { + ^bb0(%in: f16, %out: f16): + %5 = arith.mulf %in, %in : f16 + linalg.yield %5 : f16 + } -> tensor + util.return %4 : tensor +} + +// CHECK-LABEL: util.func public @attention_dispatch_masked +// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region +// CHECK-NEXT: %[[GEN0:.+]] = linalg.generic +// CHECK: flow.return %[[GEN0]] +// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region +// CHECK-NEXT: %[[GEN1:.+]] = linalg.generic +// CHECK: flow.return %[[GEN1]] +// CHECK: %[[DISPATCH2:.+]] = flow.dispatch.region +// CHECK-NEXT: %[[GEN2:.+]] = linalg.generic +// CHECK: flow.return %[[GEN2]] +// CHECK: %[[RESULT:.+]] = flow.dispatch.region +// CHECK: %[[ATTN:.+]] = iree_linalg_ext.attention +// CHECK-SAME: ins(%[[DISPATCH0]], %[[DISPATCH1]], %[[DISPATCH2]] +// CHECK: %[[GEN2:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ATTN]] +// CHECK: flow.return %[[GEN2]] diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fold_transpose.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fold_transpose.mlir index 079acf3d140e..d0dc854130b1 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fold_transpose.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fold_transpose.mlir @@ -15,7 +15,7 @@ util.func public @transpose_attention(%arg0: tensor<4x64x32x128xf16>, %arg1: ten linalg.yield %in : f16 } -> tensor<4x32x64x128xf16> %4 = tensor.empty() : tensor<4x32x64x128xf16> - %5 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%1, %2, %3, %arg3 : tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, f16) outs(%4 : tensor<4x32x64x128xf16>) -> tensor<4x32x64x128xf16> + %5 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%1, %2, %3, %arg3 : tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, f16) outs(%4 : tensor<4x32x64x128xf16>) -> tensor<4x32x64x128xf16> %6 = tensor.empty() : tensor<4x64x32x128xf16> %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%5 : tensor<4x32x64x128xf16>) outs(%6 : tensor<4x64x32x128xf16>) { ^bb0(%in: f16, %out: f16): @@ -35,6 +35,55 @@ util.func public @transpose_attention(%arg0: tensor<4x64x32x128xf16>, %arg1: ten // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3)> // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d1, d3)> // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d1, d5)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]] + +// ----- + +util.func public @transposed_attention_masked(%arg0: tensor<4x64x32x128xf16>, %arg1: tensor<4x64x32x128xf16>, %arg2: tensor<4x64x32x128xf16>, %arg3: f16, %arg4: tensor<4x64x32x64xf16>) -> tensor<4x64x4096xf16> { + %0 = tensor.empty() : tensor<4x32x64x128xf16> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x64x32x128xf16>) outs(%0 : tensor<4x32x64x128xf16>) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<4x32x64x128xf16> + %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<4x64x32x128xf16>) outs(%0 : tensor<4x32x64x128xf16>) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<4x32x64x128xf16> + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<4x64x32x128xf16>) outs(%0 : tensor<4x32x64x128xf16>) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<4x32x64x128xf16> + %empty = tensor.empty() : tensor<4x32x64x64xf16> + %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg4 : tensor<4x64x32x64xf16>) outs(%empty : tensor<4x32x64x64xf16>) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<4x32x64x64xf16> + %5 = tensor.empty() : tensor<4x32x64x128xf16> + %6 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%1, %2, %3, %arg3, %4 : tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, tensor<4x32x64x128xf16>, f16, tensor<4x32x64x64xf16>) outs(%5 : tensor<4x32x64x128xf16>) -> tensor<4x32x64x128xf16> + %7 = tensor.empty() : tensor<4x64x32x128xf16> + %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6 : tensor<4x32x64x128xf16>) outs(%7 : tensor<4x64x32x128xf16>) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<4x64x32x128xf16> + %collapsed = tensor.collapse_shape %8 [[0], [1], [2, 3]] : tensor<4x64x32x128xf16> into tensor<4x64x4096xf16> + util.return %collapsed : tensor<4x64x4096xf16> +} + +// CHECK-LABEL: util.func public @transposed_attention_masked +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[ARG3:[A-Za-z0-9]+]]: f16 +// CHECK-SAME: %[[ARG4:[A-Za-z0-9]+]]: tensor +// CHECK: %[[RESULT:.+]] = iree_linalg_ext.attention +// CHECK-SAME: indexing_maps = +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d1, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d1, d5)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d4)> // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]] diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir index cbb6ebd7835c..447ca7d8a828 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/fold_attention_with_transpose.mlir @@ -15,6 +15,7 @@ util.func public @fuse_attention_expand_transpose( indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%arg0, %arg1, %arg2, %arg3 : tensor, tensor, tensor, f16) outs(%empty : tensor) -> tensor @@ -39,26 +40,30 @@ util.func public @fuse_attention_expand_transpose( // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]] -// CHECK-DAG: %[[D_SPLIT:.+]] = arith.divsi %[[D0]], %[[C2]] -// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[D1]], %[[D_SPLIT]], %[[D4]]) : tensor<2x?x?x?xf16> -// CHECK-DAG: %[[D_SPLIT2:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%[[D0]]] -// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]] -// CHECK-DAG: %[[D3:.+]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D1]], %[[D2]]{{\]}} -// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D2]]{{\]}} -// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT2]], %[[D3]], %[[D4]]{{\]}} +// CHECK-DAG: %[[D:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG2]], %[[C2]] +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[D]], %[[D0]], %[[D1]]) : tensor // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention // CHECK-SAME: indexing_maps = -// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>] -// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] : +// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> ()>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]] : // CHECK-SAME: outs(%[[EMPTY]] : -// CHECK: util.return %[[ATTENTION]] +// CHECK-DAG: %[[D_SPLIT:.+]] = arith.divsi %[[D]], %[[C2]] +// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[D_SPLIT]], %[[D0]], %[[D1]]] +// CHECK-DAG: %[[OUTS:.+]] = tensor.empty(%[[D0]], %[[D_SPLIT]], %[[D1]]) : tensor<2x?x?x?xf16> +// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = +// CHECK-SAME: [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>] +// CHECK-SAME: ins(%[[EXPANDED]] : +// CHECK-SAME: outs(%[[OUTS]] : +// CHECK: linalg.yield +// CHECK: util.return %[[TRANSPOSE]] // ----- @@ -70,6 +75,7 @@ util.func public @fuse_attention_expand_transpose_static( indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%empty: tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> @@ -91,16 +97,23 @@ util.func public @fuse_attention_expand_transpose_static( // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16> // CHECK-SAME: %[[ARG3:.+]]: f16) -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4096x10x64xf16> -// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 16] -// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 16] -// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 1024, 64] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<20x4096x64xf16> // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention // CHECK-SAME: indexing_maps = -// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d5)>] -// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]] : +// CHECK-SAME: [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> ()>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]] : // CHECK-SAME: outs(%[[EMPTY]] : -// CHECK: util.return %[[ATTENTION]] +// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ATTENTION]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 10, 4096, 64] +// CHECK-DAG: %[[OUTS:.+]] = tensor.empty() : tensor<2x4096x10x64xf16> +// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = +// CHECK-SAME: [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>] +// CHECK-SAME: ins(%[[EXPANDED]] : +// CHECK-SAME: outs(%[[OUTS]] : +// CHECK: linalg.yield +// CHECK: util.return %[[TRANSPOSE]] diff --git a/tests/e2e/attention/generate_e2e_attention_tests.py b/tests/e2e/attention/generate_e2e_attention_tests.py index d258dc1097b1..feaad0ee867e 100644 --- a/tests/e2e/attention/generate_e2e_attention_tests.py +++ b/tests/e2e/attention/generate_e2e_attention_tests.py @@ -217,6 +217,7 @@ def generate_function( f" indexing_maps = [affine_map<(batch, m, n, k1, k2) -> (batch, m, k1)>,\n" f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, k1)>,\n" f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, n)>,\n" + f" affine_map<(batch, m, n, k1, k2) -> ()>,\n" f" affine_map<(batch, m, n, k1, k2) -> (batch, m, n)>]\n}}" f" ins(%query, %key, %value, %scale_f16: {query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {F16})\n" f" outs(%result0: {result_tensor_type}) -> {result_tensor_type}\n" diff --git a/tests/e2e/linalg_ext_ops/attention.mlir b/tests/e2e/linalg_ext_ops/attention.mlir index c418809eedd7..c2ca83256b48 100644 --- a/tests/e2e/linalg_ext_ops/attention.mlir +++ b/tests/e2e/linalg_ext_ops/attention.mlir @@ -14,6 +14,7 @@ func.func @attention1x3x4() { %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%init : tensor<1x3x4xf32>) -> tensor<1x3x4xf32> @@ -44,6 +45,7 @@ func.func @attention1x4x4() { %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor<1x4x4xf32>, tensor<1x4x4xf32>, tensor<1x4x4xf32>, f32) outs(%init : tensor<1x4x4xf32>) -> tensor<1x4x4xf32> @@ -90,6 +92,7 @@ func.func @attention3x3x4() { %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%query, %key, %value, %scale : tensor<3x3x4xf32>, tensor<3x3x4xf32>, tensor<3x3x4xf32>, f32) outs(%init : tensor<3x3x4xf32>) -> tensor<3x3x4xf32>