From 949b88e0da598f88b46734305cbe0b352115a727 Mon Sep 17 00:00:00 2001 From: "Jin, Baihui" Date: Tue, 17 Oct 2023 03:24:27 +0000 Subject: [PATCH 1/5] remove dma --- optimum/habana/transformers/models/gptj/modeling_gptj.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/gptj/modeling_gptj.py b/optimum/habana/transformers/models/gptj/modeling_gptj.py index 0b157d7fc0..da0a336834 100644 --- a/optimum/habana/transformers/models/gptj/modeling_gptj.py +++ b/optimum/habana/transformers/models/gptj/modeling_gptj.py @@ -175,8 +175,8 @@ def forward( key = apply_rotary_pos_emb(key, sin, cos) query = apply_rotary_pos_emb(query, sin, cos) - key = key.permute(0, 2, 1, 3) - query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3).contiguous() + query = query.permute(0, 2, 1, 3).contiguous() if layer_past is not None: past_key = layer_past[0] From 99ef1e1d0f4340bd89064341412701b52e01b1e7 Mon Sep 17 00:00:00 2001 From: "Jin, Baihui" Date: Mon, 23 Oct 2023 08:50:54 +0000 Subject: [PATCH 2/5] update --- .../transformers/models/gptj/modeling_gptj.py | 54 +++++++++++-------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/optimum/habana/transformers/models/gptj/modeling_gptj.py b/optimum/habana/transformers/models/gptj/modeling_gptj.py index da0a336834..93bd6d3613 100644 --- a/optimum/habana/transformers/models/gptj/modeling_gptj.py +++ b/optimum/habana/transformers/models/gptj/modeling_gptj.py @@ -43,8 +43,6 @@ def __init__(self, config): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.rotary_dim = config.rotary_dim - pos_embd_dim = self.rotary_dim or self.embed_dim - self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim) def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): """ @@ -87,8 +85,8 @@ def _attn( causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() # Keep the attention weights computation in fp32 to avoid overflow issues - query = query.to(torch.float32).contiguous() - key = key.to(torch.float32).contiguous() + query = query.contiguous() + key = key.contiguous() value = value.contiguous() attn_weights = torch.matmul(query, key.transpose(-1, -2)) @@ -96,7 +94,7 @@ def _attn( mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device="hpu") attn_weights = torch.where(causal_mask, attn_weights, mask_value) attn_weights = attn_weights / self.scale_attn @@ -117,13 +115,6 @@ def _attn( return attn_output, attn_weights - def _get_embed_positions(self, position_ids): - embed_positions = self.embed_positions - if embed_positions.device != position_ids.device: - embed_positions = embed_positions.to(position_ids.device) - self.embed_positions = embed_positions - return embed_positions.repeat(position_ids.shape[0], 1, 1) - def forward( self, hidden_states: torch.FloatTensor, @@ -134,6 +125,8 @@ def forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, + sin: Optional[torch.Tensor] = None, + cos: Optional[torch.Tensor] = None, ) -> Union[ Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], @@ -153,27 +146,21 @@ def forward( key = self._split_heads(key, self.num_attention_heads, self.head_dim, True).contiguous() value = self._split_heads(value, self.num_attention_heads, self.head_dim, False).contiguous() - embed_positions = self._get_embed_positions(position_ids) - - repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) - sincos = torch.gather(embed_positions, 1, repeated_position_ids) - sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) - if self.rotary_dim is not None: k_rot = key[:, :, :, : self.rotary_dim] k_pass = key[:, :, :, self.rotary_dim :] q_rot = query[:, :, :, : self.rotary_dim] q_pass = query[:, :, :, self.rotary_dim :] - - k_rot = apply_rotary_pos_emb(k_rot, sin, cos) - q_rot = apply_rotary_pos_emb(q_rot, sin, cos) + # Note: it appears that if we use bf16 RoPE(whether use fused kernel or not), there could be acc issue, hence use fp32 RoPE here Fused kernel feasibility needs to be confirmed in the future + k_rot = apply_rotary_pos_emb(k_rot.to(torch.float32), sin, cos).to(torch.bfloat16) + q_rot = apply_rotary_pos_emb(q_rot.to(torch.float32), sin, cos).to(torch.bfloat16) key = torch.cat([k_rot, k_pass], dim=-1) query = torch.cat([q_rot, q_pass], dim=-1) else: - key = apply_rotary_pos_emb(key, sin, cos) - query = apply_rotary_pos_emb(query, sin, cos) + key = apply_rotary_pos_emb(key.to(torch.float32), sin, cos).to(torch.bfloat16) + query = apply_rotary_pos_emb(query.to(torch.float32), sin, cos).to(torch.bfloat16) key = key.permute(0, 2, 1, 3).contiguous() query = query.permute(0, 2, 1, 3).contiguous() @@ -220,6 +207,8 @@ def gaudi_gptj_block_forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, + sin: Optional[torch.Tensor] = None, + cos: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: """ Copied from GPTJBlock.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py @@ -237,6 +226,8 @@ def gaudi_gptj_block_forward( use_cache=use_cache, output_attentions=output_attentions, token_idx=token_idx, + sin=sin, + cos=cos, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -266,6 +257,8 @@ def gaudi_gptj_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + sin: Optional[torch.Tensor] = None, + cos: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from GPTJModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py @@ -359,6 +352,19 @@ def gaudi_gptj_model_forward( presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None + rotary_dim = self.config.rotary_dim + embed_dim = self.config.hidden_size + pos_embd_dim = rotary_dim or embed_dim + max_positions = self.config.max_position_embeddings + embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim).to(torch.bfloat16) + embed_positions = embed_positions.repeat(position_ids.shape[0], 1, 1) + if embed_positions.device != position_ids.device: + embed_positions = embed_positions.to(position_ids.device) + repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) + sincos = torch.gather(embed_positions, 1, repeated_position_ids) + sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) + sin = sin.contiguous() + cos = cos.contiguous() for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): # Model parallel if self.model_parallel: @@ -401,6 +407,8 @@ def custom_forward(*inputs): use_cache=use_cache, output_attentions=output_attentions, token_idx=token_idx, + sin=sin, + cos=cos, ) hidden_states = outputs[0] From 3df001ea90dbb61408aaa3d2e8cdcd9c432ce2b6 Mon Sep 17 00:00:00 2001 From: "Jin, Baihui" Date: Mon, 23 Oct 2023 10:04:32 +0000 Subject: [PATCH 3/5] fix --- optimum/habana/transformers/models/gptj/modeling_gptj.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/gptj/modeling_gptj.py b/optimum/habana/transformers/models/gptj/modeling_gptj.py index 93bd6d3613..f061936dc5 100644 --- a/optimum/habana/transformers/models/gptj/modeling_gptj.py +++ b/optimum/habana/transformers/models/gptj/modeling_gptj.py @@ -84,7 +84,6 @@ def _attn( query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() - # Keep the attention weights computation in fp32 to avoid overflow issues query = query.contiguous() key = key.contiguous() value = value.contiguous() @@ -94,7 +93,7 @@ def _attn( mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device="hpu") + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device) attn_weights = torch.where(causal_mask, attn_weights, mask_value) attn_weights = attn_weights / self.scale_attn @@ -137,6 +136,7 @@ def forward( - add new args token_idx - remove is_torch_fx_proxy - optimize KV cache + - pass sin and cos from upper level as they are identical for each attn block """ query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) @@ -214,6 +214,7 @@ def gaudi_gptj_block_forward( Copied from GPTJBlock.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py The only differences are: - add new args token_idx + - pass sin and cos from upper level as they are identical for each attn block """ residual = hidden_states hidden_states = self.ln_1(hidden_states) @@ -264,6 +265,7 @@ def gaudi_gptj_model_forward( Copied from GPTJModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py The only differences are: - add new args token_idx + - pass sin and cos from upper level as they are identical for each attn block """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( From abf3f80da7f1087ce2f8337bed52afa6b217b744 Mon Sep 17 00:00:00 2001 From: "Jin, Baihui" Date: Wed, 25 Oct 2023 06:18:37 +0000 Subject: [PATCH 4/5] add common --- optimum/habana/transformers/models/gptj/modeling_gptj.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/optimum/habana/transformers/models/gptj/modeling_gptj.py b/optimum/habana/transformers/models/gptj/modeling_gptj.py index f061936dc5..48ec8c06e6 100644 --- a/optimum/habana/transformers/models/gptj/modeling_gptj.py +++ b/optimum/habana/transformers/models/gptj/modeling_gptj.py @@ -354,6 +354,8 @@ def gaudi_gptj_model_forward( presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None + + # replace original `_get_embed_positions` method and sin cos calculation in the attn block here to improve perf rotary_dim = self.config.rotary_dim embed_dim = self.config.hidden_size pos_embd_dim = rotary_dim or embed_dim @@ -367,6 +369,7 @@ def gaudi_gptj_model_forward( sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) sin = sin.contiguous() cos = cos.contiguous() + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): # Model parallel if self.model_parallel: From ad550fea979fea367334123b92d7b928b54f4f35 Mon Sep 17 00:00:00 2001 From: "Jin, Baihui" Date: Wed, 25 Oct 2023 06:19:50 +0000 Subject: [PATCH 5/5] formatting --- optimum/habana/transformers/models/gptj/modeling_gptj.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/gptj/modeling_gptj.py b/optimum/habana/transformers/models/gptj/modeling_gptj.py index 48ec8c06e6..8d8f771c4a 100644 --- a/optimum/habana/transformers/models/gptj/modeling_gptj.py +++ b/optimum/habana/transformers/models/gptj/modeling_gptj.py @@ -354,7 +354,7 @@ def gaudi_gptj_model_forward( presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - + # replace original `_get_embed_positions` method and sin cos calculation in the attn block here to improve perf rotary_dim = self.config.rotary_dim embed_dim = self.config.hidden_size @@ -369,7 +369,7 @@ def gaudi_gptj_model_forward( sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) sin = sin.contiguous() cos = cos.contiguous() - + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): # Model parallel if self.model_parallel: