From 1793c366fd49f96685481e086d7584f21afef450 Mon Sep 17 00:00:00 2001 From: Sasha Doubov Date: Tue, 21 Nov 2023 15:48:38 -0800 Subject: [PATCH] Fix flash attention GQA bug to use the dynamic size of the key/value tensors - used for eval/inference (#756) --- llmfoundry/models/layers/attention.py | 8 ++++---- tests/test_flash_triton_torch.py | 7 +++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 0503d6d75a..dd7f40cd19 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -296,11 +296,11 @@ def flash_attn_fn( # we use .view to modify {key, value}_unpad appropriately key_unpad = repeat_kv_for_gqa( - key_unpad.view(batch_size, seqlen, kv_n_heads, -1), - n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1) + key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1), + n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1) value_unpad = repeat_kv_for_gqa( - value_unpad.view(batch_size, seqlen, kv_n_heads, -1), - n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1) + value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1), + n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1) dropout_p = dropout_p if training else 0.0 diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index 2059585a35..6d75efa1d2 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -54,12 +54,14 @@ def allclose_helper(t0: torch.Tensor, @pytest.mark.parametrize( 'attn_type', ['multihead_attention', 'multiquery_attention', 'grouped_query_attention']) +@pytest.mark.parametrize('pad_attention_mask', [True, False]) def test_attn_impl(attn_impl_0: str, attn_impl_1: str, clip_qkv: bool, qk_ln: bool, pos_emb_config: dict, attn_type: str, + pad_attention_mask: bool, device: str = 'cuda'): """Compare all attn impl with each other. @@ -98,6 +100,11 @@ def test_attn_impl(attn_impl_0: str, attention_mask = torch.ones(n, s).to(device).bool() + if pad_attention_mask: + # zero out the last third of the attention mask + # to simulate padding + attention_mask[:, :s // 3] = 0 + def gen_bias(attn_impl: str): causal = True attn_bias = None