diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 9b4e73842..1f601bed7 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -255,7 +255,7 @@ def consume_deepy_args(cls): type=int, ) group.add_argument( - "--warmup_iter", + "--warmup_iters", type=int, ) group.add_argument( @@ -270,6 +270,10 @@ def consume_deepy_args(cls): "--save", type=str, ) + group.add_argument( + "--load", + type=str + ) group.add_argument( "--log_dir", type=str, diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index aeac958a2..77f4bcd84 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -451,12 +451,7 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate): Minimum value for learning rate. The scheduler clips values below this threshold. """ - warmup: float = None - """ - Percentage of total iterations to warmup on (.01 = 1 percent of all training iters). - """ - - warmup_iter: int = None + warmup_iters: int = None """ Number of warmup iterations """ diff --git a/megatron/training.py b/megatron/training.py index e920725f8..779b15839 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -604,18 +604,12 @@ def get_learning_rate_scheduler(optimizer, neox_args): num_iters = max(1, num_iters) - assert not (neox_args.warmup_iter and neox_args.warmup) - if neox_args.warmup: - warmup_iter = neox_args.warmup*num_iters - else: - warmup_iter = neox_args.warmup_iter - init_step = 0 lr_scheduler = AnnealingLR( optimizer, start_lr=neox_args.lr, - warmup_iter=warmup_iter, + warmup_iter=neox_args.warmup_iters, total_iters=num_iters, decay_style=neox_args.lr_decay_style, last_iter=init_step,