Skip to content

Commit

Permalink
save progress
Browse files Browse the repository at this point in the history
  • Loading branch information
thevasudevgupta committed Aug 14, 2023
1 parent 4d13e4e commit 1ecfb16
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
4 changes: 4 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def __init__(
step_scheduler_with_optimizer: bool = True,
kwargs_handlers: list[KwargsHandler] | None = None,
dynamo_backend: DynamoBackend | str | None = None,
slice_fn: Callable | None = None,
):
if project_config is not None:
self.project_configuration = project_config
Expand Down Expand Up @@ -462,6 +463,8 @@ def __init__(
if self.rng_types is None:
self.rng_types = ["generator"]

self.slice_fn = slice_fn

@property
def use_distributed(self):
"""
Expand Down Expand Up @@ -1807,6 +1810,7 @@ def prepare_data_loader(self, data_loader: torch.utils.data.DataLoader, device_p
rng_types=self.rng_types.copy(),
dispatch_batches=self.dispatch_batches,
even_batches=self.even_batches,
slice_fn=self.slice_fn,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
Expand Down
12 changes: 8 additions & 4 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import math
from contextlib import suppress
from typing import List, Optional, Union
from typing import List, Optional, Union, Callable

import torch
from torch.utils.data import BatchSampler, DataLoader, IterableDataset
Expand Down Expand Up @@ -485,7 +485,7 @@ class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
- **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
"""

def __init__(self, dataset, split_batches: bool = False, skip_batches=0, _drop_last: bool = False, **kwargs):
def __init__(self, dataset, split_batches: bool = False, skip_batches=0, _drop_last: bool = False, slice_fn=None, **kwargs):
shuffle = False
if is_torch_version(">=", "1.11.0"):
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
Expand All @@ -503,6 +503,8 @@ def __init__(self, dataset, split_batches: bool = False, skip_batches=0, _drop_l
self._drop_last = _drop_last
self.skip_batches = skip_batches

self.slice_fn = slice_fn

def _fetch_batches(self, iterator):
batches, batch = None, None
# On process 0, we gather the batch to dispatch.
Expand Down Expand Up @@ -567,7 +569,7 @@ def __iter__(self):

if not self._drop_last and first_batch is None:
# We keep at least num processes elements of the first batch to be able to complete the last batch
first_batch = slice_tensors(batch, slice(0, self.state.num_processes))
first_batch = slice_tensors(batch, slice(0, self.state.num_processes)) if self.slice_fn is None else self.slice_fn(batch, self.state.process_index, self.state.num_processes)

if batch is None:
raise ValueError(
Expand All @@ -593,7 +595,7 @@ def __iter__(self):
batch_size += 1

data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size)
batch = slice_tensors(batch, data_slice)
batch = slice_tensors(batch, data_slice) if self.slice_fn is None else self.slice_fn(batch, self.state.process_index, self.state.num_processes)

if stop_iteration:
self.end_of_dataloader = True
Expand Down Expand Up @@ -633,6 +635,7 @@ def prepare_data_loader(
rng_types: Optional[List[Union[str, RNGType]]] = None,
dispatch_batches: Optional[bool] = None,
even_batches: bool = True,
slice_fn: Optional[Callable] = None,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
Expand Down Expand Up @@ -786,6 +789,7 @@ def prepare_data_loader(
split_batches=split_batches,
batch_sampler=new_batch_sampler,
_drop_last=dataloader.drop_last,
slice_fn=slice_fn,
**kwargs,
)
elif sampler_is_batch_sampler:
Expand Down

0 comments on commit 1ecfb16

Please sign in to comment.