Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve padding and attention mask handling #104

Merged
merged 2 commits into from
Oct 16, 2023
Merged

Improve padding and attention mask handling #104

merged 2 commits into from
Oct 16, 2023

Conversation

cbalioglu
Copy link
Contributor

This PR includes several closely-coupled changes:

  1. The major change is the introduction of the new PaddingMask and AttentionMask types that replace all uses of seq_lens and tensor-based padding_mask and self_attn_mask parameters. This is also the main reason why this PR touches many files since those parameters are used in many places in the code base. The main benefit of these two new types is the lazy initialization of masks and not materializing them at all whenever possible (e.g. when a global causal attention mask is used with PyTorch SDPA or xformers, we entirely skip constructing the mask). This can significantly improve the performance in the right settings and also reduces memory use (again, sometimes significantly when used with very large context lengths).
  2. MultiheadAttention also needed to be refactored. Mask handling is moved down to SDPA implementations since each SDPA has varying support for different mask types. As part of this work add_bias_kv and add_zero_attn are also removed. Initially we introduced them to have feature parity with fairseq MHA, but turns out that these two parameters are not used anywhere. We can consider introducing them in the future again if requested, but they add needless complexity to mask handling and both parameters are superseded quite a while ago by superior techniques (e.g. pre-LN).
  3. Various nit updates at several places in the code base (like renaming some local variables, updating docstrs) mainly related to the first two points.
  4. Fixes the causal attention mask to support bfloat16.

Tested with LLaMA 2 and SeamlessM4T models in various settings and verified bitwise output parity.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 16, 2023
@cbalioglu
Copy link
Contributor Author

Manually verified clang-tidy check. GitHub runner is (again) out of disk space.

@cbalioglu cbalioglu merged commit 4604d73 into main Oct 16, 2023
18 of 19 checks passed
@cbalioglu cbalioglu deleted the mask branch October 16, 2023 16:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants