Skip to content

Commit

Permalink
Merge pull request #97 from kozistr/feature/shampoo-optimizer
Browse files Browse the repository at this point in the history
[Feature] Re-Implement Shampoo Optimizer w/ Grafting & Partitioner
  • Loading branch information
kozistr authored Jan 30, 2023
2 parents 0567ae9 + e792181 commit 5df1281
Show file tree
Hide file tree
Showing 45 changed files with 1,038 additions and 1,148 deletions.
433 changes: 0 additions & 433 deletions .pylintrc

This file was deleted.

3 changes: 1 addition & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ test:
python -m pytest -p no:pastebin -p no:nose -p no:doctest -sv -vv --cov=pytorch_optimizer --cov-report=xml ./tests

check:
isort --check-only --profile black -l 119 pytorch_optimizer tests hubconf.py
black -S -l 119 --check pytorch_optimizer tests hubconf.py
pylint --fail-under=10.0 pytorch_optimizer
ruff pytorch_optimizer tests hubconf.py

requirements:
python -m poetry export -f requirements.txt --output requirements.txt --without-hashes
Expand Down
8 changes: 0 additions & 8 deletions docs/util_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@ get_optimizer_parameters
.. autoclass:: pytorch_optimizer.get_optimizer_parameters
:members:

.. _matrix_power:

matrix_power
------------

.. autoclass:: pytorch_optimizer.matrix_power
:members:

.. _normalize_gradient:

normalize_gradient
Expand Down
365 changes: 99 additions & 266 deletions poetry.lock

Large diffs are not rendered by default.

45 changes: 43 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytorch_optimizer"
version = "2.2.1"
version = "2.3.0"
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
license = "Apache-2.0"
authors = ["kozistr <[email protected]>"]
Expand Down Expand Up @@ -42,7 +42,7 @@ torch = { version = "^1.10", source = "torch"}
[tool.poetry.dev-dependencies]
isort = "^5.11.4"
black = "^22.12.0"
pylint = "^2.15.9"
ruff = "^0.0.237"
pytest = "^7.2.0"
pytest-cov = "^4.0.0"

Expand All @@ -51,9 +51,50 @@ name = "torch"
url = "https://download.pytorch.org/whl/cpu"
secondary = true

[tool.ruff]
select = ["A", "B", "C4", "E", "F", "G", "I", "N", "S", "T", "ISC", "W", "INP", "PIE", "T20", "RET", "SIM", "ARG"]
ignore = []
fixable = ["A", "B", "C", "D", "E", "F"]
unfixable = ["F401"]
exclude = [
".eggs",
".git",
".mypy_cache",
".ruff_cache",
".github",
".venv",
"__pypackages__",
"_build",
"build",
"dist",
"node_modules",
"venv",
"docs",
"assets",
]
line-length = 119
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
target-version = "py39"

[tool.ruff.per-file-ignores]
"./hubconf.py" = ["INP001"]
"./tests/test_utils.py" = ["S101"]
"./tests/test_gradients.py" = ["S101"]
"./tests/test_optimizers.py" = ["S101"]
"./tests/test_optimizer_parameters.py" = ["S101"]
"./tests/test_load_optimizers.py" = ["S101"]
"./tests/test_load_lr_schedulers.py" = ["S101"]
"./tests/test_lr_scheduler_parameters.py" = ["S101"]
"./pytorch_optimizer/__init__.py" = ["F401"]
"./pytorch_optimizer/lr_scheduler/__init__.py" = ["F401"]

[tool.ruff.mccabe]
max-complexity = 10

[tool.coverage.run]
omit = [
"./pytorch_optimizer/optimizer/gsam.py",
"./pytorch_optimizer/optimizer/fp16.py",
]

[build-system]
Expand Down
3 changes: 1 addition & 2 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint: disable=unused-import
# ruff: noqa
from typing import Dict, List

from pytorch_optimizer.base.types import OPTIMIZER, SCHEDULER
Expand Down Expand Up @@ -44,7 +44,6 @@
disable_running_stats,
enable_running_stats,
get_optimizer_parameters,
matrix_power,
normalize_gradient,
unit_norm,
)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_optimizer/base/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from pytorch_optimizer.base.exception import NegativeLRError
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
from pytorch_optimizer.base.types import BETAS


Expand Down Expand Up @@ -90,7 +90,7 @@ def validate_reduction(reduction: str):
@staticmethod
def validate_update_frequency(update_frequency: int):
if update_frequency < 1:
raise ValueError(f'[-] update_frequency {update_frequency} must be positive')
raise NegativeStepError(f'[-] update_frequency {update_frequency} must be positive')

@staticmethod
def validate_norm(norm: float):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_optimizer/lr_scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# pylint: disable=unused-import
# ruff: noqa
from torch.optim.lr_scheduler import ConstantLR, CosineAnnealingLR, CosineAnnealingWarmRestarts, CyclicLR, OneCycleLR
8 changes: 3 additions & 5 deletions pytorch_optimizer/lr_scheduler/chebyshev.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@


