Skip to content

Commit

Permalink
catch exceptions in state generation
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed Apr 4, 2024
1 parent 9c9bd54 commit 0e46b4e
Showing 1 changed file with 24 additions and 23 deletions.
47 changes: 24 additions & 23 deletions torchdata/stateful_dataloader/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,36 +183,37 @@ def _worker_loop(
data = init_exception
init_exception = None
else:
assert fetcher is not None
try:
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
except Exception as e:
if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
try:
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
except StopIteration:
if not dataset_kind == _DatasetKind.Iterable:
raise
data = _IterableDatasetStopIteration(worker_id)
# Set `iteration_end`
# (1) to save future `next(...)` calls, and
# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
iteration_end = True
else:
# It is important that we don't store exc_info in a variable.
# `ExceptionWrapper` does the correct thing.
# See NOTE [ Python Traceback Reference Cycle Problem ]
data = ExceptionWrapper(where=f"in DataLoader worker process {worker_id}")
if snapshot or iteration_end:
if dataset_kind == _DatasetKind.Iterable:
fetcher_state = {
"dataset_iter": try_to_serialize(fetcher.dataset_iter),
"ended": fetcher.ended,
if snapshot or iteration_end:
if dataset_kind == _DatasetKind.Iterable:
fetcher_state = {
"dataset_iter": try_to_serialize(fetcher.dataset_iter),
"ended": fetcher.ended,
}
else:
fetcher_state = None
# Pick up any user-defined dataset state, for both map/iterable style datasets
dataset_state = try_to_serialize(dataset)
state_dict = {
"worker_id": worker_id,
"fetcher_state": fetcher_state,
"dataset_state": dataset_state,
}
else:
fetcher_state = None
# Pick up any user-defined dataset state, for both map/iterable style datasets
dataset_state = try_to_serialize(dataset)
state_dict = {
"worker_id": worker_id,
"fetcher_state": fetcher_state,
"dataset_state": dataset_state,
}
except Exception:
# It is important that we don't store exc_info in a variable.
# `ExceptionWrapper` does the correct thing.
# See NOTE [ Python Traceback Reference Cycle Problem ]
data = ExceptionWrapper(where=f"in DataLoader worker process {worker_id}")
data_queue.put((idx, (data, worker_id, state_dict)))
del data, idx, index, r, state_dict # save memory
except KeyboardInterrupt:
Expand Down

0 comments on commit 0e46b4e

Please sign in to comment.