Skip to content

Commit

Permalink
update: AdamMini optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Jul 6, 2024
1 parent 337bb0e commit beea7da
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 1 deletion.
2 changes: 2 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from pytorch_optimizer.optimizer.adahessian import AdaHessian
from pytorch_optimizer.optimizer.adai import Adai
from pytorch_optimizer.optimizer.adalite import Adalite
from pytorch_optimizer.optimizer.adam_mini import AdamMini
from pytorch_optimizer.optimizer.adamax import AdaMax
from pytorch_optimizer.optimizer.adamod import AdaMod
from pytorch_optimizer.optimizer.adamp import AdamP
Expand Down Expand Up @@ -203,6 +204,7 @@
GrokFastAdamW,
Kate,
StableAdamW,
AdamMini,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
5 changes: 5 additions & 0 deletions pytorch_optimizer/base/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ def validate_learning_rate(learning_rate: Optional[float]) -> None:
if learning_rate is not None and learning_rate < 0.0:
raise NegativeLRError(learning_rate)

@staticmethod
def validate_mod(x: int, y: int) -> None:
if x % y != 0:
raise ValueError(f'[-] {x} must be divisible by {y}')

def validate_betas(self, betas: BETAS) -> None:
if betas[0] is not None:
self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type='[]')
Expand Down
1 change: 1 addition & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
'fadam',
'grokfastadamw',
'stableadamw',
'adammini',
]

VALID_LR_SCHEDULER_NAMES: List[str] = [
Expand Down
2 changes: 1 addition & 1 deletion tests/test_load_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):


def test_get_supported_optimizers():
assert len(get_supported_optimizers()) == 70
assert len(get_supported_optimizers()) == 71


def test_get_supported_lr_schedulers():
Expand Down
8 changes: 8 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,3 +650,11 @@ def test_stableadamw_optimizer(environment):
optimizer = load_optimizer('StableAdamW')(model.parameters())
optimizer.reset()
optimizer.step()


def test_adam_mini_optimizer(environment):
_, model, _ = environment

optimizer = load_optimizer('AdamMini')(model)
optimizer.reset()
optimizer.step()

0 comments on commit beea7da

Please sign in to comment.