Skip to content

Commit

Permalink
Merge pull request #22 from kozistr/feature/sam-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement SAM optimizer
  • Loading branch information
kozistr authored Sep 22, 2021
2 parents 7588ee2 + 2749a08 commit 8ef27c3
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 7 deletions.
51 changes: 45 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Also, most of the captures are taken from `Ranger21` paper.
This idea originally proposed in `NFNet (Normalized-Free Network)` paper.
AGC (Adaptive Gradient Clipping) clips gradients based on the `unit-wise ratio of gradient norms to parameter norms`.

* github : [code](https://github.com/deepmind/deepmind-research/tree/master/nfnets)
* code : [github](https://github.com/deepmind/deepmind-research/tree/master/nfnets)
* paper : [arXiv](https://arxiv.org/abs/2102.06171)

### Gradient Centralization (GC)
Expand All @@ -62,7 +62,7 @@ AGC (Adaptive Gradient Clipping) clips gradients based on the `unit-wise ratio o

Gradient Centralization (GC) operates directly on gradients by centralizing the gradient to have zero mean.

* github : [code](https://github.com/Yonghongwei/Gradient-Centralization)
* code : [github](https://github.com/Yonghongwei/Gradient-Centralization)
* paper : [arXiv](https://arxiv.org/abs/2004.01461)

### Softplus Transformation
Expand All @@ -83,7 +83,7 @@ By running the final variance denom through the softplus function, it lifts extr

![positive_negative_momentum](assets/positive_negative_momentum.png)

* github : [code](https://github.com/zeke-xie/Positive-Negative-Momentum)
* code : [github](https://github.com/zeke-xie/Positive-Negative-Momentum)
* paper : [arXiv](https://arxiv.org/abs/2103.17182)

### Linear learning-rate warm-up
Expand All @@ -96,22 +96,22 @@ By running the final variance denom through the softplus function, it lifts extr

![stable_weight_decay](assets/stable_weight_decay.png)

* github : [code](https://github.com/zeke-xie/stable-weight-decay-regularization)
* code : [github](https://github.com/zeke-xie/stable-weight-decay-regularization)
* paper : [arXiv](https://arxiv.org/abs/2011.11152)

### Explore-exploit learning-rate schedule

![explore_exploit_lr_schedule](assets/explore_exploit_lr_schedule.png)

* github : [code](https://github.com/nikhil-iyer-97/wide-minima-density-hypothesis)
* code : [github](https://github.com/nikhil-iyer-97/wide-minima-density-hypothesis)
* paper : [arXiv](https://arxiv.org/abs/2003.03977)

### Lookahead

`k` steps forward, 1 step back. `Lookahead` consisting of keeping an exponential moving average of the weights that is
updated and substituted to the current weights every `k_{lookahead}` steps (5 by default).

* github : [code](https://github.com/alphadl/lookahead.pytorch)
* code : [github](https://github.com/alphadl/lookahead.pytorch)
* paper : [arXiv](https://arxiv.org/abs/1907.08610v2)

### Chebyshev learning rate schedule
Expand All @@ -120,6 +120,15 @@ Acceleration via Fractal Learning Rate Schedules

* paper : [arXiv](https://arxiv.org/abs/2103.01338v1)

### (Adaptive) Sharpness-Aware Minimization (A/SAM)

Sharpness-Aware Minimization (SAM) simultaneously minimizes loss value and loss sharpness.
In particular, it seeks parameters that lie in neighborhoods having uniformly low loss.

* SAM paper : [paper](https://arxiv.org/abs/2010.01412)
* ASAM paper : [paper](https://arxiv.org/abs/2102.11600)
* A/SAM code : [github](https://github.com/davda54/sam)

## Citations

<details>
Expand Down Expand Up @@ -370,6 +379,36 @@ Acceleration via Fractal Learning Rate Schedules

</details>

<details>

<summary>Sharpness-Aware Minimization</summary>

```
@article{foret2020sharpness,
title={Sharpness-aware minimization for efficiently improving generalization},
author={Foret, Pierre and Kleiner, Ariel and Mobahi, Hossein and Neyshabur, Behnam},
journal={arXiv preprint arXiv:2010.01412},
year={2020}
}
```

</details>

<details>

<summary>Adaptive Sharpness-Aware Minimization</summary>

```
@article{kwon2021asam,
title={ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks},
author={Kwon, Jungmin and Kim, Jeongseop and Park, Hyunseo and Choi, In Kwon},
journal={arXiv preprint arXiv:2102.11600},
year={2021}
}
```

</details>

## Author

Hyeongchan Kim / [@kozistr](http://kozistr.tech/about)
3 changes: 2 additions & 1 deletion pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytorch_optimizer.radam import RAdam
from pytorch_optimizer.ranger import Ranger
from pytorch_optimizer.ranger21 import Ranger21
from pytorch_optimizer.sam import SAM
from pytorch_optimizer.sgdp import SGDP

__VERSION__ = '0.0.5'
__VERSION__ = '0.0.6'
155 changes: 155 additions & 0 deletions pytorch_optimizer/sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from typing import Dict

import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.types import (
CLOSURE,
DEFAULT_PARAMETERS,
PARAM_GROUPS,
PARAMS,
)


class SAM(Optimizer):
"""
Reference : https://github.com/davda54/sam
Example :
from pytorch_optimizer import SAM
...
model = YourModel()
base_optimizer = Ranger21
optimizer = SAM(model.parameters(), base_optimizer)
...
for input, output in data:
# first forward-backward pass
loss = loss_function(output, model(input)) # use this loss for any training statistics
loss.backward()
optimizer.first_step(zero_grad=True)
# second forward-backward pass
loss_function(output, model(input)).backward() # make sure to do a full forward pass
optimizer.second_step(zero_grad=True)
Alternative Example with a single closure-based step function:
from pytorch_optimizer import SAM
...
model = YourModel()
base_optimizer = Ranger21
optimizer = SAM(model.parameters(), base_optimizer)
def closure():
loss = loss_function(output, model(input))
loss.backward()
return loss
...
for input, output in data:
loss = loss_function(output, model(input))
loss.backward()
optimizer.step(closure)
optimizer.zero_grad()
"""

def __init__(
self,
params: PARAMS,
base_optimizer,
rho: float = 0.05,
adaptive: bool = False,
**kwargs,
):
"""(Adaptive) Sharpness-Aware Minimization
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
:param base_optimizer:
:param rho: float. size of the neighborhood for computing the max loss
:param adaptive: bool. element-wise Adaptive SAM
:param kwargs: Dict. parameters for optimizer.
"""
self.rho = rho

self.check_valid_parameters()

defaults: DEFAULT_PARAMETERS = dict(
rho=rho, adaptive=adaptive, **kwargs
)
super().__init__(params, defaults)

self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups: PARAM_GROUPS = self.base_optimizer.param_groups

def check_valid_parameters(self):
if 0.0 > self.rho:
raise ValueError(f'Invalid rho : {self.rho}')

@torch.no_grad()
def first_step(self, zero_grad: bool = False):
grad_norm = self.grad_norm()
for group in self.param_groups:
scale = group['rho'] / (grad_norm + 1e-12)

for p in group['params']:
if p.grad is None:
continue
self.state[p]['old_p'] = p.data.clone()
e_w = (
(torch.pow(p, 2) if group['adaptive'] else 1.0)
* p.grad
* scale.to(p)
)
p.add_(e_w) # climb to the local maximum "w + e(w)"

if zero_grad:
self.zero_grad()

@torch.no_grad()
def second_step(self, zero_grad: bool = False):
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
p.data = self.state[p][
'old_p'
] # get back to "w" from "w + e(w)"

self.base_optimizer.step() # do the actual "sharpness-aware" update

if zero_grad:
self.zero_grad()

@torch.no_grad()
def step(self, closure: CLOSURE = None):
if closure is None:
raise RuntimeError(
'Sharpness Aware Minimization requires closure, but it was not provided'
)

# the closure should do a full forward-backward pass
closure = torch.enable_grad()(closure)

self.first_step(zero_grad=True)
closure()
self.second_step()

def grad_norm(self) -> torch.Tensor:
shared_device = self.param_groups[0]['params'][
0
].device # put everything on the same device, in case of model parallelism
norm = torch.norm(
torch.stack(
[
((torch.abs(p) if group['adaptive'] else 1.0) * p.grad)
.norm(p=2)
.to(shared_device)
for group in self.param_groups
for p in group['params']
if p.grad is not None
]
),
p=2,
)
return norm

def load_state_dict(self, state_dict: Dict):
super().load_state_dict(state_dict)
self.base_optimizer.param_groups = self.param_groups
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def read_version() -> str:
'adabound',
'adahessian',
'adabelief',
'sam',
'asam',
]
)

Expand Down

0 comments on commit 8ef27c3

Please sign in to comment.