Skip to content

Commit

Permalink
Changes default behaviour of StopIteration immediately after load_sta…
Browse files Browse the repository at this point in the history
…te_dict to be false (#1358)
  • Loading branch information
andrewkho authored Nov 11, 2024
1 parent 0263d58 commit bc4cdef
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 6 deletions.
18 changes: 15 additions & 3 deletions test/nodes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import random
import time
from typing import Any, Dict, Iterator, Optional
Expand Down Expand Up @@ -79,7 +80,8 @@ def __getitem__(self, i: int) -> dict:


def run_test_save_load_state(test, x: BaseNode, midpoint: int):
# Test before iter call
##############################
# Generate initial, midpoint, and end state_dict's
initial_state_dict = x.state_dict()
it = iter(x)
results = []
Expand All @@ -94,6 +96,8 @@ def run_test_save_load_state(test, x: BaseNode, midpoint: int):
# store epoch 1's results
results_1 = list(x)

##############################
# Test restoring from midpoint
x.load_state_dict(state_dict)
results_after = list(x)
test.assertEqual(results_after, results[midpoint:])
Expand All @@ -102,14 +106,22 @@ def run_test_save_load_state(test, x: BaseNode, midpoint: int):
results_after_1 = list(x)
test.assertEqual(results_after_1, results_1)

##############################
# Test initialize from beginning after resume
x.load_state_dict(initial_state_dict)
full_results = list(x)
test.assertEqual(full_results, results)
full_results_1 = list(x)
test.assertEqual(full_results_1, results_1)

# Test restoring from end of epoch 0
x.load_state_dict(state_dict_0_end)
##############################
# Test restoring from end-of-epoch 0
x.load_state_dict(state_dict_0_end, restart_on_stop_iteration=False)
results_after_dict_0_with_restart_false = list(x)
test.assertEqual(results_after_dict_0_with_restart_false, [])

##############################
# Test restoring from end of epoch 0 with restart_on_stop_iteration=True
x.load_state_dict(copy.deepcopy(state_dict_0_end), restart_on_stop_iteration=True)
results_after_dict_0 = list(x)
test.assertEqual(results_after_dict_0, results_1)
55 changes: 52 additions & 3 deletions torchdata/nodes/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def finished(self) -> bool:
class BaseNode(Iterable[T]):
__it: Optional[BaseNodeIterator[T]] = None # holds pointer to last iter() requested
__initial_state: Optional[Dict[str, Any]] = None
__restart_on_stop_iteration: bool = False

def iterator(self, initial_state: Optional[dict]) -> Iterator[T]:
"""Override this method to implement the iterator.
Expand All @@ -47,23 +48,71 @@ def get_state(self) -> Dict[str, Any]:

def __iter__(self) -> BaseNodeIterator[T]:
if self.__it is not None and not self.__it.started():
# Only create a new iter if the last requested one did not start
# Only create a new iter if the last requested one has already started
return self.__it

if self.__initial_state is not None:
self.__it = _EagerIter(self, self.__initial_state)
self.__initial_state = None
if not self.__it.has_next():
if self.__restart_on_stop_iteration and not self.__it.has_next():
self.__it = _EagerIter(self, self.__initial_state)
self.__restart_on_stop_iteration = False # reset this for subsequent calls
else:
self.__it = _EagerIter(self, self.__initial_state)
return self.__it

def state_dict(self) -> Dict[str, Any]:
return self.get_state()

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def load_state_dict(
self,
state_dict: Dict[str, Any],
restart_on_stop_iteration: bool = False,
) -> None:
"""
When iter() is next requested from this node, it will be instantiated with state_dict.
state_dict will be passed directly to the .iterator(...) method of this node which will
handle proper initialization.
NOTE [ state_dict end of iteration handling ]
Special care must be taken when state_dict is requested after StopIteration.
Consider the following common example of saving state_dict after one epoch.
```python
node: BaseNode = ...
for batch in node:
# do something
state_dict = node.state_dict()
```
Technically, since state_dict() was called before a new iterator was requested from `node`,
you should expect the following behaviour:
```python
node: BaseNode = ...
node.load_state_dict(state_dict) # Load state_dict from above
next(iter(node)) # Throws StopIteration immediately
```
You can avoid the above (default) behaviour by passing `restart_on_stop_iteration=True` when
calling `load_state_dict`, eg
```python
node: BaseNode = ...
node.load_state_dict(state_dict, restart_on_stop_iteration=True)
next(iter(node)) # Catches StopIteration, creates a new iterator, and returns next()
```
Note: we can not make `True` the default for restart_on_stop_iteration because it would
prevent StopIteration thrown in leaves from propogating up to the node where load_state_dict is called.
:param state_dict: state_dict to load in next __iter__ requested
:param restart_on_stop_iteration: (default False) - whether to restart the iterator automatically
when the first next() call would throw StopIteration.
"""
self.__initial_state = state_dict
self.__restart_on_stop_iteration = restart_on_stop_iteration


class _EagerIter(BaseNodeIterator[T]):
Expand Down

0 comments on commit bc4cdef

Please sign in to comment.