-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for grouped-query attention #9
base: main
Are you sure you want to change the base?
Conversation
# This is a noop for normal attention where ng == np. When using grouped-query attention this | ||
# creates a view that has the keys and values virtually repeated along their dimension to | ||
# match the number of queries. | ||
key_layer = key_layer.repeat_interleave( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this repeat_intervleaving done? Does it use an explicit torch.view() operation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amithrm
Instead of reshaping a tensor like torch.view
, it repeats elements of a tensor.
GQA does not have a one-to-one correspondence between query heads and key/value heads like MHA.
Instead, multiple query heads share a single key/value head.
By virtually repeating shared key/value heads until the number of heads becomes num_attention_heads
, core_attention
can treat MHA and GQA equivalently.
The middle and right illustrations in Figure 2 will be transformed to have the same shape as the left one through this operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't this cause calculations to be duplicated/redundant across GQA groups?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This operation does not increase the time complexity.
GQA reduces the complexity by reducing the output dimension of the projection layer that transforms hidden states into key/value heads.
The complexity of the dot product in the core attention is equivalent for both GQA and MHA.
@@ -5,6 +5,7 @@ export TP=8 | |||
export PP=4 | |||
export N_LAYERS=40 | |||
export N_AH=40 | |||
export N_QG=40 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you post performance numbers and convergence curves for pretraining with your changes? Can you use below config
tensor_parallel:8
pipeline_parallel:8
data_parallel:1
global_batch_size:256
activation_checkpointing:full
precision:bf16+SR
dataset:bookcorpus
lrscheduler: cosineannealing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have confirmed that the number of seconds per step is reduced on 7B with a smaller num_query_groups setting than the normal setting.
We plan to start training on 70B next week.
We will share the results through our contact at AWS Japan.
In preparation, we plan to do some experiments with smaller settings.
I will share the results as soon as we are done.
Add support for grouped-query attention for Llama 2 70B and Code Llama 34B compatibility.
References: