Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly support resuming from checkpoint with a dataset without length #33544

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

muupan
Copy link

@muupan muupan commented Sep 17, 2024

What does this PR do?

There is an inconsistency in Trainer's behavior between training from scratch and resuming from checkpoint when the given dataset has no length like datasets.IterableDataset. For a reproducible example, see #26413 (comment) . This PR fixes the inconsistency by correctly supporting resuming from checkpoint with such a dataset.

Fixes #26413

Current behavior

When training starts with a dataset without length, Trainer assumes one epoch is equal to max_steps steps and tries to train for that many steps. There are two possible scenarios.

  • A. If the dataset yields enough samples, the training finishes precisely after one epoch.
  • B. If the dataset raises StopIteration before yielding samples enough for max_steps steps, Trainer increments the current epoch and re-iterate the dataset.

When resuming from a checkpoint, Trainer simply skips the first batches until global_step of the checkpoint. In scenario A, there is no problem. In scenario B, the dataset raises StopIteration during the skipping, but Trainer does not re-iterate the dataset. Instead, it just finishes training with a warning. This is inconsistent from what happens in training from scratch, and it contradicts with what the documents about max_steps says:

max_steps (`int`, *optional*, defaults to -1):
If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.
For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until
`max_steps` is reached.

Solution

This PR modifies the skipping behavior so that Trainer now re-iterates the dataset until it catches up global_step. A caveat is that it does not support the ignore_data_skip option, as Trainer does not know what epoch to start from. I am also concerned that the logic is becoming too complicated.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker

@LysandreJik
Copy link
Member

Very impressive PR @muupan!

I'm pinging @muellerzr and @SunMarc to take a look; Zach is off for a few weeks and will take a look as soon as he's back, thank you for your patience 🙏

@LysandreJik LysandreJik requested review from SunMarc and muellerzr and removed request for SunMarc September 18, 2024 13:59
@SunMarc
Copy link
Member

SunMarc commented Sep 27, 2024

Thanks for the PR @muupan ! We will review it shortly. There is a new feature in accelerate that enable you to use a stateful dataloader, so that we don't need to iterate to resume a training. Feel free to give it a try, note that it is a very experimental support for now.

@muupan muupan force-pushed the feature/resume-training-with-iterable-dataset branch from 6f83505 to bc56f1c Compare October 31, 2024 08:53
@muupan
Copy link
Author

muupan commented Oct 31, 2024

It seems like the code got broken after rebasing with main, where #34198 renamed the variable epoch_iterator. I will fix.

@SunMarc
Copy link
Member

SunMarc commented Nov 5, 2024

Let us knew when it is done !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

resume_from_checkpoint function fails because "There seems to be not a single sample in your epoch_iterator"
3 participants