Skip to content

Commit

Permalink
Add an additional test for fast-state-dict request problem when snaps…
Browse files Browse the repository at this point in the history
…hot_every_n_steps > 1 (#1252)

Summary:
#1251 <- See this PR for more context on the bug. Adds an additional test to check that this still works when snapshot_every_n_steps > 1

### Changes
* New unit test
-
-

Pull Request resolved: #1252

Reviewed By: gokulavasan

Differential Revision: D57135043

Pulled By: andrewkho

fbshipit-source-id: 7119aac2e9773f9dc104d0a9359c0d0cdb304644
  • Loading branch information
andrewkho authored and facebook-github-bot committed May 8, 2024
1 parent c1f8b66 commit 4e70d26
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,9 +853,8 @@ def test_two_dataloaders(self) -> None:


class TestFastStateDictRequest(unittest.TestCase):
def test_fast_state_dict_request(self) -> None:
def _run_test(self, snapshot_every_n_steps, interrupt):
num_workers = 4
interrupt = 11 # because of round robin, this should stop after worker 2
dataset = DummyIterableDataset([25, 25, 25, 25], shuffle=True)

dl = StatefulDataLoader(
Expand All @@ -865,6 +864,7 @@ def test_fast_state_dict_request(self) -> None:
collate_fn=identity,
persistent_workers=True,
multiprocessing_context="forkserver" if IS_MACOS else None,
snapshot_every_n_steps=snapshot_every_n_steps,
)
it = iter(dl)
for _ in range(interrupt):
Expand Down Expand Up @@ -896,6 +896,12 @@ def test_fast_state_dict_request(self) -> None:

self.assertEqual(data, exp)

def test_fast_state_dict_request(self) -> None:
self._run_test(0, 11)

def test_fast_state_dict_request_skip_steps(self) -> None:
self._run_test(17, 19)


class TestJsonSerDe(unittest.TestCase):
def _run_test_iterable(self, num_workers):
Expand Down

0 comments on commit 4e70d26

Please sign in to comment.