diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 42116552d..c3d378890 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -219,9 +219,8 @@ def test_random_state(self): shuffle=True, ) - -class TestStatefulDataLoaderMap(TestCase): - def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): + # class TestStatefulDataLoaderMap(TestCase): + def _run_and_checkpoint3(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): if num_workers == 0: return dataset = DummyMapDataset(100, shuffle=shuffle) @@ -272,46 +271,46 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st self.assertEqual(batches, exp) - def test_no_mp(self): + def test_no_mp3(self): for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): - self._run_and_checkpoint( + self._run_and_checkpoint3( num_workers=0, batch_size=batch_size, pw=False, interrupt=interrupt, ) - def test_mp_x(self): + def test_mp_x3(self): for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): - self._run_and_checkpoint( + self._run_and_checkpoint3( num_workers=3, batch_size=batch_size, pw=False, interrupt=interrupt, ) - def test_mp_pw(self): + def test_mp_pw3(self): for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): - self._run_and_checkpoint( + self._run_and_checkpoint3( num_workers=3, batch_size=batch_size, pw=True, interrupt=interrupt, ) - def test_mp_every_n_steps(self): + def test_mp_every_n_steps3(self): batch_size = 7 for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]): - self._run_and_checkpoint( + self._run_and_checkpoint3( num_workers=3, batch_size=batch_size, pw=True, interrupt=interrupt, ) - def test_random_state(self): + def test_random_state3(self): for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]): - self._run_and_checkpoint( + self._run_and_checkpoint3( num_workers=num_workers, batch_size=7, pw=False, @@ -319,9 +318,8 @@ def test_random_state(self): shuffle=True, ) - -class TestStatefulSampler(TestCase): - def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): + # class TestStatefulSampler(TestCase): + def _run_and_checkpoint2(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): dataset = DummyMapDataset(100, shuffle=shuffle) sampler = DummySampler(len(dataset)) dl = StatefulDataLoader( @@ -366,46 +364,46 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st self.assertEqual(batches, exp) - def test_no_mp(self): + def test_no_mp2(self): for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): - self._run_and_checkpoint( + self._run_and_checkpoint2( num_workers=0, batch_size=batch_size, pw=False, interrupt=interrupt, ) - def test_mp_x(self): + def test_mp_x2(self): for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): - self._run_and_checkpoint( + self._run_and_checkpoint2( num_workers=3, batch_size=batch_size, pw=False, interrupt=interrupt, ) - def test_mp_pw(self): + def test_mp_pw2(self): for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): - self._run_and_checkpoint( + self._run_and_checkpoint2( num_workers=3, batch_size=batch_size, pw=True, interrupt=interrupt, ) - def test_mp_every_n_steps(self): + def test_mp_every_n_steps2(self): batch_size = 7 for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]): - self._run_and_checkpoint( + self._run_and_checkpoint2( num_workers=3, batch_size=batch_size, pw=True, interrupt=interrupt, ) - def test_random_state(self): + def test_random_state2(self): for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]): - self._run_and_checkpoint( + self._run_and_checkpoint2( num_workers=num_workers, batch_size=7, pw=False,