Skip to content

Commit

Permalink
change configs parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangir-azerbayev committed Oct 25, 2023
1 parent 00124fc commit 2facca5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,11 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate):
Minimum value for learning rate. The scheduler clips values below this threshold.
"""

warmup_ratio: float = None
"""
Proportion of steps to warm up for
"""

warmup_iters: int = None
"""
Number of warmup iterations
Expand Down
8 changes: 7 additions & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,10 +606,16 @@ def get_learning_rate_scheduler(optimizer, neox_args):

init_step = 0

assert not (neox_args.warmup_ratio and neox_args.warmup_iters)
if neox_args.warmup_ratio:
warmup_iters = neox_args.warmup_ratio * num_iters
else:
warmup_iters = neox_args.warmup_iters

lr_scheduler = AnnealingLR(
optimizer,
start_lr=neox_args.lr,
warmup_iter=neox_args.warmup_iters,
warmup_iter=warmup_iters,
total_iters=num_iters,
decay_style=neox_args.lr_decay_style,
last_iter=init_step,
Expand Down

0 comments on commit 2facca5

Please sign in to comment.