From faf6f74aa54486d110ac19192568bd957dcbcfef Mon Sep 17 00:00:00 2001 From: elitap Date: Sun, 13 Aug 2023 22:12:24 +0200 Subject: [PATCH] fix for #6841 lazy argument has no effect in Compose.__call__ (#6862) Fixes #6841 ### Description added check if lazy flag is set, then use it! ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). No tests added yet, will do asp. However unsure where. I think something similar as TestComposeExecuteWithLogging would work as the results of the flag could be made visible there. I dont like the string compares, but applying the flag is not changing the output, so thats probalby the best option anyway. Signed-off-by: elitap --- monai/transforms/compose.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 1614913f5e..0e0093e1bc 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -319,6 +319,7 @@ def __len__(self): return len(self.flatten().transforms) def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None = None): + _lazy = self.lazy if lazy is None else lazy result = execute_compose( input_, transforms=self.transforms, @@ -326,7 +327,7 @@ def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None end=end, map_items=self.map_items, unpack_items=self.unpack_items, - lazy=self.lazy, + lazy=_lazy, overrides=self.overrides, threading=threading, log_stats=self.log_stats, @@ -341,7 +342,7 @@ def inverse(self, data): if not invertible_transforms: warnings.warn("inverse has been called but no invertible transforms have been supplied") - if self.lazy is not False: + if self.lazy is True: warnings.warn( f"'lazy' is set to {self.lazy} but lazy execution is not supported when inverting. " f"'lazy' has been overridden to False for the call to inverse" @@ -447,7 +448,7 @@ def flatten(self): weights.append(w) return OneOf(transforms, weights, self.map_items, self.unpack_items) - def __call__(self, data, start=0, end=None, threading=False, lazy: str | bool | None = None): + def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None = None): if start != 0: raise ValueError(f"OneOf requires 'start' parameter to be 0 (start set to {start})") if end is not None: @@ -458,6 +459,7 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: str | bool | index = self.R.multinomial(1, self.weights).argmax() _transform = self.transforms[index] + _lazy = self.lazy if lazy is None else lazy data = execute_compose( data, @@ -466,7 +468,7 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: str | bool | end=end, map_items=self.map_items, unpack_items=self.unpack_items, - lazy=self.lazy, + lazy=_lazy, overrides=self.overrides, threading=threading, log_stats=self.log_stats, @@ -553,6 +555,7 @@ def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None num = len(self.transforms) applied_order = self.R.permutation(range(num)) + _lazy = self.lazy if lazy is None else lazy input_ = execute_compose( input_, @@ -561,7 +564,7 @@ def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None end=end, map_items=self.map_items, unpack_items=self.unpack_items, - lazy=self.lazy, + lazy=_lazy, threading=threading, log_stats=self.log_stats, ) @@ -718,6 +721,7 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None = sample_size = self.R.randint(self.min_num_transforms, self.max_num_transforms + 1) applied_order = self.R.choice(len(self.transforms), sample_size, replace=self.replace, p=self.weights).tolist() + _lazy = self.lazy if lazy is None else lazy data = execute_compose( data, @@ -726,7 +730,7 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None = end=end, map_items=self.map_items, unpack_items=self.unpack_items, - lazy=self.lazy, + lazy=_lazy, overrides=self.overrides, threading=threading, log_stats=self.log_stats,