Skip to content

Commit

Permalink
gpt_bigcode: added FusedSDPA kernel
Browse files Browse the repository at this point in the history
Added support of following options to gpt_bigcode (starcoderbase) model
use_flash_attention,
flash_attention_recompute,
flash_attention_fast_softmax,
flash_attention_causal_mask
  • Loading branch information
mgonchar committed Jul 17, 2024
1 parent c495f47 commit 80da185
Show file tree
Hide file tree
Showing 5 changed files with 375 additions and 193 deletions.
6 changes: 6 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,12 @@ def initialize_model(args, logger):
"token": args.token,
"trust_remote_code": args.trust_remote_code,
}

# For starcoderbase model it is essential to setup proper flags for sdpa kernel at model level
# to avoid transpose of tensors at each attention call
if 'starcoderbase' in args.model_name_or_path:
model_kwargs["use_flash_attention"] = args.use_flash_attention

if args.trust_remote_code:
logger.warning("`trust_remote_code` is set, there is no guarantee this model works properly and it may fail")

Expand Down
4 changes: 2 additions & 2 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
GaudiGPT2Attention,
GaudiGPT2Block,
GaudiGPT2LMHeadModel,
GaudiGPTBigCodeModel,
GaudiGPTBigCodeForCausalLM,
GaudiGPTJAttention,
GaudiGPTJBlock,
Expand Down Expand Up @@ -131,7 +132,6 @@
gaudi_gpt2_forward,
gaudi_gpt_bigcode_attention_forward,
gaudi_gpt_bigcode_block_forward,
gaudi_gpt_bigcode_model_forward,
gaudi_gpt_neox_attention_forward,
gaudi_gpt_neox_layer_forward,
gaudi_gpt_neox_model_forward,
Expand Down Expand Up @@ -348,7 +348,7 @@ def adapt_transformers_to_gaudi():
)
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM = GaudiGPTBigCodeForCausalLM
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeBlock.forward = gaudi_gpt_bigcode_block_forward
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeModel.forward = gaudi_gpt_bigcode_model_forward
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeModel = GaudiGPTBigCodeModel

# Optimization for gpt-neox generation on Gaudi
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM = GaudiGPTNeoXForCausalLM
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@
from .gpt2 import GaudiGPT2Attention, GaudiGPT2Block, GaudiGPT2LMHeadModel, gaudi_gpt2_forward
from .gpt_bigcode import (
GaudiGPTBigCodeForCausalLM,
GaudiGPTBigCodeModel,
gaudi_gpt_bigcode_attention_forward,
gaudi_gpt_bigcode_block_forward,
gaudi_gpt_bigcode_model_forward,
)
from .gpt_neox import (
GaudiGPTNeoXForCausalLM,
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/gpt_bigcode/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .modeling_gpt_bigcode import (
GaudiGPTBigCodeForCausalLM,
GaudiGPTBigCodeModel,
gaudi_gpt_bigcode_attention_forward,
gaudi_gpt_bigcode_block_forward,
gaudi_gpt_bigcode_model_forward,
)
Loading

0 comments on commit 80da185

Please sign in to comment.