Skip to content

Commit

Permalink
update: StableAdamW optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Jul 6, 2024
1 parent cfe887f commit b7ddc4a
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 1 deletion.
2 changes: 2 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from pytorch_optimizer.optimizer.adamod import AdaMod
from pytorch_optimizer.optimizer.adamp import AdamP
from pytorch_optimizer.optimizer.adams import AdamS
from pytorch_optimizer.optimizer.adamw import StableAdamW
from pytorch_optimizer.optimizer.adan import Adan
from pytorch_optimizer.optimizer.adanorm import AdaNorm
from pytorch_optimizer.optimizer.adapnm import AdaPNM
Expand Down Expand Up @@ -201,6 +202,7 @@
FAdam,
GrokFastAdamW,
Kate,
StableAdamW,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
3 changes: 3 additions & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
Shampoo,
SignSGD,
SophiaH,
StableAdamW,
Tiger,
Yogi,
)
Expand Down Expand Up @@ -132,6 +133,7 @@
'schedulefreeadamw',
'fadam',
'grokfastadamw',
'stableadamw',
]

VALID_LR_SCHEDULER_NAMES: List[str] = [
Expand Down Expand Up @@ -463,6 +465,7 @@
(FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(GrokFastAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 10),
(Kate, {'lr': 5e-2}, 10),
(StableAdamW, {'lr': 1e0}, 5),
]
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_load_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):


def test_get_supported_optimizers():
assert len(get_supported_optimizers()) == 69
assert len(get_supported_optimizers()) == 70


def test_get_supported_lr_schedulers():
Expand Down
10 changes: 10 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,13 @@ def test_grokfast_ema(environment):
model.fc2.bias.grad = torch.randn(1)

_ = gradfilter_ema(model, None)


def test_stableadamw_optimizer(environment):
_, model, _ = environment

model.fc1.weight.data = torch.randn(2, 2, dtype=torch.float16)

optimizer = load_optimizer('StableAdamW')(model.parameters())
optimizer.reset()
optimizer.step()

0 comments on commit b7ddc4a

Please sign in to comment.