Skip to content

Commit

Permalink
Improvements to attention mask handling
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Oct 20, 2023
1 parent e969622 commit fa7886d
Show file tree
Hide file tree
Showing 15 changed files with 438 additions and 326 deletions.
9 changes: 2 additions & 7 deletions src/fairseq2/models/nllb/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,7 @@ def build_decoder_layer(self) -> TransformerDecoderLayer:
"""Build a Transformer decoder layer."""
self_attn = self.build_attention(self.config.num_decoder_attn_heads)

encoder_decoder_attn = self.build_attention(
self.config.num_decoder_attn_heads, encoder_decoder=True
)
encoder_decoder_attn = self.build_attention(self.config.num_decoder_attn_heads)

ffn = self.build_ffn()

Expand All @@ -269,16 +267,13 @@ def build_decoder_layer(self) -> TransformerDecoderLayer:
dtype=self.dtype,
)

def build_attention(
self, num_heads: int, encoder_decoder: bool = False
) -> MultiheadAttention:
def build_attention(self, num_heads: int) -> MultiheadAttention:
"""Build a Transformer multi-head attention layer."""
sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)

return StandardMultiheadAttention(
self.config.model_dim,
num_heads,
static_kv=encoder_decoder,
sdpa=sdpa,
device=self.device,
dtype=self.dtype,
Expand Down
7 changes: 2 additions & 5 deletions src/fairseq2/models/s2t_transformer/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def build_decoder_layer(self) -> TransformerDecoderLayer:
"""Build a Transformer decoder layer."""
self_attn = self.build_decoder_attention()

encoder_decoder_attn = self.build_decoder_attention(encoder_decoder=True)
encoder_decoder_attn = self.build_decoder_attention()

ffn = self.build_ffn()

Expand Down Expand Up @@ -450,16 +450,13 @@ def build_encoder_attention(self) -> MultiheadAttention:
dtype=self.dtype,
)

def build_decoder_attention(
self, encoder_decoder: bool = False
) -> MultiheadAttention:
def build_decoder_attention(self) -> MultiheadAttention:
"""Build a Transformer decoder multi-head attention layer."""
sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)

return StandardMultiheadAttention(
self.config.model_dim,
self.config.num_decoder_attn_heads,
static_kv=encoder_decoder,
sdpa=sdpa,
device=self.device,
dtype=self.dtype,
Expand Down
4 changes: 2 additions & 2 deletions src/fairseq2/nn/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Sequence, Tuple, cast
from typing import Any, Optional, Sequence, Tuple, cast

import torch
from torch import Tensor
Expand Down Expand Up @@ -76,7 +76,7 @@ def to_padding_mask(seq_lens: Tensor, batch_seq_len: int) -> Tensor:


def apply_padding_mask(
seqs: Tensor, padding_mask: Optional[PaddingMask], fill_value: float = 0.0
seqs: Tensor, padding_mask: Optional[PaddingMask], fill_value: Any = 0
) -> Tensor:
"""Apply the specified padding mask to ``seqs``.
Expand Down
36 changes: 18 additions & 18 deletions src/fairseq2/nn/position_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def forward(
Same as ``seqs``.
"""
if self.max_seq_len is not None:
if not self.training and state_bag is not None:
start = state_bag.step
if self.training or state_bag is None:
start_step = 0
else:
start = 0
start_step = state_bag.step

if (seq_len := start + seqs.size(-2)) > self.max_seq_len:
if (seq_len := start_step + seqs.size(-2)) > self.max_seq_len:
raise ValueError(
f"The input sequence length must be less than or equal to the maximum sequence length ({self.max_seq_len}), but is {seq_len} instead."
)
Expand Down Expand Up @@ -196,13 +196,13 @@ def reset_non_persistent_buffers(self) -> None:
l_half = self.freqs[:, :num_sin]
r_half = self.freqs[:, num_sin:]

start = self._sin_offset
start_step = self._sin_offset

assert self.max_seq_len is not None

# (S)
steps = torch.arange(
start, start + self.max_seq_len, device=device, dtype=dtype
start_step, start_step + self.max_seq_len, device=device, dtype=dtype
)

# (E)
Expand All @@ -229,12 +229,12 @@ def _do_forward(
""":meta private:"""
seq_len = seqs.size(-2)

if not self.training and state_bag is not None:
start = state_bag.step
if self.training or state_bag is None:
start_step = 0
else:
start = 0
start_step = state_bag.step

fp32_seqs = seqs.float() + self.freqs[start : start + seq_len]
fp32_seqs = seqs.float() + self.freqs[start_step : start_step + seq_len]

return fp32_seqs.type_as(seqs)

Expand Down Expand Up @@ -291,13 +291,13 @@ def _do_forward(
""":meta private:"""
seq_len = seqs.size(-2)

if not self.training and state_bag is not None:
start = state_bag.step
if self.training or state_bag is None:
start_step = 0
else:
start = 0
start_step = state_bag.step

steps = torch.arange(
start, start + seq_len, device=seqs.device, dtype=torch.int64
start_step, start_step + seq_len, device=seqs.device, dtype=torch.int64
)

return seqs + embedding(steps, self.weight)
Expand Down Expand Up @@ -373,17 +373,17 @@ def _do_forward(
""":meta private:"""
seq_len = seqs.size(-2)

if not self.training and state_bag is not None:
start = state_bag.step
if self.training or state_bag is None:
start_step = 0
else:
start = 0
start_step = state_bag.step

# (*, S, E) -> (*, S, E / 2, 2)
seqs = seqs.unflatten(-1, (-1, 2))

complex_seqs = torch.view_as_complex(seqs.float())

complex_seqs = complex_seqs * self.freqs[start : start + seq_len]
complex_seqs = complex_seqs * self.freqs[start_step : start_step + seq_len]

# (*, S, E / 2, 2) -> (*, S, E)
fp32_seqs = torch.view_as_real(complex_seqs).flatten(-2)
Expand Down
8 changes: 4 additions & 4 deletions src/fairseq2/nn/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
AttentionMaskFactory as AttentionMaskFactory,
)
from fairseq2.nn.transformer.attention_mask import (
CustomAttentionMask as CustomAttentionMask,
CausalAttentionMask as CausalAttentionMask,
)
from fairseq2.nn.transformer.attention_mask import (
GlobalCausalAttentionMask as GlobalCausalAttentionMask,
CausalAttentionMaskFactory as CausalAttentionMaskFactory,
)
from fairseq2.nn.transformer.attention_mask import (
GlobalCausalAttentionMaskFactory as GlobalCausalAttentionMaskFactory,
CustomAttentionMask as CustomAttentionMask,
)
from fairseq2.nn.transformer.decoder import (
DecoderLayerOutputHook as DecoderLayerOutputHook,
Expand Down Expand Up @@ -69,7 +69,7 @@
AttentionWeightHook as AttentionWeightHook,
)
from fairseq2.nn.transformer.multihead_attention import (
GlobalAttentionState as GlobalAttentionState,
FullAttentionState as FullAttentionState,
)
from fairseq2.nn.transformer.multihead_attention import (
LocalAttentionState as LocalAttentionState,
Expand Down
50 changes: 26 additions & 24 deletions src/fairseq2/nn/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
from torch.nn.functional import dropout, softmax

from fairseq2.nn.padding import PaddingMask
from fairseq2.nn.transformer.attention_mask import (
AttentionMask,
GlobalCausalAttentionMask,
)
from fairseq2.nn.transformer.attention_mask import AttentionMask, CausalAttentionMask
from fairseq2.typing import finaloverride
from fairseq2.utils.version import is_pt2_or_greater

Expand All @@ -43,7 +40,7 @@ def __init__(self, *, attn_dropout_p: float = 0.0) -> None:
@abstractmethod
def forward(
self,
queries: Tensor,
seqs: Tensor,
keys: Tensor,
key_padding_mask: Optional[PaddingMask],
values: Tensor,
Expand All @@ -52,10 +49,10 @@ def forward(
needs_weights: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
:param queries:
The queries. *Shape:* :math:`(N,H,S,K)`, where :math:`N` is the
batch size, :math:`H` is the number of heads, :math:`S` is the
sequence length, and :math:`K` is the key size.
:param seqs:
The sequences to query. *Shape:* :math:`(N,H,S,K)`, where :math:`N`
is the batch size, :math:`H` is the number of heads, :math:`S` is
the sequence length, and :math:`K` is the key size.
:param keys:
The keys. *Shape:* :math:`(N,H,S_{kv},K)`, where :math:`N` is the
batch size, :math:`H` is the number of heads, :math:`S_{kv}` is the
Expand Down Expand Up @@ -107,17 +104,17 @@ def __init__(self, *, attn_dropout_p: float = 0.0) -> None:
@finaloverride
def forward(
self,
queries: Tensor,
seqs: Tensor,
keys: Tensor,
key_padding_mask: Optional[PaddingMask],
values: Tensor,
*,
attn_mask: Optional[AttentionMask] = None,
needs_weights: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
if not queries.is_cuda:
if not seqs.is_cuda:
return _naive_scaled_dot_product_attention(
queries,
seqs,
keys,
key_padding_mask,
values,
Expand All @@ -136,7 +133,7 @@ def forward(
self._has_warned = True

return _naive_scaled_dot_product_attention(
queries,
seqs,
keys,
key_padding_mask,
values,
Expand All @@ -160,26 +157,31 @@ def forward(
mask = mask[:, None, None, :]

# (N, 1, 1, S_kv) -> (N, H, S, S_kv)
mask = mask.expand(-1, queries.size(1), queries.size(2), -1)
mask = mask.expand(-1, seqs.size(1), seqs.size(2), -1)

if attn_mask is not None:
# ([H], S, S_kv)
m = attn_mask.materialize()

# (N, H, S, S_kv)
mask = torch.where(mask, m, -torch.inf)
elif isinstance(attn_mask, GlobalCausalAttentionMask):
mask = None
elif isinstance(attn_mask, CausalAttentionMask):
# PyTorch SDPA supports only full causal attention.
if attn_mask.attn_window_len is None:
mask = None

is_causal = True
is_causal = True
else:
# ([H], S, S_kv)
mask = attn_mask.materialize()
elif attn_mask is not None:
# ([H], S, S_kv)
mask = attn_mask.materialize()
else:
mask = None

attn = F.scaled_dot_product_attention( # type: ignore[attr-defined]
queries,
seqs,
keys,
values,
attn_mask=mask,
Expand All @@ -197,7 +199,7 @@ class NaiveSDPA(SDPA):
@finaloverride
def forward(
self,
queries: Tensor,
seqs: Tensor,
keys: Tensor,
key_padding_mask: Optional[PaddingMask],
values: Tensor,
Expand All @@ -206,7 +208,7 @@ def forward(
needs_weights: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
return _naive_scaled_dot_product_attention(
queries,
seqs,
keys,
key_padding_mask,
values,
Expand All @@ -218,7 +220,7 @@ def forward(


def _naive_scaled_dot_product_attention(
queries: Tensor,
seqs: Tensor,
keys: Tensor,
key_padding_mask: Optional[PaddingMask],
values: Tensor,
Expand All @@ -228,9 +230,9 @@ def _naive_scaled_dot_product_attention(
training: bool,
) -> Tuple[Tensor, Optional[Tensor]]:
# (N, H, S, K) @ (N, H, K, S_kv) = (N, H, S, S_kv)
attn_weights = torch.matmul(queries, keys.transpose(-1, -2))
attn_weights = torch.matmul(seqs, keys.transpose(-1, -2))

attn_weights = attn_weights * (queries.size(-1) ** -0.5)
attn_weights = attn_weights * (seqs.size(-1) ** -0.5)

if attn_mask is not None:
# (S, S_kv)
Expand All @@ -251,7 +253,7 @@ def _naive_scaled_dot_product_attention(
# For numerical stability run in single precision.
attn_weights = softmax(attn_weights, dim=-1, dtype=torch.float32)

attn_weights = attn_weights.type_as(queries)
attn_weights = attn_weights.type_as(seqs)

if training and dropout_p > 0.0:
attn_weights = dropout(attn_weights, dropout_p)
Expand Down
Loading

0 comments on commit fa7886d

Please sign in to comment.