Skip to content

Commit

Permalink
fix for #6841 lazy argument has no effect in Compose.__call__ (#6862)
Browse files Browse the repository at this point in the history
Fixes #6841

### Description

added check if lazy flag is set, then use it!

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).

No tests added yet, will do asp. However unsure where. I think something
similar as TestComposeExecuteWithLogging would work as the results of
the flag could be made visible there. I dont like the string compares,
but applying the flag is not changing the output, so thats probalby the
best option anyway.

Signed-off-by: elitap <[email protected]>
  • Loading branch information
elitap authored Aug 13, 2023
1 parent a86c0e0 commit faf6f74
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,14 +319,15 @@ def __len__(self):
return len(self.flatten().transforms)

def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None = None):
_lazy = self.lazy if lazy is None else lazy
result = execute_compose(
input_,
transforms=self.transforms,
start=start,
end=end,
map_items=self.map_items,
unpack_items=self.unpack_items,
lazy=self.lazy,
lazy=_lazy,
overrides=self.overrides,
threading=threading,
log_stats=self.log_stats,
Expand All @@ -341,7 +342,7 @@ def inverse(self, data):
if not invertible_transforms:
warnings.warn("inverse has been called but no invertible transforms have been supplied")

if self.lazy is not False:
if self.lazy is True:
warnings.warn(
f"'lazy' is set to {self.lazy} but lazy execution is not supported when inverting. "
f"'lazy' has been overridden to False for the call to inverse"
Expand Down Expand Up @@ -447,7 +448,7 @@ def flatten(self):
weights.append(w)
return OneOf(transforms, weights, self.map_items, self.unpack_items)

def __call__(self, data, start=0, end=None, threading=False, lazy: str | bool | None = None):
def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None = None):
if start != 0:
raise ValueError(f"OneOf requires 'start' parameter to be 0 (start set to {start})")
if end is not None:
Expand All @@ -458,6 +459,7 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: str | bool |

index = self.R.multinomial(1, self.weights).argmax()
_transform = self.transforms[index]
_lazy = self.lazy if lazy is None else lazy

data = execute_compose(
data,
Expand All @@ -466,7 +468,7 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: str | bool |
end=end,
map_items=self.map_items,
unpack_items=self.unpack_items,
lazy=self.lazy,
lazy=_lazy,
overrides=self.overrides,
threading=threading,
log_stats=self.log_stats,
Expand Down Expand Up @@ -553,6 +555,7 @@ def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None

num = len(self.transforms)
applied_order = self.R.permutation(range(num))
_lazy = self.lazy if lazy is None else lazy

input_ = execute_compose(
input_,
Expand All @@ -561,7 +564,7 @@ def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None
end=end,
map_items=self.map_items,
unpack_items=self.unpack_items,
lazy=self.lazy,
lazy=_lazy,
threading=threading,
log_stats=self.log_stats,
)
Expand Down Expand Up @@ -718,6 +721,7 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None =

sample_size = self.R.randint(self.min_num_transforms, self.max_num_transforms + 1)
applied_order = self.R.choice(len(self.transforms), sample_size, replace=self.replace, p=self.weights).tolist()
_lazy = self.lazy if lazy is None else lazy

data = execute_compose(
data,
Expand All @@ -726,7 +730,7 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None =
end=end,
map_items=self.map_items,
unpack_items=self.unpack_items,
lazy=self.lazy,
lazy=_lazy,
overrides=self.overrides,
threading=threading,
log_stats=self.log_stats,
Expand Down

0 comments on commit faf6f74

Please sign in to comment.