Skip to content

Commit

Permalink
gpt_big_code: make flash attention impl quantization friendly
Browse files Browse the repository at this point in the history
- introduce GaudiGPTBigCodeAttention class
- wrapped FusedSDPA kernel to separate ModuleFusedSDPA class
  • Loading branch information
mgonchar committed Sep 25, 2024
1 parent cf7bfaa commit 88ad54e
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 206 deletions.
9 changes: 5 additions & 4 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
GaudiGPT2Block,
GaudiGPT2DoubleHeadsModel,
GaudiGPT2LMHeadModel,
GaudiGPTBigCodeAttention,
GaudiGPTBigCodeForCausalLM,
GaudiGPTJAttention,
GaudiGPTJBlock,
Expand Down Expand Up @@ -148,7 +149,6 @@
gaudi_generate_speech,
gaudi_get_extended_attention_mask,
gaudi_gpt2_forward,
gaudi_gpt_bigcode_attention_forward,
gaudi_gpt_bigcode_block_forward,
gaudi_gpt_bigcode_model_forward,
gaudi_gpt_neox_attention_forward,
Expand Down Expand Up @@ -356,12 +356,13 @@ def adapt_transformers_to_gaudi():
transformers.models.gptj.modeling_gptj.GPTJModel = GaudiGPTJModel

# Optimization for GPTBigCode on Gaudi
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeAttention.forward = (
gaudi_gpt_bigcode_attention_forward
)
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeAttention = GaudiGPTBigCodeAttention
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM = GaudiGPTBigCodeForCausalLM
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeBlock.forward = gaudi_gpt_bigcode_block_forward
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeModel.forward = gaudi_gpt_bigcode_model_forward
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBIGCODE_ATTENTION_CLASSES.update(
{"eager": GaudiGPTBigCodeAttention}
)

# Optimization for gpt-neox generation on Gaudi
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM = GaudiGPTNeoXForCausalLM
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@
gaudi_gpt2_forward,
)
from .gpt_bigcode import (
GaudiGPTBigCodeAttention,
GaudiGPTBigCodeForCausalLM,
gaudi_gpt_bigcode_attention_forward,
gaudi_gpt_bigcode_block_forward,
gaudi_gpt_bigcode_model_forward,
)
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/gpt_bigcode/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .modeling_gpt_bigcode import (
GaudiGPTBigCodeAttention,
GaudiGPTBigCodeForCausalLM,
gaudi_gpt_bigcode_attention_forward,
gaudi_gpt_bigcode_block_forward,
gaudi_gpt_bigcode_model_forward,
)
Loading

0 comments on commit 88ad54e

Please sign in to comment.