Skip to content

Commit

Permalink
Image space and World Space convert transforms (#7942)
Browse files Browse the repository at this point in the history
Fixes # .

### Description
Add `ImageToWorldSpace` and `WorldToImageSpace`

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <[email protected]>
Co-authored-by: Mingxin Zheng <[email protected]>
  • Loading branch information
4 people authored Aug 29, 2024
1 parent 244967d commit 298a8d6
Show file tree
Hide file tree
Showing 9 changed files with 512 additions and 4 deletions.
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,12 @@ Utility
:members:
:special-members: __call__

`ApplyTransformToPoints`
""""""""""""""""""""""""
.. autoclass:: ApplyTransformToPoints
:members:
:special-members: __call__

Dictionary Transforms
---------------------

Expand Down Expand Up @@ -2265,6 +2271,12 @@ Utility (Dict)
:members:
:special-members: __call__

`ApplyTransformToPointsd`
"""""""""""""""""""""""""
.. autoclass:: ApplyTransformToPointsd
:members:
:special-members: __call__


MetaTensor
^^^^^^^^^^
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@
from .utility.array import (
AddCoordinateChannels,
AddExtremePointsChannel,
ApplyTransformToPoints,
AsChannelLast,
CastToType,
ClassesToIndices,
Expand Down Expand Up @@ -532,6 +533,9 @@
AddExtremePointsChanneld,
AddExtremePointsChannelD,
AddExtremePointsChannelDict,
ApplyTransformToPointsd,
ApplyTransformToPointsD,
ApplyTransformToPointsDict,
AsChannelLastd,
AsChannelLastD,
AsChannelLastDict,
Expand Down
140 changes: 136 additions & 4 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import is_no_channel, no_collation
from monai.data.utils import is_no_channel, no_collation, orientation_ras_lps
from monai.networks.layers.simplelayers import (
ApplyFilter,
EllipticalFilter,
Expand All @@ -42,16 +42,17 @@
SharpenFilter,
median_filter,
)
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.inverse import InvertibleTransform, TraceableTransform
from monai.transforms.traits import MultiSampleTrait
from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform
from monai.transforms.utils import (
apply_affine_to_points,
extreme_points_to_image,
get_extreme_points,
map_binary_to_indices,
map_classes_to_indices,
)
from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices
from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, linalg_inv, moveaxis, unravel_indices
from monai.utils import (
MetaKeys,
TraceKeys,
Expand All @@ -66,7 +67,7 @@
)
from monai.utils.enums import TransformBackends
from monai.utils.misc import is_module_ver_at_least
from monai.utils.type_conversion import convert_to_dst_type, get_equivalent_dtype
from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype

PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray")
Expand Down Expand Up @@ -106,6 +107,7 @@
"ToCupy",
"ImageFilter",
"RandImageFilter",
"ApplyTransformToPoints",
]


Expand Down Expand Up @@ -1715,3 +1717,133 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Mapping | None = None) -> Nd
if self._do_transform:
img = self.filter(img)
return img


class ApplyTransformToPoints(InvertibleTransform, Transform):
"""
Transform points between image coordinates and world coordinates.
The input coordinates are assumed to be in the shape (C, N, 2 or 3), where C represents the number of channels
and N denotes the number of points. It will return a tensor with the same shape as the input.
Args:
dtype: The desired data type for the output.
affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates
from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary
Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when
applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly.
The matrix is always converted to float64 for computation, which can be computationally
expensive when applied to a large number of points.
If None, will try to use the affine matrix from the input data.
invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``.
Typically, the affine matrix is derived from an image and represents its location in world space,
while the points are in world coordinates. A value of ``True`` represents transforming these
world space coordinates to the image's coordinate space, and ``False`` the inverse of this operation.
affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system
or you're using `ITKReader` with `affine_lps_to_ras=True`.
This ensures the correct application of the affine transformation between LPS (left-posterior-superior)
and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine
matrix are in the same coordinate system.
Use Cases:
- Transforming points between world space and image space, and vice versa.
- Automatically handling inverse transformations between image space and world space.
- If points have an existing affine transformation, the class computes and
applies the required delta affine transformation.
"""

def __init__(
self,
dtype: DtypeLike | torch.dtype | None = None,
affine: torch.Tensor | None = None,
invert_affine: bool = True,
affine_lps_to_ras: bool = False,
) -> None:
self.dtype = dtype
self.affine = affine
self.invert_affine = invert_affine
self.affine_lps_to_ras = affine_lps_to_ras

def transform_coordinates(
self, data: torch.Tensor, affine: torch.Tensor | None = None
) -> tuple[torch.Tensor, dict]:
"""
Transform coordinates using an affine transformation matrix.
Args:
data: The input coordinates are assumed to be in the shape (C, N, 2 or 3),
where C represents the number of channels and N denotes the number of points.
affine: 3x3 or 4x4 affine transformation matrix. The matrix is always converted to float64 for computation,
which can be computationally expensive when applied to a large number of points.
Returns:
Transformed coordinates.
"""
data = convert_to_tensor(data, track_meta=get_track_meta())
# applied_affine is the affine transformation matrix that has already been applied to the point data
applied_affine = getattr(data, "affine", None)

if affine is None and self.invert_affine:
raise ValueError("affine must be provided when invert_affine is True.")

affine = applied_affine if affine is None else affine
affine = convert_data_type(affine, dtype=torch.float64)[0] # always convert to float64 for affine
original_affine: torch.Tensor = affine
if self.affine_lps_to_ras:
affine = orientation_ras_lps(affine)

# the final affine transformation matrix that will be applied to the point data
_affine: torch.Tensor = affine
if self.invert_affine:
_affine = linalg_inv(affine)
if applied_affine is not None:
# consider the affine transformation already applied to the data in the world space
# and compute delta affine
_affine = _affine @ linalg_inv(applied_affine)
out = apply_affine_to_points(data, _affine, dtype=self.dtype)

extra_info = {
"invert_affine": self.invert_affine,
"dtype": get_dtype_string(self.dtype),
"image_affine": original_affine, # record for inverse operation
"affine_lps_to_ras": self.affine_lps_to_ras,
}
xform: torch.Tensor = original_affine if self.invert_affine else linalg_inv(original_affine)
meta_info = TraceableTransform.track_transform_meta(
data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info()
)

return out, meta_info

def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None):
"""
Args:
data: The input coordinates are assumed to be in the shape (C, N, 2 or 3),
where C represents the number of channels and N denotes the number of points.
affine: A 3x3 or 4x4 affine transformation matrix, this argument will take precedence over ``self.affine``.
"""
if data.ndim != 3 or data.shape[-1] not in (2, 3):
raise ValueError(f"data should be in shape (C, N, 2 or 3), got {data.shape}.")
affine = self.affine if affine is None else affine
if affine is not None and affine.shape not in ((3, 3), (4, 4)):
raise ValueError(f"affine should be in shape (3, 3) or (4, 4), got {affine.shape}.")

out, meta_info = self.transform_coordinates(data, affine)

return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out

def inverse(self, data: torch.Tensor) -> torch.Tensor:
transform = self.pop_transform(data)
# Create inverse transform
dtype = transform[TraceKeys.EXTRA_INFO]["dtype"]
invert_affine = not transform[TraceKeys.EXTRA_INFO]["invert_affine"]
affine = transform[TraceKeys.EXTRA_INFO]["image_affine"]
affine_lps_to_ras = transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"]
inverse_transform = ApplyTransformToPoints(
dtype=dtype, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
)
# Apply inverse
with inverse_transform.trace_transform(False):
data = inverse_transform(data, affine)

return data
74 changes: 74 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from monai.transforms.utility.array import (
AddCoordinateChannels,
AddExtremePointsChannel,
ApplyTransformToPoints,
AsChannelLast,
CastToType,
ClassesToIndices,
Expand Down Expand Up @@ -180,6 +181,9 @@
"ClassesToIndicesd",
"ClassesToIndicesD",
"ClassesToIndicesDict",
"ApplyTransformToPointsd",
"ApplyTransformToPointsD",
"ApplyTransformToPointsDict",
]

