Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add early support for torchdata.stateful_dataloader.StatefulDataLoader within the Accelerator #2895

Merged
merged 74 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 67 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
79a8fa2
temporary commit
byi8220 Jun 22, 2024
efa1e7d
checkout?
byi8220 Jun 22, 2024
8dc107d
dataloader wrapper
byi8220 Jun 22, 2024
f342f4c
tmp
byi8220 Jun 22, 2024
065849a
weird failing test
byi8220 Jun 22, 2024
1e3fad1
trying multiple inheritance
byi8220 Jun 22, 2024
a41cf38
DataLoaderAdapter
byi8220 Jun 23, 2024
8831488
make style
byi8220 Jun 23, 2024
140f1e6
Some dark magic dynamic reflection (for backwards compat)
byi8220 Jun 23, 2024
727afeb
typo
byi8220 Jun 23, 2024
73683b4
some tests
byi8220 Jun 25, 2024
32c318e
more mixin stuff
byi8220 Jun 25, 2024
57c6f57
maybe found broken test?
byi8220 Jun 25, 2024
ed612d1
this is a very invasive feature
byi8220 Jun 25, 2024
511050e
i think the feature is done?
byi8220 Jun 26, 2024
8dbc1a3
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Jun 26, 2024
df43960
add xpu support (#2864)
faaany Jun 26, 2024
c778e32
Merge branch 'stateful-dataloader' of https://github.com/byi8220/acce…
byi8220 Jun 26, 2024
4e00055
better tests
byi8220 Jun 26, 2024
0471fe3
discovered a bug
byi8220 Jun 26, 2024
3036b7f
maybe fixed bug?
byi8220 Jun 26, 2024
9ade2e9
make style
byi8220 Jun 26, 2024
ba0f5c6
hopefully this is PR ready
byi8220 Jun 26, 2024
b774291
properly skip tests
byi8220 Jun 26, 2024
fde597d
parameterize
byi8220 Jun 26, 2024
f273abc
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Jul 3, 2024
8a46eb6
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Jul 5, 2024
e4e1cac
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Jul 13, 2024
8bf2fe2
temporary commit
byi8220 Jun 22, 2024
ca4338d
checkout?
byi8220 Jun 22, 2024
c38f317
dataloader wrapper
byi8220 Jun 22, 2024
17a2a19
tmp
byi8220 Jun 22, 2024
b39a606
weird failing test
byi8220 Jun 22, 2024
d1e82e0
trying multiple inheritance
byi8220 Jun 22, 2024
d99d734
DataLoaderAdapter
byi8220 Jun 23, 2024
39b2866
make style
byi8220 Jun 23, 2024
f58f609
Some dark magic dynamic reflection (for backwards compat)
byi8220 Jun 23, 2024
f2119cf
typo
byi8220 Jun 23, 2024
7adec94
some tests
byi8220 Jun 25, 2024
8850af3
more mixin stuff
byi8220 Jun 25, 2024
6ff0f68
maybe found broken test?
byi8220 Jun 25, 2024
4f28d2e
this is a very invasive feature
byi8220 Jun 25, 2024
a9b637d
i think the feature is done?
byi8220 Jun 26, 2024
0384543
better tests
byi8220 Jun 26, 2024
0e0515d
discovered a bug
byi8220 Jun 26, 2024
809aca0
maybe fixed bug?
byi8220 Jun 26, 2024
5145c2d
make style
byi8220 Jun 26, 2024
ca74ff2
hopefully this is PR ready
byi8220 Jun 26, 2024
a8f8bf3
properly skip tests
byi8220 Jun 26, 2024
59738f4
parameterize
byi8220 Jun 26, 2024
d264939
Merge branch 'stateful-dataloader' of https://github.com/byi8220/acce…
byi8220 Jul 15, 2024
8f04c1e
Update src/accelerate/utils/dataclasses.py
byi8220 Jul 15, 2024
45db4b9
Update src/accelerate/data_loader.py
byi8220 Jul 15, 2024
0ffc64b
merge conflicts
byi8220 Jul 15, 2024
03a7774
Merge branch 'stateful-dataloader' of https://github.com/byi8220/acce…
byi8220 Jul 15, 2024
8d2c6c3
move imports
byi8220 Jul 15, 2024
6bfe871
make style
byi8220 Jul 15, 2024
7a344e4
merge conflicts?
byi8220 Jul 17, 2024
6ff997e
merges are breaking tests
byi8220 Jul 17, 2024
4739524
fix test name
byi8220 Jul 17, 2024
4de9159
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Jul 22, 2024
abf815a
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Jul 24, 2024
06597d4
Require safetensors>=0.4.3
byi8220 Jul 24, 2024
4142c7f
undo last commit
byi8220 Jul 24, 2024
f02f18c
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Jul 29, 2024
35977ca
minor style
byi8220 Jul 29, 2024
597e910
Merge branch 'huggingface:main' into stateful-dataloader
byi8220 Aug 5, 2024
419f607
Merge branch 'main' into stateful-dataloader
byi8220 Aug 20, 2024
4188d4c
address pr comments
byi8220 Aug 20, 2024
51377a4
Torchdata version 0.8.0 is stable now
byi8220 Aug 20, 2024
f4b6bb5
added docs and require torchdata>=0.8.0 for testing
byi8220 Aug 20, 2024
d02dfcc
test base_dataloader attr doesn't cause infinite recursion
byi8220 Aug 21, 2024
21bc420
address pr
byi8220 Aug 21, 2024
74e2f53
replace super().__iter__ with self.base_dataloader.__iter__
byi8220 Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,10 @@ def use_seedable_sampler(self):
def non_blocking(self):
return self.dataloader_config.non_blocking

@property
def use_stateful_dataloader(self):
return self.dataloader_config.use_stateful_dataloader
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved

@property
def project_dir(self):
return self.project_configuration.project_dir
Expand Down Expand Up @@ -1593,9 +1597,9 @@ def _prepare_deepspeed(self, *args):

deepspeed_plugin = self.state.deepspeed_plugin

is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args)
is_dataloader_present = any((isinstance(obj, torch.utils.data.DataLoader)) for obj in args)
result = [
self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj
self._prepare_one(obj, first_pass=True) if (isinstance(obj, torch.utils.data.DataLoader)) else obj
byi8220 marked this conversation as resolved.
Show resolved Hide resolved
for obj in args
]

Expand Down Expand Up @@ -2038,6 +2042,7 @@ def prepare_data_loader(
slice_fn_for_dispatch=slice_fn_for_dispatch,
use_seedable_sampler=self.use_seedable_sampler,
non_blocking=self.non_blocking,
use_stateful_dataloader=self.use_stateful_dataloader,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
Expand Down
110 changes: 98 additions & 12 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_data_structure,
initialize_tensors,
is_torch_version,
is_torchdata_stateful_dataloader_available,
send_to_device,
slice_tensors,
synchronize_rng_states,
Expand Down Expand Up @@ -388,9 +389,65 @@ def end(self):
self.gradient_state._remove_dataloader(self)


class DataLoaderShard(DataLoader, DataLoaderStateMixin):
class DataLoaderAdapter:
"""
Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup.
A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
"""

def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
self.use_stateful_dataloader = use_stateful_dataloader
if is_torchdata_stateful_dataloader_available():
from torchdata.stateful_dataloader import StatefulDataLoader

if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():
raise ImportError("StatefulDataLoader is not available. Please install torchdata to use it.")
byi8220 marked this conversation as resolved.
Show resolved Hide resolved
if use_stateful_dataloader:
self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
else:
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)

# Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641
# In C++ terms, this is analogous to creating `DataLoaderAdapter<T> : T`, where T is a DataLoader or
# StatefulDataLoader
#
# The same functionality could be achieved by directly creating the required subclasses for both {DataLoader,
# StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional
# dispatching scattered throughout various functions and files.
#
# This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work
# transparently.
#
# A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit),
# but this would not be backwards compatible with existing code which assumes
# DataLoaderShard/DataLoaderDispatcher are DataLoaders.
base_cls = self.__class__
base_cls_name = self.__class__.__name__
parent_cls_name = self.base_dataloader.__class__
self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me just bring up (again) that another solution could be monkey-patching __instancecheck__ on DataLoader. Not saying that it's less hacky, just wanted to raise awareness :)


# Allow this class to transparently pass through attributes from the underlying class
byi8220 marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(self.base_dataloader, "state_dict"):
self.dl_state_dict = self.base_dataloader.state_dict()

for attr in self.base_dataloader.__dict__.keys():
setattr(self, attr, getattr(self.base_dataloader, attr))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kinda looks dangerous. For example, this skips @property, is that intended? We could instead use __getattr__ to dispatch to self.base_dataloader.

If we want to stick this this, more succinct code could be: self.__dict__.update(self.base_loader.__dict__) or vars(self).update(self.base_loader.__dict__)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kinda looks dangerous.

Kinda agree with you, but all dynamic reflection looks dangerous to me.

I did write up an alternative which avoids the wizardry and just duplicates all the code required over here in: byi8220/accelerate@stateful-dataloader...byi8220:accelerate:stateful-dataloader-2

That code is messier and involves way more duplication, but much more explicit in what it does. If enough people feel the reflection approach is way too hacky and this feature doesn't justify it, I'm fine with doing that instead.

We could instead use getattr to dispatch to self.base_dataloader.

I updated the PR to do that instead.


def state_dict(self):
return self.dl_state_dict

def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
self.dl_state_dict = self.state_dict

def _save_state_dict(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, the name is not quite fitting, isn't it more like update_state_dict or so? Also, maybe we can avoid this all by not having a static self.dl_state_dict attribute but instead the state_dict method just returns self.base_dataloader.state_dict().

Copy link
Contributor Author

@byi8220 byi8220 Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, the name is not quite fitting, isn't it more like update_state_dict or so?

Changed to _update_state_dict

Also, maybe we can avoid this all by not having a static self.dl_state_dict attribute but instead the state_dict method just returns self.base_dataloader.state_dict().

I'm not sure if we can. The base dataloader's state dict is one ahead of what we're yielding, so we couldn't do a passthrough. Some additional context in the comments of a6e192c#r1704736815

if hasattr(self.base_dataloader, "state_dict"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment here when this needs to be called and with the context on why it's required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment here, kinda clunky though.

self.dl_state_dict = super().state_dict()


class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.

Args:
dataset (`torch.utils.data.dataset.Dataset`):
Expand All @@ -409,6 +466,8 @@ class DataLoaderShard(DataLoader, DataLoaderStateMixin):
A random number generator to keep synchronized across processes.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
**kwargs (additional keyword arguments, *optional*):
All other keyword arguments to pass to the regular `DataLoader` initialization.

Expand All @@ -428,11 +487,12 @@ def __init__(
rng_types=None,
synchronized_generator=None,
skip_batches=0,
use_stateful_dataloader=False,
_drop_last: bool = False,
_non_blocking: bool = False,
**kwargs,
):
super().__init__(dataset, **kwargs)
super().__init__(dataset, use_stateful_dataloader, **kwargs)
byi8220 marked this conversation as resolved.
Show resolved Hide resolved
self.device = device
self.rng_types = rng_types
self.synchronized_generator = synchronized_generator
Expand Down Expand Up @@ -461,6 +521,7 @@ def __iter__(self):
# But we still move it to the device so it is done before `StopIteration` is reached
if self.device is not None:
current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
self._save_state_dict()
next_batch = next(dataloader_iter)
if batch_index >= self.skip_batches:
yield current_batch
Expand Down Expand Up @@ -559,10 +620,10 @@ def batch_sampler(self):
return self._loader.batch_sampler


class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each
process their part of the batch.
Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process
their part of the batch.

Args:
split_batches (`bool`, *optional*, defaults to `False`):
Expand All @@ -574,6 +635,8 @@ class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
size of the `dataloader` is a round multiple of `batch_size`.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning of an iteration.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.

**Available attributes:**

Expand All @@ -589,6 +652,7 @@ def __init__(
dataset,
split_batches: bool = False,
skip_batches=0,
use_stateful_dataloader=False,
_drop_last: bool = False,
_non_blocking: bool = False,
slice_fn=None,
Expand All @@ -601,7 +665,7 @@ def __init__(
# We need to save the shuffling state of the DataPipe
if isinstance(dataset, ShufflerIterDataPipe):
shuffle = dataset._shuffle_enabled
super().__init__(dataset, **kwargs)
super().__init__(dataset, use_stateful_dataloader, **kwargs)
byi8220 marked this conversation as resolved.
Show resolved Hide resolved
self.split_batches = split_batches
if shuffle:
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
Expand All @@ -622,12 +686,14 @@ def _fetch_batches(self, iterator):
try:
if self.split_batches:
# One batch of the main iterator is dispatched and split.
self._save_state_dict()
batch = next(iterator)
else:
# num_processes batches of the main iterator are concatenated then dispatched and split.
# We add the batches one by one so we have the remainder available when drop_last=False.
batches = []
for _ in range(self.state.num_processes):
self._save_state_dict()
batches.append(next(iterator))
try:
batch = concatenate(batches, dim=0)
Expand Down Expand Up @@ -807,6 +873,7 @@ def prepare_data_loader(
slice_fn_for_dispatch: Optional[Callable] = None,
use_seedable_sampler: bool = False,
non_blocking: bool = False,
use_stateful_dataloader: bool = False,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
Expand Down Expand Up @@ -868,6 +935,10 @@ def prepare_data_loader(
non_blocking (`bool`, *optional*, defaults to `False`):
If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
"If set to true, the dataloader prepared by the Accelerator will be backed by "
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
This requires a version" " of `torchdata` with StatefulDataLoader to be installed."
byi8220 marked this conversation as resolved.
Show resolved Hide resolved


Returns:
Expand Down Expand Up @@ -1001,6 +1072,7 @@ def prepare_data_loader(
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
slice_fn=slice_fn_for_dispatch,
use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)
elif sampler_is_batch_sampler:
Expand All @@ -1013,6 +1085,7 @@ def prepare_data_loader(
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
synchronized_generator=synchronized_generator,
use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)
else:
Expand All @@ -1024,6 +1097,7 @@ def prepare_data_loader(
synchronized_generator=synchronized_generator,
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)

Expand All @@ -1041,6 +1115,7 @@ class SkipBatchSampler(BatchSampler):

def __init__(self, batch_sampler, skip_batches=0):
self.batch_sampler = batch_sampler
self.sampler = batch_sampler.sampler
self.skip_batches = skip_batches

def __iter__(self):
Expand All @@ -1056,7 +1131,7 @@ def __len__(self):
return len(self.batch_sampler) - self.skip_batches


class SkipDataLoader(DataLoader):
class SkipDataLoader(DataLoaderAdapter):
"""
Subclass of a PyTorch `DataLoader` that will skip the first batches.

Expand All @@ -1065,24 +1140,30 @@ class SkipDataLoader(DataLoader):
The dataset to use to build this datalaoder.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
kwargs:
All other keyword arguments to pass to the regular `DataLoader` initialization.
"""

def __init__(self, dataset, skip_batches=0, **kwargs):
super().__init__(dataset, **kwargs)
def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
super().__init__(dataset, use_stateful_dataloader, **kwargs)
self.skip_batches = skip_batches

def __iter__(self):
for index, batch in enumerate(super().__iter__()):
if index >= self.skip_batches:
self._save_state_dict()
yield batch


def skip_first_batches(dataloader, num_batches=0):
"""
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
"""
if is_torchdata_stateful_dataloader_available():
from torchdata.stateful_dataloader import StatefulDataLoader

dataset = dataloader.dataset
sampler_is_batch_sampler = False
if isinstance(dataset, IterableDataset):
Expand Down Expand Up @@ -1121,6 +1202,7 @@ def skip_first_batches(dataloader, num_batches=0):
split_batches=dataloader.split_batches,
batch_sampler=new_batch_sampler,
_drop_last=dataloader._drop_last,
use_stateful_dataloader=dataloader.use_stateful_dataloader,
**kwargs,
)
elif isinstance(dataloader, DataLoaderShard):
Expand All @@ -1137,13 +1219,17 @@ def skip_first_batches(dataloader, num_batches=0):
device=dataloader.device,
rng_types=dataloader.rng_types,
synchronized_generator=dataloader.synchronized_generator,
use_stateful_dataloader=dataloader.use_stateful_dataloader,
**kwargs,
)
else:
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
dataloader = SkipDataLoader(
dataset, skip_batches=num_batches, use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs
)
elif is_torchdata_stateful_dataloader_available() and isinstance(dataloader, StatefulDataLoader):
dataloader = StatefulDataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
else:
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)

return dataloader
2 changes: 1 addition & 1 deletion src/accelerate/test_utils/scripts/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,12 @@ def test_gradient_accumulation_with_opt_and_scheduler(

def test_dataloader_break():
accelerator = Accelerator()

first_dset = RegressionDataset(length=80)
first_dataloader = DataLoader(first_dset, batch_size=16)
second_dset = RegressionDataset(length=96)
second_dataloader = DataLoader(second_dset, batch_size=16)
first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader)

assert accelerator.gradient_state.active_dataloader is None
for iteration, _ in enumerate(first_dataloader):
assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader)
Expand Down
13 changes: 13 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
is_timm_available,
is_torch_version,
is_torch_xla_available,
is_torchdata_stateful_dataloader_available,
is_torchvision_available,
is_transformers_available,
is_triton_available,
Expand Down Expand Up @@ -420,6 +421,18 @@ def require_trackers(test_case):
)(test_case)


def require_torchdata_stateful_dataloader(test_case):
"""
Decorator marking a test that requires torchdata.stateful_dataloader.

These tests are skipped when torchdata with stateful_dataloader module isn't installed.

"""
return unittest.skipUnless(
is_torchdata_stateful_dataloader_available(), "test requires torchdata.stateful_dataloader"
)(test_case)


class TempDirTestCase(unittest.TestCase):
"""
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
Expand Down
2 changes: 2 additions & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@
is_tensorboard_available,
is_timm_available,
is_torch_xla_available,
is_torchdata_available,
is_torchdata_stateful_dataloader_available,
is_torchvision_available,
is_transformer_engine_available,
is_transformers_available,
Expand Down
10 changes: 9 additions & 1 deletion src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ class DataLoaderConfiguration:
metadata={
"help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process"
" and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
" underlying dataset is an `IterableDataslet`, `False` otherwise."
" underlying dataset is an `IterableDataset`, `False` otherwise."
},
)
even_batches: bool = field(
Expand Down Expand Up @@ -720,6 +720,14 @@ class DataLoaderConfiguration:
" prepared dataloader has `pin_memory` set to `True` to work properly."
},
)
use_stateful_dataloader: bool = field(
default=False,
metadata={
"help": "If set to `True`, the dataloader prepared by the Accelerator will be backed by "
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires a version"
" of `torchdata` with StatefulDataLoader to be installed."
byi8220 marked this conversation as resolved.
Show resolved Hide resolved
},
)


@dataclass
Expand Down
Loading