Skip to content

Commit

Permalink
pass args through superclass
Browse files Browse the repository at this point in the history
  • Loading branch information
byi8220 committed Sep 4, 2024
1 parent 073d7b3 commit ee681ed
Showing 1 changed file with 8 additions and 32 deletions.
40 changes: 8 additions & 32 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,20 +572,9 @@ def __iter__(self):
self.end()

def __reduce__(self):
return (
DataLoaderShard,
(
self.base_dataloader.dataset,
self.device,
self.rng_types,
self.synchronized_generator,
self.skip_batches,
self.use_stateful_dataloader,
self._drop_last,
self._non_blocking,
),
self.__dict__,
)
args = super().__reduce__()
return (DataLoaderShard, *args[1:])


def set_epoch(self, epoch: int):
# In case it is manually passed in, the user can set it to what they like
Expand Down Expand Up @@ -881,19 +870,8 @@ def __len__(self):
return math.ceil(whole_length / self.state.num_processes)

def __reduce__(self):
return (
DataLoaderDispatcher,
(
self.base_dataloader.dataset,
self.split_batches,
self.skip_batches,
self.use_stateful_dataloader,
self._drop_last,
self._non_blocking,
self.slice_fn,
),
self.__dict__,
)
args = super().__reduce__()
return (DataLoaderDispatcher, *args[1:])

@property
def total_batch_size(self):
Expand Down Expand Up @@ -1238,11 +1216,9 @@ def __len__(self):
return len(self.base_dataloader) - self.skip_batches

def __reduce__(self):
return (
SkipDataLoader,
(self.base_dataloader.dataset, self.skip_batches, self.use_stateful_dataloader),
self.__dict__,
)
args = super().__reduce__()
return (SkipDataLoader, *args[1:])



def skip_first_batches(dataloader, num_batches=0):
Expand Down

0 comments on commit ee681ed

Please sign in to comment.