Skip to content

Commit

Permalink
[LinalgExt] Masked Attention Implementation (iree-org#18525)
Browse files Browse the repository at this point in the history
Enables float/boolean mask as parameters and created linalg generic ops
to apply masking. This image (https://imgur.com/a/1MePgcy) elaborates on
the main files changed and how they enable masked attention:
- Blue boxes represent changed .cpp and .td files to
enable/pass/decompose the mask
  -  Yellow boxes represent the different op classes
- Red boxes represent test mlir files pertaining to certain .cpp/.td
implementations or ops
  
For quick reference, AggregateOpInterfaceImpl.cpp contains the bulk of
the actual mask decomposition (QK += mask)

And for clarification, TileAttention.cpp only holds the
convertToOnlineAttentionOp and getTileAttentionIndexingMaps functions;
TilingInterfaceImpl.cpp contains the main tiling capabilities in the
form of AttentionOp::getTiledImplementation and
OnlineAttentionOp::getTiledImplementation.
  

Updated version of iree-org#18461. This
version was created to include scale affine map and enable fused
attention (incorporated
https://github.com/IanWood1/iree/tree/raikonen/sdpa_mask).
- To that end, many modifications in tests are for adding the scale
affine map (without much functionality change)
- For tiling and decomposition tests, most functionality tests are
included in "tiling.mlir" and "decompose_online_attention.mlir". On the
other hand, the "tile_attention.mlir and "decompose_attention.mlir" are
old paths intended to be be retired and deprecate soon. Hence, no major
tests were added it there.

Test directory for numerical verification:
https://github.com/rohan-tan-bhowmik/iree-masked-attention-test

---------

Signed-off-by: Stanley Winata <[email protected]>
Co-authored-by: Stanley Winata <[email protected]>
Co-authored-by: Ian Wood <[email protected]>
  • Loading branch information
3 people authored Sep 22, 2024
1 parent 891f438 commit 9ee061d
Show file tree
Hide file tree
Showing 30 changed files with 1,211 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,22 @@ struct ScatterOpConversion
};
} // namespace

static SmallVector<AffineMap>
getStandardAttentionIndexingMaps(MLIRContext *ctx) {
static SmallVector<AffineMap> getStandardAttentionIndexingMaps(MLIRContext *ctx,
bool hasMask) {
AffineExpr m, n, k1, k2;
bindDims(ctx, m, n, k1, k2);

AffineMap qMap =
AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k1}, ctx);
AffineMap kMap =
AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx);
AffineMap vMap =
AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx);
AffineMap rMap =
AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx);

return {qMap, kMap, vMap, rMap};
auto qMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k1}, ctx);
auto kMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx);
auto vMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx);
auto sMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, ctx);
auto rMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx);
if (hasMask) {
// Add mask map only if it exists
auto mMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k2}, ctx);
return {qMap, kMap, vMap, sMap, mMap, rMap};
}
return {qMap, kMap, vMap, sMap, rMap};
}

struct AttentionOpConversion
Expand All @@ -100,6 +101,7 @@ struct AttentionOpConversion
Value query = op.getQuery();
Value key = op.getKey();
Value value = op.getValue();
std::optional<Value> optionalMask = op.getAttnMask();

ShapedType outputType = op.getOutputType();

Expand Down Expand Up @@ -147,18 +149,22 @@ struct AttentionOpConversion
loc, targetType, rewriter.getFloatAttr(targetType, dk));

// Add batches to standard attention indexing maps.
SmallVector<AffineMap> indexingMaps = getStandardAttentionIndexingMaps(ctx);
SmallVector<AffineMap> indexingMaps =
getStandardAttentionIndexingMaps(ctx, optionalMask.has_value());

int64_t numBatches = op.getQueryType().getRank() - 2;
for (AffineMap &map : indexingMaps) {
map = map.shiftDims(numBatches);
if (map.getNumResults() == 0)
continue;
for (int batch : llvm::seq<int>(numBatches)) {
map = map.insertResult(rewriter.getAffineDimExpr(batch), batch);
}
}

auto attention = rewriter.create<IREE::LinalgExt::AttentionOp>(
loc, result.getType(), query, key, value, scale, result,
rewriter.getAffineMapArrayAttr(indexingMaps));
rewriter.getAffineMapArrayAttr(indexingMaps), optionalMask);

rewriter.replaceOp(op, attention.getResult(0));
return success();
Expand Down
18 changes: 11 additions & 7 deletions compiler/plugins/input/Torch/InputConversion/test/attention.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ func.func @attention(%arg0: tensor<5x2x3x4xf32>, %arg1: tensor<5x2x3x4xf32>, %ar
// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>
// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>

// CHECK-LABEL: func.func @attention(
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x2x3x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>,
// CHECK: %arg3: tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> {
// CHECK-SAME: %[[ARG3:.*]]: tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> {
// CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5x2x3x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32>
// CHECK: return %[[ATTN]] : tensor<5x2x3x4xf32>

// -----
Expand All @@ -27,14 +28,15 @@ func.func @attention(%arg0: tensor<5x2x8x4xf32>, %arg1: tensor<5x2x3x4xf32>, %ar
// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>
// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>

// CHECK-LABEL: func.func @attention(
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x2x8x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>,
// CHECK: %arg3: tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> {
// CHECK-SAME: %[[ARG3:.*]]: tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> {
// CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5x2x8x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x8x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x8x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32>
// CHECK: return %[[ATTN]] : tensor<5x2x8x4xf32>

// -----
Expand All @@ -46,14 +48,15 @@ func.func @attention(%arg0: tensor<1x3x4xf32>, %arg1: tensor<1x3x4xf32>, %arg2:
// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)>
// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>

// CHECK-LABEL: func.func @attention(
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x3x4xf32>, %[[ARG1:.*]]: tensor<1x3x4xf32>, %[[ARG2:.*]]: tensor<1x3x4xf32>,
// CHECK: %arg3: tensor<1x3x4xf32>) -> tensor<1x3x4xf32> {
// CHECK: %[[ARG3:.*]]: tensor<1x3x4xf32>) -> tensor<1x3x4xf32> {
// CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x3x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
// CHECK: return %[[ATTN]] : tensor<1x3x4xf32>

// -----
Expand All @@ -65,6 +68,7 @@ func.func @attention_dyn(%arg0: tensor<?x?x4xf32>, %arg1: tensor<?x?x4xf32>, %ar
// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)>
// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>

// CHECK-LABEL: func.func @attention_dyn(
Expand All @@ -76,5 +80,5 @@ func.func @attention_dyn(%arg0: tensor<?x?x4xf32>, %arg1: tensor<?x?x4xf32>, %ar
// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[EMPTY:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<?x?x4xf32>, tensor<?x?x4xf32>, tensor<?x?x4xf32>, f32) outs(%[[EMPTY]] : tensor<?x?x4xf32>) -> tensor<?x?x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<?x?x4xf32>, tensor<?x?x4xf32>, tensor<?x?x4xf32>, f32) outs(%[[EMPTY]] : tensor<?x?x4xf32>) -> tensor<?x?x4xf32>
// CHECK: return %[[ATTN]] : tensor<?x?x4xf32>
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ func.func @conv() attributes {translation_info = #iree_codegen.translation_info<
#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>

Expand All @@ -272,7 +273,7 @@ func.func @online_attention(%query: tensor<192x1024x64xf16>,
%scale = arith.constant 1.0 : f16

%out:3 = iree_linalg_ext.online_attention
{ indexing_maps = [#mapQ, #mapK, #mapV, #mapO, #mapR, #mapR],
{ indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR],
lowering_config = #config }
ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16)
outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1747,6 +1747,7 @@ func.func @attention() attributes {hal.executable.target = #executable_target_em
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%4, %5, %6, %scale : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16)
outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ func.func @attention_20x4096x64x4096x64() {
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : tensor<20x4096x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
Expand Down Expand Up @@ -407,6 +408,7 @@ func.func @attention_large_head_dim_shared_mem() {
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d1, d2, d3, d4) -> (d1, d2)>,
affine_map<(d1, d2, d3, d4) -> (d3, d2)>,
affine_map<(d1, d2, d3, d4) -> (d3, d4)>,
affine_map<(d1, d2, d3, d4) -> ()>,
affine_map<(d1, d2, d3, d4) -> (d1, d4)>]}
ins(%4, %5, %6, %cst : tensor<1024x512xf16>, tensor<128x512xf16>, tensor<128x512xf16>, f16) outs(%7 : tensor<1024x512xf16>) -> tensor<1024x512xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0], sizes = [1024, 512], strides = [1, 1] : tensor<1024x512xf16> -> !flow.dispatch.tensor<writeonly:tensor<1024x512xf16>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ hal.executable private @attention_20x4096x64x4096x64 {
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>],
lowering_config = #config}
ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
Expand Down Expand Up @@ -637,7 +638,7 @@ hal.executable private @attention_multiple_m_transpose {
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
%7 = tensor.empty() : tensor<64x4608x24x128xf16>
%8 = tensor.empty() : tensor<24x64x4608x128xf16>
%9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
%9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], lowering_config = #config} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
Expand Down Expand Up @@ -692,7 +693,7 @@ hal.executable private @attention_mfma_32x32x8 {
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
%7 = tensor.empty() : tensor<64x4608x24x128xf16>
%8 = tensor.empty() : tensor<24x64x4608x128xf16>
%9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
%9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16>
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], lowering_config = #config} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ func.func @_attention_dispatch_0() {
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%4, %5, %6, %cst : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) outs(%7 : tensor<192x1024x64xf16>) -> tensor<192x1024x64xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : tensor<192x1024x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<192x1024x64xf16>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ func.func @attention_dispatch_0_attention_16x16384x128xf16() {
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%4, %5, %6, %scale : tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, f16) outs(%7 : tensor<16x16384x128xf16>) -> tensor<16x16384x128xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : tensor<16x16384x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<16x16384x128xf16>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,7 @@ def LinalgFusionInterface : OpInterface<"LinalgFusionOpInterface"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == $_op);
if(opOperand->getOperandNumber() >= $_op.getNumDpsInputs()){
return $_op.getIndexingMapsForResults()[opOperand->getOperandNumber() - $_op.getNumDpsInputs()];
}else {
return $_op.getIndexingMapsForOperands()[opOperand->getOperandNumber()];
}
return getIndexingMapsArray()[opOperand->getOperandNumber()];
}]
>,

Expand Down
Loading

0 comments on commit 9ee061d

Please sign in to comment.