Skip to content

Commit

Permalink
Merge branch 'math-scaling-laws' of https://github.com/EleutherAI/gpt…
Browse files Browse the repository at this point in the history
…-neox into math-scaling-laws
  • Loading branch information
zhangir-azerbayev committed Oct 26, 2023
2 parents 44c41e2 + 2facca5 commit 9ba50aa
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
10 changes: 4 additions & 6 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,20 +351,18 @@ def load_checkpoint(
):
"""Load a model checkpoint and return the iteration."""
if neox_args.deepspeed:
load_optim_and_scheduler = (
load_optim= (
not neox_args.no_load_optim
) # TODO: These should be configured by separate args
if neox_args.finetune:
load_optim_and_scheduler = False
if iteration is not None:
tag = get_checkpoint_tag(iteration)
else:
tag = None
checkpoint_name, state_dict = model.load_checkpoint(
neox_args.load,
load_optimizer_states=load_optim_and_scheduler,
load_lr_scheduler_states=load_optim_and_scheduler,
load_module_only=not load_optim_and_scheduler,
load_optimizer_states=load_optim,
load_lr_scheduler_states=False,
load_module_only=not load_optim,
tag=tag,
)

Expand Down
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,11 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate):
Minimum value for learning rate. The scheduler clips values below this threshold.
"""

warmup_ratio: float = None
"""
Proportion of steps to warm up for
"""

warmup_iters: int = None
"""
Number of warmup iterations
Expand Down
8 changes: 7 additions & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,10 +606,16 @@ def get_learning_rate_scheduler(optimizer, neox_args):

init_step = 0

assert not (neox_args.warmup_ratio and neox_args.warmup_iters)
if neox_args.warmup_ratio:
warmup_iters = neox_args.warmup_ratio * num_iters
else:
warmup_iters = neox_args.warmup_iters

lr_scheduler = AnnealingLR(
optimizer,
start_lr=neox_args.lr,
warmup_iter=neox_args.warmup_iters,
warmup_iter=warmup_iters,
total_iters=num_iters,
decay_style=neox_args.lr_decay_style,
last_iter=init_step,
Expand Down

0 comments on commit 9ba50aa

Please sign in to comment.