Skip to content

Commit

Permalink
gpt_big_code: make flash attention impl quantization friendly
Browse files Browse the repository at this point in the history
- introduce GaudiGPTBigCodeAttention class
- wrapped FusedSDPA kernel to separate ModuleFusedSDPA class
  • Loading branch information
mgonchar committed Aug 21, 2024
1 parent afa3d22 commit 31b1591
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 196 deletions.
7 changes: 3 additions & 4 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
GaudiGPT2Block,
GaudiGPT2DoubleHeadsModel,
GaudiGPT2LMHeadModel,
GaudiGPTBigCodeAttention,
GaudiGPTBigCodeForCausalLM,
GaudiGPTJAttention,
GaudiGPTJBlock,
Expand Down Expand Up @@ -137,7 +138,6 @@
gaudi_generate_speech,
gaudi_get_extended_attention_mask,
gaudi_gpt2_forward,
gaudi_gpt_bigcode_attention_forward,
gaudi_gpt_bigcode_block_forward,
gaudi_gpt_bigcode_model_forward,
gaudi_gpt_neox_attention_forward,
Expand Down Expand Up @@ -344,12 +344,11 @@ def adapt_transformers_to_gaudi():
transformers.models.gptj.modeling_gptj.GPTJModel = GaudiGPTJModel

# Optimization for GPTBigCode on Gaudi
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeAttention.forward = (
gaudi_gpt_bigcode_attention_forward
)
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeAttention = GaudiGPTBigCodeAttention
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM = GaudiGPTBigCodeForCausalLM
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeBlock.forward = gaudi_gpt_bigcode_block_forward
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeModel.forward = gaudi_gpt_bigcode_model_forward
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBIGCODE_ATTENTION_CLASSES.update({"eager": GaudiGPTBigCodeAttention})

# Optimization for gpt-neox generation on Gaudi
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM = GaudiGPTNeoXForCausalLM
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@
gaudi_gpt2_forward,
)
from .gpt_bigcode import (
GaudiGPTBigCodeAttention,
GaudiGPTBigCodeForCausalLM,
gaudi_gpt_bigcode_attention_forward,
gaudi_gpt_bigcode_block_forward,
gaudi_gpt_bigcode_model_forward,
)
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/gpt_bigcode/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .modeling_gpt_bigcode import (
GaudiGPTBigCodeAttention,
GaudiGPTBigCodeForCausalLM,
gaudi_gpt_bigcode_attention_forward,
gaudi_gpt_bigcode_block_forward,
gaudi_gpt_bigcode_model_forward,
)
Loading

0 comments on commit 31b1591

Please sign in to comment.