From b7ddc4a0587fef438f809fb5de65dcfb3dc986f2 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 6 Jul 2024 15:41:53 +0900 Subject: [PATCH] update: StableAdamW optimizer --- pytorch_optimizer/__init__.py | 2 ++ tests/constants.py | 3 +++ tests/test_load_modules.py | 2 +- tests/test_optimizers.py | 10 ++++++++++ 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 340232fe6..3e1569a49 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -44,6 +44,7 @@ from pytorch_optimizer.optimizer.adamod import AdaMod from pytorch_optimizer.optimizer.adamp import AdamP from pytorch_optimizer.optimizer.adams import AdamS +from pytorch_optimizer.optimizer.adamw import StableAdamW from pytorch_optimizer.optimizer.adan import Adan from pytorch_optimizer.optimizer.adanorm import AdaNorm from pytorch_optimizer.optimizer.adapnm import AdaPNM @@ -201,6 +202,7 @@ FAdam, GrokFastAdamW, Kate, + StableAdamW, ] OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST} diff --git a/tests/constants.py b/tests/constants.py index 7daf9d6a8..55a2e0ba8 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -67,6 +67,7 @@ Shampoo, SignSGD, SophiaH, + StableAdamW, Tiger, Yogi, ) @@ -132,6 +133,7 @@ 'schedulefreeadamw', 'fadam', 'grokfastadamw', + 'stableadamw', ] VALID_LR_SCHEDULER_NAMES: List[str] = [ @@ -463,6 +465,7 @@ (FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5), (GrokFastAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 10), (Kate, {'lr': 5e-2}, 10), + (StableAdamW, {'lr': 1e0}, 5), ] ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [ (AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10), diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index d13f0877b..bde4f45bc 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names): def test_get_supported_optimizers(): - assert len(get_supported_optimizers()) == 69 + assert len(get_supported_optimizers()) == 70 def test_get_supported_lr_schedulers(): diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index e0abec4e2..97b195144 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -640,3 +640,13 @@ def test_grokfast_ema(environment): model.fc2.bias.grad = torch.randn(1) _ = gradfilter_ema(model, None) + + +def test_stableadamw_optimizer(environment): + _, model, _ = environment + + model.fc1.weight.data = torch.randn(2, 2, dtype=torch.float16) + + optimizer = load_optimizer('StableAdamW')(model.parameters()) + optimizer.reset() + optimizer.step()