Skip to content

Commit

Permalink
use function.
Browse files Browse the repository at this point in the history
Signed-off-by: Ye, Xinyu <[email protected]>
  • Loading branch information
XinyuYe-Intel committed Sep 20, 2024
1 parent 70d09ad commit dad1d95
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 58 deletions.
4 changes: 2 additions & 2 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
GaudiGPTJForCausalLM,
GaudiGPTJModel,
GaudiGPTNeoForCausalLM,
GaudiGPTNeoBlock,
GaudiGPTNeoXForCausalLM,
GaudiGPTNeoXLayer,
GaudiLlamaAttention,
Expand Down Expand Up @@ -144,6 +143,7 @@
gaudi_gpt_bigcode_attention_forward,
gaudi_gpt_bigcode_block_forward,
gaudi_gpt_bigcode_model_forward,
gaudi_gpt_neo_block_forward,
gaudi_gpt_neo_attention_forward,
gaudi_gpt_neo_model_forward,
gaudi_gpt_neox_attention_forward,
Expand Down Expand Up @@ -364,7 +364,7 @@ def adapt_transformers_to_gaudi():
# Optimization for gpt-neo generation on Gaudi
transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM = GaudiGPTNeoForCausalLM
transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoModel.forward = gaudi_gpt_neo_model_forward
transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock = GaudiGPTNeoBlock
transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock.forward = gaudi_gpt_neo_block_forward
transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoAttention.forward = gaudi_gpt_neo_attention_forward
transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoSelfAttention.forward = gaudi_gpt_neo_selfattention_forward

Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
)
from .gpt_neo import (
GaudiGPTNeoForCausalLM,
GaudiGPTNeoBlock,
gaudi_gpt_neo_block_forward,
gaudi_gpt_neo_attention_forward,
gaudi_gpt_neo_selfattention_forward,
gaudi_gpt_neo_model_forward,
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/gpt_neo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .modeling_gpt_neo import (
GaudiGPTNeoForCausalLM,
GaudiGPTNeoBlock,
gaudi_gpt_neo_block_forward,
gaudi_gpt_neo_attention_forward,
gaudi_gpt_neo_selfattention_forward,
gaudi_gpt_neo_model_forward,
Expand Down
95 changes: 41 additions & 54 deletions optimum/habana/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
)
from transformers.models.gpt_neo.modeling_gpt_neo import (
GPTNeoForCausalLM,
GPTNeoBlock,
GPTNeoAttention,
GPTNeoMLP,
logger
)

Expand Down Expand Up @@ -99,59 +96,49 @@ def gaudi_gpt_neo_selfattention_forward(
return outputs # a, present, (attentions)


class GaudiGPTNeoBlock(GPTNeoBlock):
def __init__(self, config, layer_id):
super(GPTNeoBlock, self).__init__()
hidden_size = config.hidden_size
inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
self.ln_1 = torch.nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTNeoAttention(config, layer_id)
self.ln_2 = torch.nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTNeoMLP(inner_dim, config)

def forward(
self,
def gaudi_gpt_neo_block_forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
use_cache=False,
output_attentions=False,
token_idx=None,
):
"""
Copied from GPTNeoBlock.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neo/modeling_gpt_neo.py
The only differences are:
- add new args token_idx
"""
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
use_cache=False,
output_attentions=False,
token_idx=None,
):
"""
Copied from GPTNeoBlock.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neo/modeling_gpt_neo.py
The only differences are:
- add new args token_idx
"""
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
token_idx=token_idx,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + residual

residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states

if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
token_idx=token_idx,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + residual

residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states

if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]

return outputs # hidden_states, present, (attentions, cross_attentions)
return outputs # hidden_states, present, (attentions, cross_attentions)


def gaudi_gpt_neo_model_forward(
Expand Down

0 comments on commit dad1d95

Please sign in to comment.