From e1e3d8ebc1c7247aad9f1bffc649c5a20084340f Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Thu, 19 Dec 2024 02:50:42 +0000 Subject: [PATCH] Modify Workflow to Allow IterableDataset Inputs (#8263) ### Description This modifies the behaviour of `Workflow` to permit `IterableDataset` to be used correctly. A check against the `epoch_length` value is removed, to allow that value to be `None`, and a test is added to verify this. The length of a data loader is not defined when using iterable datasets, so try/raise is added to allow that to be queried safely. This is related to my work on the streaming support, in my [prototype gist](https://gist.github.com/ericspod/1904713716b45631260784ac3fcd6fb3) I had to provide a bogus epoch length value in the then change it to `None` later once the evaluator object was created. This PR will remove the need for this hack. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot Signed-off-by: Eric Kerfoot Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot --- monai/engines/workflow.py | 22 +++++++++++----------- tests/test_iterable_dataset.py | 13 +++++++++++++ 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 3629659db1..0c36da6d3d 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -12,7 +12,7 @@ from __future__ import annotations import warnings -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence, Sized from typing import TYPE_CHECKING, Any import torch @@ -121,24 +121,24 @@ def __init__( to_kwargs: dict | None = None, amp_kwargs: dict | None = None, ) -> None: - if iteration_update is not None: - super().__init__(iteration_update) - else: - super().__init__(self._iteration) + super().__init__(self._iteration if iteration_update is None else iteration_update) if isinstance(data_loader, DataLoader): - sampler = data_loader.__dict__["sampler"] + sampler = getattr(data_loader, "sampler", None) + + # set the epoch value for DistributedSampler objects when an epoch starts if isinstance(sampler, DistributedSampler): @self.on(Events.EPOCH_STARTED) def set_sampler_epoch(engine: Engine) -> None: sampler.set_epoch(engine.state.epoch) - if epoch_length is None: + # if the epoch_length isn't given, attempt to get it from the length of the data loader + if epoch_length is None and isinstance(data_loader, Sized): + try: epoch_length = len(data_loader) - else: - if epoch_length is None: - raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.") + except TypeError: # raised when data_loader has an iterable dataset with no length, or is some other type + pass # deliberately leave epoch_length as None # set all sharable data for the workflow based on Ignite engine.state self.state: Any = State( @@ -147,7 +147,7 @@ def set_sampler_epoch(engine: Engine) -> None: iteration=0, epoch=0, max_epochs=max_epochs, - epoch_length=epoch_length, + epoch_length=epoch_length, # None when the dataset is iterable and so has no length output=None, batch=None, metrics={}, diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index cfa711e4c0..fb554e391c 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -18,8 +18,10 @@ import nibabel as nib import numpy as np +import torch.nn as nn from monai.data import DataLoader, Dataset, IterableDataset +from monai.engines import SupervisedEvaluator from monai.transforms import Compose, LoadImaged, SimulateDelayd @@ -59,6 +61,17 @@ def test_shape(self): for d in dataloader: self.assertTupleEqual(d["image"].shape[1:], expected_shape) + def test_supervisedevaluator(self): + """ + Test that a SupervisedEvaluator is compatible with IterableDataset in conjunction with DataLoader. + """ + data = list(range(10)) + dl = DataLoader(IterableDataset(data)) + evaluator = SupervisedEvaluator(device="cpu", val_data_loader=dl, network=nn.Identity()) + evaluator.run() # fails if the epoch length or other internal setup is not done correctly + + self.assertEqual(evaluator.state.iteration, len(data)) + if __name__ == "__main__": unittest.main()