Skip to content

Commit

Permalink
uncomment 2
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed May 10, 2024
1 parent 7aef39f commit 069ccd7
Showing 1 changed file with 87 additions and 86 deletions.
173 changes: 87 additions & 86 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,92 +132,92 @@ def identity(x):


class TestStatefulDataLoaderIterable(TestCase):
def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
dataset = DummyIterableDataset([0, 100, 37], shuffle=shuffle)
dl = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
snapshot_every_n_steps=every_n_steps,
persistent_workers=pw,
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
)
list(dl)

if interrupt is None:
interrupt = len(exp)

exp = []
it = iter(dl)
for _ in range(interrupt):
next(it)

state_dict = dl.state_dict()
for data in it:
exp.append(data)

# Restore new instance from state
batches = []
dl = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
snapshot_every_n_steps=every_n_steps,
persistent_workers=pw,
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
)
dl.load_state_dict(state_dict)
for batch in iter(dl):
batches.append(batch)

self.assertEqual(exp, batches)

def test_no_mp(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
self._run_and_checkpoint(
num_workers=0,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_x(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=False,
interrupt=interrupt,
)

def test_mp_pw(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_mp_every_n_steps(self):
batch_size = 7
for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]):
self._run_and_checkpoint(
num_workers=3,
batch_size=batch_size,
pw=True,
interrupt=interrupt,
)

def test_random_state(self):
for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]):
self._run_and_checkpoint(
num_workers=num_workers,
batch_size=7,
pw=False,
interrupt=interrupt,
shuffle=True,
)
# def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
# dataset = DummyIterableDataset([0, 100, 37], shuffle=shuffle)
# dl = StatefulDataLoader(
# dataset=dataset,
# num_workers=num_workers,
# collate_fn=identity,
# snapshot_every_n_steps=every_n_steps,
# persistent_workers=pw,
# multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
# )
# list(dl)

# if interrupt is None:
# interrupt = len(exp)

# exp = []
# it = iter(dl)
# for _ in range(interrupt):
# next(it)

# state_dict = dl.state_dict()
# for data in it:
# exp.append(data)

# # Restore new instance from state
# batches = []
# dl = StatefulDataLoader(
# dataset=dataset,
# num_workers=num_workers,
# collate_fn=identity,
# snapshot_every_n_steps=every_n_steps,
# persistent_workers=pw,
# multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
# )
# dl.load_state_dict(state_dict)
# for batch in iter(dl):
# batches.append(batch)

# self.assertEqual(exp, batches)

# def test_no_mp(self):
# for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
# self._run_and_checkpoint(
# num_workers=0,
# batch_size=batch_size,
# pw=False,
# interrupt=interrupt,
# )

# def test_mp_x(self):
# for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
# self._run_and_checkpoint(
# num_workers=3,
# batch_size=batch_size,
# pw=False,
# interrupt=interrupt,
# )

# def test_mp_pw(self):
# for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
# self._run_and_checkpoint(
# num_workers=3,
# batch_size=batch_size,
# pw=True,
# interrupt=interrupt,
# )

# def test_mp_every_n_steps(self):
# batch_size = 7
# for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]):
# self._run_and_checkpoint(
# num_workers=3,
# batch_size=batch_size,
# pw=True,
# interrupt=interrupt,
# )

# def test_random_state(self):
# for num_workers, interrupt in itertools.product([0, 3], [0, 1, 10]):
# self._run_and_checkpoint(
# num_workers=num_workers,
# batch_size=7,
# pw=False,
# interrupt=interrupt,
# shuffle=True,
# )

# class TestStatefulDataLoaderMap(TestCase):
def _run_and_checkpoint3(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
Expand Down Expand Up @@ -1096,6 +1096,7 @@ def test_json_serde_multi_process_map(self):
unittest.main()
import psutil

print("Listing child PIDs")
current_process = psutil.Process()
children = current_process.children(recursive=True)
for child in children:
Expand Down

0 comments on commit 069ccd7

Please sign in to comment.