Improve padding and attention mask handling #104
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR includes several closely-coupled changes:
PaddingMask
andAttentionMask
types that replace all uses ofseq_lens
and tensor-basedpadding_mask
andself_attn_mask
parameters. This is also the main reason why this PR touches many files since those parameters are used in many places in the code base. The main benefit of these two new types is the lazy initialization of masks and not materializing them at all whenever possible (e.g. when a global causal attention mask is used with PyTorch SDPA or xformers, we entirely skip constructing the mask). This can significantly improve the performance in the right settings and also reduces memory use (again, sometimes significantly when used with very large context lengths).MultiheadAttention
also needed to be refactored. Mask handling is moved down toSDPA
implementations since each SDPA has varying support for different mask types. As part of this workadd_bias_kv
andadd_zero_attn
are also removed. Initially we introduced them to have feature parity with fairseq MHA, but turns out that these two parameters are not used anywhere. We can consider introducing them in the future again if requested, but they add needless complexity to mask handling and both parameters are superseded quite a while ago by superior techniques (e.g. pre-LN).Tested with LLaMA 2 and SeamlessM4T models in various settings and verified bitwise output parity.