From 1d78c33a5f5bed1a3698d493dbdff7893e3231a7 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 13 Oct 2023 17:03:26 +0100 Subject: [PATCH] Initial commit of a solution for issue #7130 Signed-off-by: Ben Murray --- monai/data/dataset.py | 10 +++++----- monai/transforms/compose.py | 21 +++++++++++++++++++++ monai/transforms/transform.py | 12 ++++++++++++ 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 4f2061426e..629a30257c 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -288,7 +288,7 @@ def set_transform_hash(self, hash_xform_func: Callable[..., bytes]): inherit from MONAI's `Transform` class.""" hashable_transforms = [] for _tr in self.transform.flatten().transforms: - if isinstance(_tr, RandomizableTrait) or not isinstance(_tr, Transform): + if (isinstance(_tr, RandomizableTrait) and _tr.is_random() is True) or not isinstance(_tr, Transform): break hashable_transforms.append(_tr) # Try to hash. Fall back to a hash of their names @@ -327,7 +327,7 @@ def _pre_transform(self, item_transformed): raise ValueError("transform must be an instance of monai.transforms.Compose.") first_random = self.transform.get_index_of_first( - lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + lambda t: (isinstance(t, RandomizableTrait) and t.is_random() is True) or not isinstance(t, Transform) ) item_transformed = self.transform(item_transformed, end=first_random, threading=True) @@ -350,7 +350,7 @@ def _post_transform(self, item_transformed): raise ValueError("transform must be an instance of monai.transforms.Compose.") first_random = self.transform.get_index_of_first( - lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + lambda t: isinstance(t, RandomizableTrait) and t.is_random() is True or not isinstance(t, Transform) ) if first_random is not None: item_transformed = self.transform(item_transformed, start=first_random) @@ -887,7 +887,7 @@ def _load_cache_item(self, idx: int): item = self.data[idx] first_random = self.transform.get_index_of_first( - lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + lambda t: (isinstance(t, RandomizableTrait) and t.is_random() is True) or not isinstance(t, Transform) ) item = self.transform(item, end=first_random, threading=True) @@ -921,7 +921,7 @@ def _transform(self, index: int): raise ValueError("transform must be an instance of monai.transforms.Compose.") first_random = self.transform.get_index_of_first( - lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + lambda t: (isinstance(t, RandomizableTrait) and t.is_random() is True) or not isinstance(t, Transform) ) if first_random is not None: data = deepcopy(data) if self.copy_cache is True else data diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 1614913f5e..816d94537d 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -334,6 +334,18 @@ def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None return result + def is_random(self): + def _recursive_check(tx): + if isinstance(Randomizable) and tx.is_random(): + return True + return False + + for t in self.transforms: + if _recursive_check(t) is True: + return True + + return False + def inverse(self, data): self._raise_if_not_invertible(data) @@ -481,6 +493,9 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: str | bool | self.push_transform(data[key], extra_info={"index": index}) return data + def is_random(self): + return True + def inverse(self, data): if len(self.transforms) == 0: return data @@ -575,6 +590,9 @@ def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None self.push_transform(input_[key], extra_info={"applied_order": applied_order}) return input_ + def is_random(self): + return True + def inverse(self, data): if len(self.transforms) == 0: return data @@ -740,6 +758,9 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None = return data + def is_random(self): + return True + # From RandomOrder def inverse(self, data): if len(self.transforms) == 0: diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index e35335ba0e..a77b6098b7 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -230,6 +230,18 @@ def randomize(self, data: Any) -> None: NotImplementedError: When the subclass does not override this method. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + def is_random(self): + """ + This method indicates whether this particular instance of a Randomizable is operating randomly or not. + This allows objects that have the capacity to act in a random fashion to indicate whether they are acting + randomly or not, given their current state. + This method should be overridden by objects that have this capacity. ``Compose`` is an example of this. + + Returns: + True if the object is acting randomly, False otherwise + """ + return True class Transform(ABC):