diff --git a/src/fairseq2/nn/transformer/attention.py b/src/fairseq2/nn/transformer/attention.py index 94559965f..254de5c10 100644 --- a/src/fairseq2/nn/transformer/attention.py +++ b/src/fairseq2/nn/transformer/attention.py @@ -46,33 +46,34 @@ def forward( ) -> Tuple[Tensor, Optional[Tensor]]: """ :param queries: - The queries. *Shape:* :math:`(N,S,K)`, where :math:`N` is the batch - size, :math:`S` is the sequence length, and :math:`K` is the key - size. + The queries. *Shape:* :math:`(*,S,K)`, where :math:`*` is any number + of batch dimensions including none, :math:`S` is the sequence + length, and :math:`K` is the key size. :param keys: - The keys. *Shape:* :math:`(N,S_{kv},K)`, where :math:`N` is the - batch size, :math:`S_{kv}` is the key/value sequence length, and - :math:`K` is the key size. + The keys. *Shape:* :math:`(N,S_{kv},K)`, where :math:`*` is any + number of batch dimensions including none, :math:`S_{kv}` is the + key/value sequence length, and :math:`K` is the key size. :param values: - The values. *Shape:* :math:`(N,S_{kv},V)`, where :math:`N` is the - batch size, :math:`S_{kv}` is the key/value sequence length, and - :math:`V` is the value size. + The values. *Shape:* :math:`(N,S_{kv},V)`, where :math:`*` is any + number of batch dimensions including none, :math:`S_{kv}` is the + key/value sequence length, and :math:`V` is the value size. :param mask: The float mask that will be added to the attention weights before computing the attention. *Shape:* :math:`(S,S_{kv})` or - :math:`(N,S,S_{kv})`, where :math:`N` is the batch size, - :math:`S` is the sequence length, and :math:`S_{kv}` is the - key/value sequence length. + :math:`(*,S,S_{kv})`, where :math:`*` is any number of batch + dimensions including none, :math:`S` is the sequence length, and + :math:`S_{kv}` is the key/value sequence length. :param needs_weights: If ``True``, returns the attention weights. :returns: - - The attention values. *Shape:* :math:`(N,S,V)`, where - :math:`N` is the batch size, :math:`S` is the sequence length, and - :math:`V` is the value size. - - The attention weights. *Shape:* :math:`(N,S,S_{kv})`, where - :math:`N` is the batch size, :math:`S` is the sequence length, and - :math:`S_{kv}` is the key/value sequence length. + - The attention values. *Shape:* :math:`(*,S,V)`, where :math:`*` + is the same batch dimensions as input, :math:`S` is the sequence + length, and :math:`V` is the value size. + - The attention weights. *Shape:* :math:`(*,S,S_{kv})`, where + :math:`*` is the same batch dimensions as input, :math:`S` is the + sequence length, and :math:`S_{kv}` is the key/value sequence + length. """ def extra_repr(self) -> str: @@ -199,12 +200,11 @@ def _naive_scaled_dot_product_attention( ) -> Tuple[Tensor, Optional[Tensor]]: queries = queries * (queries.size(-1) ** -0.5) - if mask is None: - # (N, S, K) @ (N, K, S_kv) = (N, S, S_kv) - attn_weights = torch.bmm(queries, keys.transpose(1, 2)) - else: - # (N, S, S_kv) + ((N, S, K) @ (N, K, S_kv)) = (N, S, S_kv) - attn_weights = torch.baddbmm(mask, queries, keys.transpose(1, 2)) + # (*, S, K) @ (*, K, S_kv) = (*, S, S_kv) + attn_weights = torch.matmul(queries, keys.transpose(-1, -2)) + + if mask is not None: + attn_weights = attn_weights + mask # For numerical stability run in single precision. attn_weights = softmax(attn_weights, dim=-1, dtype=torch.float32) @@ -214,7 +214,7 @@ def _naive_scaled_dot_product_attention( if training and dropout_p > 0.0: attn_weights = dropout(attn_weights, dropout_p) - # (N, S, S_kv) @ (N, S_kv, V) = (N, S, V) - attn = torch.bmm(attn_weights, values) + # (*, S, S_kv) @ (*, S_kv, V) = (*, S, V) + attn = torch.matmul(attn_weights, values) return attn, attn_weights if needs_weights else None