Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

fix: ensure last checkpoint is always saved, refactor training stop conditions to be computed in single location #729

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions metaseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,40 +74,47 @@ def save_checkpoint(

checkpoint_conds[f"checkpoint{epoch}{suffix}.pt"] = save_for_epoch
checkpoint_conds[f"checkpoint_{updates}{suffix}.pt"] = save_for_updates
checkpoint_conds[f"checkpoint_last{suffix}.pt"] = (
(training_finished and cfg.save_last_checkpoint)
or save_for_epoch
or save_for_updates
)
checkpoint_last_file_name = f"checkpoint_last{suffix}.pt"

extra_state = {"train_iterator": epoch_itr.state_dict()}

checkpoints = [
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
checkpoint_file_paths = [
os.path.join(cfg.save_dir, checkpoint_file_name)
for checkpoint_file_name, cond in checkpoint_conds.items()
if cond
]

if len(checkpoints) > 0:
if PathManager.islink(checkpoints[0]):
PathManager.rm(checkpoints[0])
def _save_checkpoint(checkpoint_file_path: str):
if PathManager.islink(checkpoint_file_path):
PathManager.rm(checkpoint_file_path)

trainer.save_checkpoint(
checkpoints[0],
checkpoint_file_path,
extra_state,
training_finished=training_finished,
async_callback_fn=async_callback_fn if save_to_NFS else None,
files_to_symlink_to=checkpoints[1:] if len(checkpoints) > 1 else None,
async_callback_fn=async_callback_fn,
)

write_timer.stop()
logger.info(
f"Saved checkpoint {checkpoints[0]} (epoch {epoch} @ {updates} updates) "
f"Saved checkpoint {checkpoint_file_path} (epoch {epoch} @ {updates} updates) "
f"(writing took {write_timer.sum} seconds)"
)

# See if there's any older checkpoints to delete after saving a new one.
# Only deletes if keep_last_updates > 0 or keep_last_epochs > 0 (default -1 for both).
delete_old_checkpoint_files(cfg, end_of_epoch, suffix)

# If there are checkpoints to save, save the first in the list
if len(checkpoint_file_paths) > 0:
_save_checkpoint(checkpoint_file_paths[0])

if training_finished and cfg.save_last_checkpoint:
checkpoint_last_file_path = os.path.join(
cfg.save_dir, checkpoint_last_file_name
)
_save_checkpoint(checkpoint_last_file_path)


def delete_old_checkpoint_files(cfg: CheckpointConfig, end_of_epoch: bool, suffix: str):
if not end_of_epoch and cfg.keep_last_updates > 0:
Expand Down
56 changes: 37 additions & 19 deletions metaseq/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,10 @@ def main(cfg: DictConfig) -> None:
disable_iterator_cache=True,
)

max_epoch = cfg.optimization.max_epoch or math.inf
train_meter = meters.StopwatchMeter()
train_meter.start()
while epoch_itr.next_epoch_idx <= max_epoch:

while True:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually it is practice to avoid while True anywhere since it relies on other code to stop the loop and it's easy to make mistakes. However, the alternative of splitting logic between loop and validate function is more complex and thus more likely for us to have issues in future.

Also, another reason the while true is not be so bad is because the original could potentially be the same condition when max-epochs was not provided / defined.

If not defined

while epoch_itr.next_epoch_idx <= max_epoch
while epoch_itr.next_epoch_idx <= cfg.optimization.max_epoch or math.inf
while epoch_itr.next_epoch_idx <= math.inf
while true

# train for one epoch
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
if should_stop:
Expand All @@ -221,6 +221,9 @@ def main(cfg: DictConfig) -> None:
disable_iterator_cache=True,
)
train_meter.stop()

# make sure every process finishes before exiting...
distributed_utils.global_barrier()
logger.info("done training in {:.1f} seconds".format(train_meter.sum))


Expand Down Expand Up @@ -433,6 +436,8 @@ def validate_and_save(
end_of_epoch: bool,
was_successful_step: bool,
) -> Tuple[List[Optional[float]], bool]:
num_epoch = epoch_itr.epoch
max_epoch = cfg.optimization.max_epoch or math.inf
num_updates = trainer.get_num_updates()
max_update = cfg.optimization.max_update or math.inf

Expand All @@ -444,44 +449,57 @@ def validate_and_save(
# Stopping conditions (and an additional one based on validation loss later
# on)
should_stop = False
if num_updates >= max_update:

if num_epoch > max_epoch:
should_stop = True
logger.info(
f"Stopping training due to "
f"num_epoch: {num_epoch} > max_epoch: {max_epoch}"
)
elif num_updates > max_update:
should_stop = True
logger.info(
f"Stopping training due to "
f"num_updates: {num_updates} >= max_update: {max_update}"
f"num_updates: {num_updates} > max_update: {max_update}"
)

save_locally = (
is_epoch_save_interval = (
end_of_epoch
and cfg.checkpoint.save_interval_epochs > 0
and num_epoch % cfg.checkpoint.save_interval_epochs == 0
)
is_successful_update_local_save_interval = (
cfg.checkpoint.local_save_interval_updates > 0
and num_updates > 0
and num_updates % cfg.checkpoint.local_save_interval_updates == 0
and was_successful_step
)
save_to_NFS = (
is_successful_update_save_interval = (
cfg.checkpoint.save_interval_updates > 0
and num_updates > 0
and num_updates % cfg.checkpoint.save_interval_updates == 0
and was_successful_step
)
is_successful_update_validate_interval = (
cfg.checkpoint.validate_interval_updates > 0
and num_updates > 0
and num_updates % cfg.checkpoint.validate_interval_updates == 0
and was_successful_step
)

do_save = (
(
end_of_epoch
and cfg.checkpoint.save_interval_epochs > 0
and epoch_itr.epoch % cfg.checkpoint.save_interval_epochs == 0
)
is_epoch_save_interval
or (
(save_locally or save_to_NFS)
and num_updates >= cfg.dataset.validate_after_updates
and was_successful_step
is_successful_update_local_save_interval
or is_successful_update_save_interval
)
or should_stop
)
do_validate = (
should_stop
or (
cfg.dataset.validate_interval_updates > 0
and num_updates > 0
and num_updates % cfg.dataset.validate_interval_updates == 0
and was_successful_step
is_successful_update_validate_interval
and num_updates >= cfg.dataset.validate_after_updates
)
) and not cfg.dataset.disable_validation

Expand All @@ -494,7 +512,7 @@ def validate_and_save(
training_finished=should_stop,
async_callback_fn=functools.partial(
post_checkpoint_callback, cfg, num_updates, should_stop
)
),
)

valid_losses = [None]
Expand Down