From bc4cdef898b7d4d64278629925948382b06d4cd1 Mon Sep 17 00:00:00 2001 From: Andrew Ho Date: Mon, 11 Nov 2024 17:30:38 -0500 Subject: [PATCH] Changes default behaviour of StopIteration immediately after load_state_dict to be false (#1358) --- test/nodes/utils.py | 18 ++++++++++-- torchdata/nodes/base_node.py | 55 ++++++++++++++++++++++++++++++++++-- 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/test/nodes/utils.py b/test/nodes/utils.py index cfe70206d..2bfe4c0ae 100644 --- a/test/nodes/utils.py +++ b/test/nodes/utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy import random import time from typing import Any, Dict, Iterator, Optional @@ -79,7 +80,8 @@ def __getitem__(self, i: int) -> dict: def run_test_save_load_state(test, x: BaseNode, midpoint: int): - # Test before iter call + ############################## + # Generate initial, midpoint, and end state_dict's initial_state_dict = x.state_dict() it = iter(x) results = [] @@ -94,6 +96,8 @@ def run_test_save_load_state(test, x: BaseNode, midpoint: int): # store epoch 1's results results_1 = list(x) + ############################## + # Test restoring from midpoint x.load_state_dict(state_dict) results_after = list(x) test.assertEqual(results_after, results[midpoint:]) @@ -102,6 +106,7 @@ def run_test_save_load_state(test, x: BaseNode, midpoint: int): results_after_1 = list(x) test.assertEqual(results_after_1, results_1) + ############################## # Test initialize from beginning after resume x.load_state_dict(initial_state_dict) full_results = list(x) @@ -109,7 +114,14 @@ def run_test_save_load_state(test, x: BaseNode, midpoint: int): full_results_1 = list(x) test.assertEqual(full_results_1, results_1) - # Test restoring from end of epoch 0 - x.load_state_dict(state_dict_0_end) + ############################## + # Test restoring from end-of-epoch 0 + x.load_state_dict(state_dict_0_end, restart_on_stop_iteration=False) + results_after_dict_0_with_restart_false = list(x) + test.assertEqual(results_after_dict_0_with_restart_false, []) + + ############################## + # Test restoring from end of epoch 0 with restart_on_stop_iteration=True + x.load_state_dict(copy.deepcopy(state_dict_0_end), restart_on_stop_iteration=True) results_after_dict_0 = list(x) test.assertEqual(results_after_dict_0, results_1) diff --git a/torchdata/nodes/base_node.py b/torchdata/nodes/base_node.py index 96b4f683b..cb43b12f7 100644 --- a/torchdata/nodes/base_node.py +++ b/torchdata/nodes/base_node.py @@ -27,6 +27,7 @@ def finished(self) -> bool: class BaseNode(Iterable[T]): __it: Optional[BaseNodeIterator[T]] = None # holds pointer to last iter() requested __initial_state: Optional[Dict[str, Any]] = None + __restart_on_stop_iteration: bool = False def iterator(self, initial_state: Optional[dict]) -> Iterator[T]: """Override this method to implement the iterator. @@ -47,14 +48,15 @@ def get_state(self) -> Dict[str, Any]: def __iter__(self) -> BaseNodeIterator[T]: if self.__it is not None and not self.__it.started(): - # Only create a new iter if the last requested one did not start + # Only create a new iter if the last requested one has already started return self.__it if self.__initial_state is not None: self.__it = _EagerIter(self, self.__initial_state) self.__initial_state = None - if not self.__it.has_next(): + if self.__restart_on_stop_iteration and not self.__it.has_next(): self.__it = _EagerIter(self, self.__initial_state) + self.__restart_on_stop_iteration = False # reset this for subsequent calls else: self.__it = _EagerIter(self, self.__initial_state) return self.__it @@ -62,8 +64,55 @@ def __iter__(self) -> BaseNodeIterator[T]: def state_dict(self) -> Dict[str, Any]: return self.get_state() - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict( + self, + state_dict: Dict[str, Any], + restart_on_stop_iteration: bool = False, + ) -> None: + """ + When iter() is next requested from this node, it will be instantiated with state_dict. + state_dict will be passed directly to the .iterator(...) method of this node which will + handle proper initialization. + + NOTE [ state_dict end of iteration handling ] + Special care must be taken when state_dict is requested after StopIteration. + Consider the following common example of saving state_dict after one epoch. + + ```python + node: BaseNode = ... + for batch in node: + # do something + + state_dict = node.state_dict() + ``` + + Technically, since state_dict() was called before a new iterator was requested from `node`, + you should expect the following behaviour: + + ```python + node: BaseNode = ... + node.load_state_dict(state_dict) # Load state_dict from above + next(iter(node)) # Throws StopIteration immediately + ``` + + You can avoid the above (default) behaviour by passing `restart_on_stop_iteration=True` when + calling `load_state_dict`, eg + + ```python + node: BaseNode = ... + node.load_state_dict(state_dict, restart_on_stop_iteration=True) + next(iter(node)) # Catches StopIteration, creates a new iterator, and returns next() + ``` + + Note: we can not make `True` the default for restart_on_stop_iteration because it would + prevent StopIteration thrown in leaves from propogating up to the node where load_state_dict is called. + + :param state_dict: state_dict to load in next __iter__ requested + :param restart_on_stop_iteration: (default False) - whether to restart the iterator automatically + when the first next() call would throw StopIteration. + """ self.__initial_state = state_dict + self.__restart_on_stop_iteration = restart_on_stop_iteration class _EagerIter(BaseNodeIterator[T]):