Skip to content

Commit

Permalink
pickling generator issues
Browse files Browse the repository at this point in the history
  • Loading branch information
byi8220 committed Sep 4, 2024
1 parent b43c85d commit 40ec962
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __iter__(self):
def create_accelerator(even_batches=True):
dataloader_config = DataLoaderConfiguration(even_batches=even_batches)
accelerator = Accelerator(dataloader_config=dataloader_config)
# assert accelerator.num_processes == 2, "this script expects that two GPUs are available"
assert accelerator.num_processes == 2, "this script expects that two GPUs are available"
return accelerator


Expand Down Expand Up @@ -341,8 +341,6 @@ def test_stateful_dataloader_save_state(accelerator):
accelerator.dataloader_config = old_dataloader_config

def test_pickled_dataloader(data_loader, accelerator):
# Prepare the DataLoader
data_loader = accelerator.prepare(data_loader)
# Pickle then reload the dataloader
prepared_model_dumps = pickle.dumps(accelerator)
loaded_accelerator = pickle.loads(prepared_model_dumps)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def test_can_pickle_dataloader(self, dispatch_batches):
# TODO: Add support for pickling StatefulDataLoader
dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=False)
accelerator = Accelerator(dataloader_config=dataloader_config)
torch.manual_seed(accelerator.process_index)

original_dl, prepared_skip_dl = accelerator.prepare(dl, skip_dl)
prepared_model_dumps = pickle.dumps(accelerator)

Expand Down

0 comments on commit 40ec962

Please sign in to comment.