From dff2666282ff28c9de8ae5a97f1fe5aae2eb324e Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 3 Sep 2024 18:13:04 -0400 Subject: [PATCH] fix typo --- .../test_utils/scripts/test_distributed_data_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index b32eb52b435..8b306617468 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -411,7 +411,8 @@ def main(): test_stateful_dataloader_save_state(accelerator) # Dataloader after pickling - loader = DataLoader(DummyIterableDataset(), shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) + iterable_loader = create_dataloader(accelerator, dataset_size=NUM_ELEMENTS, batch_size=BATCH_SIZE, iterable=True) + loader = DataLoader(DummyIterableDataset(iterable_loader), shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) test_pickled_dataloader(loader, accelerator) accelerator.end_training()