Skip to content

Commit

Permalink
Merge pull request #95 from kozistr/refactor/optimizers
Browse files Browse the repository at this point in the history
[Feature] Implement & Optimize a few optimizer options
  • Loading branch information
kozistr authored Jan 28, 2023
2 parents f6baa63 + 27d6b99 commit ce56167
Show file tree
Hide file tree
Showing 24 changed files with 234 additions and 168 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytorch_optimizer"
version = "2.2.0"
version = "2.2.1"
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
license = "Apache-2.0"
authors = ["kozistr <[email protected]>"]
Expand Down
5 changes: 5 additions & 0 deletions pytorch_optimizer/base/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def validate_update_frequency(update_frequency: int):
if update_frequency < 1:
raise ValueError(f'[-] update_frequency {update_frequency} must be positive')

@staticmethod
def validate_norm(norm: float):
if norm < 0.0:
raise ValueError(f'[-] norm {norm} must be positive')

@abstractmethod
def validate_parameters(self):
raise NotImplementedError
Expand Down
15 changes: 7 additions & 8 deletions pytorch_optimizer/optimizer/adabelief.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
loss = closure()

for group in self.param_groups:
beta1, beta2 = group['betas']
if self.rectify:
n_sma_max: float = 2.0 / (1.0 - beta2) - 1.0

for p in group['params']:
if p.grad is None:
continue
Expand Down Expand Up @@ -128,18 +132,16 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']

state['step'] += 1
beta1, beta2 = group['betas']

bias_correction1 = 1.0 - beta1 ** state['step']
bias_correction2 = 1.0 - beta2 ** state['step']

exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
grad_residual = grad - exp_avg
exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2)

exp_avg_var = exp_avg_var.add_(group['eps'])
exp_avg_var.add_(group['eps'])
if group['amsgrad']:
exp_avg_var = torch.max(state['max_exp_avg_var'], exp_avg_var)
torch.max(state['max_exp_avg_var'], exp_avg_var, out=exp_avg_var)

