Skip to content

Commit

Permalink
isolate some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed May 9, 2024
1 parent fb6b709 commit 0b8d2ec
Showing 1 changed file with 106 additions and 2 deletions.
108 changes: 106 additions & 2 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def test_random_state(self):
)


class TestStatefulDataLoaderMap(TestStatefulDataLoaderIterable):
class TestStatefulDataLoaderMap(TestCase):
def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
if num_workers == 0:
return
Expand Down Expand Up @@ -277,8 +277,60 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st

self.assertEqual(batches, exp)

def test_no_mp(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
with self.subTest(batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=0,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_x(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
with self.subTest(batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_pw(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
with self.subTest(batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_mp_every_n_steps(self):
batch_size = 7
for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]):
with self.subTest(every_n_steps=every_n_steps, batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_random_state(self):
for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]):
with self.subTest(num_workers=num_workers, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=num_workers,
batch_size=7,
pw=False,
interrupt=interrupt,
shuffle=True,
)


class TestStatefulSampler(TestStatefulDataLoaderIterable):
class TestStatefulSampler(TestCase):
def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
dataset = DummyMapDataset(100, shuffle=shuffle)
sampler = DummySampler(len(dataset))
Expand Down Expand Up @@ -324,6 +376,58 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st

self.assertEqual(batches, exp)

def test_no_mp(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
with self.subTest(batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=0,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_x(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
with self.subTest(batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_pw(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
with self.subTest(batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_mp_every_n_steps(self):
batch_size = 7
for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]):
with self.subTest(every_n_steps=every_n_steps, batch_size=batch_size, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_random_state(self):
for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]):
with self.subTest(num_workers=num_workers, interrupt=interrupt):
self._run_and_checkpoint(
num_workers=num_workers,
batch_size=7,
pw=False,
interrupt=interrupt,
shuffle=True,
)


# class GeneratorIterable(torch.utils.data.IterableDataset):
# def __init__(self, sizes_for_all_workers):
Expand Down

0 comments on commit 0b8d2ec

Please sign in to comment.