diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 4a40aab015..58900c48ee 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -21,6 +21,7 @@ GaudiBloomMLP, GaudiGPT2Attention, GaudiGPT2LMHeadModel, + GaudiGPTJAttention, GaudiGPTJForCausalLM, GaudiGPTNeoXForCausalLM, GaudiLlamaForCausalLM, @@ -51,7 +52,6 @@ gaudi_gpt_neox_attention_forward, gaudi_gpt_neox_layer_forward, gaudi_gpt_neox_model_forward, - gaudi_gptj_attention_forward, gaudi_gptj_block_forward, gaudi_gptj_model_forward, gaudi_invert_attention_mask, @@ -152,7 +152,9 @@ def adapt_transformers_to_gaudi(): transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding = GaudiOPTLearnedPositionalEmbedding # Optimization for GPTJ on Gaudi - transformers.models.gptj.modeling_gptj.GPTJAttention.forward = gaudi_gptj_attention_forward + # From Transformers 4.27, the bias in the GPT2Attention layer is a Boolean + # Since HCCL cannot handle this dtype, we revert it back to uint8 (same behaviour as Transformers <= 4.26) + transformers.models.gptj.modeling_gptj.GPTJAttention = GaudiGPTJAttention transformers.models.gptj.modeling_gptj.GPTJForCausalLM = GaudiGPTJForCausalLM transformers.models.gptj.modeling_gptj.GPTJBlock.forward = gaudi_gptj_block_forward transformers.models.gptj.modeling_gptj.GPTJModel.forward = gaudi_gptj_model_forward diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index e11e091e67..b8b4876254 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -26,8 +26,8 @@ gaudi_gpt_neox_model_forward, ) from .gptj import ( + GaudiGPTJAttention, GaudiGPTJForCausalLM, - gaudi_gptj_attention_forward, gaudi_gptj_block_forward, gaudi_gptj_model_forward, ) diff --git a/optimum/habana/transformers/models/gptj/__init__.py b/optimum/habana/transformers/models/gptj/__init__.py index d0de98c47b..9b3b6a6434 100644 --- a/optimum/habana/transformers/models/gptj/__init__.py +++ b/optimum/habana/transformers/models/gptj/__init__.py @@ -1,6 +1,6 @@ from .modeling_gptj import ( + GaudiGPTJAttention, GaudiGPTJForCausalLM, - gaudi_gptj_attention_forward, gaudi_gptj_block_forward, gaudi_gptj_model_forward, ) diff --git a/optimum/habana/transformers/models/gptj/modeling_gptj.py b/optimum/habana/transformers/models/gptj/modeling_gptj.py index 5b14e45b52..112e169808 100644 --- a/optimum/habana/transformers/models/gptj/modeling_gptj.py +++ b/optimum/habana/transformers/models/gptj/modeling_gptj.py @@ -1,95 +1,212 @@ from typing import Optional, Tuple, Union import torch +from torch import nn from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, apply_rotary_pos_emb, logger +from transformers.models.gptj.modeling_gptj import ( + GPTJForCausalLM, + apply_rotary_pos_emb, + create_sinusoidal_positions, + logger, +) + + +class GaudiGPTJAttention(nn.Module): + def __init__(self, config): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), + ) + self.register_buffer("masked_bias", torch.tensor(-1e9)) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.embed_dim = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_attention_heads + if self.head_dim * self.num_attention_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and" + f" `num_attention_heads`: {self.num_attention_heads})." + ) + self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.rotary_dim = config.rotary_dim + pos_embd_dim = self.rotary_dim or self.embed_dim + self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim) -def gaudi_gptj_attention_forward( - self, - hidden_states: torch.FloatTensor, - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - token_idx: Optional[torch.Tensor] = None, -) -> Union[ - Tuple[torch.Tensor, Tuple[torch.Tensor]], - Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], -]: - """ - Copied from GPTJAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py - The only differences are: - - add new args token_idx - - remove is_torch_fx_proxy - - optimize KV cache - """ - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) + def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): + """ + Splits hidden dim into attn_head_size and num_attention_heads + """ + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + tensor = tensor.view(new_shape) + if rotary: + return tensor + if len(tensor.shape) == 5: + return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features) + elif len(tensor.shape) == 4: + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + + def _merge_heads(self, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + if len(tensor.shape) == 5: + tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() + elif len(tensor.shape) == 4: + tensor = tensor.permute(0, 2, 1, 3).contiguous() + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) + return tensor.view(new_shape) + + def _attn( + self, + query, + key, + value, + attention_mask=None, + head_mask=None, + ): + # compute causal mask from causal mask buffer + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() - query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) - key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) - value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) + # Keep the attention weights computation in fp32 to avoid overflow issues + query = query.to(torch.float32) + key = key.to(torch.float32) - embed_positions = self._get_embed_positions(position_ids) + attn_weights = torch.matmul(query, key.transpose(-1, -2)) - repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) - sincos = torch.gather(embed_positions, 1, repeated_position_ids) - sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) - if self.rotary_dim is not None: - k_rot = key[:, :, :, : self.rotary_dim] - k_pass = key[:, :, :, self.rotary_dim :] + attn_weights = attn_weights / self.scale_attn - q_rot = query[:, :, :, : self.rotary_dim] - q_pass = query[:, :, :, self.rotary_dim :] + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask - k_rot = apply_rotary_pos_emb(k_rot, sin, cos) - q_rot = apply_rotary_pos_emb(q_rot, sin, cos) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = attn_weights.to(value.dtype) + attn_weights = self.attn_dropout(attn_weights) - key = torch.cat([k_rot, k_pass], dim=-1) - query = torch.cat([q_rot, q_pass], dim=-1) - else: - key = apply_rotary_pos_emb(key, sin, cos) - query = apply_rotary_pos_emb(query, sin, cos) + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask - key = key.permute(0, 2, 1, 3) - query = query.permute(0, 2, 1, 3) + attn_output = torch.matmul(attn_weights, value) - if layer_past is not None: - past_key = layer_past[0] - past_value = layer_past[1] + return attn_output, attn_weights - if token_idx is not None: - past_key.index_copy_(2, token_idx - 1, key) - past_value.index_copy_(2, token_idx - 1, value) - key = past_key - value = past_value + def _get_embed_positions(self, position_ids): + embed_positions = self.embed_positions + if embed_positions.device != position_ids.device: + embed_positions = embed_positions.to(position_ids.device) + self.embed_positions = embed_positions + return embed_positions.repeat(position_ids.shape[0], 1, 1) + + def forward( + self, + hidden_states: torch.FloatTensor, + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + token_idx: Optional[torch.Tensor] = None, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: + """ + Copied from GPTJAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py + The only differences are: + - add new args token_idx + - remove is_torch_fx_proxy + - optimize KV cache + """ + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) + value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) + + embed_positions = self._get_embed_positions(position_ids) + + repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) + sincos = torch.gather(embed_positions, 1, repeated_position_ids) + sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + k_rot = apply_rotary_pos_emb(k_rot, sin, cos) + q_rot = apply_rotary_pos_emb(q_rot, sin, cos) + + key = torch.cat([k_rot, k_pass], dim=-1) + query = torch.cat([q_rot, q_pass], dim=-1) else: - key = torch.cat([past_key, key], dim=-2) - value = torch.cat([past_value, value], dim=-2) + key = apply_rotary_pos_emb(key, sin, cos) + query = apply_rotary_pos_emb(query, sin, cos) - if use_cache is True: - present = (key, value) - else: - present = None + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + + if token_idx is not None: + past_key.index_copy_(2, token_idx - 1, key) + past_value.index_copy_(2, token_idx - 1, value) + key = past_key + value = past_value + else: + key = torch.cat([past_key, key], dim=-2) + value = torch.cat([past_value, value], dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None - # compute self-attention: V x Softmax(QK^T) - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + # compute self-attention: V x Softmax(QK^T) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) - attn_output = self.out_proj(attn_output) - attn_output = self.resid_dropout(attn_output) + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) - return outputs # a, present, (attentions) + return outputs # a, present, (attentions) def gaudi_gptj_block_forward(