Skip to content

Commit

Permalink
feat: add tests for 3d affine
Browse files Browse the repository at this point in the history
  • Loading branch information
McHaillet committed Nov 13, 2024
1 parent e46362e commit 1dbfb82
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
12 changes: 11 additions & 1 deletion src/tttsa/affine/affine_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ def affine_transform_2d(
if images.dim() == 2
else images
)
if samples.shape[0] != grid_sample_coordinates.shape[0]:
raise ValueError(
"Provide either an equal batch of images and matrices or "
"multiple matrices for a single image."
)
transformed = einops.rearrange(
F.grid_sample(
einops.rearrange(samples, "... h w -> ... 1 h w"),
Expand Down Expand Up @@ -90,7 +95,7 @@ def affine_transform_3d(
grid = einops.rearrange(grid, "d h w coords -> 1 d h w coords 1")
M = einops.rearrange(
torch.linalg.inv(affine_matrices), # invert so that each grid cell points
"... i j -> ... 1 1 i j", # to where it needs to get data from
"... i j -> ... 1 1 1 i j", # to where it needs to get data from
).to(device)
grid = M @ grid
grid = einops.rearrange(grid, "... d h w coords 1 -> ... d h w coords")[
Expand All @@ -104,6 +109,11 @@ def affine_transform_3d(
if images.dim() == 3
else images
)
if samples.shape[0] != grid_sample_coordinates.shape[0]:
raise ValueError(
"Provide either an equal batch of images and matrices or "
"multiple matrices for a single image."
)
transformed = einops.rearrange(
F.grid_sample(
einops.rearrange(samples, "... d h w -> ... 1 d h w"),
Expand Down
46 changes: 38 additions & 8 deletions tests/affine/test_affine_transform.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytest
import torch

from tttsa.affine import affine_transform_2d, stretch_image
from tttsa.transformations import R_2d
from tttsa.affine import affine_transform_2d, affine_transform_3d, stretch_image
from tttsa.transformations import R_2d, Rz


def test_stretch_image():
Expand All @@ -12,26 +12,56 @@ def test_stretch_image():


def test_affine_transform_2d():
a = torch.zeros((4, 5))
m1 = R_2d(torch.tensor(45.0))
m2 = R_2d(torch.randn(3))

# with a single image
a = torch.zeros((4, 5))
b = affine_transform_2d(a, m1)
assert a.shape == b.shape
b = affine_transform_2d(a, m1, (5, 4))
assert b.shape == (5, 4)
m2 = R_2d(torch.randn(3))
b = affine_transform_2d(a, m2)
assert b.shape == (3, 4, 5)
b = affine_transform_2d(a, m2, (5, 4))
assert b.shape == (3, 5, 4)

# with a batch of images
a = torch.zeros((3, 4, 5))
b = affine_transform_2d(a, m2)
assert a.shape == b.shape
b = affine_transform_2d(a, m2, (5, 4))
assert b.shape == (3, 5, 4)
a = torch.zeros((2, 4, 5))
b = affine_transform_2d(a, m1)
assert a.shape == b.shape
with pytest.raises(RuntimeError):
with pytest.raises(ValueError):
affine_transform_2d(a, m1)
with pytest.raises(ValueError):
affine_transform_2d(a, m2)


def test_affine_transform_3d():
pass
m1 = Rz(torch.tensor(45.0))
m2 = Rz(torch.randn(3))

# with a single image
a = torch.zeros((3, 4, 5))
b = affine_transform_3d(a, m1)
assert a.shape == b.shape
b = affine_transform_3d(a, m1, (5, 4, 3))
assert b.shape == (5, 4, 3)
b = affine_transform_3d(a, m2)
assert b.shape == (3, 3, 4, 5)
b = affine_transform_3d(a, m2, (5, 4, 3))
assert b.shape == (3, 5, 4, 3)

# with a batch of images
a = torch.zeros((3, 3, 4, 5))
b = affine_transform_3d(a, m2)
assert a.shape == b.shape
b = affine_transform_3d(a, m2, (5, 4, 3))
assert b.shape == (3, 5, 4, 3)
a = torch.zeros((2, 3, 4, 5))
with pytest.raises(ValueError):
affine_transform_3d(a, m1)
with pytest.raises(ValueError):
affine_transform_3d(a, m2)

0 comments on commit 1dbfb82

Please sign in to comment.