Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add early support for torchdata.stateful_dataloader.StatefulDataLoader within the Accelerator #2895

Merged
merged 74 commits into from
Aug 22, 2024

Conversation

byi8220
Copy link
Contributor

@byi8220 byi8220 commented Jun 26, 2024

What does this PR do?

Fixes #2859

This PR does the following:

  1. Added a new field use_stateful_dataloader in DataLoaderConfiguration. Passing this into the config makes it so that all DataLoaders prepared and returned by the Accelerator are StatefulDataLoader objects from the torchdata library
  2. Create a class DataLoaderAdapter which can wrap around and act as either PyTorch's DataLoader, or other variants of it such as StatefulDataLoader
  3. Refactor DataLoaderShard, DataLoaderDispatcher, and SkipDataLoader to inherit from DataLoaderAdapter instead of DataLoader

Testing

Added new unit tests to test that StatefulDataLoader can be dropped in and loaded and saved from.

Caveats

  • The torchdata package may have conflicts with accelerate, see Importing torchdata.stateful_dataloader causes the test check_seedable_sampler to fail #2894
    • However, if torchdata is not installed, all existing tests pass, suggesting this is not regressive.
  • torchdata.stateful_dataloader.StatefulDataLoader is only available in the beta build of torchdata, this is not a stable feature.
  • Adding another dependency (on a nightly package) means that almost none of the tests added in this PR is done underneath the existing images or imports.
  • This has only been tested on my local workstation using a single GPU.
  • The implementation of DataLoaderAdapter is somewhat invasive and uses some questionable reflection

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@muellerzr

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@byi8220
Copy link
Contributor Author

byi8220 commented Jul 31, 2024

TBH I'm not sure if I like either.

It's tricky. My understanding of the problem is this: We want to do something like this"If use_stateful_dataloader==True, then create a StatefulDataLoader instead of a DataLoader."

However, based on how classes like DataLoaderShard are currently implemented, what we have to do in practice is "If use_stateful_dataloader==True, then make this class inherit from StatefulDataLoader instead of DataLoader." This is a lot more awkward to code.

I have one more idea for a solution, which is to have those classes create a base_dataloader and manually code the passthrough to methods and properties. I feel like this is a bit fragile, and isn't too different from my original solution.

@byi8220 byi8220 marked this pull request as draft August 6, 2024 15:24
Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The more I've been working on this, the more I actually think this is the best solution we can get. Thanks a bunch for doing this, I think even though it's annoying with the patches, there's no other clear way to get there.

@muellerzr
Copy link
Collaborator

@byi8220 can you resolve the PR's and then I think we're okay to merge this.

@muellerzr
Copy link
Collaborator

As a final step, we likely want to update save_state/load_state to resume the dataloaders at this point.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR and the good discussion of possible designs. The end result is still something that I'm afraid will one day cause a hard to debug issue, but I can't say what exactly would be a better solution.

I added a couple of comments, which I think can help clean up this PR a bit, but don't consider them to be blockers. I have to admit I only skimmed the tests but they look very well done, so together with the existing ones they should hopefully avoid regressions.

One thing I would like to see is an addition to the docs to explain what stateful data loaders are, why users may want to use them, and how they can use them.

src/accelerate/accelerator.py Outdated Show resolved Hide resolved
src/accelerate/accelerator.py Outdated Show resolved Hide resolved
src/accelerate/data_loader.py Outdated Show resolved Hide resolved
base_cls = self.__class__
base_cls_name = self.__class__.__name__
parent_cls_name = self.base_dataloader.__class__
self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me just bring up (again) that another solution could be monkey-patching __instancecheck__ on DataLoader. Not saying that it's less hacky, just wanted to raise awareness :)

Comment on lines 433 to 434
for attr in self.base_dataloader.__dict__.keys():
setattr(self, attr, getattr(self.base_dataloader, attr))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kinda looks dangerous. For example, this skips @property, is that intended? We could instead use __getattr__ to dispatch to self.base_dataloader.

If we want to stick this this, more succinct code could be: self.__dict__.update(self.base_loader.__dict__) or vars(self).update(self.base_loader.__dict__)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kinda looks dangerous.

Kinda agree with you, but all dynamic reflection looks dangerous to me.

I did write up an alternative which avoids the wizardry and just duplicates all the code required over here in: byi8220/accelerate@stateful-dataloader...byi8220:accelerate:stateful-dataloader-2

