diff --git a/src/fairseq2/nn/incremental_state.py b/src/fairseq2/nn/incremental_state.py index 3e8602ce5..c99a34037 100644 --- a/src/fairseq2/nn/incremental_state.py +++ b/src/fairseq2/nn/incremental_state.py @@ -58,8 +58,8 @@ def increment_step(self, delta: int = 1) -> None: """Increment the step. This method should be called after every incremental evaluation (e.g. - beam search). It is used by modules to keep track of the position in - the sequence. + beam search) step. It is used by modules to keep track of the position + in the sequence. :param delta: The value by which to increment the step. diff --git a/src/fairseq2/nn/transformer/attention.py b/src/fairseq2/nn/transformer/attention.py index 254de5c10..db4b2ffd7 100644 --- a/src/fairseq2/nn/transformer/attention.py +++ b/src/fairseq2/nn/transformer/attention.py @@ -198,11 +198,11 @@ def _naive_scaled_dot_product_attention( needs_weights: bool, training: bool, ) -> Tuple[Tensor, Optional[Tensor]]: - queries = queries * (queries.size(-1) ** -0.5) - # (*, S, K) @ (*, K, S_kv) = (*, S, S_kv) attn_weights = torch.matmul(queries, keys.transpose(-1, -2)) + attn_weights = attn_weights * (queries.size(-1) ** -0.5) + if mask is not None: attn_weights = attn_weights + mask diff --git a/src/fairseq2/nn/transformer/multihead_attention.py b/src/fairseq2/nn/transformer/multihead_attention.py index 1d8c0716b..84d80093f 100644 --- a/src/fairseq2/nn/transformer/multihead_attention.py +++ b/src/fairseq2/nn/transformer/multihead_attention.py @@ -111,9 +111,13 @@ def register_attn_weight_hook(self, hook: "AttentionWeightHook") -> RemovableHan return handle - def _run_attn_weight_hooks(self, attn_weights: Tensor) -> None: + def _run_attn_weight_hooks(self, attn: Tensor, attn_weights: Tensor) -> None: """Run registered attention weight hooks. + :param attn: + The computed attention values. *Shape:* :math:`(N,S,V)`, where + :math:`N` is the batch size, :math:`S` is the sequence length, and + :math:`V` is the value size. :param attn_weights: The computed attention weights. *Shape:* :math:`(N,S,S_{kv})`, where :math:`N` is the batch size, :math:`S` is the sequence length, and @@ -122,7 +126,7 @@ def _run_attn_weight_hooks(self, attn_weights: Tensor) -> None: :meta public: """ for hook in self._attn_weight_hooks.values(): - hook(self, attn_weights) + hook(self, attn, attn_weights) def extra_repr(self) -> str: """:meta private:""" @@ -133,10 +137,16 @@ class AttentionWeightHook(Protocol): """Represents a hook to pass to :meth:`~MultiheadAttention.register_attn_weight_hook`.""" - def __call__(self, m: MultiheadAttention, attn_weights: Tensor) -> None: + def __call__( + self, m: MultiheadAttention, attn: Tensor, attn_weights: Tensor + ) -> None: """ :param m: The module that has computed the attention weights. + :param attn: + The computed attention values. *Shape:* :math:`(N,S,V)`, where + :math:`N` is the batch size, :math:`S` is the sequence length, and + :math:`V` is the value size. :param attn_weights: The computed attention weights. *Shape:* :math:`(N,S,S_{kv})`, where :math:`N` is the batch size, :math:`S` is the sequence length, and @@ -151,17 +161,19 @@ class StoreAttentionWeights: This class follows the :class:`AttentionWeightHook` protocol. """ - _storage: MutableSequence[Tensor] + _storage: MutableSequence[Tuple[Tensor, Tensor]] - def __init__(self, storage: MutableSequence[Tensor]) -> None: + def __init__(self, storage: MutableSequence[Tuple[Tensor, Tensor]]) -> None: """ :param storage: The storage in which to store attention weights. """ self._storage = storage - def __call__(self, m: "MultiheadAttention", attn_weights: Tensor) -> None: - self._storage.append(attn_weights) + def __call__( + self, m: MultiheadAttention, attn: Tensor, attn_weights: Tensor + ) -> None: + self._storage.append((attn, attn_weights)) @final @@ -209,7 +221,7 @@ def __init__( The number of key/value heads for Grouped Query Attention as described in :cite:t:`https://doi.org/10.48550/arXiv.2305.13245`. If ``None`` or set to ``num_heads``, it is equivalent to standard - Multi Head Attention (MHA). If set to 1, it is equivalent to Multi + Multi Head Attention (MHA); if set to 1, it is equivalent to Multi Query Attention (MQA). :param q_proj: The projection to apply to queries before computing attention. If @@ -387,107 +399,79 @@ def forward( key_padding_mask: Optional[Tensor] = None, state_bag: Optional[IncrementalStateBag] = None, ) -> Tensor: - # (*, M) -> (N, S, K_proj) - q = self.q_proj(queries) + # (N, S, M) -> (N, H, S, K_h) + q = self._project_q(queries, padding_mask, state_bag) if self.training or state_bag is None: - # (*, K) -> (N, S_kv, K_proj) - k = self.k_proj(keys) - # (*, V) -> (N, S_kv, V_proj) - v = self.v_proj(values) + # k: (N, S_kv, M) -> (N, H_kv, S_kv, K_h) + # v: (N, S_kv, M) -> (N, H_kv, S_kv, V_h) + k, v = self._project_kv(keys, key_padding_mask, values) else: - state = state_bag.get_state(self, MultiheadAttentionState) - encoder_decoder_attn = keys is values and keys is not queries - if encoder_decoder_attn: + static_state = state_bag.get_state(self, StaticMultiheadAttentionState) + # The K and V tensors of an encoder-decoder attention (i.e. the # projected encoder outputs) remain static during evaluation. - if state is not None: - k = state.prev_k - v = state.prev_v + if static_state is not None: + k = static_state.k + v = static_state.v else: - # (*, K) -> (N, S_kv, K_proj) - k = self.k_proj(keys) - # (*, V) -> (N, S_kv, V_proj) - v = self.v_proj(values) + # k: (N, S_kv, M) -> (N, H_kv, S_kv, K_h) + # v: (N, S_kv, M) -> (N, H_kv, S_kv, V_h) + k, v = self._project_kv(keys, key_padding_mask, values) - state_bag.set_state(self, MultiheadAttentionState(k, v)) + state_bag.set_state(self, StaticMultiheadAttentionState(k, v)) else: - # (*, K) -> (N, S_kv, K_proj) - k = self.k_proj(keys) - # (*, V) -> (N, S_kv, V_proj) - v = self.v_proj(values) + # k: (N, S_step, M) -> (N, H_kv, S_step, K_h) + # v: (N, S_step, M) -> (N, H_kv, S_step, V_h) + k, v = self._project_kv(keys, key_padding_mask, values, state_bag) - if state is not None: - k, v, key_padding_mask = state.append(k, v, key_padding_mask) - else: - state_bag.set_state( - self, MultiheadAttentionState(k, v, key_padding_mask) - ) - - # (N, S, Q_proj) -> (N, S, H, K_h) - q = q.unflatten(-1, (self.num_heads, -1)) - # (N, S_kv, K_proj) -> (N, S_kv, H_kv, K_h) - k = k.unflatten(-1, (self.num_key_value_heads, -1)) - # (N, S_kv, V_proj) -> (N, S_kv, H_kv, V_h) - v = v.unflatten(-1, (self.num_key_value_heads, -1)) - - # (N, S, H, K_h) -> (N, H, S, K_h) - q = q.transpose(1, 2) - # (N, S_kv, H_kv, K_h) -> (N, H_kv, S_kv, K_h) - k = k.transpose(1, 2) - # (N, S_kv, H_kv, V_h) -> (N, H_kv, S_kv, V_h) - v = v.transpose(1, 2) - - # With Grouped Query Attention, each key/value head is repeated - # `num_query_groups` times to match `num_heads`. - num_query_groups = self.num_heads // self.num_key_value_heads - - # (N, H_kv, S_kv, K_h) -> (N, H, S_kv, K_h) - k = torch.repeat_interleave(k, dim=1, repeats=num_query_groups) - # (N, H_kv, S_kv, K_h) -> (N, H, S_kv, K_h) - v = torch.repeat_interleave(v, dim=1, repeats=num_query_groups) - - # (N, H, S, K_h) -> (N x H, S, K_h) - q = q.flatten(0, 1) - # (N, H, S_kv, K_h) -> (N x H, S_kv, K_h) - k = k.flatten(0, 1) - # (N, H, S_kv, V_h) -> (N x H, S_kv, V_h) - v = v.flatten(0, 1) + state = state_bag.get_state(self, MultiheadAttentionState) + if state is None: + state = MultiheadAttentionState(k, v) - if self.pos_encoder is not None: - q = self.pos_encoder(q, padding_mask, state_bag=state_bag) - k = self.pos_encoder(k, key_padding_mask) + state_bag.set_state(self, state) + + # k: (N, H_kv, S_kv, K_h) + # v: (N, H_kv, S_kv, V_h) + k, v, key_padding_mask = state.append(k, v, key_padding_mask) + + # With Grouped Query Attention, each key/value head is repeated. + if (num_query_groups := self.num_heads // self.num_key_value_heads) > 1: + # (N, H_kv, S_kv, K_h) -> (N, H, S_kv, K_h) + k = torch.repeat_interleave(k, dim=1, repeats=num_query_groups) + # (N, H_kv, S_kv, K_h) -> (N, H, S_kv, V_h) + v = torch.repeat_interleave(v, dim=1, repeats=num_query_groups) mask_pad = 0 if self.bias_k is not None and self.bias_v is not None: batch_size = keys.size(0) - # (H, 1, K_proj) -> (N x H, 1, K_proj) - bias_k = self.bias_k.repeat(batch_size, 1, 1) + # (H, 1, K_proj) -> (N, H, 1, K_proj) + bias_k = self.bias_k.expand(batch_size, -1, -1, -1) # (H, 1, V_proj) -> (N x H, 1, V_proj) - bias_v = self.bias_v.repeat(batch_size, 1, 1) + bias_v = self.bias_v.expand(batch_size, -1, -1, -1) - # (N x H, S_kv, K_h) -> (N x H, S_kv + 1, K_h) - k = torch.cat([k, bias_k], dim=1) - # (N x H, S_kv, V_h) -> (N x H, S_kv + 1, V_h) - v = torch.cat([v, bias_v], dim=1) + # (N, H, S_kv, K_h) -> (N, H, S_kv + 1, K_h) + k = torch.cat([k, bias_k], dim=2) + # (N, H, S_kv, V_h) -> (N, H, S_kv + 1, V_h) + v = torch.cat([v, bias_v], dim=2) mask_pad += 1 if self.add_zero_attn: - # (N x H, S_kv, K_h) -> (N x H, S_kv + 1, K_h) - k = torch.cat([k, k.new_zeros((k.size(0), 1, k.size(2)))], dim=1) - # (N x H, S_kv, V_h) -> (N x H, S_kv + 1, V_h) - v = torch.cat([v, v.new_zeros((v.size(0), 1, v.size(2)))], dim=1) + # (N, H, S_kv, K_h) -> (N, H, S_kv + 1, K_h) + k = torch.cat([k, k.new_zeros((k.size(0), k.size(1), 1, k.size(3)))], dim=2) + # (N, H, S_kv, V_h) -> (N, H, S_kv + 1, V_h) + v = torch.cat([v, v.new_zeros((v.size(0), v.size(1), 1, v.size(3)))], dim=2) mask_pad += 1 if mask_pad > 0: if attn_mask is not None: - # (T, S_kv) -> (T, S_kv + mask_pad) + # (S, S_kv) -> (S, S_kv + mask_pad) attn_mask = pad(attn_mask, (0, mask_pad)) if key_padding_mask is not None: @@ -496,7 +480,8 @@ def forward( if key_padding_mask is not None: # (N, S_kv) -> (N, 1, 1, S_kv) - key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1) + key_padding_mask = key_padding_mask[:, None, None, :] + # (N, 1, 1, S_kv) -> (N, H, 1, S_kv) key_padding_mask = key_padding_mask.expand(-1, self.num_heads, -1, -1) @@ -504,25 +489,19 @@ def forward( # (N, H, 1, S_kv) attn_mask = key_padding_mask else: - # (N, H, 1, S_kv) + ([H,], S, S_kv) = (N, H, S, S_kv) + # (N, H, 1, S_kv) + ([H], S, S_kv) = (N, H, S, S_kv) attn_mask = key_padding_mask + attn_mask - # (N, H, S, S_kv) -> (N x H, 1, S_kv) - attn_mask = attn_mask.flatten(0, 1) - needs_weights = len(self._attn_weight_hooks) > 0 - # attn: (N x H, S, V_h) - # attn_weights: (N x H, S, S_kv) + # attn: (N, H, S, V_h) + # attn_weights: (N, H, S, S_kv) attn, attn_weights = self.sdpa( q, k, v, mask=attn_mask, needs_weights=needs_weights ) if attn_weights is not None: - self._run_attn_weight_hooks(attn_weights) - - # (N x H, S, V_h) -> (N, H, S, V_h) - attn = attn.unflatten(0, (-1, self.num_heads)) + self._run_attn_weight_hooks(attn, attn_weights) # (N, H, S, V_h) -> (N, S, H, V_h) attn = attn.permute(0, 2, 1, 3) @@ -531,12 +510,51 @@ def forward( attn = torch.einsum("nshv,h->nshv", attn, self.head_scale_weight) # (N, S, H, V_h) -> (N, S, V_proj) - attn = attn.flatten(-2, -1) + attn = attn.flatten(2, 3) # (N, S, V_proj) -> (N, S, M) attn = self.output_proj(attn) - return attn # type: ignore + return attn # type: ignore[no-any-return] + + def _project_q( + self, + queries: Tensor, + padding_mask: Optional[Tensor], + state_bag: Optional[IncrementalStateBag] = None, + ) -> Tensor: + # (N, S, M) -> (N, S, K_proj) + q = self.q_proj(queries) + + # (N, S, K_proj) -> (N, H, S, K_h) + q = q.unflatten(-1, (self.num_heads, -1)).transpose(1, 2) + + if self.pos_encoder is not None: + q = self.pos_encoder(q, padding_mask, state_bag=state_bag) + + return q # type: ignore[no-any-return] + + def _project_kv( + self, + keys: Tensor, + key_padding_mask: Optional[Tensor], + values: Tensor, + state_bag: Optional[IncrementalStateBag] = None, + ) -> Tuple[Tensor, Tensor]: + # (N, S, K) -> (N, S, K_proj) + k = self.k_proj(keys) + # (N, S, V) -> (N, S, V_proj) + v = self.v_proj(values) + + # (N, S, K_proj) -> (N, H, S, K_h) + k = k.unflatten(-1, (self.num_key_value_heads, -1)).transpose(1, 2) + # (N, S, V_proj) -> (N, H, S, V_h) + v = v.unflatten(-1, (self.num_key_value_heads, -1)).transpose(1, 2) + + if self.pos_encoder is not None: + k = self.pos_encoder(k, key_padding_mask, state_bag=state_bag) + + return k, v def extra_repr(self) -> str: """:meta private:""" @@ -601,112 +619,147 @@ class MultiheadAttentionState(IncrementalState): """Holds the state of a :class:`MultiheadAttention` module during an incremental evaluation.""" - prev_k: Tensor - """The projected keys accumulated from the past incremental evaluations. - *Shape:* :math:`(N,S_{prv},K_{proj})`, where :math:`N` is the batch size, - :math:`S_{prv}` is the accumulated key/value sequence length, and - :math:`K_{proj}` is the projected key size.""" + cache_reserve_size = 512 + """The reserved sequence length capacity of :attr:`k` and :attr:`v` will be + increased by multiplies of the specified value.""" + + seq_len: int + """The current sequence length of :attr:`k` and :attr:`v`.""" - prev_v: Tensor - """The projected values accumulated from the past incremental evaluations. - *Shape:* :math:`(N,S_{prv},V_{proj})`, where :math:`N` is the batch size, - :math:`S_{prv}` is the accumulated key/value sequence length, and - :math:`V_{proj}` is the projected value size.""" + k: Tensor + """The projected keys accumulated from the past incremental evaluation + steps. *Shape:* :math:`(N,H,S,K_{proj})`, where :math:`N` is the batch + size, :math:`H` is the number of heads, :math:`S` is the reserved sequence + length capacity, and :math:`K_{proj}` is the projected key size.""" - prev_key_padding_mask: Optional[Tensor] + v: Tensor + """The projected values accumulated from the past incremental evaluation + steps. *Shape:* :math:`(N,H,S,V_{proj})`, where :math:`N` is the batch + size, :math:`H` is the number of heads, :math:`S` is the reserved sequence + length capacity, and :math:`V_{proj}` is the projected value size.""" + + key_padding_mask: Tensor """The float key padding mask accumulated from the past incremental - evaluations. *Shape:* :math:`(N,S_{prv})`, where :math:`N` is the batch size - and :math:`S_{prv}` is the accumulated key/value sequence length.""" + evaluation steps. *Shape:* :math:`(N,S)`, where :math:`N` is the batch + size and :math:`S` is the reserved sequence length capacity.""" - def __init__( - self, k: Tensor, v: Tensor, key_padding_mask: Optional[Tensor] = None - ) -> None: + has_mask: bool + + def __init__(self, k: Tensor, v: Tensor) -> None: """ :param k: - The initial projected keys. *Shape:* :math:`(N,S_{int},K_{proj})`, - where :math:`N` is the batch size, :math:`S_{int}` is the initial - key/value sequence length, and :math:`K_{proj}` is the projected key - size. + The projected keys to bootstrap the internal state. :param v: - The initial projected values. *Shape:* :math:`(N,S_{int},V_{proj})`, - where :math:`N` is the batch size, :math:`S_{int}` is the initial - key/value sequence length, and :math:`V_{proj}` is the projected - value size. - :param key_padding_mask: - The initial float key padding mask. *Shape:* :math:`(N,S_{int})`, - where :math:`N` is the batch size and :math:`S_{int}` is the initial - key/value sequence length. + The projected values to bootstrap the internal state. """ - self.prev_k = k - self.prev_v = v + batch_size, num_heads, _, head_dim = k.shape + + self.seq_len = 0 + + self.k = k.new_empty((batch_size, num_heads, 0, head_dim)) + self.v = v.new_empty((batch_size, num_heads, 0, head_dim)) - self.prev_key_padding_mask = key_padding_mask + self.key_padding_mask = k.new_zeros((batch_size, 0)) + + self.has_mask = False def append( self, k: Tensor, v: Tensor, key_padding_mask: Optional[Tensor] ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: """Append the projected key, projected value, and float key padding mask - of the current incremental evaluation to :attr:`prev_k`, :attr:`prev_v`, - and :attr:`key_padding_mask`. + of the current incremental evaluation step to :attr:`k`, + :attr:`v`, and :attr:`key_padding_mask`. :param k: - The projected key of the current incremental evaluation. *Shape:* - :math:`(N,S_{stp},K_{proj})`, where :math:`N` is the batch size, - :math:`S_{stp}` is the step length (e.g. 1), and :math:`K_{proj}` is - the projected key size. + The projected key of the current incremental evaluation step. + *Shape:* :math:`(N,H,S_{stp},K_{proj})`, where :math:`N` is the + batch size, :math:`H` is the number of heads, :math:`S_{stp}` is the + step length (e.g. 1), and :math:`K_{proj}` is the projected key + size. :param v: - The projected value of the current incremental evaluation. *Shape:* - :math:`(N,S_{stp},V_{proj})`, where :math:`N` is the batch size, - :math:`S_{stp}` is the step length (e.g. 1), and :math:`V_{proj}` is - the projected value size. + The projected value of the current incremental evaluation step. + *Shape:* :math:`(N,H,S_{stp},V_{proj})`, where :math:`N` is the + batch size, :math:`H` is the number of heads, :math:`S_{stp}` is the + step length (e.g. 1), and :math:`V_{proj}` is the projected value + size. :param key_padding_mask: - The float key padding mask of the current incremental evaluation. - *Shape:* :math:`(N,S_{stp})`, where :math:`N` is the batch size and - :math:`S_{stp}` is the step length (e.g. 1). + The float key padding mask of the current incremental evaluation + step. *Shape:* :math:`(N,S_{stp})`, where :math:`N` is the batch + size and :math:`S_{stp}` is the step length (e.g. 1). :returns: The projected keys, projected values, and float key padding mask that should be used to compute the attention. """ - seq_len = k.size(1) + seq_len = k.size(2) + + start, end = self.seq_len, self.seq_len + seq_len - prev_seq_len = self.prev_k.size(1) + if end > self.k.size(2): + batch_size, num_heads, seq_len, head_dim = k.shape - self.prev_k = torch.cat([self.prev_k, k], dim=1) - self.prev_v = torch.cat([self.prev_v, v], dim=1) + # Ensure that the reserved space is always at least as long as the + # input sequence. + extra_capacity = self.cache_reserve_size * ( + (self.cache_reserve_size + seq_len - 1) // self.cache_reserve_size + ) - # Appending the key padding mask is trickier since the previous or - # current mask can be `None`. - self._append_key_padding_mask(key_padding_mask, seq_len, prev_seq_len) + cache_k = k.new_empty((batch_size, num_heads, extra_capacity, head_dim)) + cache_v = v.new_empty((batch_size, num_heads, extra_capacity, head_dim)) - return self.prev_k, self.prev_v, self.prev_key_padding_mask + self.k = torch.cat([self.k, cache_k], dim=2) + self.v = torch.cat([self.v, cache_v], dim=2) - def _append_key_padding_mask( - self, curr_mask: Optional[Tensor], curr_seq_len: int, prev_seq_len: int - ) -> None: - prev_mask = self.prev_key_padding_mask + cache_key_padding_mask = k.new_zeros((batch_size, extra_capacity)) - if prev_mask is None and curr_mask is None: - return + self.key_padding_mask = torch.cat( + [self.key_padding_mask, cache_key_padding_mask], dim=1 + ) - batch_size = self.prev_k.size(0) + self.k[:, :, start:end] = k + self.v[:, :, start:end] = v + + if key_padding_mask is not None: + self.has_mask = True - # One of the masks can be `None`. We have to ensure that both of them - # are fully materialized before concatenating. - if prev_mask is None: - prev_mask = self.prev_k.new_zeros((batch_size, prev_seq_len)) + self.key_padding_mask[:, start:end] = key_padding_mask - if curr_mask is None: - curr_mask = self.prev_k.new_zeros((batch_size, curr_seq_len)) + self.seq_len = end - self.prev_key_padding_mask = torch.cat([prev_mask, curr_mask], dim=1) + k = self.k[:, :, :end] + v = self.v[:, :, :end] + + key_padding_mask = self.key_padding_mask[:, :end] if self.has_mask else None + + return k, v, key_padding_mask @override def reorder(self, new_order: Tensor) -> None: - self.prev_k = self.prev_k.index_select(0, new_order) - self.prev_v = self.prev_v.index_select(0, new_order) + self.k = self.k.index_select(0, new_order) + self.v = self.v.index_select(0, new_order) + + if self.has_mask: + self.key_padding_mask = self.key_padding_mask.index_select(0, new_order) + - if self.prev_key_padding_mask is not None: - mask = self.prev_key_padding_mask.index_select(0, new_order) +class StaticMultiheadAttentionState(IncrementalState): + """Holds the state of an encoder-decoder :class:`MultiheadAttention` module + during an incremental evaluation.""" - self.prev_key_padding_mask = mask + k: Tensor + v: Tensor + + def __init__(self, k: Tensor, v: Tensor) -> None: + """ + :param k: + The encoder output projected as key. + :param v: + The encoder output projected as value. + """ + self.k = k + self.v = v + + @override + def reorder(self, new_order: Tensor) -> None: + self.k = self.k.index_select(0, new_order) + self.v = self.v.index_select(0, new_order) diff --git a/src/fairseq2/nn/transformer/relative_attention.py b/src/fairseq2/nn/transformer/relative_attention.py index 5f63a2273..81d6941e7 100644 --- a/src/fairseq2/nn/transformer/relative_attention.py +++ b/src/fairseq2/nn/transformer/relative_attention.py @@ -62,7 +62,7 @@ def __init__( if pos_encoding.encoding_dim != model_dim: raise ValueError( - f"`encoding_dim` of `pos_encoding` must be equal `model_dim` ({model_dim}), but is {pos_encoding.encoding_dim} instead." + f"`encoding_dim` of `pos_encoding` must be equal to `model_dim` ({model_dim}), but is {pos_encoding.encoding_dim} instead." ) self.pos_encoding = pos_encoding @@ -97,6 +97,11 @@ def forward( mask: Optional[Tensor] = None, needs_weights: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: + if queries.ndim != 4 or keys.ndim != 4 or values.ndim != 4: + raise ValueError( + "`RelativePositionSDPA` can only be used as part of a multi-head attention layer and expects its input tensors to be 4 dimensional." + ) + q = queries k = keys @@ -104,29 +109,23 @@ def forward( u_bias = self.u_bias.unsqueeze(1) v_bias = self.v_bias.unsqueeze(1) - # (N x H, S, K_h) -> (N, H, S, K_h) - q = q.unflatten(0, (-1, self.num_heads)) - # (N, H, S, K_h) + (H, 1, K_h) -> (N, H, S, K_h) q_with_u_bias = q + u_bias q_with_v_bias = q + v_bias - # (N, H, S, K_h) -> (N x H, S, K_h) - q_with_u_bias = q_with_u_bias.flatten(0, 1) - q_with_v_bias = q_with_v_bias.flatten(0, 1) - - # (N x H, 2 x S - 1, K_h) + # (N, H, 2 x S - 1, K_h) r = self._compute_r(k, batch_size=q.size(0)) - # (N x H, S, K_h) @ (N x H, K_h, S) = (N x H, S, S) - ac = torch.bmm(q_with_u_bias, k.transpose(1, 2)) + # (N, H, S, K_h) @ (N, H, K_h, S) = (N, H, S, S) + ac = torch.matmul(q_with_u_bias, k.transpose(-1, -2)) - # (N x H, S, K_h) @ (N x H, K_h, 2 x S - 1) = (N x H, S, 2 x S - 1) - bd = torch.bmm(q_with_v_bias, r.transpose(1, 2)) + # (N, H, S, K_h) @ (N, H, K_h, 2 x S - 1) = (N, H, S, 2 x S - 1) + bd = torch.matmul(q_with_v_bias, r.transpose(-1, -2)) - # (N x H, S, 2 x S -1) -> (N x H, S, S) + # (N, H, S, 2 x S - 1) -> (N, H, S, S) bd = self._shift_bd(bd) + # (N, H, S, S) attn_weights = (ac + bd) * (q.size(-1) ** -0.5) if mask is not None: @@ -139,47 +138,44 @@ def forward( if self.training and self.attn_dropout_p > 0.0: attn_weights = dropout(attn_weights, self.attn_dropout_p) - # (N x H, S, S) @ (N x H, S, V_h) = (N x H, S, V_h) - attn = torch.bmm(attn_weights, values) + # (N, H, S, S) @ (N, H, S, V_h) = (N, H, S, V_h) + attn = torch.matmul(attn_weights, values) return attn, attn_weights if needs_weights else None def _compute_r(self, k: Tensor, batch_size: int) -> Tensor: - # (S, K) -> (2 x S - 1, K) + # (2 x S - 1, K) r = self.pos_encoding(k) # (2 x S - 1, K) -> (2 x S - 1, K) r = self.r_proj(r) # (2 x S - 1, K) -> (1, 2 x S - 1, H, K_h) - r = r.view(1, -1, self.num_heads, k.size(2)) + r = r.view(1, -1, self.num_heads, k.size(-1)) # (1, 2 x S - 1, H, K_h) -> (N, H, 2 x S - 1, K_h) r = r.transpose(1, 2).expand(batch_size, -1, -1, -1) - # (N, H, 2 x S - 1, K_h) -> (N x H, 2 x S - 1, K_h) - r = r.flatten(0, 1) - return r # type: ignore[no-any-return] def _shift_bd(self, bd: Tensor) -> Tensor: - # (N x H, S, 2 x S - 1) -> (N x H, S, 2 x S) + # (N, H, S, 2 x S - 1) -> (N, H, S, 2 x S) x = pad(bd, (1, 0)) - # (N x H, S, 2 x S) -> (N x H, 2 x S, S) - x = x.view(x.size(0), x.size(2), x.size(1)) + # (N, H, S, 2 x S) -> (N, H, 2 x S, S) + x = x.view(x.size(0), x.size(1), x.size(3), x.size(2)) # Discard the first set of positive positions. - # (N x H, 2 x S, S) -> (N x H, 2 x S - 1, S) - x = x[:, 1:, :] + # (N, H, 2 x S, S) -> (N, H, 2 x S - 1, S) + x = x[:, :, 1:, :] # This op effectively shifts each row by an extra step. - # (N x H, S, 2 x S - 1) + # (N, H, 2 x S - 1, S) -> (N, H, S, 2 x S - 1) x = x.view_as(bd) # Discard positions used for shift. - # (N x H, S, 2 x S - 1) -> (N x H, S, S) - x = x[..., : bd.size(1)] + # (N, H, S, 2 x S - 1) -> (N, H, S, S) + x = x[..., : bd.size(2)] return x @@ -280,9 +276,9 @@ def forward(self, seqs: Tensor) -> Tensor: """ :param seqs: The sequences for which to return positional encodings. *Shape:* - :math:`(N,S,*)`, where :math:`N` is the batch size, :math:`S` is the - sequence length, and :math:`*` is any number of sequence-specific - dimensions including none. + :math:`(*,S,E)`, where :math:`*` is any number of batch dimensions + including none, :math:`S` is the sequence length, and :math:`E` is + the dimensionality of the positional encodings. :returns: The positional encodings to use in shift trick in @@ -290,7 +286,7 @@ def forward(self, seqs: Tensor) -> Tensor: where :math:`S` is the sequence length and :math:`E` is the dimensionality of the positional encodings. """ - seq_len = seqs.size(1) + seq_len = seqs.size(-2) if seq_len > self.max_seq_len: raise ValueError( diff --git a/src/fairseq2/nn/utils/mask.py b/src/fairseq2/nn/utils/mask.py index f70ade78e..92cdc9121 100644 --- a/src/fairseq2/nn/utils/mask.py +++ b/src/fairseq2/nn/utils/mask.py @@ -87,10 +87,8 @@ def apply_padding_mask(seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor: the batch size, :math:`S` is the sequence length, and :math:`*` is any number of sequence-specific dimensions including none. :param padding_mask: - The float padding mask to apply. *Shape:* :math:`(N_{msk},S)`, where - :math:`N_{msk}` is the mask batch size and :math:`S` is the sequence - length. :math:`N` can be a multiple of :math:`N_{msk}` in which case the - mask will be tiled before being applied. + The float padding mask to apply. *Shape:* :math:`(N,S)`, where :math:`N` + is the batch size and :math:`S` is the sequence length. :returns: The input sequences with mask applied. *Shape:* Same as ``seqs``. @@ -100,15 +98,6 @@ def apply_padding_mask(seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor: bool_mask = padding_mask.isinf() - seq_batch_size, mask_batch_size = seqs.size(0), padding_mask.size(0) - if seq_batch_size != mask_batch_size: - if seq_batch_size % mask_batch_size != 0: - raise ValueError( - f"`seqs.size(0)` must be a multiple of `padding_mask.size(0)` ({mask_batch_size}), but is {seq_batch_size} instead." - ) - - bool_mask = bool_mask.repeat(seq_batch_size // mask_batch_size, 1) - if seqs.ndim > 2: bool_mask = bool_mask.unsqueeze(2)