-
Notifications
You must be signed in to change notification settings - Fork 26.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
38 changed files
with
3,845 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,295 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from .utils import is_flash_attn_3_available | ||
|
||
|
||
if is_flash_attn_3_available(): | ||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | ||
from flash_attn_interface import _flash_attn_forward, flash_attn_func, flash_attn_varlen_func | ||
|
||
|
||
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: | ||
""" | ||
Retrieves indexing data required to repad unpadded (ragged) tensors. | ||
Arguments: | ||
attention_mask (`torch.Tensor`): | ||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. | ||
Return: | ||
indices (`torch.Tensor`): | ||
The indices of non-masked tokens from the flattened input sequence. | ||
cu_seqlens (`torch.Tensor`): | ||
The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). | ||
max_seqlen_in_batch (`int`): | ||
Maximum sequence length in batch. | ||
""" | ||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | ||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | ||
max_seqlen_in_batch = seqlens_in_batch.max().item() | ||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | ||
return ( | ||
indices, | ||
cu_seqlens, | ||
max_seqlen_in_batch, | ||
) | ||
|
||
|
||
def _upad_input( | ||
query_layer: torch.Tensor, | ||
key_layer: torch.Tensor, | ||
value_layer: torch.Tensor, | ||
attention_mask: torch.Tensor, | ||
query_length: int, | ||
): | ||
""" | ||
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. | ||
This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary | ||
tensors for query, key, value tensors. | ||
Arguments: | ||
query_layer (`torch.Tensor`): | ||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). | ||
key_layer (`torch.Tensor`): | ||
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). | ||
value_layer (`torch.Tensor`): | ||
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). | ||
attention_mask (`torch.Tensor`): | ||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. | ||
query_length (`int`): | ||
Target length. | ||
Return: | ||
query_layer (`torch.Tensor`): | ||
Query state without padding. Shape: (total_target_length, num_heads, head_dim). | ||
key_layer (`torch.Tensor`): | ||
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). | ||
value_layer (`torch.Tensor`): | ||
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). | ||
indices_q (`torch.Tensor`): | ||
The indices of non-masked tokens from the flattened input target sequence. | ||
(cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): | ||
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). | ||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): | ||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). | ||
""" | ||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) | ||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape | ||
|
||
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) | ||
value_layer = index_first_axis( | ||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k | ||
) | ||
if query_length == kv_seq_len: | ||
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) | ||
cu_seqlens_q = cu_seqlens_k | ||
max_seqlen_in_batch_q = max_seqlen_in_batch_k | ||
indices_q = indices_k | ||
elif query_length == 1: | ||
max_seqlen_in_batch_q = 1 | ||
cu_seqlens_q = torch.arange( | ||
batch_size + 1, dtype=torch.int32, device=query_layer.device | ||
) # There is a memcpy here, that is very bad. | ||
indices_q = cu_seqlens_q[:-1] | ||
query_layer = query_layer.squeeze(1) | ||
else: | ||
# The -q_len: slice assumes left padding. | ||
attention_mask = attention_mask[:, -query_length:] | ||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) | ||
|
||
return ( | ||
query_layer, | ||
key_layer, | ||
value_layer, | ||
indices_q, | ||
(cu_seqlens_q, cu_seqlens_k), | ||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), | ||
) | ||
|
||
|
||
def prepare_fa2_from_position_ids(query, key, value, position_ids): | ||
""" | ||
This function returns necessary arguments to call `flash_attn_varlen_func`. | ||
All three query, key, value states will be flattened. | ||
Cummulative lengths of each examples in the batch will be extracted from position_ids. | ||
NOTE: ideally cummulative lengths should be prepared at the data collator stage | ||
Arguments: | ||
query (`torch.Tensor`): | ||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). | ||
key (`torch.Tensor`): | ||
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). | ||
value (`torch.Tensor`): | ||
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). | ||
position_ids (`torch.Tensor`): | ||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. | ||
Return: | ||
query (`torch.Tensor`): | ||
Query state without padding. Shape: (total_target_length, num_heads, head_dim). | ||
key (`torch.Tensor`): | ||
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). | ||
value (`torch.Tensor`): | ||
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). | ||
indices_q (`torch.Tensor`): | ||
The indices of non-masked tokens from the flattened input target sequence. | ||
(cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): | ||
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). | ||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): | ||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). | ||
""" | ||
query = query.view(-1, query.size(-2), query.size(-1)) | ||
key = key.view(-1, key.size(-2), key.size(-1)) | ||
value = value.view(-1, value.size(-2), value.size(-1)) | ||
position_ids = position_ids.flatten() | ||
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) | ||
|
||
cu_seq_lens = torch.cat( | ||
( | ||
indices_q[position_ids == 0], | ||
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), | ||
) | ||
) | ||
|
||
max_length = position_ids.max() + 1 | ||
|
||
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) | ||
|
||
|
||
def _flash_attention_3_forward( | ||
query_states: torch.Tensor, | ||
key_states: torch.Tensor, | ||
value_states: torch.Tensor, | ||
attention_mask: torch.Tensor, | ||
query_length: int, | ||
is_causal: bool, | ||
position_ids: Optional[torch.Tensor] = None, | ||
softmax_scale: Optional[float] = None, | ||
use_top_left_mask: bool = False, | ||
deterministic: bool = None, | ||
descale: float = 1.0, | ||
use_fp8: bool = False, | ||
): | ||
""" | ||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token | ||
first unpad the input, then computes the attention scores and pad the final attention scores. | ||
Args: | ||
query_states (`torch.Tensor`): | ||
Input query states to be passed to Flash Attention API | ||
key_states (`torch.Tensor`): | ||
Input key states to be passed to Flash Attention API | ||
value_states (`torch.Tensor`): | ||
Input value states to be passed to Flash Attention API | ||
attention_mask (`torch.Tensor`): | ||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the | ||
position of padding tokens and 1 for the position of non-padding tokens. | ||
softmax_scale (`float`, *optional*): | ||
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) | ||
use_top_left_mask (`bool`, defaults to `False`): | ||
flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. | ||
deterministic (`bool`, *optional*): | ||
Determines if the deterministic option enabled. | ||
""" | ||
use_fp8 = os.environ.get("FLASH_ATTENTION_3_FP8", "0") == "1" | ||
|
||
softmax_scale = softmax_scale or query_states.shape[-1] ** (-0.5) | ||
|
||
if not use_top_left_mask: | ||
causal = is_causal | ||
else: | ||
causal = is_causal and query_length != 1 | ||
|
||
flash_kwargs = {} | ||
|
||
if deterministic is None: | ||
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" | ||
flash_kwargs["deterministic"] = deterministic | ||
|
||
# Contains at least one padding token in the sequence | ||
if attention_mask is not None: | ||
batch_size = query_states.shape[0] | ||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( | ||
query_states, key_states, value_states, attention_mask, query_length | ||
) | ||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens | ||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | ||
|
||
attn_output_unpad, _ = flash_attn_varlen_func( | ||
query_states, | ||
key_states, | ||
value_states, | ||
cu_seqlens_q=cu_seqlens_q, | ||
cu_seqlens_k=cu_seqlens_k, | ||
max_seqlen_q=max_seqlen_in_batch_q, | ||
max_seqlen_k=max_seqlen_in_batch_k, | ||
softmax_scale=softmax_scale, | ||
causal=causal, | ||
**flash_kwargs, | ||
) | ||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) | ||
|
||
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing | ||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. | ||
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach | ||
elif position_ids is not None and not (torch.diff(position_ids, dim=-1) >= 0).all() and query_length != 1: | ||
batch_size = query_states.size(0) | ||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( | ||
query_states, key_states, value_states, position_ids | ||
) | ||
|
||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens | ||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | ||
|
||
attn_output, _ = flash_attn_varlen_func( | ||
query_states, | ||
key_states, | ||
value_states, | ||
cu_seqlens_q=cu_seqlens_q, | ||
cu_seqlens_k=cu_seqlens_k, | ||
max_seqlen_q=max_seqlen_in_batch_q, | ||
max_seqlen_k=max_seqlen_in_batch_k, | ||
softmax_scale=softmax_scale, | ||
causal=causal, | ||
**flash_kwargs, | ||
) | ||
|
||
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) | ||
|
||
else: | ||
if use_fp8: | ||
# NOTE: descale? | ||
attn_output = _flash_attn_forward( | ||
query_states.to(torch.float8_e4m3fn), | ||
key_states.to(torch.float8_e4m3fn), | ||
value_states.to(torch.float8_e4m3fn), | ||
softmax_scale=softmax_scale, | ||
causal=causal, | ||
)[0] | ||
else: | ||
attn_output, _ = flash_attn_func( | ||
query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs | ||
) | ||
|
||
return attn_output |
Oops, something went wrong.