Skip to content

Commit

Permalink
Merge pull request #304 from kozistr/feature/optimizers
Browse files Browse the repository at this point in the history
[Feature] Implement `ScheduleFreeRAdam`, `LaProp` optimizers and lots of things
  • Loading branch information
kozistr authored Dec 4, 2024
2 parents a980dc0 + 5326483 commit aee5fc4
Show file tree
Hide file tree
Showing 15 changed files with 438 additions and 35 deletions.
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
| Status | [![PyPi download](https://static.pepy.tech/badge/pytorch-optimizer)](https://pepy.tech/project/pytorch-optimizer) [![PyPi month download](https://static.pepy.tech/badge/pytorch-optimizer/month)](https://pepy.tech/project/pytorch-optimizer) |
| License | [![apache](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) |

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **81 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
## The reasons why you use `pytorch-optimizer`.

1. Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
2. Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
3. Easy to use, clean, and tested codes
4. Active maintenance
5. Somewhat a bit more optimized compared to the original implementation

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down Expand Up @@ -187,6 +191,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| DeMo | *Decoupled Momentum Optimization* | [github](https://github.com/bloc97/DeMo) | <https://arxiv.org/abs/2411.19870> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241119870P/exportcitation) |
| MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | <https://arxiv.org/abs/2405.15593> | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) |
| Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | <https://x.com/kellerjordan0/status/1842300916864844014> | [cite](https://github.com/KellerJordan/Muon) |
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |

## Supported LR Scheduler

Expand Down
4 changes: 4 additions & 0 deletions docs/changelogs/v3.3.1.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@
* [Decoupled Momentum Optimization](https://arxiv.org/abs/2411.19870)
* Implement `Muon` optimizer. (#302)
* [MomentUm Orthogonalized by Newton-schulz](https://github.com/KellerJordan/Muon)
* Implement `ScheduleFreeRAdam` optimizer. (#304)
* Implement `LaProp` optimizer. (#304)
* [Separating Momentum and Adaptivity in Adam](https://arxiv.org/abs/2002.04839)
* Support `Cautious` variant to `LaProp`, `AdamP`, `Adopt` optimizers. (#304).
11 changes: 8 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
| Status | [![PyPi download](https://static.pepy.tech/badge/pytorch-optimizer)](https://pepy.tech/project/pytorch-optimizer) [![PyPi month download](https://static.pepy.tech/badge/pytorch-optimizer/month)](https://pepy.tech/project/pytorch-optimizer) |
| License | [![apache](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) |

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **81 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
## The reasons why you use `pytorch-optimizer`.

1. Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
2. Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
3. Easy to use, clean, and tested codes
4. Active maintenance
5. Somewhat a bit more optimized compared to the original implementation

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down Expand Up @@ -187,6 +191,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| DeMo | *Decoupled Momentum Optimization* | [github](https://github.com/bloc97/DeMo) | <https://arxiv.org/abs/2411.19870> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241119870P/exportcitation) |
| MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | <https://arxiv.org/abs/2405.15593> | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) |
| Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | <https://x.com/kellerjordan0/status/1842300916864844014> | [cite](https://github.com/KellerJordan/Muon) |
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |

## Supported LR Scheduler

Expand Down
8 changes: 8 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@
:docstring:
:members:

::: pytorch_optimizer.LaProp
:docstring:
:members:

::: pytorch_optimizer.LARS
:docstring:
:members:
Expand Down Expand Up @@ -296,6 +300,10 @@
:docstring:
:members:

::: pytorch_optimizer.ScheduleFreeRAdam
:docstring:
:members:

::: pytorch_optimizer.StableAdamW
:docstring:
:members:
Expand Down
14 changes: 7 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ keywords = [
"AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix", "ADOPT",
"AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos",
"Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion",
"DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS",
"Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM",
"Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD",
"ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SRMM",
"StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
"Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
"QGaLore",
"DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LaProp",
"LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID",
"PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD",
"ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP",
"SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
"FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge",
"bitsandbytes", "WSD", "QGaLore",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
Expand Down
2 changes: 2 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
GrokFastAdamW,
Kate,
Lamb,
LaProp,
Lion,
Lookahead,
Muon,
Expand All @@ -123,6 +124,7 @@
SafeFP16Optimizer,
ScalableShampoo,
ScheduleFreeAdamW,
ScheduleFreeRAdam,
ScheduleFreeSGD,
Shampoo,
SignSGD,
Expand Down
5 changes: 4 additions & 1 deletion pytorch_optimizer/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from pytorch_optimizer.optimizer.grokfast import GrokFastAdamW
from pytorch_optimizer.optimizer.kate import Kate
from pytorch_optimizer.optimizer.lamb import Lamb
from pytorch_optimizer.optimizer.laprop import LaProp
from pytorch_optimizer.optimizer.lars import LARS
from pytorch_optimizer.optimizer.lion import Lion
from pytorch_optimizer.optimizer.lomo import LOMO, AdaLOMO
Expand All @@ -71,7 +72,7 @@
from pytorch_optimizer.optimizer.ranger21 import Ranger21
from pytorch_optimizer.optimizer.rotograd import RotoGrad
from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM
from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeSGD
from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeRAdam, ScheduleFreeSGD
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SignSGD
from pytorch_optimizer.optimizer.sgdp import SGDP
from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo
Expand Down Expand Up @@ -275,6 +276,8 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
FTRL,
DeMo,
Muon,
ScheduleFreeRAdam,
LaProp,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
2 changes: 1 addition & 1 deletion pytorch_optimizer/optimizer/adalite.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
self.validate_betas(betas)
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps1, 'eps1')
self.validate_non_negative(eps2, 'eps1')
self.validate_non_negative(eps2, 'eps2')

defaults: DEFAULTS = {
'lr': lr,
Expand Down
6 changes: 6 additions & 0 deletions pytorch_optimizer/optimizer/adamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class AdamP(BaseOptimizer):
:param wd_ratio: float. relative weight decay applied on scale-invariant parameters compared to that applied
on scale-variant parameters.
:param use_gc: bool. use gradient centralization.
:param cautious: bool. whether to use the Cautious variant.
:param nesterov: bool. enables Nesterov momentum.
:param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
:param adanorm: bool. whether to use the AdaNorm variant.
Expand All @@ -40,6 +41,7 @@ def __init__(
delta: float = 0.1,
wd_ratio: float = 0.1,
use_gc: bool = False,
cautious: bool = False,
nesterov: bool = False,
r: float = 0.95,
adanorm: bool = False,
Expand All @@ -54,6 +56,7 @@ def __init__(
self.validate_non_negative(eps, 'eps')

self.use_gc = use_gc
self.cautious = cautious

defaults: DEFAULTS = {
'lr': lr,
Expand Down Expand Up @@ -170,6 +173,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
bias_correction1=bias_correction1,
)

if self.cautious:
self.apply_cautious(perturb, grad)

p.add_(perturb, alpha=-step_size)

return loss
11 changes: 10 additions & 1 deletion pytorch_optimizer/optimizer/adopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ADOPT(BaseOptimizer):
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
:param fixed_decay: bool. fix weight decay.
:param cautious: bool. whether to use the Cautious variant.
:param eps: float. term added to the denominator to improve numerical stability.
"""

Expand All @@ -29,6 +30,7 @@ def __init__(
weight_decay: float = 0.0,
weight_decouple: bool = False,
fixed_decay: bool = False,
cautious: bool = False,
eps: float = 1e-6,
**kwargs,
):
Expand All @@ -38,6 +40,7 @@ def __init__(
self.validate_non_negative(eps, 'eps')

self.clip_lambda = clip_lambda
self.cautious = cautious

defaults: DEFAULTS = {
'lr': lr,
Expand Down Expand Up @@ -118,6 +121,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:

exp_avg.lerp_(normed_grad, weight=1.0 - beta1)

p.add_(exp_avg, alpha=-group['lr'])
if self.cautious:
update = exp_avg.clone()
self.apply_cautious(update, normed_grad)
else:
update = exp_avg

p.add_(update, alpha=-group['lr'])

return loss
Loading

0 comments on commit aee5fc4

Please sign in to comment.