Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve MPT fp8 #1256

Merged
merged 25 commits into from
Sep 23, 2024
Merged

Improve MPT fp8 #1256

merged 25 commits into from
Sep 23, 2024

Conversation

atakaha
Copy link
Contributor

@atakaha atakaha commented Aug 14, 2024

Add Softmax and FusedSDPA
Fix unnecessary args from self._gradient_checkpointing_func() call.

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@atakaha
Copy link
Contributor Author

atakaha commented Aug 14, 2024

  • observing new token difference between orinal and use_flash_attention. Because of SDPA implementation difference, specially attn_mask, between MPT and itorch.nn.functional.scaled_dot_product_attention,

  • observing new token difference between original and softmax. This could be cause by precision difference, since softmax call torch.ops.hpu.softmax_fp8

  • bf16 sample command line (use softmax)

python run_generation.py --model_name_or_path mosaicml/mpt-7b --use_hpu_graphs --use_kv_cache --limit_hpu_graphs --max_input_tokens 128 --max_new_tokens 128 --batch_size 256 --bf16 
  • bf16 with flash attention sample command line (doesn't use softmax)
python run_generation.py --model_name_or_path mosaicml/mpt-7b --use_hpu_graphs --use_kv_cache --limit_hpu_graphs --max_input_tokens 128 --max_new_tokens 128 --batch_size 256 --bf16 --use_flash_attention
  • throughput sample
config batch size |max input tokens | max new tokens Throughput (including tokenization)(tokens/s) HPU graphs Memory allocated(GB) Max memory allocated(GB) TP improvement vs bf16
org bf16 128 | 128 | 128 5076 14 45.3 83.51 N/A
bf16 + flash_attention ditto 5150 14 46.3 84.48 1.4%
bf16 + softmax ditto 5148 14 46.3 84.48 1.4%
fp8 + flash_attention ditto 7091 21 40.1 78.29 39.7%
fp8 + softmax ditto 6921 22 38.85 77.04 36.3%
org bf16 32 | 128 | 1024 1604 14 33.31 54.79 N/A
bf16 + flash_attention ditto 1700 14 35.84 57.32 1.5%
bf16 + softmax ditto 1726 14 33.31 54.79 1.5%
fp8 + flash_attention ditto 2038 21 29.64 51.12 19.9%
fp8 + softmax ditto 2084 22 27.12 48.62 22.6%

flash_attention_recompute: Optional[bool] = False,
):
"""
Copied from MptAttention.forward: https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/mpt/modeling_mpt.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least part of the code looks like copied from newer version than v4.32.0, could you verify and update this comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is original, line 123

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still please update to latest as we are copying code from latest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge MptAttention forward r4.44.1

@atakaha atakaha marked this pull request as draft August 15, 2024 01:09
@atakaha atakaha force-pushed the mpt_fp8 branch 2 times, most recently from 30d8302 to f7704e4 Compare August 15, 2024 01:22
@atakaha atakaha closed this Aug 15, 2024
@atakaha atakaha reopened this Aug 15, 2024
@atakaha atakaha marked this pull request as ready for review August 15, 2024 15:21
@atakaha atakaha force-pushed the mpt_fp8 branch 3 times, most recently from 5f901df to b3f729f Compare August 21, 2024 19:05
Add Softmax and FusedSDPA
Update GaudiMptAttention foward to r4.44.1 base

Co-authored-by: Thanaji Rao Thakkalapelli <[email protected]>
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@atakaha
Copy link
Contributor Author

atakaha commented Sep 3, 2024

@regisss, @mandy-li , Please review this PR.

@mandy-li
Copy link
Collaborator

mandy-li commented Sep 3, 2024

@atakaha , pls use make style to fix code format issue

@atakaha
Copy link
Contributor Author

atakaha commented Sep 3, 2024

@atakaha , pls use make style to fix code format issue
@mandy-li,
make style issue is fixed.

@mandy-li
Copy link
Collaborator

mandy-li commented Sep 5, 2024

@libinta , please have somebody review this PR.thanks

optimum/habana/transformers/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
attn_weights = None
else:
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested it and throughput slowdown. The trace showed if MPT use it then it doesn't use fp8, but torch.matmul use fp8. That's the why MPT doesn't use the fp8 matmul kernel this time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's weird. Wondering if fp8_sofmax actually forcing matmul to also run with fp8. Do you see a good accuracy with fp8 softmax?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class Matmul(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return torch.matmul(x, y)

Compare profile data this between FP8, for GPT-J, Mistral and MPT.

  • GPT-J and Mistral spend <70% for MME but MPT spend ~30%.
  • GPT-J and Mistral calling index_copy_fwd_hf8 and cast_bf16_to_hf8 kernels but doesn't appear these kernel call in MPT. When add matmul_qk and matmul_qv in blocklist of maxabs_quant.json for MPT then it call same kernels of GPT-J/Mistral.
    This is reason we didn't add this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can you make sure your change doesn't break anything for the training case since model file is used for both.

if use_flash_attention and FusedSDPA:
import habana_frameworks.torch.hpu as ht

with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also check other model if this enable_recompute should be set based on the fp8/bf16 and q_len.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If set enable_recompute same way of other models, then FP8 throughput drop to half. In the trace, softmax_stage1_fwd_f32 appears and spending a lot of time, There is not this process at enable_recompute = False case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's your fp8 command with flash_attention?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test any case which enable causal_mask and enable_compute?
How about longer prompt? Usually causal_mask shows better perf for long prompt

Copy link
Contributor Author

@atakaha atakaha Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's your fp8 command with flash_attention?

The command line is

QUANT_CONFIG=./quantization_config/maxabs_quant.json \
python run_generation.py \
--model_name_or_path mosaicml/mpt-7b \
--use_hpu_graphs \
--use_kv_cache \
--limit_hpu_graphs \
--max_input_tokens 128 \
--max_new_tokens 128 \
--batch_size 128 \
--bf16 \
--use_flash_attention

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test any case which enable causal_mask and enable_compute? How about longer prompt? Usually causal_mask shows better perf for long prompt

No, I haven't test it these cases.

super().__init__()

def forward(self, x, dim=None, invAttnHead=None):
return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since INC is enabled, please use torch.nn.functional.softmax as its supported module by INC for quantization.
https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html#supported-functions

Suggested change
return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead)
return torch.nn.functional.softmax(x, dim)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the code with your suggestion.

Copy link
Collaborator

@jiminha jiminha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.
Could you investigate further though to see if we need causal_mask and recompute enabled differently for long prompt and submit as a separate patch if needed?

@jiminha jiminha added the run-test Run CI for PRs from external contributors label Sep 12, 2024
Copy link

The code quality check failed, please run make style.

@atakaha
Copy link
Contributor Author

atakaha commented Sep 12, 2024

The code quality check failed, please run make style.

fixed ruff error.

@atakaha
Copy link
Contributor Author

atakaha commented Sep 12, 2024

Looks good. Could you investigate further though to see if we need causal_mask and recompute enabled differently for long prompt and submit as a separate patch if needed?

Sure, I will.

@atakaha
Copy link
Contributor Author

atakaha commented Sep 16, 2024

@regisss, Please review this PR.

@regisss regisss merged commit b75216c into huggingface:main Sep 23, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-test Run CI for PRs from external contributors
Projects
None yet
Development

Successfully merging this pull request may close these issues.