Skip to content

Commit

Permalink
Merge pull request #77 from kozistr/feature/adan-optimizer
Browse files Browse the repository at this point in the history
[Feature, Fix] Adan optimizer
  • Loading branch information
kozistr authored Sep 2, 2022
2 parents 34fd10b + 504f13c commit f51dead
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytorch_optimizer"
version = "1.3.1"
version = "1.3.2"
description = "Bunch of optimizer implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
license = "Apache-2.0"
authors = ["kozistr <[email protected]>"]
Expand Down
43 changes: 30 additions & 13 deletions pytorch_optimizer/adan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import torch
from torch.optim.optimizer import Optimizer

Expand All @@ -8,7 +10,7 @@

class Adan(Optimizer, BaseOptimizer):
"""
Reference : x
Reference : https://github.com/sail-sg/Adan/blob/main/adan.py
Example :
from pytorch_optimizer import Adan
...
Expand All @@ -27,21 +29,24 @@ def __init__(
params: PARAMETERS,
lr: float = 1e-3,
betas: BETAS = (0.98, 0.92, 0.99),
weight_decay: float = 0.02,
weight_decay: float = 0.0,
weight_decouple: bool = False,
use_gc: bool = False,
eps: float = 1e-16,
eps: float = 1e-8,
):
"""Adan optimizer
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param weight_decay: float. weight decay (L2 penalty)
:param weight_decouple: bool. decoupled weight decay
:param use_gc: bool. use gradient centralization
:param eps: float. term added to the denominator to improve numerical stability
"""
self.lr = lr
self.betas = betas
self.weight_decay = weight_decay
self.weight_decouple = weight_decouple
self.use_gc = use_gc
self.eps = eps

Expand All @@ -52,6 +57,7 @@ def __init__(
betas=betas,
eps=eps,
weight_decay=weight_decay,
weight_decouple=weight_decouple,
)
super().__init__(params, defaults)

Expand All @@ -69,7 +75,7 @@ def reset(self):

state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_var'] = torch.zeros_like(p)
state['exp_avg_diff'] = torch.zeros_like(p)
state['exp_avg_nest'] = torch.zeros_like(p)
state['previous_grad'] = torch.zeros_like(p)

Expand All @@ -93,29 +99,40 @@ def step(self, closure: CLOSURE = None) -> LOSS:
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_var'] = torch.zeros_like(p)
state['exp_avg_diff'] = torch.zeros_like(p)
state['exp_avg_nest'] = torch.zeros_like(p)
state['previous_grad'] = torch.zeros_like(p)

exp_avg, exp_avg_var, exp_avg_nest = state['exp_avg'], state['exp_avg_var'], state['exp_avg_nest']
exp_avg, exp_avg_diff, exp_avg_nest = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_nest']
prev_grad = state['previous_grad']

state['step'] += 1
beta1, beta2, beta3 = group['betas']

bias_correction1 = 1.0 - beta1 ** state['step']
bias_correction2 = 1.0 - beta2 ** state['step']
bias_correction3 = 1.0 - beta3 ** state['step']

if self.use_gc:
grad = centralize_gradient(grad, gc_conv_only=False)

grad_diff = grad - prev_grad
state['previous_grad'] = grad.clone()

exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_var.mul_(beta2).add_(grad_diff, alpha=1.0 - beta2)
exp_avg_nest.mul_(beta3).add_((grad + beta2 * grad_diff) ** 2, alpha=1.0 - beta3)
update = grad + beta2 * grad_diff

step_size = group['lr'] / exp_avg_nest.add_(self.eps).sqrt_()

p.sub_(step_size * (exp_avg + beta2 * exp_avg_var))
p.div_(1.0 + group['weight_decay'])
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_diff.mul_(beta2).add_(grad_diff, alpha=1.0 - beta2)
exp_avg_nest.mul_(beta3).addcmul_(update, update, value=1.0 - beta3)

de_nom = (exp_avg_nest.sqrt_() / math.sqrt(bias_correction3)).add_(self.eps)
perturb = (exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2).div_(de_nom)

if group['weight_decouple']:
p.mul_(1.0 - group['lr'] * group['weight_decay'])
p.add_(perturb, alpha=-group['lr'])
else:
p.add_(perturb, alpha=-group['lr'])
p.div_(1.0 + group['lr'] * group['weight_decay'])

return loss
6 changes: 4 additions & 2 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@
(AdaPNM, {'lr': 3e-1, 'weight_decay': 1e-3, 'amsgrad': False}, 500),
(Nero, {'lr': 5e-1}, 200),
(Nero, {'lr': 5e-1, 'constraints': False}, 200),
(Adan, {'lr': 2e-1}, 200),
(Adan, {'lr': 1e-0, 'weight_decay': 1e-3, 'use_gc': True}, 500),
(Adan, {'lr': 5e-1}, 300),
(Adan, {'lr': 1e-0, 'weight_decay': 1e-3, 'use_gc': True}, 300),
(Adan, {'lr': 1e-0, 'weight_decay': 1e-3, 'use_gc': True, 'weight_decouple': True}, 300),
]

ADAMD_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
Expand Down Expand Up @@ -163,6 +164,7 @@ def test_safe_f16_optimizers(optimizer_fp16_config):
or (optimizer_name == 'RaLamb' and 'pre_norm' in config)
or (optimizer_name == 'PNM')
or (optimizer_name == 'Nero')
or (optimizer_name == 'Adan' and 'weight_decay' not in config)
):
return True

Expand Down

0 comments on commit f51dead

Please sign in to comment.