From 50caba70d3f5ffcb4caac9beac98d6f50553d5ae Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 6 Jul 2024 16:58:32 +0900 Subject: [PATCH 01/11] docs: v3.0.2 changelog --- docs/changelogs/v3.0.2.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changelogs/v3.0.2.md b/docs/changelogs/v3.0.2.md index dd9764e40..1d9bb937c 100644 --- a/docs/changelogs/v3.0.2.md +++ b/docs/changelogs/v3.0.2.md @@ -9,6 +9,8 @@ * [Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad](https://arxiv.org/abs/2403.02648) * Implement `StableAdamW` optimizer. (#250, #252) * [Stable and low-precision training for large-scale vision-language models](https://arxiv.org/abs/2304.13013) +* Implement `AdamMini` optimizer. (#246, #253) + * [Use Fewer Learning Rates To Gain More](https://arxiv.org/abs/2406.16793) ### Refactor From fbbded29a0f1ece3ef87eead6e5ecc1f8b8a523d Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 6 Jul 2024 17:00:38 +0900 Subject: [PATCH 02/11] docs: AdamMini optimizer --- README.md | 3 ++- docs/index.md | 3 ++- docs/optimizer.md | 4 ++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e385dfbfa..70a5acd0f 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **71 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported! +Currently, **72 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -168,6 +168,7 @@ supported_optimizers = get_supported_optimizers() | Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) | | Kate | *Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad* | [github](https://github.com/nazya/KATE) | | [cite](https://github.com/nazya/KATE?tab=readme-ov-file#remove-that-square-root-a-new-efficient-scale-invariant-version-of-adagrad) | | StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) | +| AdamMini | *Use Fewer Learning Rates To Gain More* | [github](https://github.com/zyushun/Adam-mini) | | [cite](https://github.com/zyushun/Adam-mini?tab=readme-ov-file#citation) | ## Supported LR Scheduler diff --git a/docs/index.md b/docs/index.md index e385dfbfa..70a5acd0f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **71 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported! +Currently, **72 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -168,6 +168,7 @@ supported_optimizers = get_supported_optimizers() | Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) | | Kate | *Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad* | [github](https://github.com/nazya/KATE) | | [cite](https://github.com/nazya/KATE?tab=readme-ov-file#remove-that-square-root-a-new-efficient-scale-invariant-version-of-adagrad) | | StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) | +| AdamMini | *Use Fewer Learning Rates To Gain More* | [github](https://github.com/zyushun/Adam-mini) | | [cite](https://github.com/zyushun/Adam-mini?tab=readme-ov-file#citation) | ## Supported LR Scheduler diff --git a/docs/optimizer.md b/docs/optimizer.md index 7af635eab..6c22e0f0d 100644 --- a/docs/optimizer.md +++ b/docs/optimizer.md @@ -32,6 +32,10 @@ :docstring: :members: +::: pytorch_optimizer.AdamMini + :docstring: + :members: + ::: pytorch_optimizer.AdaMax :docstring: :members: From 2dfa9803dade08aac1eea5872a28417b58491da7 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 6 Jul 2024 17:00:46 +0900 Subject: [PATCH 03/11] chore: keyword --- pyproject.toml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8bcff206b..aacb8e780 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,13 +12,13 @@ documentation = "https://pytorch-optimizers.readthedocs.io/en/latest" keywords = [ "pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound", "AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "Adalite", - "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", - "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", "Fromage", "GaLore", "Gravity", - "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", - "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", - "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH", - "SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", - "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", + "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM", "CAME", + "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", "Fromage", "GaLore", + "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", + "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", + "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", + "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", + "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", ] classifiers = [ "License :: OSI Approved :: Apache Software License", From 337bb0e2db9b3f9055509535df5bfd167d5c3190 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 6 Jul 2024 17:01:03 +0900 Subject: [PATCH 04/11] feature: implement AdamMini optimizer --- pytorch_optimizer/optimizer/adam_mini.py | 353 +++++++++++++++++++++++ 1 file changed, 353 insertions(+) create mode 100644 pytorch_optimizer/optimizer/adam_mini.py diff --git a/pytorch_optimizer/optimizer/adam_mini.py b/pytorch_optimizer/optimizer/adam_mini.py new file mode 100644 index 000000000..9e3eda137 --- /dev/null +++ b/pytorch_optimizer/optimizer/adam_mini.py @@ -0,0 +1,353 @@ +import math +from typing import Optional + +import torch +from torch import distributed as dist +from torch import nn +from torch.optim.optimizer import Optimizer + +from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS + + +class AdamMini(Optimizer, BaseOptimizer): + r"""Use Fewer Learning Rates To Gain More. + + :param model: nn.Module. model instance. + :param model_sharding: bool. set to True if you are using model parallelism with more than 1 GPU, including FSDP + and zero_1, 2, 3 in Deepspeed. Set to False if otherwise. + :param lr: float. learning rate. + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. + :param weight_decay: float. weight decay (L2 penalty). + :param num_embeds: int. number of embedding dimensions. could be unspecified if you are training non-transformer + models. + :param num_heads: int. number of attention heads. could be unspecified if you are training non-transformer models. + :param num_query_groups: Optional[int]. number of query groups in Group Query Attention (GQA). if not specified, it + will be equal to num_heads. could be unspecified if you are training non-transformer models. + :param eps: float. term added to the denominator to improve numerical stability. + """ + + def __init__( + self, + model: nn.Module, + lr: float = 1.0, + betas: BETAS = (0.9, 0.999), + weight_decay: float = 0.1, + model_sharding: bool = False, + num_embeds: int = 2048, + num_heads: int = 32, + num_query_groups: Optional[int] = None, + eps: float = 1e-8, + ): + self.validate_learning_rate(lr) + self.validate_betas(betas) + self.validate_non_negative(weight_decay, 'weight_decay') + self.validate_non_negative(num_embeds, 'num_embeds') + self.validate_non_negative(num_heads, 'num_heads') + self.validate_non_negative(eps, 'eps') + + self.num_query_groups: int = num_query_groups if num_query_groups is not None else num_embeds + self.validate_mod(num_embeds, self.num_query_groups) + + self.world_size: int = torch.cuda.device_count() + + self.model = model + self.model_sharding = model_sharding + self.num_embeds = num_embeds + self.num_heads = num_heads + + groups = self.get_optimizer_groups(weight_decay) + + defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'eps': eps} + super().__init__(groups, defaults) + + def __str__(self) -> str: + return 'AdamMini' + + def get_optimizer_groups(self, weight_decay: float): + groups = [] + for name, param in self.model.named_parameters(): + if not param.requires_grad: + continue + + group = { + 'name': name, + 'params': param, + 'weight_decay': 0.0 if ('norm' in name or 'ln_f' in name) else weight_decay, + } + + if ( + 'self_attn.k_proj.weight' in name + or 'self_attn.q_proj.weight' in name + or 'attn.wq.weight' in name + or 'attn.wk.weight' in name + ): + group['parameter_per_head'] = self.num_embeds * self.num_embeds // self.num_heads + + if 'attn.attn.weight' in name or 'attn.qkv.weight' in name: + group['n_head'] = self.num_heads + group['q_per_kv'] = self.num_embeds // self.num_query_groups + + groups.append(group) + + return groups + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + state = self.state[p] + + state['m'] = torch.zeros_like(p, dtype=torch.float32) + state['v'] = torch.zeros_like(p, dtype=torch.float32) + + @staticmethod + def step_embed( + p, + grad, + state, + lr: float, + beta1: float, + beta2: float, + bias_correction1: float, + bias_correction2_sq: float, + eps: float, + ) -> None: + if len(state) == 0: + state['m'] = torch.zeros_like(p, dtype=torch.float32) + state['v'] = torch.zeros_like(p, dtype=torch.float32) + + m, v = state['m'], state['v'] + + m.lerp_(grad, weight=1.0 - beta1) + v.mul_(beta2).addcmul_(grad, grad.conj(), value=1.0 - beta2) + + h = (v.sqrt() / bias_correction2_sq).add_(eps) + + p.addcdiv_(m, h, value=-lr / bias_correction1) + + @staticmethod + def step_attn_proj( + p, + grad, + state, + parameter_per_head: int, + lr: float, + beta1: float, + beta2: float, + bias_correction1: float, + bias_correction2_sq: float, + eps: float, + ) -> None: + if len(state) == 0: + state['m'] = torch.zeros_like(p, dtype=torch.float32).view(-1, parameter_per_head) + state['head'] = state['m'].shape[0] + state['v_mean'] = torch.zeros(state['head'], device=state['m'].device) + + m, v = state['m'], state['v_mean'] + + head: int = state['head'] + grad = grad.view(head, parameter_per_head) + + m.lerp_(grad, weight=1.0 - beta1) + + tmp_lr = torch.mean(grad * grad, dim=1).to(m.device) + v.mul_(beta2).add_(tmp_lr, alpha=1.0 - beta2) + + h = (v.sqrt() / bias_correction2_sq).add_(eps) + + update = (1 / (h * bias_correction1)).view(head, 1).mul_(m) + + if p.dim() > 1: + d0, d1 = p.size() + update = update.view(d0, d1) + else: + update = update.view(-1) + + p.add_(update, alpha=-lr) + + @staticmethod + def step_attn( + p, + grad, + state, + num_heads: int, + q_per_kv: int, + lr: float, + beta1: float, + beta2: float, + bias_correction1: float, + bias_correction2_sq: float, + eps: float, + ) -> None: + if len(state) == 0: + state['m'] = torch.zeros_like(p, dtype=torch.float32).view(num_heads, q_per_kv + 2, -1) + state['v_mean'] = torch.zeros(num_heads, q_per_kv + 2, device=state['m'].device) + + m, v = state['m'], state['v_mean'] + + grad = grad.view(num_heads, q_per_kv + 2, -1) + + m.lerp_(grad, weight=1.0 - beta1) + + tmp_lr = torch.mean(grad * grad, dim=2).to(m.device) + v.mul_(beta2).add_(tmp_lr, alpha=1.0 - beta2) + + h = (v.sqrt() / bias_correction2_sq).add_(eps) + + update = (1 / (h * bias_correction1)).view(num_heads, q_per_kv + 2, -1).mul_(m) + + if p.dim() > 1: + d0, d1 = p.size() + update = update.view(d0, d1) + else: + update = update.view(-1) + + p.add_(update, alpha=-lr) + + def step_lefts( + self, + p, + grad, + state, + lr: float, + beta1: float, + beta2: float, + bias_correction1: float, + bias_correction2_sq: float, + eps: float, + ) -> None: # pragma: no cover + if len(state) == 0: + dim = torch.tensor(p.numel(), device=p.device, dtype=torch.float32) + + reduced: bool = False + if self.model_sharding and self.world_size > 1: + tensor_list = [torch.zeros_like(dim) for _ in range(self.world_size)] + dist.all_gather(tensor_list, dim) + + s, dim = 0, 0 + for d in tensor_list: + if d > 0: + s += 1 + dim += d + + if s >= 2: + reduced = True + + state['m'] = torch.zeros_like(p, dtype=torch.float32) + state['v_mean'] = torch.tensor(0.0, device=state['m'].device) + state['dimension'] = dim + state['reduced'] = reduced + + tmp_lr = torch.sum(grad * grad) + + if state['reduced']: + dist.all_reduce(tmp_lr, op=dist.ReduceOp.SUM) + + tmp_lr.div_(state['dim']) + + m, v = state['m'], state['v_mean'] + + m.lerp_(grad, weight=1.0 - beta1) + v.mul_(beta2).add_(tmp_lr, value=1.0 - beta2) + + h = (v.sqrt() / bias_correction2_sq).add_(eps) + + update = 1 / (bias_correction1 * h).mul_(m) + + p.add_(update, alpha=-lr) + + @torch.no_grad() + def step(self, closure: CLOSURE = None) -> LOSS: + loss: LOSS = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + name = group['name'] + + beta1, beta2 = group['betas'] + + bias_correction1: float = 1.0 - beta1 ** group['step'] + bias_correction2: float = 1.0 - beta2 ** group['step'] + bias_correction2_sq: float = math.sqrt(bias_correction2) + + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + if grad.is_sparse: + raise NoSparseGradientError(str(self)) + + grad = grad.to(torch.float32) + + state = self.state[p] + + self.apply_weight_decay( + p=p, + grad=grad, + lr=group['lr'], + weight_decay=group['weight_decay'], + weight_decouple=True, + fixed_decay=False, + ) + + if 'embed_tokens' in name or 'wte' in name or 'lm_head' in name: + self.step_embed( + p, grad, state, group['lr'], beta1, beta2, bias_correction1, bias_correction2_sq, group['eps'] + ) + elif ( + 'self_attn.k_proj.weight' in name + or 'self_attn.q_proj.weight' in name + or 'attn.wq.weight' in name + or 'attn.wk.weight' in name + ): + self.step_attn_proj( + p, + grad, + state, + group['parameter_per_head'], + group['lr'], + beta1, + beta2, + bias_correction1, + bias_correction2_sq, + group['eps'], + ) + elif 'attn.attn.weight' in name or 'attn.qkv.weight' in name: + self.step_attn( + p, + grad, + state, + group['n_head'], + group['q_per_kv'], + group['lr'], + beta1, + beta2, + bias_correction1, + bias_correction2_sq, + group['eps'], + ) + else: + self.step_lefts( + p, + grad, + state, + group['lr'], + beta1, + beta2, + bias_correction1, + bias_correction2_sq, + group['eps'], + ) + + return loss From beea7da51c3662965fea9481f2670fdfc28cc0e3 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 6 Jul 2024 17:01:14 +0900 Subject: [PATCH 05/11] update: AdamMini optimizer --- pytorch_optimizer/__init__.py | 2 ++ pytorch_optimizer/base/optimizer.py | 5 +++++ tests/constants.py | 1 + tests/test_load_modules.py | 2 +- tests/test_optimizers.py | 8 ++++++++ 5 files changed, 17 insertions(+), 1 deletion(-) diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 3e1569a49..1f4dd552d 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -40,6 +40,7 @@ from pytorch_optimizer.optimizer.adahessian import AdaHessian from pytorch_optimizer.optimizer.adai import Adai from pytorch_optimizer.optimizer.adalite import Adalite +from pytorch_optimizer.optimizer.adam_mini import AdamMini from pytorch_optimizer.optimizer.adamax import AdaMax from pytorch_optimizer.optimizer.adamod import AdaMod from pytorch_optimizer.optimizer.adamp import AdamP @@ -203,6 +204,7 @@ GrokFastAdamW, Kate, StableAdamW, + AdamMini, ] OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST} diff --git a/pytorch_optimizer/base/optimizer.py b/pytorch_optimizer/base/optimizer.py index 722081a59..e729b13e7 100644 --- a/pytorch_optimizer/base/optimizer.py +++ b/pytorch_optimizer/base/optimizer.py @@ -258,6 +258,11 @@ def validate_learning_rate(learning_rate: Optional[float]) -> None: if learning_rate is not None and learning_rate < 0.0: raise NegativeLRError(learning_rate) + @staticmethod + def validate_mod(x: int, y: int) -> None: + if x % y != 0: + raise ValueError(f'[-] {x} must be divisible by {y}') + def validate_betas(self, betas: BETAS) -> None: if betas[0] is not None: self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type='[]') diff --git a/tests/constants.py b/tests/constants.py index 55a2e0ba8..bc42a266e 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -134,6 +134,7 @@ 'fadam', 'grokfastadamw', 'stableadamw', + 'adammini', ] VALID_LR_SCHEDULER_NAMES: List[str] = [ diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index bde4f45bc..c5831884b 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names): def test_get_supported_optimizers(): - assert len(get_supported_optimizers()) == 70 + assert len(get_supported_optimizers()) == 71 def test_get_supported_lr_schedulers(): diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 97b195144..07f25c00f 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -650,3 +650,11 @@ def test_stableadamw_optimizer(environment): optimizer = load_optimizer('StableAdamW')(model.parameters()) optimizer.reset() optimizer.step() + + +def test_adam_mini_optimizer(environment): + _, model, _ = environment + + optimizer = load_optimizer('AdamMini')(model) + optimizer.reset() + optimizer.step() From 2c7b60689a65882afe8c11dceb3ce82c5fdf48da Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 6 Jul 2024 17:02:37 +0900 Subject: [PATCH 06/11] update: test_mod --- tests/test_base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_base.py b/tests/test_base.py index c919d91e8..1a82bf271 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -34,3 +34,8 @@ def test_validate_boundary(): def test_validate_range(range_type): with pytest.raises(ValueError): BaseOptimizer.validate_range(-1.0, 'x', 0.0, 1.0, range_type=range_type) + + +def test_mod(): + with pytest.raises(ValueError): + BaseOptimizer.validate_mod(10, 3) From 4d774fa642bfdb1809a7e4169e563e3a058cdc76 Mon Sep 17 00:00:00 2001 From: kozistr Date: Fri, 21 Jun 2024 17:05:38 +0900 Subject: [PATCH 07/11] update: test_no_gradients --- tests/test_gradients.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_gradients.py b/tests/test_gradients.py index a1c29e15b..216075a0e 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize('optimizer_name', [*VALID_OPTIMIZER_NAMES, 'lookahead']) def test_no_gradients(optimizer_name): - if optimizer_name == 'lomo': + if optimizer_name in {'lomo', 'adammini'}: pytest.skip(f'skip {optimizer_name} optimizer.') p1 = simple_parameter(require_grad=True) From 4f497d9433345e00b419fbf52227acc433287f61 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 23 Jun 2024 17:06:35 +0900 Subject: [PATCH 08/11] update: test_sparse_not_supported --- tests/test_gradients.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_gradients.py b/tests/test_gradients.py index 216075a0e..53c5033df 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -39,7 +39,7 @@ def test_no_gradients(optimizer_name): @pytest.mark.parametrize('no_sparse_optimizer', NO_SPARSE_OPTIMIZERS) def test_sparse_not_supported(no_sparse_optimizer): - if no_sparse_optimizer in ('lomo', 'bsam'): + if no_sparse_optimizer in ('lomo', 'bsam', 'adammini'): pytest.skip(f'skip {no_sparse_optimizer} optimizer.') param = simple_sparse_parameter()[1] From beb84d48322c9ff6bcfbfb489faedc24c494f876 Mon Sep 17 00:00:00 2001 From: kozistr Date: Mon, 24 Jun 2024 17:07:12 +0900 Subject: [PATCH 09/11] update: test_bf16_gradient --- tests/test_gradients.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_gradients.py b/tests/test_gradients.py index 53c5033df..290c628fa 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -113,7 +113,7 @@ def test_sparse_supported(sparse_optimizer): @pytest.mark.parametrize('optimizer_name', VALID_OPTIMIZER_NAMES) def test_bf16_gradient(optimizer_name): - if optimizer_name in ('shampoo', 'lomo', 'bsam'): + if optimizer_name in ('shampoo', 'lomo', 'bsam', 'adammini'): pytest.skip(f'skip {optimizer_name}') param = torch.randn(1, 1).bfloat16().requires_grad_(True) From bede5476550acc64aa0ab55080cb8d1c081c03a7 Mon Sep 17 00:00:00 2001 From: kozistr Date: Thu, 27 Jun 2024 17:10:38 +0900 Subject: [PATCH 10/11] update: disable coverage --- pytorch_optimizer/optimizer/adam_mini.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_optimizer/optimizer/adam_mini.py b/pytorch_optimizer/optimizer/adam_mini.py index 9e3eda137..d38dab100 100644 --- a/pytorch_optimizer/optimizer/adam_mini.py +++ b/pytorch_optimizer/optimizer/adam_mini.py @@ -11,7 +11,7 @@ from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS -class AdamMini(Optimizer, BaseOptimizer): +class AdamMini(Optimizer, BaseOptimizer): # pragma: no cover r"""Use Fewer Learning Rates To Gain More. :param model: nn.Module. model instance. @@ -218,7 +218,7 @@ def step_lefts( bias_correction1: float, bias_correction2_sq: float, eps: float, - ) -> None: # pragma: no cover + ) -> None: if len(state) == 0: dim = torch.tensor(p.numel(), device=p.device, dtype=torch.float32) From a9704533b1f83201ddff62edd4a59eb526a8517e Mon Sep 17 00:00:00 2001 From: kozistr Date: Mon, 1 Jul 2024 17:11:03 +0900 Subject: [PATCH 11/11] chore: exclude adam_mini --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index aacb8e780..0680c9fef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,6 +122,7 @@ testpaths = "tests" [tool.coverage.run] omit = [ "./pytorch_optimizer/optimizer/rotograd.py", + "./pytorch_optimizer/optimizer/adam_mini.py", ] [build-system]