Skip to content

Commit

Permalink
Merge pull request #53 from kozistr/refactor/cov
Browse files Browse the repository at this point in the history
[Refactor] Coverage
  • Loading branch information
kozistr authored Feb 20, 2022
2 parents 358ff43 + a23d40b commit 6db0d49
Show file tree
Hide file tree
Showing 31 changed files with 555 additions and 297 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1 +1 @@
@kozistr
* @kozistr
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ or you can use optimizer loader, simply passing a name of the optimizer.

...
model = YourModel()
opt = load_optimizers(optimizer='adamp', use_fp16=True)
opt = load_optimizers(optimizer='adamp')
optimizer = opt(model.parameters())
...

Expand Down
18 changes: 15 additions & 3 deletions pytorch_optimizer/adabound.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,18 @@ def __setstate__(self, state: STATE):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('amsbound', False)
group.setdefault('adamd_debias_term', False)

@torch.no_grad()
def reset(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]

state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
if group['amsbound']:
state['max_exp_avg_sq'] = torch.zeros_like(p)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
Expand Down Expand Up @@ -127,14 +138,15 @@ def step(self, closure: CLOSURE = None) -> LOSS:

exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

if group['amsbound']:
max_exp_avg_sq = torch.max(state['max_exp_avg_sq'], exp_avg_sq)
de_nom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
de_nom = exp_avg_sq.sqrt().add_(group['eps'])

bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
bias_correction1 = 1.0 - beta1 ** state['step']
bias_correction2 = 1.0 - beta2 ** state['step']

step_size = group['lr'] * math.sqrt(bias_correction2)
if not group['adamd_debias_term']:
Expand Down
21 changes: 15 additions & 6 deletions pytorch_optimizer/adahessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ def set_hessian(self):
# approximate the expected values of z * (H@z)
p.hess += h_z * z / self.num_samples

@torch.no_grad()
def reset(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]

state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_hessian_diag_sq'] = torch.zeros_like(p)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
Expand Down Expand Up @@ -171,14 +181,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
beta1, beta2 = group['betas']

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)
exp_avg.mul_(beta1).add_(p.grad, alpha=1.0 - beta1)
exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1.0 - beta2)

bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
bias_correction1 = 1.0 - beta1 ** state['step']
bias_correction2 = 1.0 - beta2 ** state['step']

hessian_power = group['hessian_power']
de_nom = (exp_hessian_diag_sq / bias_correction2).pow_(hessian_power / 2.0).add_(group['eps'])
de_nom = (exp_hessian_diag_sq / bias_correction2).pow_(group['hessian_power'] / 2.0).add_(group['eps'])

step_size = group['lr']
if not group['adamd_debias_term']:
Expand Down
54 changes: 11 additions & 43 deletions pytorch_optimizer/adamp.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import math
from typing import Callable, List, Tuple

import torch
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base_optimizer import BaseOptimizer
from pytorch_optimizer.gc import centralize_gradient
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
from pytorch_optimizer.utils import projection


class AdamP(Optimizer, BaseOptimizer):
Expand Down Expand Up @@ -80,46 +79,15 @@ def validate_parameters(self):
self.validate_weight_decay_ratio(self.wd_ratio)
self.validate_epsilon(self.eps)

@staticmethod
def channel_view(x: torch.Tensor) -> torch.Tensor:
return x.view(x.size()[0], -1)

@staticmethod
def layer_view(x: torch.Tensor) -> torch.Tensor:
return x.view(1, -1)

@staticmethod
def cosine_similarity(
x: torch.Tensor,
y: torch.Tensor,
eps: float,
view_func: Callable[[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
x = view_func(x)
y = view_func(y)
return F.cosine_similarity(x, y, dim=1, eps=eps).abs_()

def projection(
self,
p,
grad,
perturb: torch.Tensor,
delta: float,
wd_ratio: float,
eps: float,
) -> Tuple[torch.Tensor, float]:
wd: float = 1.0
expand_size: List[int] = [-1] + [1] * (len(p.shape) - 1)
for view_func in (self.channel_view, self.layer_view):
cosine_sim = self.cosine_similarity(grad, p, eps, view_func)

if cosine_sim.max() < delta / math.sqrt(view_func(p).size()[1]):
p_n = p / view_func(p).norm(dim=1).view(expand_size).add_(eps)
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
wd = wd_ratio
return perturb, wd

return perturb, wd
@torch.no_grad()
def reset(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]

state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
Expand Down Expand Up @@ -166,7 +134,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:

wd_ratio: float = 1
if len(p.shape) > 1:
perturb, wd_ratio = self.projection(
perturb, wd_ratio = projection(
p,
grad,
perturb,
Expand Down
28 changes: 26 additions & 2 deletions pytorch_optimizer/base_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod

import torch

from pytorch_optimizer.types import BETAS


Expand Down Expand Up @@ -48,8 +50,8 @@ def validate_momentum(momentum: float):

@staticmethod
def validate_lookahead_k(k: int):
if k < 0:
raise ValueError(f'[-] k {k} must be non-negative')
if k < 1:
raise ValueError(f'[-] k {k} must be positive')

@staticmethod
def validate_rho(rho: float):
Expand All @@ -61,6 +63,28 @@ def validate_epsilon(epsilon: float):
if epsilon < 0.0:
raise ValueError(f'[-] epsilon {epsilon} must be non-negative')

@staticmethod
def validate_alpha(alpha: float):
if not 0.0 <= alpha < 1.0:
raise ValueError(f'[-] alpha {alpha} must be in the range [0, 1)')

@staticmethod
def validate_pullback_momentum(pullback_momentum: str):
if pullback_momentum not in ('none', 'reset', 'pullback'):
raise ValueError(
f'[-] pullback_momentum {pullback_momentum} must be one of (\'none\' or \'reset\' or \'pullback\')'
)

@staticmethod
def validate_reduction(reduction: str):
if reduction not in ('mean', 'sum'):
raise ValueError(f'[-] reduction {reduction} must be one of (\'mean\' or \'sum\')')

@abstractmethod
def validate_parameters(self):
raise NotImplementedError

@abstractmethod
@torch.no_grad()
def reset(self):
raise NotImplementedError
12 changes: 10 additions & 2 deletions pytorch_optimizer/diffgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,16 @@ def validate_parameters(self):
self.validate_weight_decay(self.weight_decay)
self.validate_epsilon(self.eps)

def __setstate__(self, state: STATE):
super().__setstate__(state)
@torch.no_grad()
def reset(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]

state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
state['previous_grad'] = torch.zeros_like(p)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
Expand Down
21 changes: 14 additions & 7 deletions pytorch_optimizer/diffrgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,16 @@ def validate_parameters(self):
self.validate_weight_decay(self.weight_decay)
self.validate_epsilon(self.eps)

def __setstate__(self, state: STATE):
super().__setstate__(state)
@torch.no_grad()
def reset(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]

state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
state['previous_grad'] = torch.zeros_like(p)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
Expand Down Expand Up @@ -123,7 +131,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
dfc = 1.0 / (1.0 + torch.exp(-diff))
state['previous_grad'] = grad.clone()

buffered = group['buffer'][int(state['step'] % 10)]
buffered = group['buffer'][state['step'] % 10]
if state['step'] == buffered[0]:
n_sma, step_size = buffered[1], buffered[2]
else:
Expand All @@ -144,10 +152,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
/ (n_sma_max - 2)
)

if group['adamd_debias_term']:
step_size = rt
else:
step_size = rt / bias_correction1
step_size = rt
if not group['adamd_debias_term']:
step_size /= bias_correction1
elif self.degenerated_to_sgd:
step_size = 1.0 / bias_correction1
else:
Expand Down
40 changes: 23 additions & 17 deletions pytorch_optimizer/fp16.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Dict, List, Optional, Union

import torch
from torch import nn
from torch.optim import Optimizer

from pytorch_optimizer.types import CLOSURE
from pytorch_optimizer.types import CLOSURE, PARAMETERS
from pytorch_optimizer.utils import clip_grad_norm, has_overflow

__AUTHOR__ = 'Facebook'
Expand Down Expand Up @@ -114,26 +115,29 @@ def get_parameters(cls, optimizer: Optimizer):
return params

@classmethod
def build_fp32_params(cls, parameters, flatten: bool = True) -> Union[torch.Tensor, List[torch.Tensor]]:
def build_fp32_params(
cls, parameters: PARAMETERS, flatten: bool = True
) -> Union[torch.Tensor, List[torch.Tensor]]:
# create FP32 copy of parameters and grads
if flatten:
total_param_size = sum(p.data.numel() for p in parameters)
total_param_size: int = sum(p.numel() for p in parameters)
fp32_params = torch.zeros(total_param_size, dtype=torch.float, device=parameters[0].device)

offset: int = 0
for p in parameters:
p_num_el = p.data.numel()
fp32_params[offset : offset + p_num_el].copy_(p.data.view(-1))
p_num_el = p.numel()
fp32_params[offset : offset + p_num_el].copy_(p.view(-1))
offset += p_num_el

fp32_params = torch.nn.Parameter(fp32_params)
fp32_params.grad = fp32_params.data.new(total_param_size)
fp32_params = nn.Parameter(fp32_params)
fp32_params.grad = fp32_params.new(total_param_size)

return fp32_params

fp32_params = []
for p in parameters:
p32 = torch.nn.Parameter(p.data.float())
p32.grad = torch.zeros_like(p32.data)
p32 = nn.Parameter(p.float())
p32.grad = torch.zeros_like(p32)
fp32_params.append(p32)

return fp32_params
Expand Down Expand Up @@ -181,25 +185,25 @@ def sync_fp16_grads_to_fp32(self, multiply_grads: float = 1.0):
continue

if p.grad is not None:
p32.grad.data.copy_(p.grad.data)
p32.grad.data.mul_(multiply_grads)
p32.grad.copy_(p.grad)
p32.grad.mul_(multiply_grads)
else:
p32.grad = torch.zeros_like(p.data, dtype=torch.float)
p32.grad = torch.zeros_like(p, dtype=torch.float)

self.needs_sync = False

def multiply_grads(self, c):
def multiply_grads(self, c: float):
"""Multiplies grads by a constant c."""
if self.needs_sync:
self.sync_fp16_grads_to_fp32(c)
else:
for p32 in self.fp32_params:
p32.grad.data.mul_(c)
p32.grad.mul_(c)

def update_main_grads(self):
self.sync_fp16_grads_to_fp32()

def clip_main_grads(self, max_norm):
def clip_main_grads(self, max_norm: float):
"""Clips gradient norm and updates dynamic loss scaler."""
self.sync_fp16_grads_to_fp32()

Expand All @@ -208,8 +212,10 @@ def clip_main_grads(self, max_norm):
# detect overflow and adjust loss scale
if self.scaler is not None:
overflow: bool = has_overflow(grad_norm)
prev_scale = self.scaler.loss_scale
prev_scale: float = self.scaler.loss_scale

self.scaler.update_scale(overflow)

if overflow:
self.zero_grad()
if self.scaler.loss_scale <= self.min_loss_scale:
Expand All @@ -235,7 +241,7 @@ def step(self, closure: CLOSURE = None):
for p, p32 in zip(self.fp16_params, self.fp32_params):
if not p.requires_grad:
continue
p.data.copy_(p32.data)
p.data.copy_(p32)

def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
Expand Down
10 changes: 10 additions & 0 deletions pytorch_optimizer/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ def validate_parameters(self):
self.validate_weight_decay(self.weight_decay)
self.validate_epsilon(self.eps)

@torch.no_grad()
def reset(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]

state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

def get_gradient_norm(self) -> float:
norm_sq: float = 0.0
for group in self.param_groups:
Expand Down
Loading

0 comments on commit 6db0d49

Please sign in to comment.