From 0b8d2ecaf7f3e8559712401660019a4553e4a072 Mon Sep 17 00:00:00 2001 From: andrewkh Date: Thu, 9 May 2024 13:28:04 -0700 Subject: [PATCH] isolate some tests --- test/stateful_dataloader/test_state_dict.py | 108 +++++++++++++++++++- 1 file changed, 106 insertions(+), 2 deletions(-) diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index f22440bac..a6f67e9c0 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -225,7 +225,7 @@ def test_random_state(self): ) -class TestStatefulDataLoaderMap(TestStatefulDataLoaderIterable): +class TestStatefulDataLoaderMap(TestCase): def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): if num_workers == 0: return @@ -277,8 +277,60 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st self.assertEqual(batches, exp) + def test_no_mp(self): + for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): + with self.subTest(batch_size=batch_size, interrupt=interrupt): + self._run_and_checkpoint( + num_workers=0, + batch_size=batch_size, + pw=False, + interrupt=interrupt, + ) + + def test_mp_x(self): + for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): + with self.subTest(batch_size=batch_size, interrupt=interrupt): + self._run_and_checkpoint( + num_workers=3, + batch_size=batch_size, + pw=False, + interrupt=interrupt, + ) + + def test_mp_pw(self): + for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): + with self.subTest(batch_size=batch_size, interrupt=interrupt): + self._run_and_checkpoint( + num_workers=3, + batch_size=batch_size, + pw=True, + interrupt=interrupt, + ) + + def test_mp_every_n_steps(self): + batch_size = 7 + for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]): + with self.subTest(every_n_steps=every_n_steps, batch_size=batch_size, interrupt=interrupt): + self._run_and_checkpoint( + num_workers=3, + batch_size=batch_size, + pw=True, + interrupt=interrupt, + ) + + def test_random_state(self): + for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]): + with self.subTest(num_workers=num_workers, interrupt=interrupt): + self._run_and_checkpoint( + num_workers=num_workers, + batch_size=7, + pw=False, + interrupt=interrupt, + shuffle=True, + ) + -class TestStatefulSampler(TestStatefulDataLoaderIterable): +class TestStatefulSampler(TestCase): def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): dataset = DummyMapDataset(100, shuffle=shuffle) sampler = DummySampler(len(dataset)) @@ -324,6 +376,58 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st self.assertEqual(batches, exp) + def test_no_mp(self): + for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): + with self.subTest(batch_size=batch_size, interrupt=interrupt): + self._run_and_checkpoint( + num_workers=0, + batch_size=batch_size, + pw=False, + interrupt=interrupt, + ) + + def test_mp_x(self): + for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): + with self.subTest(batch_size=batch_size, interrupt=interrupt): + self._run_and_checkpoint( + num_workers=3, + batch_size=batch_size, + pw=False, + interrupt=interrupt, + ) + + def test_mp_pw(self): + for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): + with self.subTest(batch_size=batch_size, interrupt=interrupt): + self._run_and_checkpoint( + num_workers=3, + batch_size=batch_size, + pw=True, + interrupt=interrupt, + ) + + def test_mp_every_n_steps(self): + batch_size = 7 + for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]): + with self.subTest(every_n_steps=every_n_steps, batch_size=batch_size, interrupt=interrupt): + self._run_and_checkpoint( + num_workers=3, + batch_size=batch_size, + pw=True, + interrupt=interrupt, + ) + + def test_random_state(self): + for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]): + with self.subTest(num_workers=num_workers, interrupt=interrupt): + self._run_and_checkpoint( + num_workers=num_workers, + batch_size=7, + pw=False, + interrupt=interrupt, + shuffle=True, + ) + # class GeneratorIterable(torch.utils.data.IterableDataset): # def __init__(self, sizes_for_all_workers):