Skip to content

Commit

Permalink
Merge pull request #20 from kozistr/feature/adabound-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement AdaBound/AdaBoundW optimizers
  • Loading branch information
kozistr authored Sep 22, 2021
2 parents 29a9dd3 + c3eab9b commit 278c29e
Show file tree
Hide file tree
Showing 5 changed files with 321 additions and 3 deletions.
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ for input, output in data:

| Optimizer | Description | Official Code | Paper |
| :---: | :---: | :---: | :---: |
| AdaBound | *Adaptive Gradient Methods with Dynamic Bound of Learning Rate* | [github](https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py) | [https://openreview.net/forum?id=Bkg3g2R9FX](https://openreview.net/forum?id=Bkg3g2R9FX) |
| AdaHessian | *An Adaptive Second Order Optimizer for Machine Learning* | [github](https://github.com/amirgholami/adahessian) | [https://arxiv.org/abs/2006.00719](https://arxiv.org/abs/2006.00719) |
| AdamP | *Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights* | [github](https://github.com/clovaai/AdamP) | [https://arxiv.org/abs/2006.08217](https://arxiv.org/abs/2006.08217) |
| MADGRAD | *A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic* | [github](https://github.com/facebookresearch/madgrad) | [https://arxiv.org/abs/2101.11075](https://arxiv.org/abs/2101.11075) |
Expand Down Expand Up @@ -336,6 +337,22 @@ Acceleration via Fractal Learning Rate Schedules

</details>

<details>

<summary>AdaBound</summary>

```
@inproceedings{Luo2019AdaBound,
author = {Luo, Liangchen and Xiong, Yuanhao and Liu, Yan and Sun, Xu},
title = {Adaptive Gradient Methods with Dynamic Bound of Learning Rate},
booktitle = {Proceedings of the 7th International Conference on Learning Representations},
month = {May},
year = {2019},
address = {New Orleans, Louisiana}
}
```

</details>

## Author

Expand Down
3 changes: 2 additions & 1 deletion pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pytorch_optimizer.adabound import AdaBound, AdaBoundW
from pytorch_optimizer.adahessian import AdaHessian
from pytorch_optimizer.adamp import AdamP
from pytorch_optimizer.agc import agc
Expand All @@ -10,4 +11,4 @@
from pytorch_optimizer.ranger21 import Ranger21
from pytorch_optimizer.sgdp import SGDP

__VERSION__ = '0.0.3'
__VERSION__ = '0.0.4'
298 changes: 298 additions & 0 deletions pytorch_optimizer/adabound.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
import math

import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.types import (
BETAS,
CLOSURE,
DEFAULT_PARAMETERS,
LOSS,
PARAMS,
STATE,
)


class AdaBound(Optimizer):
"""
Reference : https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py
Example :
from pytorch_optimizer import AdaBound
...
model = YourModel()
optimizer = AdaBound(model.parameters())
...
for input, output in data:
optimizer.zero_grad()
loss = loss_function(output, model(input))
loss.backward()
optimizer.step()
"""

def __init__(
self,
params: PARAMS,
lr: float = 1e-3,
betas: BETAS = (0.9, 0.999),
final_lr: float = 0.1,
gamma: float = 1e-3,
eps: float = 1e-8,
weight_decay: float = 0.0,
amsbound: bool = False,
):
"""AdaBound optimizer
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate
:param final_lr: float. final learning rate
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param gamma: float. convergence speed of the bound functions
:param eps: float. term added to the denominator to improve numerical stability
:param weight_decay: float. weight decay (L2 penalty)
:param amsbound: bool. whether to use the AMSBound variant
"""
self.lr = lr
self.betas = betas
self.eps = eps
self.weight_decay = weight_decay

defaults: DEFAULT_PARAMETERS = dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
super().__init__(params, defaults)

self.base_lrs = [group['lr'] for group in self.param_groups]

def check_valid_parameters(self):
if 0.0 > self.lr:
raise ValueError(f'Invalid learning rate : {self.lr}')
if 0.0 > self.eps:
raise ValueError(f'Invalid eps : {self.eps}')
if 0.0 > self.weight_decay:
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
if not 0.0 <= self.betas[0] < 1.0:
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
if not 0.0 <= self.betas[1] < 1.0:
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')

def __setstate__(self, state: STATE):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('amsbound', False)

def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
loss = closure()

for group, base_lr in zip(self.param_groups, self.base_lrs):
for p in group['params']:
if p.grad is None:
continue

grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'AdaBound does not support sparse gradients'
)

amsbound = group['amsbound']

state = self.state[p]

if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
if amsbound:
state['max_exp_avg_sq'] = torch.zeros_like(p)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsbound:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']

state['step'] += 1

if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsbound:
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])

bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = (
group['lr']
* math.sqrt(bias_correction2)
/ bias_correction1
)

final_lr = group['final_lr'] * group['lr'] / base_lr
lower_bound = final_lr * (
1 - 1 / (group['gamma'] * state['step'] + 1)
)
upper_bound = final_lr * (
1 + 1 / (group['gamma'] * state['step'])
)
step_size = torch.full_like(denom, step_size)
step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(
exp_avg
)

p.data.add_(-step_size)

return loss


class AdaBoundW(Optimizer):
"""
Reference : https://github.com/Luolc/AdaBound
Example :
from pytorch_optimizer import AdaBoundW
...
model = YourModel()
optimizer = AdaBoundW(model.parameters())
...
for input, output in data:
optimizer.zero_grad()
loss = loss_function(output, model(input))
loss.backward()
optimizer.step()
"""

def __init__(
self,
params: PARAMS,
lr: float = 1e-3,
betas: BETAS = (0.9, 0.999),
final_lr: float = 0.1,
gamma: float = 1e-3,
eps: float = 1e-8,
weight_decay: float = 0.0,
amsbound: bool = False,
):
"""AdaBound optimizer with decoupled weight decay
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate
:param final_lr: float. final learning rate
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param gamma: float. convergence speed of the bound functions
:param eps: float. term added to the denominator to improve numerical stability
:param weight_decay: float. weight decay (L2 penalty)
:param amsbound: bool. whether to use the AMSBound variant
"""
self.lr = lr
self.betas = betas
self.eps = eps
self.weight_decay = weight_decay

defaults: DEFAULT_PARAMETERS = dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
eps=eps,
weight_decay=weight_decay,
amsbound=amsbound,
)
super().__init__(params, defaults)