de_nom = (exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

Expand All @@ -155,12 +157,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
n_sma_max = 2 / (1 - beta2) - 1
n_sma = n_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = n_sma

if n_sma >= self.n_sma_threshold:
rt = math.sqrt(
step_size = math.sqrt(
(1 - beta2_t)
* (n_sma - 4)
/ (n_sma_max - 4)
Expand All @@ -169,8 +170,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
* n_sma_max
/ (n_sma_max - 2)
)

step_size = rt
if not group['adamd_debias_term']:
step_size /= bias_correction1
elif self.degenerated_to_sgd:
Expand Down
8 changes: 3 additions & 5 deletions pytorch_optimizer/optimizer/adabound.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
loss = closure()

for group, base_lr in zip(self.param_groups, self.base_lrs):
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue
Expand All @@ -112,21 +113,18 @@ def step(self, closure: CLOSURE = None) -> LOSS:
state['step'] += 1
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

if group['weight_decay'] != 0:
if group['weight_decay'] > 0.0:
if self.weight_decouple:
p.mul_(
1.0 - (group['weight_decay'] if self.fixed_decay else group['lr'] * group['weight_decay'])
)
else:
grad.add_(p, alpha=group['weight_decay'])

beta1, beta2 = group['betas']

exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

if group['amsbound']:
exp_avg_sq = torch.max(state['max_exp_avg_sq'], exp_avg_sq)
torch.max(state['max_exp_avg_sq'], exp_avg_sq, out=exp_avg_sq)

de_nom = exp_avg_sq.sqrt().add_(group['eps'])

Expand Down
9 changes: 5 additions & 4 deletions pytorch_optimizer/optimizer/adai.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg_sq_hat_sum: float = 0.0

for group in self.param_groups:
_, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue
Expand All @@ -106,14 +107,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
state['step'] += 1

exp_avg_sq = state['exp_avg_sq']
_, beta2 = group['betas']

if self.use_gc:
grad = centralize_gradient(grad, gc_conv_only=False)

bias_correction2 = 1.0 - beta2 ** state['step']

if group['weight_decay'] != 0:
if group['weight_decay'] > 0.0:
if self.weight_decouple:
p.mul_(1.0 - group['lr'] * group['weight_decay'])
else:
Expand All @@ -129,6 +129,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg_sq_hat_mean = exp_avg_sq_hat_sum / param_size

for group in self.param_groups:
beta0, beta2 = group['betas']
beta0_dp = math.pow(beta0, 1.0 - group['dampening'])
for p in group['params']:
if p.grad is None:
continue
Expand All @@ -138,7 +140,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1_prod = state['beta1_prod']
beta0, beta2 = group['betas']

bias_correction2 = 1.0 - beta2 ** state['step']

Expand All @@ -152,7 +153,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
bias_correction1 = 1.0 - beta1_prod

exp_avg.mul_(beta1).addcmul_(beta3, grad)
exp_avg_hat = exp_avg / bias_correction1 * math.pow(beta0, 1.0 - group['dampening'])
exp_avg_hat = exp_avg / bias_correction1 * beta0_dp

p.add_(exp_avg_hat, alpha=-group['lr'])

Expand Down
15 changes: 8 additions & 7 deletions pytorch_optimizer/optimizer/adamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
loss = closure()

for group in self.param_groups:
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue
Expand All @@ -103,10 +104,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

state['step'] += 1
beta1, beta2 = group['betas']
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

bias_correction1 = 1.0 - beta1 ** state['step']
bias_correction2 = 1.0 - beta2 ** state['step']
Expand All @@ -117,12 +116,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

de_nom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
inv_de_nom = 1.0 / (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

perturb = exp_avg.clone()
if group['nesterov']:
perturb = (beta1 * exp_avg + (1.0 - beta1) * grad) / de_nom
# perturb = beta1 * exp_avg + (1.0 - beta1) * grad / de_nom
perturb.mul_(beta1).addcmul_(grad, inv_de_nom, value=1.0 - beta1)
else:
perturb = exp_avg / de_nom
perturb.mul_(inv_de_nom)

wd_ratio: float = 1
if len(p.shape) > 1:
Expand All @@ -135,7 +136,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
group['eps'],
)

if group['weight_decay'] > 0:
if group['weight_decay'] > 0.0:
p.mul_(1.0 - group['lr'] * group['weight_decay'] * wd_ratio)

step_size = group['lr']
Expand Down
56 changes: 43 additions & 13 deletions pytorch_optimizer/optimizer/adan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from typing import Union

import torch
from torch.optim.optimizer import Optimizer
Expand All @@ -17,6 +18,7 @@ class Adan(Optimizer, BaseOptimizer):
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. decoupled weight decay.
:param max_grad_norm: float. max gradient norm to clip.
:param use_gc: bool. use gradient centralization.
:param eps: float. term added to the denominator to improve numerical stability.
"""
Expand All @@ -28,13 +30,15 @@ def __init__(
betas: BETAS = (0.98, 0.92, 0.99),
weight_decay: float = 0.0,
weight_decouple: bool = False,
max_grad_norm: float = 0.0,
use_gc: bool = False,
eps: float = 1e-8,
):
self.lr = lr
self.betas = betas
self.weight_decay = weight_decay
self.weight_decouple = weight_decouple
self.max_grad_norm = max_grad_norm
self.use_gc = use_gc
self.eps = eps

Expand All @@ -46,6 +50,7 @@ def __init__(
eps=eps,
weight_decay=weight_decay,
weight_decouple=weight_decouple,
max_grad_norm=max_grad_norm,
)
super().__init__(params, defaults)

Expand All @@ -54,6 +59,7 @@ def validate_parameters(self):
self.validate_betas(self.betas)
self.validate_weight_decay(self.weight_decay)
self.validate_epsilon(self.eps)
self.validate_norm(self.max_grad_norm)

@property
def __name__(self) -> str:
Expand All @@ -62,23 +68,54 @@ def __name__(self) -> str:
@torch.no_grad()
def reset(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
state = self.state[p]

state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_diff'] = torch.zeros_like(p)
state['exp_avg_nest'] = torch.zeros_like(p)
state['previous_grad'] = torch.zeros_like(p)

@torch.no_grad()
def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
if self.defaults['max_grad_norm'] == 0.0:
return 1.0

device = self.param_groups[0]['params'][0].device

global_grad_norm = torch.zeros(1, device=device)
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)

for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
global_grad_norm.add_(torch.linalg.norm(p.grad).pow(2))

global_grad_norm = torch.sqrt(global_grad_norm)

return torch.clamp(max_grad_norm / (global_grad_norm + self.eps), max=1.0)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

clip_global_grad_norm = self.get_global_gradient_norm()

for group in self.param_groups:
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1

beta1, beta2, beta3 = group['betas']
bias_correction1 = 1.0 - beta1 ** group['step']
bias_correction2 = 1.0 - beta2 ** group['step']
bias_correction3_sq = math.sqrt(1.0 - beta3 ** group['step'])

for p in group['params']:
if p.grad is None:
continue
Expand All @@ -89,35 +126,28 @@ def step(self, closure: CLOSURE = None) -> LOSS:

state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_diff'] = torch.zeros_like(p)
state['exp_avg_nest'] = torch.zeros_like(p)
state['previous_grad'] = torch.zeros_like(p)
state['previous_grad'] = grad.clone()

exp_avg, exp_avg_diff, exp_avg_nest = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_nest']
prev_grad = state['previous_grad']

state['step'] += 1
beta1, beta2, beta3 = group['betas']

bias_correction1 = 1.0 - beta1 ** state['step']
bias_correction2 = 1.0 - beta2 ** state['step']
bias_correction3 = 1.0 - beta3 ** state['step']
grad.mul_(clip_global_grad_norm)

if self.use_gc:
grad = centralize_gradient(grad, gc_conv_only=False)

grad_diff = grad - prev_grad
state['previous_grad'] = grad.clone()
grad_diff = grad - state['previous_grad']
state['previous_grad'].copy_(grad)

update = grad + beta2 * grad_diff

exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_diff.mul_(beta2).add_(grad_diff, alpha=1.0 - beta2)
exp_avg_nest.mul_(beta3).addcmul_(update, update, value=1.0 - beta3)

de_nom = (exp_avg_nest.sqrt_() / math.sqrt(bias_correction3)).add_(self.eps)
de_nom = (exp_avg_nest.sqrt_() / bias_correction3_sq).add_(self.eps)
perturb = (exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2).div_(de_nom)

if group['weight_decouple']:
Expand Down
13 changes: 6 additions & 7 deletions pytorch_optimizer/optimizer/adapnm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
loss = closure()

for group in self.param_groups:
beta1, beta2, beta3 = group['betas']
noise_norm = math.sqrt((1 + beta3) ** 2 + beta3 ** 2) # fmt: skip
for p in group['params']:
if p.grad is None:
continue
Expand All @@ -107,7 +109,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
state['max_exp_avg_sq'] = torch.zeros_like(p)

state['step'] += 1
beta1, beta2, beta3 = group['betas']

bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
Expand All @@ -120,18 +121,16 @@ def step(self, closure: CLOSURE = None) -> LOSS:

exp_avg.mul_(beta1 ** 2).add_(grad, alpha=1 - beta1 ** 2) # fmt: skip
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

if group['amsgrad']:
exp_avg_sq = torch.max(state['max_exp_avg_sq'], exp_avg_sq)
torch.max(state['max_exp_avg_sq'], exp_avg_sq, out=exp_avg_sq)

denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
de_nom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

step_size = group['lr']
if not group['adamd_debias_term']:
step_size /= bias_correction1

noise_norm = math.sqrt((1 + beta3) ** 2 + beta3 ** 2) # fmt: skip
pn_momentum = exp_avg.mul(1 + beta3).add(neg_exp_avg, alpha=-beta3).mul(1.0 / noise_norm)
p.addcdiv_(pn_momentum, denom, value=-step_size)
pn_momentum = exp_avg.mul(1.0 + beta3).add(neg_exp_avg, alpha=-beta3).mul(1.0 / noise_norm)
p.addcdiv_(pn_momentum, de_nom, value=-step_size)

return loss
Loading

0 comments on commit ce56167

Please sign in to comment.