diff --git a/optimum/habana/transformers/models/gptj/modeling_gptj.py b/optimum/habana/transformers/models/gptj/modeling_gptj.py index 0b157d7fc0..8d8f771c4a 100644 --- a/optimum/habana/transformers/models/gptj/modeling_gptj.py +++ b/optimum/habana/transformers/models/gptj/modeling_gptj.py @@ -43,8 +43,6 @@ def __init__(self, config): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.rotary_dim = config.rotary_dim - pos_embd_dim = self.rotary_dim or self.embed_dim - self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim) def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): """ @@ -86,9 +84,8 @@ def _attn( query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() - # Keep the attention weights computation in fp32 to avoid overflow issues - query = query.to(torch.float32).contiguous() - key = key.to(torch.float32).contiguous() + query = query.contiguous() + key = key.contiguous() value = value.contiguous() attn_weights = torch.matmul(query, key.transpose(-1, -2)) @@ -96,7 +93,7 @@ def _attn( mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device) attn_weights = torch.where(causal_mask, attn_weights, mask_value) attn_weights = attn_weights / self.scale_attn @@ -117,13 +114,6 @@ def _attn( return attn_output, attn_weights - def _get_embed_positions(self, position_ids): - embed_positions = self.embed_positions - if embed_positions.device != position_ids.device: - embed_positions = embed_positions.to(position_ids.device) - self.embed_positions = embed_positions - return embed_positions.repeat(position_ids.shape[0], 1, 1) - def forward( self, hidden_states: torch.FloatTensor, @@ -134,6 +124,8 @@ def forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, + sin: Optional[torch.Tensor] = None, + cos: Optional[torch.Tensor] = None, ) -> Union[ Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], @@ -144,6 +136,7 @@ def forward( - add new args token_idx - remove is_torch_fx_proxy - optimize KV cache + - pass sin and cos from upper level as they are identical for each attn block """ query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) @@ -153,30 +146,24 @@ def forward( key = self._split_heads(key, self.num_attention_heads, self.head_dim, True).contiguous() value = self._split_heads(value, self.num_attention_heads, self.head_dim, False).contiguous() - embed_positions = self._get_embed_positions(position_ids) - - repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) - sincos = torch.gather(embed_positions, 1, repeated_position_ids) - sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) - if self.rotary_dim is not None: k_rot = key[:, :, :, : self.rotary_dim] k_pass = key[:, :, :, self.rotary_dim :] q_rot = query[:, :, :, : self.rotary_dim] q_pass = query[:, :, :, self.rotary_dim :] - - k_rot = apply_rotary_pos_emb(k_rot, sin, cos) - q_rot = apply_rotary_pos_emb(q_rot, sin, cos) + # Note: it appears that if we use bf16 RoPE(whether use fused kernel or not), there could be acc issue, hence use fp32 RoPE here Fused kernel feasibility needs to be confirmed in the future + k_rot = apply_rotary_pos_emb(k_rot.to(torch.float32), sin, cos).to(torch.bfloat16) + q_rot = apply_rotary_pos_emb(q_rot.to(torch.float32), sin, cos).to(torch.bfloat16) key = torch.cat([k_rot, k_pass], dim=-1) query = torch.cat([q_rot, q_pass], dim=-1) else: - key = apply_rotary_pos_emb(key, sin, cos) - query = apply_rotary_pos_emb(query, sin, cos) + key = apply_rotary_pos_emb(key.to(torch.float32), sin, cos).to(torch.bfloat16) + query = apply_rotary_pos_emb(query.to(torch.float32), sin, cos).to(torch.bfloat16) - key = key.permute(0, 2, 1, 3) - query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3).contiguous() + query = query.permute(0, 2, 1, 3).contiguous() if layer_past is not None: past_key = layer_past[0] @@ -220,11 +207,14 @@ def gaudi_gptj_block_forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, + sin: Optional[torch.Tensor] = None, + cos: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: """ Copied from GPTJBlock.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py The only differences are: - add new args token_idx + - pass sin and cos from upper level as they are identical for each attn block """ residual = hidden_states hidden_states = self.ln_1(hidden_states) @@ -237,6 +227,8 @@ def gaudi_gptj_block_forward( use_cache=use_cache, output_attentions=output_attentions, token_idx=token_idx, + sin=sin, + cos=cos, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -266,11 +258,14 @@ def gaudi_gptj_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + sin: Optional[torch.Tensor] = None, + cos: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from GPTJModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py The only differences are: - add new args token_idx + - pass sin and cos from upper level as they are identical for each attn block """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -359,6 +354,22 @@ def gaudi_gptj_model_forward( presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None + + # replace original `_get_embed_positions` method and sin cos calculation in the attn block here to improve perf + rotary_dim = self.config.rotary_dim + embed_dim = self.config.hidden_size + pos_embd_dim = rotary_dim or embed_dim + max_positions = self.config.max_position_embeddings + embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim).to(torch.bfloat16) + embed_positions = embed_positions.repeat(position_ids.shape[0], 1, 1) + if embed_positions.device != position_ids.device: + embed_positions = embed_positions.to(position_ids.device) + repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) + sincos = torch.gather(embed_positions, 1, repeated_position_ids) + sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) + sin = sin.contiguous() + cos = cos.contiguous() + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): # Model parallel if self.model_parallel: @@ -401,6 +412,8 @@ def custom_forward(*inputs): use_cache=use_cache, output_attentions=output_attentions, token_idx=token_idx, + sin=sin, + cos=cos, ) hidden_states = outputs[0]