-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
self.proj
in CausalSelfAttention
too large if config.n_query_groups < config.n_head
?
#1890
Comments
Hello @mseeger I'm not entirely sure why the size of |
Here is what I'd do:
The output of all still has the same shape. And remove this one:
|
Still don't understand 🙂 If you don't expand
If to drop all prep steps, this is the core of SDPA (python pseudo-code): attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value From the PR that I still need to merge 🫠: # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
y = self.scaled_dot_product_attention(q, k, v, mask)
The only reason that I see, is that in the latest version of SDPA there is Or I'm clearly missing something. On the side note, didn't see that in 2.5 the SDPA is updated. |
OK, I see. You are right. It is still a bit weird. |
Adding a comment here: If Not sure how important this is. |
Yes, if you set ...
if enable_gqa:
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
... The benefit of it, as I understand, is that the kernel might do the expansion during calculation, instead of expanding K&V first and then moving tensors from HBM ("global" memory) to the "shared" memory. |
Resolved |
There are only
config.n_query_groups
V vectors. The shape ofself.proj
should really be(config.head_size * config.n_query_groups, config.n_embd)
. It should be smaller in the same sense asself.attn
is smaller if there are less query groups than heads.In the code, you expand the V matrix before multiplying with the linear map. This is equivalent to using a smaller weight matrix, but what is done right now is more expensive and needs more memory.
I could send a PR to fix this, but I am wondering about the compatibility with Hugging Face. Are they also doing this? Do we need to change import scripts for pre-trained models then?
The text was updated successfully, but these errors were encountered: