Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Implement StableAdamW optimizer #252

Merged
merged 7 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **70 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!
Currently, **71 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down Expand Up @@ -167,6 +167,7 @@ supported_optimizers = get_supported_optimizers()
| FAdam | *Adam is a natural gradient optimizer using diagonal empirical Fisher information* | [github](https://github.com/lessw2020/fadam_pytorch) | <https://arxiv.org/abs/2405.12807> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240512807H/exportcitation) |
| Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | <https://arxiv.org/abs/2405.20233> | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) |
| Kate | *Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad* | [github](https://github.com/nazya/KATE) | <https://arxiv.org/abs/2403.02648> | [cite](https://github.com/nazya/KATE?tab=readme-ov-file#remove-that-square-root-a-new-efficient-scale-invariant-version-of-adagrad) |
| StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | <https://arxiv.org/abs/2304.13013> | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) |

## Supported LR Scheduler

Expand Down
2 changes: 2 additions & 0 deletions docs/changelogs/v3.0.2.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
* Add more Pytorch built-in lr schedulers. (#248)
* Implement `Kate` optimizer. (#249, #251)
* [Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad](https://arxiv.org/abs/2403.02648)
* Implement `StableAdamW` optimizer. (#250, #252)
* [Stable and low-precision training for large-scale vision-language models](https://arxiv.org/abs/2304.13013)

### Refactor

Expand Down
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **70 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!
Currently, **71 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down Expand Up @@ -167,6 +167,7 @@ supported_optimizers = get_supported_optimizers()
| FAdam | *Adam is a natural gradient optimizer using diagonal empirical Fisher information* | [github](https://github.com/lessw2020/fadam_pytorch) | <https://arxiv.org/abs/2405.12807> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240512807H/exportcitation) |
| Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | <https://arxiv.org/abs/2405.20233> | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) |
| Kate | *Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad* | [github](https://github.com/nazya/KATE) | <https://arxiv.org/abs/2403.02648> | [cite](https://github.com/nazya/KATE?tab=readme-ov-file#remove-that-square-root-a-new-efficient-scale-invariant-version-of-adagrad) |
| StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | <https://arxiv.org/abs/2304.13013> | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) |

## Supported LR Scheduler

Expand Down
4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@
:docstring:
:members:

::: pytorch_optimizer.StableAdamW
:docstring:
:members:

::: pytorch_optimizer.AccSGD
:docstring:
:members:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ keywords = [
"GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad",
"PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM",
"ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH",
"SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM",
"Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
"SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
"Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
Expand Down
2 changes: 2 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from pytorch_optimizer.optimizer.adamod import AdaMod
from pytorch_optimizer.optimizer.adamp import AdamP
from pytorch_optimizer.optimizer.adams import AdamS
from pytorch_optimizer.optimizer.adamw import StableAdamW
from pytorch_optimizer.optimizer.adan import Adan
from pytorch_optimizer.optimizer.adanorm import AdaNorm
from pytorch_optimizer.optimizer.adapnm import AdaPNM
Expand Down Expand Up @@ -201,6 +202,7 @@
FAdam,
GrokFastAdamW,
Kate,
StableAdamW,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
135 changes: 135 additions & 0 deletions pytorch_optimizer/optimizer/adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import math

import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
from pytorch_optimizer.optimizer.utils import debias_beta


class StableAdamW(Optimizer, BaseOptimizer):
r"""Stable and low-precision training for large-scale vision-language models.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
:param kahan_sum: bool. Enables Kahan summation for more accurate parameter updates when training in low precision
(float16 or bfloat16).
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. decoupled weight decay.
:param eps: float. term added to the denominator to improve numerical stability.
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-3,
betas: BETAS = (0.9, 0.99),
kahan_sum: bool = True,
weight_decay: float = 1e-2,
weight_decouple: bool = True,
eps: float = 1e-8,
):
self.validate_learning_rate(lr)
self.validate_betas(betas)
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
'kahan_sum': kahan_sum,
'weight_decay': weight_decay,
'weight_decouple': weight_decouple,
'eps': eps,
}

super().__init__(params, defaults)

def __str__(self) -> str:
return 'StableAdamW'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
state = self.state[p]

state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

state['kahan_comp'] = (
torch.zeros_like(p) if group['kahan_sum'] and p.dtype in {torch.float16, torch.bfloat16} else None
)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1

beta1, beta2 = group['betas']

beta1_comp: float = 1.0 - debias_beta(beta1, group['step'])
beta2_hat: float = debias_beta(beta2, group['step'])

eps_p2: float = math.pow(group['eps'], 2)

for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(str(self))

state = self.state[p]
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

state['kahan_comp'] = (
torch.zeros_like(p)
if (group['kahan_sum'] and p.dtype in {torch.float16, torch.bfloat16})
else None
)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
exp_avg.lerp_(grad, weight=beta1_comp)
exp_avg_sq.mul_(beta2_hat).addcmul_(grad, grad, value=1.0 - beta2_hat)

rms = grad.pow(2).div_(exp_avg_sq.clip(min=eps_p2)).mean().sqrt_()

lr = group['lr'] / rms.clip(min=1.0)

self.apply_weight_decay(
p,
p.grad,
lr=lr,
weight_decay=group['weight_decay'],
weight_decouple=group['weight_decouple'],
fixed_decay=False,
)

if group['kahan_sum'] and p.dtype in {torch.float16, torch.bfloat16}:
kahan_comp = state['kahan_comp']
kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(group['eps']), value=-lr)

grad.copy_(p.detach())
p.add_(kahan_comp)

kahan_comp.add_(grad.sub_(p))
else:
p.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(group['eps']), value=-lr)

return loss
11 changes: 11 additions & 0 deletions pytorch_optimizer/optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
from pytorch_optimizer.base.types import PARAMETERS


def debias_beta(beta: float, step: int) -> float:
r"""Apply the Adam-style debias correction into beta.

Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)`

:param beta: float. beta.
:param step: int. number of step.
"""
return (beta ** step - beta) / (beta ** step - 1.0) # fmt: skip


def is_valid_parameters(parameters: PARAMETERS) -> bool:
r"""Check where the parameters are valid."""
return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], dict)
Expand Down
3 changes: 3 additions & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
Shampoo,
SignSGD,
SophiaH,
StableAdamW,
Tiger,
Yogi,
)
Expand Down Expand Up @@ -132,6 +133,7 @@
'schedulefreeadamw',
'fadam',
'grokfastadamw',
'stableadamw',
]

VALID_LR_SCHEDULER_NAMES: List[str] = [
Expand Down Expand Up @@ -463,6 +465,7 @@
(FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(GrokFastAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 10),
(Kate, {'lr': 5e-2}, 10),
(StableAdamW, {'lr': 1e0}, 5),
]
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_load_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):


def test_get_supported_optimizers():
assert len(get_supported_optimizers()) == 69
assert len(get_supported_optimizers()) == 70


def test_get_supported_lr_schedulers():
Expand Down
10 changes: 10 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,13 @@ def test_grokfast_ema(environment):
model.fc2.bias.grad = torch.randn(1)

_ = gradfilter_ema(model, None)


def test_stableadamw_optimizer(environment):
_, model, _ = environment

model.fc1.weight.data = torch.randn(2, 2, dtype=torch.float16)

optimizer = load_optimizer('StableAdamW')(model.parameters())
optimizer.reset()
optimizer.step()