From d728f9e70cc52dda75f7ae97d075a32729a6dfe3 Mon Sep 17 00:00:00 2001 From: kozistr Date: Tue, 13 Aug 2024 12:17:45 +0900 Subject: [PATCH] update: AdamG optimizer --- pytorch_optimizer/__init__.py | 2 ++ tests/constants.py | 3 +++ tests/test_load_modules.py | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 25241f9f4..9b0ba60f3 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -44,6 +44,7 @@ from pytorch_optimizer.optimizer.adalite import Adalite from pytorch_optimizer.optimizer.adam_mini import AdamMini from pytorch_optimizer.optimizer.adamax import AdaMax +from pytorch_optimizer.optimizer.adamg import AdamG from pytorch_optimizer.optimizer.adamod import AdaMod from pytorch_optimizer.optimizer.adamp import AdamP from pytorch_optimizer.optimizer.adams import AdamS @@ -206,6 +207,7 @@ StableAdamW, AdamMini, AdaLOMO, + AdamG, ] OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST} diff --git a/tests/constants.py b/tests/constants.py index 072645cee..947667a9f 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -25,6 +25,7 @@ Adai, Adalite, AdaMax, + AdamG, AdaMod, AdamP, AdamS, @@ -136,6 +137,7 @@ 'grokfastadamw', 'stableadamw', 'adammini', + 'adamg', ] VALID_LR_SCHEDULER_NAMES: List[str] = [ @@ -468,6 +470,7 @@ (GrokFastAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 10), (Kate, {'lr': 5e-2}, 10), (StableAdamW, {'lr': 1e0}, 5), + (AdamG, {'lr': 1e0}, 20), ] ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [ (AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10), diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index 9a23043ca..f1b1d3cd5 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names): def test_get_supported_optimizers(): - assert len(get_supported_optimizers()) == 73 + assert len(get_supported_optimizers()) == 74 def test_get_supported_lr_schedulers():