Skip to content

Commit

Permalink
Merge pull request #277 from kozistr/update/uv
Browse files Browse the repository at this point in the history
[CI] to Python 3.12
  • Loading branch information
kozistr authored Sep 29, 2024
2 parents 3cdf496 + c76c1ee commit 09a2b83
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 121 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.11']
python-version: ['3.12']

steps:
- uses: actions/checkout@v4
Expand All @@ -40,4 +40,4 @@ jobs:
files: ./coverage.xml
env_vars: OS,PYTHON
fail_ci_if_error: true
verbose: true
verbose: false
4 changes: 2 additions & 2 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Setup Python 3.11
- name: Setup Python 3.12
uses: actions/setup-python@v5
with:
python-version: 3.11
python-version: 3.12
cache: 'pip'
- name: Install dependencies
run: |
Expand Down
9 changes: 0 additions & 9 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,3 @@ _build/
.netrwhist
.vscode/*
.history
*.pt
*.pth
*.ckpt
*.pkl
*.onnx
*.pb
*.csv
*.tsv
*.ftr
42 changes: 21 additions & 21 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

75 changes: 38 additions & 37 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ classifiers = [
]

[tool.poetry.dependencies]
python = ">=3.8,<4.0.0"
python = ">=3.8"
numpy = [
{ version = ">1.24.4", python = ">=3.9" },
{ version = "<=1.24.4", python = "<3.9" },
Expand All @@ -68,20 +68,13 @@ url = "https://download.pytorch.org/whl/cpu"
priority = "explicit"

[tool.ruff]
lint.select = [
"A", "B", "C4", "D", "E", "F", "G", "I", "N", "S", "T", "ISC", "ICN", "W", "INP", "PIE", "T20", "RET", "SIM",
"TID", "ARG", "ERA", "RUF", "YTT", "PL", "Q"
]
lint.ignore = [
"B905", "D100", "D102", "D104", "D105", "D107", "D203", "D213", "D413", "PIE790", "PLR0912", "PLR0913", "PLR0915",
"PLR2004", "RUF013", "Q003", "ARG002",
src = [
"pytorch_optimizer",
"tests",
"examples",
]
lint.fixable = ["ALL"]
lint.unfixable = ["F401"]
lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
lint.flake8-quotes.docstring-quotes = "double"
lint.flake8-quotes.inline-quotes = "single"
lint.pylint.max-args = 7
target-version = "py312"
line-length = 119
exclude = [
".git",
".github",
Expand All @@ -95,32 +88,40 @@ exclude = [
".venv",
"__pypackages__",
]
line-length = 119
target-version = "py311"

[tool.ruff.lint]
select = [
"A", "B", "C4", "D", "E", "F", "G", "I", "N", "S", "T", "ISC", "ICN", "W", "INP", "PIE", "T20", "RET", "SIM",
"TID", "ARG", "ERA", "RUF", "YTT", "PL", "Q"
]
ignore = [
"B905", "D100", "D102", "D104", "D105", "D107", "D203", "D213", "D413", "PIE790", "PLR0912", "PLR0913", "PLR0915",
"PLR2004", "RUF013", "Q003", "ARG002",
]
fixable = ["ALL"]
unfixable = ["F401"]
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
flake8-quotes.docstring-quotes = "double"
flake8-quotes.inline-quotes = "single"

[tool.ruff.lint.extend-per-file-ignores]
"hubconf.py" = ["D", "INP001"]
"examples/visualize_optimizers.py" = ["D103", "D400", "D415"]
"**/__init__.py" = ["F401"]
"{tests,examples}/*.py" = ["D", "S101"]

[tool.ruff.lint.isort]
combine-as-imports = false
detect-same-package = true
force-sort-within-sections = false
known-first-party = ["pytorch_optimizer"]

[tool.ruff.lint.pylint]
max-args = 7

[tool.ruff.format]
quote-style = "single"

[tool.ruff.lint.per-file-ignores]
"./pytorch_optimizer/__init__.py" = ["F401"]
"./pytorch_optimizer/lr_scheduler/__init__.py" = ["F401"]
"./hubconf.py" = ["D", "INP001"]
"./examples/visualize_optimizers.py" = ["D103", "D400", "D415"]
"./tests/__init__.py" = ["D"]
"./tests/constants.py" = ["D"]
"./tests/utils.py" = ["D"]
"./tests/test_base.py" = ["D", "S101"]
"./tests/test_utils.py" = ["D", "S101", "ERA001"]
"./tests/test_gradients.py" = ["D", "S101"]
"./tests/test_optimizers.py" = ["D", "S101"]
"./tests/test_optimizer_parameters.py" = ["D", "S101"]
"./tests/test_general_optimizer_parameters.py" = ["D", "S101"]
"./tests/test_lr_schedulers.py" = ["D", "S101"]
"./tests/test_lr_scheduler_parameters.py" = ["D", "S101"]
"./tests/test_create_optimizer.py" = ["D"]
"./tests/test_loss_functions.py" = ["D", "S101"]
"./tests/test_load_modules.py" = ["D", "S101"]

