Skip to content

Commit

Permalink
Merge pull request #50 from kozistr/feature/lars-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement LARS optimizer
  • Loading branch information
kozistr authored Feb 1, 2022
2 parents 16aeb2c + 9fadf57 commit 3cd5158
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 1 deletion.
1 change: 1 addition & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytorch_optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
from pytorch_optimizer.gc import centralize_gradient
from pytorch_optimizer.lamb import Lamb
from pytorch_optimizer.lars import LARS
from pytorch_optimizer.lookahead import Lookahead
from pytorch_optimizer.madgrad import MADGRAD
from pytorch_optimizer.optimizers import load_optimizers
Expand Down
106 changes: 106 additions & 0 deletions pytorch_optimizer/lars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import torch
from torch.optim import Optimizer

from pytorch_optimizer.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS


class LARS(Optimizer):
"""
Reference : https://github.com/facebookresearch/mae/blob/main/util/lars.py
Example :
from pytorch_optimizer import LARS
...
model = YourModel()
optimizer = LARS(model.parameters())
...
for input, output in data:
optimizer.zero_grad()
loss = loss_function(output, model(input))
loss.backward()
optimizer.step()
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-3,
weight_decay: float = 0.0,
momentum: float = 0.9,
trust_coefficient: float = 0.001,
eps: float = 1e-6,
):
"""LARS optimizer, no rate scaling or weight decay for parameters <= 1D
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate
:param weight_decay: float. weight decay (L2 penalty)
:param momentum: float. momentum
:param trust_coefficient: float. trust_coefficient
:param eps: float. epsilon
"""
self.lr = lr
self.weight_decay = weight_decay
self.momentum = momentum
self.trust_coefficient = trust_coefficient
self.eps = eps

self.check_valid_parameters()

defaults: DEFAULTS = dict(
lr=lr,
weight_decay=weight_decay,
momentum=momentum,
trust_coefficient=trust_coefficient,
)
super().__init__(params, defaults)

def check_valid_parameters(self):
if self.lr < 0.0:
raise ValueError(f'Invalid learning rate : {self.lr}')
if self.weight_decay < 0.0:
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
if self.momentum < 0.0:
raise ValueError(f'Invalid momentum : {self.momentum}')
if self.trust_coefficient < 0.0:
raise ValueError(f'Invalid trust_coefficient : {self.trust_coefficient}')
if self.eps < 0.0:
raise ValueError(f'Invalid eps : {self.eps}')

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
loss = closure()

for g in self.param_groups:
for p in g['params']:
if p.grad is None:
continue

if p.grad.data.is_sparse:
raise RuntimeError('LARS does not support sparse gradients')

dp = p.grad

if p.ndim > 1: # if not normalization gamma/beta or bias
dp = dp.add(p, alpha=g['weight_decay'])
param_norm = torch.norm(p)
update_norm = torch.norm(dp)
one = torch.ones_like(param_norm)

q = torch.where(
param_norm > 0.0,
torch.where(update_norm > 0.0, (g['trust_coefficient'] * param_norm / update_norm), one),
one,
)
dp = dp.mul(q)

param_state = self.state[p]
if 'mu' not in param_state:
param_state['mu'] = torch.zeros_like(p)

mu = param_state['mu']
mu.mul_(g['momentum']).add_(dp)

p.add_(mu, alpha=-g['lr'])

return loss
3 changes: 3 additions & 0 deletions pytorch_optimizer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pytorch_optimizer.diffrgrad import DiffRGrad
from pytorch_optimizer.fp16 import SafeFP16Optimizer
from pytorch_optimizer.lamb import Lamb
from pytorch_optimizer.lars import LARS
from pytorch_optimizer.madgrad import MADGRAD
from pytorch_optimizer.radam import RAdam
from pytorch_optimizer.ralamb import RaLamb
Expand Down Expand Up @@ -45,6 +46,8 @@ def load_optimizers(optimizer: str, use_fp16: bool = False):
opt = Lamb
elif optimizer == 'ralamb':
opt = RaLamb
elif optimizer == 'lars':
opt = LARS
else:
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')

Expand Down
2 changes: 1 addition & 1 deletion pytorch_optimizer/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__VERSION__ = '0.3.6'
__VERSION__ = '0.3.7'
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def read_version() -> str:
'adamd',
'lamb',
'ralamb',
'lars',
]
)

Expand Down
1 change: 1 addition & 0 deletions tests/test_load_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
'diffrgrad',
'lamb',
'ralamb',
'lars',
]

INVALID_OPTIMIZER_NAMES: List[str] = [
Expand Down
1 change: 1 addition & 0 deletions tests/test_optimizer_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
'diffrgrad',
'lamb',
'ralamb',
'lars',
]

BETA_OPTIMIZER_NAMES: List[str] = [
Expand Down
2 changes: 2 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn import functional as F

from pytorch_optimizer import (
LARS,
MADGRAD,
SAM,
SGDP,
Expand Down Expand Up @@ -94,6 +95,7 @@ def build_lookahead(*parameters, **kwargs):
(DiffRGrad, {'lr': 5e-1, 'weight_decay': 1e-3}, 200),
(Lamb, {'lr': 1e-1, 'weight_decay': 1e-3}, 500),
(Lamb, {'lr': 1e-1, 'weight_decay': 1e-3, 'pre_norm': True, 'eps': 1e-8}, 500),
(LARS, {'lr': 1e-1, 'weight_decay': 1e-3}, 500),
(RaLamb, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
(MADGRAD, {'lr': 1e-2, 'weight_decay': 1e-3}, 500),
(RAdam, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
Expand Down
1 change: 1 addition & 0 deletions tests/test_sparse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
'diffrgrad',
'lamb',
'ralamb',
'lars',
]


Expand Down

0 comments on commit 3cd5158

Please sign in to comment.