Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gpt_bigcode: added FusedSDPA kernel #1138

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
Expand All @@ -9,6 +11,141 @@
from ...modeling_attn_mask_utils import GaudiAttentionMaskConverter


try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None

import habana_frameworks.torch.core as htcore


def gaudi_flash_attn_v1(
query_layer, key_layer, value_layer, attention_mask, dropout_rate, is_causal, scale, softmax_mode, q_block_size
):
"""
Gaudi version of Flash Attention V1 to support long sequence at prompt phase
Causal mask is not supported in this optimization
"""
if is_causal:
raise ValueError("Causal mask is not supported for long input sequences")

q_len = query_layer.size(-2)
q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size)
q_padding = q_tiles * q_block_size - q_len
query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0)
if attention_mask is not None:
attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0)
row_o_list = []
for i in range(q_tiles):
s, e = i * q_block_size, (i + 1) * q_block_size
row_q = query_layer[:, :, s:e, :]
row_mask = attention_mask[:, :, s:e, :]
attn_output_partial = FusedSDPA.apply(
row_q, key_layer, value_layer, row_mask, dropout_rate, is_causal, scale, softmax_mode
)
row_o_list.append(attn_output_partial)
attn_output = torch.cat(row_o_list, dim=-2)
if q_padding != 0:
attn_output = attn_output[:, :, :-q_padding, :]
return attn_output


def apply_FusedSDPA(
self,
query,
key,
value,
attention_mask=None,
flash_attention_recompute=False,
flash_attention_fast_softmax=False,
flash_attention_causal_mask=False,
):
"""
Copied from GPTBigCodeSdpaAttention._attn: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
The only differences are:
- replaced torch.nn.functional.scaled_dot_product_attention with Habana's FusedSDPA
- removed WA for key and value tensor expanding over heads dimension. That WA also works but dramatically drops throughput
- added args use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask to control parameters of FusedSDPA
- added special case handling for input larger 8192 with function gaudi_flash_attn_v1
"""

scale = None
if not self.scale_attn_weights:
scale = 1

# MQA models: (batch_size, query_length, num_heads * head_dim)
# MHA models: (batch_size, num_heads, query_length, head_dim)
query_shape = query.shape
batch_size = query_shape[0]

if self.multi_query:
query_length = query_shape[1]

# SDPA requires the dimension [..., sequence_length, head_dim].
query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)

# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
key = key.unsqueeze(1)
value = value.unsqueeze(1)

else:
query_length = query_shape[-1]

if attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

sdpa_result = None
enable_recompute = flash_attention_recompute and query_length > 1

if query_length > 1 and flash_attention_causal_mask:
attention_mask = None
use_causal_mask = True
else:
use_causal_mask = self.is_causal and attention_mask is None and query_length > 1

import habana_frameworks.torch.hpu as ht

with ht.sdp_kernel(enable_recompute=enable_recompute):
if query_length > 8192:
sdpa_result = gaudi_flash_attn_v1(
query,
key,
value,
attention_mask,
self.attn_pdrop if self.training else 0.0,
use_causal_mask,
scale,
"fast" if flash_attention_fast_softmax else "None",
4096,
)
htcore.mark_step()
else:
sdpa_result = FusedSDPA.apply(
query,
key,
value,
attention_mask,
self.attn_pdrop if self.training else 0.0,
use_causal_mask,
scale,
"fast" if flash_attention_fast_softmax else "None",
)

if self.multi_query:
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
sdpa_result = sdpa_result.transpose(1, 2)

# Reshape is kind of expensive here, as it does a memory copy,
# but I did not manage to make away without it (logits do not match when using view)
# (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
sdpa_result = sdpa_result.reshape(query_shape)

return sdpa_result, None


def gaudi_gpt_bigcode_attention_forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -20,14 +157,18 @@ def gaudi_gpt_bigcode_attention_forward(
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
]:
"""
Copied from GPTBigCodeAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Copied from GPTBigCodeAttention.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
The only differences are:
- add new args token_idx
- add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask
- optimize KV cache
"""
if encoder_hidden_states is not None:
Expand Down Expand Up @@ -65,7 +206,21 @@ def gaudi_gpt_bigcode_attention_forward(
value = torch.cat((past_value, value), dim=-2)
present = torch.cat((key, value), dim=-1) if use_cache else None

attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask)
if not output_attentions and head_mask is None and use_flash_attention:
# Difference with the original implementation: there is no need to transpose the key here,
# as SDPA expects seq_length to be at index -2 for the key as well
attn_output, attn_weights = apply_FusedSDPA(
self,
query,
key,
value,
attention_mask,
flash_attention_recompute,
flash_attention_fast_softmax,
flash_attention_causal_mask,
)
else:
attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask)

if not self.multi_query:
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
Expand Down Expand Up @@ -93,11 +248,15 @@ def gaudi_gpt_bigcode_block_forward(
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Copied from GPTBigCodeBlock.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Copied from GPTBigCodeBlock.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
The only differences are:
- add new args token_idx
- add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask
"""
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
Expand All @@ -109,6 +268,10 @@ def gaudi_gpt_bigcode_block_forward(
use_cache=use_cache,
output_attentions=output_attentions,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_fast_softmax=flash_attention_fast_softmax,
flash_attention_causal_mask=flash_attention_causal_mask,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
Expand Down Expand Up @@ -167,13 +330,21 @@ def gaudi_gpt_bigcode_model_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
"""
Copied from GPTBigCodeModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Copied from GPTBigCodeModel.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
The only differences are:
- add new args token_idx
- add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask
- if token_idx and past_key_values are passed, set self_attention_mask based on the static shape of past_key_values
"""

# This flag used for correct tensors reshape for attention kernel
self._use_sdpa = use_flash_attention

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -322,6 +493,10 @@ def gaudi_gpt_bigcode_model_forward(
use_cache=use_cache,
output_attentions=output_attentions,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_fast_softmax=flash_attention_fast_softmax,
flash_attention_causal_mask=flash_attention_causal_mask,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -358,10 +533,10 @@ def gaudi_gpt_bigcode_model_forward(

class GaudiGPTBigCodeForCausalLM(GPTBigCodeForCausalLM):
"""
Inherits from GPTBigCodeForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Inherits from GPTBigCodeForCausalLM: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
The only differences are:
- add new args token_idx
- add token_idx into model_inputs
- add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask
- add token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask into model_inputs
- when KV cache is enabled, slice next_input_ids from input_ids based on the token_idx
- when KV cache is enabled, slice next_position_ids from position_ids based on the token_idx
"""
Expand Down Expand Up @@ -422,6 +597,10 @@ def prepare_inputs_for_generation(
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"token_idx": token_idx,
"use_flash_attention": kwargs.get("use_flash_attention", False),
"flash_attention_recompute": kwargs.get("flash_attention_recompute", False),
"flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax", False),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask", False),
}
)
return model_inputs
Expand All @@ -443,6 +622,10 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand All @@ -467,6 +650,10 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_fast_softmax=flash_attention_fast_softmax,
flash_attention_causal_mask=flash_attention_causal_mask,
)
hidden_states = transformer_outputs[0]

Expand Down
5 changes: 4 additions & 1 deletion tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
("EleutherAI/gpt-neox-20b", 1, False, 50.67672679310354),
("meta-llama/Llama-2-7b-hf", 1, True, 141.25776956002076),
("tiiuae/falcon-40b", 1, True, 25.202450111088346),
("bigcode/starcoder", 1, False, 65.58632640700114),
("bigcode/starcoder", 256, False, 4329.754794647058),
("Salesforce/codegen2-1B", 1, False, 446.4029486883532),
("mosaicml/mpt-30b", 1, False, 36.06464336116623),
("mistralai/Mistral-7B-v0.1", 1, True, 130.2172236767782),
Expand Down Expand Up @@ -142,6 +142,9 @@ def _test_text_generation(
if "falcon" in model_name.lower() or "starcoder2" in model_name.lower():
command += ["--use_flash_attention", "--flash_attention_causal_mask"]

if "starcoder" in model_name.lower() and "starcoder2" not in model_name.lower():
command += ["--use_flash_attention"]

if "starcoder2" in model_name.lower():
command += ["--flash_attention_recompute"]

Expand Down
Loading