Skip to content

Commit

Permalink
Initial commit of a solution for issue #7130
Browse files Browse the repository at this point in the history
Signed-off-by: Ben Murray <[email protected]>
  • Loading branch information
atbenmurray committed Oct 13, 2023
1 parent c7a6cca commit 1d78c33
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
10 changes: 5 additions & 5 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def set_transform_hash(self, hash_xform_func: Callable[..., bytes]):
inherit from MONAI's `Transform` class."""
hashable_transforms = []
for _tr in self.transform.flatten().transforms:
if isinstance(_tr, RandomizableTrait) or not isinstance(_tr, Transform):
if (isinstance(_tr, RandomizableTrait) and _tr.is_random() is True) or not isinstance(_tr, Transform):
break
hashable_transforms.append(_tr)
# Try to hash. Fall back to a hash of their names
Expand Down Expand Up @@ -327,7 +327,7 @@ def _pre_transform(self, item_transformed):
raise ValueError("transform must be an instance of monai.transforms.Compose.")

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
lambda t: (isinstance(t, RandomizableTrait) and t.is_random() is True) or not isinstance(t, Transform)
)
item_transformed = self.transform(item_transformed, end=first_random, threading=True)

Expand All @@ -350,7 +350,7 @@ def _post_transform(self, item_transformed):
raise ValueError("transform must be an instance of monai.transforms.Compose.")

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
lambda t: isinstance(t, RandomizableTrait) and t.is_random() is True or not isinstance(t, Transform)
)
if first_random is not None:
item_transformed = self.transform(item_transformed, start=first_random)
Expand Down Expand Up @@ -887,7 +887,7 @@ def _load_cache_item(self, idx: int):
item = self.data[idx]

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
lambda t: (isinstance(t, RandomizableTrait) and t.is_random() is True) or not isinstance(t, Transform)
)
item = self.transform(item, end=first_random, threading=True)

Expand Down Expand Up @@ -921,7 +921,7 @@ def _transform(self, index: int):
raise ValueError("transform must be an instance of monai.transforms.Compose.")

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
lambda t: (isinstance(t, RandomizableTrait) and t.is_random() is True) or not isinstance(t, Transform)
)
if first_random is not None:
data = deepcopy(data) if self.copy_cache is True else data
Expand Down
21 changes: 21 additions & 0 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,18 @@ def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None

return result

def is_random(self):
def _recursive_check(tx):
if isinstance(Randomizable) and tx.is_random():
return True
return False

for t in self.transforms:
if _recursive_check(t) is True:
return True

return False

def inverse(self, data):
self._raise_if_not_invertible(data)

Expand Down Expand Up @@ -481,6 +493,9 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: str | bool |
self.push_transform(data[key], extra_info={"index": index})
return data

def is_random(self):
return True

def inverse(self, data):
if len(self.transforms) == 0:
return data
Expand Down Expand Up @@ -575,6 +590,9 @@ def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None
self.push_transform(input_[key], extra_info={"applied_order": applied_order})
return input_

def is_random(self):
return True

def inverse(self, data):
if len(self.transforms) == 0:
return data
Expand Down Expand Up @@ -740,6 +758,9 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None =

return data

def is_random(self):
return True

# From RandomOrder
def inverse(self, data):
if len(self.transforms) == 0:
Expand Down
12 changes: 12 additions & 0 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,18 @@ def randomize(self, data: Any) -> None:
NotImplementedError: When the subclass does not override this method.
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

def is_random(self):
"""
This method indicates whether this particular instance of a Randomizable is operating randomly or not.
This allows objects that have the capacity to act in a random fashion to indicate whether they are acting
randomly or not, given their current state.
This method should be overridden by objects that have this capacity. ``Compose`` is an example of this.
Returns:
True if the object is acting randomly, False otherwise
"""
return True


class Transform(ABC):
Expand Down

0 comments on commit 1d78c33

Please sign in to comment.