Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Llama2 and gpt_neox performance with Habana fused RoPE and RMSNorm #321

Merged
merged 7 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
gaudi_llama_attention_forward,
gaudi_llama_decoder_layer_forward,
gaudi_llama_model_forward,
gaudi_llama_rmsnorm_forward,
gaudi_opt_attention_forward,
gaudi_opt_decoder_forward,
gaudi_opt_decoder_layer_forward,
Expand Down Expand Up @@ -170,6 +171,7 @@ def adapt_transformers_to_gaudi():
transformers.models.llama.modeling_llama.LlamaModel.forward = gaudi_llama_model_forward
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = gaudi_llama_decoder_layer_forward
transformers.models.llama.modeling_llama.LlamaAttention.forward = gaudi_llama_attention_forward
transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = gaudi_llama_rmsnorm_forward

# Dropout kernel improvement for Flan-T5
transformers.models.t5.modeling_t5.T5Stack = GaudiT5Stack
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
gaudi_llama_attention_forward,
gaudi_llama_decoder_layer_forward,
gaudi_llama_model_forward,
gaudi_llama_rmsnorm_forward,
)
from .modeling_all_models import gaudi_conv1d_forward, gaudi_get_extended_attention_mask, gaudi_invert_attention_mask
from .opt import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM, apply_rotary_pos_emb, logger


try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None


def gaudi_gpt_neox_attention_forward(
self,
hidden_states: torch.FloatTensor,
Expand Down Expand Up @@ -51,7 +58,7 @@ def gaudi_gpt_neox_attention_forward(
if has_layer_past:
seq_len += layer_past[0].shape[-2]
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
query, key = apply_customized_rope(query_rot, key_rot, cos, sin, position_ids)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)

Expand Down Expand Up @@ -378,3 +385,10 @@ def prepare_inputs_for_generation(
"past_key_values": past_key_values,
"token_idx": token_idx,
}


def apply_customized_rope(q, k, cos, sin, position_ids):
if q.device.type == "hpu" and FusedRoPE:
mandy-li marked this conversation as resolved.
Show resolved Hide resolved
return FusedRoPE.apply(q, cos, sin, position_ids), FusedRoPE.apply(k, cos, sin, position_ids)
else:
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
gaudi_llama_attention_forward,
gaudi_llama_decoder_layer_forward,
gaudi_llama_model_forward,
gaudi_llama_rmsnorm_forward,
)
43 changes: 42 additions & 1 deletion optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@
from transformers.models.llama.modeling_llama import LlamaForCausalLM, apply_rotary_pos_emb, logger


try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None

try:
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm
except ImportError:
print("Not using HPU fused kernel for RMSNorm")
FusedRMSNorm = None


def gaudi_llama_attention_forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -34,7 +47,7 @@ def gaudi_llama_attention_forward(
else:
kv_seq_len = past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
Expand Down Expand Up @@ -238,6 +251,27 @@ def custom_forward(*inputs):
)


def gaudi_llama_rmsnorm_forward(self, hidden_states):
"""
Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- override RMSNorm with Habana fused RMSNorm
"""
if hidden_states.device.type == "hpu" and FusedRMSNorm:
mandy-li marked this conversation as resolved.
Show resolved Hide resolved
orig_dtype = hidden_states.dtype
hidden_states = FusedRMSNorm.apply(hidden_states.float(), self.weight.float(), self.variance_epsilon)
return hidden_states.to(orig_dtype)
else:
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)

return self.weight * hidden_states


class GaudiLlamaForCausalLM(LlamaForCausalLM):
"""
Inherits from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand Down Expand Up @@ -346,3 +380,10 @@ def prepare_inputs_for_generation(
}
)
return model_inputs


def apply_customized_rope(q, k, cos, sin, position_ids):
if q.device.type == "hpu" and FusedRoPE:
mandy-li marked this conversation as resolved.
Show resolved Hide resolved
return FusedRoPE.apply(q, cos, sin, position_ids), FusedRoPE.apply(k, cos, sin, position_ids)
else:
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)
Loading