diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 7e6b48a31..e92fa4a85 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -853,9 +853,8 @@ def test_two_dataloaders(self) -> None: class TestFastStateDictRequest(unittest.TestCase): - def test_fast_state_dict_request(self) -> None: + def _run_test(self, snapshot_every_n_steps, interrupt): num_workers = 4 - interrupt = 11 # because of round robin, this should stop after worker 2 dataset = DummyIterableDataset([25, 25, 25, 25], shuffle=True) dl = StatefulDataLoader( @@ -865,6 +864,7 @@ def test_fast_state_dict_request(self) -> None: collate_fn=identity, persistent_workers=True, multiprocessing_context="forkserver" if IS_MACOS else None, + snapshot_every_n_steps=snapshot_every_n_steps, ) it = iter(dl) for _ in range(interrupt): @@ -896,6 +896,12 @@ def test_fast_state_dict_request(self) -> None: self.assertEqual(data, exp) + def test_fast_state_dict_request(self) -> None: + self._run_test(0, 11) + + def test_fast_state_dict_request_skip_steps(self) -> None: + self._run_test(17, 19) + class TestJsonSerDe(unittest.TestCase): def _run_test_iterable(self, num_workers):