Skip to content

Commit

Permalink
Merge pull request #229 from kozistr/feature/adalite-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement Adalite optimizer
  • Loading branch information
kozistr authored Apr 7, 2024
2 parents b1b5ed4 + 6abce12 commit 48030b5
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 9 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **63 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported!
Currently, **64 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported!

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down Expand Up @@ -161,6 +161,7 @@ supported_optimizers = get_supported_optimizers()
| WSAM | *Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term* | [github](https://github.com/intelligent-machine-learning/dlrover/blob/master/atorch/atorch/optimizers/wsam.py) | <https://arxiv.org/abs/2305.15817> | [cite](https://github.com/intelligent-machine-learning/dlrover) |
| Aida | *A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range* | [github](https://github.com/guoqiang-zhang-x/Aida-Optimizer) | <https://arxiv.org/abs/2203.13273> | [cite](https://github.com/guoqiang-zhang-x/Aida-Optimizer?tab=readme-ov-file#1-brief-description-of-aida) |
| GaLore | *Memory-Efficient LLM Training by Gradient Low-Rank Projection* | [github](https://github.com/jiaweizzhao/GaLore) | <https://arxiv.org/abs/2403.03507> | [cite](https://github.com/jiaweizzhao/GaLore/tree/master?tab=readme-ov-file#citation) |
| Adalite | *Adalite optimizer* | [github](https://github.com/VatsaDev/adalite) | <https://github.com/VatsaDev/adalite> | [cite](https://github.com/VatsaDev/adalite) |

## Supported LR Scheduler

Expand Down
1 change: 1 addition & 0 deletions docs/changelogs/v3.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Major version is updated! (`v2.12.0` -> `v3.0.0`) (#164)
* [Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term](https://arxiv.org/abs/2305.15817)
* Implement `GaLore` optimizer. (#224, #228)
* [Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507)
* Implement `Adalite` optimizer. (#225, #229)

### Fix

Expand Down
4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
:docstring:
:members:

::: pytorch_optimizer.Adalite
:docstring:
:members:

::: pytorch_optimizer.AdaMax
:docstring:
:members:
Expand Down
15 changes: 8 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ repository = "https://github.com/kozistr/pytorch_optimizer"
documentation = "https://pytorch-optimizers.readthedocs.io/en/latest"
keywords = [
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "AdamP",
"AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "CAME", "DAdaptAdaGrad", "DAdaptAdam",
"DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "GaLore", "Gravity", "GSAM", "LARS", "Lamb", "Lion",
"LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam",
"QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD",
"SM3", "SopihaH", "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
"Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes",
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "Adalite",
"AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "CAME", "DAdaptAdaGrad",
"DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "GaLore", "Gravity", "GSAM", "LARS",
"Lamb", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM",
"Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo",
"SGDW", "SignSGD", "SM3", "SopihaH", "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
"FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge",
"bitsandbytes",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
Expand Down
2 changes: 2 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytorch_optimizer.optimizer.adafactor import AdaFactor
from pytorch_optimizer.optimizer.adahessian import AdaHessian
from pytorch_optimizer.optimizer.adai import Adai
from pytorch_optimizer.optimizer.adalite import Adalite
from pytorch_optimizer.optimizer.adamax import AdaMax
from pytorch_optimizer.optimizer.adamod import AdaMod
from pytorch_optimizer.optimizer.adamp import AdamP
Expand Down Expand Up @@ -184,6 +185,7 @@
DAdaptLion,
Aida,
GaLore,
Adalite,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
168 changes: 168 additions & 0 deletions pytorch_optimizer/optimizer/adalite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import torch
from torch.nn.functional import softmax
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS


class Adalite(Optimizer, BaseOptimizer):
r"""Adalite optimizer.
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
:param fixed_decay: bool. fix weight decay.
:param g_norm_min: float.
:param ratio_min: float.
:param tau: float.
:param eps1: float. term added to the denominator to improve numerical stability.
:param eps2: float. term added to the denominator to improve numerical stability.
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-3,
betas: BETAS = (0.9, 0.999),
weight_decay: float = 1e-2,
weight_decouple: bool = False,
fixed_decay: bool = False,
g_norm_min: float = 1e-10,
ratio_min: float = 1e-4,
tau: float = 1.0,
eps1: float = 1e-6,
eps2: float = 1e-10,
):
self.validate_learning_rate(lr)
self.validate_betas(betas)
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps1, 'eps1')
self.validate_non_negative(eps2, 'eps1')

defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
'weight_decay': weight_decay,
'weight_decouple': weight_decouple,
'fixed_decay': fixed_decay,
'g_norm_min': g_norm_min,
'ratio_min': ratio_min,
'tau': tau,
'eps1': eps1,
'eps2': eps2,
}
super().__init__(params, defaults)

def __str__(self) -> str:
return 'Adalite'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
state = self.state[p]

if len(p.shape) < 2:
state['m_avg'] = torch.zeros_like(p)
state['v_avg'] = torch.zeros_like(p)
else:
state['v_avg_0'] = torch.zeros_like(p.mean(dim=1))
state['v_avg_1'] = torch.zeros_like(p.mean(dim=0))

state['m_avg_c'] = torch.zeros_like(p.mean(dim=1)[:, None])
state['m_avg_r'] = torch.zeros_like(p.mean(dim=0)[None, :])
state['m_avg_u'] = torch.zeros_like(p.mean().unsqueeze(0).unsqueeze(0))

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1

beta1, beta2 = group['betas']

for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(str(self))

state = self.state[p]

if len(state) == 0:
if len(p.shape) < 2:
state['m_avg'] = torch.zeros_like(p)
state['v_avg'] = torch.zeros_like(p)
else:
state['v_avg_0'] = torch.zeros_like(p.mean(dim=1))
state['v_avg_1'] = torch.zeros_like(p.mean(dim=0))

state['m_avg_c'] = torch.zeros_like(p.mean(dim=1)[:, None])
state['m_avg_r'] = torch.zeros_like(p.mean(dim=0)[None, :])
state['m_avg_u'] = torch.zeros_like(p.mean().unsqueeze(0).unsqueeze(0))

if sum(grad.shape) > 1:
trust_ratio = (p.norm() / grad.norm().clip(min=group['g_norm_min'])).clip(min=group['ratio_min'])
grad.mul_(trust_ratio)

if len(grad.shape) < 2:
m = state['m_avg']
v = state['v_avg']
else:
r, c = state['v_avg_0'][:, None], state['v_avg_1'][None, :]
v = (r * c) / r.sum().clamp(min=group['eps2'])
m = state['m_avg_c'] @ state['m_avg_u'] @ state['m_avg_r']

m.lerp_(grad, 1.0 - beta1)
v.lerp_((grad - m).square(), 1.0 - beta2)

v_avg = v / (1.0 - beta2 ** group['step'])

if len(grad.shape) == 2:
imp_c = softmax(v.mean(dim=1), dim=0)[:, None]
imp_r = softmax(v.mean(dim=0), dim=0)[None, :]
m.lerp_(grad, 1.0 - imp_c * imp_r)

u = m.lerp(grad, 1.0 - beta1)

if len(grad.shape) < 2:
state['m_avg'] = m
state['v_avg'] = v
else:
state['v_avg_0'] = v.sum(dim=1)
state['v_avg_1'] = v.sum(dim=0) / v.sum().clamp(min=group['eps2'])

imp_c = softmax(v.mean(dim=1) / group['tau'], dim=-1)[:, None]
imp_r = softmax(v.mean(dim=0) / group['tau'], dim=-1)[None, :]

c = ((m * imp_r).sum(dim=1))[:, None]
r = ((m * imp_c).sum(dim=0))[None, :]

s = (c.T @ m @ r.T) / (c.T @ c @ r @ r.T).clamp(min=group['eps2'])

state['m_avg_c'] = c
state['m_avg_r'] = r
state['m_avg_u'] = s

u.div_((v_avg + group['eps1']).sqrt())

u = u.reshape(p.shape)
u.add_(p, alpha=group['weight_decay'])

p.add_(u, alpha=-group['lr'])

return loss
4 changes: 4 additions & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
AdaFactor,
AdaHessian,
Adai,
Adalite,
AdaMax,
AdaMod,
AdamP,
Expand Down Expand Up @@ -120,6 +121,8 @@
'padam',
'came',
'aida',
'galore',
'adalite',
]

VALID_LR_SCHEDULER_NAMES: List[str] = [
Expand Down Expand Up @@ -434,6 +437,7 @@
{'lr': 5e-1, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 2, 'projection_type': 'full'},
5,
),
(Adalite, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
]
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),
Expand Down
1 change: 1 addition & 0 deletions tests/test_general_optimizer_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_epsilon(optimizer_name):
'lomo',
'tiger',
'came',
'adalite',
):
pytest.skip(f'skip {optimizer_name} optimizer')

Expand Down
2 changes: 1 addition & 1 deletion tests/test_load_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):


def test_get_supported_optimizers():
assert len(get_supported_optimizers()) == 62
assert len(get_supported_optimizers()) == 63


def test_get_supported_lr_schedulers():
Expand Down
5 changes: 5 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,11 @@ def test_prodigy_reset():
assert str(optimizer) == 'Prodigy'


def test_adalite_reset():
optimizer = load_optimizer('adalite')([simple_zero_rank_parameter(True)])
optimizer.reset()


@pytest.mark.parametrize('pre_conditioner_type', [0, 1, 2])
def test_scalable_shampoo_pre_conditioner_with_svd(pre_conditioner_type, environment):
(x_data, y_data), _, loss_fn = environment
Expand Down

0 comments on commit 48030b5

Please sign in to comment.