Skip to content

Commit

Permalink
Created duplicate Stateful Dataloader derivative classes
Browse files Browse the repository at this point in the history
  • Loading branch information
byi8220 committed Jul 29, 2024
1 parent 35977ca commit f5971fe
Show file tree
Hide file tree
Showing 4 changed files with 594 additions and 123 deletions.
11 changes: 6 additions & 5 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from huggingface_hub import split_torch_state_dict_into_shards

from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
from .data_loader import prepare_data_loader, skip_first_batches, is_data_loader_dispatcher
from .hooks import AlignDevicesHook
from .logging import get_logger
from .optimizer import AcceleratedOptimizer
Expand Down Expand Up @@ -90,6 +90,7 @@
is_npu_available,
is_torch_version,
is_torch_xla_available,
is_torchdata_stateful_dataloader_available,
is_xpu_available,
load_fsdp_model,
load_fsdp_optimizer,
Expand Down Expand Up @@ -155,7 +156,6 @@
_even_batches = object()
_use_seedable_sampler = object()


class Accelerator:
"""
Creates an instance of an accelerator for distributed training (on multi-GPU, TPU) or mixed precision training.
Expand Down Expand Up @@ -1137,7 +1137,7 @@ def join_uneven_inputs(self, joinables, even_batches=None):
iterable_dl_seen = False
# override value in batch sampler for map-style datasets
for dl_idx, dl in enumerate(self._dataloaders):
if isinstance(dl, DataLoaderDispatcher):
if is_data_loader_dispatcher(dl):
iterable_dl_seen = True
continue
dl_even_batches_values.append((dl_idx, dl.batch_sampler.even_batches))
Expand Down Expand Up @@ -1994,11 +1994,12 @@ def _prepare_msamp(self, *args):
return tuple(result)

def prepare_data_loader(
self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None
self, data_loader, device_placement=None, slice_fn_for_dispatch=None
):
"""
Prepares a PyTorch DataLoader for training in any distributed setup. It is recommended to use
[`Accelerator.prepare`] instead.
[`Accelerator.prepare`] instead.
If config.use_stateful_dataloader is set, prepares a torchdata StatefulDataLoader instead.
Args:
data_loader (`torch.utils.data.DataLoader`):
Expand Down
Loading

0 comments on commit f5971fe

Please sign in to comment.