Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding cosine rewarmed scheduler #243

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from open_lm.distributed import is_master, init_distributed_device, broadcast_object
from open_lm.logger import setup_logging
from open_lm.params import parse_args
from open_lm.scheduler import cosine_lr, const_lr
from open_lm.scheduler import cosine_lr, const_lr, cosine_rewarmed_lr
from open_lm.train import train_one_epoch
from open_lm.evaluate import evaluate_loop
from open_lm.file_utils import (
Expand Down Expand Up @@ -691,8 +691,23 @@ def main(args):
# args.lr_cooldown_end,
# args.force_min_lr,
)
elif args.lr_scheduler == "cosine-rewarmed":
resumed_step = (args.train_num_samples * start_epoch) // args.global_batch_size
scheduler = cosine_rewarmed_lr(
optimizer,
args.lr,
args.warmup,
total_steps,
args.lr_cooldown_end,
args.force_min_lr,
args.cosine_rewarmed_target_steps,
args.cosine_rewarmed_original_warmup,
resumed_step,
)
else:
raise ValueError(f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const.")
raise ValueError(
f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, cosine-rewarned."
)

# determine if this worker should save logs and checkpoints. only do so if it is rank == 0
args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args)
Expand Down
18 changes: 15 additions & 3 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,9 @@ def check_args(args):
if args.remote_sync_protocol != "s3":
raise ValueError("Sync protocol not supported when using resume latest.")

if args.lr_scheduler not in {"cosine", "const", "const-cooldown"}:
if args.lr_scheduler not in {"cosine", "const", "cosine-rewarmed"}:
raise ValueError(
f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown."
f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, cosine-rewarmed."
)

if args.experimental_meta_device:
Expand Down Expand Up @@ -391,7 +391,19 @@ def parse_args(args):
"--lr-scheduler",
type=str,
default="cosine",
help="LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown). Default: cosine",
help="LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown), 'cosine-rewarmed'. Default: cosine",
)
parser.add_argument(
"--cosine-rewarmed-target-steps",
type=int,
default=None,
help="for cosine rewarmed, the target steps for the cosine schedule. Default: cosine",
)
parser.add_argument(
"--cosine-rewarmed-original-warmup",
type=int,
default=1000,
help="for cosine rewarmed, the original warmup steps. Default: 1000",
)
parser.add_argument(
"--lr-cooldown-end",
Expand Down
36 changes: 36 additions & 0 deletions open_lm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ def _warmup_lr(base_lr, warmup_length, step):
return base_lr * (step + 1) / warmup_length


def _cosine_lr(step, base_lr, warmup_length, steps, min_lr, force_min_lr):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
e = step - warmup_length
es = steps - warmup_length
lr = min_lr + 0.5 * (1 + np.cos(np.pi * e / es)) * (base_lr - min_lr)
lr = max(lr, force_min_lr)
return lr


def const_lr(optimizer, base_lr, warmup_length):
def _lr_adjuster(step):
if step < warmup_length:
Expand Down Expand Up @@ -63,3 +74,28 @@ def _lr_adjuster(step):
return lr

return _lr_adjuster


def cosine_rewarmed_lr(
optimizer, base_lr, warmup_length, steps, min_lr, force_min_lr, target_steps, original_warmup, resumed_step
):
def _lr_adjuster(step):
step -= resumed_step
new_base_lr = _cosine_lr(
target_steps - steps + warmup_length, base_lr, original_warmup, target_steps, min_lr, force_min_lr
)
if step < warmup_length:
lr = _warmup_lr(new_base_lr, warmup_length, step)
else:
lr = _cosine_lr(
target_steps - steps + step - warmup_length,
base_lr,
warmup_length,
target_steps - warmup_length,
min_lr,
force_min_lr,
)
assign_learning_rate(optimizer, lr)
return lr

return _lr_adjuster
Loading