Skip to content

Commit

Permalink
Add early support for `torchdata.stateful_dataloader.StatefulDataLoad…
Browse files Browse the repository at this point in the history
…er` within the `Accelerator` (#2895)

* temporary commit

* checkout?

* dataloader wrapper

* tmp

* weird failing test

* trying multiple inheritance

* DataLoaderAdapter

* make style

* Some dark magic dynamic reflection (for backwards compat)

* typo

* some tests

* more mixin stuff

* maybe found broken test?

* this is a very invasive feature

* i think the feature is done?

* add xpu support (#2864)

* better tests

* discovered a bug

* maybe fixed bug?

* make style

* hopefully this is PR ready

* properly skip tests

* parameterize

* temporary commit

* checkout?

* dataloader wrapper

* tmp

* weird failing test

* trying multiple inheritance

* DataLoaderAdapter

* make style

* Some dark magic dynamic reflection (for backwards compat)

* typo

* some tests

* more mixin stuff

* maybe found broken test?

* this is a very invasive feature

* i think the feature is done?

* better tests

* discovered a bug

* maybe fixed bug?

* make style

* hopefully this is PR ready

* properly skip tests

* parameterize

* Update src/accelerate/utils/dataclasses.py

Co-authored-by: Zach Mueller <[email protected]>

* Update src/accelerate/data_loader.py

Co-authored-by: Zach Mueller <[email protected]>

* merge conflicts

* move imports

* make style

* merges are breaking tests

* fix test name

* Require safetensors>=0.4.3

* undo last commit

* minor style

* address pr comments

* Torchdata version 0.8.0 is stable now

* added docs and require torchdata>=0.8.0 for testing

* test base_dataloader attr doesn't cause infinite recursion

* address pr

* replace super().__iter__ with self.base_dataloader.__iter__

---------

Co-authored-by: Fanli Lin <[email protected]>
Co-authored-by: Zach Mueller <[email protected]>
  • Loading branch information
3 people authored Aug 22, 2024
1 parent 1a6af0b commit ad3f574
Show file tree
Hide file tree
Showing 12 changed files with 605 additions and 21 deletions.
6 changes: 6 additions & 0 deletions docs/source/basic_tutorials/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,9 @@ During training, you may want to save the current state of the model, optimizer,
To further customize where and how states are saved through [`~Accelerator.save_state`], use the [`~utils.ProjectConfiguration`] class. For example, if `automatic_checkpoint_naming` is enabled, each saved checkpoint is stored at `Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}`.

Any other stateful items to be stored should be registered with the [`~Accelerator.register_for_checkpointing`] method so they can be saved and loaded. Every object passed to this method to be stored must have a `load_state_dict` and `state_dict` function.

<Note>

If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, you can additionally pass `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`]. This extends Accelerate's DataLoader classes with a `load_state_dict` and `state_dict` function, and makes it so `Accelerator.save_state` and `Accelerator.load_state` also track how far into the training dataset it has read when persisting the model.

</Note>
6 changes: 6 additions & 0 deletions docs/source/concept_guides/internal_mechanism.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,10 @@ setting the same seed in the main random number generator in all processes.

</Tip>

<Note>

If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, and you have passed `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`], these classes will directly inherit from `StatefulDataLoader` instead, and maintain a `state_dict`.

</Note>

For more details about the internals, see the [Internals page](package_reference/torch_wrappers).
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"datasets",
"diffusers",
"evaluate",
"torchdata>=0.8.0",
"torchpippy>=0.2.0",
"transformers",
"scipy",
Expand Down
7 changes: 7 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,12 @@ def use_seedable_sampler(self):
def non_blocking(self):
return self.dataloader_config.non_blocking

@property
def use_stateful_dataloader(self):
if hasattr(self.dataloader_config, "use_stateful_dataloader"):
return self.dataloader_config.use_stateful_dataloader
return False

@property
def project_dir(self):
return self.project_configuration.project_dir
Expand Down Expand Up @@ -2068,6 +2074,7 @@ def prepare_data_loader(
slice_fn_for_dispatch=slice_fn_for_dispatch,
use_seedable_sampler=self.use_seedable_sampler,
non_blocking=self.non_blocking,
use_stateful_dataloader=self.use_stateful_dataloader,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
Expand Down
127 changes: 112 additions & 15 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_data_structure,
initialize_tensors,
is_torch_version,
is_torchdata_stateful_dataloader_available,
send_to_device,
slice_tensors,
synchronize_rng_states,
Expand Down Expand Up @@ -388,9 +389,75 @@ def end(self):
self.gradient_state._remove_dataloader(self)


class DataLoaderShard(DataLoader, DataLoaderStateMixin):
class DataLoaderAdapter:
"""
Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup.
A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
"""

def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
self.use_stateful_dataloader = use_stateful_dataloader
if is_torchdata_stateful_dataloader_available():
from torchdata.stateful_dataloader import StatefulDataLoader

if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():
raise ImportError(
"StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
)
if use_stateful_dataloader:
self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
else:
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)

# Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641
# In C++ terms, this is analogous to creating `DataLoaderAdapter<T> : T`, where T is a DataLoader or
# StatefulDataLoader
#
# The same functionality could be achieved by directly creating the required subclasses for both {DataLoader,
# StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional
# dispatching scattered throughout various functions and files.
#
# This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work
# transparently.
#
# A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit),
# but this would not be backwards compatible with existing code which assumes
# DataLoaderShard/DataLoaderDispatcher are DataLoaders.
base_cls = self.__class__
base_cls_name = self.__class__.__name__
parent_cls_name = self.base_dataloader.__class__
self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {})

if hasattr(self.base_dataloader, "state_dict"):
self.dl_state_dict = self.base_dataloader.state_dict()

def __getattr__(self, name):
# Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.
if name == "base_dataloader":
raise AttributeError()
# Delegate attribute access to the internal dataloader
return getattr(self.base_dataloader, name)

def state_dict(self):
return self.dl_state_dict

def load_state_dict(self, state_dict):
self.base_dataloader.load_state_dict(state_dict)
self.dl_state_dict = self.state_dict

def _update_state_dict(self):
# The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
# E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
# what it wants to yield.
#
# _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
if hasattr(self.base_dataloader, "state_dict"):
self.dl_state_dict = self.base_dataloader.state_dict()


class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
Args:
dataset (`torch.utils.data.dataset.Dataset`):
Expand All @@ -409,6 +476,8 @@ class DataLoaderShard(DataLoader, DataLoaderStateMixin):
A random number generator to keep synchronized across processes.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
**kwargs (additional keyword arguments, *optional*):
All other keyword arguments to pass to the regular `DataLoader` initialization.
Expand All @@ -428,11 +497,12 @@ def __init__(
rng_types=None,
synchronized_generator=None,
skip_batches=0,
use_stateful_dataloader=False,
_drop_last: bool = False,
_non_blocking: bool = False,
**kwargs,
):
super().__init__(dataset, **kwargs)
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
self.device = device
self.rng_types = rng_types
self.synchronized_generator = synchronized_generator
Expand All @@ -448,7 +518,7 @@ def __iter__(self):
self.begin()

self.set_epoch(self.iteration)
dataloader_iter = super().__iter__()
dataloader_iter = self.base_dataloader.__iter__()
# We iterate one batch ahead to check when we are at the end
try:
current_batch = next(dataloader_iter)
Expand All @@ -461,6 +531,7 @@ def __iter__(self):
# But we still move it to the device so it is done before `StopIteration` is reached
if self.device is not None:
current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
self._update_state_dict()
next_batch = next(dataloader_iter)
if batch_index >= self.skip_batches:
yield current_batch
Expand Down Expand Up @@ -564,10 +635,10 @@ def dataloader(self):
return self._loader


