Skip to content

Commit

Permalink
fix a GQA issue (#1314) (#1315)
Browse files Browse the repository at this point in the history
- do not create a fake head dim and split the 'mixed_x_layer' into QKV layers directly.
  • Loading branch information
tiandeyu-cs authored and jahatef committed Nov 29, 2024
1 parent bdb3658 commit ff7f328
Showing 1 changed file with 18 additions and 42 deletions.
60 changes: 18 additions & 42 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,51 +763,16 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None):
# pass through projection: [sq, b, h] --> [sq, b, ((np + 2 * kvp) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

# First: reshape so we have seqlen, batch, and num. query heads each as separate dims
# Final dim is not exactly head dim: the first (head dim) dims are query heads,
# The last (head dim * ratio of kv to q heads) each are the "k/v heads"
# (right now we treat like we have same num. heads, but smaller head dim)

# [sq, b, ((np + 2 * kvp) * hn)] --> [sq, b, np, (hn * (1 + 2 * (kvp / np)))]
new_qkv_shape = (
mixed_x_layer.shape[0],
mixed_x_layer.shape[1],
self.num_attention_heads_per_partition,
int(
self.hidden_size_per_attention_head
* (
1
+ 2
* (
self.num_kv_heads_per_partition
/ self.num_attention_heads_per_partition
)
)
),
)
mixed_x_layer = mixed_x_layer.reshape(*new_qkv_shape)

# Next: split our fake head dim. (last dim) so that the first (head dim) dimensions go to Q,
# the last smaller 2 * (head dim * kv to q head ratio) each divided between K and V separately
# split the last dim, so that the first (q head * head dim) dimensions go to Q,
# the last smaller 2 * (kv head * head dim) each divided between K and V separately
split_sizes = (
self.hidden_size_per_attention_head,
int(
(
self.num_kv_heads_per_partition
/ self.num_attention_heads_per_partition
)
* self.hidden_size_per_attention_head
),
int(
(
self.num_kv_heads_per_partition
/ self.num_attention_heads_per_partition
)
* self.hidden_size_per_attention_head
),
self.num_attention_heads_per_partition
* self.hidden_size_per_attention_head,
self.num_kv_heads_per_partition * self.hidden_size_per_attention_head,
self.num_kv_heads_per_partition * self.hidden_size_per_attention_head,
)

# [sq, b, np, (hn * (1 + 2 * (kvp / np)))] --> 1 x [sq, b, np, hn] , 2 x [sq, b, np, (hn * (kvp / np))]
# [sq, b, ((np + 2 * kvp) * hn)] --> 1 x [sq, b, np * hn] , 2 x [sq, b, kvp * hn]
(query_layer, key_layer, value_layer) = [
x.contiguous()
for x in torch.split(
Expand All @@ -817,6 +782,17 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None):
)
]

# reshape Q to proper output shape (last dim = correct full "real" head size again)
# [sq, b, np * hn] --> [sq, b, np, hn]
new_query_shape = (
query_layer.size(0),
query_layer.size(1),
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)

query_layer = query_layer.view(*new_query_shape)

# reshape K/V to proper output shape (last dim = correct full "real" head size again)
# 2 x [sq, b, np, (hn * (kvp / np))] --> 2 x [sq, b, kvp, hn]
new_kv_shape = (
Expand Down

0 comments on commit ff7f328

Please sign in to comment.