Skip to content

Commit

Permalink
Enable Flash Attention (Fused SDPA) for Starcoder (#1114)
Browse files Browse the repository at this point in the history
Co-authored-by: regisss <[email protected]>
  • Loading branch information
abhilash1910 and regisss authored Jul 29, 2024
1 parent ac79d23 commit a4851cd
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None

try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
print("Not using HPU fused sdpa kernel ")
FusedSDPA = None

logger = logging.get_logger(__name__)

Expand All @@ -41,13 +46,17 @@ def gaudi_starcoder2_attention_forward(
output_attentions: bool = False,
use_cache: bool = False,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Copied from Starcoder2Attention.forward: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/starcoder2/modeling_starcoder2.py
The only differences are:
- add new args token_idx
- optimize KV cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
"""
if "padding_mask" in kwargs:
warnings.warn(
Expand Down Expand Up @@ -114,10 +123,32 @@ def gaudi_starcoder2_attention_forward(

attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
query_length = q_len if past_key_value is None else q_len + past_key_value.key_cache[self.layer_idx].shape[2]
# Taken from mpt: https://github.com/huggingface/optimum-habana/blob/main/optimum/habana/transformers/models/mpt/modeling_mpt.py
if use_flash_attention and FusedSDPA:
import habana_frameworks.torch.hpu as ht

if query_length == 1:
# next token
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(query_states, key_states, value_states, attention_mask, 0.0, False, None)
else:
# first token
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
if query_length > 16384:
attn_output = self.gaudi_flash_attn_v1(
query_states, key_states, value_states, attention_mask, 0.0, self.block_size
)
ht.mark_step()
else:
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
attn_weights = None
else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
Expand Down Expand Up @@ -151,12 +182,16 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Copied from Starcoder2DecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/starcoder2/modeling_starcoder2.py
The only differences are:
- add new args token_idx
- add new args use_flash_attention
- add new arg flash_attention_recompute
"""

residual = hidden_states
Expand All @@ -172,6 +207,8 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -204,6 +241,8 @@ def gaudi_starcoder2_model_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Copied from Starcoder2Model.forward: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/starcoder2/modeling_starcoder2.py
Expand Down Expand Up @@ -297,6 +336,8 @@ def gaudi_starcoder2_model_forward(
output_attentions,
use_cache,
None,
use_flash_attention,
flash_attention_recompute,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -307,6 +348,8 @@ def gaudi_starcoder2_model_forward(
output_attentions=output_attentions,
use_cache=use_cache,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -351,6 +394,8 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Inherits from Starcoder2ForCausalLM: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/starcoder2/modeling_starcoder2.py
Expand All @@ -375,6 +420,8 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)

hidden_states = outputs[0]
Expand Down
7 changes: 5 additions & 2 deletions tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
("meta-llama/Llama-2-7b-hf", 512, False, 8711), # in some cases like TGI, reuse_cache isnt used
("stabilityai/stablelm-2-12b", 1, False, 74.8904496532218),
("codellama/CodeLlama-34b-hf", 1, True, 32.644),
("bigcode/starcoder2-3b", 1, False, 234.2649120507936),
("bigcode/starcoder2-3b", 1, False, 261.07213776344133),
("adept/persimmon-8b-base", 4, False, 366.73968820698406),
("Qwen/Qwen1.5-7B", 4, False, 518.894516133132),
("google/gemma-7b", 1, False, 109.70751574382221),
Expand Down Expand Up @@ -139,9 +139,12 @@ def _test_text_generation(
if "llama" in model_name.lower():
command += ["--trim_logits", "--attn_softmax_bf16"]

if "falcon" in model_name.lower():
if "falcon" in model_name.lower() or "starcoder2" in model_name.lower():
command += ["--use_flash_attention", "--flash_attention_causal_mask"]

if "starcoder2" in model_name.lower():
command += ["--flash_attention_recompute"]

if reuse_cache or torch_compile:
command += ["--reuse_cache"]

Expand Down

0 comments on commit a4851cd

Please sign in to comment.