Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flash-attention-3 #33522

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

flash-attention-3 #33522

wants to merge 1 commit into from

Conversation

hlky
Copy link
Contributor

@hlky hlky commented Sep 17, 2024

What does this PR do?

This PR adds preliminary support for Flash Attention 3.

  • is_flash_attn_3_available required a workaround in _is_package_available as package_version = importlib.metadata.version(pkg_name) fails with importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn_interface.
  • _supports_flash_attn_3 and _check_and_enable_flash_attn_3 added to modeling_utils.py, near duplicate of _check_and_enable_flash_attn_2.
  • _flash_attention_3_forward implemented in modeling_flash_attention_3_utils.py.
    • Similar to FAv2 _flash_attention_forward, currently FAv3 does not support dropout, sliding window or softcap, and in FAv3 flash_attn_func/flash_attn_varlen_func return a tuple.
    • attention_mask is not None and position_ids is not None paths depend on _upad_input and prepare_fa2_from_position_ids respectively, these are duplicated from modeling_flash_attention_utils.py and are not included in FAv3 package therefore FAv3 depends on flash_attn, this is reflected in is_flash_attn_3_available which checks for is_flash_attn_2_available.
    • In the remaining path FAv3 supports FP8, this PR currently uses environment variable FLASH_ATTENTION_3_FP8 for this purpose, we can probably add something like attention_kwargs to model forwards to control this, or maybe another _attn_implementation type flash_attention_3_fp8, best to get reviews first and consensus on the best way to do it[1]
  • flash_attention_3 is added to Llama with LlamaFlashAttention3, similar to LlamaFlashAttention2 with unsupported options like dropout and sliding window removed. Edit: added to other models, see comment below.
  • _update_causal_mask is updated in various models due to utils/check_copies.py, and _supports_flash_attn_3 is added in to some other models already for the same reason. See comment below.

Fixes #33373

Todo

  • Test attention_mask is not None and position_ids is not None paths
  • Implement FlashAttention3 classes for other models Done.
  • FP8 usage[1]
  • Documentation
  • Benchmarks would be nice

Notes

Llama tested on H100 SXM with:

import torch
from transformers import AutoTokenizer, LlamaForCausalLM

tokenizer = AutoTokenizer.from_pretrained('NousResearch/Hermes-3-Llama-3.1-8B', trust_remote_code=True)
model = LlamaForCausalLM.from_pretrained(
    "NousResearch/Hermes-3-Llama-3.1-8B",
    torch_dtype=torch.float16,
    device_map="auto",
    attn_implementation="flash_attention_3"
)

prompts = [
    """<|im_start|>system
You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.<|im_end|>
<|im_start|>user
Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.<|im_end|>
<|im_start|>assistant""",
    ]

for chat in prompts:
    print(chat)
    input_ids = tokenizer(chat, return_tensors="pt").input_ids.to("cuda")
    generated_ids = model.generate(input_ids, max_new_tokens=750, temperature=0.8, repetition_penalty=1.1, do_sample=True, eos_token_id=tokenizer.eos_token_id)
    response = tokenizer.decode(generated_ids[0][input_ids.shape[-1]:], skip_special_tokens=True, clean_up_tokenization_space=True)
    print(f"Response: {response}")

(shortened) responses

FP16:

In the vast expanse of the universe, there existed a celestial realm where beings of extraordinary powers roamed freely. One such being was Goku, the legendary warrior known for his indomitable spirit and unconquerable will.

One fateful day, as Goku trained under the golden sun, he sensed an unusual disturbance in the cosmic energy. Puzzled by this anomaly, Goku rushed to the source of the disturbance and arrived at the edge of a hidden dimension.

To his horror, Goku witnessed Kirby and Majin Buu working together, devising devious plans to obliterate entire planets. Their combined strength was formidable, their intentions sinister, and their alliance unprecedented.

FP8:

In the vast expanse of the universe, there existed a planet called Planet Vegeta, where the powerful Saiyan warrior, Goku, lived among his friends in the city of Earth.

One fateful day, Goku was training on the top of a mountain when he sensed an unusual energy signature. His spidey senses immediately tingled with suspicion.

"Who could that be?" he wondered aloud as he leapt into the sky, soaring towards the source of the disturbance.

As Goku arrived at the scene, he discovered something utterly shocking - Kirby, the infamous villain from another galaxy, had formed an alliance with Majin Buu, the mischievous yet formidable entity who had caused Goku so much trouble in the past.

All other models will be tested after I've finished adding FlashAttention3 classes.

Who can review?

cc @ArthurZucker

@hlky hlky force-pushed the flash-attention-3 branch 3 times, most recently from b6afd63 to 0976545 Compare September 17, 2024 14:28
@hlky
Copy link
Contributor Author

hlky commented Sep 17, 2024

FlashAttention3 classes added to the models that had to be modified due to utils/check_copies.py. The following models do not currently support Flash Attention and were only modified due to utils/check_copies.py: bloom, codegen, gpt_neox_japanese, idefics, persimmon.

There are more models that support FAv2, these will be done next.

Note that Sliding Window should be supported soon, after Dao-AILab/flash-attention#1233

@hlky hlky force-pushed the flash-attention-3 branch 2 times, most recently from 5aa58ab to 7ae105e Compare September 17, 2024 17:49
@hlky
Copy link
Contributor Author

hlky commented Sep 17, 2024

All models supporting FAv2 should now have FAv3 classes.

The following models will need sliding window adding back in when available:
chameleon/modeling_chameleon.py
gemma/modeling_gemma.py
gemma2/modeling_gemma2.py
granite/modeling_granite.py
idefics2/modeling_idefics2.py
jamba/modeling_jamba.py
llama/modeling_llama.py
mistral/modeling_mistral.py
mixtral/modeling_mixtral.py
nemotron/modeling_nemotron.py
phi3/modeling_phi3.py
qwen2/modeling_qwen2.py
qwen2_moe/modeling_qwen2_moe.py
qwen2_vl/modeling_qwen2_vl.py
starcoder2/modeling_starcoder2.py

There are some areas that use config._attn_implementation == "flash_attention_2" that I'll update next.

@hlky
Copy link
Contributor Author

hlky commented Sep 17, 2024

All occurrences of config._attn_implementation == .../self._use_flash_attention_ and other mentions of "flash_attention_2" should now be updated with FAv3 versions. This should be about it on the modeling side with the exception of sliding window.

Documentation and tests will be done next.

@hlky
Copy link
Contributor Author

hlky commented Sep 17, 2024

Some documentation and all tests are updated for FAv3. I'll run the tests on a H100 instance then mark this as ready for (initial) review.

@hlky
Copy link
Contributor Author

hlky commented Sep 17, 2024

Generally FAv3 tests are failing due to the small configurations used: RuntimeError: Only support head size 64, 128, and 256 for now.

Instead I've tested the majority of models from their examples, with a few exceptions like Gemma and Mistral that I need to request access to, and particularly large models such as Jamba that my instance doesn't have space to download.

All of the tested models with examples are ok, with the exception of HuggingFaceM4/idefics2-8b-base:

  File "/workspace/transformers/src/transformers/modeling_flash_attention_3_utils.py", line 118, in _upad_input
    query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: too many values to unpack (expected 4)

However this error also occurs with flash_attention_2.

StableLM models are currently not supported due to num_attention_heads/Only support head size 64, 128, and 256 for now.

I've attached test reports, the numerical accuracy failures may need special care as per flash-attention/hopper/test_flash_attn.py here and here

test_report.zip

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wowowo super nice initiative thanks! 🔥
IMO since we already abstracted the flash attention API, let's try to keep it in flashAttentionLlama but maybe support flash_attention_3 in the attn_implementation for example! WDYT?

value_states = value_states.to(target_dtype)

# TODO: get `use_fp8` to here, add attention_kwargs or something
attn_output = _flash_attention_3_forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! As far as I can tell, the only diff is the forward function right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the difference between FlashAttention2 classes and FlashAttention3 is just the forward function, and lack of dropout/sliding window/softcap for FAv3. As you suggest we could support v3 in the existing classes instead using config.attn_implementation to select the appropriate function, happy to make this change if you think that's better.

return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))


def _flash_attention_3_forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's maybe replace flash_attention_forward by this one when flash attention3 is available WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK FAv3 will be for Hopper GPUs only

@hlky
Copy link
Contributor Author

hlky commented Sep 18, 2024

I've replaced the FlashAttention3 classes and integrated it into existing FlashAttention2 classes. self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" is added to control which version to use with if self._flash_attn_3:.

I've renamed the FlashAttention2 classes to just FlashAttention as ATTENTION_CLASSES would look strange doing e.g. "flash_attention_3": Qwen2VLFlashAttention2

Note that while I was checking all Attention classes contain config I added some missing type annotations to the config parameters, I then had to add a few more type due to Copied from.

In src/transformers/models/qwen2_vl/modeling_qwen2_vl.py I've changed VisionFlashAttention and VisionSdpaAttention to subclass VisionAttention and added config as a parameter, this was needed for self.config._attn_implementation == "flash_attention_3".

We could simplify the changes to FlashAttention classes further by creating a wrapper for both _flash_attention_forward and _flash_attention_3_forward with the above _flash_attn_3 as a parameter.

Checks like config._attn_implementation != "flash_attention_2"/config._attn_implementation == "flash_attention_2" could also be changed to something like "flash_attention" not in config._attn_implementation/"flash_attention" in config._attn_implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Any plans on adding Flash Attention 3?
2 participants