Skip to content

Commit

Permalink
Merge pull request #251 from kozistr/feature/kate-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement Kate optimizer
  • Loading branch information
kozistr authored Jul 6, 2024
2 parents 908e82e + 862ec9d commit b316ef9
Show file tree
Hide file tree
Showing 11 changed files with 292 additions and 171 deletions.
149 changes: 75 additions & 74 deletions README.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions docs/changelogs/v3.0.2.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
* Implement `WSD` LR Scheduler. (#247, #248)
* [Warmup-Stable-Decay LR Scheduler](https://arxiv.org/abs/2404.06395)
* Add more Pytorch built-in lr schedulers. (#248)
* Implement `Kate` optimizer. (#249, #251)
* [Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad](https://arxiv.org/abs/2403.02648)

### Refactor

Expand Down
149 changes: 75 additions & 74 deletions docs/index.md

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@
:docstring:
:members:

::: pytorch_optimizer.Kate
:docstring:
:members:

::: pytorch_optimizer.Lamb
:docstring:
:members:
Expand Down
38 changes: 19 additions & 19 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ keywords = [
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "Adalite",
"AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad",
"DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", "Fromage", "GaLore", "Gravity",
"GrokFast", "GSAM", "LARS", "Lamb", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam",
"PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM",
"GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad",
"PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM",
"ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH",
"SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM",
"Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
Expand Down
2 changes: 2 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from pytorch_optimizer.optimizer.gc import centralize_gradient
from pytorch_optimizer.optimizer.gravity import Gravity
from pytorch_optimizer.optimizer.grokfast import GrokFastAdamW, gradfilter_ema, gradfilter_ma
from pytorch_optimizer.optimizer.kate import Kate
from pytorch_optimizer.optimizer.lamb import Lamb
from pytorch_optimizer.optimizer.lars import LARS
from pytorch_optimizer.optimizer.lion import Lion
Expand Down Expand Up @@ -199,6 +200,7 @@
ScheduleFreeAdamW,
FAdam,
GrokFastAdamW,
Kate,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
109 changes: 109 additions & 0 deletions pytorch_optimizer/optimizer/kate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS


class Kate(Optimizer, BaseOptimizer):
r"""Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad.
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param delta: float. delta. 0.0 or 1e-8.
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
:param fixed_decay: bool. fix weight decay.
:param eps: float. epsilon value.
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-3,
delta: float = 0.0,
weight_decay: float = 0.0,
weight_decouple: bool = True,
fixed_decay: bool = False,
eps: float = 1e-8,
):
self.validate_learning_rate(lr)
self.validate_range(delta, 'delta', 0.0, 1.0, '[)')
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

defaults: DEFAULTS = {
'lr': lr,
'delta': delta,
'weight_decay': weight_decay,
'weight_decouple': weight_decouple,
'fixed_decay': fixed_decay,
'eps': eps,
}

super().__init__(params, defaults)

def __str__(self) -> str:
return 'Kate'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
state = self.state[p]

state['m'] = torch.zeros_like(p)
state['b'] = torch.zeros_like(p)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1

for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(str(self))

state = self.state[p]

if len(state) == 0:
state['m'] = torch.zeros_like(p)
state['b'] = torch.zeros_like(p)

self.apply_weight_decay(
p=p,
grad=p.grad,
lr=group['lr'],
weight_decay=group['weight_decay'],
weight_decouple=group['weight_decouple'],
fixed_decay=group['fixed_decay'],
)

grad_p2 = torch.mul(grad, grad)

m, b = state['m'], state['b']
b.mul_(b).add_(grad_p2).add_(group['eps'])

m.mul_(m).add_(grad_p2, alpha=group['delta']).add_(grad_p2 / b).sqrt_()

update = m.mul(grad).div_(b)

p.add_(update, alpha=-group['lr'])

b.sqrt_()

return loss
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ platformdirs==4.2.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
pluggy==1.5.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
pytest-cov==5.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
pytest==8.2.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
ruff==0.5.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
ruff==0.5.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
sympy==1.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
tbb==2021.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows"
tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6"
Expand Down
2 changes: 2 additions & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
GaLore,
Gravity,
GrokFastAdamW,
Kate,
Lamb,
Lion,
Nero,
Expand Down Expand Up @@ -461,6 +462,7 @@
(ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(GrokFastAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 10),
(Kate, {'lr': 5e-2}, 10),
]
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_load_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):


def test_get_supported_optimizers():
assert len(get_supported_optimizers()) == 68
assert len(get_supported_optimizers()) == 69


def test_get_supported_lr_schedulers():
Expand Down

0 comments on commit b316ef9

Please sign in to comment.