Skip to content

Commit

Permalink
skip_first_batches should be used on raw dls
Browse files Browse the repository at this point in the history
  • Loading branch information
byi8220 committed Sep 3, 2024
1 parent 364a7fd commit 8c4f15a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def test_can_pickle_dataloader(self, dispatch_batches):
"""
Test that pickling a prepared dataloader works.
"""
data = torch.arange(10)
data = torch.arange(10).to(torch_device)
ds = torch.utils.data.TensorDataset(data)
dl = torch.utils.data.DataLoader(ds)
# Currently, StatefulDataLoader doesn't seem to support pickling, so we aren't testing that functionality
Expand All @@ -674,8 +674,8 @@ def test_can_pickle_dataloader(self, dispatch_batches):
assert [i for i in model_loaded._dataloaders[0]] == [i for i in original_dl]

# Test skip dataloader works as expected as well
skip_dl = skip_first_batches(original_dl, 2)
assert isinstance(skip_dl, torch.utils.data.DataLoader)
skip_dl = skip_first_batches(dl, 2)
assert isinstance(skip_dl, DataLoader)
assert len(skip_dl) == len(original_dl) - 2
orig_items = [i for i in original_dl]
skip_dl_items = [i for i in skip_dl]
Expand Down

0 comments on commit 8c4f15a

Please sign in to comment.