From 80da185ff7ff445517165e3b48d74e9dcf8ef1e8 Mon Sep 17 00:00:00 2001 From: Miroslav Goncharenko Date: Wed, 17 Jul 2024 15:27:44 +0200 Subject: [PATCH] gpt_bigcode: added FusedSDPA kernel Added support of following options to gpt_bigcode (starcoderbase) model use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask --- examples/text-generation/utils.py | 6 + optimum/habana/transformers/modeling_utils.py | 4 +- .../habana/transformers/models/__init__.py | 2 +- .../models/gpt_bigcode/__init__.py | 2 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 554 ++++++++++++------ 5 files changed, 375 insertions(+), 193 deletions(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index fa1946b914..5b4df3093f 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -515,6 +515,12 @@ def initialize_model(args, logger): "token": args.token, "trust_remote_code": args.trust_remote_code, } + + # For starcoderbase model it is essential to setup proper flags for sdpa kernel at model level + # to avoid transpose of tensors at each attention call + if 'starcoderbase' in args.model_name_or_path: + model_kwargs["use_flash_attention"] = args.use_flash_attention + if args.trust_remote_code: logger.warning("`trust_remote_code` is set, there is no guarantee this model works properly and it may fail") diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 6b88086a0d..6961b1514c 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -46,6 +46,7 @@ GaudiGPT2Attention, GaudiGPT2Block, GaudiGPT2LMHeadModel, + GaudiGPTBigCodeModel, GaudiGPTBigCodeForCausalLM, GaudiGPTJAttention, GaudiGPTJBlock, @@ -131,7 +132,6 @@ gaudi_gpt2_forward, gaudi_gpt_bigcode_attention_forward, gaudi_gpt_bigcode_block_forward, - gaudi_gpt_bigcode_model_forward, gaudi_gpt_neox_attention_forward, gaudi_gpt_neox_layer_forward, gaudi_gpt_neox_model_forward, @@ -348,7 +348,7 @@ def adapt_transformers_to_gaudi(): ) transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM = GaudiGPTBigCodeForCausalLM transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeBlock.forward = gaudi_gpt_bigcode_block_forward - transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeModel.forward = gaudi_gpt_bigcode_model_forward + transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeModel = GaudiGPTBigCodeModel # Optimization for gpt-neox generation on Gaudi transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM = GaudiGPTNeoXForCausalLM diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 80ec7cd09f..4e38aef63b 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -68,9 +68,9 @@ from .gpt2 import GaudiGPT2Attention, GaudiGPT2Block, GaudiGPT2LMHeadModel, gaudi_gpt2_forward from .gpt_bigcode import ( GaudiGPTBigCodeForCausalLM, + GaudiGPTBigCodeModel, gaudi_gpt_bigcode_attention_forward, gaudi_gpt_bigcode_block_forward, - gaudi_gpt_bigcode_model_forward, ) from .gpt_neox import ( GaudiGPTNeoXForCausalLM, diff --git a/optimum/habana/transformers/models/gpt_bigcode/__init__.py b/optimum/habana/transformers/models/gpt_bigcode/__init__.py index 556f61f8c7..ff1c919b04 100644 --- a/optimum/habana/transformers/models/gpt_bigcode/__init__.py +++ b/optimum/habana/transformers/models/gpt_bigcode/__init__.py @@ -1,6 +1,6 @@ from .modeling_gpt_bigcode import ( GaudiGPTBigCodeForCausalLM, + GaudiGPTBigCodeModel, gaudi_gpt_bigcode_attention_forward, gaudi_gpt_bigcode_block_forward, - gaudi_gpt_bigcode_model_forward, ) diff --git a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index e35b4cac52..304fa45cec 100644 --- a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -1,14 +1,138 @@ +import math from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions -from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM +from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM, GPTBigCodeModel from ...modeling_attn_mask_utils import GaudiAttentionMaskConverter +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None + +import habana_frameworks.torch.core as htcore + + +def gaudi_flash_attn_v1(query_layer, key_layer, value_layer, attention_mask, dropout_rate, is_causal, scale, softmax_mode, q_block_size): + """ + Gaudi version of Flash Attention V1 to support long sequence at prompt phase + Causal mask is not supported in this optimization + """ + if is_causal: + raise ValueError( + "Causal mask is not supported for long input sequences" + ) + + q_len = query_layer.size(-2) + q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) + q_padding = q_tiles * q_block_size - q_len + query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0) + if attention_mask is not None: + attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0) + row_o_list = [] + for i in range(q_tiles): + s, e = i * q_block_size, (i + 1) * q_block_size + row_q = query_layer[:, :, s:e, :] + row_mask = attention_mask[:, :, s:e, :] + attn_output_partial = FusedSDPA.apply(row_q, key_layer, value_layer, row_mask, dropout_rate, is_causal, scale, softmax_mode) + row_o_list.append(attn_output_partial) + attn_output = torch.cat(row_o_list, dim=-2) + if q_padding != 0: + attn_output = attn_output[:, :, :-q_padding, :] + return attn_output + +def apply_FusedSDPA(self, query, key, value, attention_mask=None, flash_attention_recompute=False, flash_attention_fast_softmax=False, flash_attention_causal_mask=False): + + """ + Copied from GPTBigCodeSdpaAttention._attn: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py + The only differences are: + - replaced torch.nn.functional.scaled_dot_product_attention with Habana's FusedSDPA + - removed WA for key and value tensor expanding over heads dimention. That WA also works but dramatically drops throughput + - added args use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask to control parameters of FusedSDPA + - added special case handling for input larger 8192 with function gaudi_flash_attn_v1 + """ + + scale = None + if not self.scale_attn_weights: + scale = 1 + + # MQA models: (batch_size, query_length, num_heads * head_dim) + # MHA models: (batch_size, num_heads, query_length, head_dim) + query_shape = query.shape + batch_size = query_shape[0] + + if self.multi_query: + query_length = query_shape[1] + + # SDPA requires the dimension [..., sequence_length, head_dim]. + query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2) + + # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions. + key = key.unsqueeze(1) + value = value.unsqueeze(1) + + else: + query_length = query_shape[-1] + + if attention_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + sdpa_result = None + enable_recompute = flash_attention_recompute and query_length > 1 + + if query_length > 1 and flash_attention_causal_mask: + attention_mask = None + use_causal_mask = True + else: + use_causal_mask = self.is_causal and attention_mask is None and query_length > 1 + + import habana_frameworks.torch.hpu as ht + with ht.sdp_kernel(enable_recompute=enable_recompute): + if query_length > 8192: + sdpa_result = gaudi_flash_attn_v1( + query, + key, + value, + attention_mask, + self.attn_pdrop if self.training else 0.0, + use_causal_mask, + scale, + 'fast' if flash_attention_fast_softmax else 'None', + 4096 + ) + htcore.mark_step() + else: + sdpa_result = FusedSDPA.apply( + query, + key, + value, + attention_mask, + self.attn_pdrop if self.training else 0.0, + use_causal_mask, + scale, + 'fast' if flash_attention_fast_softmax else 'None') + + if self.multi_query: + # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim) + sdpa_result = sdpa_result.transpose(1, 2) + + # Reshape is kind of expensive here, as it does a memory copy, + # but I did not manage to make away without it (logits do not match when using view) + # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim) + sdpa_result = sdpa_result.reshape(query_shape) + + return sdpa_result, None + + def gaudi_gpt_bigcode_attention_forward( self, hidden_states: torch.Tensor, @@ -20,6 +144,10 @@ def gaudi_gpt_bigcode_attention_forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], @@ -27,7 +155,7 @@ def gaudi_gpt_bigcode_attention_forward( """ Copied from GPTBigCodeAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py The only differences are: - - add new args token_idx + - add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask - optimize KV cache """ if encoder_hidden_states is not None: @@ -65,7 +193,12 @@ def gaudi_gpt_bigcode_attention_forward( value = torch.cat((past_value, value), dim=-2) present = torch.cat((key, value), dim=-1) if use_cache else None - attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) + if not output_attentions and head_mask is None and use_flash_attention: + # Difference with the original implementation: there is no need to transpose the key here, + # as SDPA expects seq_length to be at index -2 for the key as well + attn_output, attn_weights = apply_FusedSDPA(self, query, key, value, attention_mask, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask) + else: + attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) if not self.multi_query: attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) @@ -93,11 +226,15 @@ def gaudi_gpt_bigcode_block_forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Copied from GPTBigCodeBlock.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py The only differences are: - - add new args token_idx + - add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask """ residual = hidden_states hidden_states = self.ln_1(hidden_states) @@ -109,6 +246,10 @@ def gaudi_gpt_bigcode_block_forward( use_cache=use_cache, output_attentions=output_attentions, token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_fast_softmax=flash_attention_fast_softmax, + flash_attention_causal_mask=flash_attention_causal_mask ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -150,222 +291,245 @@ def gaudi_gpt_bigcode_block_forward( return outputs # hidden_states, present, (attentions, cross_attentions) - -def gaudi_gpt_bigcode_model_forward( - self, - input_ids: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, -) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: +class GaudiGPTBigCodeModel(GPTBigCodeModel): """ - Copied from GPTBigCodeModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py + Inherits from GPTBigCodeModel: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py The only differences are: - - add new args token_idx - - if token_idx and past_key_values are passed, set self_attention_mask based on the static shape of past_key_values + - correctly set self._use_sdpa flag for attention_mask shape preparation """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0].size(-2) - - if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_length > 0: - position_ids = position_ids[:, past_length : input_shape[-1] + past_length :] - elif position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) - - # Self-attention mask. - query_length = input_shape[-1] - key_length = past_length + query_length - if past_length > 0 and token_idx is not None: - self_attention_mask = self.bias[None, past_length - 1 : past_length, :past_length] - else: - self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + def __init__(self, config): + super().__init__(config) - if attention_mask is not None: - self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( - dtype=torch.bool, device=self_attention_mask.device - ) + self._use_sdpa = FusedSDPA is not None - # MQA models: (batch_size, query_length, n_heads, key_length) - # MHA models: (batch_size, n_heads, query_length, key_length) - self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) - - if self._use_sdpa and head_mask is None and not output_attentions: - # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. - dtype = self.wte.weight.dtype - min_dtype = torch.finfo(dtype).min - self_attention_mask = torch.where( - self_attention_mask, - torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), - torch.full([], min_dtype, dtype=dtype, device=self_attention_mask.device), + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + """ + Copied from GPTBigCodeModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py + The only differences are: + - add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask + - if token_idx and past_key_values are passed, set self_attention_mask based on the static shape of past_key_values + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - if self.multi_query: - # gpt_bigcode using MQA has the bad taste to use a causal mask with shape - # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose. - self_attention_mask = self_attention_mask.transpose(1, 2) - - if query_length > 1 and attention_mask is not None: - # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend - # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - self_attention_mask = GaudiAttentionMaskConverter._unmask_unattended( - self_attention_mask, min_dtype=min_dtype - ) + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") - attention_mask = self_attention_mask + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None and encoder_attention_mask is not None: - if encoder_attention_mask.dim() == 2: - encoder_attention_mask.unsqueeze(1) - assert encoder_attention_mask.dim() == 3 - encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1) - else: - encoder_attention_mask = None + device = input_ids.device if input_ids is not None else inputs_embeds.device - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0].size(-2) - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds + if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_length > 0: + position_ids = position_ids[:, past_length : input_shape[-1] + past_length :] + elif position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Self-attention mask. + query_length = input_shape[-1] + key_length = past_length + query_length + if past_length > 0 and token_idx is not None: + self_attention_mask = self.bias[None, past_length - 1 : past_length, :past_length] + else: + self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] - hidden_states = self.drop(hidden_states) + if attention_mask is not None: + self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( + dtype=torch.bool, device=self_attention_mask.device + ) - output_shape = input_shape + (hidden_states.size(-1),) + # MQA models: (batch_size, query_length, n_heads, key_length) + # MHA models: (batch_size, n_heads, query_length, key_length) + self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + + if self._use_sdpa and head_mask is None and not output_attentions: + # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. + dtype = self.wte.weight.dtype + min_dtype = torch.finfo(dtype).min + self_attention_mask = torch.where( + self_attention_mask, + torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), + torch.full([], min_dtype, dtype=dtype, device=self_attention_mask.device), + ) - presents = [] if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + if self.multi_query: + # gpt_bigcode using MQA has the bad taste to use a causal mask with shape + # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose. + self_attention_mask = self_attention_mask.transpose(1, 2) + + if query_length > 1 and attention_mask is not None: + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + self_attention_mask = GaudiAttentionMaskConverter._unmask_unattended( + self_attention_mask, min_dtype=min_dtype + ) + + attention_mask = self_attention_mask + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None and encoder_attention_mask is not None: + if encoder_attention_mask.dim() == 2: + encoder_attention_mask.unsqueeze(1) + assert encoder_attention_mask.dim() == 3 + encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = [] if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + None, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_fast_softmax=flash_attention_fast_softmax, + flash_attention_causal_mask=flash_attention_causal_mask + ) + + hidden_states = outputs[0] + if use_cache: + presents.append(outputs[1]) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - use_cache, - output_attentions, - None, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - token_idx=token_idx, + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None ) - hidden_states = outputs[0] - if use_cache: - presents.append(outputs[1]) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - hidden_states = self.ln_f(hidden_states) - - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] - if v is not None + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - class GaudiGPTBigCodeForCausalLM(GPTBigCodeForCausalLM): """ Inherits from GPTBigCodeForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py The only differences are: - - add new args token_idx - - add token_idx into model_inputs + - add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask + - add token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask into model_inputs - when KV cache is enabled, slice next_input_ids from input_ids based on the token_idx - when KV cache is enabled, slice next_position_ids from position_ids based on the token_idx """ + def __init__(self, config, use_flash_attention : Optional[bool] = False): + super().__init__(config) + + self.transformer._use_sdpa = self.transformer._use_sdpa and use_flash_attention + def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, token_idx=None, **kwargs ): @@ -422,6 +586,10 @@ def prepare_inputs_for_generation( "attention_mask": attention_mask, "token_type_ids": token_type_ids, "token_idx": token_idx, + "use_flash_attention": kwargs.get("use_flash_attention", False), + "flash_attention_recompute": kwargs.get("flash_attention_recompute", False), + "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax", False), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask", False), } ) return model_inputs @@ -443,6 +611,10 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -467,6 +639,10 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_fast_softmax=flash_attention_fast_softmax, + flash_attention_causal_mask=flash_attention_causal_mask ) hidden_states = transformer_outputs[0]