Skip to content

Commit

Permalink
test_pickle_accelerator
Browse files Browse the repository at this point in the history
  • Loading branch information
byi8220 committed Sep 4, 2024
1 parent 40ec962 commit 1964011
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/accelerate/test_utils/scripts/test_distributed_data_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,14 @@ def test_join_raises_warning_for_iterable_when_overriding_even_batches():
assert issubclass(w[-1].category, UserWarning)
assert "only supported for map-style datasets" in str(w[-1].message)

def test_pickle_accelerator():
accelerator = create_accelerator()
data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4)
_ = accelerator.prepare(data_loader)
pickled_accelerator = pickle.dumps(accelerator)
unpickled_accelerator = pickle.loads(pickled_accelerator)
assert accelerator.state == unpickled_accelerator.state


def test_data_loader(data_loader, accelerator):
# Prepare the DataLoader
Expand Down Expand Up @@ -340,11 +348,6 @@ def test_stateful_dataloader_save_state(accelerator):
finally:
accelerator.dataloader_config = old_dataloader_config

def test_pickled_dataloader(data_loader, accelerator):
# Pickle then reload the dataloader
prepared_model_dumps = pickle.dumps(accelerator)
loaded_accelerator = pickle.loads(prepared_model_dumps)

def main():
accelerator = create_accelerator()
torch.manual_seed(accelerator.process_index)
Expand Down Expand Up @@ -373,6 +376,9 @@ def main():
test_join_raises_warning_for_non_ddp_distributed(accelerator)
accelerator.state.distributed_type = original_state

accelerator.print("Test pickling an accelerator")
test_pickle_accelerator()

dataset = DummyDataset()
# Conventional Dataloader with shuffle=False
loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
Expand All @@ -394,10 +400,6 @@ def main():
test_stateful_dataloader(accelerator)
test_stateful_dataloader_save_state(accelerator)

# Test pickling an accelerator works
iterable_loader = create_dataloader(accelerator, dataset_size=NUM_ELEMENTS, batch_size=BATCH_SIZE, iterable=True)
test_pickled_dataloader(iterable_loader, accelerator)

accelerator.end_training()


Expand Down

0 comments on commit 1964011

Please sign in to comment.