Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

self.proj in CausalSelfAttention too large if config.n_query_groups < config.n_head? #1890

Closed
mseeger opened this issue Dec 25, 2024 · 7 comments
Labels
question Further information is requested

Comments

@mseeger
Copy link
Contributor

mseeger commented Dec 25, 2024

There are only config.n_query_groups V vectors. The shape of self.proj should really be (config.head_size * config.n_query_groups, config.n_embd). It should be smaller in the same sense as self.attn is smaller if there are less query groups than heads.

In the code, you expand the V matrix before multiplying with the linear map. This is equivalent to using a smaller weight matrix, but what is done right now is more expensive and needs more memory.

I could send a PR to fix this, but I am wondering about the compatibility with Hugging Face. Are they also doing this? Do we need to change import scripts for pre-trained models then?

@mseeger mseeger added the question Further information is requested label Dec 25, 2024
@Andrei-Aksionov
Copy link
Collaborator

Hello @mseeger

I'm not entirely sure why the size of self.proj should be smaller when using GQA (Grouped Query Attention). The input shape to the projection layer remains unchanged. GQA primarily shares/reuses the key and value during computations, but the output shape should still align with the number of heads multiplied by the size of each head.

@mseeger
Copy link
Contributor Author

mseeger commented Dec 26, 2024

Here is what I'd do:

self.proj = nn.Linear(config.head_size * config.n_query_groups, config.n_embd, bias=config.bias)

The output of all still has the same shape.

And remove this one: v = v.expand(*q.shape).

scaled_dot_product_attention is fine with v having a shorter final dimension.

@Andrei-Aksionov
Copy link
Collaborator

Still don't understand 🙂

If you don't expand V, then you will get a mismatch at dimension 1, i.e.:

  • attention scores of shape (B, nh, T, T)
  • V of shape (B, nh_v, T, hs)

If to drop all prep steps, this is the core of SDPA (python pseudo-code):

attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value

From the PR that I still need to merge 🫠:

# ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
  y = self.scaled_dot_product_attention(q, k, v, mask)

scaled_dot_product_attention is fine with v having a shorter final dimension.

The only reason that I see, is that in the latest version of SDPA there is enable_gqa argument, that does the expansion.
But it by default is False. Perhaps you enabled it?

Or I'm clearly missing something.


On the side note, didn't see that in 2.5 the SDPA is updated.
We need to relax the version constraint for pytorch, so the latest one (2.5) is installed and let SDPA deal with GQA.

@mseeger
Copy link
Contributor Author

mseeger commented Dec 26, 2024

OK, I see. You are right. It is still a bit weird.

@mseeger mseeger closed this as completed Dec 26, 2024
@mseeger
Copy link
Contributor Author

mseeger commented Dec 30, 2024

https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch-nn-functional-scaled-dot-product-attention

Adding a comment here: If enable_gqa=True, you can have Hq > H. In your notation, H = n_query_groups, Hq = n_head. The final linear self.proj still has full size, though.

Not sure how important this is.

@mseeger mseeger reopened this Dec 30, 2024
@Andrei-Aksionov
Copy link
Collaborator

Yes, if you set enable_gqa=True, then SDPA will expand K and V by itself.
It can be seen from the pseudocode in the docs:

...
if enable_gqa:
    key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
    value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
...

The benefit of it, as I understand, is that the kernel might do the expansion during calculation, instead of expanding K&V first and then moving tensors from HBM ("global" memory) to the "shared" memory.
But I haven't noticed any speed-up.

@mseeger
Copy link
Contributor Author

mseeger commented Jan 6, 2025

Resolved

@mseeger mseeger closed this as completed Jan 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants