From 88ad54e5cfac28f0ab9629fcf85b773197188077 Mon Sep 17 00:00:00 2001 From: Miroslav Goncharenko Date: Tue, 20 Aug 2024 13:06:47 +0200 Subject: [PATCH] gpt_big_code: make flash attention impl quantization friendly - introduce GaudiGPTBigCodeAttention class - wrapped FusedSDPA kernel to separate ModuleFusedSDPA class --- optimum/habana/transformers/modeling_utils.py | 9 +- .../habana/transformers/models/__init__.py | 2 +- .../models/gpt_bigcode/__init__.py | 2 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 424 +++++++++--------- 4 files changed, 231 insertions(+), 206 deletions(-) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index e0e95c2704..9526319dd0 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -55,6 +55,7 @@ GaudiGPT2Block, GaudiGPT2DoubleHeadsModel, GaudiGPT2LMHeadModel, + GaudiGPTBigCodeAttention, GaudiGPTBigCodeForCausalLM, GaudiGPTJAttention, GaudiGPTJBlock, @@ -148,7 +149,6 @@ gaudi_generate_speech, gaudi_get_extended_attention_mask, gaudi_gpt2_forward, - gaudi_gpt_bigcode_attention_forward, gaudi_gpt_bigcode_block_forward, gaudi_gpt_bigcode_model_forward, gaudi_gpt_neox_attention_forward, @@ -356,12 +356,13 @@ def adapt_transformers_to_gaudi(): transformers.models.gptj.modeling_gptj.GPTJModel = GaudiGPTJModel # Optimization for GPTBigCode on Gaudi - transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeAttention.forward = ( - gaudi_gpt_bigcode_attention_forward - ) + transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeAttention = GaudiGPTBigCodeAttention transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM = GaudiGPTBigCodeForCausalLM transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeBlock.forward = gaudi_gpt_bigcode_block_forward transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeModel.forward = gaudi_gpt_bigcode_model_forward + transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBIGCODE_ATTENTION_CLASSES.update( + {"eager": GaudiGPTBigCodeAttention} + ) # Optimization for gpt-neox generation on Gaudi transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM = GaudiGPTNeoXForCausalLM diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 3ef8aae0f4..e08a17ffa2 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -79,8 +79,8 @@ gaudi_gpt2_forward, ) from .gpt_bigcode import ( + GaudiGPTBigCodeAttention, GaudiGPTBigCodeForCausalLM, - gaudi_gpt_bigcode_attention_forward, gaudi_gpt_bigcode_block_forward, gaudi_gpt_bigcode_model_forward, ) diff --git a/optimum/habana/transformers/models/gpt_bigcode/__init__.py b/optimum/habana/transformers/models/gpt_bigcode/__init__.py index 556f61f8c7..08ccaf3725 100644 --- a/optimum/habana/transformers/models/gpt_bigcode/__init__.py +++ b/optimum/habana/transformers/models/gpt_bigcode/__init__.py @@ -1,6 +1,6 @@ from .modeling_gpt_bigcode import ( + GaudiGPTBigCodeAttention, GaudiGPTBigCodeForCausalLM, - gaudi_gpt_bigcode_attention_forward, gaudi_gpt_bigcode_block_forward, gaudi_gpt_bigcode_model_forward, ) diff --git a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index c4c9337657..4cbe06c3ce 100644 --- a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -6,7 +6,7 @@ import torch.utils.checkpoint from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions -from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM +from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention, GPTBigCodeForCausalLM from ...modeling_attn_mask_utils import GaudiAttentionMaskConverter @@ -20,230 +20,254 @@ import habana_frameworks.torch.core as htcore -def gaudi_flash_attn_v1( - query_layer, - key_layer, - value_layer, - attention_mask, - dropout_rate, - is_causal, - scale, - softmax_mode, - enable_recompute, - q_block_size, -): - """ - Gaudi version of Flash Attention V1 to support long sequence at prompt phase - Causal mask is not supported in this optimization - """ - if is_causal: - raise ValueError("Causal mask is not supported for long input sequences") +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA - q_len = query_layer.size(-2) - q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) - q_padding = q_tiles * q_block_size - q_len - query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0) - if attention_mask is not None: - attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0) - row_o_list = [] - for i in range(q_tiles): - s, e = i * q_block_size, (i + 1) * q_block_size - row_q = query_layer[:, :, s:e, :] - row_mask = attention_mask[:, :, s:e, :] - attn_output_partial = FusedSDPA.apply( - row_q, key_layer, value_layer, row_mask, dropout_rate, is_causal, scale, softmax_mode, enable_recompute + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode, enable_recompute): + return self._hpu_kernel_fsdpa.apply( + query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode, enable_recompute ) - row_o_list.append(attn_output_partial) - attn_output = torch.cat(row_o_list, dim=-2) - if q_padding != 0: - attn_output = attn_output[:, :, :-q_padding, :] - return attn_output -def apply_FusedSDPA( - self, - query, - key, - value, - attention_mask=None, - flash_attention_recompute=False, - flash_attention_fast_softmax=False, - flash_attention_causal_mask=False, -): - """ - Copied from GPTBigCodeSdpaAttention._attn: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py - The only differences are: - - replaced torch.nn.functional.scaled_dot_product_attention with Habana's FusedSDPA - - removed WA for key and value tensor expanding over heads dimension. That WA also works but dramatically drops throughput - - added args use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask to control parameters of FusedSDPA - - added special case handling for input larger 8192 with function gaudi_flash_attn_v1 - """ +class GaudiGPTBigCodeAttention(GPTBigCodeAttention): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__(config, is_cross_attention, layer_idx) - scale = None - if not self.scale_attn_weights: - scale = 1 + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA is not None else None + self.block_size = 4096 - # MQA models: (batch_size, query_length, num_heads * head_dim) - # MHA models: (batch_size, num_heads, query_length, head_dim) - query_shape = query.shape - batch_size = query_shape[0] + def gaudi_flash_attn_v1( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + dropout_rate, + is_causal, + scale, + softmax_mode, + enable_recompute, + q_block_size, + ): + """ + Gaudi version of Flash Attention V1 to support long sequence at prompt phase + Causal mask is not supported in this optimization + """ + if is_causal: + raise ValueError("Causal mask is not supported for long input sequences") - if self.multi_query: - query_length = query_shape[1] + q_len = query_layer.size(-2) + q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) + q_padding = q_tiles * q_block_size - q_len + query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0) + if attention_mask is not None: + attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0) + row_o_list = [] + for i in range(q_tiles): + s, e = i * q_block_size, (i + 1) * q_block_size + row_q = query_layer[:, :, s:e, :] + row_mask = attention_mask[:, :, s:e, :] + attn_output_partial = self.fused_scaled_dot_product_attention( + row_q, key_layer, value_layer, row_mask, dropout_rate, is_causal, scale, softmax_mode, enable_recompute + ) + row_o_list.append(attn_output_partial) + attn_output = torch.cat(row_o_list, dim=-2) + if q_padding != 0: + attn_output = attn_output[:, :, :-q_padding, :] + return attn_output - # SDPA requires the dimension [..., sequence_length, head_dim]. - query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2) + def apply_FusedSDPA( + self, + query, + key, + value, + attention_mask=None, + flash_attention_recompute=False, + flash_attention_fast_softmax=False, + flash_attention_causal_mask=False, + ): + """ + Copied from GPTBigCodeSdpaAttention._attn: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py + The only differences are: + - replaced torch.nn.functional.scaled_dot_product_attention with Habana's FusedSDPA + - removed WA for key and value tensor expanding over heads dimension. That WA also works but dramatically drops throughput + - added args use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask to control parameters of FusedSDPA + - added special case handling for input larger 8192 with function gaudi_flash_attn_v1 + """ - # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions. - key = key.unsqueeze(1) - value = value.unsqueeze(1) + scale = None + if not self.scale_attn_weights: + scale = 1 - else: - query_length = query_shape[-1] + # MQA models: (batch_size, query_length, num_heads * head_dim) + # MHA models: (batch_size, num_heads, query_length, head_dim) + query_shape = query.shape + batch_size = query_shape[0] - if attention_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() + if self.multi_query: + query_length = query_shape[1] - sdpa_result = None - enable_recompute = flash_attention_recompute and query_length > 1 + # SDPA requires the dimension [..., sequence_length, head_dim]. + query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2) - if query_length > 1 and flash_attention_causal_mask: - attention_mask = None - use_causal_mask = True - else: - use_causal_mask = self.is_causal and attention_mask is None and query_length > 1 - - if query_length > 8192: - sdpa_result = gaudi_flash_attn_v1( - query, - key, - value, - attention_mask, - self.attn_pdrop if self.training else 0.0, - use_causal_mask, - scale, - "fast" if flash_attention_fast_softmax else "None", - enable_recompute, - 4096, - ) - htcore.mark_step() - else: - sdpa_result = FusedSDPA.apply( - query, - key, - value, - attention_mask, - self.attn_pdrop if self.training else 0.0, - use_causal_mask, - scale, - "fast" if flash_attention_fast_softmax else "None", - enable_recompute, - ) + # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions. + key = key.unsqueeze(1) + value = value.unsqueeze(1) - if self.multi_query: - # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim) - sdpa_result = sdpa_result.transpose(1, 2) + else: + query_length = query_shape[-1] - # Reshape is kind of expensive here, as it does a memory copy, - # but I did not manage to make away without it (logits do not match when using view) - # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim) - sdpa_result = sdpa_result.reshape(query_shape) + if attention_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() - return sdpa_result, None + sdpa_result = None + enable_recompute = flash_attention_recompute and query_length > 1 + if query_length > 1 and flash_attention_causal_mask: + attention_mask = None + use_causal_mask = True + else: + use_causal_mask = self.is_causal and attention_mask is None and query_length > 1 + + if query_length > 8192: + sdpa_result = self.gaudi_flash_attn_v1( + query, + key, + value, + attention_mask, + self.attn_pdrop if self.training else 0.0, + use_causal_mask, + scale, + "fast" if flash_attention_fast_softmax else "None", + enable_recompute, + self.block_size, + ) + htcore.mark_step() + else: + sdpa_result = self.fused_scaled_dot_product_attention( + query, + key, + value, + attention_mask, + self.attn_pdrop if self.training else 0.0, + use_causal_mask, + scale, + "fast" if flash_attention_fast_softmax else "None", + enable_recompute, + ) -def gaudi_gpt_bigcode_attention_forward( - self, - hidden_states: torch.Tensor, - layer_past: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - token_idx: Optional[torch.Tensor] = None, - use_flash_attention: Optional[bool] = False, - flash_attention_recompute: Optional[bool] = False, - flash_attention_fast_softmax: Optional[bool] = False, - flash_attention_causal_mask: Optional[bool] = False, -) -> Union[ - Tuple[torch.Tensor, Optional[torch.Tensor]], - Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], -]: - """ - Copied from GPTBigCodeAttention.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py - The only differences are: - - add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask - - optimize KV cache - """ - if encoder_hidden_states is not None: - if not hasattr(self, "q_attn") or not self.is_cross_attention: - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." - ) + if self.multi_query: + # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim) + sdpa_result = sdpa_result.transpose(1, 2) - query = self.q_attn(hidden_states) - key_value = self.c_attn(encoder_hidden_states) - attention_mask = encoder_attention_mask - elif self.multi_query: - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) - else: - # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), - # i.e., the memory layout is not the same as GPT2. - # This makes the concatenation with past_key_value more efficient. - query, key_value = ( - self.c_attn(hidden_states) - .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) - .transpose(1, 2) - .split((self.head_dim, 2 * self.head_dim), dim=3) - ) + # Reshape is kind of expensive here, as it does a memory copy, + # but I did not manage to make away without it (logits do not match when using view) + # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim) + sdpa_result = sdpa_result.reshape(query_shape) - key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + return sdpa_result, None - if layer_past is not None: - past_key, past_value = layer_past.split((self.head_dim, self.head_dim), dim=-1) - if token_idx is not None: - # Using out of place version of index_add_() to ensure the intermediate tensors are not lost when HPU graphs are enabled. - key = past_key.index_add(1, token_idx - 1, key - torch.index_select(past_key, 1, token_idx - 1)) - value = past_value.index_add(1, token_idx - 1, value - torch.index_select(past_value, 1, token_idx - 1)) + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + """ + Copied from GPTBigCodeAttention.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py + The only differences are: + - add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask + - optimize KV cache + """ + if use_flash_attention: + assert ( + self.fused_scaled_dot_product_attention is not None + ), "Can't load HPU fused scaled dot-product attention kernel. Please retry without flash attention" + + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + attention_mask = encoder_attention_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) else: - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - present = torch.cat((key, value), dim=-1) if use_cache else None - - if not output_attentions and head_mask is None and use_flash_attention: - # Difference with the original implementation: there is no need to transpose the key here, - # as SDPA expects seq_length to be at index -2 for the key as well - attn_output, attn_weights = apply_FusedSDPA( - self, - query, - key, - value, - attention_mask, - flash_attention_recompute, - flash_attention_fast_softmax, - flash_attention_causal_mask, - ) - else: - attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) - if not self.multi_query: - attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) - attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - outputs = (attn_output, present) - if output_attentions: - if self.multi_query: - # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) - attn_weights = attn_weights.transpose(1, 2) - outputs += (attn_weights,) + if layer_past is not None: + past_key, past_value = layer_past.split((self.head_dim, self.head_dim), dim=-1) + if token_idx is not None: + # Using out of place version of index_add_() to ensure the intermediate tensors are not lost when HPU graphs are enabled. + key = past_key.index_add(1, token_idx - 1, key - torch.index_select(past_key, 1, token_idx - 1)) + value = past_value.index_add( + 1, token_idx - 1, value - torch.index_select(past_value, 1, token_idx - 1) + ) + else: + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = torch.cat((key, value), dim=-1) if use_cache else None + + if not output_attentions and head_mask is None and use_flash_attention: + # Difference with the original implementation: there is no need to transpose the key here, + # as SDPA expects seq_length to be at index -2 for the key as well + attn_output, attn_weights = self.apply_FusedSDPA( + query, + key, + value, + attention_mask, + flash_attention_recompute, + flash_attention_fast_softmax, + flash_attention_causal_mask, + ) + else: + attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) + + if not self.multi_query: + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights = attn_weights.transpose(1, 2) + outputs += (attn_weights,) - return outputs # a, present, (attentions) + return outputs # a, present, (attentions) def gaudi_gpt_bigcode_block_forward(