[tool.pytest.ini_options]
testpaths = "tests"

Expand All @@ -131,5 +132,5 @@ omit = [
]

[build-system]
requires = ["poetry-core>=1.0.0"]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"
6 changes: 6 additions & 0 deletions pytorch_optimizer/optimizer/soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class SOAP(BaseOptimizer):
:param merge_dims: bool. whether to merge dimensions of the pre-conditioner
:param precondition_1d: bool. whether to precondition 1D gradients.
:param correct_bias: bool. whether to correct bias in Adam.
:param normalize_gradient: bool. whether to normalize the gradients.
:param eps: float. term added to the denominator to improve numerical stability.
"""

Expand All @@ -40,6 +41,7 @@ def __init__(
merge_dims: bool = False,
precondition_1d: bool = False,
correct_bias: bool = True,
normalize_gradient: bool = False,
data_format: DATA_FORMAT = 'channels_first',
eps: float = 1e-8,
**kwargs,
Expand All @@ -64,6 +66,7 @@ def __init__(
'merge_dims': merge_dims,
'precondition_1d': precondition_1d,
'correct_bias': correct_bias,
'normalize_gradient': normalize_gradient,
'eps': eps,
}
super().__init__(params, defaults)
Expand Down Expand Up @@ -312,6 +315,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
project_type='backward',
)

if group['normalize_gradient']:
norm_grad.div_(torch.mean(norm_grad.square()).sqrt_().add_(group['eps']))

p.add_(norm_grad, alpha=-step_size)

self.apply_weight_decay(
Expand Down
54 changes: 27 additions & 27 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
--extra-index-url https://download.pytorch.org/whl/cpu

black==24.8.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
click==8.1.7 ; python_version >= "3.8" and python_full_version < "4.0.0"
colorama==0.4.6 ; python_version >= "3.8" and python_full_version < "4.0.0" and (sys_platform == "win32" or platform_system == "Windows")
coverage[toml]==7.6.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
exceptiongroup==1.2.2 ; python_version >= "3.8" and python_version < "3.11"
filelock==3.16.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
fsspec==2024.9.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
iniconfig==2.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
isort==5.13.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
jinja2==3.1.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
markupsafe==2.1.5 ; python_version >= "3.8" and python_full_version < "4.0.0"
mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
mypy-extensions==1.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
numpy==1.24.4 ; python_version >= "3.8" and python_version < "3.9"
numpy==2.0.2 ; python_version >= "3.9" and python_full_version < "4.0.0"
packaging==24.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
pathspec==0.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
platformdirs==4.3.6 ; python_version >= "3.8" and python_full_version < "4.0.0"
pluggy==1.5.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
pytest-cov==5.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
pytest==8.3.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
ruff==0.6.7 ; python_version >= "3.8" and python_full_version < "4.0.0"
sympy==1.13.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6"
torch==2.4.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
typing-extensions==4.12.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
black==24.8.0 ; python_version >= "3.8"
click==8.1.7 ; python_version >= "3.8"
colorama==0.4.6 ; python_version >= "3.8" and (sys_platform == "win32" or platform_system == "Windows")
coverage[toml]==7.6.1 ; python_version >= "3.8"
exceptiongroup==1.2.2 ; python_version < "3.11" and python_version >= "3.8"
filelock==3.16.1 ; python_version >= "3.8"
fsspec==2024.9.0 ; python_version >= "3.8"
iniconfig==2.0.0 ; python_version >= "3.8"
isort==5.13.2 ; python_version >= "3.8"
jinja2==3.1.4 ; python_version >= "3.8"
markupsafe==2.1.5 ; python_version >= "3.8"
mpmath==1.3.0 ; python_version >= "3.8"
mypy-extensions==1.0.0 ; python_version >= "3.8"
networkx==3.1 ; python_version >= "3.8"
numpy==1.24.4 ; python_version < "3.9" and python_version >= "3.8"
numpy==2.0.2 ; python_version >= "3.9"
packaging==24.1 ; python_version >= "3.8"
pathspec==0.12.1 ; python_version >= "3.8"
platformdirs==4.3.6 ; python_version >= "3.8"
pluggy==1.5.0 ; python_version >= "3.8"
pytest-cov==5.0.0 ; python_version >= "3.8"
pytest==8.3.3 ; python_version >= "3.8"
ruff==0.6.8 ; python_version >= "3.8"
sympy==1.13.3 ; python_version >= "3.8"
tomli==2.0.1 ; python_full_version <= "3.11.0a6" and python_version >= "3.8"
torch==2.4.1+cpu ; python_version >= "3.8"
typing-extensions==4.12.2 ; python_version >= "3.8"
22 changes: 11 additions & 11 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
--extra-index-url https://download.pytorch.org/whl/cpu

filelock==3.16.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
fsspec==2024.9.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
jinja2==3.1.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
markupsafe==2.1.5 ; python_version >= "3.8" and python_full_version < "4.0.0"
mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
numpy==1.24.4 ; python_version >= "3.8" and python_version < "3.9"
numpy==2.0.2 ; python_version >= "3.9" and python_full_version < "4.0.0"
sympy==1.13.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
torch==2.4.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
typing-extensions==4.12.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
filelock==3.16.1 ; python_version >= "3.8"
fsspec==2024.9.0 ; python_version >= "3.8"
jinja2==3.1.4 ; python_version >= "3.8"
markupsafe==2.1.5 ; python_version >= "3.8"
mpmath==1.3.0 ; python_version >= "3.8"
networkx==3.1 ; python_version >= "3.8"
numpy==1.24.4 ; python_version < "3.9" and python_version >= "3.8"
numpy==2.0.2 ; python_version >= "3.9"
sympy==1.13.3 ; python_version >= "3.8"
torch==2.4.1+cpu ; python_version >= "3.8"
typing-extensions==4.12.2 ; python_version >= "3.8"
2 changes: 1 addition & 1 deletion tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@
(
SOAP,
{'lr': 1e0, 'shampoo_beta': 0.95, 'precondition_frequency': 1, 'merge_dims': False, 'precondition_1d': True},
5,
3,
),
]
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
Expand Down
8 changes: 7 additions & 1 deletion tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,13 @@ def test_trac_optimizer_erf_imag():
'params',
[
{'merge_dims': True, 'precondition_1d': True, 'max_precondition_dim': 4, 'precondition_frequency': 1},
{'merge_dims': True, 'precondition_1d': False, 'max_precondition_dim': 1, 'precondition_frequency': 1},
{
'merge_dims': True,
'precondition_1d': False,
'max_precondition_dim': 1,
'precondition_frequency': 1,
'normalize_gradient': True,
},
],
)
def test_soap_parameters(params, environment):
Expand Down
10 changes: 0 additions & 10 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,37 +136,28 @@ def test_running_stats():


def test_compute_power():
# case 1 : len(x.shape) == 1
x = compute_power_schur_newton(torch.zeros((1,)), p=1)
assert torch.tensor([1000000.0]) == x

# case 2 : len(x.shape) != 1 and x.shape[0] == 1
x = compute_power_schur_newton(torch.zeros((1, 2)), p=1)
assert torch.tensor([1.0]) == x

# case 3 : len(x.shape) != 1 and x.shape[0] != 1, n&n-1 != 0
# it doesn't work on torch 2.1.1+cpu
_ = compute_power_schur_newton(torch.ones((2, 2)), p=3)

# case 4 : p=1
x = compute_power_schur_newton(torch.ones((2, 2)), p=1)
assert np.sum(x.numpy() - np.asarray([[252206.4062, -252205.8750], [-252205.8750, 252206.4062]])) < 200

# case 5 : p=8
_ = compute_power_schur_newton(torch.ones((2, 2)), p=8)

# case 6 : p=16
_ = compute_power_schur_newton(torch.ones((2, 2)), p=16)

# case 7 : max_error_ratio=0
x = compute_power_schur_newton(torch.ones((2, 2)), p=16, max_error_ratio=0.0)
np.testing.assert_array_almost_equal(
np.asarray([[1.0946, 0.0000], [0.0000, 1.0946]]),
x.numpy(),
decimal=2,
)

# case 8 : p=2
x = compute_power_schur_newton(torch.ones((2, 2)), p=2)
assert np.sum(x.numpy() - np.asarray([[359.1108, -358.4036], [-358.4036, 359.1108]])) < 50

Expand Down Expand Up @@ -217,7 +208,6 @@ def test_pre_conditioner_type(pre_conditioner_type):
if pre_conditioner_type in (0, 1, 2):
PreConditioner(var, 0.9, 0, 128, 1, 8192, True, pre_conditioner_type=pre_conditioner_type)
else:
# invalid pre-conditioner type
with pytest.raises(ValueError):
PreConditioner(var, 0.9, 0, 128, 1, 8192, True, pre_conditioner_type=pre_conditioner_type)

Expand Down

0 comments on commit 09a2b83

Please sign in to comment.