diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index d112a7461..42dbdfeeb 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -763,51 +763,16 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None): # pass through projection: [sq, b, h] --> [sq, b, ((np + 2 * kvp) * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) - # First: reshape so we have seqlen, batch, and num. query heads each as separate dims - # Final dim is not exactly head dim: the first (head dim) dims are query heads, - # The last (head dim * ratio of kv to q heads) each are the "k/v heads" - # (right now we treat like we have same num. heads, but smaller head dim) - - # [sq, b, ((np + 2 * kvp) * hn)] --> [sq, b, np, (hn * (1 + 2 * (kvp / np)))] - new_qkv_shape = ( - mixed_x_layer.shape[0], - mixed_x_layer.shape[1], - self.num_attention_heads_per_partition, - int( - self.hidden_size_per_attention_head - * ( - 1 - + 2 - * ( - self.num_kv_heads_per_partition - / self.num_attention_heads_per_partition - ) - ) - ), - ) - mixed_x_layer = mixed_x_layer.reshape(*new_qkv_shape) - - # Next: split our fake head dim. (last dim) so that the first (head dim) dimensions go to Q, - # the last smaller 2 * (head dim * kv to q head ratio) each divided between K and V separately + # split the last dim, so that the first (q head * head dim) dimensions go to Q, + # the last smaller 2 * (kv head * head dim) each divided between K and V separately split_sizes = ( - self.hidden_size_per_attention_head, - int( - ( - self.num_kv_heads_per_partition - / self.num_attention_heads_per_partition - ) - * self.hidden_size_per_attention_head - ), - int( - ( - self.num_kv_heads_per_partition - / self.num_attention_heads_per_partition - ) - * self.hidden_size_per_attention_head - ), + self.num_attention_heads_per_partition + * self.hidden_size_per_attention_head, + self.num_kv_heads_per_partition * self.hidden_size_per_attention_head, + self.num_kv_heads_per_partition * self.hidden_size_per_attention_head, ) - # [sq, b, np, (hn * (1 + 2 * (kvp / np)))] --> 1 x [sq, b, np, hn] , 2 x [sq, b, np, (hn * (kvp / np))] + # [sq, b, ((np + 2 * kvp) * hn)] --> 1 x [sq, b, np * hn] , 2 x [sq, b, kvp * hn] (query_layer, key_layer, value_layer) = [ x.contiguous() for x in torch.split( @@ -817,6 +782,17 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None): ) ] + # reshape Q to proper output shape (last dim = correct full "real" head size again) + # [sq, b, np * hn] --> [sq, b, np, hn] + new_query_shape = ( + query_layer.size(0), + query_layer.size(1), + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + + query_layer = query_layer.view(*new_query_shape) + # reshape K/V to proper output shape (last dim = correct full "real" head size again) # 2 x [sq, b, np, (hn * (kvp / np))] --> 2 x [sq, b, kvp, hn] new_kv_shape = (