self.base_lrs = [group['lr'] for group in self.param_groups]

def check_valid_parameters(self):
if 0.0 > self.lr:
raise ValueError(f'Invalid learning rate : {self.lr}')
if 0.0 > self.eps:
raise ValueError(f'Invalid eps : {self.eps}')
if 0.0 > self.weight_decay:
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
if not 0.0 <= self.betas[0] < 1.0:
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
if not 0.0 <= self.betas[1] < 1.0:
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')

def __setstate__(self, state: STATE):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('amsbound', False)

def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
loss = closure()

for group, base_lr in zip(self.param_groups, self.base_lrs):
for p in group['params']:
if p.grad is None:
continue

p.mul_(1 - base_lr * group['weight_decay'])

grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'AdaBound does not support sparse gradients'
)

amsbound = group['amsbound']

state = self.state[p]

if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
if amsbound:
state['max_exp_avg_sq'] = torch.zeros_like(p)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsbound:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']

state['step'] += 1

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsbound:
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])

bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = (
group['lr']
* math.sqrt(bias_correction2)
/ bias_correction1
)

final_lr = group['final_lr'] * group['lr'] / base_lr
lower_bound = final_lr * (
1 - 1 / (group['gamma'] * state['step'] + 1)
)
upper_bound = final_lr * (
1 + 1 / (group['gamma'] * state['step'])
)
step_size = torch.full_like(denom, step_size)
step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(
exp_avg
)

p.data.add_(-step_size)

return loss
2 changes: 1 addition & 1 deletion pytorch_optimizer/adamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class AdamP(Optimizer):
from pytorch_optimizer import AdamP
...
model = YourModel()
optimizer = AdaHessian(model.parameters())
optimizer = AdamP(model.parameters())
...
for input, output in data:
optimizer.zero_grad()
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def read_version() -> str:
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Operating System :: OS Independent',
Expand All @@ -55,6 +54,9 @@ def read_version() -> str:
'chebyshev_schedule',
'lookahead',
'radam',
'adabound',
'adaboundw',
'adahessian',
]
)

Expand Down

0 comments on commit 278c29e

Please sign in to comment.