-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save state in dataset_iter_state when dataset is also an iterator #1279
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/data/1279
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3641c8f with merge base 958eeb0 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@gokulavasan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@gokulavasan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@@ -1529,7 +1530,8 @@ def state_dict(self): | |||
return {"iter_calls": self.iter_calls, "items": deepcopy(self.items)} | |||
|
|||
def load_state_dict(self, state_dict): | |||
self.iter_calls = state_dict["iter_calls"] | |||
# sequence of calls for this : iter is called first and then load_state_dict is called and thus we don't want state to override calls to iter that has happened before load_state_dict was called | |||
self.iter_calls += state_dict["iter_calls"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes load_state_dict non-idempotent, so eg if load_state_dict is called multiple times it will result in incorrect behaviour
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@andrewkho Updated the unit test to check specifically the iter_calls alone and not restore it value which is unnecessary
@gokulavasan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
self.local_iter_calls = 0 | ||
self.prev_state_iter_calls = 0 | ||
self.items = [] | ||
|
||
def __iter__(self): | ||
self.items = list(range(self.length)) | ||
self.iter_calls += 1 | ||
self.local_iter_calls += 1 | ||
return self | ||
|
||
def __next__(self): | ||
if len(self.items) > 0: | ||
self.items.popleft() | ||
return self.items.pop(0) | ||
else: | ||
raise StopIteration | ||
|
||
def state_dict(self): | ||
return {"iter_calls": self.iter_calls, "items": deepcopy(self.items)} | ||
return {"iter_calls": self.local_iter_calls + self.prev_state_iter_calls, "items": deepcopy(self.items)} | ||
|
||
def load_state_dict(self, state_dict): | ||
self.iter_calls = state_dict["iter_calls"] | ||
self.prev_state_iter_calls = state_dict["iter_calls"] | ||
self.items = state_dict["items"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's revert this, except delete the self.iter_calls
in load_state_dict and just use it to track calls in the test
# Need to make a copy here as iter calls is tracked and its count is stored in dataset state which persists across calls for single process runs. | ||
dataset_copy = deepcopy(dataset) if num_workers == 0 else dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Need to make a copy here as iter calls is tracked and its count is stored in dataset state which persists across calls for single process runs. | |
dataset_copy = deepcopy(dataset) if num_workers == 0 else dataset | |
# Need to make a copy here as iter calls is tracked and its count is stored in dataset state which persists across calls for single process runs. | |
dataset = deepcopy(dataset) |
Let's not have the test behave differently for different num_workers as much as possible, a test that is modified just to pass is not really a good test and I'd prefer we delete it
dl = StatefulDataLoader( | ||
dataset=dataset, | ||
num_workers=num_workers, | ||
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), | ||
persistent_workers=True if num_workers else False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The behaviour should be the same for both PW = True and False
# Ensure that iter is called only once per worker even when dataloader resumes from a state | ||
self.assertEqual(self._get_iter_calls(state2), [3] * max(1, num_workers)) | ||
|
||
for _ in range(10): | ||
next(it) | ||
state = dl2.state_dict() | ||
# Ensure that iter has not been invoked again | ||
self.assertEqual(self._get_iter_calls(state2), [3] * max(1, num_workers)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's change this test to just assert that iter is only called once for new dataloaders with and without state-resumes, but not enforce the actual value of state["iter_calls"] after resume state
@andrewkho Addressed comments, please take a look. The only comment I couldn't resolve was around persistent workers being True. It has to be set to True if we want to match the behavior of single process and multiprocess test cases in the TestSingleIterCalled_shard0 unit test. This is because self.iter_calls is local dataset variable. In case of single process, the modified value persists beyond the lifetime of the first dataloader created. In case of multiprocess without persistent workers, the value is modified only the copy of that variable in the multiprocessing worker. Let me know how you want to me to address it - options are: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gogogo
Context:
For iterable datasets, state can be defined both at dataset level and also in the iterator. In the state_dict that is vended, this dataset_state is stored in dataset_state and iterator state is stored in fetcher_state::dataset_iter_state.
For a dataset which also acts as an iterator, initially the identical state was saved in both places which was wasteful. This was fixed in #1258 where the state was saved only once in fetcher_state::dataset_iter_state.
But this was swapped in #1273.
This is being swapped back to its original state to make sure the state variables can be initialized even in the iter method (which is the only place where it can access variables such as WorkerInfo).
Fixes #{issue number}
Changes