Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove GPTJ dma before mha #468

Merged
merged 6 commits into from
Oct 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 39 additions & 26 deletions optimum/habana/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def __init__(self, config):
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.rotary_dim = config.rotary_dim
pos_embd_dim = self.rotary_dim or self.embed_dim
self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)

def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
"""
regisss marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -86,17 +84,16 @@ def _attn(
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()

# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32).contiguous()
key = key.to(torch.float32).contiguous()
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

attn_weights = torch.matmul(query, key.transpose(-1, -2))

mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)

attn_weights = attn_weights / self.scale_attn
Expand All @@ -117,13 +114,6 @@ def _attn(

return attn_output, attn_weights

def _get_embed_positions(self, position_ids):
embed_positions = self.embed_positions
if embed_positions.device != position_ids.device:
embed_positions = embed_positions.to(position_ids.device)
self.embed_positions = embed_positions
return embed_positions.repeat(position_ids.shape[0], 1, 1)

def forward(
self,
hidden_states: torch.FloatTensor,
Expand All @@ -134,6 +124,8 @@ def forward(
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
token_idx: Optional[torch.Tensor] = None,
sin: Optional[torch.Tensor] = None,
ZhaiFeiyue marked this conversation as resolved.
Show resolved Hide resolved
cos: Optional[torch.Tensor] = None,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
Expand All @@ -144,6 +136,7 @@ def forward(
- add new args token_idx
- remove is_torch_fx_proxy
- optimize KV cache
- pass sin and cos from upper level as they are identical for each attn block
"""
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
Expand All @@ -153,30 +146,24 @@ def forward(
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True).contiguous()
value = self._split_heads(value, self.num_attention_heads, self.head_dim, False).contiguous()

embed_positions = self._get_embed_positions(position_ids)

repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)

if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim :]

q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]

k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
# Note: it appears that if we use bf16 RoPE(whether use fused kernel or not), there could be acc issue, hence use fp32 RoPE here Fused kernel feasibility needs to be confirmed in the future
k_rot = apply_rotary_pos_emb(k_rot.to(torch.float32), sin, cos).to(torch.bfloat16)
q_rot = apply_rotary_pos_emb(q_rot.to(torch.float32), sin, cos).to(torch.bfloat16)

key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, sin, cos)
query = apply_rotary_pos_emb(query, sin, cos)
key = apply_rotary_pos_emb(key.to(torch.float32), sin, cos).to(torch.bfloat16)
query = apply_rotary_pos_emb(query.to(torch.float32), sin, cos).to(torch.bfloat16)

key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3).contiguous()
query = query.permute(0, 2, 1, 3).contiguous()

if layer_past is not None:
past_key = layer_past[0]
Expand Down Expand Up @@ -220,11 +207,14 @@ def gaudi_gptj_block_forward(
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
token_idx: Optional[torch.Tensor] = None,
sin: Optional[torch.Tensor] = None,
ZhaiFeiyue marked this conversation as resolved.
Show resolved Hide resolved
cos: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
"""
Copied from GPTJBlock.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py
The only differences are:
- add new args token_idx
- pass sin and cos from upper level as they are identical for each attn block
"""
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
Expand All @@ -237,6 +227,8 @@ def gaudi_gptj_block_forward(
use_cache=use_cache,
output_attentions=output_attentions,
token_idx=token_idx,
sin=sin,
cos=cos,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
Expand Down Expand Up @@ -266,11 +258,14 @@ def gaudi_gptj_model_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
sin: Optional[torch.Tensor] = None,
cos: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Copied from GPTJModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py
The only differences are:
- add new args token_idx
- pass sin and cos from upper level as they are identical for each attn block
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -359,6 +354,22 @@ def gaudi_gptj_model_forward(
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None

# replace original `_get_embed_positions` method and sin cos calculation in the attn block here to improve perf
rotary_dim = self.config.rotary_dim
ZhaiFeiyue marked this conversation as resolved.
Show resolved Hide resolved
embed_dim = self.config.hidden_size
pos_embd_dim = rotary_dim or embed_dim
max_positions = self.config.max_position_embeddings
embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim).to(torch.bfloat16)
embed_positions = embed_positions.repeat(position_ids.shape[0], 1, 1)
if embed_positions.device != position_ids.device:
embed_positions = embed_positions.to(position_ids.device)
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
sin = sin.contiguous()
cos = cos.contiguous()
Comment on lines +359 to +371
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this piece of code comes from the code blocks that were removed above. Could this be moved to a dedicated method that would be called here please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theoretically it can be done, but test shows that an additional memcpy occurred, perf drop detail as follow.
Throughput (including tokenization) = 3885.6094019038055 tokens/second
Memory allocated = 27.33 GB
Max memory allocated = 28.73 GB
Total memory available = 94.46 GB
Graph compilation duration = 8.958231755999805 seconds

Copy link
Contributor Author

@BaihuiJin BaihuiJin Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, changes looks like this
def get_embed_positions(embed_positions, position_ids):
embed_positions = embed_positions.repeat(position_ids.shape[0], 1, 1)
if embed_positions.device != position_ids.device:
embed_positions = embed_positions.to(position_ids.device)
return embed_positions

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's surprising as objects are passed to functions by references if I'm not mistaken.
Okay, in that case, could you just add a comment above this block saying which methods it replaces, and also add a blank line right above and below please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's surprising as objects are passed to functions by references if I'm not mistaken. Okay, in that case, could you just add a comment above this block saying which methods it replaces, and also add a blank line right above and below please?

Surprising indeed. Anyway, changed accordingly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's surprising as objects are passed to functions by references if I'm not mistaken. Okay, in that case, could you just add a comment above this block saying which methods it replaces, and also add a blank line right above and below please?

By the way, I think the make style removed blank line I added below this block : )


for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
Expand Down Expand Up @@ -401,6 +412,8 @@ def custom_forward(*inputs):
use_cache=use_cache,
output_attentions=output_attentions,
token_idx=token_idx,
sin=sin,
cos=cos,
)

hidden_states = outputs[0]
Expand Down
Loading