From 4df1d61b4d291403a2ea32a45880e0b3fa5c1a6d Mon Sep 17 00:00:00 2001 From: Daniel Capellan-Martin Date: Wed, 25 Sep 2024 14:34:35 +0000 Subject: [PATCH] Add multi-channel conversion transforms for brats23 --- monai/transforms/__init__.py | 8 ++++ monai/transforms/utility/array.py | 44 +++++++++++++++++++ monai/transforms/utility/dictionary.py | 61 ++++++++++++++++++++++++++ 3 files changed, 113 insertions(+) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2cdd965c91..dde0b4cc60 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -516,6 +516,8 @@ CastToType, ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, + ConvertToMultiChannelBasedOnBrats23Classes, + ConvertToMultiChannelBasedOnBrats23ClassesNoReg, CuCIM, DataStats, EnsureChannelFirst, @@ -569,6 +571,12 @@ ConvertToMultiChannelBasedOnBratsClassesd, ConvertToMultiChannelBasedOnBratsClassesD, ConvertToMultiChannelBasedOnBratsClassesDict, + ConvertToMultiChannelBasedOnBrats23ClassesD, + ConvertToMultiChannelBasedOnBrats23Classesd, + ConvertToMultiChannelBasedOnBrats23ClassesDict, + ConvertToMultiChannelBasedOnBrats23ClassesNoRegD, + ConvertToMultiChannelBasedOnBrats23ClassesNoRegd, + ConvertToMultiChannelBasedOnBrats23ClassesNoRegDict, CopyItemsd, CopyItemsD, CopyItemsDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 72dd189009..fd1c987902 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -97,6 +97,8 @@ "FgBgToIndices", "ClassesToIndices", "ConvertToMultiChannelBasedOnBratsClasses", + "ConvertToMultiChannelBasedOnBrats23Classes", + "ConvertToMultiChannelBasedOnBrats23ClassesNoReg", "AddExtremePointsChannel", "TorchVision", "MapLabelValue", @@ -1070,7 +1072,49 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT # label 4 is ET return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0) + +class ConvertToMultiChannelBasedOnBrats23Classes(Transform): + """ + Convert labels to multi channels based on brats23 classes: + label 1 is the necrotic and non-enhancing tumor core (NCR) + label 2 is the peritumoral edema (ED) + label 3 is the GD-enhancing tumor (ET) + NOTE: REGION-BASED CONVERSION to TC, WT, ET + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + # if img has channel dim, squeeze it + if img.ndim == 4 and img.shape[0] == 1: + img = img.squeeze(0) + + result = [(img == 1) | (img == 3), (img == 1) | (img == 3) | (img == 2), img == 3] # -> tc, wt, et + # merge labels 1 (ncr) and 3 (et) and 2 (ed) to WT + return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0) + + +class ConvertToMultiChannelBasedOnBrats23ClassesNoReg(Transform): + """ + Convert labels to multi channels based on brats23 classes: + label 1 is the necrotic and non-enhancing tumor core (NCR) + label 2 is the peritumoral edema (ED) + label 3 is the GD-enhancing tumor (ET) + NOTE: LABEL-BASED CONVERSION + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + # if img has channel dim, squeeze it + if img.ndim == 4 and img.shape[0] == 1: + img = img.squeeze(0) + + result = [(img == 1), (img == 2), (img == 3)] + + return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0) + class AddExtremePointsChannel(Randomizable, Transform): """ diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 79d0be522d..a87944a563 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -40,6 +40,8 @@ CastToType, ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, + ConvertToMultiChannelBasedOnBrats23Classes, + ConvertToMultiChannelBasedOnBrats23ClassesNoReg, CuCIM, DataStats, EnsureChannelFirst, @@ -89,6 +91,12 @@ "ConvertToMultiChannelBasedOnBratsClassesD", "ConvertToMultiChannelBasedOnBratsClassesDict", "ConvertToMultiChannelBasedOnBratsClassesd", + "ConvertToMultiChannelBasedOnBrats23ClassesD", + "ConvertToMultiChannelBasedOnBrats23ClassesDict", + "ConvertToMultiChannelBasedOnBrats23Classesd", + "ConvertToMultiChannelBasedOnBrats23NoRegClassesD", + "ConvertToMultiChannelBasedOnBrats23NoRegClassesDict", + "ConvertToMultiChannelBasedOnBrats23NoRegClassesd", "CopyItemsD", "CopyItemsDict", "CopyItemsd", @@ -1305,6 +1313,53 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d + + +class ConvertToMultiChannelBasedOnBrats23Classesd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBrats23Classes`. + Convert labels to multi channels based on brats23 classes: + label 1 is the necrotic and non-enhancing tumor core + label 2 is the peritumoral edema + label 3 is the GD-enhancing tumor + The possible classes are TC (Tumor core), WT (Whole tumor) + and ET (Enhancing tumor). + """ + + backend = ConvertToMultiChannelBasedOnBrats23Classes.backend + + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): + super().__init__(keys, allow_missing_keys) + self.converter = ConvertToMultiChannelBasedOnBrats23Classes() + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + return d + + +class ConvertToMultiChannelBasedOnBrats23ClassesNoRegd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBratsClasses`. + Convert labels to multi channels based on brats23 classes: + label 1 is the necrotic and non-enhancing tumor core + label 2 is the peritumoral edema + label 4 is the GD-enhancing tumor + In this case, labels are converted to multi channels. No regions are involved. + """ + + backend = ConvertToMultiChannelBasedOnBrats23ClassesNoReg.backend + + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): + super().__init__(keys, allow_missing_keys) + self.converter = ConvertToMultiChannelBasedOnBrats23ClassesNoReg() + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + return d class AddExtremePointsChanneld(Randomizable, MapTransform): @@ -1870,6 +1925,12 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch ConvertToMultiChannelBasedOnBratsClassesD = ConvertToMultiChannelBasedOnBratsClassesDict = ( ConvertToMultiChannelBasedOnBratsClassesd ) +ConvertToMultiChannelBasedOnBrats23ClassesD = ConvertToMultiChannelBasedOnBrats23ClassesDict = ( + ConvertToMultiChannelBasedOnBrats23Classesd +) +ConvertToMultiChannelBasedOnBrats23ClassesNoRegD = ConvertToMultiChannelBasedOnBrats23ClassesNoRegDict = ( + ConvertToMultiChannelBasedOnBrats23ClassesNoRegd +) AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld TorchVisionD = TorchVisionDict = TorchVisiond RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond