Skip to content

Commit

Permalink
Allow multiple batch dimensions in position encoders (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Sep 21, 2023
1 parent c0b893f commit 173c7ae
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 20 deletions.
35 changes: 16 additions & 19 deletions src/fairseq2/nn/position_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,13 @@ def forward(
"""
:param seqs:
The sequences to encode with positional information. *Shape:*
:math:`(N,S,E)`, where :math:`N` is the batch size, :math:`S` is the
sequence length, and :math:`E` is the dimensionality of the
positional encodings.
:math:`(*,S,E)`, where :math:`*` is any number of batch dimensions
including none, :math:`S` is the sequence length, and :math:`E` is
the dimensionality of the positional encodings.
:param padding_mask:
The float padding mask of ``seqs``. *Shape:* :math:`(N_{msk},S)`,
where :math:`N_{msk}` is the mask batch size and :math:`S` is the
sequence length. :math:`N` can be a multiple of :math:`N_{msk}` in
which case the mask will be tiled before being applied.
The float padding mask of ``seqs``. *Shape:* :math:`(*,S)`, where
:math:`*` is any number of batch dimensions including none and
:math:`S` is the sequence length.
:param state_bag:
The state bag to use for incremental evaluation.
Expand All @@ -68,7 +67,7 @@ def forward(
else:
start_step = 0

if (seq_len := start_step + seqs.size(1)) > self.max_seq_len:
if (seq_len := start_step + seqs.size(-2)) > self.max_seq_len:
raise ValueError(
f"The input sequence length must be less than or equal to the maximum sequence length ({self.max_seq_len}), but is {seq_len} instead."
)
Expand All @@ -85,15 +84,13 @@ def _do_forward(
"""
:param seqs:
The sequences to encode with positional information. *Shape:*
:math:`(N,S,E)`, where :math:`N` is the batch size, :math:`S` is the
sequence length, and :math:`E` is the dimensionality of the
positional encodings.
:math:`(*,S,E)`, where :math:`*` is any number of batch dimensions
including none, :math:`S` is the sequence length, and :math:`E` is
the dimensionality of the positional encodings.
:param padding_mask:
The float padding mask of ``seqs``. *Shape:* :math:`(N_{msk},S)`,
where :math:`N_{msk}` is the mask batch size and :math:`S` is the
sequence length. If padding has to be applied, a derived class
should use the :func:`~fairseq2.nn.utils.mask.apply_padding_mask`
function that handles tiling.
The float padding mask of ``seqs``. *Shape:* :math:`(*,S)`, where
:math:`*` is any number of batch dimensions including none and
:math:`S` is the sequence length.
:param state_bag:
The state bag to use for incremental evaluation.
Expand Down Expand Up @@ -237,7 +234,7 @@ def _do_forward(
state_bag: Optional[IncrementalStateBag],
) -> Tensor:
""":meta private:"""
seq_len = seqs.size(1)
seq_len = seqs.size(-2)

if not self.training and state_bag is not None:
start_step = state_bag.step
Expand Down Expand Up @@ -297,7 +294,7 @@ def _do_forward(
state_bag: Optional[IncrementalStateBag],
) -> Tensor:
""":meta private:"""
seq_len = seqs.size(1)
seq_len = seqs.size(-2)

if not self.training and state_bag is not None:
start_step = state_bag.step
Expand Down Expand Up @@ -385,7 +382,7 @@ def _do_forward(
state_bag: Optional[IncrementalStateBag],
) -> Tensor:
""":meta private:"""
seq_len = seqs.size(1)
seq_len = seqs.size(-2)

if not self.training and state_bag is not None:
start_step = state_bag.step
Expand Down
20 changes: 19 additions & 1 deletion tests/unit/nn/test_position_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def test_forward_works(self) -> None:

assert_close(y - x, m.weight[:9].expand_as(y))

# Test with multiple batch dimensions.
x = torch.randn((4, 3, 9, 4), device=device)

y = m(x, padding_mask=None)

assert y.shape == (4, 3, 9, 4)

assert_close(y - x, m.weight[:9].expand_as(y))

@pytest.mark.parametrize("step", [0, 1, 2])
def test_forward_works_in_incremental_eval(self, step: int) -> None:
m = SinusoidalPositionEncoder(encoding_dim=32, max_seq_len=4, device=device)
Expand Down Expand Up @@ -182,6 +191,15 @@ def test_forward_works(self) -> None:

assert_close(y - x, m.weight[:9].expand_as(y))

# Test with multiple batch dimensions.
x = torch.randn((4, 3, 9, 4), device=device)

y = m(x, padding_mask=None)

assert y.shape == (4, 3, 9, 4)

assert_close(y - x, m.weight[:9].expand_as(y))

@pytest.mark.parametrize("step", [0, 1, 2])
def test_forward_works_in_incremental_eval(self, step: int) -> None:
m = LearnedPositionEncoder(encoding_dim=32, max_seq_len=4, device=device)
Expand Down Expand Up @@ -235,7 +253,7 @@ def test_init_raises_error_when_encoding_dim_is_odd(self) -> None:
def test_forward_works(self) -> None:
m = RotaryEncoder(encoding_dim=4, max_seq_len=10, device=device)

x = torch.randn((3, 9, 4), device=device)
x = torch.randn((4, 3, 9, 4), device=device)

y = m(x, padding_mask=None)

Expand Down

0 comments on commit 173c7ae

Please sign in to comment.