-
Notifications
You must be signed in to change notification settings - Fork 200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add flash attention support for FP8 Mistral #1156
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@schoi-habana
Thanks for opening this.
Below are some changes/suggestion for this.
Let me know if you have any questions.
patch:
0001-fea-Updated-the-CI-tests-for-Mistral-fp8.patch
- I've updated the CI tests to include the flash attention for mistral fp8 cases. Please review these and apply the patch with
git am < 0001*
(don't rebase here).- Below is the CI tests results after the patch is applied. All 4 is passing.
- After applying the patch, please push the changes and then sync/rebase this branch to the top of OH main.
- Pls run the CI tests onetime before to make sure everything is fine.
- Below is the minimal setup to do so. Make sure to comment out the non-mistral cases here (for local testing)
CI tests
Setup
export RUN_SLOW=true
export GAUDI2_CI=1
python -m pytest tests/test_text_generation_example.py -s -v -k test_text_generation_fp8 --token $HFToken
cmds
# grep "Command" ../ci-updates-tests-smig.log
Measure Command to test: python3 /root/ run_generation.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 --batch_size 1 --use_kv_cache --reuse_cache --use_hpu_graphs --use_flash_attention --flash_attention_recompute --attn_softmax_bf16 --trim_logits --bf16 --trim_logits
Command to test: python3 /root/ run_generation.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 --batch_size 896 --use_kv_cache --max_new_tokens 128 --reuse_cache --use_hpu_graphs --use_flash_attention --flash_attention_recompute --attn_softmax_bf16 --trim_logits --bf16 --trim_logits --max_input_tokens 128 --limit_hpu_graphs --output_dir /tmp/tmp47oitwl7
Command to test: python3 /root/ run_generation.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 --batch_size 120 --use_kv_cache --max_new_tokens 2048 --reuse_cache --use_hpu_graphs --use_flash_attention --flash_attention_recompute --attn_softmax_bf16 --trim_logits --bf16 --trim_logits --max_input_tokens 128 --limit_hpu_graphs --output_dir /tmp/tmpc8a2f8_e
Command to test: python3 /root/ run_generation.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 --batch_size 120 --use_kv_cache --max_new_tokens 128 --reuse_cache --use_hpu_graphs --use_flash_attention --flash_attention_recompute --attn_softmax_bf16 --trim_logits --bf16 --trim_logits --max_input_tokens 2048 --limit_hpu_graphs --output_dir /tmp/tmp7u1cq70q
Command to test: python3 /root/ run_generation.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 --batch_size 44 --use_kv_cache --max_new_tokens 2048 --reuse_cache --use_hpu_graphs --use_flash_attention --flash_attention_recompute --attn_softmax_bf16 --trim_logits --bf16 --trim_logits --max_input_tokens 2048 --limit_hpu_graphs --output_dir /tmp/tmpg2tgdpjn
Results
============================= test session starts ==============================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.5.0 -- /usr/bin/python3
cachedir: .pytest_cache
rootdir: /devops/sgohari/tests/codes/pr-reviews/1156/optimum-habana
configfile: setup.cfg
collecting ... collected 33 items / 29 deselected / 4 selected
tests/test_text_generation_example.py::test_text_generation_fp8[token0-mistralai/Mistral-7B-Instruct-v0.2-1-896-True-128-128-17068.965283763682] [WARNING|utils.py:212] 2024-07-24 20:40:16,674 >> optimum-habana v1.12.0.dev0 has been validated for SynapseAI v1.16.0 but habana-frameworks v1.17.0.417 was found, this could lead to undefined behavior!
[WARNING|utils.py:225] 2024-07-24 20:40:17,580 >> optimum-habana v1.12.0.dev0 has been validated for SynapseAI v1.16.0 but the driver version is v1.17.0, this could lead to undefined behavior!
/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:366: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead
warnings.warn(
/usr/local/lib/python3.10/dist-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
warnings.warn(
07/24/2024 20:40:18 - INFO - __main__ - Single-device run.
============================= HABANA PT BRIDGE CONFIGURATION ===========================
.
.
.
PASSED
================= 4 passed, 29 deselected in 906.15s (0:15:06) =================
Local test
QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_generation.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 --attn_softmax_bf16 --use_hpu_graphs --trim_logits --use_kv_cache --reuse_cache --bf16 --batch_size 256 --max_new_tokens 256 --max_input_tokens 128 --limit_hpu_graphs --use_flash_attention --flash_attention_recompute
Stats:
---------------------------------------------------------------------------------------------------------------
Throughput (including tokenization) = 16553.935480822838 tokens/second
Number of HPU graphs = 91
Memory allocated = 15.72 GB
Max memory allocated = 23.5 GB
Total memory available = 94.62 GB
Graph compilation duration = 23.994101402000524 seconds
---------------------------------------------------------------------------------------------------------------
@imangohari1 @yeonsily addressed your comment |
cd8932d
to
821b679
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGMT.
@regisss please take a final look here when you had a chance.
Thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Let's wait for the release of Synapse 1.17 to merge this one.
Add flash attention support for FP8 Mistral huggingface#1156
Add flash attention support for FP8 Mistral huggingface#1156
Closing as it was integrated into #1163. |
--use_flash_attention was not used for FP8 Mistral due to the accuracy issue (#985). Since the accuracy issue is now gone in 1.17.0, this PR refactor the mistral modeling script to enable flash attention for FP8 flow. Also the GaudiMistralAttentionLongSequence class is removed because the model can serve seq_len==max_position_embedding and performs better without it.
For the case with --batch_size 7 --max_new_tokens 512 --max_input_tokens 32000
Original code without flash attention:
Throughput (including tokenization) = 42.07969811168232 tokens/second
Number of HPU graphs = 121
Memory allocated = 37.51 GB
Max memory allocated = 94.31 GB
Total memory available = 94.62 GB
Graph compilation duration = 285.1875329967588 seconds
Time to first token = 77804.80473767966ms
This PR with flash attention:
Throughput (including tokenization) = 63.06745994995644 tokens/second
Number of HPU graphs = 85
Memory allocated = 37.52 GB
Max memory allocated = 94.0 GB
Total memory available = 94.62 GB
Graph compilation duration = 277.9795082975179 seconds
Time to first token = 49177.13945917785ms
command: QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_generation.py --model_name_or_path /root/tf/test/mistral/ --attn_softmax_bf16 --use_hpu_graphs --trim_logits --use_kv_cache --reuse_cache --bf16 --batch_size 7 --max_new_tokens 512 --max_input_tokens 32000 --limit_hpu_graphs --use_flash_attention --flash_attention_recompute
Before submitting