-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save state in dataset_iter_state when dataset is also an iterator #1279
Changes from 8 commits
fbf6913
cd20fca
18ac5ce
f5db1fb
d0df5d4
0a4432a
35b72a3
2fb7597
3641c8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1285,7 +1285,7 @@ def test(self): | |||||||||
for _ in range((num_workers + 1) * 2): | ||||||||||
next(it) | ||||||||||
state_dict = dl.state_dict() | ||||||||||
worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["dataset_state"] | ||||||||||
worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["dataset_iter_state"] | ||||||||||
self.assertEqual(len(worker_state), 7) | ||||||||||
deep_copy_state_dict = deepcopy(state_dict) | ||||||||||
|
||||||||||
|
@@ -1295,7 +1295,9 @@ def test(self): | |||||||||
next_state_dict = dl.state_dict() | ||||||||||
self.assertEqual(state_dict, deep_copy_state_dict) | ||||||||||
self.assertFalse(state_dict == next_state_dict) | ||||||||||
worker_state = next_state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["dataset_state"] | ||||||||||
worker_state = next_state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"][ | ||||||||||
"dataset_iter_state" | ||||||||||
] | ||||||||||
self.assertEqual(len(worker_state), 11) | ||||||||||
|
||||||||||
dl = StatefulDataLoader( | ||||||||||
|
@@ -1311,7 +1313,7 @@ def test(self): | |||||||||
exp.extend(next(it)) | ||||||||||
state_dict = dl.state_dict() | ||||||||||
self.assertEqual(exp, [3, 3]) | ||||||||||
worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["dataset_state"] | ||||||||||
worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["dataset_iter_state"] | ||||||||||
self.assertEqual(len(worker_state), 9) | ||||||||||
|
||||||||||
|
||||||||||
|
@@ -1334,16 +1336,15 @@ def test(self): | |||||||||
if num_workers > 0: | ||||||||||
for i in range(num_workers): | ||||||||||
# Ensure worker state is stored only once if the dataset is also the iterator | ||||||||||
self.assertTrue(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"]) | ||||||||||
self.assertEqual( | ||||||||||
self.assertEqual(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], None) | ||||||||||
self.assertTrue( | ||||||||||
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][ | ||||||||||
"dataset_iter_state" | ||||||||||
], | ||||||||||
None, | ||||||||||
] | ||||||||||
) | ||||||||||
else: | ||||||||||
self.assertTrue(state_dict["dataset_state"]) | ||||||||||
self.assertEqual(state_dict["fetcher_state"]["dataset_iter_state"], None) | ||||||||||
self.assertEqual(state_dict["dataset_state"], None) | ||||||||||
self.assertTrue(state_dict["fetcher_state"]["dataset_iter_state"]) | ||||||||||
|
||||||||||
|
||||||||||
class PeriodicStateIterableDataset(torch.utils.data.IterableDataset): | ||||||||||
|
@@ -1511,25 +1512,26 @@ def load_state_dict(self, state_dict): | |||||||||
class CountIterCallsIter(torch.utils.data.IterableDataset): | ||||||||||
def __init__(self, length): | ||||||||||
self.length = length | ||||||||||
self.iter_calls = 0 | ||||||||||
self.local_iter_calls = 0 | ||||||||||
self.prev_state_iter_calls = 0 | ||||||||||
self.items = [] | ||||||||||
|
||||||||||
def __iter__(self): | ||||||||||
self.items = list(range(self.length)) | ||||||||||
self.iter_calls += 1 | ||||||||||
self.local_iter_calls += 1 | ||||||||||
return self | ||||||||||
|
||||||||||
def __next__(self): | ||||||||||
if len(self.items) > 0: | ||||||||||
self.items.popleft() | ||||||||||
return self.items.pop(0) | ||||||||||
else: | ||||||||||
raise StopIteration | ||||||||||
|
||||||||||
def state_dict(self): | ||||||||||
return {"iter_calls": self.iter_calls, "items": deepcopy(self.items)} | ||||||||||
return {"iter_calls": self.local_iter_calls + self.prev_state_iter_calls, "items": deepcopy(self.items)} | ||||||||||
|
||||||||||
def load_state_dict(self, state_dict): | ||||||||||
self.iter_calls = state_dict["iter_calls"] | ||||||||||
self.prev_state_iter_calls = state_dict["iter_calls"] | ||||||||||
self.items = state_dict["items"] | ||||||||||
|
||||||||||
|
||||||||||
|
@@ -1540,26 +1542,52 @@ def _get_iter_calls(self, state): | |||||||||
else: | ||||||||||
w_states = list(state["_snapshot"]["_worker_snapshots"].values()) | ||||||||||
|
||||||||||
return [x["dataset_state"]["iter_calls"] for x in w_states] | ||||||||||
if w_states[0]["dataset_state"] is not None: | ||||||||||
return [x["dataset_state"]["iter_calls"] for x in w_states] | ||||||||||
return [x["fetcher_state"]["dataset_iter_state"]["iter_calls"] for x in w_states] | ||||||||||
|
||||||||||
def _run_test(self, num_workers, dataset): | ||||||||||
# Need to make a copy here as iter calls is tracked and its count is stored in dataset state which persists across calls for single process runs. | ||||||||||
dataset_copy = deepcopy(dataset) if num_workers == 0 else dataset | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Let's not have the test behave differently for different num_workers as much as possible, a test that is modified just to pass is not really a good test and I'd prefer we delete it |
||||||||||
dl = StatefulDataLoader( | ||||||||||
dataset=dataset, | ||||||||||
andrewkho marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
num_workers=num_workers, | ||||||||||
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), | ||||||||||
persistent_workers=True if num_workers else False, | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The behaviour should be the same for both PW = True and False |
||||||||||
) | ||||||||||
iter(dl) | ||||||||||
it = iter(dl) | ||||||||||
state = dl.state_dict() | ||||||||||
# Ensure iter is called only once per worker | ||||||||||
self.assertEqual(self._get_iter_calls(state), [1] * max(1, num_workers)) | ||||||||||
|
||||||||||
for _ in range(10): | ||||||||||
next(it) | ||||||||||
state = dl.state_dict() | ||||||||||
# Ensure that iter has not been invoked again | ||||||||||
self.assertEqual(self._get_iter_calls(state), [1] * max(1, num_workers)) | ||||||||||
|
||||||||||
# Call iter on dl to see if iter is called again | ||||||||||
iter(dl) | ||||||||||
state = dl.state_dict() | ||||||||||
self.assertEqual(self._get_iter_calls(state), [2] * max(1, num_workers)) | ||||||||||
|
||||||||||
dl2 = StatefulDataLoader( | ||||||||||
dataset=dataset, | ||||||||||
dataset=dataset_copy, | ||||||||||
num_workers=num_workers, | ||||||||||
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), | ||||||||||
persistent_workers=True if num_workers else False, | ||||||||||
) | ||||||||||
dl2.load_state_dict(state) | ||||||||||
iter(dl2) | ||||||||||
it = iter(dl2) | ||||||||||
state2 = dl2.state_dict() | ||||||||||
self.assertEqual(self._get_iter_calls(state2), [2] * max(1, num_workers)) | ||||||||||
# Ensure that iter is called only once per worker even when dataloader resumes from a state | ||||||||||
self.assertEqual(self._get_iter_calls(state2), [3] * max(1, num_workers)) | ||||||||||
|
||||||||||
for _ in range(10): | ||||||||||
next(it) | ||||||||||
state = dl2.state_dict() | ||||||||||
# Ensure that iter has not been invoked again | ||||||||||
self.assertEqual(self._get_iter_calls(state2), [3] * max(1, num_workers)) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's change this test to just assert that iter is only called once for new dataloaders with and without state-resumes, but not enforce the actual value of state["iter_calls"] after resume state |
||||||||||
|
||||||||||
def test_inline(self): | ||||||||||
self._run_test(0, CountIterCalls(100)) | ||||||||||
|
@@ -1574,5 +1602,94 @@ def test_mp_iter(self): | |||||||||
self._run_test(2, CountIterCallsIter(100)) | ||||||||||
|
||||||||||
|
||||||||||
class IterationState: | ||||||||||
def __init__(self, start, end): | ||||||||||
self.curr = start | ||||||||||
self.end = end | ||||||||||
|
||||||||||
def set_state(self, state): | ||||||||||
self.curr = state["curr"] | ||||||||||
self.end = state["end"] | ||||||||||
|
||||||||||
def get_state(self): | ||||||||||
return {"curr": self.curr, "end": self.end} | ||||||||||
|
||||||||||
|
||||||||||
class StatesInitializationDataset(torch.utils.data.IterableDataset): | ||||||||||
def __init__(self, length): | ||||||||||
self.length = length | ||||||||||
|
||||||||||
def __iter__(self): | ||||||||||
worker_id = 0 | ||||||||||
if torch.utils.data.get_worker_info() is not None: | ||||||||||
worker_id = torch.utils.data.get_worker_info().id | ||||||||||
num_workers = 1 | ||||||||||
if torch.utils.data.get_worker_info() is not None: | ||||||||||
num_workers = torch.utils.data.get_worker_info().num_workers | ||||||||||
|
||||||||||
num_samples = (int)(self.length / num_workers) | ||||||||||
self.iter_state = IterationState(num_samples * worker_id, num_samples * (worker_id + 1)) | ||||||||||
return self | ||||||||||
|
||||||||||
def __next__(self): | ||||||||||
if self.iter_state.curr >= self.iter_state.end: | ||||||||||
raise StopIteration | ||||||||||
value = self.iter_state.curr | ||||||||||
self.iter_state.curr += 1 | ||||||||||
return value | ||||||||||
|
||||||||||
def state_dict(self): | ||||||||||
return {"state": self.iter_state.get_state()} | ||||||||||
|
||||||||||
def load_state_dict(self, state_dict): | ||||||||||
self.iter_state.set_state(state_dict["state"]) | ||||||||||
|
||||||||||
|
||||||||||
class TestStateInitializationDataset(TestCase): | ||||||||||
def _run_test(self, num_workers, dataset): | ||||||||||
length = dataset.length | ||||||||||
|
||||||||||
# Ensure test is run with compatible parameters as the test and dataset used in the test doesn't cover all the corner cases | ||||||||||
if num_workers > 0: | ||||||||||
self.assertTrue(length % num_workers == 0) | ||||||||||
self.assertTrue(length > 30) | ||||||||||
|
||||||||||
dl = StatefulDataLoader( | ||||||||||
dataset=dataset, | ||||||||||
num_workers=num_workers, | ||||||||||
collate_fn=identity, | ||||||||||
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), | ||||||||||
) | ||||||||||
it = iter(dl) | ||||||||||
state = dl.state_dict() | ||||||||||
data = [] | ||||||||||
|
||||||||||
for _ in range(length - 30): | ||||||||||
data.extend(next(it)) | ||||||||||
state = dl.state_dict() | ||||||||||
|
||||||||||
# Resume from state | ||||||||||
dl2 = StatefulDataLoader( | ||||||||||
dataset=dataset, | ||||||||||
num_workers=num_workers, | ||||||||||
collate_fn=identity, | ||||||||||
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), | ||||||||||
) | ||||||||||
dl2.load_state_dict(state) | ||||||||||
it = iter(dl2) | ||||||||||
|
||||||||||
for _ in range(30): | ||||||||||
data.extend(next(it)) | ||||||||||
|
||||||||||
# Order could be different for multiworker case as the data comes from different workers, so use set to check equality instead of list | ||||||||||
self.assertEqual(set(data), set(range(length))) | ||||||||||
|
||||||||||
def test_inline(self): | ||||||||||
self._run_test(0, StatesInitializationDataset(100)) | ||||||||||
|
||||||||||
def test_mp(self): | ||||||||||
self._run_test(2, StatesInitializationDataset(100)) | ||||||||||
|
||||||||||
|
||||||||||
if __name__ == "__main__": | ||||||||||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's revert this, except delete the
self.iter_calls
in load_state_dict and just use it to track calls in the test