Skip to content

Commit

Permalink
isolate some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed May 9, 2024
1 parent 0b8d2ec commit c3c9bcd
Showing 1 changed file with 93 additions and 108 deletions.
201 changes: 93 additions & 108 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,55 +174,50 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st

def test_no_mp(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
with self.subTest(batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=0,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)
self._run_and_checkpoint(
num_workers=0,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_x(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
with self.subTest(batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_pw(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
with self.subTest(batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_mp_every_n_steps(self):
batch_size = 7
for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]):
with self.subTest(every_n_steps=every_n_steps, batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_random_state(self):
for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]):
with self.subTest(num_workers=num_workers, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=num_workers,
batch_size=7,
pw=False,
interrupt=interrupt,
shuffle=True,
)
self._run_and_checkpoint(
num_workers=num_workers,
batch_size=7,
pw=False,
interrupt=interrupt,
shuffle=True,
)


class TestStatefulDataLoaderMap(TestCase):
Expand Down Expand Up @@ -279,55 +274,50 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st

def test_no_mp(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
with self.subTest(batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=0,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)
self._run_and_checkpoint(
num_workers=0,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_x(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
with self.subTest(batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_pw(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
with self.subTest(batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_mp_every_n_steps(self):
batch_size = 7
for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]):
with self.subTest(every_n_steps=every_n_steps, batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_random_state(self):
for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]):
with self.subTest(num_workers=num_workers, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=num_workers,
batch_size=7,
pw=False,
interrupt=interrupt,
shuffle=True,
)
self._run_and_checkpoint(
num_workers=num_workers,
batch_size=7,
pw=False,
interrupt=interrupt,
shuffle=True,
)


class TestStatefulSampler(TestCase):
Expand Down Expand Up @@ -378,55 +368,50 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st

def test_no_mp(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
with self.subTest(batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=0,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)
self._run_and_checkpoint(
num_workers=0,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_x(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
with self.subTest(batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_pw(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
with self.subTest(batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_mp_every_n_steps(self):
batch_size = 7
for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]):
with self.subTest(every_n_steps=every_n_steps, batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_random_state(self):
for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]):
with self.subTest(num_workers=num_workers, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=num_workers,
batch_size=7,
pw=False,
interrupt=interrupt,
shuffle=True,
)
self._run_and_checkpoint(
num_workers=num_workers,
batch_size=7,
pw=False,
interrupt=interrupt,
shuffle=True,
)


# class GeneratorIterable(torch.utils.data.IterableDataset):
Expand Down

0 comments on commit c3c9bcd

Please sign in to comment.