diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py index c70efe7c6..d5d2640c9 100644 --- a/megatron/learning_rates.py +++ b/megatron/learning_rates.py @@ -82,7 +82,8 @@ def get_lr(self): ) elif self.decay_style == "exponential": # exp(-0.693) = 1/2 - lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter) + end_iter = self.end_iter - self.warmup_iter + lr = self.start_lr * math.exp(-0.693 * num_iters_ / end_iter) else: lr = self.start_lr return max(lr, self.min_lr)