From 1ecfb162e44d8c7fae27426ce7a6e1079825f1ff Mon Sep 17 00:00:00 2001 From: Vasudev Gupta <7vasudevgupta@gmail.com> Date: Mon, 14 Aug 2023 07:36:57 +0000 Subject: [PATCH] save progress --- src/accelerate/accelerator.py | 4 ++++ src/accelerate/data_loader.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index b6382e741f0..13bc0e732dc 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -258,6 +258,7 @@ def __init__( step_scheduler_with_optimizer: bool = True, kwargs_handlers: list[KwargsHandler] | None = None, dynamo_backend: DynamoBackend | str | None = None, + slice_fn: Callable | None = None, ): if project_config is not None: self.project_configuration = project_config @@ -462,6 +463,8 @@ def __init__( if self.rng_types is None: self.rng_types = ["generator"] + self.slice_fn = slice_fn + @property def use_distributed(self): """ @@ -1807,6 +1810,7 @@ def prepare_data_loader(self, data_loader: torch.utils.data.DataLoader, device_p rng_types=self.rng_types.copy(), dispatch_batches=self.dispatch_batches, even_batches=self.even_batches, + slice_fn=self.slice_fn, ) self._dataloaders.append(prepared_data_loader) return prepared_data_loader diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 86da333cd40..412543e85aa 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -14,7 +14,7 @@ import math from contextlib import suppress -from typing import List, Optional, Union +from typing import List, Optional, Union, Callable import torch from torch.utils.data import BatchSampler, DataLoader, IterableDataset @@ -485,7 +485,7 @@ class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes. """ - def __init__(self, dataset, split_batches: bool = False, skip_batches=0, _drop_last: bool = False, **kwargs): + def __init__(self, dataset, split_batches: bool = False, skip_batches=0, _drop_last: bool = False, slice_fn=None, **kwargs): shuffle = False if is_torch_version(">=", "1.11.0"): from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe @@ -503,6 +503,8 @@ def __init__(self, dataset, split_batches: bool = False, skip_batches=0, _drop_l self._drop_last = _drop_last self.skip_batches = skip_batches + self.slice_fn = slice_fn + def _fetch_batches(self, iterator): batches, batch = None, None # On process 0, we gather the batch to dispatch. @@ -567,7 +569,7 @@ def __iter__(self): if not self._drop_last and first_batch is None: # We keep at least num processes elements of the first batch to be able to complete the last batch - first_batch = slice_tensors(batch, slice(0, self.state.num_processes)) + first_batch = slice_tensors(batch, slice(0, self.state.num_processes)) if self.slice_fn is None else self.slice_fn(batch, self.state.process_index, self.state.num_processes) if batch is None: raise ValueError( @@ -593,7 +595,7 @@ def __iter__(self): batch_size += 1 data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size) - batch = slice_tensors(batch, data_slice) + batch = slice_tensors(batch, data_slice) if self.slice_fn is None else self.slice_fn(batch, self.state.process_index, self.state.num_processes) if stop_iteration: self.end_of_dataloader = True @@ -633,6 +635,7 @@ def prepare_data_loader( rng_types: Optional[List[Union[str, RNGType]]] = None, dispatch_batches: Optional[bool] = None, even_batches: bool = True, + slice_fn: Optional[Callable] = None, ) -> DataLoader: """ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. @@ -786,6 +789,7 @@ def prepare_data_loader( split_batches=split_batches, batch_sampler=new_batch_sampler, _drop_last=dataloader.drop_last, + slice_fn=slice_fn, **kwargs, ) elif sampler_is_batch_sampler: