diff --git a/megatron/training.py b/megatron/training.py index 03491f70a..28b73e6f4 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -665,9 +665,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): ) else: model.module.set_batch_fn( - partial( - get_batch_sequential, neox_args=neox_args - ) + partial(get_batch_sequential, neox_args=neox_args) ) else: @@ -687,6 +685,11 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): else: neox_args.iteration = 0 + # need this for correct lr scheduling resume from ckpt + lr_scheduler.optimizer = model.optimizer + lr_scheduler.param_groups = model.optimizer.param_groups + lr_scheduler.model = model + return model, optimizer, lr_scheduler