From 5a2002ebae32d91d125886c43bdb30f9b10b1ca6 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 7 Apr 2024 15:51:58 +0900 Subject: [PATCH 01/12] feature: implement Adalite optimizer --- pytorch_optimizer/optimizer/adalite.py | 168 +++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 pytorch_optimizer/optimizer/adalite.py diff --git a/pytorch_optimizer/optimizer/adalite.py b/pytorch_optimizer/optimizer/adalite.py new file mode 100644 index 000000000..3c514b1b1 --- /dev/null +++ b/pytorch_optimizer/optimizer/adalite.py @@ -0,0 +1,168 @@ +import torch +from torch.nn.functional import softmax +from torch.optim.optimizer import Optimizer + +from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS + + +class Adalite(Optimizer, BaseOptimizer): + r"""Adalite optimizer. + + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. + :param lr: float. learning rate. + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. + :param weight_decay: float. weight decay (L2 penalty). + :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. + :param fixed_decay: bool. fix weight decay. + :param g_norm_min: float. + :param ratio_min: float. + :param tau: float. + :param eps1: float. term added to the denominator to improve numerical stability. + :param eps2: float. term added to the denominator to improve numerical stability. + """ + + def __init__( + self, + params: PARAMETERS, + lr: float = 1e-3, + betas: BETAS = (0.9, 0.999), + weight_decay: float = 1e-2, + weight_decouple: bool = False, + fixed_decay: bool = False, + g_norm_min: float = 1e-10, + ratio_min: float = 1e-4, + tau: float = 1.0, + eps1: float = 1e-6, + eps2: float = 1e-10, + ): + self.validate_learning_rate(lr) + self.validate_betas(betas) + self.validate_non_negative(weight_decay, 'weight_decay') + self.validate_non_negative(eps1, 'eps1') + self.validate_non_negative(eps2, 'eps1') + + defaults: DEFAULTS = { + 'lr': lr, + 'betas': betas, + 'weight_decay': weight_decay, + 'weight_decouple': weight_decouple, + 'fixed_decay': fixed_decay, + 'g_norm_min': g_norm_min, + 'ratio_min': ratio_min, + 'tau': tau, + 'eps1': eps1, + 'eps2': eps2, + } + super().__init__(params, defaults) + + def __str__(self) -> str: + return 'Adalite' + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + state = self.state[p] + + if len(p.shape) < 2: + state['m_avg'] = torch.zeros_like(p) + state['v_avg'] = torch.zeros_like(p) + else: + state['v_avg_0'] = torch.zeros_like(p.shape(dim=1)) + state['v_avg_1'] = torch.zeros_like(p.mean(dim=0)) + + state['m_avg_c'] = torch.zeros_like(p.mean(dim=1)[:, None]) + state['m_avg_r'] = torch.zeros_like(p.mean(dim=0)[None, :]) + state['m_avg_u'] = torch.zeros_like(p.mean().unsqueeze(0).unsqueeze(0)) + + @torch.no_grad() + def step(self, closure: CLOSURE = None) -> LOSS: + loss: LOSS = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + if grad.is_sparse: + raise NoSparseGradientError(str(self)) + + state = self.state[p] + + if len(state) == 0: + if len(p.shape) < 2: + state['m_avg'] = torch.zeros_like(p) + state['v_avg'] = torch.zeros_like(p) + else: + state['v_avg_0'] = torch.zeros_like(p.shape(dim=1)) + state['v_avg_1'] = torch.zeros_like(p.mean(dim=0)) + + state['m_avg_c'] = torch.zeros_like(p.mean(dim=1)[:, None]) + state['m_avg_r'] = torch.zeros_like(p.mean(dim=0)[None, :]) + state['m_avg_u'] = torch.zeros_like(p.mean().unsqueeze(0).unsqueeze(0)) + + if sum(grad.shape) > 1: + trust_ratio = (p.norm() / grad.norm().clip(min=group['g_norm_min'])).clip(min=group['ratio_min']) + grad.mul_(trust_ratio) + + if len(grad.shape) < 2: + m = state['m_avg'] + v = state['v_avg'] + else: + r, c = state['v_avg_0'][:, None], state['v_avg_1'][None, :] + v = (r * c) / r.sum().clamp(min=group['eps2']) + m = state['m_avg_c'] @ state['m_avg_u'] @ state['m_avg_r'] + + m.lerp_(grad, 1.0 - beta1) + v.lerp_((grad - m).square(), 1.0 - beta2) + + v_avg = v / (1.0 - beta2 ** group['step']) + + if len(grad.shape) == 2: + imp_c = softmax(v.mean(dim=1), dim=0)[:, None] + imp_r = softmax(v.mean(dim=0), dim=0)[None, :] + m.lerp_(grad, 1.0 - imp_c * imp_r) + + u = m.lerp(grad, 1.0 - beta1) + + if len(grad.shape) < 2: + state['m_avg'] = m + state['v_avg'] = v + else: + state['v_avg_0'] = v.sum(dim=1) + state['v_avg_1'] = v.sum(dim=0) / v.sum().clamp(min=group['eps2']) + + imp_c = softmax(v.mean(dim=1) / group['tau'], dim=-1)[:, None] + imp_r = softmax(v.mean(dim=0) / group['tau'], dim=-1)[None, :] + + c = ((m * imp_r).sum(dim=1))[:, None] + r = ((m * imp_c).sum(dim=0))[None, :] + + s = (c.T @ m @ r.T) / (c.T @ c @ r @ r.T).clamp(min=group['eps2']) + + state['m_avg_c'] = c + state['m_avg_r'] = r + state['m_avg_u'] = s + + u.div_((v_avg + group['eps1']).sqrt()) + + u = u.reshape(p.shape) + u.add_(p, alpha=group['weight_decay']) + + p.add_(u, alpha=-group['lr']) + + return loss From 41d7a46bba8cee543b959d835cb489419728117f Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 7 Apr 2024 15:53:14 +0900 Subject: [PATCH 02/12] chore: keywords --- pyproject.toml | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 646597102..72654dbe5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,13 +11,14 @@ repository = "https://github.com/kozistr/pytorch_optimizer" documentation = "https://pytorch-optimizers.readthedocs.io/en/latest" keywords = [ "pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound", - "AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "AdamP", - "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "CAME", "DAdaptAdaGrad", "DAdaptAdam", - "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "GaLore", "Gravity", "GSAM", "LARS", "Lamb", "Lion", - "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", - "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", - "SM3", "SopihaH", "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", - "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", + "AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "Adalite", + "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "CAME", "DAdaptAdaGrad", + "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "GaLore", "Gravity", "GSAM", "LARS", + "Lamb", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", + "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo", + "SGDW", "SignSGD", "SM3", "SopihaH", "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", + "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", + "bitsandbytes", ] classifiers = [ "License :: OSI Approved :: Apache Software License", From d7466baeb47f94ac2eb42c9d48dd7c5b77056ebd Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 7 Apr 2024 15:53:30 +0900 Subject: [PATCH 03/12] update: Adalite optimizer --- pytorch_optimizer/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index f0f1c24ef..6c3b1c2c3 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -34,6 +34,7 @@ from pytorch_optimizer.optimizer.adafactor import AdaFactor from pytorch_optimizer.optimizer.adahessian import AdaHessian from pytorch_optimizer.optimizer.adai import Adai +from pytorch_optimizer.optimizer.adalite import Adalite from pytorch_optimizer.optimizer.adamax import AdaMax from pytorch_optimizer.optimizer.adamod import AdaMod from pytorch_optimizer.optimizer.adamp import AdamP @@ -184,6 +185,7 @@ DAdaptLion, Aida, GaLore, + Adalite, ] OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST} From c3856a73ee53b760968bd4859b5ae42d509d8731 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 7 Apr 2024 15:56:16 +0900 Subject: [PATCH 04/12] docs: Adalite optimizer --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index dbfc1a1d2..a3fb2103a 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,7 @@ supported_optimizers = get_supported_optimizers() | WSAM | *Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term* | [github](https://github.com/intelligent-machine-learning/dlrover/blob/master/atorch/atorch/optimizers/wsam.py) | | [cite](https://github.com/intelligent-machine-learning/dlrover) | | Aida | *A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range* | [github](https://github.com/guoqiang-zhang-x/Aida-Optimizer) | | [cite](https://github.com/guoqiang-zhang-x/Aida-Optimizer?tab=readme-ov-file#1-brief-description-of-aida) | | GaLore | *Memory-Efficient LLM Training by Gradient Low-Rank Projection* | [github](https://github.com/jiaweizzhao/GaLore) | | [cite](https://github.com/jiaweizzhao/GaLore/tree/master?tab=readme-ov-file#citation) | +| Adalite | *Adalite optimizer* | [github](https://github.com/VatsaDev/adalite) | | [cite](https://github.com/VatsaDev/adalite) | ## Supported LR Scheduler From d3597dd124dcea79c3f09ff1c12e4f2990ab9c8b Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 7 Apr 2024 15:56:56 +0900 Subject: [PATCH 05/12] docs: v3.0.0 changelog --- docs/changelogs/v3.0.0.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelogs/v3.0.0.md b/docs/changelogs/v3.0.0.md index 0de82f3c9..eb1f6aedc 100644 --- a/docs/changelogs/v3.0.0.md +++ b/docs/changelogs/v3.0.0.md @@ -12,6 +12,7 @@ Major version is updated! (`v2.12.0` -> `v3.0.0`) (#164) * [Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term](https://arxiv.org/abs/2305.15817) * Implement `GaLore` optimizer. (#224, #228) * [Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507) +* Implement `Adalite` optimizer. (#225, #229) ### Fix From ae1f936c4593a8b273c919f0a1c3a3a15ad28a18 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 7 Apr 2024 15:57:32 +0900 Subject: [PATCH 06/12] docs: Adalite optimizer --- docs/optimizer.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/optimizer.md b/docs/optimizer.md index 14d326b4e..6c08ba865 100644 --- a/docs/optimizer.md +++ b/docs/optimizer.md @@ -28,6 +28,10 @@ :docstring: :members: +::: pytorch_optimizer.Adalite + :docstring: + :members: + ::: pytorch_optimizer.AdaMax :docstring: :members: From 30427e7ef53c2da03152d7e529693e9db93e3658 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 7 Apr 2024 15:57:39 +0900 Subject: [PATCH 07/12] update: number of optimizers --- tests/test_load_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index b7812180c..e6662b13c 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names): def test_get_supported_optimizers(): - assert len(get_supported_optimizers()) == 62 + assert len(get_supported_optimizers()) == 63 def test_get_supported_lr_schedulers(): From ae8ea0f76d876b2e47d586e00663063d784765d1 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 7 Apr 2024 15:59:30 +0900 Subject: [PATCH 08/12] update: recipes --- tests/test_general_optimizer_parameters.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_general_optimizer_parameters.py b/tests/test_general_optimizer_parameters.py index b6bf4491d..19266f098 100644 --- a/tests/test_general_optimizer_parameters.py +++ b/tests/test_general_optimizer_parameters.py @@ -46,6 +46,7 @@ def test_epsilon(optimizer_name): 'lomo', 'tiger', 'came', + 'adalite', ): pytest.skip(f'skip {optimizer_name} optimizer') From e0acc05373093760219996707e8c7e12682acd93 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 7 Apr 2024 15:59:33 +0900 Subject: [PATCH 09/12] update: recipes --- tests/constants.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/constants.py b/tests/constants.py index f7ffc4b46..4b322bdec 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -23,6 +23,7 @@ AdaFactor, AdaHessian, Adai, + Adalite, AdaMax, AdaMod, AdamP, @@ -120,6 +121,8 @@ 'padam', 'came', 'aida', + 'galore', + 'adalite', ] VALID_LR_SCHEDULER_NAMES: List[str] = [ @@ -434,6 +437,7 @@ {'lr': 5e-1, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 2, 'projection_type': 'full'}, 5, ), + (Adalite, {'lr': 1e0, 'weight_decay': 1e-3}, 5), ] ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [ (AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10), From 96d6ac6d13e0b827263cab49ded04dbc0a86989f Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 7 Apr 2024 16:05:14 +0900 Subject: [PATCH 10/12] fix: typo --- pytorch_optimizer/optimizer/adalite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_optimizer/optimizer/adalite.py b/pytorch_optimizer/optimizer/adalite.py index 3c514b1b1..71a5f7cb4 100644 --- a/pytorch_optimizer/optimizer/adalite.py +++ b/pytorch_optimizer/optimizer/adalite.py @@ -71,7 +71,7 @@ def reset(self): state['m_avg'] = torch.zeros_like(p) state['v_avg'] = torch.zeros_like(p) else: - state['v_avg_0'] = torch.zeros_like(p.shape(dim=1)) + state['v_avg_0'] = torch.zeros_like(p.mean(dim=1)) state['v_avg_1'] = torch.zeros_like(p.mean(dim=0)) state['m_avg_c'] = torch.zeros_like(p.mean(dim=1)[:, None]) @@ -108,7 +108,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: state['m_avg'] = torch.zeros_like(p) state['v_avg'] = torch.zeros_like(p) else: - state['v_avg_0'] = torch.zeros_like(p.shape(dim=1)) + state['v_avg_0'] = torch.zeros_like(p.mean(dim=1)) state['v_avg_1'] = torch.zeros_like(p.mean(dim=0)) state['m_avg_c'] = torch.zeros_like(p.mean(dim=1)[:, None]) From bab6541895daccd1de272eec19aab513a8c20876 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 7 Apr 2024 16:06:29 +0900 Subject: [PATCH 11/12] docs: README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a3fb2103a..e8f6267dc 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **63 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported! +Currently, **64 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). From 6abce120188c22870c368696d4990049310ae2dd Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 7 Apr 2024 16:10:06 +0900 Subject: [PATCH 12/12] update: test_adalite_reset --- tests/test_optimizers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index b7e7c99d0..0546f8fd1 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -464,6 +464,11 @@ def test_prodigy_reset(): assert str(optimizer) == 'Prodigy' +def test_adalite_reset(): + optimizer = load_optimizer('adalite')([simple_zero_rank_parameter(True)]) + optimizer.reset() + + @pytest.mark.parametrize('pre_conditioner_type', [0, 1, 2]) def test_scalable_shampoo_pre_conditioner_with_svd(pre_conditioner_type, environment): (x_data, y_data), _, loss_fn = environment