forked from QData/C-Tran
-
Notifications
You must be signed in to change notification settings - Fork 0
/
optim_schedule.py
22 lines (19 loc) · 905 Bytes
/
optim_schedule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
'''A wrapper class for optimizer '''
import numpy as np
from torch.optim.lr_scheduler import LambdaLR
class WarmupLinearSchedule(LambdaLR):
""" Linear warmup and then linear decay.
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps`
steps.
"""
def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
super(WarmupLinearSchedule, self).__init__(
optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1, self.warmup_steps))
return max(0.0, float(self.t_total - step) / float(
max(1.0, self.t_total - self.warmup_steps)))