-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #251 from kozistr/feature/kate-optimizer
[Feature] Implement Kate optimizer
- Loading branch information
Showing
11 changed files
with
292 additions
and
171 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import torch | ||
from torch.optim.optimizer import Optimizer | ||
|
||
from pytorch_optimizer.base.exception import NoSparseGradientError | ||
from pytorch_optimizer.base.optimizer import BaseOptimizer | ||
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS | ||
|
||
|
||
class Kate(Optimizer, BaseOptimizer): | ||
r"""Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad. | ||
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. | ||
:param lr: float. learning rate. | ||
:param delta: float. delta. 0.0 or 1e-8. | ||
:param weight_decay: float. weight decay (L2 penalty). | ||
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. | ||
:param fixed_decay: bool. fix weight decay. | ||
:param eps: float. epsilon value. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
params: PARAMETERS, | ||
lr: float = 1e-3, | ||
delta: float = 0.0, | ||
weight_decay: float = 0.0, | ||
weight_decouple: bool = True, | ||
fixed_decay: bool = False, | ||
eps: float = 1e-8, | ||
): | ||
self.validate_learning_rate(lr) | ||
self.validate_range(delta, 'delta', 0.0, 1.0, '[)') | ||
self.validate_non_negative(weight_decay, 'weight_decay') | ||
self.validate_non_negative(eps, 'eps') | ||
|
||
defaults: DEFAULTS = { | ||
'lr': lr, | ||
'delta': delta, | ||
'weight_decay': weight_decay, | ||
'weight_decouple': weight_decouple, | ||
'fixed_decay': fixed_decay, | ||
'eps': eps, | ||
} | ||
|
||
super().__init__(params, defaults) | ||
|
||
def __str__(self) -> str: | ||
return 'Kate' | ||
|
||
@torch.no_grad() | ||
def reset(self): | ||
for group in self.param_groups: | ||
group['step'] = 0 | ||
for p in group['params']: | ||
state = self.state[p] | ||
|
||
state['m'] = torch.zeros_like(p) | ||
state['b'] = torch.zeros_like(p) | ||
|
||
@torch.no_grad() | ||
def step(self, closure: CLOSURE = None) -> LOSS: | ||
loss: LOSS = None | ||
if closure is not None: | ||
with torch.enable_grad(): | ||
loss = closure() | ||
|
||
for group in self.param_groups: | ||
if 'step' in group: | ||
group['step'] += 1 | ||
else: | ||
group['step'] = 1 | ||
|
||
for p in group['params']: | ||
if p.grad is None: | ||
continue | ||
|
||
grad = p.grad | ||
if grad.is_sparse: | ||
raise NoSparseGradientError(str(self)) | ||
|
||
state = self.state[p] | ||
|
||
if len(state) == 0: | ||
state['m'] = torch.zeros_like(p) | ||
state['b'] = torch.zeros_like(p) | ||
|
||
self.apply_weight_decay( | ||
p=p, | ||
grad=p.grad, | ||
lr=group['lr'], | ||
weight_decay=group['weight_decay'], | ||
weight_decouple=group['weight_decouple'], | ||
fixed_decay=group['fixed_decay'], | ||
) | ||
|
||
grad_p2 = torch.mul(grad, grad) | ||
|
||
m, b = state['m'], state['b'] | ||
b.mul_(b).add_(grad_p2).add_(group['eps']) | ||
|
||
m.mul_(m).add_(grad_p2, alpha=group['delta']).add_(grad_p2 / b).sqrt_() | ||
|
||
update = m.mul(grad).div_(b) | ||
|
||
p.add_(update, alpha=-group['lr']) | ||
|
||
b.sqrt_() | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters