Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save state in dataset_iter_state when dataset is also an iterator #1279

Merged
merged 9 commits into from
Jun 24, 2024

Conversation

gokulavasan
Copy link
Contributor

@gokulavasan gokulavasan commented Jun 21, 2024

Context:

For iterable datasets, state can be defined both at dataset level and also in the iterator. In the state_dict that is vended, this dataset_state is stored in dataset_state and iterator state is stored in fetcher_state::dataset_iter_state.

For a dataset which also acts as an iterator, initially the identical state was saved in both places which was wasteful. This was fixed in #1258 where the state was saved only once in fetcher_state::dataset_iter_state.

But this was swapped in #1273.

This is being swapped back to its original state to make sure the state variables can be initialized even in the iter method (which is the only place where it can access variables such as WorkerInfo).

Fixes #{issue number}

Changes

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 21, 2024
Copy link

pytorch-bot bot commented Jun 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/data/1279

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3641c8f with merge base 958eeb0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@gokulavasan gokulavasan marked this pull request as ready for review June 21, 2024 22:58
@facebook-github-bot
Copy link
Contributor

@gokulavasan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@gokulavasan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -1529,7 +1530,8 @@ def state_dict(self):
return {"iter_calls": self.iter_calls, "items": deepcopy(self.items)}

def load_state_dict(self, state_dict):
self.iter_calls = state_dict["iter_calls"]
# sequence of calls for this : iter is called first and then load_state_dict is called and thus we don't want state to override calls to iter that has happened before load_state_dict was called
self.iter_calls += state_dict["iter_calls"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes load_state_dict non-idempotent, so eg if load_state_dict is called multiple times it will result in incorrect behaviour

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andrewkho Updated the unit test to check specifically the iter_calls alone and not restore it value which is unnecessary

@facebook-github-bot
Copy link
Contributor

@gokulavasan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Comment on lines 1515 to 1535
self.local_iter_calls = 0
self.prev_state_iter_calls = 0
self.items = []

def __iter__(self):
self.items = list(range(self.length))
self.iter_calls += 1
self.local_iter_calls += 1
return self

def __next__(self):
if len(self.items) > 0:
self.items.popleft()
return self.items.pop(0)
else:
raise StopIteration

def state_dict(self):
return {"iter_calls": self.iter_calls, "items": deepcopy(self.items)}
return {"iter_calls": self.local_iter_calls + self.prev_state_iter_calls, "items": deepcopy(self.items)}

def load_state_dict(self, state_dict):
self.iter_calls = state_dict["iter_calls"]
self.prev_state_iter_calls = state_dict["iter_calls"]
self.items = state_dict["items"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's revert this, except delete the self.iter_calls in load_state_dict and just use it to track calls in the test

Comment on lines 1550 to 1551
# Need to make a copy here as iter calls is tracked and its count is stored in dataset state which persists across calls for single process runs.
dataset_copy = deepcopy(dataset) if num_workers == 0 else dataset
Copy link
Contributor

@andrewkho andrewkho Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Need to make a copy here as iter calls is tracked and its count is stored in dataset state which persists across calls for single process runs.
dataset_copy = deepcopy(dataset) if num_workers == 0 else dataset
# Need to make a copy here as iter calls is tracked and its count is stored in dataset state which persists across calls for single process runs.
dataset = deepcopy(dataset)

Let's not have the test behave differently for different num_workers as much as possible, a test that is modified just to pass is not really a good test and I'd prefer we delete it

dl = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
persistent_workers=True if num_workers else False,
Copy link
Contributor

@andrewkho andrewkho Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behaviour should be the same for both PW = True and False

Comment on lines 1583 to 1590
# Ensure that iter is called only once per worker even when dataloader resumes from a state
self.assertEqual(self._get_iter_calls(state2), [3] * max(1, num_workers))

for _ in range(10):
next(it)
state = dl2.state_dict()
# Ensure that iter has not been invoked again
self.assertEqual(self._get_iter_calls(state2), [3] * max(1, num_workers))
Copy link
Contributor

@andrewkho andrewkho Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change this test to just assert that iter is only called once for new dataloaders with and without state-resumes, but not enforce the actual value of state["iter_calls"] after resume state

@gokulavasan
Copy link
Contributor Author

gokulavasan commented Jun 24, 2024

@andrewkho Addressed comments, please take a look. The only comment I couldn't resolve was around persistent workers being True. It has to be set to True if we want to match the behavior of single process and multiprocess test cases in the TestSingleIterCalled_shard0 unit test. This is because self.iter_calls is local dataset variable.

In case of single process, the modified value persists beyond the lifetime of the first dataloader created.

In case of multiprocess without persistent workers, the value is modified only the copy of that variable in the multiprocessing worker.

Let me know how you want to me to address it - options are:
i) we can switch to persistent workers=True and keep the expected iter_calls value same for single process and multiprocess
ii) make a dataset copy for the single process case and pass in a different dataset object
iii) pass in different expected values for those two scenarios.

Copy link
Contributor

@andrewkho andrewkho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gogogo

@gokulavasan gokulavasan merged commit b0e25e2 into main Jun 24, 2024
44 checks passed
@gokulavasan gokulavasan deleted the save-state-in-fetcher-state branch June 24, 2024 22:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants