-
Notifications
You must be signed in to change notification settings - Fork 26.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
flash-attention-3 #33522
base: main
Are you sure you want to change the base?
flash-attention-3 #33522
Conversation
b6afd63
to
0976545
Compare
There are more models that support FAv2, these will be done next. Note that Sliding Window should be supported soon, after Dao-AILab/flash-attention#1233 |
5aa58ab
to
7ae105e
Compare
All models supporting FAv2 should now have FAv3 classes. The following models will need sliding window adding back in when available: There are some areas that use |
7ae105e
to
bd6e9e7
Compare
All occurrences of Documentation and tests will be done next. |
bd6e9e7
to
fbf9bec
Compare
Some documentation and all tests are updated for FAv3. I'll run the tests on a H100 instance then mark this as ready for (initial) review. |
fbf9bec
to
27edb62
Compare
Generally FAv3 tests are failing due to the small configurations used: Instead I've tested the majority of models from their examples, with a few exceptions like Gemma and Mistral that I need to request access to, and particularly large models such as Jamba that my instance doesn't have space to download. All of the tested models with examples are ok, with the exception of
However this error also occurs with StableLM models are currently not supported due to num_attention_heads/ I've attached test reports, the numerical accuracy failures may need special care as per |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wowowo super nice initiative thanks! 🔥
IMO since we already abstracted the flash attention API, let's try to keep it in flashAttentionLlama
but maybe support flash_attention_3
in the attn_implementation
for example! WDYT?
value_states = value_states.to(target_dtype) | ||
|
||
# TODO: get `use_fp8` to here, add attention_kwargs or something | ||
attn_output = _flash_attention_3_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! As far as I can tell, the only diff is the forward function right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the difference between FlashAttention2 classes and FlashAttention3 is just the forward function, and lack of dropout/sliding window/softcap for FAv3. As you suggest we could support v3 in the existing classes instead using config.attn_implementation
to select the appropriate function, happy to make this change if you think that's better.
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) | ||
|
||
|
||
def _flash_attention_3_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's maybe replace flash_attention_forward by this one when flash attention3 is available WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK FAv3 will be for Hopper GPUs only
27edb62
to
ba268ef
Compare
I've replaced the I've renamed the Note that while I was checking all In We could simplify the changes to Checks like |
ba268ef
to
e4d2197
Compare
e4d2197
to
4473129
Compare
What does this PR do?
This PR adds preliminary support for Flash Attention 3.
is_flash_attn_3_available
required a workaround in_is_package_available
aspackage_version = importlib.metadata.version(pkg_name)
fails withimportlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn_interface
._supports_flash_attn_3
and_check_and_enable_flash_attn_3
added tomodeling_utils.py
, near duplicate of_check_and_enable_flash_attn_2
._flash_attention_3_forward
implemented inmodeling_flash_attention_3_utils.py
._flash_attention_forward
, currently FAv3 does not support dropout, sliding window or softcap, and in FAv3flash_attn_func
/flash_attn_varlen_func
return a tuple.attention_mask is not None
andposition_ids is not None
paths depend on_upad_input
andprepare_fa2_from_position_ids
respectively, these are duplicated frommodeling_flash_attention_utils.py
and are not included in FAv3 package therefore FAv3 depends onflash_attn
, this is reflected inis_flash_attn_3_available
which checks foris_flash_attn_2_available
.FLASH_ATTENTION_3_FP8
for this purpose, we can probably add something likeattention_kwargs
to model forwards to control this, or maybe another_attn_implementation
typeflash_attention_3_fp8
, best to get reviews first and consensus on the best way to do it[1]flash_attention_3
is added to Llama withLlamaFlashAttention3
, similar toLlamaFlashAttention2
with unsupported options like dropout and sliding window removed. Edit: added to other models, see comment below.See comment below._update_causal_mask
is updated in various models due toutils/check_copies.py
, and_supports_flash_attn_3
is added in to some other models already for the same reason.Fixes #33373
Todo
attention_mask is not None
andposition_ids is not None
pathsImplement FlashAttention3 classes for other modelsDone.Notes
Llama tested on H100 SXM with:
(shortened) responses
FP16:
FP8:
All other models will be tested after I've finished adding FlashAttention3 classes.
Who can review?
cc @ArthurZucker