Skip to content

Commit

Permalink
Allowmultiple batch dimennsions in sdpa
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Sep 20, 2023
1 parent b038881 commit 3e3d009
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions src/fairseq2/nn/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,33 +46,34 @@ def forward(
) -> Tuple[Tensor, Optional[Tensor]]:
"""
:param queries:
The queries. *Shape:* :math:`(N,S,K)`, where :math:`N` is the batch
size, :math:`S` is the sequence length, and :math:`K` is the key
size.
The queries. *Shape:* :math:`(*,S,K)`, where :math:`*` is any number
of batch dimensions including none, :math:`S` is the sequence
length, and :math:`K` is the key size.
:param keys:
The keys. *Shape:* :math:`(N,S_{kv},K)`, where :math:`N` is the
batch size, :math:`S_{kv}` is the key/value sequence length, and
:math:`K` is the key size.
The keys. *Shape:* :math:`(N,S_{kv},K)`, where :math:`*` is any
number of batch dimensions including none, :math:`S_{kv}` is the
key/value sequence length, and :math:`K` is the key size.
:param values:
The values. *Shape:* :math:`(N,S_{kv},V)`, where :math:`N` is the
batch size, :math:`S_{kv}` is the key/value sequence length, and
:math:`V` is the value size.
The values. *Shape:* :math:`(N,S_{kv},V)`, where :math:`*` is any
number of batch dimensions including none, :math:`S_{kv}` is the
key/value sequence length, and :math:`V` is the value size.
:param mask:
The float mask that will be added to the attention weights before
computing the attention. *Shape:* :math:`(S,S_{kv})` or
:math:`(N,S,S_{kv})`, where :math:`N` is the batch size,
:math:`S` is the sequence length, and :math:`S_{kv}` is the
key/value sequence length.
:math:`(*,S,S_{kv})`, where :math:`*` is any number of batch
dimensions including none, :math:`S` is the sequence length, and
:math:`S_{kv}` is the key/value sequence length.
:param needs_weights:
If ``True``, returns the attention weights.
:returns:
- The attention values. *Shape:* :math:`(N,S,V)`, where
:math:`N` is the batch size, :math:`S` is the sequence length, and
:math:`V` is the value size.
- The attention weights. *Shape:* :math:`(N,S,S_{kv})`, where
:math:`N` is the batch size, :math:`S` is the sequence length, and
:math:`S_{kv}` is the key/value sequence length.
- The attention values. *Shape:* :math:`(*,S,V)`, where :math:`*`
is the same batch dimensions as input, :math:`S` is the sequence
length, and :math:`V` is the value size.
- The attention weights. *Shape:* :math:`(*,S,S_{kv})`, where
:math:`*` is the same batch dimensions as input, :math:`S` is the
sequence length, and :math:`S_{kv}` is the key/value sequence
length.
"""

def extra_repr(self) -> str:
Expand Down Expand Up @@ -199,12 +200,11 @@ def _naive_scaled_dot_product_attention(
) -> Tuple[Tensor, Optional[Tensor]]:
queries = queries * (queries.size(-1) ** -0.5)

if mask is None:
# (N, S, K) @ (N, K, S_kv) = (N, S, S_kv)
attn_weights = torch.bmm(queries, keys.transpose(1, 2))
else:
# (N, S, S_kv) + ((N, S, K) @ (N, K, S_kv)) = (N, S, S_kv)
attn_weights = torch.baddbmm(mask, queries, keys.transpose(1, 2))
# (*, S, K) @ (*, K, S_kv) = (*, S, S_kv)
attn_weights = torch.matmul(queries, keys.transpose(-1, -2))

if mask is not None:
attn_weights = attn_weights + mask

# For numerical stability run in single precision.
attn_weights = softmax(attn_weights, dim=-1, dtype=torch.float32)
Expand All @@ -214,7 +214,7 @@ def _naive_scaled_dot_product_attention(
if training and dropout_p > 0.0:
attn_weights = dropout(attn_weights, dropout_p)

# (N, S, S_kv) @ (N, S_kv, V) = (N, S, V)
attn = torch.bmm(attn_weights, values)
# (*, S, S_kv) @ (*, S_kv, V) = (*, S, V)
attn = torch.matmul(attn_weights, values)

return attn, attn_weights if needs_weights else None

0 comments on commit 3e3d009

Please sign in to comment.