diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 8bcc01f3b..ea447faa5 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -351,20 +351,18 @@ def load_checkpoint( ): """Load a model checkpoint and return the iteration.""" if neox_args.deepspeed: - load_optim_and_scheduler = ( + load_optim= ( not neox_args.no_load_optim ) # TODO: These should be configured by separate args - if neox_args.finetune: - load_optim_and_scheduler = False if iteration is not None: tag = get_checkpoint_tag(iteration) else: tag = None checkpoint_name, state_dict = model.load_checkpoint( neox_args.load, - load_optimizer_states=load_optim_and_scheduler, - load_lr_scheduler_states=load_optim_and_scheduler, - load_module_only=not load_optim_and_scheduler, + load_optimizer_states=load_optim, + load_lr_scheduler_states=False, + load_module_only=not load_optim, tag=tag, ) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 77f4bcd84..b257800e4 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -451,6 +451,11 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate): Minimum value for learning rate. The scheduler clips values below this threshold. """ + warmup_ratio: float = None + """ + Proportion of steps to warm up for + """ + warmup_iters: int = None """ Number of warmup iterations diff --git a/megatron/training.py b/megatron/training.py index 779b15839..a5bfdb7b7 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -606,10 +606,16 @@ def get_learning_rate_scheduler(optimizer, neox_args): init_step = 0 + assert not (neox_args.warmup_ratio and neox_args.warmup_iters) + if neox_args.warmup_ratio: + warmup_iters = neox_args.warmup_ratio * num_iters + else: + warmup_iters = neox_args.warmup_iters + lr_scheduler = AnnealingLR( optimizer, start_lr=neox_args.lr, - warmup_iter=neox_args.warmup_iters, + warmup_iter=warmup_iters, total_iters=num_iters, decay_style=neox_args.lr_decay_style, last_iter=init_step,