Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image space and World Space convert transforms #7942

Merged
merged 67 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 65 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
c4a67d8
add image and world space convert transform
KumoLiu Jul 23, 2024
558a1da
Merge remote-tracking branch 'origin/geometric' into world2image
KumoLiu Jul 24, 2024
87550de
combine to one transform
KumoLiu Jul 25, 2024
46ec145
add dictionary version
KumoLiu Jul 25, 2024
7930bed
add enum
KumoLiu Jul 25, 2024
52f4253
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2024
25c3686
support invert
KumoLiu Jul 25, 2024
7ed5f97
include demo notebook
KumoLiu Jul 25, 2024
3b29c2e
Merge branch 'world2image' of https://github.com/KumoLiu/MONAI into w…
KumoLiu Jul 25, 2024
6e06d65
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2024
a9f1704
fix format
KumoLiu Jul 25, 2024
5647a8d
upload link
KumoLiu Jul 26, 2024
3d3bacd
remove invert
KumoLiu Jul 26, 2024
4f4942c
ensure affine
KumoLiu Jul 30, 2024
0dfdf7e
update notebook
KumoLiu Jul 30, 2024
be7c897
add coronal and sagittal cases
KumoLiu Aug 1, 2024
e7b3e88
add `apply_affine_to_points`
KumoLiu Aug 1, 2024
9b6e01c
update test cases
KumoLiu Aug 2, 2024
2651de4
add luna pipeline
KumoLiu Aug 8, 2024
a819491
add `get_dtype_string`
KumoLiu Aug 15, 2024
9425248
add type hint
KumoLiu Aug 15, 2024
83de050
add type hint
KumoLiu Aug 15, 2024
c0b3bb9
add shape check
KumoLiu Aug 15, 2024
d196339
add warning
KumoLiu Aug 15, 2024
7d4cd68
Merge branch 'geometric' into world2image
KumoLiu Aug 19, 2024
83f59b4
support multi-channel
KumoLiu Aug 22, 2024
87b0c47
support 2d points
KumoLiu Aug 22, 2024
524d509
enhance coordinate trans
KumoLiu Aug 22, 2024
6421c0b
minor modify
KumoLiu Aug 22, 2024
a4a14c5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2024
4853f25
update
KumoLiu Aug 22, 2024
608816b
Merge branch 'world2image' of https://github.com/KumoLiu/MONAI into w…
KumoLiu Aug 22, 2024
5c4637a
minor fix
KumoLiu Aug 22, 2024
93afb03
add test
KumoLiu Aug 22, 2024
5e6fa65
minor modify
KumoLiu Aug 22, 2024
5ed2db6
minor fix
KumoLiu Aug 22, 2024
690a3c1
address comments
KumoLiu Aug 23, 2024
54422e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2024
3a18f10
add unittests
KumoLiu Aug 23, 2024
6d676ca
Merge branch 'world2image' of https://github.com/KumoLiu/MONAI into w…
KumoLiu Aug 23, 2024
d2319ad
modify test cases
KumoLiu Aug 23, 2024
5761a7b
fix format
KumoLiu Aug 23, 2024
6eb23b9
address comments
KumoLiu Aug 26, 2024
34d64c1
address comments
KumoLiu Aug 26, 2024
d535b4a
format fix
KumoLiu Aug 26, 2024
16f14a1
Update monai/transforms/utility/array.py
KumoLiu Aug 27, 2024
b9e3984
address comments
KumoLiu Aug 27, 2024
d6cc8f0
Update monai/transforms/utility/array.py
KumoLiu Aug 27, 2024
48f97b5
enhance docstring
KumoLiu Aug 27, 2024
0b200bd
address comments
KumoLiu Aug 27, 2024
7a91b1e
fix format
KumoLiu Aug 27, 2024
1469af7
remove notebook
KumoLiu Aug 27, 2024
13441b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
4d258d1
fix format
KumoLiu Aug 27, 2024
b1b5a87
enhance docstring
KumoLiu Aug 27, 2024
815b702
fix mypy
KumoLiu Aug 27, 2024
0724878
fix format
KumoLiu Aug 27, 2024
ffab390
address comments
KumoLiu Aug 28, 2024
5808439
Update monai/transforms/utility/array.py
KumoLiu Aug 28, 2024
729f11b
address comments
KumoLiu Aug 28, 2024
6f34194
Merge branch 'world2image' of https://github.com/KumoLiu/MONAI into w…
KumoLiu Aug 28, 2024
48707ed
fix format
KumoLiu Aug 28, 2024
a7d201b
add in init
KumoLiu Aug 28, 2024
f21f967
fix format
KumoLiu Aug 28, 2024
c078fe2
fix format
KumoLiu Aug 28, 2024
0c548c4
fix docstring format
KumoLiu Aug 28, 2024
ee2ac77
enhance docstring
KumoLiu Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,12 @@ Utility
:members:
:special-members: __call__

