Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image space and World Space convert transforms #7942

Merged
merged 67 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
c4a67d8
add image and world space convert transform
KumoLiu Jul 23, 2024
558a1da
Merge remote-tracking branch 'origin/geometric' into world2image
KumoLiu Jul 24, 2024
87550de
combine to one transform
KumoLiu Jul 25, 2024
46ec145
add dictionary version
KumoLiu Jul 25, 2024
7930bed
add enum
KumoLiu Jul 25, 2024
52f4253
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2024
25c3686
support invert
KumoLiu Jul 25, 2024
7ed5f97
include demo notebook
KumoLiu Jul 25, 2024
3b29c2e
Merge branch 'world2image' of https://github.com/KumoLiu/MONAI into w…
KumoLiu Jul 25, 2024
6e06d65
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2024
a9f1704
fix format
KumoLiu Jul 25, 2024
5647a8d
upload link
KumoLiu Jul 26, 2024
3d3bacd
remove invert
KumoLiu Jul 26, 2024
4f4942c
ensure affine
KumoLiu Jul 30, 2024
0dfdf7e
update notebook
KumoLiu Jul 30, 2024
be7c897
add coronal and sagittal cases
KumoLiu Aug 1, 2024
e7b3e88
add `apply_affine_to_points`
KumoLiu Aug 1, 2024
9b6e01c
update test cases
KumoLiu Aug 2, 2024
2651de4
add luna pipeline
KumoLiu Aug 8, 2024
a819491
add `get_dtype_string`
KumoLiu Aug 15, 2024
9425248
add type hint
KumoLiu Aug 15, 2024
83de050
add type hint
KumoLiu Aug 15, 2024
c0b3bb9
add shape check
KumoLiu Aug 15, 2024
d196339
add warning
KumoLiu Aug 15, 2024
7d4cd68
Merge branch 'geometric' into world2image
KumoLiu Aug 19, 2024
83f59b4
support multi-channel
KumoLiu Aug 22, 2024
87b0c47
support 2d points
KumoLiu Aug 22, 2024
524d509
enhance coordinate trans
KumoLiu Aug 22, 2024
6421c0b
minor modify
KumoLiu Aug 22, 2024
a4a14c5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2024
4853f25
update
KumoLiu Aug 22, 2024
608816b
Merge branch 'world2image' of https://github.com/KumoLiu/MONAI into w…
KumoLiu Aug 22, 2024
5c4637a
minor fix
KumoLiu Aug 22, 2024
93afb03
add test
KumoLiu Aug 22, 2024
5e6fa65
minor modify
KumoLiu Aug 22, 2024
5ed2db6
minor fix
KumoLiu Aug 22, 2024
690a3c1
address comments
KumoLiu Aug 23, 2024
54422e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2024
3a18f10
add unittests
KumoLiu Aug 23, 2024
6d676ca
Merge branch 'world2image' of https://github.com/KumoLiu/MONAI into w…
KumoLiu Aug 23, 2024
d2319ad
modify test cases
KumoLiu Aug 23, 2024
5761a7b
fix format
KumoLiu Aug 23, 2024
6eb23b9
address comments
KumoLiu Aug 26, 2024
34d64c1
address comments
KumoLiu Aug 26, 2024
d535b4a
format fix
KumoLiu Aug 26, 2024
16f14a1
Update monai/transforms/utility/array.py
KumoLiu Aug 27, 2024
b9e3984
address comments
KumoLiu Aug 27, 2024
d6cc8f0
Update monai/transforms/utility/array.py
KumoLiu Aug 27, 2024
48f97b5
enhance docstring
KumoLiu Aug 27, 2024
0b200bd
address comments
KumoLiu Aug 27, 2024
7a91b1e
fix format
KumoLiu Aug 27, 2024
1469af7
remove notebook
KumoLiu Aug 27, 2024
13441b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
4d258d1
fix format
KumoLiu Aug 27, 2024
b1b5a87
enhance docstring
KumoLiu Aug 27, 2024
815b702
fix mypy
KumoLiu Aug 27, 2024
0724878
fix format
KumoLiu Aug 27, 2024
ffab390
address comments
KumoLiu Aug 28, 2024
5808439
Update monai/transforms/utility/array.py
KumoLiu Aug 28, 2024
729f11b
address comments
KumoLiu Aug 28, 2024
6f34194
Merge branch 'world2image' of https://github.com/KumoLiu/MONAI into w…
KumoLiu Aug 28, 2024
48707ed
fix format
KumoLiu Aug 28, 2024
a7d201b
add in init
KumoLiu Aug 28, 2024
f21f967
fix format
KumoLiu Aug 28, 2024
c078fe2
fix format
KumoLiu Aug 28, 2024
0c548c4
fix docstring format
KumoLiu Aug 28, 2024
ee2ac77
enhance docstring
KumoLiu Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 113 additions & 4 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import is_no_channel, no_collation
from monai.data.utils import is_no_channel, no_collation, orientation_ras_lps
from monai.networks.layers.simplelayers import (
ApplyFilter,
EllipticalFilter,
Expand All @@ -42,16 +42,17 @@
SharpenFilter,
median_filter,
)
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.inverse import InvertibleTransform, TraceableTransform
from monai.transforms.traits import MultiSampleTrait
from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform
from monai.transforms.utils import (
apply_affine_to_points,
extreme_points_to_image,
get_extreme_points,
map_binary_to_indices,
map_classes_to_indices,
)
from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices
from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, linalg_inv, moveaxis, unravel_indices
from monai.utils import (
MetaKeys,
TraceKeys,
Expand All @@ -66,7 +67,7 @@
)
from monai.utils.enums import TransformBackends
from monai.utils.misc import is_module_ver_at_least
from monai.utils.type_conversion import convert_to_dst_type, get_equivalent_dtype
from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype

PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray")
Expand Down Expand Up @@ -1715,3 +1716,111 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Mapping | None = None) -> Nd
if self._do_transform:
img = self.filter(img)
return img


class ApplyTransformToPoints(InvertibleTransform, Transform):
mingxin-zheng marked this conversation as resolved.
Show resolved Hide resolved
"""
Transform points between image coordinates and world coordinates.
ericspod marked this conversation as resolved.
Show resolved Hide resolved

Args:
dtype: The desired data type for the output.
affine: A 3x3 or 4x4 affine transformation matrix applied to points. Typically, this matrix originates from the image.
invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``.
Typically, the affine matrix is derived from the image, while the points are in world coordinates.
If you want to align the points with the image, set this to ``True``. Otherwise, set it to ``False``.
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
affine_lps_to_ras: Defaults to ``False``. Set this to ``True`` if all of the following are true:
1) The image is read by `ITKReader`,
2) The `ITKReader` has `affine_lps_to_ras=True`,
3) The data is in world coordinates.
This ensures the correct application of the affine transformation between LPS (left-posterior-superior)
and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine
matrix are in the same coordinate system.
"""

def __init__(
self,
dtype: DtypeLike | torch.dtype = torch.float64,
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
affine: torch.Tensor | None = None,
invert_affine: bool = True,
affine_lps_to_ras: bool = False,
) -> None:
self.dtype = dtype
self.affine = affine
self.invert_affine = invert_affine
self.affine_lps_to_ras = affine_lps_to_ras

def transform_coordinates(self, data: torch.Tensor, affine: torch.Tensor | None = None):
"""
Transform coordinates using an affine transformation matrix.

Args:
data: The input coordinates, assume to be in shape (C, N, 2 or 3).
affine: A 3x3 or 4x4 affine transformation matrix.

Returns:
Transformed coordinates.
"""
data = convert_to_tensor(data, track_meta=get_track_meta())
applied_affine = data.affine if isinstance(data, MetaTensor) else None
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

