Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed Apr 4, 2024
1 parent b94e31d commit 2e15563
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
6 changes: 3 additions & 3 deletions torchdata/stateful_dataloader/stateful.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Dict, Optional, Protocol, runtime_checkable
from typing import Any, Dict, Protocol, runtime_checkable


@runtime_checkable
class Stateful(Protocol):
def state_dict(self) -> Optional[Dict[str, Any]]:
def state_dict(self) -> Dict[str, Any]:
...

def load_state_dict(self, state_dict: Optional[Dict[str, Any]]) -> None:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...
13 changes: 9 additions & 4 deletions torchdata/stateful_dataloader/stateful_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class StatefulDataLoader(DataLoader[T_co]):
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
"""

_iterator: Optional["_StatefulBaseDataLoaderIter"]

def __init__(
self,
dataset: Dataset[T_co],
Expand Down Expand Up @@ -189,10 +191,11 @@ def __init__(
pin_memory_device=pin_memory_device,
)
self.snapshot_every_n_steps = snapshot_every_n_steps
self.next_iter_state = None
self.next_iter_state: Optional[Dict[str, Any]] = None
self.iter_calls = 0

def _get_iterator(self) -> "_BaseDataLoaderIter":
def _get_iterator(self) -> "_StatefulBaseDataLoaderIter":
it: _StatefulBaseDataLoaderIter
if self.num_workers == 0:
it = _StatefulSingleProcessDataLoaderIter(self, self.next_iter_state)
else:
Expand Down Expand Up @@ -681,6 +684,8 @@ class _StatefulMultiProcessingDataLoaderIter(_StatefulBaseDataLoaderIter):
# processing indices already in `index_queue` if we are already shutting
# down.

_last_yielded_worker_id: int

def __init__(self, loader, next_iter_state):
super().__init__(loader)
self._snapshot_interval = loader.snapshot_every_n_steps
Expand Down Expand Up @@ -804,7 +809,7 @@ def __init__(self, loader, next_iter_state):
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
_utils.signal_handling._set_SIGCHLD_handler()
self._worker_pids_set = True
self._snapshot, self._worker_snapshots, self._main_snapshots = {}, {}, collections.deque()
self._snapshot, self._worker_snapshots, self._main_snapshots = {}, {}, collections.deque() # type: ignore[var-annotated]

self._main_state_0 = self._get_main_state()
self._reset(loader, first_iter=True, prime_prefetch=next_iter_state is None)
Expand Down Expand Up @@ -1366,4 +1371,4 @@ def __del__(self):
self._shutdown_workers()


torch.utils.data.DataLoader = StatefulDataLoader
torch.utils.data.DataLoader = StatefulDataLoader # type: ignore[misc]
2 changes: 1 addition & 1 deletion torchdata/stateful_dataloader/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def try_to_serialize(obj: Any) -> Union[dict, None]:
return obj_state


def try_to_deserialize(obj: T, state_dict: Union[dict, None]) -> Union[T, None]:
def try_to_deserialize(obj: T, state_dict: dict) -> T:
if isinstance(obj, Stateful):
obj.load_state_dict(state_dict)
return obj # type: ignore[return-value]
Expand Down

0 comments on commit 2e15563

Please sign in to comment.