Skip to content

Commit

Permalink
Merge pull request #18 from FlyingPumba/linear-lr
Browse files Browse the repository at this point in the history
Add option for linear decrease of LR
  • Loading branch information
cybershiptrooper authored Sep 15, 2024
2 parents aea84c7 + e6d278c commit 04c5c4c
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions iit/model_pairs/base_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,30 @@ def train(
scheduler_kwargs['patience'] = 10
if 'factor' not in scheduler_kwargs:
scheduler_kwargs['factor'] = 0.1

print(f"Setting up ReduceLROnPlateau scheduler ({mode}): {scheduler_kwargs}")
lr_scheduler = scheduler_cls(optimizer, mode=mode, **scheduler_kwargs)
if scheduler_cls == t.optim.lr_scheduler.LambdaLR:
# The default behaviour is to linearly reduce the learning rate to 2e-4
initial_lr = training_args["lr"]
final_lr = 2e-4
if 'final_lr' in scheduler_kwargs:
final_lr = scheduler_kwargs['final_lr']
del scheduler_kwargs['final_lr']

if "lr_lambda" not in scheduler_kwargs:
def linear_lr(step: int) -> float:
return 1 - (step / epochs) * (1 - final_lr / initial_lr)

scheduler_kwargs["lr_lambda"] = linear_lr

print(f"Setting up LambdaLR scheduler: {scheduler_kwargs}")
lr_scheduler = scheduler_cls(optimizer, **scheduler_kwargs)
elif scheduler_cls:
print(f"Setting up {scheduler_cls} scheduler: {scheduler_kwargs}")
lr_scheduler = scheduler_cls(optimizer, **scheduler_kwargs)
else:
print("No LR scheduler set up")

if use_wandb and not wandb.run:
wandb.init(project="iit", name=wandb_name_suffix,
Expand Down

0 comments on commit 04c5c4c

Please sign in to comment.