From 069ccd7a9e16315949b0134456669c1b418e1fb4 Mon Sep 17 00:00:00 2001 From: andrewkh Date: Thu, 9 May 2024 17:18:22 -0700 Subject: [PATCH] uncomment 2 --- test/stateful_dataloader/test_state_dict.py | 173 ++++++++++---------- 1 file changed, 87 insertions(+), 86 deletions(-) diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index d7797d4fd..94d90aa7d 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -132,92 +132,92 @@ def identity(x): class TestStatefulDataLoaderIterable(TestCase): - def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): - dataset = DummyIterableDataset([0, 100, 37], shuffle=shuffle) - dl = StatefulDataLoader( - dataset=dataset, - num_workers=num_workers, - collate_fn=identity, - snapshot_every_n_steps=every_n_steps, - persistent_workers=pw, - multiprocessing_context="forkserver" if IS_MACOS and num_workers else None, - ) - list(dl) - - if interrupt is None: - interrupt = len(exp) - - exp = [] - it = iter(dl) - for _ in range(interrupt): - next(it) - - state_dict = dl.state_dict() - for data in it: - exp.append(data) - - # Restore new instance from state - batches = [] - dl = StatefulDataLoader( - dataset=dataset, - num_workers=num_workers, - collate_fn=identity, - snapshot_every_n_steps=every_n_steps, - persistent_workers=pw, - multiprocessing_context="forkserver" if IS_MACOS and num_workers else None, - ) - dl.load_state_dict(state_dict) - for batch in iter(dl): - batches.append(batch) - - self.assertEqual(exp, batches) - - def test_no_mp(self): - for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): - self._run_and_checkpoint( - num_workers=0, - batch_size=batch_size, - pw=False, - interrupt=interrupt, - ) - - def test_mp_x(self): - for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): - self._run_and_checkpoint( - num_workers=3, - batch_size=batch_size, - pw=False, - interrupt=interrupt, - ) - - def test_mp_pw(self): - for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): - self._run_and_checkpoint( - num_workers=3, - batch_size=batch_size, - pw=True, - interrupt=interrupt, - ) - - def test_mp_every_n_steps(self): - batch_size = 7 - for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]): - self._run_and_checkpoint( - num_workers=3, - batch_size=batch_size, - pw=True, - interrupt=interrupt, - ) - - def test_random_state(self): - for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]): - self._run_and_checkpoint( - num_workers=num_workers, - batch_size=7, - pw=False, - interrupt=interrupt, - shuffle=True, - ) + # def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): + # dataset = DummyIterableDataset([0, 100, 37], shuffle=shuffle) + # dl = StatefulDataLoader( + # dataset=dataset, + # num_workers=num_workers, + # collate_fn=identity, + # snapshot_every_n_steps=every_n_steps, + # persistent_workers=pw, + # multiprocessing_context="forkserver" if IS_MACOS and num_workers else None, + # ) + # list(dl) + + # if interrupt is None: + # interrupt = len(exp) + + # exp = [] + # it = iter(dl) + # for _ in range(interrupt): + # next(it) + + # state_dict = dl.state_dict() + # for data in it: + # exp.append(data) + + # # Restore new instance from state + # batches = [] + # dl = StatefulDataLoader( + # dataset=dataset, + # num_workers=num_workers, + # collate_fn=identity, + # snapshot_every_n_steps=every_n_steps, + # persistent_workers=pw, + # multiprocessing_context="forkserver" if IS_MACOS and num_workers else None, + # ) + # dl.load_state_dict(state_dict) + # for batch in iter(dl): + # batches.append(batch) + + # self.assertEqual(exp, batches) + + # def test_no_mp(self): + # for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): + # self._run_and_checkpoint( + # num_workers=0, + # batch_size=batch_size, + # pw=False, + # interrupt=interrupt, + # ) + + # def test_mp_x(self): + # for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): + # self._run_and_checkpoint( + # num_workers=3, + # batch_size=batch_size, + # pw=False, + # interrupt=interrupt, + # ) + + # def test_mp_pw(self): + # for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): + # self._run_and_checkpoint( + # num_workers=3, + # batch_size=batch_size, + # pw=True, + # interrupt=interrupt, + # ) + + # def test_mp_every_n_steps(self): + # batch_size = 7 + # for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]): + # self._run_and_checkpoint( + # num_workers=3, + # batch_size=batch_size, + # pw=True, + # interrupt=interrupt, + # ) + + # def test_random_state(self): + # for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]): + # self._run_and_checkpoint( + # num_workers=num_workers, + # batch_size=7, + # pw=False, + # interrupt=interrupt, + # shuffle=True, + # ) # class TestStatefulDataLoaderMap(TestCase): def _run_and_checkpoint3(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): @@ -1096,6 +1096,7 @@ def test_json_serde_multi_process_map(self): unittest.main() import psutil + print("Listing child PIDs") current_process = psutil.Process() children = current_process.children(recursive=True) for child in children: