Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixup dataloader state dict bugs + incorporate load/save_state API #3034

Merged
merged 19 commits into from
Aug 23, 2024

Conversation

muellerzr
Copy link
Collaborator

@muellerzr muellerzr commented Aug 22, 2024

What does this PR do?

This PR builds on #2895, finishing it and solving a few issues that come up during distributed training.

  1. As pointed out in _update_state_dict:
    def _update_state_dict(self):
        # The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
        # E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
        # what it wants to yield.
        #
        # _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
        if hasattr(self.base_dataloader, "state_dict"):
            self.dl_state_dict = self.base_dataloader.state_dict()

technically we're 1 ahead of what we want to yield. This causes issues later on when we want to restore, as during DDP our _iter_yielded, _num_yielded, and samples_yielded will be off by exactly one full batch pulled - 1 (or num_processes -1). With this fix in load_state_dict, we can properly restore the batch to the right spot.
2. I noticed that I could only get reproducible behavior if we enabled deterministic algorithms. Unrelated to the prior PR, it's a flaky thing I noticed that goes away when enabling
3. Reworked the test to fix a few issues, namely skip_first_batches is meant to be used on raw pytorch dataloaders. This API technically removes the need for this (and we have tests for this elsewhere). The real test we want to ensure is if we resume from that state, the fully trained model should equal the partially trained model resumed from that state, which the tests now do.
4. Enables save_state/load_state to actually load in the dataloader states and adds tests for them
5. Ensures that this can actually run on multiple GPUs, namely by removing the part of the skip_first_batches since it was being used on a prepared dataloader

cc @byi8220

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@BenjaminBossan @SunMarc

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -539,6 +553,7 @@ def __iter__(self):
current_batch = next_batch
except StopIteration:
self.end_of_dataloader = True
self._update_state_dict()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

additionally, we need to update the state dict when we've hit the stop iteration to know if we've hit the end of the iterator

new_linear1 = prepared_model.linear1.weight
new_batchnorm = prepared_model.batchnorm.weight
new_linear2 = prepared_model.linear2.weight
unwrapped_model_2 = accelerator.unwrap_model(prepared_model)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

during any distributed process, we need to unwrap the model first before we can toy with the layers

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for investigating the issues and providing a fix, as well as harden the tests.

Generally, this LGTM, I have a few comments though, please check.

src/accelerate/checkpointing.py Show resolved Hide resolved
if getattr(dataloader, "use_stateful_dataloader", False):
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
state_dict = torch.load(input_dataloader_state_dict_file)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not something for this PR, but this line could cause trouble when the weights_only switch to torch.load will come.

@@ -442,8 +442,21 @@ def state_dict(self):
return self.dl_state_dict

def load_state_dict(self, state_dict):
# The state dict will be off by a factor of `n-1` batch too many during DDP,
# so we need to adjust it here
if PartialState().distributed_type != DistributedType.NO:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be fixed during loading or rather during saving?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make it during saving, that isolates it to users who just do Accelerator.save_state/Accelerator.load_state, which (esp in the trainer) might not be what users want to end up doing, since it's always entirely optional. I'd rather it happen in load IMO but will think if there's a better way

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other (probably more right option) is to fix it in state_dict()

Copy link
Contributor

@byi8220 byi8220 Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding onto this, is it appropriate to hack around with the state_dict's fields such as _iter_yielded, _num_yielded?

IIUC this works in the basic case, but torchdata seems to support custom state functions.

This makes me feel like this needs to be fixed during saving, and maybe in a way that doesn't make assumptions about the contents of state_dict (make sure to save the right state dict?).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We must, it's not a matter of if its appropriate or not. If we don't, the sampler/resuming simply wont work :) As we need to modify their values in their sampler.

This is a naiive implementation to start, and if we hit edge cases with more things to be adjusted later, we can. But the base case is supported

Copy link
Collaborator Author

