diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 637f0873f1..74df003397 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1216,6 +1216,12 @@ Utility :members: :special-members: __call__ +`ApplyTransformToPoints` +"""""""""""""""""""""""" +.. autoclass:: ApplyTransformToPoints + :members: + :special-members: __call__ + Dictionary Transforms --------------------- @@ -2265,6 +2271,12 @@ Utility (Dict) :members: :special-members: __call__ +`ApplyTransformToPointsd` +""""""""""""""""""""""""" +.. autoclass:: ApplyTransformToPointsd + :members: + :special-members: __call__ + MetaTensor ^^^^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9548443768..6fe0e62f49 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -493,6 +493,7 @@ from .utility.array import ( AddCoordinateChannels, AddExtremePointsChannel, + ApplyTransformToPoints, AsChannelLast, CastToType, ClassesToIndices, @@ -532,6 +533,9 @@ AddExtremePointsChanneld, AddExtremePointsChannelD, AddExtremePointsChannelDict, + ApplyTransformToPointsd, + ApplyTransformToPointsD, + ApplyTransformToPointsDict, AsChannelLastd, AsChannelLastD, AsChannelLastDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 5dfbcb0e91..fee546bea3 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -31,7 +31,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor -from monai.data.utils import is_no_channel, no_collation +from monai.data.utils import is_no_channel, no_collation, orientation_ras_lps from monai.networks.layers.simplelayers import ( ApplyFilter, EllipticalFilter, @@ -42,16 +42,17 @@ SharpenFilter, median_filter, ) -from monai.transforms.inverse import InvertibleTransform +from monai.transforms.inverse import InvertibleTransform, TraceableTransform from monai.transforms.traits import MultiSampleTrait from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform from monai.transforms.utils import ( + apply_affine_to_points, extreme_points_to_image, get_extreme_points, map_binary_to_indices, map_classes_to_indices, ) -from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices +from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, linalg_inv, moveaxis, unravel_indices from monai.utils import ( MetaKeys, TraceKeys, @@ -66,7 +67,7 @@ ) from monai.utils.enums import TransformBackends from monai.utils.misc import is_module_ver_at_least -from monai.utils.type_conversion import convert_to_dst_type, get_equivalent_dtype +from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") @@ -106,6 +107,7 @@ "ToCupy", "ImageFilter", "RandImageFilter", + "ApplyTransformToPoints", ] @@ -1715,3 +1717,133 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Mapping | None = None) -> Nd if self._do_transform: img = self.filter(img) return img + + +class ApplyTransformToPoints(InvertibleTransform, Transform): + """ + Transform points between image coordinates and world coordinates. + The input coordinates are assumed to be in the shape (C, N, 2 or 3), where C represents the number of channels + and N denotes the number of points. It will return a tensor with the same shape as the input. + + Args: + dtype: The desired data type for the output. + affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates + from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary + Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when + applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly. + The matrix is always converted to float64 for computation, which can be computationally + expensive when applied to a large number of points. + If None, will try to use the affine matrix from the input data. + invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``. + Typically, the affine matrix is derived from an image and represents its location in world space, + while the points are in world coordinates. A value of ``True`` represents transforming these + world space coordinates to the image's coordinate space, and ``False`` the inverse of this operation. + affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system + or you're using `ITKReader` with `affine_lps_to_ras=True`. + This ensures the correct application of the affine transformation between LPS (left-posterior-superior) + and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine + matrix are in the same coordinate system. + + Use Cases: + - Transforming points between world space and image space, and vice versa. + - Automatically handling inverse transformations between image space and world space. + - If points have an existing affine transformation, the class computes and + applies the required delta affine transformation. + + """ + + def __init__( + self, + dtype: DtypeLike | torch.dtype | None = None, + affine: torch.Tensor | None = None, + invert_affine: bool = True, + affine_lps_to_ras: bool = False, + ) -> None: + self.dtype = dtype + self.affine = affine + self.invert_affine = invert_affine + self.affine_lps_to_ras = affine_lps_to_ras + + def transform_coordinates( + self, data: torch.Tensor, affine: torch.Tensor | None = None + ) -> tuple[torch.Tensor, dict]: + """ + Transform coordinates using an affine transformation matrix. + + Args: + data: The input coordinates are assumed to be in the shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + affine: 3x3 or 4x4 affine transformation matrix. The matrix is always converted to float64 for computation, + which can be computationally expensive when applied to a large number of points. + + Returns: + Transformed coordinates. + """ + data = convert_to_tensor(data, track_meta=get_track_meta()) + # applied_affine is the affine transformation matrix that has already been applied to the point data + applied_affine = getattr(data, "affine", None) + + if affine is None and self.invert_affine: + raise ValueError("affine must be provided when invert_affine is True.") + + affine = applied_affine if affine is None else affine + affine = convert_data_type(affine, dtype=torch.float64)[0] # always convert to float64 for affine + original_affine: torch.Tensor = affine + if self.affine_lps_to_ras: + affine = orientation_ras_lps(affine) + + # the final affine transformation matrix that will be applied to the point data + _affine: torch.Tensor = affine + if self.invert_affine: + _affine = linalg_inv(affine) + if applied_affine is not None: + # consider the affine transformation already applied to the data in the world space + # and compute delta affine + _affine = _affine @ linalg_inv(applied_affine) + out = apply_affine_to_points(data, _affine, dtype=self.dtype) + + extra_info = { + "invert_affine": self.invert_affine, + "dtype": get_dtype_string(self.dtype), + "image_affine": original_affine, # record for inverse operation + "affine_lps_to_ras": self.affine_lps_to_ras, + } + xform: torch.Tensor = original_affine if self.invert_affine else linalg_inv(original_affine) + meta_info = TraceableTransform.track_transform_meta( + data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info() + ) + + return out, meta_info + + def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None): + """ + Args: + data: The input coordinates are assumed to be in the shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + affine: A 3x3 or 4x4 affine transformation matrix, this argument will take precedence over ``self.affine``. + """ + if data.ndim != 3 or data.shape[-1] not in (2, 3): + raise ValueError(f"data should be in shape (C, N, 2 or 3), got {data.shape}.") + affine = self.affine if affine is None else affine + if affine is not None and affine.shape not in ((3, 3), (4, 4)): + raise ValueError(f"affine should be in shape (3, 3) or (4, 4), got {affine.shape}.") + + out, meta_info = self.transform_coordinates(data, affine) + + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + def inverse(self, data: torch.Tensor) -> torch.Tensor: + transform = self.pop_transform(data) + # Create inverse transform + dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] + invert_affine = not transform[TraceKeys.EXTRA_INFO]["invert_affine"] + affine = transform[TraceKeys.EXTRA_INFO]["image_affine"] + affine_lps_to_ras = transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"] + inverse_transform = ApplyTransformToPoints( + dtype=dtype, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras + ) + # Apply inverse + with inverse_transform.trace_transform(False): + data = inverse_transform(data, affine) + + return data diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 7e3a7b0454..df63044793 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -35,6 +35,7 @@ from monai.transforms.utility.array import ( AddCoordinateChannels, AddExtremePointsChannel, + ApplyTransformToPoints, AsChannelLast, CastToType, ClassesToIndices, @@ -180,6 +181,9 @@ "ClassesToIndicesd", "ClassesToIndicesD", "ClassesToIndicesDict", + "ApplyTransformToPointsd", + "ApplyTransformToPointsD", + "ApplyTransformToPointsDict", ] DEFAULT_POST_FIX = PostFix.meta() @@ -1740,6 +1744,75 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class ApplyTransformToPointsd(MapTransform, InvertibleTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ApplyTransformToPoints`. + The input coordinates are assumed to be in the shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + The output has the same shape as the input. + + Args: + keys: keys of the corresponding items to be transformed. + See also: monai.transforms.MapTransform + refer_key: The key of the reference item used for transformation. + It can directly refer to an affine or an image from which the affine can be derived. + dtype: The desired data type for the output. + affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates + from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary + Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when + applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly. + The matrix is always converted to float64 for computation, which can be computationally + expensive when applied to a large number of points. + If None, will try to use the affine matrix from the refer data. + invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``. + Typically, the affine matrix is derived from the image, while the points are in world coordinates. + If you want to align the points with the image, set this to ``True``. Otherwise, set it to ``False``. + affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system + or you're using `ITKReader` with `affine_lps_to_ras=True`. + This ensures the correct application of the affine transformation between LPS (left-posterior-superior) + and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine + matrix are in the same coordinate system. + allow_missing_keys: Don't raise exception if key is missing. + """ + + def __init__( + self, + keys: KeysCollection, + refer_key: str | None = None, + dtype: DtypeLike | torch.dtype = torch.float64, + affine: torch.Tensor | None = None, + invert_affine: bool = True, + affine_lps_to_ras: bool = False, + allow_missing_keys: bool = False, + ): + MapTransform.__init__(self, keys, allow_missing_keys) + self.refer_key = refer_key + self.converter = ApplyTransformToPoints( + dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras + ) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]): + d = dict(data) + if self.refer_key is not None: + if self.refer_key in d: + refer_data = d[self.refer_key] + else: + raise KeyError(f"The refer_key '{self.refer_key}' is not found in the data.") + else: + refer_data = None + affine = getattr(refer_data, "affine", refer_data) + for key in self.key_iterator(d): + coords = d[key] + d[key] = self.converter(coords, affine) + return d + + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter.inverse(d[key]) + return d + + RandImageFilterD = RandImageFilterDict = RandImageFilterd ImageFilterD = ImageFilterDict = ImageFilterd IdentityD = IdentityDict = Identityd @@ -1780,3 +1853,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N RandCuCIMD = RandCuCIMDict = RandCuCIMd AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd +ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 363fce91be..a82dced252 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -27,6 +27,7 @@ import monai from monai.config import DtypeLike, IndexSelection from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor +from monai.data.utils import to_affine_nd from monai.networks.layers import GaussianFilter from monai.networks.utils import meshgrid_ij from monai.transforms.compose import Compose @@ -35,6 +36,7 @@ from monai.transforms.utils_pytorch_numpy_unification import ( any_np_pt, ascontiguousarray, + concatenate, cumsum, isfinite, nonzero, @@ -2509,5 +2511,26 @@ def distance_transform_edt( return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0] +def apply_affine_to_points(data: torch.Tensor, affine: torch.Tensor, dtype: DtypeLike | torch.dtype | None = None): + """ + apply affine transformation to a set of points. + + Args: + data: input data to apply affine transformation, should be a tensor of shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + affine: affine matrix to be applied, should be a tensor of shape (3, 3) or (4, 4). + dtype: output data dtype. + """ + data_: torch.Tensor = convert_to_tensor(data, track_meta=False, dtype=torch.float64) + affine = to_affine_nd(data_.shape[-1], affine) + + homogeneous: torch.Tensor = concatenate((data_, torch.ones((data_.shape[0], data_.shape[1], 1))), axis=2) # type: ignore + transformed_homogeneous = torch.matmul(homogeneous, affine.T) + transformed_coordinates = transformed_homogeneous[:, :, :-1] + out, *_ = convert_to_dst_type(transformed_coordinates, data, dtype=dtype) + + return out + + if __name__ == "__main__": print_transform_backends() diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 03fa1ceed1..4e36e3cd47 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -148,6 +148,7 @@ dtype_numpy_to_torch, dtype_torch_to_numpy, get_dtype, + get_dtype_string, get_equivalent_dtype, get_numpy_dtype_from_string, get_torch_dtype_from_string, diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index e4f97fc4a6..420e935b33 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -33,6 +33,7 @@ "get_equivalent_dtype", "convert_data_type", "get_dtype", + "get_dtype_string", "convert_to_cupy", "convert_to_numpy", "convert_to_tensor", @@ -102,6 +103,13 @@ def get_dtype(data: Any) -> DtypeLike | torch.dtype: return type(data) +def get_dtype_string(dtype: DtypeLike | torch.dtype) -> str: + """Get a string representation of the dtype.""" + if isinstance(dtype, torch.dtype): + return str(dtype)[6:] + return str(dtype)[3:] + + def convert_to_tensor( data: Any, dtype: DtypeLike | torch.dtype = None, diff --git a/tests/test_apply_transform_to_points.py b/tests/test_apply_transform_to_points.py new file mode 100644 index 0000000000..bd5fb9e393 --- /dev/null +++ b/tests/test_apply_transform_to_points.py @@ -0,0 +1,121 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.data import MetaTensor +from monai.transforms.utility.array import ApplyTransformToPoints +from monai.utils import set_determinism + +set_determinism(seed=0) + +DATA_2D = torch.rand(1, 64, 64) +DATA_3D = torch.rand(1, 64, 64, 64) +POINT_2D_WORLD = torch.tensor([[[2, 2], [2, 4], [4, 6]]]) +POINT_2D_IMAGE = torch.tensor([[[1, 1], [1, 2], [2, 3]]]) +POINT_2D_IMAGE_RAS = torch.tensor([[[-1, -1], [-1, -2], [-2, -3]]]) +POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]]) +POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]]) +POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]]) + +TEST_CASES = [ + [ + MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_2D_WORLD, + None, + True, + False, + POINT_2D_IMAGE, + ], + [ + None, + MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + None, + False, + False, + POINT_2D_WORLD, + ], + [ + None, + MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]), + False, + False, + POINT_2D_WORLD, + ], + [ + MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_2D_WORLD, + None, + True, + True, + POINT_2D_IMAGE_RAS, + ], + [ + MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_3D_WORLD, + None, + True, + False, + POINT_3D_IMAGE, + ], + [ + MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + MetaTensor(POINT_3D_IMAGE, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + None, + False, + False, + POINT_3D_WORLD, + ], + [ + MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_3D_WORLD, + None, + True, + True, + POINT_3D_IMAGE_RAS, + ], +] + +TEST_CASES_WRONG = [ + [POINT_2D_WORLD, True, None], + [POINT_2D_WORLD.unsqueeze(0), False, None], + [POINT_3D_WORLD[..., 0:1], False, None], + [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])], +] + + +class TestCoordinateTransform(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_transform_coordinates(self, image, points, affine, invert_affine, affine_lps_to_ras, expected_output): + transform = ApplyTransformToPoints( + dtype=torch.int64, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras + ) + affine = image.affine if image is not None else None + output = transform(points, affine) + self.assertTrue(torch.allclose(output, expected_output)) + invert_out = transform.inverse(output) + self.assertTrue(torch.allclose(invert_out, points)) + + @parameterized.expand(TEST_CASES_WRONG) + def test_wrong_input(self, input, invert_affine, affine): + transform = ApplyTransformToPoints(dtype=torch.int64, invert_affine=invert_affine) + with self.assertRaises(ValueError): + transform(input, affine) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_apply_transform_to_pointsd.py b/tests/test_apply_transform_to_pointsd.py new file mode 100644 index 0000000000..4cedfa9d66 --- /dev/null +++ b/tests/test_apply_transform_to_pointsd.py @@ -0,0 +1,133 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.data import MetaTensor +from monai.transforms.utility.dictionary import ApplyTransformToPointsd +from monai.utils import set_determinism + +set_determinism(seed=0) + +DATA_2D = torch.rand(1, 64, 64) +DATA_3D = torch.rand(1, 64, 64, 64) +POINT_2D_WORLD = torch.tensor([[[2, 2], [2, 4], [4, 6]]]) +POINT_2D_IMAGE = torch.tensor([[[1, 1], [1, 2], [2, 3]]]) +POINT_2D_IMAGE_RAS = torch.tensor([[[-1, -1], [-1, -2], [-2, -3]]]) +POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]]) +POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]]) +POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]]) + +TEST_CASES = [ + [ + MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_2D_WORLD, + None, + True, + False, + POINT_2D_IMAGE, + ], + [ + None, + MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + None, + False, + False, + POINT_2D_WORLD, + ], + [ + None, + MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]), + False, + False, + POINT_2D_WORLD, + ], + [ + MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_2D_WORLD, + None, + True, + True, + POINT_2D_IMAGE_RAS, + ], + [ + MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_3D_WORLD, + None, + True, + False, + POINT_3D_IMAGE, + ], + ["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], + [ + MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + MetaTensor(POINT_3D_IMAGE, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + None, + False, + False, + POINT_3D_WORLD, + ], + [ + MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_3D_WORLD, + None, + True, + True, + POINT_3D_IMAGE_RAS, + ], +] + +TEST_CASES_WRONG = [ + [POINT_2D_WORLD, True, None], + [POINT_2D_WORLD.unsqueeze(0), False, None], + [POINT_3D_WORLD[..., 0:1], False, None], + [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])], +] + + +class TestCoordinateTransform(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_transform_coordinates(self, image, points, affine, invert_affine, affine_lps_to_ras, expected_output): + data = { + "image": image, + "point": points, + "affine": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]), + } + refer_key = "image" if (image is not None and image != "affine") else image + transform = ApplyTransformToPointsd( + keys="point", + refer_key=refer_key, + dtype=torch.int64, + affine=affine, + invert_affine=invert_affine, + affine_lps_to_ras=affine_lps_to_ras, + ) + output = transform(data) + + self.assertTrue(torch.allclose(output["point"], expected_output)) + invert_out = transform.inverse(output) + self.assertTrue(torch.allclose(invert_out["point"], points)) + + @parameterized.expand(TEST_CASES_WRONG) + def test_wrong_input(self, input, invert_affine, affine): + transform = ApplyTransformToPointsd(keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine) + with self.assertRaises(ValueError): + transform({"point": input}) + + +if __name__ == "__main__": + unittest.main()