if affine is None and self.invert_affine:
raise ValueError("affine must be provided when invert_affine is True.")

affine = applied_affine if affine is None else affine
affine = convert_data_type(affine, dtype=torch.float64)[0] # always convert to float64 for affine
ericspod marked this conversation as resolved.
Show resolved Hide resolved
original_affine = affine
if self.affine_lps_to_ras:
affine = orientation_ras_lps(affine)

_affine = affine
if self.invert_affine:
_affine = linalg_inv(affine)
# consider the affine transformation already applied to the data in the world space
if applied_affine is not None:
_affine = _affine @ linalg_inv(applied_affine)
out = apply_affine_to_points(data, _affine, self.dtype)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

extra_info = {
"invert_affine": self.invert_affine,
"dtype": get_dtype_string(self.dtype),
"image_affine": original_affine,
"affine_lps_to_ras": self.affine_lps_to_ras,
}
xform = original_affine if self.invert_affine else linalg_inv(original_affine)
meta_info = TraceableTransform.track_transform_meta(
data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info()
)

return out, meta_info

def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None) -> torch.Tensor:
"""
Args:
data: The input coordinates, assume to be in shape (C, N, 2 or 3).
affine: A 3x3 or 4x4 affine transformation matrix, this argument will take precedence over ``self.affine``.
"""
if data.ndim != 3 or data.shape[-1] not in (2, 3):
raise ValueError(f"data should be in shape (C, N, 2 or 3), got {data.shape}.")
affine = self.affine if affine is None else affine
if affine is not None and affine.shape not in ((3, 3), (4, 4)):
raise ValueError(f"affine should be in shape (3, 3) or (4, 4), got {affine.shape}.")

out, meta_info = self.transform_coordinates(data, affine)

return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out

def inverse(self, data: torch.Tensor) -> torch.Tensor:
transform = self.pop_transform(data)
# Create inverse transform
dtype = transform[TraceKeys.EXTRA_INFO]["dtype"]
invert_affine = not transform[TraceKeys.EXTRA_INFO]["invert_affine"]
affine = transform[TraceKeys.EXTRA_INFO]["image_affine"]
affine_lps_to_ras = transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"]
inverse_transform = ApplyTransformToPoints(
dtype=dtype, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
)
# Apply inverse
with inverse_transform.trace_transform(False):
data = inverse_transform(data, affine)

return data
64 changes: 64 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import re
import warnings
from collections.abc import Callable, Hashable, Mapping
from copy import deepcopy
from typing import Any, Sequence, cast
Expand All @@ -35,6 +36,7 @@
from monai.transforms.utility.array import (
AddCoordinateChannels,
AddExtremePointsChannel,
ApplyTransformToPoints,
AsChannelLast,
CastToType,
ClassesToIndices,
Expand Down Expand Up @@ -1740,6 +1742,68 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ApplyTransformToPoints`.

Args:
keys: keys of the corresponding items to be transformed.
See also: monai.transforms.MapTransform
refer_key: the key of the reference item used for transformation.
dtype: The desired data type for the output.
affine: A 3x3 or 4x4 affine transformation matrix applied to points. Typically, this matrix originates from the image.
invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``.
Typically, the affine matrix is derived from the image, while the points are in world coordinates.
If you want to align the points with the image, set this to ``True``. Otherwise, set it to ``False``.
affine_lps_to_ras: Defaults to ``False``. Set this to ``True`` if all of the following are true:
vikashg marked this conversation as resolved.
Show resolved Hide resolved
1) The image is read by `ITKReader`,
2) The `ITKReader` has `affine_lps_to_ras=True`,
3) The data is in world coordinates.
This ensures the correct application of the affine transformation between LPS (left-posterior-superior)
and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine
matrix are in the same coordinate system.
allow_missing_keys: Don't raise exception if key is missing.
"""

def __init__(
self,
keys: KeysCollection,
refer_key: str | None = None,
dtype: DtypeLike | torch.dtype = torch.float64,
affine: torch.Tensor | None = None,
invert_affine: bool = True,
affine_lps_to_ras: bool = False,
allow_missing_keys: bool = False,
):
MapTransform.__init__(self, keys, allow_missing_keys)
self.refer_key = refer_key
self.affine = affine
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
if self.refer_key is None and self.affine is None:
warnings.warn("No reference data or affine matrix is provided, will use the affine derived from the data.")
self.converter = ApplyTransformToPoints(
dtype=dtype, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
)

def __call__(self, data: Mapping[Hashable, torch.Tensor]):
d = dict(data)
refer_data = d[self.refer_key] if self.refer_key is not None else None
mingxin-zheng marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(refer_data, MetaTensor):
affine = refer_data.affine
else:
warnings.warn("No reference affine find in the refer key, will use the affine derived from the data.")
affine = self.affine
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
for key in self.key_iterator(d):
coords = d[key]
d[key] = self.converter(coords, affine)
return d

def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter.inverse(d[key])
return d


RandImageFilterD = RandImageFilterDict = RandImageFilterd
ImageFilterD = ImageFilterDict = ImageFilterd
IdentityD = IdentityDict = Identityd
Expand Down
644 changes: 644 additions & 0 deletions monai/transforms/utility/temp_test.ipynb

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import monai
from monai.config import DtypeLike, IndexSelection
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.data.utils import to_affine_nd
from monai.networks.layers import GaussianFilter
from monai.networks.utils import meshgrid_ij
from monai.transforms.compose import Compose
Expand All @@ -35,6 +36,7 @@
from monai.transforms.utils_pytorch_numpy_unification import (
any_np_pt,
ascontiguousarray,
concatenate,
cumsum,
isfinite,
nonzero,
Expand Down Expand Up @@ -2509,5 +2511,25 @@ def distance_transform_edt(
return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0]


def apply_affine_to_points(data: torch.Tensor, affine: torch.Tensor, dtype: DtypeLike | torch.dtype):
"""
apply affine transformation to a set of points.

Args:
data: input data to apply affine transformation, should be a tensor of shape (C, N, 2 or 3).
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
affine: affine matrix to be applied, should be a tensor of shape (3, 3) or (4, 4).
dtype: output data dtype.
"""
data_: torch.Tensor = convert_to_tensor(data, track_meta=False, dtype=torch.float64)
affine = to_affine_nd(data_.shape[-1], affine)

homogeneous = concatenate((data_, torch.ones((data_.shape[0], data_.shape[1], 1))), axis=2)
transformed_homogeneous = torch.matmul(homogeneous, affine.T)
transformed_coordinates = transformed_homogeneous[:, :, :-1]
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
out, *_ = convert_to_dst_type(transformed_coordinates, data, dtype=dtype)

return out


if __name__ == "__main__":
print_transform_backends()
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
dtype_numpy_to_torch,
dtype_torch_to_numpy,
get_dtype,
get_dtype_string,
get_equivalent_dtype,
get_numpy_dtype_from_string,
get_torch_dtype_from_string,
Expand Down
8 changes: 8 additions & 0 deletions monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"get_equivalent_dtype",
"convert_data_type",
"get_dtype",
"get_dtype_string",
"convert_to_cupy",
"convert_to_numpy",
"convert_to_tensor",
Expand Down Expand Up @@ -102,6 +103,13 @@ def get_dtype(data: Any) -> DtypeLike | torch.dtype:
return type(data)


def get_dtype_string(dtype: DtypeLike | torch.dtype) -> str:
"""Get a string representation of the dtype."""
if isinstance(dtype, torch.dtype):
return str(dtype)[6:]
return str(dtype)[3:]


def convert_to_tensor(
data: Any,
dtype: DtypeLike | torch.dtype = None,
Expand Down
Loading
Loading