def chebyshev_steps(small_m: float, big_m: float, num_epochs: int) -> np.ndarray:
"""chebyshev_steps
r"""chebyshev_steps
:param small_m: float. stands for 'm' notation.
:param big_m: float. stands for 'M' notation.
:param num_epochs: int. stands for 'T' notation.
:return: np.array. chebyshev_steps
:return: np.array. chebyshev_steps.
"""

c, r = (big_m + small_m) / 2.0, (big_m - small_m) / 2.0
Expand All @@ -26,6 +26,4 @@ def chebyshev_perm(num_epochs: int) -> np.ndarray:
def get_chebyshev_schedule(num_epochs: int) -> np.ndarray:
steps: np.ndarray = chebyshev_steps(0.1, 1, num_epochs - 2)
perm: np.ndarray = chebyshev_perm(num_epochs - 2)
chebyshev_schedule = steps[perm]

return chebyshev_schedule
return steps[perm]
22 changes: 11 additions & 11 deletions pytorch_optimizer/optimizer/adabelief.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ def __init__(

self.validate_parameters()

defaults: DEFAULTS = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
adamd_debias_term=adamd_debias_term,
buffer=[[None, None, None] for _ in range(10)],
)
defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
'eps': eps,
'weight_decay': weight_decay,
'amsgrad': amsgrad,
'adamd_debias_term': adamd_debias_term,
'buffer': [[None, None, None] for _ in range(10)],
}
super().__init__(params, defaults)

def validate_parameters(self):
Expand All @@ -71,7 +71,7 @@ def validate_parameters(self):
self.validate_epsilon(self.eps)

@property
def __name__(self) -> str:
def __str__(self) -> str:
return 'AdaBelief'

@torch.no_grad()
Expand Down Expand Up @@ -106,7 +106,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(self.__name__)
raise NoSparseGradientError(self.__str__)

state = self.state[p]
if len(state) == 0:
Expand Down
24 changes: 12 additions & 12 deletions pytorch_optimizer/optimizer/adabound.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ def __init__(

self.validate_parameters()

defaults: DEFAULTS = dict(
lr=lr,
betas=betas,
final_lr=final_lr,
gamma=gamma,
weight_decay=weight_decay,
amsbound=amsbound,
adamd_debias_term=adamd_debias_term,
eps=eps,
)
defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
'final_lr': final_lr,
'gamma': gamma,
'weight_decay': weight_decay,
'amsbound': amsbound,
'adamd_debias_term': adamd_debias_term,
'eps': eps,
}
super().__init__(params, defaults)

self.base_lrs: List[float] = [group['lr'] for group in self.param_groups]
Expand All @@ -69,7 +69,7 @@ def validate_parameters(self):
self.validate_epsilon(self.eps)

@property
def __name__(self) -> str:
def __str__(self) -> str:
return 'AdaBound'

@torch.no_grad()
Expand Down Expand Up @@ -100,7 +100,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(self.__name__)
raise NoSparseGradientError(self.__str__)

state = self.state[p]

Expand Down
19 changes: 9 additions & 10 deletions pytorch_optimizer/optimizer/adai.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def __init__(

self.validate_parameters()

defaults: DEFAULTS = dict(
lr=lr,
betas=betas,
weight_decay=weight_decay,
dampening=dampening,
eps=eps,
)
defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
'weight_decay': weight_decay,
'dampening': dampening,
'eps': eps,
}
super().__init__(params, defaults)

def validate_parameters(self):
Expand All @@ -60,7 +60,7 @@ def validate_parameters(self):
self.validate_epsilon(self.eps)

@property
def __name__(self) -> str:
def __str__(self) -> str:
return 'Adai'

@torch.no_grad()
Expand Down Expand Up @@ -92,7 +92,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(self.__name__)
raise NoSparseGradientError(self.__str__)

param_size += p.numel()

Expand All @@ -105,7 +105,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
state['beta1_prod'] = torch.ones_like(p)

state['step'] += 1

exp_avg_sq = state['exp_avg_sq']

if self.use_gc:
Expand Down
24 changes: 12 additions & 12 deletions pytorch_optimizer/optimizer/adamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ def __init__(

self.validate_parameters()

defaults: DEFAULTS = dict(
lr=lr,
betas=betas,
weight_decay=weight_decay,
delta=delta,
wd_ratio=wd_ratio,
nesterov=nesterov,
adamd_debias_term=adamd_debias_term,
eps=eps,
)
defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
'weight_decay': weight_decay,
'delta': delta,
'wd_ratio': wd_ratio,
'nesterov': nesterov,
'adamd_debias_term': adamd_debias_term,
'eps': eps,
}
super().__init__(params, defaults)

def validate_parameters(self):
Expand All @@ -68,7 +68,7 @@ def validate_parameters(self):
self.validate_epsilon(self.eps)

@property
def __name__(self) -> str:
def __str__(self) -> str:
return 'AdamP'

@torch.no_grad()
Expand Down Expand Up @@ -96,7 +96,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(self.__name__)
raise NoSparseGradientError(self.__str__)

state = self.state[p]
if len(state) == 0:
Expand Down
20 changes: 10 additions & 10 deletions pytorch_optimizer/optimizer/adan.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ def __init__(

self.validate_parameters()

defaults: DEFAULTS = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
weight_decouple=weight_decouple,
max_grad_norm=max_grad_norm,
)
defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
'weight_decay': weight_decay,
'weight_decouple': weight_decouple,
'max_grad_norm': max_grad_norm,
'eps': eps,
}
super().__init__(params, defaults)

def validate_parameters(self):
Expand All @@ -62,7 +62,7 @@ def validate_parameters(self):
self.validate_norm(self.max_grad_norm)

@property
def __name__(self) -> str:
def __str__(self) -> str:
return 'Adan'

@torch.no_grad()
Expand Down Expand Up @@ -122,7 +122,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(self.__name__)
raise NoSparseGradientError(self.__str__)

state = self.state[p]
if len(state) == 0:
Expand Down
Loading

0 comments on commit 5df1281

Please sign in to comment.