-
Notifications
You must be signed in to change notification settings - Fork 0
/
CosineDecayRestarts.py
124 lines (111 loc) · 5.21 KB
/
CosineDecayRestarts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""This file includes CosineDecayRestarts learning rate scheduler, which is
not supported by current tensorflow version I used. This version is copied from
Tensorflow v2.8.0.
(https://github.com/keras-team/keras/blob/d8fcb9d4d4dad45080ecfdd575483653028f8eda/keras/optimizer_v2/learning_rate_schedule.py#L646-L768)
"""
import tensorflow as tf
import math
class CosineDecayRestarts(tf.keras.optimizers.schedules.LearningRateSchedule):
"""A LearningRateSchedule that uses a cosine decay schedule with restarts.
See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983),
SGDR: Stochastic Gradient Descent with Warm Restarts.
When training a model, it is often useful to lower the learning rate as
the training progresses. This schedule applies a cosine decay function with
restarts to an optimizer step, given a provided initial learning rate.
It requires a `step` value to compute the decayed learning rate. You can
just pass a TensorFlow variable that you increment at each training step.
The schedule is a 1-arg callable that produces a decayed learning
rate when passed the current optimizer step. This can be useful for changing
the learning rate value across different invocations of optimizer functions.
The learning rate multiplier first decays
from 1 to `alpha` for `first_decay_steps` steps. Then, a warm
restart is performed. Each new warm restart runs for `t_mul` times more
steps and with `m_mul` times initial learning rate as the new learning rate.
Example usage:
```python
first_decay_steps = 1000
lr_decayed_fn = (
tf.keras.optimizers.schedules.CosineDecayRestarts(
initial_learning_rate,
first_decay_steps))
```
You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
as the learning rate. The learning rate schedule is also serializable and
deserializable using `tf.keras.optimizers.schedules.serialize` and
`tf.keras.optimizers.schedules.deserialize`.
Returns:
A 1-arg callable learning rate schedule that takes the current optimizer
step and outputs the decayed learning rate, a scalar `Tensor` of the same
type as `initial_learning_rate`.
"""
def __init__(
self,
initial_learning_rate,
first_decay_steps,
t_mul=2.0,
m_mul=1.0,
alpha=0.0,
name=None):
"""Applies cosine decay with restarts to the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
number. The initial learning rate.
first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python
number. Number of steps to decay over.
t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
Used to derive the number of iterations in the i-th period.
m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
Used to derive the initial learning rate of the i-th period.
alpha: A scalar `float32` or `float64` Tensor or a Python number.
Minimum learning rate value as a fraction of the initial_learning_rate.
name: String. Optional name of the operation. Defaults to 'SGDRDecay'.
"""
super(CosineDecayRestarts, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.first_decay_steps = first_decay_steps
self._t_mul = t_mul
self._m_mul = m_mul
self.alpha = alpha
self.name = name
def __call__(self, step):
with tf.name_scope(self.name or "SGDRDecay") as name:
initial_learning_rate = tf.convert_to_tensor(
self.initial_learning_rate, name="initial_learning_rate")
dtype = initial_learning_rate.dtype
first_decay_steps = tf.cast(self.first_decay_steps, dtype)
alpha = tf.cast(self.alpha, dtype)
t_mul = tf.cast(self._t_mul, dtype)
m_mul = tf.cast(self._m_mul, dtype)
global_step_recomp = tf.cast(step, dtype)
completed_fraction = global_step_recomp / first_decay_steps
def compute_step(completed_fraction, geometric=False):
"""Helper for `cond` operation."""
if geometric:
i_restart = tf.floor(
tf.math.log(1.0 - completed_fraction * (1.0 - t_mul)) /
tf.math.log(t_mul))
sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart
else:
i_restart = tf.floor(completed_fraction)
completed_fraction -= i_restart
return i_restart, completed_fraction
i_restart, completed_fraction = tf.cond(
tf.equal(t_mul, 1.0),
lambda: compute_step(completed_fraction, geometric=False),
lambda: compute_step(completed_fraction, geometric=True))
m_fac = m_mul**i_restart
cosine_decayed = 0.5 * m_fac * (1.0 + tf.cos(
tf.constant(math.pi, dtype=dtype) * completed_fraction))
decayed = (1 - alpha) * cosine_decayed + alpha
# print(step.numpy(), decayed.numpy())
return tf.multiply(initial_learning_rate, decayed, name=name)
def get_config(self):
return {
"initial_learning_rate": self.initial_learning_rate,
"first_decay_steps": self.first_decay_steps,
"t_mul": self._t_mul,
"m_mul": self._m_mul,
"alpha": self.alpha,
"name": self.name
}