Skip to content

Commit

Permalink
Correctly support resuming with dataset without length
Browse files Browse the repository at this point in the history
  • Loading branch information
muupan committed Oct 31, 2024
1 parent 405b562 commit bc56f1c
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 16 deletions.
56 changes: 40 additions & 16 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2205,7 +2205,7 @@ def _inner_training_loop(
max_steps = args.max_steps
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_train_epochs = sys.maxsize
num_update_steps_per_epoch = max_steps
num_update_steps_per_epoch = None
num_examples = total_train_batch_size * args.max_steps
num_train_samples = args.max_steps * total_train_batch_size
if args.include_tokens_per_second:
Expand Down Expand Up @@ -2355,15 +2355,28 @@ def _inner_training_loop(
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
self.compare_trainer_and_checkpoint_args(self.args, self.state)
self._load_callback_state()
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
if num_update_steps_per_epoch is not None:
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0
else:
steps_trained_in_current_epoch = 0
# If the dataloader does not have a length, we cannot restore the number of trained epochs.
# In the following loop, we repeatedly iterate over the dataloader to skip the first
# `steps_trained_in_current_epoch` steps and increment `epochs_trained` accordingly.
epochs_trained = 0
steps_trained_in_current_epoch = self.state.global_step * args.gradient_accumulation_steps
if args.ignore_data_skip:
raise ValueError(
"The dataloader does not have a length, so it is impossible to restore the number of trained"
" epochs. Please disable the `ignore_data_skip` option."
)

logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
if num_update_steps_per_epoch is not None:
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip:
logger.info(
Expand Down Expand Up @@ -2410,6 +2423,26 @@ def _inner_training_loop(
if hasattr(epoch_dataloader, "set_epoch"):
epoch_dataloader.set_epoch(epoch)

steps_skipped = 0
rng_to_sync = False
if steps_trained_in_current_epoch > 0 and num_update_steps_per_epoch is None:
# Since the dataloader does not have a length, we just loop until the required number of steps.
# Every time we reach the end of the dataloader, we increment epoch and reset the iterator.
epoch_iterator = iter(epoch_iterator)
epoch_over = False
while steps_trained_in_current_epoch > 0:
try:
next(epoch_iterator)
steps_trained_in_current_epoch -= 1
steps_skipped += 1
except StopIteration:
epoch_over = True
break
if epoch_over:
continue
assert steps_trained_in_current_epoch == 0
rng_to_sync = True

# Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0:
self._past = None
Expand All @@ -2424,8 +2457,6 @@ def _inner_training_loop(
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)

rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
Expand Down Expand Up @@ -2575,13 +2606,6 @@ def _inner_training_loop(
if is_torch_xla_available():
xm.mark_step()
break
if step < 0:
logger.warning(
"There seems not to be a single sample in your epoch_iterator, stopping training at step"
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
f" num_steps ({max_steps}) higher than the number of available samples."
)
self.control.should_training_stop = True

self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
Expand Down
38 changes: 38 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2958,6 +2958,44 @@ def test_resume_training_with_frozen_params(self):
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)

@parameterized.expand([(9, 1), (10, 1), (11, 1), (20, 1), (21, 1), (9, 2)])
def test_resume_training_with_iterable_dataset(self, dataset_length, gradient_accumulation_steps):
with tempfile.TemporaryDirectory() as tmpdir:

def get_trainer():
config = RegressionModelConfig()
train_dataset = SampleIterableDataset(length=dataset_length)
model = RegressionRandomPreTrainedModel(config)
args = RegressionTrainingArguments(
output_dir=tmpdir,
learning_rate=0.1,
max_steps=20,
save_steps=10,
per_device_train_batch_size=1,
gradient_accumulation_steps=gradient_accumulation_steps,
)
return Trainer(model=model, args=args, train_dataset=train_dataset)

# Train from scratch.
trainer = get_trainer()
trainer.train()
self.assertEqual(trainer.state.global_step, 20)
(a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state)

# Train from a checkpoint.
checkpoint = os.path.join(tmpdir, "checkpoint-10")
trainer = get_trainer()
trainer.train(resume_from_checkpoint=checkpoint)
self.assertEqual(trainer.state.global_step, 20)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)

# Check that the resumed model is the same as the original one.
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)

def test_load_best_model_at_end(self):
total = int(self.n_epochs * 64 / self.batch_size)
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down

0 comments on commit bc56f1c

Please sign in to comment.