Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve skip_first_batches method to efficiently support IterableDataset and StatefulDataloader #2859

Closed
yzhangcs opened this issue Jun 15, 2024 · 9 comments · Fixed by #2895

Comments

@yzhangcs
Copy link

yzhangcs commented Jun 15, 2024

Hi all, Thank you for developing this great project.
Currently, the implementation naively iterates through all batches until the specified number have been consumed, which can be extremely slow for very large datasets.
The latest version of the datasets library now supports resumable iterable datasets, as well as the StatefulDataloader to allow for efficient resumption of training states.
a3049bf3e1246379c558fe8133c6d34e

I'm wondering if there are any plans to leverage these new features in Accelerate to make skip_first_batches more efficient and compatible with the latest datasets capabilities?
If not, are there plans to add support for this in the future?
Efficiently skipping batches on huge datasets would significantly speed up resuming interrupted training runs. Let me know if you need any additional information or have thoughts on the best way to approach this.

Thanks for considering this suggestion!

@yzhangcs yzhangcs changed the title How to skip first batches efficiently. Improve skip_first_batches method to efficiently support IterableDataset and StatefulDataloader Jun 15, 2024
@muellerzr
Copy link
Collaborator

You can ping myself (@muellerzr) or @SunMarc on these things, Sylvain hasn't worked at HF for well over a year or two now :)

@muellerzr
Copy link
Collaborator

Yes, we are indeed actively looking into this!

@byi8220
Copy link
Contributor

byi8220 commented Jun 22, 2024

Ran into something annoying while looking at this. Merely importing StatefulDataLoader (i.e. putting the line from torchdata.stateful_dataloader import StatefulDataLoader anywhere in the code) causes one of the unit test, check_seedable_sampler to fail.

I suspect it has something to do with torchdata overriding torch's BatchSampler in this code. This is supported by the fact if I import this and add some logging, it seems SeedableRandomSampler.__iter__() is called one less time than expected:

# We should see the epoch and seed sequence [(0, 42), (1, 43), (2, 44)] twice, but the first call with seed 42 is missing
# It looks like the first sample is being drawn without setting a seed

stdout: stdout: Shuffled central dataloader passing.
stdout: stdout: {'x': tensor([-1.3022,  0.1278], device='cuda:0'), 'y': tensor([0.3097, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400,  0.7505], device='cuda:0'), 'y': tensor([0.9978, 4.5075], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.3047, -0.0168], device='cuda:0'), 'y': tensor([3.6974, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.8530,  0.9406], device='cuda:0'), 'y': tensor([1.2889, 4.9939], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.3162, -1.9510], device='cuda:0'), 'y': tensor([ 2.2716, -0.8553], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 1 43
stdout: stdout: {'x': tensor([-0.0168, -1.9510], device='cuda:0'), 'y': tensor([ 3.0542, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.8530,  0.3047], device='cuda:0'), 'y': tensor([1.2889, 3.6974], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022, -1.0400], device='cuda:0'), 'y': tensor([0.3097, 0.9978], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.1278, -0.3162], device='cuda:0'), 'y': tensor([3.2926, 2.2716], device='cuda:0')}
stdout: stdout: {'x': tensor([0.7505, 0.9406], device='cuda:0'), 'y': tensor([4.5075, 4.9939], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 2 44
stdout: stdout: {'x': tensor([0.7505, 0.1278], device='cuda:0'), 'y': tensor([4.5075, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.9406, -0.0168], device='cuda:0'), 'y': tensor([4.9939, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.3162, -0.8530], device='cuda:0'), 'y': tensor([2.2716, 1.2889], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400, -1.9510], device='cuda:0'), 'y': tensor([ 0.9978, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022,  0.3047], device='cuda:0'), 'y': tensor([0.3097, 3.6974], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Resetting epoch and seed
stdout: stdout: Setting seed at epoch 0 42
stdout: stdout: {'x': tensor([0.7505, 0.1278], device='cuda:0'), 'y': tensor([4.5075, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400, -0.0168], device='cuda:0'), 'y': tensor([0.9978, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.9510, -1.3022], device='cuda:0'), 'y': tensor([-0.8553,  0.3097], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.3047, -0.8530], device='cuda:0'), 'y': tensor([3.6974, 1.2889], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.9406, -0.3162], device='cuda:0'), 'y': tensor([4.9939, 2.2716], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 1 43
stdout: stdout: {'x': tensor([-0.0168, -1.9510], device='cuda:0'), 'y': tensor([ 3.0542, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.8530,  0.3047], device='cuda:0'), 'y': tensor([1.2889, 3.6974], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022, -1.0400], device='cuda:0'), 'y': tensor([0.3097, 0.9978], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.1278, -0.3162], device='cuda:0'), 'y': tensor([3.2926, 2.2716], device='cuda:0')}
stdout: stdout: {'x': tensor([0.7505, 0.9406], device='cuda:0'), 'y': tensor([4.5075, 4.9939], device='cuda:0')}
stdout: stdout: --
stdout: stdout: Setting seed at epoch 2 44
stdout: stdout: {'x': tensor([0.7505, 0.1278], device='cuda:0'), 'y': tensor([4.5075, 3.2926], device='cuda:0')}
stdout: stdout: {'x': tensor([ 0.9406, -0.0168], device='cuda:0'), 'y': tensor([4.9939, 3.0542], device='cuda:0')}
stdout: stdout: {'x': tensor([-0.3162, -0.8530], device='cuda:0'), 'y': tensor([2.2716, 1.2889], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.0400, -1.9510], device='cuda:0'), 'y': tensor([ 0.9978, -0.8553], device='cuda:0')}
stdout: stdout: {'x': tensor([-1.3022,  0.3047], device='cuda:0'), 'y': tensor([0.3097, 3.6974], device='cuda:0')}
stdout: stdout: --
stdout: stdout: original_items:
stdout: stdout:  tensor([-1.3022,  0.1278, -1.0400,  0.7505,  0.3047, -0.0168, -0.8530,  0.9406,
stdout: stdout:         -0.3162, -1.9510, -0.0168, -1.9510, -0.8530,  0.3047, -1.3022, -1.0400,
stdout: stdout:          0.1278, -0.3162,  0.7505,  0.9406,  0.7505,  0.1278,  0.9406, -0.0168,
stdout: stdout:         -0.3162, -0.8530, -1.0400, -1.9510, -1.3022,  0.3047], device='cuda:0')
stdout: stdout: new_items:
stdout: stdout:  tensor([ 0.7505,  0.1278, -1.0400, -0.0168, -1.9510, -1.3022,  0.3047, -0.8530,
stdout: stdout:          0.9406, -0.3162, -0.0168, -1.9510, -0.8530,  0.3047, -1.3022, -1.0400,
stdout: stdout:          0.1278, -0.3162,  0.7505,  0.9406,  0.7505,  0.1278,  0.9406, -0.0168,
stdout: stdout:         -0.3162, -0.8530, -1.0400, -1.9510, -1.3022,  0.3047], device='cuda:0')

