diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 639a5daf2..7dec66da2 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -14,14 +14,19 @@ LR Scheduler Arguments Learning rate decay function. Choose from 'constant', 'linear', 'cosine', 'exponential'. - - **lr_decay_iters**: int Default = None - Number of iterations to decay learning rate over, If None defaults to --train-iters + Number of iterations to decay learning rate over. If None, defaults to + --train-iters or the equivalent inferred value from train_epochs. + +- **lr_decay_fraction**: float + Default = None + Effective fraction of training over which to decay lr. Overrides lr_decay_iters. + Useful when specifying train_epochs. - **min_lr**: float diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index b6286a5eb..dbf5eb0bf 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -573,7 +573,13 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate): lr_decay_iters: int = None """ - Number of iterations to decay learning rate over, If None defaults to --train-iters + Number of iterations to decay learning rate over, If None defaults to + --train-iters or the equivalent infered valued from train_epochs. + """ + + lr_decay_fraction: float = None + """ + Effective fraction of training over which to decay lr, overrides lr_decay_iters, useful when specifying train_epochs """ min_lr: float = 0.0 diff --git a/megatron/training.py b/megatron/training.py index 7150484c0..2b8af0470 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1090,6 +1090,8 @@ def get_learning_rate_scheduler(optimizer, neox_args): # Add linear learning rate scheduler. if neox_args.lr_decay_iters is not None: num_iters = neox_args.lr_decay_iters + elif neox_args.lr_decay_fraction is not None: + num_iters = math.floor(neox_args.train_iters * neox_args.lr_decay_fraction) else: num_iters = neox_args.train_iters num_iters = max(1, num_iters)