That code is messier and involves way more duplication, but much more explicit in what it does. If enough people feel the reflection approach is way too hacky and this feature doesn't justify it, I'm fine with doing that instead.

We could instead use getattr to dispatch to self.base_dataloader.

I updated the PR to do that instead.

super().load_state_dict(state_dict)
self.dl_state_dict = self.state_dict

def _save_state_dict(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, the name is not quite fitting, isn't it more like update_state_dict or so? Also, maybe we can avoid this all by not having a static self.dl_state_dict attribute but instead the state_dict method just returns self.base_dataloader.state_dict().

Copy link
Contributor Author

@byi8220 byi8220 Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, the name is not quite fitting, isn't it more like update_state_dict or so?

Changed to _update_state_dict

Also, maybe we can avoid this all by not having a static self.dl_state_dict attribute but instead the state_dict method just returns self.base_dataloader.state_dict().

I'm not sure if we can. The base dataloader's state dict is one ahead of what we're yielding, so we couldn't do a passthrough. Some additional context in the comments of a6e192c#r1704736815

src/accelerate/data_loader.py Outdated Show resolved Hide resolved
src/accelerate/data_loader.py Outdated Show resolved Hide resolved
src/accelerate/data_loader.py Outdated Show resolved Hide resolved
@byi8220
Copy link
Contributor Author

byi8220 commented Aug 20, 2024

@BenjaminBossan Thanks for the review. Just addressed the comments on the PR.

The end result is still something that I'm afraid will one day cause a hard to debug issue, but I can't say what exactly would be a better solution.

The ultimate intent of this code is something like "Sometimes I want a DataLoaderDispatcher that inherits from DataLoader, but other times I want a DataLoaderDispatcher that inherits from StatefulDataLoader."

Imo, the less magical alternative would be to explicitly duplicate each DataLoader derivative that accelerate defines into a stateful version. I.e. manually create the classes StatefulDataLoaderDispatcher, StatefulDataLoaderShard, StatefulSkipDataLoader. I wrote up this alternative in a separate branch (diffed by byi8220/accelerate@stateful-dataloader...byi8220:accelerate:stateful-dataloader-2), but it leads to quite a lot of code duplication and also looks messy.

I have to admit I only skimmed the tests but they look very well done, so together with the existing ones they should hopefully avoid regressions.

I've tested this locally on my 1 GPU home workstation + a 2xGPU cloud instance (which costs me a few dollars every time I want to run the test suite 😞 ...)
The fact that all tests pass for me when not using this feature, regardless of if the required torchdata version is not installed gives confidence that it's not causing a breaking regression.

This is my first real PR into accelerate, so I added the sanity and happy test cases I could think of based on my limited context, so I might have just been guessing on what's sufficient.

The tests highlighted one small thing though, the fact that to fully stop using a dataloader in the middle you have to call dataloader.end(), but this might just be unavoidable. If the use case of StatefulDataLoader is to restart the entire program from a checkpoint maybe it's not a big issue: https://github.com/huggingface/accelerate/pull/2895/files#diff-68b278b14afa2e1ea337bb5e13d122f6d074c8bf0f0b83bef779eac6f4ba7f9aR724-R726

One thing I would like to see is an addition to the docs to explain what stateful data loaders are, why users may want to use them, and how they can use them.

Imo this might be better in a separate PR, once the code is checked in?

@byi8220 byi8220 marked this pull request as ready for review August 20, 2024 17:10
@muellerzr
Copy link
Collaborator

muellerzr commented Aug 20, 2024

The tests highlighted one small thing though, the fact that to fully stop using a dataloader in the middle you have to call dataloader.end(), but this might just be unavoidable. If the use case of StatefulDataLoader is to restart the entire program from a checkpoint maybe it's not a big issue

I believe this has been a known "issue" in accelerate (I've seen it pop up in other issues sparingly). Agreed that it's less of an issue here, since this is pretty much just called once at the start of training. As long as we have the state properly (which your tests check!) it's a different bug/issue to solve

Imo this might be better in a separate PR, once the code is checked in?

We tend to like full FC PR's that include doc updates. Less likely it'll be forgotten about and it's done all at once so users who want the bleeding edge can read immediately :)

@byi8220
Copy link
Contributor Author

byi8220 commented Aug 20, 2024

I believe this has been a known "issue" in accelerate (I've seen it pop up in other issues sparingly).

Well, I have no idea how to solve such a problem in python. In the C++ world this is what destructors and RAII are for, I guess.

We tend to like full FC PR's that include doc updates. Less likely it'll be forgotten about and it's done all at once so users who want the bleeding edge can read immediately :)

Sure, added a footnote in the docs about this feature.

Also since this feature is now stable in torchdata I added a requirement for torchdata>=0.8.0 in setup.py

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates.

I'm not sure if we can. The base dataloader's state dict is one ahead of what we're yielding, so we couldn't do a passthrough. Some additional context in the comments of a6e192c#r1704736815

I see. If Zach is fine with the proposed solution, then we're good.

Well, I have no idea how to solve such a problem in python. In the C++ world this is what destructors and RAII are for, I guess.

There is the __del__ magic method in Python but let's not touch it.


def __getattr__(self, name):
# Delegate attribute access to the internal dataloader
return getattr(self.base_dataloader, name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of an edge case: Let's also check if the name is not "base_dataloader", and if it is to raise an AttributeError, to avoid an infinite recursion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give a code example of how infinite recursion would happen here?

If I'm reading the python3 docs for __getattr__() correctly, it states "Note that if the attribute is found through the normal mechanism, __getattr__() is not called." IIUC, base_dataloader should always be retrievable through the normal mechanism.

If I add the following block into test_dataloader_inheritance() in test_data_loader.py (without making any changes), the tests pass without causing an infinite recursion:

        assert isinstance(skip_dl.base_dataloader, DataLoader)
        assert isinstance(dl_shard.base_dataloader, DataLoader)
        assert isinstance(dl_dispatcher.base_dataloader, DataLoader)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give a code example of how infinite recursion would happen here?

Yes, that would be for the edge case of an attribute getting called on the class, i.e. before it is instantiated. In that case, the base_dataloader attribute does not exist. Now you could say "who would do such a pernicious thing?", but it's a bug that actually happened in another project and for some reason DeepSpeed would do this (on a module, not a data loader, but let's rather be safe than sorry).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still not entirely sure how this could happen, but I added a check in __getattr__.

@byi8220
Copy link
Contributor Author

byi8220 commented Aug 21, 2024

I see. If Zach is fine with the proposed solution, then we're good.

sgtm

There is the del magic method in Python but let's not touch it.

I see. Destructors in python don't seem very reliable, but my knowledge of the python memory model isn't great.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a final review and also more carefully reviewed the tests. I didn't find anything big, but a few minor things that could be improved. After that, this can be merged from my POV.

src/accelerate/data_loader.py Outdated Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
return _is_package_available("torchdata")


# TODO: Remove this function once stateful_dataloader is a stable feature in torchdata.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this now be adjusted to use a version check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, modeled after other version checks in file

tests/test_accelerator.py Outdated Show resolved Hide resolved
tests/test_accelerator.py Outdated Show resolved Hide resolved
self.dl_state_dict = self.state_dict

def _update_state_dict(self):
if hasattr(self.base_dataloader, "state_dict"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment here when this needs to be called and with the context on why it's required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment here, kinda clunky though.

tests/test_accelerator.py Outdated Show resolved Hide resolved
@byi8220
Copy link
Contributor Author

byi8220 commented Aug 21, 2024

Thanks!

I didn't find anything big, but a few minor things that could be improved.

I fixed the nits above, but I also made one, maybe important, change, done in 74e2f53

Basically, I literally realized just now that I have been delegating the work of iteration to the superclass, instead of the backing dataloader. That felt wrong, so I did the commit above.

To confirm, replacing super()->self.base_dataloader here is the sensible thing to do, right? Like I want to fully delegate everything I can to the base_dataloader, and the only reason it worked as written before is because of getattr spaghetti?

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nicely done! 🤩

Thanks so much for all your hard work on this, let's get it merged in ✅

(And yes, I think it is a sensible thing to do rather than the getattr spaghetti)

@muellerzr muellerzr merged commit ad3f574 into huggingface:main Aug 22, 2024
25 checks passed
@muellerzr
Copy link
Collaborator

as a next step I'll work on getting this working with Accelerator.save_state/Accelerator.load_state today

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve skip_first_batches method to efficiently support IterableDataset and StatefulDataloader
5 participants