@muellerzr muellerzr Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only point that would be is on a non-multi/distributed setup. Otherwise they're likely not using Accelerate dataloaders

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the question of custom state functions: Would it be possible to optionally allow custom data loaders to make the update themselves and otherwise fall back on the solution provided here? So for instance:

if hasattr(self.base_data_loader, "correct_state_for_prefetch"):  # or whatever fitting name
    self.dl_state_dict = self.base_data_loader.correct_state_for_prefetch(self.dl_state_dict, PartialState())
else:
    ...  # existing code

Of course, this needs to be documented so that users can correctly implement this method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm more a fan of this, but not sure if we can get there in our current state since it's either a StatefulDataLoader or a native torch.utils.data.DataLoader that gets built

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got confused, the type of the data loader is fix.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with an adjust_state_dict_for_prefetch func in DataLoaderAdapter, with documentation on how overriding it should work.

src/accelerate/data_loader.py Show resolved Hide resolved
if iterable:
dataset = DummyIterableDataset(torch.as_tensor(range(dataset_size)))
dataset = DummyIterableDataset(values)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not the same for TensorDataset?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorDataset can be shuffled, basically. Our iterable here can't

tests/test_accelerator.py Outdated Show resolved Hide resolved
@byi8220
Copy link
Contributor

byi8220 commented Aug 23, 2024

Thanks for catching this. Just to get an understanding of what's going on, this is an issue when training on multiple GPUs, where the inner state_dict (for the base dataloader on the dispatcher) is fetching 1 extra item for each additional shard beyond the first?

Did I miss this because I didn't have a test in test_distributed_data_loop.py, and my assumption that my test in test_accelerator.py would test this didn't hold because I didn't unwrap the model?

@muellerzr
Copy link
Collaborator Author

@byi8220 exactly. If it's not in those scripts, it's only ever ran on a single GPU since those scripts are the ones that test during DDP/multiple GPUs :)

It's not catching an extra item, it's one complete batch ahead. (due to prefetching)

@byi8220
Copy link
Contributor

byi8220 commented Aug 23, 2024

If it's not in those scripts, it's only ever ran on a single GPU since those scripts are the ones that test during DDP/multiple GPUs :)

Ah that's annoying. I kinda just guessed test_accelerator.py would cover multi_gpu due to the usage of the require_multi_gpu annotation in some of its tests.

IMO missing DDP coverage seems like a mistake that's very easy to make when writing features/tests.

I wonder if there's either a way to.

  1. Have make test automatically run the unit tests on multigpu as well (if more than 1 gpu is present), or
  2. Somehow have the CI catch this (Something like a LINT.IfChange rule).

@muellerzr
Copy link
Collaborator Author

muellerzr commented Aug 23, 2024

It's why we have the scripts, it's not a perfect solution but it's what we have to do :)

(And also, it does. They are found in tests/test_multigpu.py)

@byi8220
Copy link
Contributor

byi8220 commented Aug 23, 2024

(And also, it does. They are found in tests/test_multigpu.py)

Ah, so it runs these tests? Should the test_accelerator.py suite be integrated somewhere here, or is that not possible?

I think I get the workflow, I still feel it's something that's really easy to miss if you're not paying attention or missing context (as I have literally just done).

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my comments.

I have added one suggestion on how to possibly deal with custom stateful dataloaders, otherwise this LGTM.

@muellerzr
Copy link
Collaborator Author

@byi8220 it's not possible as these are only CPU runners. We then run them on merge to main since those are GPU runners and can do so (it's how I flagged/saw this had issues)

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the latest changes, LGTM. adjust_state_dict_for_prefetch is implemented even more elegantly than I was thinking of.

@byi8220
Copy link
Contributor

byi8220 commented Aug 23, 2024

We then run them on merge to main since those are GPU runners and can do so (it's how I flagged/saw this had issues)

Well, I guess if it's caught on merge it's not that big of an issue

@muellerzr muellerzr merged commit 726140c into main Aug 23, 2024
28 checks passed
@muellerzr muellerzr deleted the muellerzr-save-load-state-dl branch August 23, 2024 19:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants