Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flash attention support for FP8 Mistral #1156

Closed
wants to merge 5 commits into from
Closed

Conversation

schoi-habana
Copy link
Collaborator

@schoi-habana schoi-habana commented Jul 24, 2024

--use_flash_attention was not used for FP8 Mistral due to the accuracy issue (#985). Since the accuracy issue is now gone in 1.17.0, this PR refactor the mistral modeling script to enable flash attention for FP8 flow. Also the GaudiMistralAttentionLongSequence class is removed because the model can serve seq_len==max_position_embedding and performs better without it.

For the case with --batch_size 7 --max_new_tokens 512 --max_input_tokens 32000

Original code without flash attention:
Throughput (including tokenization) = 42.07969811168232 tokens/second
Number of HPU graphs = 121
Memory allocated = 37.51 GB
Max memory allocated = 94.31 GB
Total memory available = 94.62 GB
Graph compilation duration = 285.1875329967588 seconds
Time to first token = 77804.80473767966ms

This PR with flash attention:
Throughput (including tokenization) = 63.06745994995644 tokens/second
Number of HPU graphs = 85
Memory allocated = 37.52 GB
Max memory allocated = 94.0 GB
Total memory available = 94.62 GB
Graph compilation duration = 277.9795082975179 seconds
Time to first token = 49177.13945917785ms

command: QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_generation.py --model_name_or_path /root/tf/test/mistral/ --attn_softmax_bf16 --use_hpu_graphs --trim_logits --use_kv_cache --reuse_cache --bf16 --batch_size 7 --max_new_tokens 512 --max_input_tokens 32000 --limit_hpu_graphs --use_flash_attention --flash_attention_recompute

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@schoi-habana schoi-habana added the synapse 1.17_dependency PR not backward compatible can be merged only when synapse 1.17 is available. label Jul 24, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@schoi-habana schoi-habana added the run-test Run CI for PRs from external contributors label Jul 24, 2024
Copy link
Contributor

@imangohari1 imangohari1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@schoi-habana
Thanks for opening this.
Below are some changes/suggestion for this.
Let me know if you have any questions.
patch:
0001-fea-Updated-the-CI-tests-for-Mistral-fp8.patch

  • I've updated the CI tests to include the flash attention for mistral fp8 cases. Please review these and apply the patch with git am < 0001* (don't rebase here).
    • Below is the CI tests results after the patch is applied. All 4 is passing.
  • After applying the patch, please push the changes and then sync/rebase this branch to the top of OH main.
  • Pls run the CI tests onetime before to make sure everything is fine.
    • Below is the minimal setup to do so. Make sure to comment out the non-mistral cases here (for local testing)

CI tests

Setup

export RUN_SLOW=true
export GAUDI2_CI=1
python -m pytest tests/test_text_generation_example.py -s -v -k test_text_generation_fp8 --token $HFToken 

cmds

# grep "Command" ../ci-updates-tests-smig.log
Measure Command to test: python3 /root/ run_generation.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 --batch_size 1 --use_kv_cache --reuse_cache --use_hpu_graphs --use_flash_attention --flash_attention_recompute --attn_softmax_bf16 --trim_logits --bf16 --trim_logits
Command to test: python3 /root/ run_generation.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 --batch_size 896 --use_kv_cache --max_new_tokens 128 --reuse_cache --use_hpu_graphs --use_flash_attention --flash_attention_recompute --attn_softmax_bf16 --trim_logits --bf16 --trim_logits --max_input_tokens 128 --limit_hpu_graphs --output_dir /tmp/tmp47oitwl7
Command to test: python3 /root/ run_generation.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 --batch_size 120 --use_kv_cache --max_new_tokens 2048 --reuse_cache --use_hpu_graphs --use_flash_attention --flash_attention_recompute --attn_softmax_bf16 --trim_logits --bf16 --trim_logits --max_input_tokens 128 --limit_hpu_graphs --output_dir /tmp/tmpc8a2f8_e
Command to test: python3 /root/ run_generation.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 --batch_size 120 --use_kv_cache --max_new_tokens 128 --reuse_cache --use_hpu_graphs --use_flash_attention --flash_attention_recompute --attn_softmax_bf16 --trim_logits --bf16 --trim_logits --max_input_tokens 2048 --limit_hpu_graphs --output_dir /tmp/tmp7u1cq70q
Command to test: python3 /root/ run_generation.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 --batch_size 44 --use_kv_cache --max_new_tokens 2048 --reuse_cache --use_hpu_graphs --use_flash_attention --flash_attention_recompute --attn_softmax_bf16 --trim_logits --bf16 --trim_logits --max_input_tokens 2048 --limit_hpu_graphs --output_dir /tmp/tmpg2tgdpjn

Results

============================= test session starts ==============================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.5.0 -- /usr/bin/python3
cachedir: .pytest_cache
rootdir: /devops/sgohari/tests/codes/pr-reviews/1156/optimum-habana
configfile: setup.cfg
collecting ... collected 33 items / 29 deselected / 4 selected

tests/test_text_generation_example.py::test_text_generation_fp8[token0-mistralai/Mistral-7B-Instruct-v0.2-1-896-True-128-128-17068.965283763682] [WARNING|utils.py:212] 2024-07-24 20:40:16,674 >> optimum-habana v1.12.0.dev0 has been validated for SynapseAI v1.16.0 but habana-frameworks v1.17.0.417 was found, this could lead to undefined behavior!
[WARNING|utils.py:225] 2024-07-24 20:40:17,580 >> optimum-habana v1.12.0.dev0 has been validated for SynapseAI v1.16.0 but the driver version is v1.17.0, this could lead to undefined behavior!
/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:366: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead
 warnings.warn(
/usr/local/lib/python3.10/dist-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
 warnings.warn(
07/24/2024 20:40:18 - INFO - __main__ - Single-device run.
============================= HABANA PT BRIDGE CONFIGURATION ===========================

.
.
.

PASSED

================= 4 passed, 29 deselected in 906.15s (0:15:06) =================

Local test

 QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_generation.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 --attn_softmax_bf16 --use_hpu_graphs --trim_logits --use_kv_cache --reuse_cache --bf16 --batch_size 256 --max_new_tokens 256 --max_input_tokens 128 --limit_hpu_graphs --use_flash_attention --flash_attention_recompute
Stats:
---------------------------------------------------------------------------------------------------------------
Throughput (including tokenization) = 16553.935480822838 tokens/second
Number of HPU graphs                = 91
Memory allocated                    = 15.72 GB
Max memory allocated                = 23.5 GB
Total memory available              = 94.62 GB
Graph compilation duration          = 23.994101402000524 seconds
---------------------------------------------------------------------------------------------------------------

@schoi-habana
Copy link
Collaborator Author

@imangohari1 @yeonsily addressed your comment

Copy link
Contributor

@imangohari1 imangohari1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGMT.
@regisss please take a final look here when you had a chance.
Thank you.

Copy link
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
Let's wait for the release of Synapse 1.17 to merge this one.

vidyasiv pushed a commit to emascarenhas/optimum-habana that referenced this pull request Aug 1, 2024
Add flash attention support for FP8 Mistral huggingface#1156
vidyasiv added a commit to emascarenhas/optimum-habana that referenced this pull request Aug 2, 2024
Add flash attention support for FP8 Mistral huggingface#1156
@libinta libinta removed the run-test Run CI for PRs from external contributors label Aug 2, 2024
@regisss
Copy link
Collaborator

regisss commented Aug 7, 2024

Closing as it was integrated into #1163.

@regisss regisss closed this Aug 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
synapse 1.17_dependency PR not backward compatible can be merged only when synapse 1.17 is available.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants