Skip to content

Commit

Permalink
isolate some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed May 9, 2024
1 parent c3c9bcd commit e7a143a
Showing 1 changed file with 24 additions and 26 deletions.
50 changes: 24 additions & 26 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,8 @@ def test_random_state(self):
shuffle=True,
)


class TestStatefulDataLoaderMap(TestCase):
def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
# class TestStatefulDataLoaderMap(TestCase):
def _run_and_checkpoint3(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
if num_workers == 0:
return
dataset = DummyMapDataset(100, shuffle=shuffle)
Expand Down Expand Up @@ -272,56 +271,55 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st

self.assertEqual(batches, exp)

def test_no_mp(self):
def test_no_mp3(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
self._run_and_checkpoint(
self._run_and_checkpoint3(
num_workers=0,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_x(self):
def test_mp_x3(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
self._run_and_checkpoint(
self._run_and_checkpoint3(
num_workers=3,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_pw(self):
def test_mp_pw3(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
self._run_and_checkpoint(
self._run_and_checkpoint3(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_mp_every_n_steps(self):
def test_mp_every_n_steps3(self):
batch_size = 7
for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]):
self._run_and_checkpoint(
self._run_and_checkpoint3(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_random_state(self):
def test_random_state3(self):
for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]):
self._run_and_checkpoint(
self._run_and_checkpoint3(
num_workers=num_workers,
batch_size=7,
pw=False,
interrupt=interrupt,
shuffle=True,
)


class TestStatefulSampler(TestCase):
def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
# class TestStatefulSampler(TestCase):
def _run_and_checkpoint2(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
dataset = DummyMapDataset(100, shuffle=shuffle)
sampler = DummySampler(len(dataset))
dl = StatefulDataLoader(
Expand Down Expand Up @@ -366,46 +364,46 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st

self.assertEqual(batches, exp)

def test_no_mp(self):
def test_no_mp2(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
self._run_and_checkpoint(
self._run_and_checkpoint2(
num_workers=0,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_x(self):
def test_mp_x2(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
self._run_and_checkpoint(
self._run_and_checkpoint2(
num_workers=3,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_pw(self):
def test_mp_pw2(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
self._run_and_checkpoint(
self._run_and_checkpoint2(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_mp_every_n_steps(self):
def test_mp_every_n_steps2(self):
batch_size = 7
for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]):
self._run_and_checkpoint(
self._run_and_checkpoint2(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_random_state(self):
def test_random_state2(self):
for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]):
self._run_and_checkpoint(
self._run_and_checkpoint2(
num_workers=num_workers,
batch_size=7,
pw=False,
Expand Down

0 comments on commit e7a143a

Please sign in to comment.