Skip to content

Commit

Permalink
feat: start on tests for affine transform 2d
Browse files Browse the repository at this point in the history
  • Loading branch information
McHaillet committed Nov 9, 2024
1 parent c585074 commit e46362e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/tttsa/affine/affine_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,16 @@ def affine_transform_3d(
..., :3
].contiguous()
grid_sample_coordinates = array_to_grid_sample(grid, images.shape[-3:])
if images.dim() == 3: # needed for grid sample
images = einops.repeat(images, "d h w -> n d h w", n=M.shape[0])
samples = (
einops.repeat( # needed for grid sample
images, "d h w -> n d h w", n=M.shape[0]
)
if images.dim() == 3
else images
)
transformed = einops.rearrange(
F.grid_sample(
einops.rearrange(images, "... d h w -> ... 1 d h w"),
einops.rearrange(samples, "... d h w -> ... 1 d h w"),
grid_sample_coordinates,
align_corners=True,
mode=interpolation,
Expand Down
37 changes: 37 additions & 0 deletions tests/affine/test_affine_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import torch

from tttsa.affine import affine_transform_2d, stretch_image
from tttsa.transformations import R_2d


def test_stretch_image():
a = torch.zeros((5, 5))
b = stretch_image(a, 1.1, -85)
assert a.shape == b.shape


def test_affine_transform_2d():
a = torch.zeros((4, 5))
m1 = R_2d(torch.tensor(45.0))
b = affine_transform_2d(a, m1)
assert a.shape == b.shape
b = affine_transform_2d(a, m1, (5, 4))
assert b.shape == (5, 4)
m2 = R_2d(torch.randn(3))
b = affine_transform_2d(a, m2)
assert b.shape == (3, 4, 5)
b = affine_transform_2d(a, m2, (5, 4))
assert b.shape == (3, 5, 4)
a = torch.zeros((3, 4, 5))
b = affine_transform_2d(a, m2)
assert a.shape == b.shape
a = torch.zeros((2, 4, 5))
b = affine_transform_2d(a, m1)
assert a.shape == b.shape
with pytest.raises(RuntimeError):
affine_transform_2d(a, m2)


def test_affine_transform_3d():
pass

0 comments on commit e46362e

Please sign in to comment.