class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each
process their part of the batch.
Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process
their part of the batch.
Args:
split_batches (`bool`, *optional*, defaults to `False`):
Expand All @@ -579,6 +650,8 @@ class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
size of the `dataloader` is a round multiple of `batch_size`.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning of an iteration.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
**Available attributes:**
Expand All @@ -594,6 +667,7 @@ def __init__(
dataset,
split_batches: bool = False,
skip_batches=0,
use_stateful_dataloader=False,
_drop_last: bool = False,
_non_blocking: bool = False,
slice_fn=None,
Expand All @@ -606,7 +680,7 @@ def __init__(
# We need to save the shuffling state of the DataPipe
if isinstance(dataset, ShufflerIterDataPipe):
shuffle = dataset._shuffle_enabled
super().__init__(dataset, **kwargs)
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
self.split_batches = split_batches
if shuffle:
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
Expand All @@ -627,12 +701,14 @@ def _fetch_batches(self, iterator):
try:
if self.split_batches:
# One batch of the main iterator is dispatched and split.
self._update_state_dict()
batch = next(iterator)
else:
# num_processes batches of the main iterator are concatenated then dispatched and split.
# We add the batches one by one so we have the remainder available when drop_last=False.
batches = []
for _ in range(self.state.num_processes):
self._update_state_dict()
batches.append(next(iterator))
try:
batch = concatenate(batches, dim=0)
Expand Down Expand Up @@ -673,9 +749,9 @@ def __iter__(self):
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
# But, we only iterate through the DataLoader on process 0.
main_iterator = super().__iter__()
main_iterator = self.base_dataloader.__iter__()
elif self.state.process_index == 0:
main_iterator = super().__iter__()
main_iterator = self.base_dataloader.__iter__()
stop_iteration = False
self._stop_iteration = False
first_batch = None
Expand Down Expand Up @@ -812,6 +888,7 @@ def prepare_data_loader(
slice_fn_for_dispatch: Optional[Callable] = None,
use_seedable_sampler: bool = False,
non_blocking: bool = False,
use_stateful_dataloader: bool = False,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
Expand Down Expand Up @@ -873,6 +950,10 @@ def prepare_data_loader(
non_blocking (`bool`, *optional*, defaults to `False`):
If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
"If set to true, the dataloader prepared by the Accelerator will be backed by "
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
Returns:
Expand Down Expand Up @@ -1006,6 +1087,7 @@ def prepare_data_loader(
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
slice_fn=slice_fn_for_dispatch,
use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)
elif sampler_is_batch_sampler:
Expand All @@ -1018,6 +1100,7 @@ def prepare_data_loader(
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
synchronized_generator=synchronized_generator,
use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)
else:
Expand All @@ -1029,6 +1112,7 @@ def prepare_data_loader(
synchronized_generator=synchronized_generator,
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)

Expand All @@ -1046,6 +1130,7 @@ class SkipBatchSampler(BatchSampler):

def __init__(self, batch_sampler, skip_batches=0):
self.batch_sampler = batch_sampler
self.sampler = batch_sampler.sampler
self.skip_batches = skip_batches

def __iter__(self):
Expand All @@ -1061,7 +1146,7 @@ def __len__(self):
return len(self.batch_sampler) - self.skip_batches


class SkipDataLoader(DataLoader):
class SkipDataLoader(DataLoaderAdapter):
"""
Subclass of a PyTorch `DataLoader` that will skip the first batches.
Expand All @@ -1070,24 +1155,30 @@ class SkipDataLoader(DataLoader):
The dataset to use to build this datalaoder.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
kwargs:
All other keyword arguments to pass to the regular `DataLoader` initialization.
"""

def __init__(self, dataset, skip_batches=0, **kwargs):
super().__init__(dataset, **kwargs)
def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
self.skip_batches = skip_batches

def __iter__(self):
for index, batch in enumerate(super().__iter__()):
for index, batch in enumerate(self.base_dataloader.__iter__()):
if index >= self.skip_batches:
self._update_state_dict()
yield batch


def skip_first_batches(dataloader, num_batches=0):
"""
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
"""
if is_torchdata_stateful_dataloader_available():
from torchdata.stateful_dataloader import StatefulDataLoader

state = PartialState()
if state.distributed_type == DistributedType.XLA:
device = dataloader.device
Expand Down Expand Up @@ -1131,6 +1222,7 @@ def skip_first_batches(dataloader, num_batches=0):
split_batches=dataloader.split_batches,
batch_sampler=new_batch_sampler,
_drop_last=dataloader._drop_last,
use_stateful_dataloader=dataloader.use_stateful_dataloader,
**kwargs,
)
elif isinstance(dataloader, DataLoaderShard):
Expand All @@ -1147,12 +1239,17 @@ def skip_first_batches(dataloader, num_batches=0):
device=dataloader.device,
rng_types=dataloader.rng_types,
synchronized_generator=dataloader.synchronized_generator,
use_stateful_dataloader=dataloader.use_stateful_dataloader,
**kwargs,
)
else:
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
dataloader = SkipDataLoader(
dataset, skip_batches=num_batches, use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs
)
elif is_torchdata_stateful_dataloader_available() and isinstance(dataloader, StatefulDataLoader):
dataloader = StatefulDataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
else:
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/test_utils/scripts/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,12 @@ def test_gradient_accumulation_with_opt_and_scheduler(

def test_dataloader_break():
accelerator = Accelerator()

first_dset = RegressionDataset(length=80)
first_dataloader = DataLoader(first_dset, batch_size=16)
second_dset = RegressionDataset(length=96)
second_dataloader = DataLoader(second_dset, batch_size=16)
first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader)

assert accelerator.gradient_state.active_dataloader is None
for iteration, _ in enumerate(first_dataloader):
assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader)
Expand Down
Loading

0 comments on commit ad3f574

Please sign in to comment.