Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add early support for torchdata.stateful_dataloader.StatefulDataLoader within the Accelerator #2895

Merged
merged 74 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
79a8fa2
temporary commit
byi8220 Jun 22, 2024
efa1e7d
checkout?
byi8220 Jun 22, 2024
8dc107d
dataloader wrapper
byi8220 Jun 22, 2024
f342f4c
tmp
byi8220 Jun 22, 2024
065849a
weird failing test
byi8220 Jun 22, 2024
1e3fad1
trying multiple inheritance
byi8220 Jun 22, 2024
a41cf38
DataLoaderAdapter
byi8220 Jun 23, 2024
8831488
make style
byi8220 Jun 23, 2024
140f1e6
Some dark magic dynamic reflection (for backwards compat)
byi8220 Jun 23, 2024
727afeb
typo
byi8220 Jun 23, 2024
73683b4
some tests
byi8220 Jun 25, 2024
32c318e
more mixin stuff
byi8220 Jun 25, 2024
57c6f57
maybe found broken test?
byi8220 Jun 25, 2024
ed612d1
this is a very invasive feature
byi8220 Jun 25, 2024
511050e
i think the feature is done?
byi8220 Jun 26, 2024
8dbc1a3
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Jun 26, 2024
df43960
add xpu support (#2864)
faaany Jun 26, 2024
c778e32
Merge branch 'stateful-dataloader' of https://github.com/byi8220/acce…
byi8220 Jun 26, 2024
4e00055
better tests
byi8220 Jun 26, 2024
0471fe3
discovered a bug
byi8220 Jun 26, 2024
3036b7f
maybe fixed bug?
byi8220 Jun 26, 2024
9ade2e9
make style
byi8220 Jun 26, 2024
ba0f5c6
hopefully this is PR ready
byi8220 Jun 26, 2024
b774291
properly skip tests
byi8220 Jun 26, 2024
fde597d
parameterize
byi8220 Jun 26, 2024
f273abc
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Jul 3, 2024
8a46eb6
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Jul 5, 2024
e4e1cac
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Jul 13, 2024
8bf2fe2
temporary commit
byi8220 Jun 22, 2024
ca4338d
checkout?
byi8220 Jun 22, 2024
c38f317
dataloader wrapper
byi8220 Jun 22, 2024
17a2a19
tmp
byi8220 Jun 22, 2024
b39a606
weird failing test
byi8220 Jun 22, 2024
d1e82e0
trying multiple inheritance
byi8220 Jun 22, 2024
d99d734
DataLoaderAdapter
byi8220 Jun 23, 2024
39b2866
make style
byi8220 Jun 23, 2024
f58f609
Some dark magic dynamic reflection (for backwards compat)
byi8220 Jun 23, 2024
f2119cf
typo
byi8220 Jun 23, 2024
7adec94
some tests
byi8220 Jun 25, 2024
8850af3
more mixin stuff
byi8220 Jun 25, 2024
6ff0f68
maybe found broken test?
byi8220 Jun 25, 2024
4f28d2e
this is a very invasive feature
byi8220 Jun 25, 2024
a9b637d
i think the feature is done?
byi8220 Jun 26, 2024
0384543
better tests
byi8220 Jun 26, 2024
0e0515d
discovered a bug
byi8220 Jun 26, 2024
809aca0
maybe fixed bug?
byi8220 Jun 26, 2024
5145c2d
make style
byi8220 Jun 26, 2024
ca74ff2
hopefully this is PR ready
byi8220 Jun 26, 2024
a8f8bf3
properly skip tests
byi8220 Jun 26, 2024
59738f4
parameterize
byi8220 Jun 26, 2024
d264939
Merge branch 'stateful-dataloader' of https://github.com/byi8220/acce…
byi8220 Jul 15, 2024
8f04c1e
Update src/accelerate/utils/dataclasses.py
byi8220 Jul 15, 2024
45db4b9
Update src/accelerate/data_loader.py
byi8220 Jul 15, 2024
0ffc64b
merge conflicts
byi8220 Jul 15, 2024
03a7774
Merge branch 'stateful-dataloader' of https://github.com/byi8220/acce…
byi8220 Jul 15, 2024
8d2c6c3
move imports
byi8220 Jul 15, 2024
6bfe871
make style
byi8220 Jul 15, 2024
7a344e4
merge conflicts?
byi8220 Jul 17, 2024
6ff997e
merges are breaking tests
byi8220 Jul 17, 2024
4739524
fix test name
byi8220 Jul 17, 2024
4de9159
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Jul 22, 2024
abf815a
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Jul 24, 2024
06597d4
Require safetensors>=0.4.3
byi8220 Jul 24, 2024
4142c7f
undo last commit
byi8220 Jul 24, 2024
f02f18c
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Jul 29, 2024
35977ca
minor style
byi8220 Jul 29, 2024
597e910
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Aug 5, 2024
419f607
Merge branch 'main' into stateful-dataloader
byi8220 Aug 20, 2024
4188d4c
address pr comments
byi8220 Aug 20, 2024
51377a4
Torchdata version 0.8.0 is stable now
byi8220 Aug 20, 2024
f4b6bb5
added docs and require torchdata>=0.8.0 for testing
byi8220 Aug 20, 2024
d02dfcc
test base_dataloader attr doesn't cause infinite recursion
byi8220 Aug 21, 2024
21bc420
address pr
byi8220 Aug 21, 2024
74e2f53
replace super().__iter__ with self.base_dataloader.__iter__
byi8220 Aug 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/basic_tutorials/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,9 @@ During training, you may want to save the current state of the model, optimizer,
To further customize where and how states are saved through [`~Accelerator.save_state`], use the [`~utils.ProjectConfiguration`] class. For example, if `automatic_checkpoint_naming` is enabled, each saved checkpoint is stored at `Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}`.

Any other stateful items to be stored should be registered with the [`~Accelerator.register_for_checkpointing`] method so they can be saved and loaded. Every object passed to this method to be stored must have a `load_state_dict` and `state_dict` function.

<Note>

If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, you can additionally pass `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`]. This extends Accelerate's DataLoader classes with a `load_state_dict` and `state_dict` function, and makes it so `Accelerator.save_state` and `Accelerator.load_state` also track how far into the training dataset it has read when persisting the model.

