-
Notifications
You must be signed in to change notification settings - Fork 967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add early support for torchdata.stateful_dataloader.StatefulDataLoader
within the Accelerator
#2895
Changes from all commits
79a8fa2
efa1e7d
8dc107d
f342f4c
065849a
1e3fad1
a41cf38
8831488
140f1e6
727afeb
73683b4
32c318e
57c6f57
ed612d1
511050e
8dbc1a3
df43960
c778e32
4e00055
0471fe3
3036b7f
9ade2e9
ba0f5c6
b774291
fde597d
f273abc
8a46eb6
e4e1cac
8bf2fe2
ca4338d
c38f317
17a2a19
b39a606
d1e82e0
d99d734
39b2866
f58f609
f2119cf
7adec94
8850af3
6ff0f68
4f28d2e
a9b637d
0384543
0e0515d
809aca0
5145c2d
ca74ff2
a8f8bf3
59738f4
d264939
8f04c1e
45db4b9
0ffc64b
03a7774
8d2c6c3
6bfe871
7a344e4
6ff997e
4739524
4de9159
abf815a
06597d4
4142c7f
f02f18c
35977ca
597e910
419f607
4188d4c
51377a4
f4b6bb5
d02dfcc
21bc420
74e2f53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
get_data_structure, | ||
initialize_tensors, | ||
is_torch_version, | ||
is_torchdata_stateful_dataloader_available, | ||
send_to_device, | ||
slice_tensors, | ||
synchronize_rng_states, | ||
|
@@ -388,9 +389,75 @@ def end(self): | |
self.gradient_state._remove_dataloader(self) | ||
|
||
|
||
class DataLoaderShard(DataLoader, DataLoaderStateMixin): | ||
class DataLoaderAdapter: | ||
""" | ||
Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup. | ||
A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For | ||
compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in. | ||
""" | ||
|
||
def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs): | ||
self.use_stateful_dataloader = use_stateful_dataloader | ||
if is_torchdata_stateful_dataloader_available(): | ||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
|
||
if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): | ||
raise ImportError( | ||
"StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it." | ||
) | ||
if use_stateful_dataloader: | ||
self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs) | ||
else: | ||
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs) | ||
|
||
# Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641 | ||
# In C++ terms, this is analogous to creating `DataLoaderAdapter<T> : T`, where T is a DataLoader or | ||
# StatefulDataLoader | ||
# | ||
# The same functionality could be achieved by directly creating the required subclasses for both {DataLoader, | ||
# StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional | ||
# dispatching scattered throughout various functions and files. | ||
# | ||
# This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work | ||
# transparently. | ||
# | ||
# A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit), | ||
# but this would not be backwards compatible with existing code which assumes | ||
# DataLoaderShard/DataLoaderDispatcher are DataLoaders. | ||
base_cls = self.__class__ | ||
base_cls_name = self.__class__.__name__ | ||
parent_cls_name = self.base_dataloader.__class__ | ||
self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {}) | ||
|
||
if hasattr(self.base_dataloader, "state_dict"): | ||
self.dl_state_dict = self.base_dataloader.state_dict() | ||
|
||
def __getattr__(self, name): | ||
# Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute. | ||
if name == "base_dataloader": | ||
raise AttributeError() | ||
# Delegate attribute access to the internal dataloader | ||
return getattr(self.base_dataloader, name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A bit of an edge case: Let's also check if the name is not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you give a code example of how infinite recursion would happen here? If I'm reading the python3 docs for If I add the following block into
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, that would be for the edge case of an attribute getting called on the class, i.e. before it is instantiated. In that case, the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am still not entirely sure how this could happen, but I added a check in |
||
|
||
def state_dict(self): | ||
return self.dl_state_dict | ||
|
||
def load_state_dict(self, state_dict): | ||
self.base_dataloader.load_state_dict(state_dict) | ||
self.dl_state_dict = self.state_dict | ||
|
||
def _update_state_dict(self): | ||
# The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded. | ||
# E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of | ||
# what it wants to yield. | ||
# | ||
# _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter. | ||
if hasattr(self.base_dataloader, "state_dict"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add a comment here when this needs to be called and with the context on why it's required. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a comment here, kinda clunky though. |
||
self.dl_state_dict = self.base_dataloader.state_dict() | ||
|
||
|
||
class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): | ||
""" | ||
Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup. | ||
|
||
Args: | ||
dataset (`torch.utils.data.dataset.Dataset`): | ||
|
@@ -409,6 +476,8 @@ class DataLoaderShard(DataLoader, DataLoaderStateMixin): | |
A random number generator to keep synchronized across processes. | ||
skip_batches (`int`, *optional*, defaults to 0): | ||
The number of batches to skip at the beginning. | ||
use_stateful_dataloader (`bool`, *optional*, defaults to `False`): | ||
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. | ||
**kwargs (additional keyword arguments, *optional*): | ||
All other keyword arguments to pass to the regular `DataLoader` initialization. | ||
|
||
|
@@ -428,11 +497,12 @@ def __init__( | |
rng_types=None, | ||
synchronized_generator=None, | ||
skip_batches=0, | ||
use_stateful_dataloader=False, | ||
_drop_last: bool = False, | ||
_non_blocking: bool = False, | ||
**kwargs, | ||
): | ||
super().__init__(dataset, **kwargs) | ||
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs) | ||
self.device = device | ||
self.rng_types = rng_types | ||
self.synchronized_generator = synchronized_generator | ||
|
@@ -448,7 +518,7 @@ def __iter__(self): | |
self.begin() | ||
|
||
self.set_epoch(self.iteration) | ||
dataloader_iter = super().__iter__() | ||
dataloader_iter = self.base_dataloader.__iter__() | ||
# We iterate one batch ahead to check when we are at the end | ||
try: | ||
current_batch = next(dataloader_iter) | ||
|
@@ -461,6 +531,7 @@ def __iter__(self): | |
# But we still move it to the device so it is done before `StopIteration` is reached | ||
if self.device is not None: | ||
current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking) | ||
self._update_state_dict() | ||
next_batch = next(dataloader_iter) | ||
if batch_index >= self.skip_batches: | ||
yield current_batch | ||
|
@@ -564,10 +635,10 @@ def dataloader(self): | |
return self._loader | ||
|
||
|
||
class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): | ||
class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin): | ||
""" | ||
Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each | ||
process their part of the batch. | ||
Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process | ||
their part of the batch. | ||
|
||
Args: | ||
split_batches (`bool`, *optional*, defaults to `False`): | ||
|
@@ -579,6 +650,8 @@ class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): | |
size of the `dataloader` is a round multiple of `batch_size`. | ||
skip_batches (`int`, *optional*, defaults to 0): | ||
The number of batches to skip at the beginning of an iteration. | ||
use_stateful_dataloader (`bool`, *optional*, defaults to `False`): | ||
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. | ||
|
||
**Available attributes:** | ||
|
||
|
@@ -594,6 +667,7 @@ def __init__( | |
dataset, | ||
split_batches: bool = False, | ||
skip_batches=0, | ||
use_stateful_dataloader=False, | ||
_drop_last: bool = False, | ||
_non_blocking: bool = False, | ||
slice_fn=None, | ||
|
@@ -606,7 +680,7 @@ def __init__( | |
# We need to save the shuffling state of the DataPipe | ||
if isinstance(dataset, ShufflerIterDataPipe): | ||
shuffle = dataset._shuffle_enabled | ||
super().__init__(dataset, **kwargs) | ||
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs) | ||
self.split_batches = split_batches | ||
if shuffle: | ||
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) | ||
|
@@ -627,12 +701,14 @@ def _fetch_batches(self, iterator): | |
try: | ||
if self.split_batches: | ||
# One batch of the main iterator is dispatched and split. | ||
self._update_state_dict() | ||
batch = next(iterator) | ||
else: | ||
# num_processes batches of the main iterator are concatenated then dispatched and split. | ||
# We add the batches one by one so we have the remainder available when drop_last=False. | ||
batches = [] | ||
for _ in range(self.state.num_processes): | ||
self._update_state_dict() | ||
batches.append(next(iterator)) | ||
try: | ||
batch = concatenate(batches, dim=0) | ||
|
@@ -673,9 +749,9 @@ def __iter__(self): | |
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts | ||
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes. | ||
# But, we only iterate through the DataLoader on process 0. | ||
main_iterator = super().__iter__() | ||
main_iterator = self.base_dataloader.__iter__() | ||
elif self.state.process_index == 0: | ||
main_iterator = super().__iter__() | ||
main_iterator = self.base_dataloader.__iter__() | ||
stop_iteration = False | ||
self._stop_iteration = False | ||
first_batch = None | ||
|
@@ -812,6 +888,7 @@ def prepare_data_loader( | |
slice_fn_for_dispatch: Optional[Callable] = None, | ||
use_seedable_sampler: bool = False, | ||
non_blocking: bool = False, | ||
use_stateful_dataloader: bool = False, | ||
) -> DataLoader: | ||
""" | ||
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. | ||
|
@@ -873,6 +950,10 @@ def prepare_data_loader( | |
non_blocking (`bool`, *optional*, defaults to `False`): | ||
If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has | ||
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations. | ||
use_stateful_dataloader (`bool`, *optional*, defaults to `False`): | ||
"If set to true, the dataloader prepared by the Accelerator will be backed by " | ||
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). | ||
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." | ||
|
||
|
||
Returns: | ||
|
@@ -1006,6 +1087,7 @@ def prepare_data_loader( | |
_drop_last=dataloader.drop_last, | ||
_non_blocking=non_blocking, | ||
slice_fn=slice_fn_for_dispatch, | ||
use_stateful_dataloader=use_stateful_dataloader, | ||
**kwargs, | ||
) | ||
elif sampler_is_batch_sampler: | ||
|
@@ -1018,6 +1100,7 @@ def prepare_data_loader( | |
_drop_last=dataloader.drop_last, | ||
_non_blocking=non_blocking, | ||
synchronized_generator=synchronized_generator, | ||
use_stateful_dataloader=use_stateful_dataloader, | ||
**kwargs, | ||
) | ||
else: | ||
|
@@ -1029,6 +1112,7 @@ def prepare_data_loader( | |
synchronized_generator=synchronized_generator, | ||
_drop_last=dataloader.drop_last, | ||
_non_blocking=non_blocking, | ||
use_stateful_dataloader=use_stateful_dataloader, | ||
**kwargs, | ||
) | ||
|
||
|
@@ -1046,6 +1130,7 @@ class SkipBatchSampler(BatchSampler): | |
|
||
def __init__(self, batch_sampler, skip_batches=0): | ||
self.batch_sampler = batch_sampler | ||
self.sampler = batch_sampler.sampler | ||
self.skip_batches = skip_batches | ||
|
||
def __iter__(self): | ||
|
@@ -1061,7 +1146,7 @@ def __len__(self): | |
return len(self.batch_sampler) - self.skip_batches | ||
|
||
|
||
class SkipDataLoader(DataLoader): | ||
class SkipDataLoader(DataLoaderAdapter): | ||
""" | ||
Subclass of a PyTorch `DataLoader` that will skip the first batches. | ||
|
||
|
@@ -1070,24 +1155,30 @@ class SkipDataLoader(DataLoader): | |
The dataset to use to build this datalaoder. | ||
skip_batches (`int`, *optional*, defaults to 0): | ||
The number of batches to skip at the beginning. | ||
use_stateful_dataloader (`bool`, *optional*, defaults to `False`): | ||
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. | ||
kwargs: | ||
All other keyword arguments to pass to the regular `DataLoader` initialization. | ||
""" | ||
|
||
def __init__(self, dataset, skip_batches=0, **kwargs): | ||
super().__init__(dataset, **kwargs) | ||
def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs): | ||
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs) | ||
self.skip_batches = skip_batches | ||
|
||
def __iter__(self): | ||
for index, batch in enumerate(super().__iter__()): | ||
for index, batch in enumerate(self.base_dataloader.__iter__()): | ||
if index >= self.skip_batches: | ||
self._update_state_dict() | ||
yield batch | ||
|
||
|
||
def skip_first_batches(dataloader, num_batches=0): | ||
""" | ||
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. | ||
""" | ||
if is_torchdata_stateful_dataloader_available(): | ||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
|
||
state = PartialState() | ||
if state.distributed_type == DistributedType.XLA: | ||
device = dataloader.device | ||
|
@@ -1131,6 +1222,7 @@ def skip_first_batches(dataloader, num_batches=0): | |
split_batches=dataloader.split_batches, | ||
batch_sampler=new_batch_sampler, | ||
_drop_last=dataloader._drop_last, | ||
use_stateful_dataloader=dataloader.use_stateful_dataloader, | ||
**kwargs, | ||
) | ||
elif isinstance(dataloader, DataLoaderShard): | ||
|
@@ -1147,12 +1239,17 @@ def skip_first_batches(dataloader, num_batches=0): | |
device=dataloader.device, | ||
rng_types=dataloader.rng_types, | ||
synchronized_generator=dataloader.synchronized_generator, | ||
use_stateful_dataloader=dataloader.use_stateful_dataloader, | ||
**kwargs, | ||
) | ||
else: | ||
if new_batch_sampler is None: | ||
# Need to manually skip batches in the dataloader | ||
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) | ||
dataloader = SkipDataLoader( | ||
dataset, skip_batches=num_batches, use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs | ||
) | ||
elif is_torchdata_stateful_dataloader_available() and isinstance(dataloader, StatefulDataLoader): | ||
dataloader = StatefulDataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) | ||
else: | ||
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me just bring up (again) that another solution could be monkey-patching
__instancecheck__
onDataLoader
. Not saying that it's less hacky, just wanted to raise awareness :)