Skip to content

Commit

Permalink
Enable hpu rms fused kernel for t5 (#344)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhaiFeiyue authored Aug 17, 2023
1 parent 0c4ab75 commit 8a0a81d
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 0 deletions.
4 changes: 4 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
gaudi_opt_model_forward,
gaudi_rot_matmul,
gaudi_rot_vec_mul,
gaudi_t5_layernorm_forward,
gaudi_vit_self_attention_forward,
gaudi_wav2vec2_forward,
)
Expand Down Expand Up @@ -201,3 +202,6 @@ def adapt_transformers_to_gaudi():
transformers.models.falcon.modeling_falcon.FalconDecoderLayer.forward = gaudi_falcon_decoder_layer_forward
transformers.models.falcon.modeling_falcon.FalconAttention.forward = gaudi_falcon_attention_forward
transformers.models.falcon.modeling_falcon.FalconRotaryEmbedding.forward = gaudi_falcon_rotary_embedding_forward

# Optimization for t5 on Gaudi
transformers.models.t5.modeling_t5.T5LayerNorm.forward = gaudi_t5_layernorm_forward
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
gaudi_opt_decoder_layer_forward,
gaudi_opt_model_forward,
)
from .t5 import gaudi_t5_layernorm_forward
from .vit import gaudi_vit_self_attention_forward
from .wav2vec2 import (
_gaudi_wav2vec2_compute_mask_indices,
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/t5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .modeling_t5 import gaudi_t5_layernorm_forward
29 changes: 29 additions & 0 deletions optimum/habana/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch


try:
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm
except ImportError:
print("Not using HPU fused kernel for RMSNorm")
FusedRMSNorm = None


def gaudi_t5_layernorm_forward(self, hidden_states):
"""
Copied from T5LayerNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
The only differences are:
- override RMSNorm with Habana fused RMSNorm
"""
if not self.training and hidden_states.device.type == "hpu" and FusedRMSNorm:
orig_dtype = hidden_states.dtype
hidden_states = FusedRMSNorm.apply(hidden_states.float(), self.weight.float(), self.variance_epsilon)
return hidden_states.to(orig_dtype)
else:
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)

return self.weight * hidden_states

0 comments on commit 8a0a81d

Please sign in to comment.