Skip to content

Commit

Permalink
fix gptj HCCL issue occured in DDP (#318)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi authored and schoi-habana committed Aug 10, 2023
1 parent 613afe7 commit 5d4765e
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 73 deletions.
6 changes: 4 additions & 2 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
GaudiBloomMLP,
GaudiGPT2Attention,
GaudiGPT2LMHeadModel,
GaudiGPTJAttention,
GaudiGPTJForCausalLM,
GaudiGPTNeoXForCausalLM,
GaudiLlamaForCausalLM,
Expand Down Expand Up @@ -51,7 +52,6 @@
gaudi_gpt_neox_attention_forward,
gaudi_gpt_neox_layer_forward,
gaudi_gpt_neox_model_forward,
gaudi_gptj_attention_forward,
gaudi_gptj_block_forward,
gaudi_gptj_model_forward,
gaudi_invert_attention_mask,
Expand Down Expand Up @@ -152,7 +152,9 @@ def adapt_transformers_to_gaudi():
transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding = GaudiOPTLearnedPositionalEmbedding

# Optimization for GPTJ on Gaudi
transformers.models.gptj.modeling_gptj.GPTJAttention.forward = gaudi_gptj_attention_forward
# From Transformers 4.27, the bias in the GPT2Attention layer is a Boolean
# Since HCCL cannot handle this dtype, we revert it back to uint8 (same behaviour as Transformers <= 4.26)
transformers.models.gptj.modeling_gptj.GPTJAttention = GaudiGPTJAttention
transformers.models.gptj.modeling_gptj.GPTJForCausalLM = GaudiGPTJForCausalLM
transformers.models.gptj.modeling_gptj.GPTJBlock.forward = gaudi_gptj_block_forward
transformers.models.gptj.modeling_gptj.GPTJModel.forward = gaudi_gptj_model_forward
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
gaudi_gpt_neox_model_forward,
)
from .gptj import (
GaudiGPTJAttention,
GaudiGPTJForCausalLM,
gaudi_gptj_attention_forward,
gaudi_gptj_block_forward,
gaudi_gptj_model_forward,
)
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/gptj/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .modeling_gptj import (
GaudiGPTJAttention,
GaudiGPTJForCausalLM,
gaudi_gptj_attention_forward,
gaudi_gptj_block_forward,
gaudi_gptj_model_forward,
)
255 changes: 186 additions & 69 deletions optimum/habana/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,212 @@
from typing import Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, apply_rotary_pos_emb, logger
from transformers.models.gptj.modeling_gptj import (
GPTJForCausalLM,
apply_rotary_pos_emb,
create_sinusoidal_positions,
logger,
)


class GaudiGPTJAttention(nn.Module):
def __init__(self, config):
super().__init__()

max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e9))

self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)

self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_attention_heads
if self.head_dim * self.num_attention_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())

self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.rotary_dim = config.rotary_dim
pos_embd_dim = self.rotary_dim or self.embed_dim
self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)

def gaudi_gptj_attention_forward(
self,
hidden_states: torch.FloatTensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
token_idx: Optional[torch.Tensor] = None,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
"""
Copied from GPTJAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py
The only differences are:
- add new args token_idx
- remove is_torch_fx_proxy
- optimize KV cache
"""
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(new_shape)
if rotary:
return tensor
if len(tensor.shape) == 5:
return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
elif len(tensor.shape) == 4:
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")

def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)

def _attn(
self,
query,
key,
value,
attention_mask=None,
head_mask=None,
):
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()

query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
key = key.to(torch.float32)

embed_positions = self._get_embed_positions(position_ids)
attn_weights = torch.matmul(query, key.transpose(-1, -2))

repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)

if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim :]
attn_weights = attn_weights / self.scale_attn

q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask

k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_weights = self.attn_dropout(attn_weights)

key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, sin, cos)
query = apply_rotary_pos_emb(query, sin, cos)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask

key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
attn_output = torch.matmul(attn_weights, value)

if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
return attn_output, attn_weights

if token_idx is not None:
past_key.index_copy_(2, token_idx - 1, key)
past_value.index_copy_(2, token_idx - 1, value)
key = past_key
value = past_value
def _get_embed_positions(self, position_ids):
embed_positions = self.embed_positions
if embed_positions.device != position_ids.device:
embed_positions = embed_positions.to(position_ids.device)
self.embed_positions = embed_positions
return embed_positions.repeat(position_ids.shape[0], 1, 1)

def forward(
self,
hidden_states: torch.FloatTensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
token_idx: Optional[torch.Tensor] = None,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
"""
Copied from GPTJAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py
The only differences are:
- add new args token_idx
- remove is_torch_fx_proxy
- optimize KV cache
"""
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)

query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)

embed_positions = self._get_embed_positions(position_ids)

repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)

if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim :]

q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]

k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)

key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = torch.cat([past_key, key], dim=-2)
value = torch.cat([past_value, value], dim=-2)
key = apply_rotary_pos_emb(key, sin, cos)
query = apply_rotary_pos_emb(query, sin, cos)

if use_cache is True:
present = (key, value)
else:
present = None
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)

if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]

if token_idx is not None:
past_key.index_copy_(2, token_idx - 1, key)
past_value.index_copy_(2, token_idx - 1, value)
key = past_key
value = past_value
else:
key = torch.cat([past_key, key], dim=-2)
value = torch.cat([past_value, value], dim=-2)

if use_cache is True:
present = (key, value)
else:
present = None

# compute self-attention: V x Softmax(QK^T)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# compute self-attention: V x Softmax(QK^T)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)

return outputs # a, present, (attentions)
return outputs # a, present, (attentions)


def gaudi_gptj_block_forward(
Expand Down

0 comments on commit 5d4765e

Please sign in to comment.