Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixup dataloader state dict bugs + incorporate load/save_state API #3034
Fixup dataloader state dict bugs + incorporate load/save_state API #3034
Changes from 11 commits
f9d914e
079eba4
5515b70
aecf440
baa041d
1e5032e
463a045
4ec77b5
5866951
36a735c
e0ea50d
b32f37a
285ab46
a2de149
160d0ee
ff8aefa
3ca6272
5ad3091
2a83d54
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not something for this PR, but this line could cause trouble when the
weights_only
switch totorch.load
will come.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be fixed during loading or rather during saving?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we make it during saving, that isolates it to users who just do
Accelerator.save_state
/Accelerator.load_state
, which (esp in the trainer) might not be what users want to end up doing, since it's always entirely optional. I'd rather it happen inload
IMO but will think if there's a better wayThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other (probably more right option) is to fix it in
state_dict()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding onto this, is it appropriate to hack around with the state_dict's fields such as
_iter_yielded
,_num_yielded
?IIUC this works in the basic case, but
torchdata
seems to support custom state functions.This makes me feel like this needs to be fixed during saving, and maybe in a way that doesn't make assumptions about the contents of state_dict (make sure to save the right state dict?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We must, it's not a matter of if its appropriate or not. If we don't, the sampler/resuming simply wont work :) As we need to modify their values in their sampler.
This is a naiive implementation to start, and if we hit edge cases with more things to be adjusted later, we can. But the base case is supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only point that would be is on a non-multi/distributed setup. Otherwise they're likely not using Accelerate dataloaders
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding the question of custom state functions: Would it be possible to optionally allow custom data loaders to make the update themselves and otherwise fall back on the solution provided here? So for instance:
Of course, this needs to be documented so that users can correctly implement this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm more a fan of this, but not sure if we can get there in our current state since it's either a
StatefulDataLoader
or a native torch.utils.data.DataLoader that gets builtThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got confused, the type of the data loader is fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with an
adjust_state_dict_for_prefetch
func inDataLoaderAdapter
, with documentation on how overriding it should work.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
additionally, we need to update the state dict when we've hit the stop iteration to know if we've hit the end of the iterator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not the same for
TensorDataset
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorDataset
can be shuffled, basically. Our iterable here can't