Skip to content

Commit

Permalink
enhance unit tests for position encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
zyaoj committed Dec 30, 2024
1 parent 47fd523 commit 0840024
Showing 1 changed file with 184 additions and 0 deletions.
184 changes: 184 additions & 0 deletions tests/unit/nn/test_position_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
IncrementalStateBag,
LearnedPositionEncoder,
RotaryEncoder,
Sinusoidal2dPositionEncoder,
Sinusoidal3dPositionEncoder,
SinusoidalPositionEncoder,
)
from fairseq2.utils.rng import temporary_manual_seed
Expand Down Expand Up @@ -172,6 +174,35 @@ def test_forward_works_when_state_bag_is_not_none_in_training(self) -> None:

assert y.shape == (5, 2, 32)

def test_forward_works_with_padding_mask(self) -> None:
m = SinusoidalPositionEncoder(
encoding_dim=4, max_seq_len=10, _legacy_pad_idx=-1, device=device
)

x = torch.randn((3, 9, 4), device=device)
padding_mask = torch.zeros((3, 9), dtype=torch.bool, device=device)
padding_mask[:, 5:] = True

y = m(x, padding_mask=padding_mask)

assert y.shape == (3, 9, 4)

def test_forward_works_with_multiple_batch_dims_and_padding(self) -> None:
m = SinusoidalPositionEncoder(encoding_dim=4, max_seq_len=10, device=device)

x = torch.randn((4, 3, 9, 4), device=device)
padding_mask = torch.zeros((4, 3, 9), dtype=torch.bool, device=device)
padding_mask[..., 5:] = True

y = m(x, padding_mask=padding_mask)

assert y.shape == (4, 3, 9, 4)

def test_extra_repr_works(self) -> None:
m = SinusoidalPositionEncoder(encoding_dim=4, max_seq_len=10, device=device)

assert m.extra_repr() == "encoding_dim=4, max_seq_len=10"


class TestLearnedPositionEncoder:
def test_init_works(self) -> None:
Expand Down Expand Up @@ -249,6 +280,28 @@ def test_forward_works_when_state_bag_is_not_none_in_training(self) -> None:

assert y.shape == (5, 2, 32)

def test_forward_works_with_padding_mask(self) -> None:
m = LearnedPositionEncoder(encoding_dim=4, max_seq_len=10, device=device)

x = torch.randn((3, 9, 4), device=device)
padding_mask = torch.zeros((3, 9), dtype=torch.bool, device=device)
padding_mask[:, 5:] = True

y = m(x, padding_mask=padding_mask)

assert y.shape == (3, 9, 4)

def test_forward_works_with_multiple_batch_dims_and_padding(self) -> None:
m = LearnedPositionEncoder(encoding_dim=4, max_seq_len=10, device=device)

x = torch.randn((4, 3, 9, 4), device=device)
padding_mask = torch.zeros((4, 3, 9), dtype=torch.bool, device=device)
padding_mask[..., 5:] = True

y = m(x, padding_mask=padding_mask)

assert y.shape == (4, 3, 9, 4)


class TestRotaryEncoder:
def test_init_raises_error_when_encoding_dim_is_odd(self) -> None:
Expand Down Expand Up @@ -337,3 +390,134 @@ def test_forward_works_when_state_bag_is_not_none_in_training(self) -> None:
y = m(x, padding_mask=None, state_bag=state_bag)

assert y.shape == (5, 2, 32)

def test_forward_works_with_padding_mask(self) -> None:
m = RotaryEncoder(encoding_dim=4, max_seq_len=10, device=device)

x = torch.randn((3, 9, 4), device=device)
padding_mask = torch.zeros((3, 9), dtype=torch.bool, device=device)
padding_mask[:, 5:] = True

y = m(x, padding_mask=padding_mask)

assert y.shape == (3, 9, 4)

def test_forward_works_with_custom_freqs_init(self) -> None:
def custom_freqs_init(encoder: RotaryEncoder) -> torch.Tensor:
return torch.ones(encoder.encoding_dim // 2, device=device)

m = RotaryEncoder(
encoding_dim=4,
max_seq_len=10,
freqs_init_fn=custom_freqs_init,
device=device,
)

x = torch.randn((3, 9, 4), device=device)
y = m(x, padding_mask=None)

assert y.shape == (3, 9, 4)


class TestSinusoidal2dPositionEncoder:
def test_init_raises_error_when_encoding_dim_is_odd(self) -> None:
with pytest.raises(
ValueError, match=r"^`encoding_dim` must be even, but is 13 instead\.$"
):
Sinusoidal2dPositionEncoder(
encoding_dim=13, grid_dims=(10, 10), device=device
)

def test_forward_works(self) -> None:
m = Sinusoidal2dPositionEncoder(encoding_dim=4, grid_dims=(8, 8), device=device)

# Test with same dimensions as grid
x = torch.randn((2, 8, 8, 4), device=device)
y = m(x)

assert y.shape == (2, 8, 8, 4)
assert y.dtype == x.dtype

# Test with different dimensions (should trigger interpolation)
x = torch.randn((2, 16, 16, 4), device=device)
y = m(x)

assert y.shape == (2, 16, 16, 4)
assert y.dtype == x.dtype

def test_forward_raises_error_on_wrong_dims(self) -> None:
m = Sinusoidal2dPositionEncoder(encoding_dim=4, grid_dims=(8, 8), device=device)

# Test with wrong number of dimensions
x = torch.randn((2, 8, 4), device=device)

with pytest.raises(
ValueError,
match=r"^`x` must be 4 dimensional, but is 3 dimensional instead\.$",
):
m(x)

def test_extra_repr_works(self) -> None:
m = Sinusoidal2dPositionEncoder(encoding_dim=4, grid_dims=(8, 8), device=device)

assert m.extra_repr() == "encoding_dim=4, grid_dims=(8, 8)"


class TestSinusoidal3dPositionEncoder:
def test_init_raises_error_when_encoding_dim_is_odd(self) -> None:
with pytest.raises(
ValueError, match=r"^`encoding_dim` must be even, but is 13 instead\.$"
):
Sinusoidal3dPositionEncoder(
encoding_dim=13, grid_dims=(8, 8, 8), device=device
)

def test_forward_works(self) -> None:
m = Sinusoidal3dPositionEncoder(
encoding_dim=6, grid_dims=(4, 4, 4), device=device
)

# Test with same dimensions as grid
x = torch.randn((2, 4, 4, 4, 6), device=device)
y = m(x)

assert y.shape == (2, 4, 4, 4, 6)
assert y.dtype == x.dtype

# Test with different dimensions (should trigger interpolation)
x = torch.randn((2, 8, 8, 8, 6), device=device)
y = m(x)

assert y.shape == (2, 8, 8, 8, 6)
assert y.dtype == x.dtype

def test_forward_works_with_uniform_power(self) -> None:
m = Sinusoidal3dPositionEncoder(
encoding_dim=6, grid_dims=(4, 4, 4), uniform_power=True, device=device
)

x = torch.randn((2, 4, 4, 4, 6), device=device)
y = m(x)

assert y.shape == (2, 4, 4, 4, 6)

def test_forward_raises_error_on_wrong_dims(self) -> None:
m = Sinusoidal3dPositionEncoder(
encoding_dim=6, grid_dims=(4, 4, 4), device=device
)

# Test with wrong number of dimensions
x = torch.randn((2, 4, 4, 6), device=device)

with pytest.raises(
ValueError,
match=r"^`x` must be 5 dimensional, but is 4 dimensional instead\.$",
):
m(x)

def test_extra_repr_works(self) -> None:
m = Sinusoidal3dPositionEncoder(
encoding_dim=6, grid_dims=(4, 4, 4), device=device
)

assert m.extra_repr() == "encoding_dim=6, grid_dims=(4, 4, 4)"

0 comments on commit 0840024

Please sign in to comment.