Skip to content

Commit

Permalink
Make StreamingDataset state_dict() more flexible (#90)
Browse files Browse the repository at this point in the history
* Dataset.state_dict(num_samples, from_beginning).

* Elaborate on docstring.

* Add note.

* Rearg.
  • Loading branch information
knighton authored Dec 8, 2022
1 parent 1c867fe commit 1639bc0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
5 changes: 3 additions & 2 deletions streaming/base/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def state_dict(self) -> Optional[Dict[str, Any]]:
"""
if isinstance(self.dataset, StreamingDataset):
world = World()
return self.dataset.state_dict(self.num_samples_yielded * world.num_ranks)
num_samples = self.num_samples_yielded * world.num_ranks
return self.dataset.state_dict(num_samples, False)
return None

def load_state_dict(self, obj: Dict[str, Any]) -> None:
Expand All @@ -84,7 +85,7 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None:
obj (Dict[str, Any]): The state.
"""
if isinstance(self.dataset, StreamingDataset):
return self.dataset.load_state_dict(obj)
self.dataset.load_state_dict(obj)

def __del__(self) -> None:
"""Terminate the workers during cleanup."""
Expand Down
15 changes: 12 additions & 3 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,23 +666,32 @@ def __iter__(self) -> Iterator[Dict[str, Any]]:
for sample_id in self._each_sample(sample_ids):
yield self[sample_id]

def state_dict(self, sample_in_epoch: int) -> Dict[str, Any]:
def state_dict(self, num_samples: int, from_beginning: bool) -> Dict[str, Any]:
"""Get a dict containing training state (called from non-worker process).
This is called on rank zero.
Our stock StreamingDataLoader counts samples from start of training (from_beginning=false).
However, if you are always counting from the start of the epoch, set from_beginning=true.
Args:
sample_in_epoch (int): The number of samples processed so far in the current epoch.
num_samples (int): The number of samples processed so far in the current epoch.
from_beginning (int): Whether we are counting samples from the start of this epoch, or
the start of just this potentially resumed training run this epoch.
Returns:
Dict[str, Any]: The state.
"""
world = World()
epoch = self.next_epoch - 1
epoch, offset = self._resume(world, epoch)
if from_beginning:
sample_in_epoch = num_samples
else:
sample_in_epoch = offset + num_samples
return {
'epoch': epoch,
'sample_in_epoch': offset + sample_in_epoch,
'sample_in_epoch': sample_in_epoch,
'num_canonical_nodes': self.num_canonical_nodes,
'shuffle_seed': self.shuffle_seed
}
Expand Down

0 comments on commit 1639bc0

Please sign in to comment.