</Note>
6 changes: 6 additions & 0 deletions docs/source/concept_guides/internal_mechanism.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,10 @@ setting the same seed in the main random number generator in all processes.

</Tip>

<Note>

If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, and you have passed `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`], these classes will directly inherit from `StatefulDataLoader` instead, and maintain a `state_dict`.

</Note>

For more details about the internals, see the [Internals page](package_reference/torch_wrappers).
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"datasets",
"diffusers",
"evaluate",
"torchdata>=0.8.0",
"torchpippy>=0.2.0",
"transformers",
"scipy",
Expand Down
7 changes: 7 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,12 @@ def use_seedable_sampler(self):
def non_blocking(self):
return self.dataloader_config.non_blocking

@property
def use_stateful_dataloader(self):
if hasattr(self.dataloader_config, "use_stateful_dataloader"):
return self.dataloader_config.use_stateful_dataloader
return False

@property
def project_dir(self):
return self.project_configuration.project_dir
Expand Down Expand Up @@ -2068,6 +2074,7 @@ def prepare_data_loader(
slice_fn_for_dispatch=slice_fn_for_dispatch,
use_seedable_sampler=self.use_seedable_sampler,
non_blocking=self.non_blocking,
use_stateful_dataloader=self.use_stateful_dataloader,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
Expand Down
127 changes: 112 additions & 15 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_data_structure,
initialize_tensors,
is_torch_version,
is_torchdata_stateful_dataloader_available,
send_to_device,
slice_tensors,
synchronize_rng_states,
Expand Down Expand Up @@ -388,9 +389,75 @@ def end(self):
self.gradient_state._remove_dataloader(self)


class DataLoaderShard(DataLoader, DataLoaderStateMixin):
class DataLoaderAdapter:
"""
Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup.
A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
"""

def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
self.use_stateful_dataloader = use_stateful_dataloader
if is_torchdata_stateful_dataloader_available():
from torchdata.stateful_dataloader import StatefulDataLoader

if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():
raise ImportError(
"StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
)
if use_stateful_dataloader:
self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
else:
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)

# Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641
# In C++ terms, this is analogous to creating `DataLoaderAdapter<T> : T`, where T is a DataLoader or
# StatefulDataLoader
#
# The same functionality could be achieved by directly creating the required subclasses for both {DataLoader,
# StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional
# dispatching scattered throughout various functions and files.
#
# This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work
# transparently.
#
# A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit),
# but this would not be backwards compatible with existing code which assumes
# DataLoaderShard/DataLoaderDispatcher are DataLoaders.
base_cls = self.__class__
base_cls_name = self.__class__.__name__
parent_cls_name = self.base_dataloader.__class__
self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me just bring up (again) that another solution could be monkey-patching __instancecheck__ on DataLoader. Not saying that it's less hacky, just wanted to raise awareness :)


if hasattr(self.base_dataloader, "state_dict"):
self.dl_state_dict = self.base_dataloader.state_dict()

def __getattr__(self, name):
# Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.
if name == "base_dataloader":
raise AttributeError()
# Delegate attribute access to the internal dataloader
return getattr(self.base_dataloader, name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of an edge case: Let's also check if the name is not "base_dataloader", and if it is to raise an AttributeError, to avoid an infinite recursion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give a code example of how infinite recursion would happen here?

If I'm reading the python3 docs for __getattr__() correctly, it states "Note that if the attribute is found through the normal mechanism, __getattr__() is not called." IIUC, base_dataloader should always be retrievable through the normal mechanism.

If I add the following block into test_dataloader_inheritance() in test_data_loader.py (without making any changes), the tests pass without causing an infinite recursion:

        assert isinstance(skip_dl.base_dataloader, DataLoader)
        assert isinstance(dl_shard.base_dataloader, DataLoader)
        assert isinstance(dl_dispatcher.base_dataloader, DataLoader)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give a code example of how infinite recursion would happen here?

Yes, that would be for the edge case of an attribute getting called on the class, i.e. before it is instantiated. In that case, the base_dataloader attribute does not exist. Now you could say "who would do such a pernicious thing?", but it's a bug that actually happened in another project and for some reason DeepSpeed would do this (on a module, not a data loader, but let's rather be safe than sorry).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still not entirely sure how this could happen, but I added a check in __getattr__.


def state_dict(self):
return self.dl_state_dict

def load_state_dict(self, state_dict):
self.base_dataloader.load_state_dict(state_dict)
self.dl_state_dict = self.state_dict

