From 1dbfb82e95f902f9985d5ac4fc0082be0376c338 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Wed, 13 Nov 2024 13:16:42 +0100 Subject: [PATCH] feat: add tests for 3d affine --- src/tttsa/affine/affine_transform.py | 12 ++++++- tests/affine/test_affine_transform.py | 46 ++++++++++++++++++++++----- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/src/tttsa/affine/affine_transform.py b/src/tttsa/affine/affine_transform.py index 8718b80..0290dc6 100644 --- a/src/tttsa/affine/affine_transform.py +++ b/src/tttsa/affine/affine_transform.py @@ -62,6 +62,11 @@ def affine_transform_2d( if images.dim() == 2 else images ) + if samples.shape[0] != grid_sample_coordinates.shape[0]: + raise ValueError( + "Provide either an equal batch of images and matrices or " + "multiple matrices for a single image." + ) transformed = einops.rearrange( F.grid_sample( einops.rearrange(samples, "... h w -> ... 1 h w"), @@ -90,7 +95,7 @@ def affine_transform_3d( grid = einops.rearrange(grid, "d h w coords -> 1 d h w coords 1") M = einops.rearrange( torch.linalg.inv(affine_matrices), # invert so that each grid cell points - "... i j -> ... 1 1 i j", # to where it needs to get data from + "... i j -> ... 1 1 1 i j", # to where it needs to get data from ).to(device) grid = M @ grid grid = einops.rearrange(grid, "... d h w coords 1 -> ... d h w coords")[ @@ -104,6 +109,11 @@ def affine_transform_3d( if images.dim() == 3 else images ) + if samples.shape[0] != grid_sample_coordinates.shape[0]: + raise ValueError( + "Provide either an equal batch of images and matrices or " + "multiple matrices for a single image." + ) transformed = einops.rearrange( F.grid_sample( einops.rearrange(samples, "... d h w -> ... 1 d h w"), diff --git a/tests/affine/test_affine_transform.py b/tests/affine/test_affine_transform.py index b80fc7e..dba9495 100644 --- a/tests/affine/test_affine_transform.py +++ b/tests/affine/test_affine_transform.py @@ -1,8 +1,8 @@ import pytest import torch -from tttsa.affine import affine_transform_2d, stretch_image -from tttsa.transformations import R_2d +from tttsa.affine import affine_transform_2d, affine_transform_3d, stretch_image +from tttsa.transformations import R_2d, Rz def test_stretch_image(): @@ -12,26 +12,56 @@ def test_stretch_image(): def test_affine_transform_2d(): - a = torch.zeros((4, 5)) m1 = R_2d(torch.tensor(45.0)) + m2 = R_2d(torch.randn(3)) + + # with a single image + a = torch.zeros((4, 5)) b = affine_transform_2d(a, m1) assert a.shape == b.shape b = affine_transform_2d(a, m1, (5, 4)) assert b.shape == (5, 4) - m2 = R_2d(torch.randn(3)) b = affine_transform_2d(a, m2) assert b.shape == (3, 4, 5) b = affine_transform_2d(a, m2, (5, 4)) assert b.shape == (3, 5, 4) + + # with a batch of images a = torch.zeros((3, 4, 5)) b = affine_transform_2d(a, m2) assert a.shape == b.shape + b = affine_transform_2d(a, m2, (5, 4)) + assert b.shape == (3, 5, 4) a = torch.zeros((2, 4, 5)) - b = affine_transform_2d(a, m1) - assert a.shape == b.shape - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): + affine_transform_2d(a, m1) + with pytest.raises(ValueError): affine_transform_2d(a, m2) def test_affine_transform_3d(): - pass + m1 = Rz(torch.tensor(45.0)) + m2 = Rz(torch.randn(3)) + + # with a single image + a = torch.zeros((3, 4, 5)) + b = affine_transform_3d(a, m1) + assert a.shape == b.shape + b = affine_transform_3d(a, m1, (5, 4, 3)) + assert b.shape == (5, 4, 3) + b = affine_transform_3d(a, m2) + assert b.shape == (3, 3, 4, 5) + b = affine_transform_3d(a, m2, (5, 4, 3)) + assert b.shape == (3, 5, 4, 3) + + # with a batch of images + a = torch.zeros((3, 3, 4, 5)) + b = affine_transform_3d(a, m2) + assert a.shape == b.shape + b = affine_transform_3d(a, m2, (5, 4, 3)) + assert b.shape == (3, 5, 4, 3) + a = torch.zeros((2, 3, 4, 5)) + with pytest.raises(ValueError): + affine_transform_3d(a, m1) + with pytest.raises(ValueError): + affine_transform_3d(a, m2)