Skip to content

Commit

Permalink
PR #17: add unittests for alignment, correlation, and affine libraries
Browse files Browse the repository at this point in the history
  • Loading branch information
McHaillet authored Nov 13, 2024
2 parents ae175a0 + 1dbfb82 commit dda6553
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 39 deletions.
23 changes: 19 additions & 4 deletions src/tttsa/affine/affine_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ def affine_transform_2d(
if images.dim() == 2
else images
)
if samples.shape[0] != grid_sample_coordinates.shape[0]:
raise ValueError(
"Provide either an equal batch of images and matrices or "
"multiple matrices for a single image."
)
transformed = einops.rearrange(
F.grid_sample(
einops.rearrange(samples, "... h w -> ... 1 h w"),
Expand Down Expand Up @@ -90,18 +95,28 @@ def affine_transform_3d(
grid = einops.rearrange(grid, "d h w coords -> 1 d h w coords 1")
M = einops.rearrange(
torch.linalg.inv(affine_matrices), # invert so that each grid cell points
"... i j -> ... 1 1 i j", # to where it needs to get data from
"... i j -> ... 1 1 1 i j", # to where it needs to get data from
).to(device)
grid = M @ grid
grid = einops.rearrange(grid, "... d h w coords 1 -> ... d h w coords")[
..., :3
].contiguous()
grid_sample_coordinates = array_to_grid_sample(grid, images.shape[-3:])
if images.dim() == 3: # needed for grid sample
images = einops.repeat(images, "d h w -> n d h w", n=M.shape[0])
samples = (
einops.repeat( # needed for grid sample
images, "d h w -> n d h w", n=M.shape[0]
)
if images.dim() == 3
else images
)
if samples.shape[0] != grid_sample_coordinates.shape[0]:
raise ValueError(
"Provide either an equal batch of images and matrices or "
"multiple matrices for a single image."
)
transformed = einops.rearrange(
F.grid_sample(
einops.rearrange(images, "... d h w -> ... 1 d h w"),
einops.rearrange(samples, "... d h w -> ... 1 d h w"),
grid_sample_coordinates,
align_corners=True,
mode=interpolation,
Expand Down
8 changes: 8 additions & 0 deletions src/tttsa/alignment/find_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def find_image_shift(
correlation = correlate_2d(image_a, image_b, normalize=True)
maximum_idx = torch.unravel_index(correlation.argmax().cpu(), shape=image_a.shape)
y, x = maximum_idx
# Ensure that the max index is not on the border
if (
y == 0
or y == correlation.shape[0] - 1
or x == 0
or x == correlation.shape[1] - 1
):
return torch.tensor([float(y), float(x)]) - center
# Parabolic interpolation in the y direction
f_y0 = correlation[y - 1, x]
f_y1 = correlation[y, x]
Expand Down
34 changes: 0 additions & 34 deletions src/tttsa/alignment/tests/test_find_shift.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/tttsa/correlation/correlate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def correlate_2d(
# AreTomo using some like this (filtered FFT-based approach):
# result = result / torch.sqrt(result.abs() + .0001)
# result = bfactor_dft(result, 300, (result.shape[-2], ) * 2, 1, True)
result = torch.fft.irfftn(result, dim=(-2, -1), s=a.shape)
result = torch.fft.irfftn(result, dim=(-2, -1), s=a.shape[-2:])
result = torch.real(torch.fft.ifftshift(result, dim=(-2, -1)))
if normalize is True:
result = result / (h * w)
Expand Down
67 changes: 67 additions & 0 deletions tests/affine/test_affine_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pytest
import torch

from tttsa.affine import affine_transform_2d, affine_transform_3d, stretch_image
from tttsa.transformations import R_2d, Rz


def test_stretch_image():
a = torch.zeros((5, 5))
b = stretch_image(a, 1.1, -85)
assert a.shape == b.shape


def test_affine_transform_2d():
m1 = R_2d(torch.tensor(45.0))
m2 = R_2d(torch.randn(3))

# with a single image
a = torch.zeros((4, 5))
b = affine_transform_2d(a, m1)
assert a.shape == b.shape
b = affine_transform_2d(a, m1, (5, 4))
assert b.shape == (5, 4)
b = affine_transform_2d(a, m2)
assert b.shape == (3, 4, 5)
b = affine_transform_2d(a, m2, (5, 4))
assert b.shape == (3, 5, 4)

# with a batch of images
a = torch.zeros((3, 4, 5))
b = affine_transform_2d(a, m2)
assert a.shape == b.shape
b = affine_transform_2d(a, m2, (5, 4))
assert b.shape == (3, 5, 4)
a = torch.zeros((2, 4, 5))
with pytest.raises(ValueError):
affine_transform_2d(a, m1)
with pytest.raises(ValueError):
affine_transform_2d(a, m2)


def test_affine_transform_3d():
m1 = Rz(torch.tensor(45.0))
m2 = Rz(torch.randn(3))

# with a single image
a = torch.zeros((3, 4, 5))
b = affine_transform_3d(a, m1)
assert a.shape == b.shape
b = affine_transform_3d(a, m1, (5, 4, 3))
assert b.shape == (5, 4, 3)
b = affine_transform_3d(a, m2)
assert b.shape == (3, 3, 4, 5)
b = affine_transform_3d(a, m2, (5, 4, 3))
assert b.shape == (3, 5, 4, 3)

# with a batch of images
a = torch.zeros((3, 3, 4, 5))
b = affine_transform_3d(a, m2)
assert a.shape == b.shape
b = affine_transform_3d(a, m2, (5, 4, 3))
assert b.shape == (3, 5, 4, 3)
a = torch.zeros((2, 3, 4, 5))
with pytest.raises(ValueError):
affine_transform_3d(a, m1)
with pytest.raises(ValueError):
affine_transform_3d(a, m2)
30 changes: 30 additions & 0 deletions tests/alignment/test_find_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
import torch

from tttsa.alignment import find_image_shift


def test_find_image_shift():
a = torch.zeros((4, 4))
a[0, 0] = 1
b = torch.zeros((4, 4))
b[2, 2] = 0.7
b[2, 3] = 0.3
shift = find_image_shift(a, b)
print(shift)
assert shift.dtype == torch.float32
assert torch.all(shift == -2.0), (
"Interpolating a shift too close to a border is "
"not possible, so an integer shift should be "
"returned."
)
a = torch.zeros((8, 8))
a[3, 3] = 1
b = torch.zeros((8, 8))
b[4, 4] = 0.7
b[4, 5] = 0.3
shift = find_image_shift(a, b)
# values should interpolated with floating point precision
assert shift.dtype == torch.float32
assert shift[0] == pytest.approx(-1.1, 0.1)
assert shift[1] == pytest.approx(-1.2, 0.1)
14 changes: 14 additions & 0 deletions tests/test_correlate.py → tests/correlation/test_correlate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,17 @@ def test_correlate_2d():
assert torch.allclose(shift, torch.tensor([-1, -1]))
assert torch.allclose(cross_correlation[peak_position], torch.tensor([1.0]))
assert cross_correlation.shape == a.shape


def test_correlate_2d_stacks():
# test for stacks of images
a = torch.zeros((3, 10, 11))
a[:, 5, 5] = 1
b = torch.zeros((1, 10, 11))
b[:, 6, 6] = 1
cross_correlation = correlate_2d(a, b, normalize=True)
assert cross_correlation.shape == a.shape
cross_correlation = correlate_2d(b, a, normalize=True)
assert cross_correlation.shape == a.shape
cross_correlation = correlate_2d(a, a, normalize=True)
assert cross_correlation.shape == a.shape

0 comments on commit dda6553

Please sign in to comment.