diff --git a/megatron/training.py b/megatron/training.py index 0f9f7a0e0..548f81cb0 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -686,7 +686,9 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): neox_args.iteration = 0 # need this for correct lr scheduling resume from ckpt - lr_scheduler.optimizer = model.optimizer + # but it will not exist if this is being called for inference + if lr_scheduler is not None: + lr_scheduler.optimizer = model.optimizer return model, optimizer, lr_scheduler