Skip to content

Commit

Permalink
Enable MPT fp8 support
Browse files Browse the repository at this point in the history
Add Softmax and FusedSDPA
Update GaudiMptAttention foward to r4.44.1 base

Co-authored-by: Thanaji Rao Thakkalapelli <[email protected]>
  • Loading branch information
atakaha and tthakkal committed Aug 21, 2024
1 parent 8abf818 commit b3f729f
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 166 deletions.
8 changes: 4 additions & 4 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
GaudiMixtralDecoderLayer,
GaudiMixtralForCausalLM,
GaudiMixtralModel,
GaudiMptAttention,
GaudiMptBlock,
GaudiMptForCausalLM,
GaudiMptModel,
GaudiOPTForCausalLM,
Expand Down Expand Up @@ -150,8 +152,6 @@
gaudi_mistral_rmsnorm_forward,
gaudi_mixtral_block_sparse_moe_forward,
gaudi_mixtral_rmsnorm_forward,
gaudi_mpt_attention_forward,
gaudi_mpt_block_forward,
gaudi_opt_attention_forward,
gaudi_opt_decoder_forward,
gaudi_opt_decoder_layer_forward,
Expand Down Expand Up @@ -415,8 +415,8 @@ def adapt_transformers_to_gaudi():
# Optimization for mpt on Gaudi
transformers.models.mpt.modeling_mpt.MptForCausalLM = GaudiMptForCausalLM
transformers.models.mpt.modeling_mpt.MptModel = GaudiMptModel
transformers.models.mpt.modeling_mpt.MptAttention.forward = gaudi_mpt_attention_forward
transformers.models.mpt.modeling_mpt.MptBlock.forward = gaudi_mpt_block_forward
transformers.models.mpt.modeling_mpt.MptAttention = GaudiMptAttention
transformers.models.mpt.modeling_mpt.MptBlock = GaudiMptBlock

# Optimization for mistral on Gaudi
transformers.models.mistral.modeling_mistral.MistralForCausalLM = GaudiMistralForCausalLM
Expand Down
4 changes: 2 additions & 2 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@
gaudi_invert_attention_mask,
)
from .mpt import (
GaudiMptAttention,
GaudiMptBlock,
GaudiMptForCausalLM,
GaudiMptModel,
gaudi_mpt_attention_forward,
gaudi_mpt_block_forward,
)
from .opt import (
GaudiOPTForCausalLM,
Expand Down
4 changes: 2 additions & 2 deletions optimum/habana/transformers/models/mpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .modeling_mpt import (
GaudiMptAttention,
GaudiMptBlock,
GaudiMptForCausalLM,
GaudiMptModel,
gaudi_mpt_attention_forward,
gaudi_mpt_block_forward,
)
Loading

0 comments on commit b3f729f

Please sign in to comment.