DEFAULT_POST_FIX = PostFix.meta()
Expand Down Expand Up @@ -1740,6 +1744,75 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ApplyTransformToPoints`.
The input coordinates are assumed to be in the shape (C, N, 2 or 3),
where C represents the number of channels and N denotes the number of points.
The output has the same shape as the input.
Args:
keys: keys of the corresponding items to be transformed.
See also: monai.transforms.MapTransform
refer_key: The key of the reference item used for transformation.
It can directly refer to an affine or an image from which the affine can be derived.
dtype: The desired data type for the output.
affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates
from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary
Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when
applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly.
The matrix is always converted to float64 for computation, which can be computationally
expensive when applied to a large number of points.
If None, will try to use the affine matrix from the refer data.
invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``.
Typically, the affine matrix is derived from the image, while the points are in world coordinates.
If you want to align the points with the image, set this to ``True``. Otherwise, set it to ``False``.
affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system
or you're using `ITKReader` with `affine_lps_to_ras=True`.
This ensures the correct application of the affine transformation between LPS (left-posterior-superior)
and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine
matrix are in the same coordinate system.
allow_missing_keys: Don't raise exception if key is missing.
"""

def __init__(
self,
keys: KeysCollection,
refer_key: str | None = None,
dtype: DtypeLike | torch.dtype = torch.float64,
affine: torch.Tensor | None = None,
invert_affine: bool = True,
affine_lps_to_ras: bool = False,
allow_missing_keys: bool = False,
):
MapTransform.__init__(self, keys, allow_missing_keys)
self.refer_key = refer_key
self.converter = ApplyTransformToPoints(
dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
)

def __call__(self, data: Mapping[Hashable, torch.Tensor]):
d = dict(data)
if self.refer_key is not None:
if self.refer_key in d:
refer_data = d[self.refer_key]
else:
raise KeyError(f"The refer_key '{self.refer_key}' is not found in the data.")
else:
refer_data = None
affine = getattr(refer_data, "affine", refer_data)
for key in self.key_iterator(d):
coords = d[key]
d[key] = self.converter(coords, affine)
return d

def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter.inverse(d[key])
return d


RandImageFilterD = RandImageFilterDict = RandImageFilterd
ImageFilterD = ImageFilterDict = ImageFilterd
IdentityD = IdentityDict = Identityd
Expand Down Expand Up @@ -1780,3 +1853,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
RandCuCIMD = RandCuCIMDict = RandCuCIMd
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd
ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd
23 changes: 23 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import monai
from monai.config import DtypeLike, IndexSelection
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.data.utils import to_affine_nd
from monai.networks.layers import GaussianFilter
from monai.networks.utils import meshgrid_ij
from monai.transforms.compose import Compose
Expand All @@ -35,6 +36,7 @@
from monai.transforms.utils_pytorch_numpy_unification import (
any_np_pt,
ascontiguousarray,
concatenate,
cumsum,
isfinite,
nonzero,
Expand Down Expand Up @@ -2509,5 +2511,26 @@ def distance_transform_edt(
return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0]


def apply_affine_to_points(data: torch.Tensor, affine: torch.Tensor, dtype: DtypeLike | torch.dtype | None = None):
"""
apply affine transformation to a set of points.
Args:
data: input data to apply affine transformation, should be a tensor of shape (C, N, 2 or 3),
where C represents the number of channels and N denotes the number of points.
affine: affine matrix to be applied, should be a tensor of shape (3, 3) or (4, 4).
dtype: output data dtype.
"""
data_: torch.Tensor = convert_to_tensor(data, track_meta=False, dtype=torch.float64)
affine = to_affine_nd(data_.shape[-1], affine)

homogeneous: torch.Tensor = concatenate((data_, torch.ones((data_.shape[0], data_.shape[1], 1))), axis=2) # type: ignore
transformed_homogeneous = torch.matmul(homogeneous, affine.T)
transformed_coordinates = transformed_homogeneous[:, :, :-1]
out, *_ = convert_to_dst_type(transformed_coordinates, data, dtype=dtype)

return out


if __name__ == "__main__":
print_transform_backends()
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
dtype_numpy_to_torch,
dtype_torch_to_numpy,
get_dtype,
get_dtype_string,
get_equivalent_dtype,
get_numpy_dtype_from_string,
get_torch_dtype_from_string,
Expand Down
8 changes: 8 additions & 0 deletions monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"get_equivalent_dtype",
"convert_data_type",
"get_dtype",
"get_dtype_string",
"convert_to_cupy",
"convert_to_numpy",
"convert_to_tensor",
Expand Down Expand Up @@ -102,6 +103,13 @@ def get_dtype(data: Any) -> DtypeLike | torch.dtype:
return type(data)


def get_dtype_string(dtype: DtypeLike | torch.dtype) -> str:
"""Get a string representation of the dtype."""
if isinstance(dtype, torch.dtype):
return str(dtype)[6:]
return str(dtype)[3:]


def convert_to_tensor(
data: Any,
dtype: DtypeLike | torch.dtype = None,
Expand Down
Loading

0 comments on commit 298a8d6

Please sign in to comment.