diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index c7823b1c5d2..1660adf13cd 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -340,7 +340,7 @@ def test_stateful_dataloader_save_state(accelerator): finally: accelerator.dataloader_config = old_dataloader_config -def test_pickled_dataloader(accelerator): +def test_pickled_dataloader(data_loader, accelerator): # Prepare the DataLoader data_loader = accelerator.prepare(data_loader) # Pickle then reload the dataloader @@ -412,7 +412,8 @@ def main(): # Dataloader after pickling loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) - test_pickled_dataloader(accelerator) + test_pickled_dataloader(loader, accelerator) + accelerator.end_training()