At the moment I don't know why this happens so I can't tell if this some misconfig in my local workspace, a bug somewhere in the torchdata library itself, or just a sharp edge that could be worked around.

Assuming there aren't other traps, writing the rest of the feature doesn't feel like too much work, though the most immediate solution I could think of (that isn't a big refactor) to just create some subclasses e.g. StatefulDataLoaderShard(DataLoaderShard, StatefulDataLoader, DataLoaderStateMixin) and let duck typing do the rest feels kinda hacky imo.

@yzhangcs
Copy link
Author

@byi8220 Hi, seems that datasets does not support states with buffer right now.
huggingface/datasets#6658 (comment)

@byi8220
Copy link
Contributor

byi8220 commented Jun 24, 2024

Gah, this feature is getting more complicated every second. We're also at the mercy of how StatefulDataLoader is implemented (ran into a tricky problem here) 😞

Hi, seems that datasets does not support states with buffer right now

Thank you for mentioning. Is it accurate to call this a related but separate issue? Please correct me if I'm wrong, my understanding of the problem scope is that:

  1. The datasets library is responsible for supporting state_dict/load_state_dict when a dataset has a buffer
  2. The accelerate library is responsible for utilizing state_dict/load_state_dict to save and load checkpoints (scope of this issue)
  3. The trainer library is responsible for being aware of when to call skip_first_batches

But regarding the breaking test I mentioned above, I'm unsure if it is related. The test which breaks when importing StatefulDataLoader is check_seedable_sampler. What is very strange about this test's breakage is that the test breaks without any changes to the code except by simply importing the package torchdata.stateful_dataloader. The test was unchanged and used a non-stateful pytorch DataLoader. It's as if the import itself caused something to break.

@byi8220
Copy link
Contributor

byi8220 commented Jun 24, 2024

Also, just to elaborate on the the problem with StatefulDataLoader I'm running into, in case it's helpful info:

DataLoaderShard.__iter__() (https://github.com/huggingface/accelerate/blob/main/src/accelerate/data_loader.py#L445-L476) works by wrapping around the underlying DataLoader.iter() and advancing it. The problem here is that we eagerly pick up next_batch before yielding current_batch. This appears to be done to support self.end_of_dataloader.

Here's a crude mock of how I think this behavior works: https://pastebin.com/Sk1DfDYz

The problem here is that when the wrapper w1 is yielding 1 the inner itr has already yielded 2. In real code, w1 would be DataLoaderShard, and itr would be DataLoader.

One solution could be to implement DataLoaderShard.state_dict() to just keep the previous state_dict around. But this would introduce the overhead of calling StatefulDataLoader.state_dict() on every iteration, which might be expensive? Or maybe a refactor of the semantics of end_of_dataloader (which seems like a big refactor). Or maybe I'm missing a really obvious solution.

@byi8220
Copy link
Contributor

byi8220 commented Jun 26, 2024

Took a shot at getting StatefulDataLoader connected to this library in #2895. Seems like way more work than I would have imagined, and admittedly it's experimental and there may be issues.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@byi8220
Copy link
Contributor

byi8220 commented Jul 31, 2024

Don't know if this is closed considering the PR is still open...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
3 participants