Skip to content

Commit

Permalink
add lr_decay_fraction
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-WAIFU committed Sep 27, 2024
1 parent 6c85b3e commit 4e9551b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
9 changes: 7 additions & 2 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@ LR Scheduler Arguments
Learning rate decay function. Choose from 'constant', 'linear', 'cosine', 'exponential'.



- **lr_decay_iters**: int

Default = None

Number of iterations to decay learning rate over, If None defaults to --train-iters
Number of iterations to decay learning rate over. If None, defaults to
--train-iters or the equivalent inferred value from train_epochs.

- **lr_decay_fraction**: float

Default = None

Effective fraction of training over which to decay lr. Overrides lr_decay_iters.
Useful when specifying train_epochs.

- **min_lr**: float

Expand Down
8 changes: 7 additions & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,13 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate):

lr_decay_iters: int = None
"""
Number of iterations to decay learning rate over, If None defaults to --train-iters
Number of iterations to decay learning rate over, If None defaults to
--train-iters or the equivalent infered valued from train_epochs.
"""

lr_decay_fraction: float = None
"""
Effective fraction of training over which to decay lr, overrides lr_decay_iters, useful when specifying train_epochs
"""

min_lr: float = 0.0
Expand Down
2 changes: 2 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,8 @@ def get_learning_rate_scheduler(optimizer, neox_args):
# Add linear learning rate scheduler.
if neox_args.lr_decay_iters is not None:
num_iters = neox_args.lr_decay_iters
elif neox_args.lr_decay_fraction is not None:
num_iters = math.floor(neox_args.train_iters * neox_args.lr_decay_fraction)
else:
num_iters = neox_args.train_iters
num_iters = max(1, num_iters)
Expand Down

0 comments on commit 4e9551b

Please sign in to comment.