def _update_state_dict(self):
# The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
# E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
# what it wants to yield.
#
# _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
if hasattr(self.base_dataloader, "state_dict"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment here when this needs to be called and with the context on why it's required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment here, kinda clunky though.

self.dl_state_dict = self.base_dataloader.state_dict()


class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.

Args:
dataset (`torch.utils.data.dataset.Dataset`):
Expand All @@ -409,6 +476,8 @@ class DataLoaderShard(DataLoader, DataLoaderStateMixin):
A random number generator to keep synchronized across processes.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
**kwargs (additional keyword arguments, *optional*):
All other keyword arguments to pass to the regular `DataLoader` initialization.

Expand All @@ -428,11 +497,12 @@ def __init__(
rng_types=None,
synchronized_generator=None,
skip_batches=0,
use_stateful_dataloader=False,
_drop_last: bool = False,
_non_blocking: bool = False,
**kwargs,
):
super().__init__(dataset, **kwargs)
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
self.device = device
self.rng_types = rng_types
self.synchronized_generator = synchronized_generator
Expand All @@ -448,7 +518,7 @@ def __iter__(self):
self.begin()

self.set_epoch(self.iteration)
dataloader_iter = super().__iter__()
dataloader_iter = self.base_dataloader.__iter__()
# We iterate one batch ahead to check when we are at the end
try:
current_batch = next(dataloader_iter)
Expand All @@ -461,6 +531,7 @@ def __iter__(self):
# But we still move it to the device so it is done before `StopIteration` is reached
if self.device is not None:
current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
self._update_state_dict()
next_batch = next(dataloader_iter)
if batch_index >= self.skip_batches:
yield current_batch
Expand Down Expand Up @@ -564,10 +635,10 @@ def dataloader(self):
return self._loader


class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each
process their part of the batch.
Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process
their part of the batch.

Args:
split_batches (`bool`, *optional*, defaults to `False`):
Expand All @@ -579,6 +650,8 @@ class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
size of the `dataloader` is a round multiple of `batch_size`.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning of an iteration.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.

**Available attributes:**

Expand All @@ -594,6 +667,7 @@ def __init__(
dataset,
split_batches: bool = False,
skip_batches=0,
use_stateful_dataloader=False,
_drop_last: bool = False,
_non_blocking: bool = False,
slice_fn=None,
Expand All @@ -606,7 +680,7 @@ def __init__(
# We need to save the shuffling state of the DataPipe
if isinstance(dataset, ShufflerIterDataPipe):
shuffle = dataset._shuffle_enabled
super().__init__(dataset, **kwargs)
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
self.split_batches = split_batches
if shuffle:
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
Expand All @@ -627,12 +701,14 @@ def _fetch_batches(self, iterator):
try:
if self.split_batches:
# One batch of the main iterator is dispatched and split.
self._update_state_dict()
batch = next(iterator)
else:
# num_processes batches of the main iterator are concatenated then dispatched and split.
# We add the batches one by one so we have the remainder available when drop_last=False.
batches = []
for _ in range(self.state.num_processes):
self._update_state_dict()
batches.append(next(iterator))
try:
batch = concatenate(batches, dim=0)
Expand Down Expand Up @@ -673,9 +749,9 @@ def __iter__(self):
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
# But, we only iterate through the DataLoader on process 0.
main_iterator = super().__iter__()
main_iterator = self.base_dataloader.__iter__()
elif self.state.process_index == 0:
main_iterator = super().__iter__()
main_iterator = self.base_dataloader.__iter__()
stop_iteration = False
self._stop_iteration = False
first_batch = None
Expand Down Expand Up @@ -812,6 +888,7 @@ def prepare_data_loader(
slice_fn_for_dispatch: Optional[Callable] = None,
use_seedable_sampler: bool = False,
non_blocking: bool = False,
use_stateful_dataloader: bool = False,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
Expand Down Expand Up @@ -873,6 +950,10 @@ def prepare_data_loader(
non_blocking (`bool`, *optional*, defaults to `False`):
If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
"If set to true, the dataloader prepared by the Accelerator will be backed by "
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."


Returns:
Expand Down Expand Up @@ -1006,6 +1087,7 @@ def prepare_data_loader(
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
slice_fn=slice_fn_for_dispatch,
use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)
elif sampler_is_batch_sampler:
Expand All @@ -1018,6 +1100,7 @@ def prepare_data_loader(
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
synchronized_generator=synchronized_generator,
use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)
else:
Expand All @@ -1029,6 +1112,7 @@ def prepare_data_loader(
synchronized_generator=synchronized_generator,
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)

Expand All @@ -1046,6 +1130,7 @@ class SkipBatchSampler(BatchSampler):

def __init__(self, batch_sampler, skip_batches=0):
self.batch_sampler = batch_sampler
self.sampler = batch_sampler.sampler
self.skip_batches = skip_batches

def __iter__(self):
Expand All @@ -1061,7 +1146,7 @@ def __len__(self):
return len(self.batch_sampler) - self.skip_batches


class SkipDataLoader(DataLoader):
class SkipDataLoader(DataLoaderAdapter):
"""
Subclass of a PyTorch `DataLoader` that will skip the first batches.

Expand All @@ -1070,24 +1155,30 @@ class SkipDataLoader(DataLoader):
The dataset to use to build this datalaoder.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
kwargs:
All other keyword arguments to pass to the regular `DataLoader` initialization.
"""

def __init__(self, dataset, skip_batches=0, **kwargs):
super().__init__(dataset, **kwargs)
def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
self.skip_batches = skip_batches

def __iter__(self):
for index, batch in enumerate(super().__iter__()):
for index, batch in enumerate(self.base_dataloader.__iter__()):
if index >= self.skip_batches:
self._update_state_dict()
yield batch


def skip_first_batches(dataloader, num_batches=0):
"""
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
"""
if is_torchdata_stateful_dataloader_available():
from torchdata.stateful_dataloader import StatefulDataLoader

state = PartialState()
if state.distributed_type == DistributedType.XLA:
device = dataloader.device
Expand Down Expand Up @@ -1131,6 +1222,7 @@ def skip_first_batches(dataloader, num_batches=0):
split_batches=dataloader.split_batches,
batch_sampler=new_batch_sampler,
_drop_last=dataloader._drop_last,
use_stateful_dataloader=dataloader.use_stateful_dataloader,
**kwargs,
)
elif isinstance(dataloader, DataLoaderShard):
Expand All @@ -1147,12 +1239,17 @@ def skip_first_batches(dataloader, num_batches=0):
device=dataloader.device,
rng_types=dataloader.rng_types,
synchronized_generator=dataloader.synchronized_generator,
use_stateful_dataloader=dataloader.use_stateful_dataloader,
**kwargs,
)
else:
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
dataloader = SkipDataLoader(
dataset, skip_batches=num_batches, use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs
)
elif is_torchdata_stateful_dataloader_available() and isinstance(dataloader, StatefulDataLoader):
dataloader = StatefulDataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
else:
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/test_utils/scripts/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,12 @@ def test_gradient_accumulation_with_opt_and_scheduler(

def test_dataloader_break():
accelerator = Accelerator()

first_dset = RegressionDataset(length=80)
first_dataloader = DataLoader(first_dset, batch_size=16)
second_dset = RegressionDataset(length=96)
second_dataloader = DataLoader(second_dset, batch_size=16)
first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader)

assert accelerator.gradient_state.active_dataloader is None
for iteration, _ in enumerate(first_dataloader):
assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader)
Expand Down
Loading
Loading