-
Notifications
You must be signed in to change notification settings - Fork 938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixup dataloader state dict bugs + incorporate load/save_state API #3034
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -539,6 +553,7 @@ def __iter__(self): | |||
current_batch = next_batch | |||
except StopIteration: | |||
self.end_of_dataloader = True | |||
self._update_state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
additionally, we need to update the state dict when we've hit the stop iteration to know if we've hit the end of the iterator
new_linear1 = prepared_model.linear1.weight | ||
new_batchnorm = prepared_model.batchnorm.weight | ||
new_linear2 = prepared_model.linear2.weight | ||
unwrapped_model_2 = accelerator.unwrap_model(prepared_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
during any distributed process, we need to unwrap the model first before we can toy with the layers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for investigating the issues and providing a fix, as well as harden the tests.
Generally, this LGTM, I have a few comments though, please check.
src/accelerate/checkpointing.py
Outdated
if getattr(dataloader, "use_stateful_dataloader", False): | ||
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin" | ||
input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name) | ||
state_dict = torch.load(input_dataloader_state_dict_file) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not something for this PR, but this line could cause trouble when the weights_only
switch to torch.load
will come.
@@ -442,8 +442,21 @@ def state_dict(self): | |||
return self.dl_state_dict | |||
|
|||
def load_state_dict(self, state_dict): | |||
# The state dict will be off by a factor of `n-1` batch too many during DDP, | |||
# so we need to adjust it here | |||
if PartialState().distributed_type != DistributedType.NO: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be fixed during loading or rather during saving?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we make it during saving, that isolates it to users who just do Accelerator.save_state
/Accelerator.load_state
, which (esp in the trainer) might not be what users want to end up doing, since it's always entirely optional. I'd rather it happen in load
IMO but will think if there's a better way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other (probably more right option) is to fix it in state_dict()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding onto this, is it appropriate to hack around with the state_dict's fields such as _iter_yielded
, _num_yielded
?
IIUC this works in the basic case, but torchdata
seems to support custom state functions.
This makes me feel like this needs to be fixed during saving, and maybe in a way that doesn't make assumptions about the contents of state_dict (make sure to save the right state dict?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We must, it's not a matter of if its appropriate or not. If we don't, the sampler/resuming simply wont work :) As we need to modify their values in their sampler.
This is a naiive implementation to start, and if we hit edge cases with more things to be adjusted later, we can. But the base case is supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only point that would be is on a non-multi/distributed setup. Otherwise they're likely not using Accelerate dataloaders
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding the question of custom state functions: Would it be possible to optionally allow custom data loaders to make the update themselves and otherwise fall back on the solution provided here? So for instance:
if hasattr(self.base_data_loader, "correct_state_for_prefetch"): # or whatever fitting name
self.dl_state_dict = self.base_data_loader.correct_state_for_prefetch(self.dl_state_dict, PartialState())
else:
... # existing code
Of course, this needs to be documented so that users can correctly implement this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm more a fan of this, but not sure if we can get there in our current state since it's either a StatefulDataLoader
or a native torch.utils.data.DataLoader that gets built
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got confused, the type of the data loader is fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with an adjust_state_dict_for_prefetch
func in DataLoaderAdapter
, with documentation on how overriding it should work.
if iterable: | ||
dataset = DummyIterableDataset(torch.as_tensor(range(dataset_size))) | ||
dataset = DummyIterableDataset(values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not the same for TensorDataset
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorDataset
can be shuffled, basically. Our iterable here can't
src/accelerate/test_utils/scripts/test_distributed_data_loop.py
Outdated
Show resolved
Hide resolved
src/accelerate/test_utils/scripts/test_distributed_data_loop.py
Outdated
Show resolved
Hide resolved
Thanks for catching this. Just to get an understanding of what's going on, this is an issue when training on multiple GPUs, where the inner state_dict (for the base dataloader on the dispatcher) is fetching 1 extra item for each additional shard beyond the first? Did I miss this because I didn't have a test in |
@byi8220 exactly. If it's not in those scripts, it's only ever ran on a single GPU since those scripts are the ones that test during DDP/multiple GPUs :) It's not catching an extra item, it's one complete batch ahead. (due to prefetching) |
Ah that's annoying. I kinda just guessed IMO missing DDP coverage seems like a mistake that's very easy to make when writing features/tests. I wonder if there's either a way to.
|
It's why we have the scripts, it's not a perfect solution but it's what we have to do :) (And also, it does. They are found in |
Ah, so it runs these tests? Should the I think I get the workflow, I still feel it's something that's really easy to miss if you're not paying attention or missing context (as I have literally just done). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing my comments.
I have added one suggestion on how to possibly deal with custom stateful dataloaders, otherwise this LGTM.
@byi8220 it's not possible as these are only CPU runners. We then run them on merge to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the latest changes, LGTM. adjust_state_dict_for_prefetch
is implemented even more elegantly than I was thinking of.
Well, I guess if it's caught on merge it's not that big of an issue |
What does this PR do?
This PR builds on #2895, finishing it and solving a few issues that come up during distributed training.
_update_state_dict
:technically we're 1 ahead of what we want to yield. This causes issues later on when we want to restore, as during DDP our
_iter_yielded
,_num_yielded
, andsamples_yielded
will be off by exactly one full batch pulled - 1 (ornum_processes -1
). With this fix inload_state_dict
, we can properly restore the batch to the right spot.2. I noticed that I could only get reproducible behavior if we enabled deterministic algorithms. Unrelated to the prior PR, it's a flaky thing I noticed that goes away when enabling
3. Reworked the test to fix a few issues, namely
skip_first_batches
is meant to be used on raw pytorch dataloaders. This API technically removes the need for this (and we have tests for this elsewhere). The real test we want to ensure is if we resume from that state, the fully trained model should equal the partially trained model resumed from that state, which the tests now do.4. Enables
save_state
/load_state
to actually load in the dataloader states and adds tests for them5. Ensures that this can actually run on multiple GPUs, namely by removing the part of the
skip_first_batches
since it was being used on a prepared dataloadercc @byi8220
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@BenjaminBossan @SunMarc