diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index be7ab6aee3c..b2ea2f667e5 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -34,7 +34,7 @@ from huggingface_hub import split_torch_state_dict_into_shards from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state -from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches +from .data_loader import prepare_data_loader, skip_first_batches, is_data_loader_dispatcher from .hooks import AlignDevicesHook from .logging import get_logger from .optimizer import AcceleratedOptimizer @@ -90,6 +90,7 @@ is_npu_available, is_torch_version, is_torch_xla_available, + is_torchdata_stateful_dataloader_available, is_xpu_available, load_fsdp_model, load_fsdp_optimizer, @@ -155,7 +156,6 @@ _even_batches = object() _use_seedable_sampler = object() - class Accelerator: """ Creates an instance of an accelerator for distributed training (on multi-GPU, TPU) or mixed precision training. @@ -1137,7 +1137,7 @@ def join_uneven_inputs(self, joinables, even_batches=None): iterable_dl_seen = False # override value in batch sampler for map-style datasets for dl_idx, dl in enumerate(self._dataloaders): - if isinstance(dl, DataLoaderDispatcher): + if is_data_loader_dispatcher(dl): iterable_dl_seen = True continue dl_even_batches_values.append((dl_idx, dl.batch_sampler.even_batches)) @@ -1994,11 +1994,12 @@ def _prepare_msamp(self, *args): return tuple(result) def prepare_data_loader( - self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None + self, data_loader, device_placement=None, slice_fn_for_dispatch=None ): """ Prepares a PyTorch DataLoader for training in any distributed setup. It is recommended to use - [`Accelerator.prepare`] instead. + [`Accelerator.prepare`] instead. + If config.use_stateful_dataloader is set, prepares a torchdata StatefulDataLoader instead. Args: data_loader (`torch.utils.data.DataLoader`): diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index a9bb70218f5..7526fa63026 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -14,7 +14,7 @@ import math from contextlib import suppress -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Union, Any import torch from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler @@ -387,67 +387,10 @@ def begin(self): def end(self): "Cleans up the gradient state after exiting the dataloader" self.gradient_state._remove_dataloader(self) - - -class DataLoaderAdapter: - """ - A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For - compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in. - """ - - def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs): - self.use_stateful_dataloader = use_stateful_dataloader - if is_torchdata_stateful_dataloader_available(): - from torchdata.stateful_dataloader import StatefulDataLoader - - if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available(): - raise ImportError("StatefulDataLoader is not available. Please install torchdata to use it.") - if use_stateful_dataloader: - self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs) - else: - self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs) - - # Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641 - # In C++ terms, this is analogous to creating `DataLoaderAdapter : T`, where T is a DataLoader or - # StatefulDataLoader - # - # The same functionality could be achieved by directly creating the required subclasses for both {DataLoader, - # StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional - # dispatching scattered throughout various functions and files. - # - # This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work - # transparently. - # - # A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit), - # but this would not be backwards compatible with existing code which assumes - # DataLoaderShard/DataLoaderDispatcher are DataLoaders. - base_cls = self.__class__ - base_cls_name = self.__class__.__name__ - parent_cls_name = self.base_dataloader.__class__ - self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {}) - - # Allow this class to transparently pass through attributes from the underlying class - if hasattr(self.base_dataloader, "state_dict"): - self.dl_state_dict = self.base_dataloader.state_dict() - - for attr in self.base_dataloader.__dict__.keys(): - setattr(self, attr, getattr(self.base_dataloader, attr)) - - def state_dict(self): - return self.dl_state_dict - - def load_state_dict(self, state_dict): - super().load_state_dict(state_dict) - self.dl_state_dict = self.state_dict - - def _save_state_dict(self): - if hasattr(self.base_dataloader, "state_dict"): - self.dl_state_dict = super().state_dict() - - -class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): + +class DataLoaderShard(DataLoader, DataLoaderStateMixin): """ - Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup. + Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup. Args: dataset (`torch.utils.data.dataset.Dataset`): @@ -466,8 +409,6 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): A random number generator to keep synchronized across processes. skip_batches (`int`, *optional*, defaults to 0): The number of batches to skip at the beginning. - use_stateful_dataloader (`bool`, *optional*, defaults to `False`): - Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. **kwargs (additional keyword arguments, *optional*): All other keyword arguments to pass to the regular `DataLoader` initialization. @@ -487,12 +428,11 @@ def __init__( rng_types=None, synchronized_generator=None, skip_batches=0, - use_stateful_dataloader=False, _drop_last: bool = False, _non_blocking: bool = False, **kwargs, ): - super().__init__(dataset, use_stateful_dataloader, **kwargs) + super().__init__(dataset, **kwargs) self.device = device self.rng_types = rng_types self.synchronized_generator = synchronized_generator @@ -521,7 +461,6 @@ def __iter__(self): # But we still move it to the device so it is done before `StopIteration` is reached if self.device is not None: current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking) - self._save_state_dict() next_batch = next(dataloader_iter) if batch_index >= self.skip_batches: yield current_batch @@ -563,6 +502,10 @@ def total_dataset_length(self): else: return len(self.dataset) + @property + def use_stateful_dataloader(self): + return False + def get_sampler(self): return get_sampler(self) @@ -575,6 +518,39 @@ def set_sampler(self, sampler): if hasattr(self.batch_sampler, "batch_sampler"): self.batch_sampler.batch_sampler.sampler = sampler +def create_data_loader_shard( + dataset, + device=None, + rng_types=None, + synchronized_generator=None, + skip_batches=0, + use_stateful_dataloader = False, + _drop_last: bool = False, + _non_blocking: bool = False, + **kwargs, +): + """Create a `DataLoader` or a `DataLoaderShard` depending on the `use_stateful_dataloader` flag.""" + if use_stateful_dataloader and is_torchdata_stateful_dataloader_available(): + from .stateful_data_loader import StatefulDataLoaderShard + return StatefulDataLoaderShard( + dataset, + device=device, + rng_types=rng_types, + synchronized_generator=synchronized_generator, + skip_batches=skip_batches, + _drop_last=_drop_last, + _non_blocking=_non_blocking, + **kwargs) + else: + return DataLoaderShard( + dataset, + device=device, + rng_types=rng_types, + synchronized_generator=synchronized_generator, + skip_batches=skip_batches, + _drop_last=_drop_last, + _non_blocking=_non_blocking, + **kwargs) if is_torch_xla_available(): import torch_xla.distributed.parallel_loader as xpl @@ -620,9 +596,9 @@ def batch_sampler(self): return self._loader.batch_sampler -class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin): +class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): """ - Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process + Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each process their part of the batch. Args: @@ -635,8 +611,6 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin): size of the `dataloader` is a round multiple of `batch_size`. skip_batches (`int`, *optional*, defaults to 0): The number of batches to skip at the beginning of an iteration. - use_stateful_dataloader (`bool`, *optional*, defaults to `False`): - Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. **Available attributes:** @@ -652,7 +626,6 @@ def __init__( dataset, split_batches: bool = False, skip_batches=0, - use_stateful_dataloader=False, _drop_last: bool = False, _non_blocking: bool = False, slice_fn=None, @@ -665,7 +638,7 @@ def __init__( # We need to save the shuffling state of the DataPipe if isinstance(dataset, ShufflerIterDataPipe): shuffle = dataset._shuffle_enabled - super().__init__(dataset, use_stateful_dataloader, **kwargs) + super().__init__(dataset, **kwargs) self.split_batches = split_batches if shuffle: torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) @@ -686,14 +659,12 @@ def _fetch_batches(self, iterator): try: if self.split_batches: # One batch of the main iterator is dispatched and split. - self._save_state_dict() batch = next(iterator) else: # num_processes batches of the main iterator are concatenated then dispatched and split. # We add the batches one by one so we have the remainder available when drop_last=False. batches = [] for _ in range(self.state.num_processes): - self._save_state_dict() batches.append(next(iterator)) try: batch = concatenate(batches, dim=0) @@ -829,6 +800,10 @@ def total_batch_size(self): def total_dataset_length(self): return len(self.dataset) + @property + def use_stateful_dataloader(self): + return False + def get_sampler(self): return get_sampler(self) @@ -841,6 +816,35 @@ def set_sampler(self, sampler): if hasattr(self.batch_sampler, "batch_sampler"): self.batch_sampler.batch_sampler.sampler = sampler +def create_data_loader_dispatcher( + dataset, + split_batches: bool = False, + skip_batches=0, + use_stateful_dataloader = False, + _drop_last: bool = False, + _non_blocking: bool = False, + slice_fn=None, + **kwargs): + """Create a `DataLoader` or a `DataLoaderDispatcher` depending on the `use_stateful_dataloader` flag.""" + if use_stateful_dataloader and is_torchdata_stateful_dataloader_available(): + from .stateful_data_loader import StatefulDataLoaderDispatcher + return StatefulDataLoaderDispatcher( + dataset, + split_batches=split_batches, + skip_batches=skip_batches, + _drop_last=_drop_last, + _non_blocking=_non_blocking, + slice_fn=slice_fn, + **kwargs) + else: + return DataLoaderDispatcher( + dataset, + split_batches=split_batches, + skip_batches=skip_batches, + _drop_last=_drop_last, + _non_blocking=_non_blocking, + slice_fn=slice_fn, + **kwargs) def get_sampler(dataloader): """ @@ -1065,7 +1069,7 @@ def prepare_data_loader( ) if dispatch_batches: kwargs.pop("generator") - dataloader = DataLoaderDispatcher( + dataloader = create_data_loader_dispatcher( new_dataset, split_batches=split_batches, batch_sampler=new_batch_sampler, @@ -1076,7 +1080,7 @@ def prepare_data_loader( **kwargs, ) elif sampler_is_batch_sampler: - dataloader = DataLoaderShard( + dataloader = create_data_loader_shard( new_dataset, device=device if put_on_device and state.distributed_type != DistributedType.XLA else None, sampler=new_batch_sampler, @@ -1089,7 +1093,7 @@ def prepare_data_loader( **kwargs, ) else: - dataloader = DataLoaderShard( + dataloader = create_data_loader_shard( new_dataset, device=device if put_on_device and state.distributed_type != DistributedType.XLA else None, batch_sampler=new_batch_sampler, @@ -1131,7 +1135,7 @@ def __len__(self): return len(self.batch_sampler) - self.skip_batches -class SkipDataLoader(DataLoaderAdapter): +class SkipDataLoader(DataLoader): """ Subclass of a PyTorch `DataLoader` that will skip the first batches. @@ -1140,22 +1144,49 @@ class SkipDataLoader(DataLoaderAdapter): The dataset to use to build this datalaoder. skip_batches (`int`, *optional*, defaults to 0): The number of batches to skip at the beginning. - use_stateful_dataloader (`bool`, *optional*, defaults to `False`): - Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. kwargs: All other keyword arguments to pass to the regular `DataLoader` initialization. """ - def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs): - super().__init__(dataset, use_stateful_dataloader, **kwargs) + def __init__(self, dataset, skip_batches=0, **kwargs): + super().__init__(dataset, **kwargs) self.skip_batches = skip_batches def __iter__(self): for index, batch in enumerate(super().__iter__()): if index >= self.skip_batches: - self._save_state_dict() yield batch + @property + def use_stateful_dataloader(self): + return False + +def create_skip_data_loader(dataset, skip_batches= 0, use_stateful_dataloader = False, **kwargs): + """Create a `DataLoader` or a `SkipDataLoader` depending on the `use_stateful_dataloader` flag.""" + if use_stateful_dataloader and is_torchdata_stateful_dataloader_available(): + from .stateful_data_loader import StatefulSkipDataLoader + return StatefulSkipDataLoader(dataset, skip_batches=skip_batches, **kwargs) + else: + return SkipDataLoader(dataset, skip_batches=skip_batches, **kwargs) + +def is_data_loader_dispatcher(obj: Any) -> bool: + if is_torchdata_stateful_dataloader_available(): + from .stateful_data_loader import StatefulDataLoaderDispatcher + return isinstance(obj, StatefulDataLoaderDispatcher) or isinstance(obj, DataLoaderDispatcher) + return isinstance(obj, DataLoaderDispatcher) + +def is_data_loader_shard(obj: Any) -> bool: + if is_torchdata_stateful_dataloader_available(): + from .stateful_data_loader import StatefulDataLoaderShard + return isinstance(obj, StatefulDataLoaderShard) or isinstance(obj, DataLoaderShard) + return isinstance(obj, DataLoaderShard) + +def is_skip_data_loader(obj: Any) -> bool: + if is_torchdata_stateful_dataloader_available(): + from .stateful_data_loader import StatefulSkipDataLoader + return isinstance(obj, StatefulSkipDataLoader) or isinstance(obj, SkipDataLoader) + return isinstance(obj, SkipDataLoader) + def skip_first_batches(dataloader, num_batches=0): """ @@ -1193,11 +1224,11 @@ def skip_first_batches(dataloader, num_batches=0): kwargs["drop_last"] = dataloader.drop_last kwargs["batch_size"] = dataloader.batch_size - if isinstance(dataloader, DataLoaderDispatcher): + if is_data_loader_dispatcher(dataloader): if new_batch_sampler is None: # Need to manually skip batches in the dataloader kwargs["skip_batches"] = num_batches - dataloader = DataLoaderDispatcher( + dataloader = create_data_loader_dispatcher( dataset, split_batches=dataloader.split_batches, batch_sampler=new_batch_sampler, @@ -1205,7 +1236,7 @@ def skip_first_batches(dataloader, num_batches=0): use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs, ) - elif isinstance(dataloader, DataLoaderShard): + elif is_data_loader_shard(dataloader): if new_batch_sampler is None: # Need to manually skip batches in the dataloader kwargs["skip_batches"] = num_batches @@ -1214,7 +1245,7 @@ def skip_first_batches(dataloader, num_batches=0): kwargs["batch_size"] = dataloader.batch_size else: kwargs["batch_sampler"] = new_batch_sampler - dataloader = DataLoaderShard( + dataloader = create_data_loader_shard( dataset, device=dataloader.device, rng_types=dataloader.rng_types, diff --git a/src/accelerate/stateful_data_loader.py b/src/accelerate/stateful_data_loader.py new file mode 100644 index 00000000000..93c33f29a23 --- /dev/null +++ b/src/accelerate/stateful_data_loader.py @@ -0,0 +1,435 @@ +from torchdata.stateful_dataloader import StatefulDataLoader + +import math +from contextlib import suppress +from typing import Callable, List, Optional, Union + +import torch +from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler + +from .data_loader import DataLoaderStateMixin, DataLoaderDispatcher, DataLoaderShard, SkipDataLoader, get_sampler +from .logging import get_logger +from .state import AcceleratorState, DistributedType, GradientState, is_torch_xla_available +from .utils import ( + RNGType, + broadcast, + broadcast_object_list, + concatenate, + find_batch_size, + get_data_structure, + initialize_tensors, + is_torch_version, + send_to_device, + slice_tensors, + synchronize_rng_states, +) + +class StatefulDataLoaderShard(StatefulDataLoader, DataLoaderStateMixin): + """ + Subclass of a torchdata `StatefulDataLoader` that will deal with device placement and current distributed setup. + + Args: + dataset (`torch.utils.data.dataset.Dataset`): + The dataset to use to build this datalaoder. + device (`torch.device`, *optional*): + If passed, the device to put all batches on. + rng_types (list of `str` or [`~utils.RNGType`]): + The list of random number generators to synchronize at the beginning of each iteration. Should be one or + several of: + + - `"torch"`: the base torch random number generator + - `"cuda"`: the CUDA random number generator (GPU only) + - `"xla"`: the XLA random number generator (TPU only) + - `"generator"`: an optional `torch.Generator` + synchronized_generator (`torch.Generator`, *optional*): + A random number generator to keep synchronized across processes. + skip_batches (`int`, *optional*, defaults to 0): + The number of batches to skip at the beginning. + **kwargs (additional keyword arguments, *optional*): + All other keyword arguments to pass to the regular `DataLoader` initialization. + + **Available attributes:** + + - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes. + Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total + number of processes + + - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes. + """ + + def __init__( + self, + dataset, + device=None, + rng_types=None, + synchronized_generator=None, + skip_batches=0, + _drop_last: bool = False, + _non_blocking: bool = False, + **kwargs, + ): + super().__init__(dataset, **kwargs) + self.device = device + self.rng_types = rng_types + self.synchronized_generator = synchronized_generator + self.skip_batches = skip_batches + self.gradient_state = GradientState() + self._drop_last = _drop_last + self._non_blocking = _non_blocking + self.iteration = 0 + + def __iter__(self): + if self.rng_types is not None: + synchronize_rng_states(self.rng_types, self.synchronized_generator) + self.begin() + + self.set_epoch(self.iteration) + dataloader_iter = super().__iter__() + # We iterate one batch ahead to check when we are at the end + try: + current_batch = next(dataloader_iter) + except StopIteration: + yield + + batch_index = 0 + while True: + try: + # But we still move it to the device so it is done before `StopIteration` is reached + if self.device is not None: + current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking) + self._save_state_dict() + next_batch = next(dataloader_iter) + if batch_index >= self.skip_batches: + yield current_batch + batch_index += 1 + current_batch = next_batch + except StopIteration: + self.end_of_dataloader = True + if batch_index >= self.skip_batches: + yield current_batch + break + + self.iteration += 1 + self.end() + + def set_epoch(self, epoch: int): + # In case it is manually passed in, the user can set it to what they like + if self.iteration != epoch: + self.iteration = epoch + if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"): + self.batch_sampler.sampler.set_epoch(epoch) + # We support if a custom `Dataset` implementation has `set_epoch` + # or in general HF datasets `Datasets` + elif hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + @property + def total_batch_size(self): + batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler + return ( + batch_sampler.batch_size + if getattr(batch_sampler, "split_batches", False) + else (batch_sampler.batch_size * getattr(batch_sampler, "num_processes", 1)) + ) + + @property + def total_dataset_length(self): + if hasattr(self.dataset, "total_length"): + return self.dataset.total_length + else: + return len(self.dataset) + + @property + def use_stateful_dataloader(self): + return True + + def get_sampler(self): + return get_sampler(self) + + def set_sampler(self, sampler): + sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler) + if sampler_is_batch_sampler: + self.sampler.sampler = sampler + else: + self.batch_sampler.sampler = sampler + if hasattr(self.batch_sampler, "batch_sampler"): + self.batch_sampler.batch_sampler.sampler = sampler + + def state_dict(self): + return self.dl_state_dict + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self.dl_state_dict = self.state_dict + + def _save_state_dict(self): + self.dl_state_dict = super().state_dict() + +class StatefulDataLoaderDispatcher(StatefulDataLoader, DataLoaderStateMixin): + """ + Subclass of a torchdata `StatefulDataLoader` that will iterate and preprocess on process 0 only, then dispatch on each process + their part of the batch. + + Args: + split_batches (`bool`, *optional*, defaults to `False`): + Whether the resulting `DataLoader` should split the batches of the original data loader across devices or + yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of + `num_processes` batches at each iteration). Another way to see this is that the observed batch size will be + the same as the initial `dataloader` if this option is set to `True`, the batch size of the initial + `dataloader` multiplied by `num_processes` otherwise. Setting this option to `True` requires that the batch + size of the `dataloader` is a round multiple of `batch_size`. + skip_batches (`int`, *optional*, defaults to 0): + The number of batches to skip at the beginning of an iteration. + + **Available attributes:** + + - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes. + Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total + number of processes + + - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes. + """ + + def __init__( + self, + dataset, + split_batches: bool = False, + skip_batches=0, + _drop_last: bool = False, + _non_blocking: bool = False, + slice_fn=None, + **kwargs, + ): + shuffle = False + if is_torch_version(">=", "1.11.0"): + from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe + + # We need to save the shuffling state of the DataPipe + if isinstance(dataset, ShufflerIterDataPipe): + shuffle = dataset._shuffle_enabled + super().__init__(dataset, **kwargs) + self.split_batches = split_batches + if shuffle: + torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) + + self.gradient_state = GradientState() + self.state = AcceleratorState() + self._drop_last = _drop_last + self._non_blocking = _non_blocking + self.skip_batches = skip_batches + + self.slice_fn = slice_tensors if slice_fn is None else slice_fn + self.iteration = 0 + + def _fetch_batches(self, iterator): + batches, batch = None, None + # On process 0, we gather the batch to dispatch. + if self.state.process_index == 0: + try: + if self.split_batches: + # One batch of the main iterator is dispatched and split. + self._save_state_dict() + batch = next(iterator) + else: + # num_processes batches of the main iterator are concatenated then dispatched and split. + # We add the batches one by one so we have the remainder available when drop_last=False. + batches = [] + for _ in range(self.state.num_processes): + self._save_state_dict() + batches.append(next(iterator)) + try: + batch = concatenate(batches, dim=0) + except RuntimeError as e: + raise RuntimeError( + "You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`." + "either pass `dispatch_batches=False` and have each process fetch its own batch " + " or pass `split_batches=True`. By doing so, the main process will fetch a full batch and " + "slice it into `num_processes` batches for each process." + ) from e + # In both cases, we need to get the structure of the batch that we will broadcast on other + # processes to initialize the tensors with the right shape. + # data_structure, stop_iteration + batch_info = [get_data_structure(batch), False] + except StopIteration: + batch_info = [None, True] + else: + batch_info = [None, self._stop_iteration] + # This is inplace, so after this instruction, every process has the same `batch_info` as process 0. + broadcast_object_list(batch_info) + self._stop_iteration = batch_info[1] + if self._stop_iteration: + # If drop_last is False and split_batches is False, we may have a remainder to take care of. + if not self.split_batches and not self._drop_last: + if self.state.process_index == 0 and len(batches) > 0: + batch = concatenate(batches, dim=0) + batch_info = [get_data_structure(batch), False] + else: + batch_info = [None, True] + broadcast_object_list(batch_info) + return batch, batch_info + + def __iter__(self): + self.begin() + self.set_epoch(self.iteration) + main_iterator = None + if is_torch_version(">=", "2.0.1"): + # NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts + # shared seed to all dist processes. Thus, we need to create iterator for all dist processes. + # But, we only iterate through the DataLoader on process 0. + main_iterator = super().__iter__() + elif self.state.process_index == 0: + main_iterator = super().__iter__() + stop_iteration = False + self._stop_iteration = False + first_batch = None + next_batch, next_batch_info = self._fetch_batches(main_iterator) + batch_index = 0 + while not stop_iteration: + batch, batch_info = next_batch, next_batch_info + + if self.state.process_index != 0: + # Initialize tensors on other processes than process 0. + batch = initialize_tensors(batch_info[0]) + batch = send_to_device(batch, self.state.device, non_blocking=self._non_blocking) + # Broadcast the batch before splitting it. + batch = broadcast(batch, from_process=0) + + if not self._drop_last and first_batch is None: + # We keep at least num processes elements of the first batch to be able to complete the last batch + first_batch = self.slice_fn( + batch, + slice(0, self.state.num_processes), + process_index=self.state.process_index, + num_processes=self.state.num_processes, + ) + + if batch is None: + raise ValueError( + f"Batch does not contain any data (`{batch}`). At the end of all iterable data available before expected stop iteration." + ) + + observed_batch_size = find_batch_size(batch) + batch_size = observed_batch_size // self.state.num_processes + + stop_iteration = self._stop_iteration + if not stop_iteration: + # We may still be at the end of the dataloader without knowing it yet: if there is nothing left in + # the dataloader since the number of batches is a round multiple of the number of processes. + next_batch, next_batch_info = self._fetch_batches(main_iterator) + # next_batch_info[0] is None when there are no more batches, otherwise we still need to process them. + if self._stop_iteration and next_batch_info[0] is None: + stop_iteration = True + + if not self._drop_last and stop_iteration and observed_batch_size % self.state.num_processes != 0: + # If the last batch is not complete, let's add the first batch to it. + batch = concatenate([batch, first_batch], dim=0) + # Batch size computation above is wrong, it's off by 1 so we fix it. + batch_size += 1 + + data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size) + batch = self.slice_fn( + batch, + data_slice, + process_index=self.state.process_index, + num_processes=self.state.num_processes, + ) + + if stop_iteration: + self.end_of_dataloader = True + self.remainder = observed_batch_size + if batch_index >= self.skip_batches: + yield batch + batch_index += 1 + self.iteration += 1 + self.end() + + def set_epoch(self, epoch: int): + # In case it is manually passed in, the user can set it to what they like + if self.iteration != epoch: + self.iteration = epoch + if hasattr(self.batch_sampler.sampler, "set_epoch"): + self.batch_sampler.sampler.set_epoch(epoch) + elif hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + def __len__(self): + whole_length = super().__len__() + if self.split_batches: + return whole_length + elif self._drop_last: + return whole_length // self.state.num_processes + else: + return math.ceil(whole_length / self.state.num_processes) + + @property + def total_batch_size(self): + return ( + self.dataset.batch_size if self.split_batches else (self.dataset.batch_size * self.dataset.num_processes) + ) + + @property + def total_dataset_length(self): + return len(self.dataset) + + @property + def use_stateful_dataloader(self): + return True + + def get_sampler(self): + return get_sampler(self) + + def set_sampler(self, sampler): + sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler) + if sampler_is_batch_sampler: + self.sampler.sampler = sampler + else: + self.batch_sampler.sampler = sampler + if hasattr(self.batch_sampler, "batch_sampler"): + self.batch_sampler.batch_sampler.sampler = sampler + + def state_dict(self): + return self.dl_state_dict + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self.dl_state_dict = self.state_dict + + def _save_state_dict(self): + self.dl_state_dict = super().state_dict() + +class StatefulSkipDataLoader(StatefulDataLoader): + """ + Subclass of a torchdata `StatefulDataLoader` that will skip the first batches. + + Args: + dataset (`torch.utils.data.dataset.Dataset`): + The dataset to use to build this datalaoder. + skip_batches (`int`, *optional*, defaults to 0): + The number of batches to skip at the beginning. + kwargs: + All other keyword arguments to pass to the regular `DataLoader` initialization. + """ + + def __init__(self, dataset, skip_batches=0, **kwargs): + super().__init__(dataset, **kwargs) + self.skip_batches = skip_batches + + def __iter__(self): + for index, batch in enumerate(super().__iter__()): + if index >= self.skip_batches: + self._save_state_dict() + yield batch + + @property + def use_stateful_dataloader(self): + return True + + def state_dict(self): + return self.dl_state_dict + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self.dl_state_dict = self.state_dict + + def _save_state_dict(self): + self.dl_state_dict = super().state_dict() \ No newline at end of file diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index b1108bf9f22..2815ce4f226 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -29,6 +29,9 @@ SkipBatchSampler, SkipDataLoader, skip_first_batches, + create_data_loader_dispatcher, + create_data_loader_shard, + create_skip_data_loader ) from accelerate.test_utils.testing import require_torchdata_stateful_dataloader from accelerate.utils import is_torchdata_stateful_dataloader_available @@ -39,6 +42,9 @@ from torchdata.stateful_dataloader import ( StatefulDataLoader, ) + from accelerate.stateful_data_loader import ( + StatefulDataLoaderDispatcher, StatefulDataLoaderShard, StatefulSkipDataLoader + ) def parameterized_custom_name_func(func, param_num, param): @@ -394,18 +400,18 @@ def test_dataloader_inheritance(self): are instances of DataLoader and DataLoaderStateMixin. """ Accelerator() - skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2) - dl_shard = DataLoaderShard(range(16), batch_size=4) - dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4) - assert isinstance(skip_dl, DataLoader) - assert isinstance(dl_shard, DataLoader) - assert isinstance(dl_dispatcher, DataLoader) + skip_dl = create_skip_data_loader(range(16), batch_size=4, skip_batches=2) + dl_shard = create_data_loader_shard(range(16), batch_size=4) + dl_dispatcher = create_data_loader_dispatcher(range(16), batch_size=4) + assert isinstance(skip_dl, SkipDataLoader) + assert isinstance(dl_shard, DataLoaderShard) + assert isinstance(dl_dispatcher, DataLoaderDispatcher) assert isinstance(dl_shard, DataLoaderStateMixin) assert isinstance(dl_dispatcher, DataLoaderStateMixin) def test_skip_data_loader(self): - dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2) + dataloader = create_skip_data_loader(list(range(16)), batch_size=4, skip_batches=2) assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] def test_skip_first_batches(self): @@ -414,7 +420,7 @@ def test_skip_first_batches(self): assert [t.tolist() for t in new_dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] def test_end_of_dataloader(self): - dataloader = DataLoaderShard(list(range(16)), batch_size=4) + dataloader = create_data_loader_shard(list(range(16)), batch_size=4) for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) @@ -424,7 +430,7 @@ def test_end_of_dataloader(self): def test_end_of_dataloader_dispatcher(self): Accelerator() - dataloader = DataLoaderDispatcher(range(16), batch_size=4) + dataloader = create_data_loader_dispatcher(range(16), batch_size=4) for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) @@ -436,7 +442,7 @@ def test_end_of_dataloader_dispatcher(self): class StatefulDataLoaderTester(unittest.TestCase): @require_torchdata_stateful_dataloader def test_skip_data_loader(self): - dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2, use_stateful_dataloader=True) + dataloader = create_skip_data_loader(list(range(16)), batch_size=4, skip_batches=2, use_stateful_dataloader=True) assert isinstance(dataloader, StatefulDataLoader) assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] @@ -449,8 +455,7 @@ def test_skip_first_batches(self): @require_torchdata_stateful_dataloader def test_end_of_dataloader(self): - dataloader = DataLoaderShard(list(range(16)), batch_size=4, use_stateful_dataloader=True) - assert dataloader.use_stateful_dataloader + dataloader = create_data_loader_shard(list(range(16)), batch_size=4, use_stateful_dataloader=True) assert isinstance(dataloader, StatefulDataLoader) for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) @@ -462,7 +467,7 @@ def test_end_of_dataloader(self): @require_torchdata_stateful_dataloader def test_end_of_dataloader_dispatcher(self): Accelerator() - dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) + dataloader = create_data_loader_dispatcher(range(16), batch_size=4, use_stateful_dataloader=True) assert isinstance(dataloader, StatefulDataLoader) for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) @@ -478,9 +483,8 @@ def test_dataloader_state_dict(self, num_workers): Test that saving a stateful dataloader's state, then loading it back, gives the same results. """ dataset = list(range(16)) - dataloader = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) + dataloader = create_data_loader_shard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) - assert dataloader.use_stateful_dataloader assert isinstance(dataloader, StatefulDataLoader) vals = [] for idx, val in enumerate(dataloader): @@ -489,7 +493,7 @@ def test_dataloader_state_dict(self, num_workers): sd = dataloader.state_dict() assert len(vals) == 4 - dataloader2 = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) + dataloader2 = create_data_loader_shard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) dataloader2.load_state_dict(sd) data1 = vals[2:] @@ -506,7 +510,7 @@ def test_dataloader_dispatcher_state_dict(self, num_workers): dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) Accelerator(dataloader_config=dataloader_config) dataset = list(range(16)) - dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) + dataloader = create_data_loader_dispatcher(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) assert dataloader.use_stateful_dataloader assert isinstance(dataloader, StatefulDataLoader) @@ -516,7 +520,7 @@ def test_dataloader_dispatcher_state_dict(self, num_workers): if idx == 1: sd = dataloader.state_dict() assert len(vals) == 4 - dataloader2 = DataLoaderDispatcher( + dataloader2 = create_data_loader_dispatcher( dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers ) dataloader2.load_state_dict(sd) @@ -533,12 +537,12 @@ def test_dataloader_inheritance(self): subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin. """ Accelerator() - skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True) - dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True) - dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) - assert isinstance(skip_dl, StatefulDataLoader) - assert isinstance(dl_shard, StatefulDataLoader) - assert isinstance(dl_dispatcher, StatefulDataLoader) + skip_dl = create_skip_data_loader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True) + dl_shard = create_data_loader_shard(range(16), batch_size=4, use_stateful_dataloader=True) + dl_dispatcher = create_data_loader_dispatcher(range(16), batch_size=4, use_stateful_dataloader=True) + assert isinstance(skip_dl, StatefulSkipDataLoader) + assert isinstance(dl_shard, StatefulDataLoaderShard) + assert isinstance(dl_dispatcher, StatefulDataLoaderDispatcher) assert isinstance(dl_shard, DataLoaderStateMixin) assert isinstance(dl_dispatcher, DataLoaderStateMixin) @@ -558,13 +562,13 @@ def g(): accelerator = Accelerator() stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g()) - skip_dl = SkipDataLoader( + skip_dl = create_skip_data_loader( dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True ) - dl_shard = DataLoaderShard( + dl_shard = create_data_loader_shard( dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True ) - dl_dispatcher = DataLoaderDispatcher( + dl_dispatcher = create_data_loader_dispatcher( dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True ) @@ -606,7 +610,7 @@ def get_first_n_batches(dl, n, device): assert expected_state_dict == dl_dispatcher_state_dict # Load the state dict into new dataloaders - manual_skip_dl = SkipDataLoader( + manual_skip_dl = create_skip_data_loader( dataset, batch_size=4, num_workers=num_workers, @@ -616,15 +620,15 @@ def get_first_n_batches(dl, n, device): ) loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g()) loaded_stateful_dl.load_state_dict(expected_state_dict) - loaded_skip_dl = SkipDataLoader( + loaded_skip_dl = create_skip_data_loader( dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True ) loaded_skip_dl.load_state_dict(expected_state_dict) - loaded_dl_shard = DataLoaderShard( + loaded_dl_shard = create_data_loader_shard( dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True ) loaded_dl_shard.load_state_dict(expected_state_dict) - loaded_dl_dispatcher = DataLoaderDispatcher( + loaded_dl_dispatcher = create_data_loader_dispatcher( dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True ) loaded_dl_dispatcher.load_state_dict(expected_state_dict)