Skip to content

Commit

Permalink
Improve Llama2 and gpt_neox performance with Habana fused RoPE and RM…
Browse files Browse the repository at this point in the history
…SNorm (#321)
  • Loading branch information
mandy-li authored and schoi-habana committed Aug 10, 2023
1 parent 5d4765e commit 7b6f024
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 2 deletions.
2 changes: 2 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
gaudi_llama_attention_forward,
gaudi_llama_decoder_layer_forward,
gaudi_llama_model_forward,
gaudi_llama_rmsnorm_forward,
gaudi_opt_attention_forward,
gaudi_opt_decoder_forward,
gaudi_opt_decoder_layer_forward,
Expand Down Expand Up @@ -170,6 +171,7 @@ def adapt_transformers_to_gaudi():
transformers.models.llama.modeling_llama.LlamaModel.forward = gaudi_llama_model_forward
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = gaudi_llama_decoder_layer_forward
transformers.models.llama.modeling_llama.LlamaAttention.forward = gaudi_llama_attention_forward
transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = gaudi_llama_rmsnorm_forward

# Dropout kernel improvement for Flan-T5
transformers.models.t5.modeling_t5.T5Stack = GaudiT5Stack
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
gaudi_llama_attention_forward,
gaudi_llama_decoder_layer_forward,
gaudi_llama_model_forward,
gaudi_llama_rmsnorm_forward,
)
from .modeling_all_models import gaudi_conv1d_forward, gaudi_get_extended_attention_mask, gaudi_invert_attention_mask
from .opt import (
Expand Down
16 changes: 15 additions & 1 deletion optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM, apply_rotary_pos_emb, logger


try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None


def gaudi_gpt_neox_attention_forward(
self,
hidden_states: torch.FloatTensor,
Expand Down Expand Up @@ -51,7 +58,7 @@ def gaudi_gpt_neox_attention_forward(
if has_layer_past:
seq_len += layer_past[0].shape[-2]
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
query, key = apply_customized_rope(query_rot, key_rot, cos, sin, position_ids)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)

Expand Down Expand Up @@ -378,3 +385,10 @@ def prepare_inputs_for_generation(
"past_key_values": past_key_values,
"token_idx": token_idx,
}


def apply_customized_rope(q, k, cos, sin, position_ids):
if q.device.type == "hpu" and FusedRoPE:
return FusedRoPE.apply(q, cos, sin, position_ids), FusedRoPE.apply(k, cos, sin, position_ids)
else:
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
gaudi_llama_attention_forward,
gaudi_llama_decoder_layer_forward,
gaudi_llama_model_forward,
gaudi_llama_rmsnorm_forward,
)
43 changes: 42 additions & 1 deletion optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@
from transformers.models.llama.modeling_llama import LlamaForCausalLM, apply_rotary_pos_emb, logger


try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None

try:
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm
except ImportError:
print("Not using HPU fused kernel for RMSNorm")
FusedRMSNorm = None


def gaudi_llama_attention_forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -34,7 +47,7 @@ def gaudi_llama_attention_forward(
else:
kv_seq_len = past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
Expand Down Expand Up @@ -238,6 +251,27 @@ def custom_forward(*inputs):
)


def gaudi_llama_rmsnorm_forward(self, hidden_states):
"""
Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- override RMSNorm with Habana fused RMSNorm
"""
if hidden_states.device.type == "hpu" and FusedRMSNorm:
orig_dtype = hidden_states.dtype
hidden_states = FusedRMSNorm.apply(hidden_states.float(), self.weight.float(), self.variance_epsilon)
return hidden_states.to(orig_dtype)
else:
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)

return self.weight * hidden_states


class GaudiLlamaForCausalLM(LlamaForCausalLM):
"""
Inherits from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand Down Expand Up @@ -346,3 +380,10 @@ def prepare_inputs_for_generation(
}
)
return model_inputs


def apply_customized_rope(q, k, cos, sin, position_ids):
if q.device.type == "hpu" and FusedRoPE:
return FusedRoPE.apply(q, cos, sin, position_ids), FusedRoPE.apply(k, cos, sin, position_ids)
else:
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)

0 comments on commit 7b6f024

Please sign in to comment.