From 08400240ea1da51f28348a508ce3596220723354 Mon Sep 17 00:00:00 2001 From: zyaoj Date: Mon, 30 Dec 2024 12:05:47 +0000 Subject: [PATCH] enhance unit tests for position encoder --- tests/unit/nn/test_position_encoder.py | 184 +++++++++++++++++++++++++ 1 file changed, 184 insertions(+) diff --git a/tests/unit/nn/test_position_encoder.py b/tests/unit/nn/test_position_encoder.py index 21154b8fa..f77b13430 100644 --- a/tests/unit/nn/test_position_encoder.py +++ b/tests/unit/nn/test_position_encoder.py @@ -14,6 +14,8 @@ IncrementalStateBag, LearnedPositionEncoder, RotaryEncoder, + Sinusoidal2dPositionEncoder, + Sinusoidal3dPositionEncoder, SinusoidalPositionEncoder, ) from fairseq2.utils.rng import temporary_manual_seed @@ -172,6 +174,35 @@ def test_forward_works_when_state_bag_is_not_none_in_training(self) -> None: assert y.shape == (5, 2, 32) + def test_forward_works_with_padding_mask(self) -> None: + m = SinusoidalPositionEncoder( + encoding_dim=4, max_seq_len=10, _legacy_pad_idx=-1, device=device + ) + + x = torch.randn((3, 9, 4), device=device) + padding_mask = torch.zeros((3, 9), dtype=torch.bool, device=device) + padding_mask[:, 5:] = True + + y = m(x, padding_mask=padding_mask) + + assert y.shape == (3, 9, 4) + + def test_forward_works_with_multiple_batch_dims_and_padding(self) -> None: + m = SinusoidalPositionEncoder(encoding_dim=4, max_seq_len=10, device=device) + + x = torch.randn((4, 3, 9, 4), device=device) + padding_mask = torch.zeros((4, 3, 9), dtype=torch.bool, device=device) + padding_mask[..., 5:] = True + + y = m(x, padding_mask=padding_mask) + + assert y.shape == (4, 3, 9, 4) + + def test_extra_repr_works(self) -> None: + m = SinusoidalPositionEncoder(encoding_dim=4, max_seq_len=10, device=device) + + assert m.extra_repr() == "encoding_dim=4, max_seq_len=10" + class TestLearnedPositionEncoder: def test_init_works(self) -> None: @@ -249,6 +280,28 @@ def test_forward_works_when_state_bag_is_not_none_in_training(self) -> None: assert y.shape == (5, 2, 32) + def test_forward_works_with_padding_mask(self) -> None: + m = LearnedPositionEncoder(encoding_dim=4, max_seq_len=10, device=device) + + x = torch.randn((3, 9, 4), device=device) + padding_mask = torch.zeros((3, 9), dtype=torch.bool, device=device) + padding_mask[:, 5:] = True + + y = m(x, padding_mask=padding_mask) + + assert y.shape == (3, 9, 4) + + def test_forward_works_with_multiple_batch_dims_and_padding(self) -> None: + m = LearnedPositionEncoder(encoding_dim=4, max_seq_len=10, device=device) + + x = torch.randn((4, 3, 9, 4), device=device) + padding_mask = torch.zeros((4, 3, 9), dtype=torch.bool, device=device) + padding_mask[..., 5:] = True + + y = m(x, padding_mask=padding_mask) + + assert y.shape == (4, 3, 9, 4) + class TestRotaryEncoder: def test_init_raises_error_when_encoding_dim_is_odd(self) -> None: @@ -337,3 +390,134 @@ def test_forward_works_when_state_bag_is_not_none_in_training(self) -> None: y = m(x, padding_mask=None, state_bag=state_bag) assert y.shape == (5, 2, 32) + + def test_forward_works_with_padding_mask(self) -> None: + m = RotaryEncoder(encoding_dim=4, max_seq_len=10, device=device) + + x = torch.randn((3, 9, 4), device=device) + padding_mask = torch.zeros((3, 9), dtype=torch.bool, device=device) + padding_mask[:, 5:] = True + + y = m(x, padding_mask=padding_mask) + + assert y.shape == (3, 9, 4) + + def test_forward_works_with_custom_freqs_init(self) -> None: + def custom_freqs_init(encoder: RotaryEncoder) -> torch.Tensor: + return torch.ones(encoder.encoding_dim // 2, device=device) + + m = RotaryEncoder( + encoding_dim=4, + max_seq_len=10, + freqs_init_fn=custom_freqs_init, + device=device, + ) + + x = torch.randn((3, 9, 4), device=device) + y = m(x, padding_mask=None) + + assert y.shape == (3, 9, 4) + + +class TestSinusoidal2dPositionEncoder: + def test_init_raises_error_when_encoding_dim_is_odd(self) -> None: + with pytest.raises( + ValueError, match=r"^`encoding_dim` must be even, but is 13 instead\.$" + ): + Sinusoidal2dPositionEncoder( + encoding_dim=13, grid_dims=(10, 10), device=device + ) + + def test_forward_works(self) -> None: + m = Sinusoidal2dPositionEncoder(encoding_dim=4, grid_dims=(8, 8), device=device) + + # Test with same dimensions as grid + x = torch.randn((2, 8, 8, 4), device=device) + y = m(x) + + assert y.shape == (2, 8, 8, 4) + assert y.dtype == x.dtype + + # Test with different dimensions (should trigger interpolation) + x = torch.randn((2, 16, 16, 4), device=device) + y = m(x) + + assert y.shape == (2, 16, 16, 4) + assert y.dtype == x.dtype + + def test_forward_raises_error_on_wrong_dims(self) -> None: + m = Sinusoidal2dPositionEncoder(encoding_dim=4, grid_dims=(8, 8), device=device) + + # Test with wrong number of dimensions + x = torch.randn((2, 8, 4), device=device) + + with pytest.raises( + ValueError, + match=r"^`x` must be 4 dimensional, but is 3 dimensional instead\.$", + ): + m(x) + + def test_extra_repr_works(self) -> None: + m = Sinusoidal2dPositionEncoder(encoding_dim=4, grid_dims=(8, 8), device=device) + + assert m.extra_repr() == "encoding_dim=4, grid_dims=(8, 8)" + + +class TestSinusoidal3dPositionEncoder: + def test_init_raises_error_when_encoding_dim_is_odd(self) -> None: + with pytest.raises( + ValueError, match=r"^`encoding_dim` must be even, but is 13 instead\.$" + ): + Sinusoidal3dPositionEncoder( + encoding_dim=13, grid_dims=(8, 8, 8), device=device + ) + + def test_forward_works(self) -> None: + m = Sinusoidal3dPositionEncoder( + encoding_dim=6, grid_dims=(4, 4, 4), device=device + ) + + # Test with same dimensions as grid + x = torch.randn((2, 4, 4, 4, 6), device=device) + y = m(x) + + assert y.shape == (2, 4, 4, 4, 6) + assert y.dtype == x.dtype + + # Test with different dimensions (should trigger interpolation) + x = torch.randn((2, 8, 8, 8, 6), device=device) + y = m(x) + + assert y.shape == (2, 8, 8, 8, 6) + assert y.dtype == x.dtype + + def test_forward_works_with_uniform_power(self) -> None: + m = Sinusoidal3dPositionEncoder( + encoding_dim=6, grid_dims=(4, 4, 4), uniform_power=True, device=device + ) + + x = torch.randn((2, 4, 4, 4, 6), device=device) + y = m(x) + + assert y.shape == (2, 4, 4, 4, 6) + + def test_forward_raises_error_on_wrong_dims(self) -> None: + m = Sinusoidal3dPositionEncoder( + encoding_dim=6, grid_dims=(4, 4, 4), device=device + ) + + # Test with wrong number of dimensions + x = torch.randn((2, 4, 4, 6), device=device) + + with pytest.raises( + ValueError, + match=r"^`x` must be 5 dimensional, but is 4 dimensional instead\.$", + ): + m(x) + + def test_extra_repr_works(self) -> None: + m = Sinusoidal3dPositionEncoder( + encoding_dim=6, grid_dims=(4, 4, 4), device=device + ) + + assert m.extra_repr() == "encoding_dim=6, grid_dims=(4, 4, 4)"