Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Save and load the Lookahead optimizer's state #310

Merged
merged 17 commits into from
Dec 14, 2024
7 changes: 5 additions & 2 deletions .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ A clear and concise description of what the bug is.

* OS : (e.g. Linux, Windows, MacOS)
* PyTorch version : (e.g. 2.0.1, 1.13, >=1.8, <1.10)
* Python version : (e.g. 3.8, 3.11
* reproducible codes :
* Python version : (e.g. 3.8, 3.11)
* pytorch-optimizer version : (e.g. 3.3.0)
* reproducible codes : please share your reproducible codes, scripts, or links. If sharing the code is complicated, you can manually write minimal code to reproduce bugs!

Here's an [example](https://github.com/kozistr/pytorch_optimizer/issues/305#issue-2721453417).

## Log

Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

## The reasons why you use `pytorch-optimizer`.

1. Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
2. Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
3. Easy to use, clean, and tested codes
4. Active maintenance
5. Somewhat a bit more optimized compared to the original implementation
* Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
* Somewhat a bit more optimized compared to the original implementation

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down
8 changes: 8 additions & 0 deletions docs/changelogs/v3.3.1.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@

### Feature

* Support `Cautious` variant to `AdaShift` optimizer. (#310)
* Save the state of the `Lookahead` optimizer too. (#310)

### Bug

* Fix `bias_correction` in `AdamG` optimizer. (#305, #308)
* Fix a potential bug when loading the state for `Lookahead` optimizer. (#306, #310)

### Docs

* Add more visualizations. (#310)

### Contributions

Expand Down
10 changes: 5 additions & 5 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

## The reasons why you use `pytorch-optimizer`.

1. Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
2. Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
3. Easy to use, clean, and tested codes
4. Active maintenance
5. Somewhat a bit more optimized compared to the original implementation
* Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
* Somewhat a bit more optimized compared to the original implementation

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down
32 changes: 32 additions & 0 deletions docs/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaPNM.png)

### AdaShift

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaShift.png)

### AdaSmooth

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaSmooth.png)
Expand Down Expand Up @@ -170,6 +174,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Lamb.png)

### LaProp

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_LaProp.png)

### LARS

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_LARS.png)
Expand All @@ -186,6 +194,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_MSVAG.png)

### Muon

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Muon.png)

### Nero

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Nero.png)
Expand Down Expand Up @@ -238,6 +250,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_ScheduleFreeAdamW.png)

### ScheduleFreeRAdam

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_ScheduleFreeRAdam.png)

### ScheduleFreeSGD

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_ScheduleFreeSGD.png)
Expand Down Expand Up @@ -368,6 +384,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaPNM.png)

### AdaShift

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaShift.png)

### AdaSmooth

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaSmooth.png)
Expand Down Expand Up @@ -464,6 +484,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Lamb.png)

### LaProp

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_LaProp.png)

### LARS

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_LARS.png)
Expand All @@ -480,6 +504,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_MSVAG.png)

### Muon

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Muon.png)

### Nero

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Nero.png)
Expand Down Expand Up @@ -532,6 +560,10 @@

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_ScheduleFreeAdamW.png)

### ScheduleFreeRAdam

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_ScheduleFreeRAdam.png)

### ScheduleFreeSGD

![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_ScheduleFreeSGD.png)
Expand Down
Binary file added docs/visualizations/rastrigin_AdaShift.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/visualizations/rastrigin_AdamG.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rastrigin_LaProp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rastrigin_Muon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rosenbrock_AdaShift.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/visualizations/rosenbrock_AdamG.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rosenbrock_LaProp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/visualizations/rosenbrock_Muon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion examples/visualize_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def execute_steps(func, initial_state, optimizer_class, optimizer_config, num_it

if optimizer_class.__name__ == 'Ranger21':
optimizer_config.update({'num_iterations': num_iters})
if optimizer_class.__name__ == 'AdaShift':
optimizer_config.update({'keep_num': 1})

optimizer = optimizer_class([x], **optimizer_config)

Expand Down Expand Up @@ -155,7 +157,7 @@ def main():
optimizers = [
(optimizer, -6, 0.5)
for optimizer_name, optimizer in OPTIMIZERS.items()
if optimizer_name.lower() not in {'alig', 'lomo', 'adalomo', 'bsam', 'adammini'}
if optimizer_name.lower() not in {'alig', 'lomo', 'adalomo', 'bsam', 'adammini', 'demo'}
]
optimizers.extend([(torch.optim.AdamW, -6, 0.5), (torch.optim.Adam, -6, 0.5), (torch.optim.SGD, -6, -1.0)])

Expand Down
6 changes: 4 additions & 2 deletions pytorch_optimizer/optimizer/a2grad.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import math
from typing import Optional
from typing import Literal, Optional

import torch

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS

VARIANTS = Literal['uni', 'inc', 'exp']


class A2Grad(BaseOptimizer):
r"""Optimal Adaptive and Accelerated Stochastic Gradient Descent.
Expand All @@ -26,7 +28,7 @@ def __init__(
beta: float = 10.0,
lips: float = 10.0,
rho: float = 0.5,
variant: str = 'uni',
variant: VARIANTS = 'uni',
**kwargs,
):
self.validate_learning_rate(lr)
Expand Down
12 changes: 9 additions & 3 deletions pytorch_optimizer/optimizer/adashift.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class AdaShift(BaseOptimizer):
:param keep_num: int. number of gradients used to compute first moment estimation.
:param reduce_func: Optional[Callable]. function applied to squared gradients to further reduce the correlation.
If None, no function is applied.
:param cautious: bool. whether to use cautious feature.
:param eps: float. term added to the denominator to improve numerical stability.
"""

Expand All @@ -27,6 +28,7 @@ def __init__(
betas: BETAS = (0.9, 0.999),
keep_num: int = 10,
reduce_func: Optional[Callable] = torch.max,
cautious: bool = False,
eps: float = 1e-10,
**kwargs,
):
Expand All @@ -36,6 +38,7 @@ def __init__(
self.validate_non_negative(eps, 'eps')

self.reduce_func: Callable = reduce_func if reduce_func is not None else lambda x: x
self.cautious = cautious

defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'keep_num': keep_num, 'eps': eps}
super().__init__(params, defaults)
Expand Down Expand Up @@ -101,13 +104,16 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg = state['exp_avg']
exp_avg.sub_(offset_grad, alpha=first_grad_weight).mul_(beta1).add_(grad, alpha=last_grad_weight)

reduced_grad_sq = self.reduce_func(offset_grad.mul_(offset_grad))
reduced_grad_sq = self.reduce_func(offset_grad.pow_(2))

exp_avg_sq = state['exp_avg_sq']
exp_avg_sq.mul_(beta2).add_(reduced_grad_sq, alpha=1.0 - beta2)

de_nom = exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps'])
update = exp_avg.clone()
update.div_(exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps']))
if self.cautious:
self.apply_cautious(update, grad)

p.addcdiv_(exp_avg, de_nom, value=-group['lr'])
p.add_(update, alpha=-group['lr'])

return loss
8 changes: 4 additions & 4 deletions pytorch_optimizer/optimizer/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ def __init__(
self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
self._optimizer_step_post_hooks: Dict[int, Callable] = {}

self.optimizer = optimizer
self.alpha = alpha
self.k = k
self.pullback_momentum = pullback_momentum

self.optimizer = optimizer

self.state: STATE = defaultdict(dict)

for group in self.param_groups:
Expand Down Expand Up @@ -93,11 +92,12 @@ def clear_and_load_backup(self):
del state['backup_params']

def state_dict(self) -> STATE:
return self.optimizer.state_dict()
return {'lookahead_state': self.state, 'base_optimizer': self.optimizer.state_dict()}

def load_state_dict(self, state: STATE):
r"""Load state."""
self.optimizer.load_state_dict(state)
self.state = state['lookahead_state']
self.optimizer.load_state_dict(state['base_optimizer'])

@torch.no_grad()
def zero_grad(self):
Expand Down
1 change: 1 addition & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,4 +552,5 @@
(LaProp, {'lr': 1e0, 'cautious': True}, 2),
(AdamP, {'lr': 1e0, 'cautious': True}, 2),
(ADOPT, {'lr': 1e1, 'cautious': True}, 3),
(AdaShift, {'lr': 1e1, 'keep_num': 1, 'cautious': True}, 3),
]
3 changes: 0 additions & 3 deletions tests/test_optimizer_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,12 @@ def test_lookahead_parameters():

_ = opt.__getstate__()

# test lookahead step `k`
with pytest.raises(ValueError):
Lookahead(optimizer, k=0)

# test ema ratio `alpha`
with pytest.raises(ValueError):
Lookahead(optimizer, alpha=-0.1)

# test invalid pullback momentum type
with pytest.raises(ValueError):
Lookahead(optimizer, pullback_momentum='invalid')

Expand Down
Loading