Skip to content

Commit

Permalink
Add multi-channel conversion transforms for brats23
Browse files Browse the repository at this point in the history
  • Loading branch information
dani-capellan committed Sep 25, 2024
1 parent fa1c1af commit 4df1d61
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 0 deletions.
8 changes: 8 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,8 @@
CastToType,
ClassesToIndices,
ConvertToMultiChannelBasedOnBratsClasses,
ConvertToMultiChannelBasedOnBrats23Classes,
ConvertToMultiChannelBasedOnBrats23ClassesNoReg,
CuCIM,
DataStats,
EnsureChannelFirst,
Expand Down Expand Up @@ -569,6 +571,12 @@
ConvertToMultiChannelBasedOnBratsClassesd,
ConvertToMultiChannelBasedOnBratsClassesD,
ConvertToMultiChannelBasedOnBratsClassesDict,
ConvertToMultiChannelBasedOnBrats23ClassesD,
ConvertToMultiChannelBasedOnBrats23Classesd,
ConvertToMultiChannelBasedOnBrats23ClassesDict,
ConvertToMultiChannelBasedOnBrats23ClassesNoRegD,
ConvertToMultiChannelBasedOnBrats23ClassesNoRegd,
ConvertToMultiChannelBasedOnBrats23ClassesNoRegDict,
CopyItemsd,
CopyItemsD,
CopyItemsDict,
Expand Down
44 changes: 44 additions & 0 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@
"FgBgToIndices",
"ClassesToIndices",
"ConvertToMultiChannelBasedOnBratsClasses",
"ConvertToMultiChannelBasedOnBrats23Classes",
"ConvertToMultiChannelBasedOnBrats23ClassesNoReg",
"AddExtremePointsChannel",
"TorchVision",
"MapLabelValue",
Expand Down Expand Up @@ -1070,7 +1072,49 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
# merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT
# label 4 is ET
return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)


class ConvertToMultiChannelBasedOnBrats23Classes(Transform):
"""
Convert labels to multi channels based on brats23 classes:
label 1 is the necrotic and non-enhancing tumor core (NCR)
label 2 is the peritumoral edema (ED)
label 3 is the GD-enhancing tumor (ET)
NOTE: REGION-BASED CONVERSION to TC, WT, ET
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
# if img has channel dim, squeeze it
if img.ndim == 4 and img.shape[0] == 1:
img = img.squeeze(0)

result = [(img == 1) | (img == 3), (img == 1) | (img == 3) | (img == 2), img == 3] # -> tc, wt, et
# merge labels 1 (ncr) and 3 (et) and 2 (ed) to WT
return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)


class ConvertToMultiChannelBasedOnBrats23ClassesNoReg(Transform):
"""
Convert labels to multi channels based on brats23 classes:
label 1 is the necrotic and non-enhancing tumor core (NCR)
label 2 is the peritumoral edema (ED)
label 3 is the GD-enhancing tumor (ET)
NOTE: LABEL-BASED CONVERSION
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
# if img has channel dim, squeeze it
if img.ndim == 4 and img.shape[0] == 1:
img = img.squeeze(0)

result = [(img == 1), (img == 2), (img == 3)]

return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)


class AddExtremePointsChannel(Randomizable, Transform):
"""
Expand Down
61 changes: 61 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
CastToType,
ClassesToIndices,
ConvertToMultiChannelBasedOnBratsClasses,
ConvertToMultiChannelBasedOnBrats23Classes,
ConvertToMultiChannelBasedOnBrats23ClassesNoReg,
CuCIM,
DataStats,
EnsureChannelFirst,
Expand Down Expand Up @@ -89,6 +91,12 @@
"ConvertToMultiChannelBasedOnBratsClassesD",
"ConvertToMultiChannelBasedOnBratsClassesDict",
"ConvertToMultiChannelBasedOnBratsClassesd",
"ConvertToMultiChannelBasedOnBrats23ClassesD",
"ConvertToMultiChannelBasedOnBrats23ClassesDict",
"ConvertToMultiChannelBasedOnBrats23Classesd",
"ConvertToMultiChannelBasedOnBrats23NoRegClassesD",
"ConvertToMultiChannelBasedOnBrats23NoRegClassesDict",
"ConvertToMultiChannelBasedOnBrats23NoRegClassesd",
"CopyItemsD",
"CopyItemsDict",
"CopyItemsd",
Expand Down Expand Up @@ -1305,6 +1313,53 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
return d


class ConvertToMultiChannelBasedOnBrats23Classesd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBrats23Classes`.
Convert labels to multi channels based on brats23 classes:
label 1 is the necrotic and non-enhancing tumor core
label 2 is the peritumoral edema
label 3 is the GD-enhancing tumor
The possible classes are TC (Tumor core), WT (Whole tumor)
and ET (Enhancing tumor).
"""

backend = ConvertToMultiChannelBasedOnBrats23Classes.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
super().__init__(keys, allow_missing_keys)
self.converter = ConvertToMultiChannelBasedOnBrats23Classes()

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
return d


class ConvertToMultiChannelBasedOnBrats23ClassesNoRegd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBratsClasses`.
Convert labels to multi channels based on brats23 classes:
label 1 is the necrotic and non-enhancing tumor core
label 2 is the peritumoral edema
label 4 is the GD-enhancing tumor
In this case, labels are converted to multi channels. No regions are involved.
"""

backend = ConvertToMultiChannelBasedOnBrats23ClassesNoReg.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
super().__init__(keys, allow_missing_keys)
self.converter = ConvertToMultiChannelBasedOnBrats23ClassesNoReg()

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
return d


class AddExtremePointsChanneld(Randomizable, MapTransform):
Expand Down Expand Up @@ -1870,6 +1925,12 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
ConvertToMultiChannelBasedOnBratsClassesD = ConvertToMultiChannelBasedOnBratsClassesDict = (
ConvertToMultiChannelBasedOnBratsClassesd
)
ConvertToMultiChannelBasedOnBrats23ClassesD = ConvertToMultiChannelBasedOnBrats23ClassesDict = (
ConvertToMultiChannelBasedOnBrats23Classesd
)
ConvertToMultiChannelBasedOnBrats23ClassesNoRegD = ConvertToMultiChannelBasedOnBrats23ClassesNoRegDict = (
ConvertToMultiChannelBasedOnBrats23ClassesNoRegd
)
AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
TorchVisionD = TorchVisionDict = TorchVisiond
RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond
Expand Down

0 comments on commit 4df1d61

Please sign in to comment.