Skip to content

Commit

Permalink
Merge pull request #42 from kozistr/feature/load-optimizers
Browse files Browse the repository at this point in the history
[Test] Add test case for load_optimizers
  • Loading branch information
kozistr authored Jan 28, 2022
2 parents 6343cca + b974419 commit 42655e0
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 18 deletions.
4 changes: 2 additions & 2 deletions pytorch_optimizer/adabound.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:

beta1, beta2 = group['betas']

exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if group['amsbound']:
max_exp_avg_sq = torch.max(max_exp_avg_sq, exp_avg_sq)
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
Expand Down
8 changes: 4 additions & 4 deletions pytorch_optimizer/diffgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg, exp_avg_sq, previous_grad = state['exp_avg'], state['exp_avg_sq'], state['previous_grad']

if group['weight_decay'] != 0:
grad.add_(group['weight_decay'], p.data)
grad.add_(p.data, alpha=group['weight_decay'])

state['step'] += 1
beta1, beta2 = group['betas']

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = exp_avg_sq.sqrt().add_(group['eps'])

bias_correction1 = 1 - beta1 ** state['step']
Expand All @@ -116,6 +116,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
else:
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

p.data.addcdiv_(-step_size, exp_avg1, denom)
p.data.addcdiv_(exp_avg1, denom, value=-step_size)

return loss
12 changes: 6 additions & 6 deletions pytorch_optimizer/diffrgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:

bias_correction1 = 1 - beta1 ** state['step']

exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

# compute diffGrad coefficient (dfc)
diff = abs(previous_grad - grad)
Expand Down Expand Up @@ -164,18 +164,18 @@ def step(self, closure: CLOSURE = None) -> LOSS:

if n_sma >= self.n_sma_threshold:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])

denom = exp_avg_sq.sqrt().add_(group['eps'])

# update momentum with dfc
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg * dfc.float(), denom)
p_data_fp32.addcdiv_(exp_avg * dfc.float(), denom, value=-step_size * group['lr'])
p.data.copy_(p_data_fp32)
elif step_size > 0:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])

p_data_fp32.add_(-step_size * group['lr'], exp_avg)
p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
p.data.copy_(p_data_fp32)

return loss
3 changes: 3 additions & 0 deletions pytorch_optimizer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytorch_optimizer.diffgrad import DiffGrad
from pytorch_optimizer.diffrgrad import DiffRGrad
from pytorch_optimizer.fp16 import SafeFP16Optimizer
from pytorch_optimizer.lamb import Lamb
from pytorch_optimizer.madgrad import MADGRAD
from pytorch_optimizer.radam import RAdam
from pytorch_optimizer.ranger import Ranger
Expand Down Expand Up @@ -39,6 +40,8 @@ def load_optimizers(optimizer: str, use_fp16: bool = False):
opt = DiffGrad
elif optimizer == 'adahessian':
opt = AdaHessian
elif optimizer == 'lamb':
opt = Lamb
else:
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')

Expand Down
12 changes: 6 additions & 6 deletions pytorch_optimizer/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:

bias_correction1 = 1 - beta1 ** state['step']

exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

state['step'] += 1
buffered = group['buffer'][int(state['step'] % 10)]
Expand Down Expand Up @@ -155,14 +155,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:

if n_sma >= self.n_sma_threshold:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
p.data.copy_(p_data_fp32)
elif step_size > 0:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
p.data.copy_(p_data_fp32)

return loss
43 changes: 43 additions & 0 deletions tests/test_load_optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import List

import pytest

from pytorch_optimizer import load_optimizers

VALID_OPTIMIZER_NAMES: List[str] = [
'adamp',
'sgdp',
'madgrad',
'ranger',
'ranger21',
'radam',
'adabound',
'adahessian',
'adabelief',
'diffgrad',
'diffrgrad',
'lamb',
]

INVALID_OPTIMIZER_NAMES: List[str] = [
'asam',
'sam',
'pcgrad',
'adamd',
'lookahead',
'chebyshev_schedule',
]


@pytest.mark.parametrize('valid_optimizer_names', VALID_OPTIMIZER_NAMES)
def test_load_optimizers_valid(valid_optimizer_names):
load_optimizers(valid_optimizer_names)


@pytest.mark.parametrize('invalid_optimizer_names', INVALID_OPTIMIZER_NAMES)
def test_load_optimizers_invalid(invalid_optimizer_names):
try:
load_optimizers(invalid_optimizer_names)
except NotImplementedError:
return True
return False

0 comments on commit 42655e0

Please sign in to comment.