`ApplyTransformToPoints`
""""""""""""""""""""""""
.. autoclass:: ApplyTransformToPoints
:members:
:special-members: __call__

Dictionary Transforms
---------------------

Expand Down Expand Up @@ -2265,6 +2271,12 @@ Utility (Dict)
:members:
:special-members: __call__

`ApplyTransformToPointsd`
"""""""""""""""""""""""""
.. autoclass:: ApplyTransformToPointsd
:members:
:special-members: __call__


MetaTensor
^^^^^^^^^^
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@
from .utility.array import (
AddCoordinateChannels,
AddExtremePointsChannel,
ApplyTransformToPoints,
AsChannelLast,
CastToType,
ClassesToIndices,
Expand Down Expand Up @@ -532,6 +533,9 @@
AddExtremePointsChanneld,
AddExtremePointsChannelD,
AddExtremePointsChannelDict,
ApplyTransformToPointsd,
ApplyTransformToPointsD,
ApplyTransformToPointsDict,
AsChannelLastd,
AsChannelLastD,
AsChannelLastDict,
Expand Down
142 changes: 138 additions & 4 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import is_no_channel, no_collation
from monai.data.utils import is_no_channel, no_collation, orientation_ras_lps
from monai.networks.layers.simplelayers import (
ApplyFilter,
EllipticalFilter,
Expand All @@ -42,16 +42,17 @@
SharpenFilter,
median_filter,
)
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.inverse import InvertibleTransform, TraceableTransform
from monai.transforms.traits import MultiSampleTrait
from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform
from monai.transforms.utils import (
apply_affine_to_points,
extreme_points_to_image,
get_extreme_points,
map_binary_to_indices,
map_classes_to_indices,
)
from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices
from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, linalg_inv, moveaxis, unravel_indices
from monai.utils import (
MetaKeys,
TraceKeys,
Expand All @@ -66,7 +67,7 @@
)
from monai.utils.enums import TransformBackends
from monai.utils.misc import is_module_ver_at_least
from monai.utils.type_conversion import convert_to_dst_type, get_equivalent_dtype
from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype

PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray")
Expand Down Expand Up @@ -106,6 +107,7 @@
"ToCupy",
"ImageFilter",
"RandImageFilter",
"ApplyTransformToPoints",
]


Expand Down Expand Up @@ -1715,3 +1717,135 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Mapping | None = None) -> Nd
if self._do_transform:
img = self.filter(img)
return img


class ApplyTransformToPoints(InvertibleTransform, Transform):
mingxin-zheng marked this conversation as resolved.
Show resolved Hide resolved
"""
Transform points between image coordinates and world coordinates.
ericspod marked this conversation as resolved.
Show resolved Hide resolved
The input coordinates are assumed to be in the shape (C, N, 2 or 3), where C represents the number of channels
and N denotes the number of points. It will return a tensor with the same shape as the input.

Args:
dtype: The desired data type for the output.
affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates
from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary
Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when
applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly.
The matrix is always converted to float64 for computation, which can be computationally
expensive when applied to a large number of points.
If None, will try to use the affine matrix from the input data.
invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``.
Typically, the affine matrix is derived from an image and represents its location in world space,
while the points are in world coordinates. A value of ``True`` represents transforming these
world space coordinates to the image's coordinate space, and ``False`` the inverse of this operation.
affine_lps_to_ras: Defaults to ``False``. Set this to ``True`` if all of the following are true:
1) The image is read by `ITKReader`,
2) The `ITKReader` has `affine_lps_to_ras=True`,
3) The data is in world coordinates.
This ensures the correct application of the affine transformation between LPS (left-posterior-superior)
and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine
matrix are in the same coordinate system.

vikashg marked this conversation as resolved.
Show resolved Hide resolved
Use Cases:
- Transforming points between world space and image space, and vice versa.
- Automatically handling inverse transformations between image space and world space.
- If points have an existing affine transformation, the class computes and
applies the required delta affine transformation.

"""

def __init__(
self,
dtype: DtypeLike | torch.dtype | None = None,
affine: torch.Tensor | None = None,
invert_affine: bool = True,
affine_lps_to_ras: bool = False,
) -> None:
self.dtype = dtype
self.affine = affine
self.invert_affine = invert_affine
self.affine_lps_to_ras = affine_lps_to_ras

def transform_coordinates(
self, data: torch.Tensor, affine: torch.Tensor | None = None
) -> tuple[torch.Tensor, dict]:
"""
Transform coordinates using an affine transformation matrix.

Args:
data: The input coordinates are assumed to be in the shape (C, N, 2 or 3),
where C represents the number of channels and N denotes the number of points.
affine: 3x3 or 4x4 affine transformation matrix. The matrix is always converted to float64 for computation,
which can be computationally expensive when applied to a large number of points.

Returns:
Transformed coordinates.
"""
data = convert_to_tensor(data, track_meta=get_track_meta())
# applied_affine is the affine transformation matrix that has already been applied to the point data
applied_affine = getattr(data, "affine", None)

if affine is None and self.invert_affine:
raise ValueError("affine must be provided when invert_affine is True.")

affine = applied_affine if affine is None else affine
affine = convert_data_type(affine, dtype=torch.float64)[0] # always convert to float64 for affine
ericspod marked this conversation as resolved.
Show resolved Hide resolved
original_affine: torch.Tensor = affine
if self.affine_lps_to_ras:
affine = orientation_ras_lps(affine)

# the final affine transformation matrix that will be applied to the point data
_affine: torch.Tensor = affine
if self.invert_affine:
_affine = linalg_inv(affine)
if applied_affine is not None:
# consider the affine transformation already applied to the data in the world space
# and compute delta affine
_affine = _affine @ linalg_inv(applied_affine)
out = apply_affine_to_points(data, _affine, dtype=self.dtype)

extra_info = {
"invert_affine": self.invert_affine,
"dtype": get_dtype_string(self.dtype),
"image_affine": original_affine, # record for inverse operation
"affine_lps_to_ras": self.affine_lps_to_ras,
}
xform: torch.Tensor = original_affine if self.invert_affine else linalg_inv(original_affine)
mingxin-zheng marked this conversation as resolved.
Show resolved Hide resolved
meta_info = TraceableTransform.track_transform_meta(
data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info()
)

return out, meta_info

def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None):
"""
Args:
data: The input coordinates are assumed to be in the shape (C, N, 2 or 3),
where C represents the number of channels and N denotes the number of points.
affine: A 3x3 or 4x4 affine transformation matrix, this argument will take precedence over ``self.affine``.
"""
if data.ndim != 3 or data.shape[-1] not in (2, 3):
raise ValueError(f"data should be in shape (C, N, 2 or 3), got {data.shape}.")
affine = self.affine if affine is None else affine
if affine is not None and affine.shape not in ((3, 3), (4, 4)):
raise ValueError(f"affine should be in shape (3, 3) or (4, 4), got {affine.shape}.")

out, meta_info = self.transform_coordinates(data, affine)

return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out

def inverse(self, data: torch.Tensor) -> torch.Tensor:
transform = self.pop_transform(data)
# Create inverse transform
dtype = transform[TraceKeys.EXTRA_INFO]["dtype"]
invert_affine = not transform[TraceKeys.EXTRA_INFO]["invert_affine"]
affine = transform[TraceKeys.EXTRA_INFO]["image_affine"]
affine_lps_to_ras = transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"]
inverse_transform = ApplyTransformToPoints(
dtype=dtype, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
)
# Apply inverse
with inverse_transform.trace_transform(False):
data = inverse_transform(data, affine)

return data
76 changes: 76 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from monai.transforms.utility.array import (
AddCoordinateChannels,
AddExtremePointsChannel,
ApplyTransformToPoints,
AsChannelLast,
CastToType,
ClassesToIndices,
Expand Down Expand Up @@ -180,6 +181,9 @@
"ClassesToIndicesd",
"ClassesToIndicesD",
"ClassesToIndicesDict",
"ApplyTransformToPointsd",
"ApplyTransformToPointsD",
"ApplyTransformToPointsDict",
]

DEFAULT_POST_FIX = PostFix.meta()
Expand Down Expand Up @@ -1740,6 +1744,77 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ApplyTransformToPoints`.
The input coordinates are assumed to be in the shape (C, N, 2 or 3),
where C represents the number of channels and N denotes the number of points.
The output has the same shape as the input.

Args:
keys: keys of the corresponding items to be transformed.
See also: monai.transforms.MapTransform
refer_key: The key of the reference item used for transformation.
It can directly refer to an affine or an image from which the affine can be derived.
dtype: The desired data type for the output.
affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates
mingxin-zheng marked this conversation as resolved.
Show resolved Hide resolved
from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary
Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when
applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly.
The matrix is always converted to float64 for computation, which can be computationally
expensive when applied to a large number of points.
If None, will try to use the affine matrix from the refer data.
invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``.
Typically, the affine matrix is derived from the image, while the points are in world coordinates.
If you want to align the points with the image, set this to ``True``. Otherwise, set it to ``False``.
affine_lps_to_ras: Defaults to ``False``. Set this to ``True`` if all of the following are true:
vikashg marked this conversation as resolved.
Show resolved Hide resolved
1) The image is read by `ITKReader`,
2) The `ITKReader` has `affine_lps_to_ras=True`,
3) The data is in world coordinates.
This ensures the correct application of the affine transformation between LPS (left-posterior-superior)
and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine
matrix are in the same coordinate system.
allow_missing_keys: Don't raise exception if key is missing.
"""

def __init__(
self,
keys: KeysCollection,
refer_key: str | None = None,
dtype: DtypeLike | torch.dtype = torch.float64,
affine: torch.Tensor | None = None,
invert_affine: bool = True,
affine_lps_to_ras: bool = False,
allow_missing_keys: bool = False,
):
MapTransform.__init__(self, keys, allow_missing_keys)
self.refer_key = refer_key
self.converter = ApplyTransformToPoints(
dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
)

def __call__(self, data: Mapping[Hashable, torch.Tensor]):
d = dict(data)
if self.refer_key is not None:
if self.refer_key in d:
refer_data = d[self.refer_key]
else:
raise KeyError(f"The refer_key '{self.refer_key}' is not found in the data.")
else:
refer_data = None
affine = getattr(refer_data, "affine", refer_data)
for key in self.key_iterator(d):
coords = d[key]
d[key] = self.converter(coords, affine)
return d

def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter.inverse(d[key])
return d


RandImageFilterD = RandImageFilterDict = RandImageFilterd
ImageFilterD = ImageFilterDict = ImageFilterd
IdentityD = IdentityDict = Identityd
Expand Down Expand Up @@ -1780,3 +1855,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
RandCuCIMD = RandCuCIMDict = RandCuCIMd
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd
ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd
23 changes: 23 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import monai
from monai.config import DtypeLike, IndexSelection
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.data.utils import to_affine_nd
from monai.networks.layers import GaussianFilter
from monai.networks.utils import meshgrid_ij
from monai.transforms.compose import Compose
Expand All @@ -35,6 +36,7 @@
from monai.transforms.utils_pytorch_numpy_unification import (
any_np_pt,
ascontiguousarray,
concatenate,
cumsum,
isfinite,
nonzero,
Expand Down Expand Up @@ -2509,5 +2511,26 @@ def distance_transform_edt(
return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0]


def apply_affine_to_points(data: torch.Tensor, affine: torch.Tensor, dtype: DtypeLike | torch.dtype | None = None):
"""
apply affine transformation to a set of points.

Args:
data: input data to apply affine transformation, should be a tensor of shape (C, N, 2 or 3),
where C represents the number of channels and N denotes the number of points.
affine: affine matrix to be applied, should be a tensor of shape (3, 3) or (4, 4).
dtype: output data dtype.
"""
data_: torch.Tensor = convert_to_tensor(data, track_meta=False, dtype=torch.float64)
affine = to_affine_nd(data_.shape[-1], affine)

homogeneous: torch.Tensor = concatenate((data_, torch.ones((data_.shape[0], data_.shape[1], 1))), axis=2) # type: ignore
transformed_homogeneous = torch.matmul(homogeneous, affine.T)
transformed_coordinates = transformed_homogeneous[:, :, :-1]
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
out, *_ = convert_to_dst_type(transformed_coordinates, data, dtype=dtype)

return out


if __name__ == "__main__":
print_transform_backends()
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
dtype_numpy_to_torch,
dtype_torch_to_numpy,
get_dtype,
get_dtype_string,
get_equivalent_dtype,
get_numpy_dtype_from_string,
get_torch_dtype_from_string,
Expand Down
8 changes: 8 additions & 0 deletions monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"get_equivalent_dtype",
"convert_data_type",
"get_dtype",
"get_dtype_string",
"convert_to_cupy",
"convert_to_numpy",
"convert_to_tensor",
Expand Down Expand Up @@ -102,6 +103,13 @@ def get_dtype(data: Any) -> DtypeLike | torch.dtype:
return type(data)


def get_dtype_string(dtype: DtypeLike | torch.dtype) -> str:
"""Get a string representation of the dtype."""
if isinstance(dtype, torch.dtype):
return str(dtype)[6:]
return str(dtype)[3:]


def convert_to_tensor(
data: Any,
dtype: DtypeLike | torch.dtype = None,
Expand Down
Loading
Loading