From 1964011ee4723e459f7c5d3fc125967db1d27b68 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 3 Sep 2024 20:51:31 -0400 Subject: [PATCH] test_pickle_accelerator --- .../scripts/test_distributed_data_loop.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index 31338561249..b6aefe0b4fd 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -247,6 +247,14 @@ def test_join_raises_warning_for_iterable_when_overriding_even_batches(): assert issubclass(w[-1].category, UserWarning) assert "only supported for map-style datasets" in str(w[-1].message) +def test_pickle_accelerator(): + accelerator = create_accelerator() + data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4) + _ = accelerator.prepare(data_loader) + pickled_accelerator = pickle.dumps(accelerator) + unpickled_accelerator = pickle.loads(pickled_accelerator) + assert accelerator.state == unpickled_accelerator.state + def test_data_loader(data_loader, accelerator): # Prepare the DataLoader @@ -340,11 +348,6 @@ def test_stateful_dataloader_save_state(accelerator): finally: accelerator.dataloader_config = old_dataloader_config -def test_pickled_dataloader(data_loader, accelerator): - # Pickle then reload the dataloader - prepared_model_dumps = pickle.dumps(accelerator) - loaded_accelerator = pickle.loads(prepared_model_dumps) - def main(): accelerator = create_accelerator() torch.manual_seed(accelerator.process_index) @@ -373,6 +376,9 @@ def main(): test_join_raises_warning_for_non_ddp_distributed(accelerator) accelerator.state.distributed_type = original_state + accelerator.print("Test pickling an accelerator") + test_pickle_accelerator() + dataset = DummyDataset() # Conventional Dataloader with shuffle=False loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) @@ -394,10 +400,6 @@ def main(): test_stateful_dataloader(accelerator) test_stateful_dataloader_save_state(accelerator) - # Test pickling an accelerator works - iterable_loader = create_dataloader(accelerator, dataset_size=NUM_ELEMENTS, batch_size=BATCH_SIZE, iterable=True) - test_pickled_dataloader(iterable_loader, accelerator) - accelerator.end_training()