diff --git a/src/fairseq2/nn/position_encoder.py b/src/fairseq2/nn/position_encoder.py index 74a7b676b..d9494e483 100644 --- a/src/fairseq2/nn/position_encoder.py +++ b/src/fairseq2/nn/position_encoder.py @@ -47,14 +47,13 @@ def forward( """ :param seqs: The sequences to encode with positional information. *Shape:* - :math:`(N,S,E)`, where :math:`N` is the batch size, :math:`S` is the - sequence length, and :math:`E` is the dimensionality of the - positional encodings. + :math:`(*,S,E)`, where :math:`*` is any number of batch dimensions + including none, :math:`S` is the sequence length, and :math:`E` is + the dimensionality of the positional encodings. :param padding_mask: - The float padding mask of ``seqs``. *Shape:* :math:`(N_{msk},S)`, - where :math:`N_{msk}` is the mask batch size and :math:`S` is the - sequence length. :math:`N` can be a multiple of :math:`N_{msk}` in - which case the mask will be tiled before being applied. + The float padding mask of ``seqs``. *Shape:* :math:`(*,S)`, where + :math:`*` is any number of batch dimensions including none and + :math:`S` is the sequence length. :param state_bag: The state bag to use for incremental evaluation. @@ -68,7 +67,7 @@ def forward( else: start_step = 0 - if (seq_len := start_step + seqs.size(1)) > self.max_seq_len: + if (seq_len := start_step + seqs.size(-2)) > self.max_seq_len: raise ValueError( f"The input sequence length must be less than or equal to the maximum sequence length ({self.max_seq_len}), but is {seq_len} instead." ) @@ -85,15 +84,13 @@ def _do_forward( """ :param seqs: The sequences to encode with positional information. *Shape:* - :math:`(N,S,E)`, where :math:`N` is the batch size, :math:`S` is the - sequence length, and :math:`E` is the dimensionality of the - positional encodings. + :math:`(*,S,E)`, where :math:`*` is any number of batch dimensions + including none, :math:`S` is the sequence length, and :math:`E` is + the dimensionality of the positional encodings. :param padding_mask: - The float padding mask of ``seqs``. *Shape:* :math:`(N_{msk},S)`, - where :math:`N_{msk}` is the mask batch size and :math:`S` is the - sequence length. If padding has to be applied, a derived class - should use the :func:`~fairseq2.nn.utils.mask.apply_padding_mask` - function that handles tiling. + The float padding mask of ``seqs``. *Shape:* :math:`(*,S)`, where + :math:`*` is any number of batch dimensions including none and + :math:`S` is the sequence length. :param state_bag: The state bag to use for incremental evaluation. @@ -237,7 +234,7 @@ def _do_forward( state_bag: Optional[IncrementalStateBag], ) -> Tensor: """:meta private:""" - seq_len = seqs.size(1) + seq_len = seqs.size(-2) if not self.training and state_bag is not None: start_step = state_bag.step @@ -297,7 +294,7 @@ def _do_forward( state_bag: Optional[IncrementalStateBag], ) -> Tensor: """:meta private:""" - seq_len = seqs.size(1) + seq_len = seqs.size(-2) if not self.training and state_bag is not None: start_step = state_bag.step @@ -385,7 +382,7 @@ def _do_forward( state_bag: Optional[IncrementalStateBag], ) -> Tensor: """:meta private:""" - seq_len = seqs.size(1) + seq_len = seqs.size(-2) if not self.training and state_bag is not None: start_step = state_bag.step diff --git a/tests/unit/nn/test_position_encoder.py b/tests/unit/nn/test_position_encoder.py index a1d0cfa1d..44593c2da 100644 --- a/tests/unit/nn/test_position_encoder.py +++ b/tests/unit/nn/test_position_encoder.py @@ -116,6 +116,15 @@ def test_forward_works(self) -> None: assert_close(y - x, m.weight[:9].expand_as(y)) + # Test with multiple batch dimensions. + x = torch.randn((4, 3, 9, 4), device=device) + + y = m(x, padding_mask=None) + + assert y.shape == (4, 3, 9, 4) + + assert_close(y - x, m.weight[:9].expand_as(y)) + @pytest.mark.parametrize("step", [0, 1, 2]) def test_forward_works_in_incremental_eval(self, step: int) -> None: m = SinusoidalPositionEncoder(encoding_dim=32, max_seq_len=4, device=device) @@ -182,6 +191,15 @@ def test_forward_works(self) -> None: assert_close(y - x, m.weight[:9].expand_as(y)) + # Test with multiple batch dimensions. + x = torch.randn((4, 3, 9, 4), device=device) + + y = m(x, padding_mask=None) + + assert y.shape == (4, 3, 9, 4) + + assert_close(y - x, m.weight[:9].expand_as(y)) + @pytest.mark.parametrize("step", [0, 1, 2]) def test_forward_works_in_incremental_eval(self, step: int) -> None: m = LearnedPositionEncoder(encoding_dim=32, max_seq_len=4, device=device) @@ -235,7 +253,7 @@ def test_init_raises_error_when_encoding_dim_is_odd(self) -> None: def test_forward_works(self) -> None: m = RotaryEncoder(encoding_dim=4, max_seq_len=10, device=device) - x = torch.randn((3, 9, 4), device=device) + x = torch.randn((4, 3, 9, 4), device=device) y = m(x, padding_mask=None)