Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for grouped-query attention #9

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yasuhisa-nakashima
Copy link

# This is a noop for normal attention where ng == np. When using grouped-query attention this
# creates a view that has the keys and values virtually repeated along their dimension to
# match the number of queries.
key_layer = key_layer.repeat_interleave(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this repeat_intervleaving done? Does it use an explicit torch.view() operation?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amithrm
Instead of reshaping a tensor like torch.view, it repeats elements of a tensor.

GQA does not have a one-to-one correspondence between query heads and key/value heads like MHA.
Instead, multiple query heads share a single key/value head.
By virtually repeating shared key/value heads until the number of heads becomes num_attention_heads, core_attention can treat MHA and GQA equivalently.

The middle and right illustrations in Figure 2 will be transformed to have the same shape as the left one through this operation.

254449231-2808a8a9-2c5c-4c72-b7cf-bbdb4468832b

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this cause calculations to be duplicated/redundant across GQA groups?

Copy link
Author

@yasuhisa-nakashima yasuhisa-nakashima Sep 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operation does not increase the time complexity.
GQA reduces the complexity by reducing the output dimension of the projection layer that transforms hidden states into key/value heads.
The complexity of the dot product in the core attention is equivalent for both GQA and MHA.

@@ -5,6 +5,7 @@ export TP=8
export PP=4
export N_LAYERS=40
export N_AH=40
export N_QG=40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you post performance numbers and convergence curves for pretraining with your changes? Can you use below config

tensor_parallel:8
pipeline_parallel:8
data_parallel:1
global_batch_size:256
activation_checkpointing:full
precision:bf16+SR
dataset:bookcorpus
lrscheduler: cosineannealing

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have confirmed that the number of seconds per step is reduced on 7B with a smaller num_query_groups setting than the normal setting.

We plan to start training on 70B next week.
We will share the results through our contact at AWS Japan.

In preparation, we plan to do some experiments with smaller settings.
I will share the results as soon as we are done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants