From 84b5d96df9b1ffb7dbd2edf34b0a03fd4fe4220b Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Mon, 4 Dec 2023 09:06:24 -0800 Subject: [PATCH] Shashank/seq id flash attn (#738) * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * Update llmfoundry/models/layers/attention.py Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. --------- Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> --- llmfoundry/models/layers/attention.py | 104 ++++--- llmfoundry/models/layers/blocks.py | 3 + llmfoundry/models/mpt/configuration_mpt.py | 15 +- llmfoundry/models/mpt/modeling_mpt.py | 142 ++++++++-- tests/models/layers/test_flash_attn.py | 255 ++++++++++++++++++ .../models/layers/test_flash_triton_torch.py | 60 ++++- tests/models/test_model.py | 110 ++++++++ 7 files changed, 613 insertions(+), 76 deletions(-) create mode 100644 tests/models/layers/test_flash_attn.py diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index dd7f40cd19..86e49c315d 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -92,7 +92,6 @@ def scaled_multihead_dot_product_attention( multiquery: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: - if multiquery: warnings.warn( DeprecationWarning( @@ -219,6 +218,9 @@ def flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, + attention_mask_in_length: Optional[torch.Tensor] = None, + should_repeat_kv_for_gqa: Optional[bool] = True, + sliding_window_size: int = -1, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: try: @@ -249,58 +251,65 @@ def flash_attn_fn( past_key_value = (key, value) - if attn_bias is not None: - # clamp to 0 necessary for torch 2.0 compile() - _s_q = max(0, attn_bias.size(2) - query.size(1)) - _s_k = max(0, attn_bias.size(3) - key.size(1)) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if attn_bias is not None: raise NotImplementedError(f'attn_bias not implemented for flash attn.') batch_size, seqlen = query.shape[:2] - if key_padding_mask is None: - key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) - query_padding_mask = key_padding_mask[:, -query.size(1):] + if attention_mask_in_length is None: + if key_padding_mask is None: + key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) + query_padding_mask = key_padding_mask[:, -query.size(1):] + unpadding_function = bert_padding.unpad_input + else: + key_padding_mask = attention_mask_in_length + query_padding_mask = attention_mask_in_length + unpadding_function = bert_padding.unpad_input_for_concatenated_sequences - query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input( + query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( query, query_padding_mask) query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) - key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input( + key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function( key, key_padding_mask) key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask) + value_unpad, _, _, _ = unpadding_function(value, key_padding_mask) value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - # multi-query case - if kv_n_heads == 1: - # Expanding a tensor does not allocate new memory, but only creates a new - # view on the existing tensor where a dimension of size one is expanded - # to a larger size by setting the stride to 0. - # - pytorch docs - # - # hopefully the kernels can utilize this and we're jot just wasting BW here - key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, - key_unpad.size(-1)) - value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, - value_unpad.size(-1)) - # grouped query case - elif kv_n_heads < n_heads: - # Each query belong to a group of kv heads of group size n_heads // kv_n_heads - # We repeat each kv head by the group size number to use the underlying MHA kernels - - # since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d) - # we use .view to modify {key, value}_unpad appropriately + if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and ( + not should_repeat_kv_for_gqa): + raise ValueError( + 'For Grouped Query Attention or Multi Query Attention, should_repeat_kv_for_gqa should be set to True if not using Flash Attention v2.' + ) - key_unpad = repeat_kv_for_gqa( - key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1), - n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1) - value_unpad = repeat_kv_for_gqa( - value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1), - n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1) + if should_repeat_kv_for_gqa: + # multi-query case + if kv_n_heads == 1: + # Expanding a tensor does not allocate new memory, but only creates a new + # view on the existing tensor where a dimension of size one is expanded + # to a larger size by setting the stride to 0. + # - pytorch docs + # + # hopefully the kernels can utilize this and we're jot just wasting BW here + key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, + key_unpad.size(-1)) + value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, + value_unpad.size(-1)) + # grouped query case + elif kv_n_heads < n_heads: + # Each query belong to a group of kv heads of group size n_heads // kv_n_heads + # We repeat each kv head by the group size number to use the underlying MHA kernels + + # since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d) + # we use .view to modify {key, value}_unpad appropriately + + key_unpad = repeat_kv_for_gqa( + key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1), + n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1) + value_unpad = repeat_kv_for_gqa( + value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1), + n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1) dropout_p = dropout_p if training else 0.0 @@ -331,7 +340,8 @@ def flash_attn_fn( dropout_p=dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, - return_attn_probs=needs_weights) + return_attn_probs=needs_weights, + window_size=(sliding_window_size, sliding_window_size)) else: raise RuntimeError( 'flash-attn==1.0.9 or flash-attn==2.3.2 is required.') @@ -490,6 +500,7 @@ def __init__( fc_type: str = 'torch', device: Optional[str] = None, bias: bool = True, + sliding_window_size: int = -1, ): super().__init__() @@ -500,6 +511,7 @@ def __init__( self.d_model = d_model self.n_heads = n_heads self.kv_n_heads = kv_n_heads + self.sliding_window_size = sliding_window_size self.head_dim = d_model // n_heads @@ -569,6 +581,7 @@ def forward( rotary_emb_w_meta_info: Optional[dict] = None, is_causal: bool = True, needs_weights: bool = False, + attention_mask_in_length: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) @@ -626,6 +639,14 @@ def forward( query = query.view(bsz, seqlen, self.d_model) key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim) + extra_attn_kwargs = {} + if self.attn_impl == 'flash': + extra_attn_kwargs = { + 'attention_mask_in_length': attention_mask_in_length, + 'should_repeat_kv_for_gqa': not is_flash_v2_installed(), + 'sliding_window_size': self.sliding_window_size, + } + context, attn_weights, past_key_value = self.attn_fn( query, key, @@ -640,6 +661,7 @@ def forward( dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, + **extra_attn_kwargs, ) return self.out_proj(context), attn_weights, past_key_value @@ -665,6 +687,7 @@ def __init__( fc_type: str = 'torch', device: Optional[str] = None, bias: bool = True, + sliding_window_size: int = -1, ): super().__init__( d_model=d_model, @@ -679,6 +702,7 @@ def __init__( fc_type=fc_type, device=device, bias=bias, + sliding_window_size=sliding_window_size, ) @@ -702,6 +726,7 @@ def __init__( fc_type: str = 'torch', device: Optional[str] = None, bias: bool = True, + sliding_window_size: int = -1, ): super().__init__( d_model=d_model, @@ -716,6 +741,7 @@ def __init__( fc_type=fc_type, device=device, bias=bias, + sliding_window_size=sliding_window_size, ) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 6605807c6b..6db9ff22ca 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -21,6 +21,7 @@ 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, + 'sliding_window_size': -1, 'alibi': False, 'alibi_bias_max': 8, 'rope': False, @@ -113,6 +114,7 @@ def forward( attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, + attention_mask_in_length: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) @@ -124,6 +126,7 @@ def forward( attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, + attention_mask_in_length=attention_mask_in_length, ) x = x + self.resid_attn_dropout(b) m = x diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index f8022808bf..47fd5ac9e5 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -91,6 +91,7 @@ def __init__( When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates which sub-sequence each token belongs to. Defaults to ``False`` meaning any provided `sequence_id` will be ignored. + sliding_window_size (int): Window size for sliding window local attention. Defaults to -1, which means no sliding window. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size, i + seqlen_k - seqlen_q + window_size] inclusive. Only works for flash attention v2.3.0 or higher. alibi (bool): Whether to use the alibi bias instead of position embeddings. alibi_bias_max (int): The maximum value of the alibi bias. rope (bool): Whether to use rotary positional embeddings. @@ -221,10 +222,12 @@ def _validate_config(self) -> None: ]: raise NotImplementedError( 'alibi only implemented with torch and triton attention.') - if self.attn_config['attn_uses_sequence_id'] and self.attn_config[ - 'attn_impl'] not in ['torch', 'triton']: + if self.attn_config['attn_uses_sequence_id'] and not ( + self.attn_config['attn_impl'] in ['torch', 'triton'] or + (self.attn_config['attn_impl'] == 'flash' and + is_flash_v2_installed(v2_version='v2.1.2'))): raise NotImplementedError( - 'attn_uses_sequence_id only implemented with torch and triton attention.' + 'attn_uses_sequence_id only implemented with torch, triton, and flash (v2.1.2 or higher) attention.' ) if self.attn_config['rope'] and (self.attn_config['rope_impl'] not in ['dail', 'hf']): @@ -251,6 +254,12 @@ def _validate_config(self) -> None: raise ImportError( 'If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support' ) + if self.attn_config['sliding_window_size'] != -1 and not ( + self.attn_config['attn_impl'] == 'flash' and + is_flash_v2_installed(v2_version='v2.3.0')): + raise NotImplementedError( + 'sliding window only implemented with flash attention v2.3.0 or higher.' + ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 34b8992d3e..e2d2ee6fbc 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -132,6 +132,114 @@ def gen_rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, raise ValueError('rope_impl needs to be either dail or hf') +def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, + attn_uses_sequence_id: bool, attn_impl: str, + attention_mask: Union[torch.Tensor, None]): + """Generates the attention mask used for sequence masking in FA v2. + + Only supports sequence id based sparse attention for no attention masking or attention masking with right padding. + In case of left padding: + 1. Training with left padding is not supported in MPT (see https://github.com/mosaicml/llm-foundry/blob/1eecd4cb8e734499f77f6a35f657b8b20c0adfcb/llmfoundry/models/mpt/modeling_mpt.py#L407). + 2. For generation with left padding, we only have a single sequence id per sample, so we don't need sequence id based sparse attention. + + Args: + sequence_id (Union[None, torch.Tensor]): Tensor containing the sequence id for each token. Shape (batch_size, seq_len). + S (int): Sequence length + attn_uses_sequence_id (bool): Whether the attention uses sequence id based masking. + attn_impl (str): Attention implementation. This function is only creates attention_mask_in_length for flash attention. + attention_mask (Union[torch.Tensor, None]): Attention mask tensor of shape (batch_size, seq_len) + + Returns: + attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: + ``` + [ + [2, 3, 0, 0, 0, 0], + [3, 2, 0, 0, 0, 0], + [6, 0, 0, 0, 0, 0] + ] + ``` + , which refers to the 3D-attention mask: + ``` + [ + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1] + ] + ] + ```. + (The description above is taken verbatim from https://github.com/Dao-AILab/flash-attention/blob/9356a1c0389660d7e231ff3163c1ac17d9e3824a/flash_attn/bert_padding.py#L125 .) + """ + attention_mask_in_length = None + if (sequence_id is not None) and attn_uses_sequence_id and (attn_impl + == 'flash'): + # Check if sequence has left padding. If yes, raise an error. + if (attention_mask is not None) and (attention_mask[:, 0].sum() != + attention_mask.shape[0]): + raise NotImplementedError( + 'Left padding is not supported with flash attention when attn_uses_sequence_id is set to True.' + ) + if S != sequence_id.shape[-1]: + raise ValueError( + f'Sequence length ({S}) does not match length of sequences in sequence_id ({sequence_id.shape[-1]}).' + ) + attention_mask_in_length = torch.nn.functional.one_hot(sequence_id) + if attention_mask is not None: + attention_mask_in_length = attention_mask_in_length.masked_fill( + ~attention_mask.unsqueeze(-1), 0) + attention_mask_in_length = attention_mask_in_length.sum(dim=1) + attention_mask_in_length = torch.nn.functional.pad( + attention_mask_in_length, + (0, S - attention_mask_in_length.shape[-1]), + mode='constant', + value=0) + + return attention_mask_in_length + + +def apply_sequence_id(attn_bias: torch.Tensor, sequence_id: torch.LongTensor, + max_seq_len: int) -> torch.Tensor: + seq_len = sequence_id.shape[-1] + if seq_len > max_seq_len: + raise ValueError( + f'sequence_id sequence length cannot exceed max_seq_len={max_seq_len}' + ) + + # select seq_len subset of attn mask + attn_bias = attn_bias[..., :seq_len, :seq_len] + + # Restrict attention to tokens that share the same value + # in sequence_id + cannot_attend = torch.logical_not( + torch.eq( + sequence_id.view(-1, seq_len, 1), + sequence_id.view(-1, 1, seq_len), + )).unsqueeze(1) + min_val = torch.finfo(attn_bias.dtype).min + attn_bias = attn_bias.masked_fill(cannot_attend, min_val) + + return attn_bias + + class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig base_model_prefix = 'model' @@ -286,7 +394,8 @@ def _attn_bias( # If using torch or triton, we incorporate sequence_id (if appropriate) if self.attn_uses_sequence_id and sequence_id is not None: assert isinstance(attn_bias, torch.Tensor) # pyright - attn_bias = self._apply_sequence_id(attn_bias, sequence_id) + attn_bias = apply_sequence_id(attn_bias, sequence_id, + self.config.max_seq_len) # If using torch or triton, we incorporate attention_mask. This will output # None in place of attention_mask since it will not be further needed in the @@ -343,29 +452,6 @@ def _apply_prefix_mask(self, attn_bias: torch.Tensor, return attn_bias - def _apply_sequence_id(self, attn_bias: torch.Tensor, - sequence_id: torch.LongTensor) -> torch.Tensor: - seq_len = sequence_id.shape[-1] - if seq_len > self.config.max_seq_len: - raise ValueError( - f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}' - ) - - # select seq_len subset of attn mask - attn_bias = attn_bias[..., :seq_len, :seq_len] - - # Restrict attention to tokens that share the same value - # in sequence_id - cannot_attend = torch.logical_not( - torch.eq( - sequence_id.view(-1, seq_len, 1), - sequence_id.view(-1, 1, seq_len), - )).unsqueeze(1) - min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill(cannot_attend, min_val) - - return attn_bias - def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -515,7 +601,12 @@ def forward( prefix_mask=prefix_mask, sequence_id=sequence_id, ) - + attention_mask_in_length = gen_attention_mask_in_length( + sequence_id=sequence_id, + S=S, + attn_uses_sequence_id=self.attn_uses_sequence_id, + attn_impl=self.attn_impl, + attention_mask=attention_mask) # initialize the past key values cache if it should be used presents = () if use_cache else None if use_cache and past_key_values is None: @@ -538,6 +629,7 @@ def forward( attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), + attention_mask_in_length=attention_mask_in_length, ) if presents is not None: presents += (present,) diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py new file mode 100644 index 0000000000..acefd2c42d --- /dev/null +++ b/tests/models/layers/test_flash_attn.py @@ -0,0 +1,255 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import math + +import pytest +import torch + +from llmfoundry.models.layers.attention import (flash_attn_fn, + is_flash_v2_installed, + triton_flash_attn_fn) + + +@pytest.mark.gpu +@pytest.mark.skipif( + not is_flash_v2_installed(), + reason='GQA natively only supported by Flash Attention after v2.') +@pytest.mark.parametrize('kv_n_heads', [1, 4, 8]) +def test_gqa_kv_repetition(kv_n_heads: int): + # Test that flash attention v2 with GQA (kv_n_heads < n_heads) works the same + # whether we repeat the kv_n_heads explicitly or flash attention v2 handles it on its own. + d = 128 + n_heads = 8 + seqlen_1 = 6 + bsz = 2 + + query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(torch.bfloat16).cuda() + query_1.requires_grad = True + key_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda() + key_1.requires_grad = True + value_1 = torch.randn(bsz, seqlen_1, + kv_n_heads * d).to(torch.bfloat16).cuda() + value_1.requires_grad = True + + output_1, _, _ = flash_attn_fn(query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + attention_mask_in_length=None, + should_repeat_kv_for_gqa=True) + + output_1.sum().backward() + + query_2 = query_1.detach().clone() + query_2.requires_grad = True + key_2 = key_1.detach().clone() + key_2.requires_grad = True + value_2 = value_1.detach().clone() + value_2.requires_grad = True + + output_2, _, _ = flash_attn_fn(query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + attention_mask_in_length=None, + should_repeat_kv_for_gqa=False) + + output_2.sum().backward() + assert torch.allclose(output_1, output_2) + assert torch.allclose(query_1.grad, query_2.grad) # type: ignore + assert torch.allclose(key_1.grad, key_2.grad) # type: ignore + assert torch.allclose(value_1.grad, value_2.grad) # type: ignore + + +@pytest.mark.gpu +@pytest.mark.skipif( + not is_flash_v2_installed(v2_version='v2.1.2'), + reason= + 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.' +) +def test_seq_id_masking_FA_v2(): + # Test that flash attention v2 with sequence id masking works correctly. + d = 128 + n_heads = 4 + kv_n_heads = 4 + seqlen_1 = 6 + bsz = 2 + + query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(torch.bfloat16).cuda() + query_1.requires_grad = True + key_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda() + key_1.requires_grad = True + value_1 = torch.randn(bsz, seqlen_1, + kv_n_heads * d).to(torch.bfloat16).cuda() + value_1.requires_grad = True + + seq_ranges = [ + (0, 3), (3, 5), (5, 6) + ] # Each batch has 3 sequences of length 3, 2, and 1 respectively. + attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0], + [3, 2, 1, 0, 0, + 0]]).to(torch.int64).cuda() + + output_1, _, _ = flash_attn_fn( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + attention_mask_in_length=attention_mask_in_length_1) + + output_1.sum().backward() + + for seq_range in seq_ranges: + query_2 = query_1.detach().clone()[:, seq_range[0]:seq_range[1], :] + query_2.requires_grad = True + key_2 = key_1.detach().clone()[:, seq_range[0]:seq_range[1], :] + key_2.requires_grad = True + value_2 = value_1.detach().clone()[:, seq_range[0]:seq_range[1], :] + value_2.requires_grad = True + + output_2, _, _ = flash_attn_fn(query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + attention_mask_in_length=None) + + output_2.sum().backward() + assert torch.allclose(output_1[:, seq_range[0]:seq_range[1], :], + output_2) + assert torch.allclose( + query_1.grad[:, seq_range[0]:seq_range[1], :], # type: ignore + query_2.grad) # type: ignore + assert torch.allclose( + key_1.grad[:, seq_range[0]:seq_range[1], :], # type: ignore + key_2.grad) # type: ignore + assert torch.allclose( + value_1.grad[:, seq_range[0]:seq_range[1], :], # type: ignore + value_2.grad) # type: ignore + + +@pytest.mark.gpu +@pytest.mark.skipif( + not is_flash_v2_installed(v2_version='v2.3.0'), + reason= + 'Sliding window attention only supported by Flash Attention after v2.3.0.') +@pytest.mark.parametrize('sliding_window_size', [1, 4, 8]) +def test_sliding_window(sliding_window_size: int): + # Test that sliding window attention works as expected. + dtype = torch.bfloat16 + device = 'cuda' + d = 128 + n_heads = 8 + seqlen_1 = 8 + bsz = 2 + + query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, + device=device) + query_1.requires_grad = True + key_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, + device=device) + key_1.requires_grad = True + value_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, + device=device) + value_1.requires_grad = True + + output_1, _, _ = flash_attn_fn(query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + attention_mask_in_length=None, + should_repeat_kv_for_gqa=True, + sliding_window_size=sliding_window_size) + + output_1.sum().backward() + + query_2 = query_1.detach().clone() + query_2.requires_grad = True + key_2 = key_1.detach().clone() + key_2.requires_grad = True + value_2 = value_1.detach().clone() + value_2.requires_grad = True + + attn_bias_2 = torch.zeros(1, 1, seqlen_1, seqlen_1).to(dtype=dtype, + device=device) + + window_mask_2 = torch.tril( + torch.ones(seqlen_1, seqlen_1), diagonal=-(sliding_window_size + 1)).to( + dtype=dtype, device=device) * torch.finfo(attn_bias_2.dtype).min + attn_bias_2 = attn_bias_2 + window_mask_2 + output_2, _, _ = triton_flash_attn_fn( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=attn_bias_2, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + ) + + output_2.sum().backward() + + assert torch.allclose(output_1, output_2) + assert torch.norm(query_2.grad - query_1.grad # type: ignore + ) <= 1e-2 + 1e-2 * torch.norm(query_2.grad) + assert torch.norm(key_2.grad - key_1.grad # type: ignore + ) <= 1e-2 + 1e-2 * torch.norm(key_2.grad) + assert torch.norm(value_2.grad - value_1.grad # type: ignore + ) <= 1e-2 + 1e-2 * torch.norm(value_2.grad) diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py index e140f678bc..454fda311d 100644 --- a/tests/models/layers/test_flash_triton_torch.py +++ b/tests/models/layers/test_flash_triton_torch.py @@ -7,7 +7,9 @@ from llmfoundry.models.layers import attention from llmfoundry.models.layers.attention import is_flash_v2_installed -from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding +from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id, + gen_attention_mask_in_length, + gen_rotary_embedding) def allclose_helper(t0: torch.Tensor, @@ -54,6 +56,7 @@ def allclose_helper(t0: torch.Tensor, @pytest.mark.parametrize( 'attn_type', ['multihead_attention', 'multiquery_attention', 'grouped_query_attention']) +@pytest.mark.parametrize('attn_uses_sequence_id', [True, False]) @pytest.mark.parametrize('pad_attention_mask', [True, False]) def test_attn_impl(attn_impl_0: str, attn_impl_1: str, @@ -61,6 +64,7 @@ def test_attn_impl(attn_impl_0: str, qk_ln: bool, pos_emb_config: dict, attn_type: str, + attn_uses_sequence_id: bool, pad_attention_mask: bool, device: str = 'cuda'): """Compare all attn impl with each other. @@ -77,6 +81,16 @@ def test_attn_impl(attn_impl_0: str, == 'dail') and (not is_flash_v2_installed()): pytest.skip('dail implementation of rope requires flash attention 2.') + if attn_uses_sequence_id and ( + attn_impl_0 == 'flash' or attn_impl_1 + == 'flash') and (not is_flash_v2_installed(v2_version='v2.1.2')): + pytest.skip( + 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.' + ) + + if not (alibi or rope) and attn_uses_sequence_id: + pytest.skip('attn_uses_sequence_id requires alibi or rope.') + cfg = om.create({ 'attn_impl': 'flash', 'd_model': 64, @@ -91,6 +105,14 @@ def test_attn_impl(attn_impl_0: str, if attn_type == 'grouped_query_attention': cfg.kv_n_heads = 2 + sequence_id = None + if attn_uses_sequence_id: + assert n == 2 + assert s >= 4 + sequence_id = torch.LongTensor([[0] * 2 + [1] * (s - 2), + [0] * 4 + [1] * (s - 4) + ]).to(device=device) + cfg.attn_impl = attn_impl_0 attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) cfg.attn_impl = attn_impl_1 @@ -113,7 +135,7 @@ def gen_bias(attn_impl: str): s, alibi, prefix_lm=False, - use_sequence_id=False, + use_sequence_id=attn_uses_sequence_id, causal=causal) if bs is not None: attn_bias = torch.zeros(*bs, device=device) @@ -126,17 +148,35 @@ def gen_bias(attn_impl: str): alibi=alibi, alibi_bias_max=8, ) + if attn_impl != 'flash' and attn_uses_sequence_id and sequence_id is not None: + assert isinstance(attn_bias, torch.Tensor) # pyright + attn_bias = apply_sequence_id( + attn_bias, + sequence_id, # type: ignore + s) return attn_bias + attention_mask_in_length_0 = gen_attention_mask_in_length( + sequence_id=sequence_id, + S=s, + attn_uses_sequence_id=attn_uses_sequence_id, + attn_impl=attn_impl_0, + attention_mask=attention_mask) + attention_mask_in_length_1 = gen_attention_mask_in_length( + sequence_id=sequence_id, + S=s, + attn_uses_sequence_id=attn_uses_sequence_id, + attn_impl=attn_impl_1, + attention_mask=attention_mask) + x0 = torch.randn(n, s, f).to(device) x1 = x0.clone().detach() x0.requires_grad = True x1.requires_grad = True with torch.autocast(x0.device.type): - attn_bias = gen_bias(attn0.attn_impl) - + attn_bias_0 = gen_bias(attn_impl_0) rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( @@ -165,17 +205,19 @@ def gen_bias(attn_impl: str): y0, _, _ = attn0(x0, past_key_value=None, - attn_bias=attn_bias, + attn_bias=attn_bias_0, attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, - is_causal=True) - attn_bias = gen_bias(attn1.attn_impl) + is_causal=True, + attention_mask_in_length=attention_mask_in_length_0) + attn_bias_1 = gen_bias(attn_impl_1) y1, _, _ = attn1(x1, past_key_value=None, - attn_bias=attn_bias, + attn_bias=attn_bias_1, attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, - is_causal=True) + is_causal=True, + attention_mask_in_length=attention_mask_in_length_1) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index acb2074ae9..98a556f534 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -555,6 +555,116 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): assert block.resid_ffn_dropout.p == 0.2 +@pytest.mark.gpu +@pytest.mark.parametrize('attention_impl', ['flash', 'triton', 'torch']) +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): + # Testing the output of concatenated sequence with sequence id masking vs individual sequences. + alibi = pos_emb_config['alibi'] + if alibi and attention_impl == 'flash': + pytest.skip(f'alibi only implemented with torch and triton attention.') + + rope = pos_emb_config['rope'] + if rope and pos_emb_config[ + 'rope_impl'] == 'dail' and not is_flash_v2_installed(): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') + + if attention_impl == 'flash' and ( + not is_flash_v2_installed(v2_version='v2.1.2')): + pytest.skip( + 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.' + ) + + composer_device = get_device(None) + + hf_config = MPTConfig( + init_device='cpu', + d_model=128, + n_heads=1, + n_layers=2, + expansion_ratio=2, + max_seq_len=2048, + emb_pdrop=0.1, + resid_pdrop=0.2, + attn_config={ + 'attn_impl': attention_impl, + 'attn_uses_sequence_id': True, + **pos_emb_config, + }, + init_config={ + 'name': 'baseline_', + 'init_std': 0.02, + }, + ) + mpt = MPTForCausalLM(hf_config) + mpt.eval() + mpt = composer_device.module_to_device(mpt) + + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): + # padding on the right side of the input + concatenated_seq_ids = torch.tensor([[11274, 16390, 11, 4332, 323, 423], + [2342, 12, 111, 123, 50256, 342]]) + concatenated_seq_ids = composer_device.tensor_to_device( + concatenated_seq_ids) + + sequence_id = torch.tensor([[0, 0, 0, 1, 2, 2], [0, 0, 0, 1, 2, 2]]) + sequence_id = composer_device.tensor_to_device(sequence_id) + + first_seq_ids = torch.tensor([[11274, 16390, 11], [2342, 12, 111]]) + first_seq_ids = composer_device.tensor_to_device(first_seq_ids) + + second_seq_ids = torch.tensor([[4332], [123]]) + second_seq_ids = composer_device.tensor_to_device(second_seq_ids) + + third_seq_ids = torch.tensor([[323, 423], [50256, 342]]) + third_seq_ids = composer_device.tensor_to_device(third_seq_ids) + + concatenated_seq_output = mpt(concatenated_seq_ids, + sequence_id=sequence_id).logits + first_seq_output = mpt(first_seq_ids).logits + second_seq_output = mpt(second_seq_ids).logits + third_seq_output = mpt(third_seq_ids).logits + + assert torch.allclose(concatenated_seq_output[:, :3], + first_seq_output, + atol=2e-6 if attention_impl == 'torch' else 1e-8) + assert torch.allclose(concatenated_seq_output[:, 3:4], + second_seq_output, + atol=2e-6 if attention_impl == 'torch' else 1e-8) + atol = 1e-8 + if attention_impl == 'torch': + atol = 2e-6 + elif pos_emb_config['rope']: + atol = 2e-2 + assert torch.allclose(concatenated_seq_output[:, 4:6], + third_seq_output, + atol=atol) + + @pytest.mark.parametrize('attention_impl', [ 'torch', pytest.param('flash', marks=pytest.mark.gpu),