Skip to content

Commit

Permalink
Merge pull request #316 from kozistr/fix/cautious
Browse files Browse the repository at this point in the history
[Feature] Implement `SGDSaI` optimizer
  • Loading branch information
kozistr authored Dec 21, 2024
2 parents d16a368 + a5e0894 commit 8f538d4
Show file tree
Hide file tree
Showing 16 changed files with 176 additions and 19 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **85 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **86 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -194,6 +194,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |
| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | <https://arxiv.org/abs/2412.05270> | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) |
| MARS | *Unleashing the Power of Variance Reduction for Training Large Models* | [github](https://github.com/AGI-Arena/MARS) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AGI-Arena/MARS/tree/main?tab=readme-ov-file#citation) |
| SGDSaI | *No More Adam: Learning Rate Scaling at Initialization is All You Need* | [github](https://github.com/AnonymousAlethiometer/SGD_SaI) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AnonymousAlethiometer/SGD_SaI?tab=readme-ov-file#citation) |

## Supported LR Scheduler

Expand Down
10 changes: 10 additions & 0 deletions docs/changelogs/v3.3.2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
### Change Log

### Feature

* Implement `SGDSaI` optimizer. (#315, #316)
* [No More Adam: Learning Rate Scaling at Initialization is All You Need](https://arxiv.org/abs/2412.11768)

### Bug

* Clone `exp_avg` before calling `apply_cautious` not to mask `exp_avg`. (#316)
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **85 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **86 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -194,6 +194,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |
| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | <https://arxiv.org/abs/2412.05270> | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) |
| MARS | *Unleashing the Power of Variance Reduction for Training Large Models* | [github](https://github.com/AGI-Arena/MARS) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AGI-Arena/MARS/tree/main?tab=readme-ov-file#citation) |
| SGDSaI | *No More Adam: Learning Rate Scaling at Initialization is All You Need* | [github](https://github.com/AnonymousAlethiometer/SGD_SaI) | <https://arxiv.org/abs/2411.10438> | [cite](https://github.com/AnonymousAlethiometer/SGD_SaI?tab=readme-ov-file#citation) |

## Supported LR Scheduler

Expand Down
4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@
:docstring:
:members:

::: pytorch_optimizer.SGDSaI
:docstring:
:members:

::: pytorch_optimizer.SGDP
:docstring:
:members:
Expand Down
8 changes: 8 additions & 0 deletions docs/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SGDP.png)

### SGDSaI

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SGDSaI.png)

### SGDW

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SGDW.png)
Expand Down Expand Up @@ -592,6 +596,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SGDP.png)

### SGDSaI

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SGDSaI.png)

### SGDW

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SGDW.png)
Expand Down
Binary file added docs/visualizations/rastrigin_SGDSaI.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rosenbrock_SGDSaI.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytorch_optimizer"
version = "3.3.1"
version = "3.3.2"
description = "optimizer & lr scheduler & objective function collections in PyTorch"
license = "Apache-2.0"
authors = ["kozistr <[email protected]>"]
Expand Down
1 change: 1 addition & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
ScheduleFreeAdamW,
ScheduleFreeRAdam,
ScheduleFreeSGD,
SGDSaI,
Shampoo,
SignSGD,
SophiaH,
Expand Down
3 changes: 2 additions & 1 deletion pytorch_optimizer/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from pytorch_optimizer.optimizer.rotograd import RotoGrad
from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM
from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeRAdam, ScheduleFreeSGD
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SignSGD
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SGDSaI, SignSGD
from pytorch_optimizer.optimizer.sgdp import SGDP
from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo
from pytorch_optimizer.optimizer.sm3 import SM3
Expand Down Expand Up @@ -281,6 +281,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
ScheduleFreeRAdam,
LaProp,
MARS,
SGDSaI,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
3 changes: 2 additions & 1 deletion pytorch_optimizer/optimizer/adashift.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg_sq.mul_(beta2).add_(reduced_grad_sq, alpha=1.0 - beta2)

update = exp_avg.clone()
update.div_(exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps']))
if self.cautious:
self.apply_cautious(update, grad)

update.div_(exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps']))

p.add_(update, alpha=-group['lr'])

return loss
5 changes: 3 additions & 2 deletions pytorch_optimizer/optimizer/ademamix.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:

de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])

update = exp_avg.clone()
if self.cautious:
self.apply_cautious(exp_avg, grad)
self.apply_cautious(update, grad)

update = (exp_avg + alpha_t * exp_avg_slow).div_(de_nom)
update.add_(exp_avg_slow, alpha=alpha_t).div_(de_nom)

p.add_(update, alpha=-step_size)

Expand Down
25 changes: 14 additions & 11 deletions pytorch_optimizer/optimizer/mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,26 +121,27 @@ def optimize_mixed(

exp_avg.mul_(beta1).add_(c_t, alpha=1.0 - beta1)

update = exp_avg.clone()
if cautious:
self.apply_cautious(exp_avg, grad)
self.apply_cautious(update, grad)

if mars_type == 'adamw' or (mars_type == 'shampoo' and not is_grad_2d):
exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1.0 - beta2)

bias_correction1: float = self.debias(beta1, step)
bias_correction2_sq: float = math.sqrt(self.debias(beta2, step))

update = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
update.div_(bias_correction2_sq).mul_(bias_correction1)
de_nom = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
de_nom.div_(bias_correction2_sq).mul_(bias_correction1)

return exp_avg.div(update)
return update.div_(de_nom)

if mars_type == 'lion':
return exp_avg.sign()
return update.sign_()

factor: float = max(1.0, grad.size(0) / grad.size(1)) ** 0.5
factor: float = math.sqrt(max(1.0, grad.size(0) / grad.size(1)))

return zero_power_via_newton_schulz_5(exp_avg.mul(1.0 / (1.0 - beta1)), eps=eps).mul_(factor)
return zero_power_via_newton_schulz_5(update.mul_(1.0 / (1.0 - beta1)), eps=eps).mul_(factor)

def optimize_1d(
self,
Expand All @@ -162,13 +163,15 @@ def optimize_1d(
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

update = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
update.div_(bias_correction2_sq).mul_(bias_correction1)
update = exp_avg.clone()

if cautious:
self.apply_cautious(exp_avg, grad)
self.apply_cautious(update, grad)

return exp_avg.div(update)
de_nom = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
de_nom.div_(bias_correction2_sq).mul_(bias_correction1)

return update.div_(de_nom)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
Expand Down
123 changes: 123 additions & 0 deletions pytorch_optimizer/optimizer/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,9 @@ def __init__(
}
super().__init__(params, defaults)

def __str__(self) -> str:
return 'SignSGD'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
Expand Down Expand Up @@ -396,3 +399,123 @@ def step(self, closure: CLOSURE = None) -> LOSS:
p.add_(torch.sign(buf), alpha=-group['lr'])

return loss


class SGDSaI(BaseOptimizer):
r"""No More Adam: Learning Rate Scaling at Initialization is All You Need.
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param momentum: float. coefficients used for computing running averages of gradient.
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
:param eps: float. term added to the denominator to improve numerical stability.
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-2,
momentum: float = 0.9,
weight_decay: float = 1e-2,
weight_decouple: bool = True,
eps: float = 1e-8,
**kwargs,
):
self.validate_learning_rate(lr)
self.validate_range(momentum, 'beta', 0.0, 1.0)
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

self.has_warmup: bool = False

defaults: DEFAULTS = {
'lr': lr,
'momentum': momentum,
'weight_decay': weight_decay,
'weight_decouple': weight_decouple,
'eps': eps,
}
super().__init__(params, defaults)

def __str__(self) -> str:
return 'SGDSaI'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
state = self.state[p]

if group['momentum'] > 0.0:
state['momentum_buffer'] = torch.zeros_like(p)

@torch.no_grad()
def warmup_step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(str(self))

sigma = grad.std().nan_to_num_()
grad_norm = grad.norm()

g_snr = grad_norm.div_(sigma.add_(group['eps'])) if sigma != 0.0 else grad_norm

self.state[p]['gsnr'] = g_snr

self.has_warmup = True

return loss

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
if not self.has_warmup:
self.warmup_step(closure)

loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
momentum: float = group['momentum']
for p in group['params']:
if p.grad is None:
continue

grad = p.grad

state = self.state[p]

if momentum > 0.0:
if 'momentum_buffer' not in state:
state['momentum_buffer'] = grad.clone()

buf = state['momentum_buffer']
buf.mul_(momentum).add_(grad, alpha=1.0 - momentum)
else:
buf = grad

self.apply_weight_decay(
p,
grad,
group['lr'],
group['weight_decay'],
group['weight_decouple'],
False,
)

p.add_(buf, alpha=-group['lr'] * state['gsnr'])

return loss
3 changes: 3 additions & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
ScheduleFreeAdamW,
ScheduleFreeRAdam,
ScheduleFreeSGD,
SGDSaI,
Shampoo,
SignSGD,
SophiaH,
Expand Down Expand Up @@ -538,6 +539,8 @@
(MARS, {'lr': 1e-1, 'weight_decay': 1e-3, 'mars_type': 'lion', 'optimize_1d': True}, 5),
(MARS, {'lr': 5e-1, 'lr_1d': 5e-1, 'weight_decay': 1e-3, 'mars_type': 'shampoo'}, 5),
(MARS, {'lr': 5e-1, 'lr_1d': 5e-1, 'weight_decay': 1e-3, 'mars_type': 'adamw', 'ams_bound': True}, 5),
(SGDSaI, {'lr': 1e0}, 15),
(SGDSaI, {'lr': 1e0, 'momentum': 0.0}, 15),
]
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_load_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):


def test_get_supported_optimizers():
assert len(get_supported_optimizers()) == 84
assert len(get_supported_optimizers()) == 85
assert len(get_supported_optimizers('adam*')) == 7
assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 9

Expand Down

0 comments on commit 8f538d4

Please sign in to comment.