Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-channel conversion transforms for brats23 #8112

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,8 @@
CastToType,
ClassesToIndices,
ConvertToMultiChannelBasedOnBratsClasses,
ConvertToMultiChannelBasedOnBrats23Classes,
ConvertToMultiChannelBasedOnBrats23ClassesNoReg,
CuCIM,
DataStats,
EnsureChannelFirst,
Expand Down Expand Up @@ -569,6 +571,12 @@
ConvertToMultiChannelBasedOnBratsClassesd,
ConvertToMultiChannelBasedOnBratsClassesD,
ConvertToMultiChannelBasedOnBratsClassesDict,
ConvertToMultiChannelBasedOnBrats23ClassesD,
ConvertToMultiChannelBasedOnBrats23Classesd,
ConvertToMultiChannelBasedOnBrats23ClassesDict,
ConvertToMultiChannelBasedOnBrats23ClassesNoRegD,
ConvertToMultiChannelBasedOnBrats23ClassesNoRegd,
ConvertToMultiChannelBasedOnBrats23ClassesNoRegDict,
CopyItemsd,
CopyItemsD,
CopyItemsDict,
Expand Down
44 changes: 44 additions & 0 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@
"FgBgToIndices",
"ClassesToIndices",
"ConvertToMultiChannelBasedOnBratsClasses",
"ConvertToMultiChannelBasedOnBrats23Classes",
"ConvertToMultiChannelBasedOnBrats23ClassesNoReg",
"AddExtremePointsChannel",
"TorchVision",
"MapLabelValue",
Expand Down Expand Up @@ -1072,6 +1074,48 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)


class ConvertToMultiChannelBasedOnBrats23Classes(Transform):
"""
Convert labels to multi channels based on brats23 classes:
label 1 is the necrotic and non-enhancing tumor core (NCR)
label 2 is the peritumoral edema (ED)
label 3 is the GD-enhancing tumor (ET)
NOTE: REGION-BASED CONVERSION to TC, WT, ET
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
# if img has channel dim, squeeze it
if img.ndim == 4 and img.shape[0] == 1:
img = img.squeeze(0)

result = [(img == 1) | (img == 3), (img == 1) | (img == 3) | (img == 2), img == 3] # -> tc, wt, et
# merge labels 1 (ncr) and 3 (et) and 2 (ed) to WT
return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)


class ConvertToMultiChannelBasedOnBrats23ClassesNoReg(Transform):
"""
Convert labels to multi channels based on brats23 classes:
label 1 is the necrotic and non-enhancing tumor core (NCR)
label 2 is the peritumoral edema (ED)
label 3 is the GD-enhancing tumor (ET)
NOTE: LABEL-BASED CONVERSION
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
# if img has channel dim, squeeze it
if img.ndim == 4 and img.shape[0] == 1:
img = img.squeeze(0)

result = [(img == 1), (img == 2), (img == 3)]

return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)


class AddExtremePointsChannel(Randomizable, Transform):
"""
Add extreme points of label to the image as a new channel. This transform generates extreme
Expand Down
61 changes: 61 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
CastToType,
ClassesToIndices,
ConvertToMultiChannelBasedOnBratsClasses,
ConvertToMultiChannelBasedOnBrats23Classes,
ConvertToMultiChannelBasedOnBrats23ClassesNoReg,
CuCIM,
DataStats,
EnsureChannelFirst,
Expand Down Expand Up @@ -89,6 +91,12 @@
"ConvertToMultiChannelBasedOnBratsClassesD",
"ConvertToMultiChannelBasedOnBratsClassesDict",
"ConvertToMultiChannelBasedOnBratsClassesd",
"ConvertToMultiChannelBasedOnBrats23ClassesD",
"ConvertToMultiChannelBasedOnBrats23ClassesDict",
"ConvertToMultiChannelBasedOnBrats23Classesd",
"ConvertToMultiChannelBasedOnBrats23NoRegClassesD",
"ConvertToMultiChannelBasedOnBrats23NoRegClassesDict",
"ConvertToMultiChannelBasedOnBrats23NoRegClassesd",
"CopyItemsD",
"CopyItemsDict",
"CopyItemsd",
Expand Down Expand Up @@ -1307,6 +1315,53 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class ConvertToMultiChannelBasedOnBrats23Classesd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBrats23Classes`.
Convert labels to multi channels based on brats23 classes:
label 1 is the necrotic and non-enhancing tumor core
label 2 is the peritumoral edema
label 3 is the GD-enhancing tumor
The possible classes are TC (Tumor core), WT (Whole tumor)
and ET (Enhancing tumor).
"""

backend = ConvertToMultiChannelBasedOnBrats23Classes.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
super().__init__(keys, allow_missing_keys)
self.converter = ConvertToMultiChannelBasedOnBrats23Classes()

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
return d


class ConvertToMultiChannelBasedOnBrats23ClassesNoRegd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBratsClasses`.
Convert labels to multi channels based on brats23 classes:
label 1 is the necrotic and non-enhancing tumor core
label 2 is the peritumoral edema
label 4 is the GD-enhancing tumor
In this case, labels are converted to multi channels. No regions are involved.
"""

backend = ConvertToMultiChannelBasedOnBrats23ClassesNoReg.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
super().__init__(keys, allow_missing_keys)
self.converter = ConvertToMultiChannelBasedOnBrats23ClassesNoReg()

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
return d


class AddExtremePointsChanneld(Randomizable, MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.AddExtremePointsChannel`.
Expand Down Expand Up @@ -1870,6 +1925,12 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
ConvertToMultiChannelBasedOnBratsClassesD = ConvertToMultiChannelBasedOnBratsClassesDict = (
ConvertToMultiChannelBasedOnBratsClassesd
)
ConvertToMultiChannelBasedOnBrats23ClassesD = ConvertToMultiChannelBasedOnBrats23ClassesDict = (
ConvertToMultiChannelBasedOnBrats23Classesd
)
ConvertToMultiChannelBasedOnBrats23ClassesNoRegD = ConvertToMultiChannelBasedOnBrats23ClassesNoRegDict = (
ConvertToMultiChannelBasedOnBrats23ClassesNoRegd
)
AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
TorchVisionD = TorchVisionDict = TorchVisiond
RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond
Expand Down
Loading