diff --git a/README.md b/README.md index 2567280d..acd0f1bd 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **75 optimizers (+ `bitsandbytes`, `qgalore`)**, **16 lr schedulers**, and **13 loss functions** are supported! +Currently, **76 optimizers (+ `bitsandbytes`, `qgalore`)**, **16 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -173,6 +173,7 @@ supported_optimizers = get_supported_optimizers() | AdamMini | *Use Fewer Learning Rates To Gain More* | [github](https://github.com/zyushun/Adam-mini) | | [cite](https://github.com/zyushun/Adam-mini?tab=readme-ov-file#citation) | | TRAC | *Adaptive Parameter-free Optimization* | [github](https://github.com/ComputationalRobotics/TRAC) | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240516642M/exportcitation) | | AdamG | *Towards Stability of Parameter-free Optimization* | | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240504376P/exportcitation) | +| AdEMAMix | *Better, Faster, Older* | [github](https://github.com/nanowell/AdEMAMix-Optimizer-Pytorch) | | [cite](https://github.com/nanowell/AdEMAMix-Optimizer-Pytorch?tab=readme-ov-file#reference) | ## Supported LR Scheduler diff --git a/docs/changelogs/v3.1.2.md b/docs/changelogs/v3.1.2.md index 70cff062..3e2b454b 100644 --- a/docs/changelogs/v3.1.2.md +++ b/docs/changelogs/v3.1.2.md @@ -1,5 +1,10 @@ ## Change Log +### Feature + +* Implement `AdEMAMix` optimizer. (#272) + * [THE ADEMAMIX OPTIMIZER: BETTER, FASTER, OLDER](https://arxiv.org/pdf/2409.03137) + ### Bug * Add `**kwargs` to the parameters for dummy placeholder. (#270, #271) diff --git a/docs/index.md b/docs/index.md index 2567280d..acd0f1bd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **75 optimizers (+ `bitsandbytes`, `qgalore`)**, **16 lr schedulers**, and **13 loss functions** are supported! +Currently, **76 optimizers (+ `bitsandbytes`, `qgalore`)**, **16 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -173,6 +173,7 @@ supported_optimizers = get_supported_optimizers() | AdamMini | *Use Fewer Learning Rates To Gain More* | [github](https://github.com/zyushun/Adam-mini) | | [cite](https://github.com/zyushun/Adam-mini?tab=readme-ov-file#citation) | | TRAC | *Adaptive Parameter-free Optimization* | [github](https://github.com/ComputationalRobotics/TRAC) | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240516642M/exportcitation) | | AdamG | *Towards Stability of Parameter-free Optimization* | | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240504376P/exportcitation) | +| AdEMAMix | *Better, Faster, Older* | [github](https://github.com/nanowell/AdEMAMix-Optimizer-Pytorch) | | [cite](https://github.com/nanowell/AdEMAMix-Optimizer-Pytorch?tab=readme-ov-file#reference) | ## Supported LR Scheduler diff --git a/docs/optimizer.md b/docs/optimizer.md index e6f5d5b8..35f1d293 100644 --- a/docs/optimizer.md +++ b/docs/optimizer.md @@ -80,6 +80,10 @@ :docstring: :members: +::: pytorch_optimizer.AdEMAMix + :docstring: + :members: + ::: pytorch_optimizer.agc :docstring: :members: diff --git a/poetry.lock b/poetry.lock index 7bef65d2..9d9f4079 100644 --- a/poetry.lock +++ b/poetry.lock @@ -193,29 +193,29 @@ test = ["pytest (>=6)"] [[package]] name = "filelock" -version = "3.15.4" +version = "3.16.0" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, - {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, + {file = "filelock-3.16.0-py3-none-any.whl", hash = "sha256:f6ed4c963184f4c84dd5557ce8fece759a3724b37b80c6c4f20a2f63a4dc6609"}, + {file = "filelock-3.16.0.tar.gz", hash = "sha256:81de9eb8453c769b63369f87f11131a7ab04e367f8d97ad39dc230daa07e3bec"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] -typing = ["typing-extensions (>=4.8)"] +docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.1.1)", "pytest (>=8.3.2)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.3)"] +typing = ["typing-extensions (>=4.12.2)"] [[package]] name = "fsspec" -version = "2024.6.1" +version = "2024.9.0" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.6.1-py3-none-any.whl", hash = "sha256:3cb443f8bcd2efb31295a5b9fdb02aee81d8452c80d28f97a6d0959e6cee101e"}, - {file = "fsspec-2024.6.1.tar.gz", hash = "sha256:fad7d7e209dd4c1208e3bbfda706620e0da5142bebbd9c384afb95b07e798e49"}, + {file = "fsspec-2024.9.0-py3-none-any.whl", hash = "sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b"}, + {file = "fsspec-2024.9.0.tar.gz", hash = "sha256:4b0afb90c2f21832df142f292649035d80b421f60a9e1c027802e5a0da2b04e8"}, ] [package.extras] @@ -464,19 +464,19 @@ files = [ [[package]] name = "platformdirs" -version = "4.2.2" +version = "4.3.2" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" files = [ - {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"}, - {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"}, + {file = "platformdirs-4.3.2-py3-none-any.whl", hash = "sha256:eb1c8582560b34ed4ba105009a4badf7f6f85768b30126f351328507b2beb617"}, + {file = "platformdirs-4.3.2.tar.gz", hash = "sha256:9e5e27a08aa095dd127b9f2e764d74254f482fef22b0970773bfba79d091ab8c"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] -type = ["mypy (>=1.8)"] +docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.2)", "pytest-cov (>=5)", "pytest-mock (>=3.14)"] +type = ["mypy (>=1.11.2)"] [[package]] name = "pluggy" @@ -535,29 +535,29 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] [[package]] name = "ruff" -version = "0.5.7" +version = "0.6.4" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.5.7-py3-none-linux_armv6l.whl", hash = "sha256:548992d342fc404ee2e15a242cdbea4f8e39a52f2e7752d0e4cbe88d2d2f416a"}, - {file = "ruff-0.5.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:00cc8872331055ee017c4f1071a8a31ca0809ccc0657da1d154a1d2abac5c0be"}, - {file = "ruff-0.5.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eaf3d86a1fdac1aec8a3417a63587d93f906c678bb9ed0b796da7b59c1114a1e"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a01c34400097b06cf8a6e61b35d6d456d5bd1ae6961542de18ec81eaf33b4cb8"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcc8054f1a717e2213500edaddcf1dbb0abad40d98e1bd9d0ad364f75c763eea"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f70284e73f36558ef51602254451e50dd6cc479f8b6f8413a95fcb5db4a55fc"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:a78ad870ae3c460394fc95437d43deb5c04b5c29297815a2a1de028903f19692"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ccd078c66a8e419475174bfe60a69adb36ce04f8d4e91b006f1329d5cd44bcf"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e31c9bad4ebf8fdb77b59cae75814440731060a09a0e0077d559a556453acbb"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d796327eed8e168164346b769dd9a27a70e0298d667b4ecee6877ce8095ec8e"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4a09ea2c3f7778cc635e7f6edf57d566a8ee8f485f3c4454db7771efb692c499"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a36d8dcf55b3a3bc353270d544fb170d75d2dff41eba5df57b4e0b67a95bb64e"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9369c218f789eefbd1b8d82a8cf25017b523ac47d96b2f531eba73770971c9e5"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b88ca3db7eb377eb24fb7c82840546fb7acef75af4a74bd36e9ceb37a890257e"}, - {file = "ruff-0.5.7-py3-none-win32.whl", hash = "sha256:33d61fc0e902198a3e55719f4be6b375b28f860b09c281e4bdbf783c0566576a"}, - {file = "ruff-0.5.7-py3-none-win_amd64.whl", hash = "sha256:083bbcbe6fadb93cd86709037acc510f86eed5a314203079df174c40bbbca6b3"}, - {file = "ruff-0.5.7-py3-none-win_arm64.whl", hash = "sha256:2dca26154ff9571995107221d0aeaad0e75a77b5a682d6236cf89a58c70b76f4"}, - {file = "ruff-0.5.7.tar.gz", hash = "sha256:8dfc0a458797f5d9fb622dd0efc52d796f23f0a1493a9527f4e49a550ae9a7e5"}, + {file = "ruff-0.6.4-py3-none-linux_armv6l.whl", hash = "sha256:c4b153fc152af51855458e79e835fb6b933032921756cec9af7d0ba2aa01a258"}, + {file = "ruff-0.6.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:bedff9e4f004dad5f7f76a9d39c4ca98af526c9b1695068198b3bda8c085ef60"}, + {file = "ruff-0.6.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d02a4127a86de23002e694d7ff19f905c51e338c72d8e09b56bfb60e1681724f"}, + {file = "ruff-0.6.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7862f42fc1a4aca1ea3ffe8a11f67819d183a5693b228f0bb3a531f5e40336fc"}, + {file = "ruff-0.6.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eebe4ff1967c838a1a9618a5a59a3b0a00406f8d7eefee97c70411fefc353617"}, + {file = "ruff-0.6.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:932063a03bac394866683e15710c25b8690ccdca1cf192b9a98260332ca93408"}, + {file = "ruff-0.6.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:50e30b437cebef547bd5c3edf9ce81343e5dd7c737cb36ccb4fe83573f3d392e"}, + {file = "ruff-0.6.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c44536df7b93a587de690e124b89bd47306fddd59398a0fb12afd6133c7b3818"}, + {file = "ruff-0.6.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ea086601b22dc5e7693a78f3fcfc460cceabfdf3bdc36dc898792aba48fbad6"}, + {file = "ruff-0.6.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b52387d3289ccd227b62102c24714ed75fbba0b16ecc69a923a37e3b5e0aaaa"}, + {file = "ruff-0.6.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0308610470fcc82969082fc83c76c0d362f562e2f0cdab0586516f03a4e06ec6"}, + {file = "ruff-0.6.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:803b96dea21795a6c9d5bfa9e96127cc9c31a1987802ca68f35e5c95aed3fc0d"}, + {file = "ruff-0.6.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:66dbfea86b663baab8fcae56c59f190caba9398df1488164e2df53e216248baa"}, + {file = "ruff-0.6.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:34d5efad480193c046c86608dbba2bccdc1c5fd11950fb271f8086e0c763a5d1"}, + {file = "ruff-0.6.4-py3-none-win32.whl", hash = "sha256:f0f8968feea5ce3777c0d8365653d5e91c40c31a81d95824ba61d871a11b8523"}, + {file = "ruff-0.6.4-py3-none-win_amd64.whl", hash = "sha256:549daccee5227282289390b0222d0fbee0275d1db6d514550d65420053021a58"}, + {file = "ruff-0.6.4-py3-none-win_arm64.whl", hash = "sha256:ac4b75e898ed189b3708c9ab3fc70b79a433219e1e87193b4f2b77251d058d14"}, + {file = "ruff-0.6.4.tar.gz", hash = "sha256:ac3b5bfbee99973f80aa1b7cbd1c9cbce200883bdd067300c22a6cc1c7fba212"}, ] [[package]] @@ -590,21 +590,21 @@ files = [ [[package]] name = "torch" -version = "2.4.0+cpu" +version = "2.4.1+cpu" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.4.0+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:0e59377b27823dda6d26528febb7ca06fc5b77816eaa58b4420cc8785e33d4ce"}, - {file = "torch-2.4.0+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:53c3f75fa4ef0726e494ebef003b17d8a61c3c9fa4630b465610b462bf06c3de"}, - {file = "torch-2.4.0+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:14a7a8b595347dddca594f9e448b93ce68ce4f871acbd32cf04bda7c03664c0c"}, - {file = "torch-2.4.0+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:3b3cb9a6c17b5a4cea42bb37a243bfbad7659cef6d9b4ee29cb793bdf20f482c"}, - {file = "torch-2.4.0+cpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:78dbf5f2789933a7ea2dabeead4daa44679b1e0d8eb35ddb7071c8ab7b181eb3"}, - {file = "torch-2.4.0+cpu-cp312-cp312-win_amd64.whl", hash = "sha256:f59c53a1c3247efb3700f9f78bdd289712177037a85b5519b9ecdef7c77c1fee"}, - {file = "torch-2.4.0+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:08753c3d776ae49dc9ddbae02e26720a513a4dc7997e41d95392bca71623a0cd"}, - {file = "torch-2.4.0+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:9f376f5a14eb04a44974c3a9dfd857a68090acb435b98e62bbf523baeefac85e"}, - {file = "torch-2.4.0+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:040abaee8affa1bb0f3ca14ca693ba81d0d90d88df5b8a839af96933a7fa2d29"}, - {file = "torch-2.4.0+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:441fbf517c46fee6782a4289ffe49f701d0a52e3533ab5397ce395da165d921d"}, + {file = "torch-2.4.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:833490a28ac156762ed6adaa7c695879564fa2fd0dc51bcf3fdb2c7b47dc55e6"}, + {file = "torch-2.4.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:1dd062d296fb78aa7cfab8690bf03704995a821b5ef69cfc807af5c0831b4202"}, + {file = "torch-2.4.1+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:2b03e20f37557d211d14e3fb3f71709325336402db132a1e0dd8b47392185baf"}, + {file = "torch-2.4.1+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:76a6fe7b10491b650c630bc9ae328df40f79a948296b41d3b087b29a8a63cbad"}, + {file = "torch-2.4.1+cpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:8800deef0026011d502c0c256cc4b67d002347f63c3a38cd8e45f1f445c61364"}, + {file = "torch-2.4.1+cpu-cp312-cp312-win_amd64.whl", hash = "sha256:3a570e5c553415cdbddfe679207327b3a3806b21c6adea14fba77684d1619e97"}, + {file = "torch-2.4.1+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:0c0a7cc4f7c74ff024d5a5e21230a01289b65346b27a626f6c815d94b4b8c955"}, + {file = "torch-2.4.1+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:330e780f478707478f797fdc82c2a96e9b8c5f60b6f1f57bb6ad1dd5b1e7e97e"}, + {file = "torch-2.4.1+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:3c99506980a2fb4b634008ccb758f42dd82f93ae2830c1e41f64536e310bf562"}, + {file = "torch-2.4.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:c4f2c3c026e876d4dad7629170ec14fff48c076d6c2ae0e354ab3fdc09024f00"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index 9aae7371..c214a2e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pytorch_optimizer" -version = "3.1.1" +version = "3.1.2" description = "optimizer & lr scheduler & objective function collections in PyTorch" license = "Apache-2.0" authors = ["kozistr "] @@ -11,15 +11,15 @@ repository = "https://github.com/kozistr/pytorch_optimizer" documentation = "https://pytorch-optimizers.readthedocs.io/en/latest" keywords = [ "pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound", - "AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", - "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", - "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", - "Fromage", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", - "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", - "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", - "SignSGD", "SM3", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", - "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", - "LovaszHinge", "bitsandbytes", "WSD", "QGaLore", + "AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix", "AdaHessian", + "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", + "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", + "FAdam", "Fromage", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", + "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", + "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", + "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", + "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", + "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore", ] classifiers = [ "License :: OSI Approved :: Apache Software License", diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 9b0ba60f..c9279b17 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -54,6 +54,7 @@ from pytorch_optimizer.optimizer.adapnm import AdaPNM from pytorch_optimizer.optimizer.adashift import AdaShift from pytorch_optimizer.optimizer.adasmooth import AdaSmooth +from pytorch_optimizer.optimizer.ademamix import AdEMAMix from pytorch_optimizer.optimizer.agc import agc from pytorch_optimizer.optimizer.aggmo import AggMo from pytorch_optimizer.optimizer.aida import Aida @@ -208,6 +209,7 @@ AdamMini, AdaLOMO, AdamG, + AdEMAMix, ] OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST} diff --git a/pytorch_optimizer/optimizer/ademamix.py b/pytorch_optimizer/optimizer/ademamix.py new file mode 100644 index 00000000..1d2cdac2 --- /dev/null +++ b/pytorch_optimizer/optimizer/ademamix.py @@ -0,0 +1,149 @@ +import math +from typing import Optional + +import torch + +from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS + + +class AdEMAMix(BaseOptimizer): + r"""Better, Faster, Older. + + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. + :param lr: float. learning rate. + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. + :param weight_decay: float. weight decay (L2 penalty). + :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. + :param fixed_decay: bool. fix weight decay. + :param alpha: float. usually between 4 and 10 would work well. + :param t_alpha_beta3: Optional[float]. total number of iterations is preferred when needed. + :param eps: float. term added to the denominator to improve numerical stability. + """ + + def __init__( + self, + params: PARAMETERS, + lr: float = 1e-3, + betas: BETAS = (0.9, 0.999, 0.9999), + weight_decay: float = 0.0, + weight_decouple: bool = False, + fixed_decay: bool = False, + alpha: float = 5.0, + t_alpha_beta3: Optional[float] = None, + eps: float = 1e-8, + **kwargs, + ): + self.validate_learning_rate(lr) + self.validate_betas(betas) + self.validate_non_negative(alpha, 'alpha') + self.validate_non_negative(t_alpha_beta3, 't_alpha_beta3') + self.validate_non_negative(weight_decay, 'weight_decay') + self.validate_non_negative(eps, 'eps') + + defaults: DEFAULTS = { + 'lr': lr, + 'betas': betas, + 'weight_decay': weight_decay, + 'weight_decouple': weight_decouple, + 'fixed_decay': fixed_decay, + 'alpha': alpha, + 't_alpha_beta3': t_alpha_beta3, + 'eps': eps, + } + + super().__init__(params, defaults) + + def __str__(self) -> str: + return 'AdEMAMix' + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + state = self.state[p] + + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + state['exp_avg_slow'] = torch.zeros_like(p) + + @staticmethod + def schedule_alpha(t_alpha_beta3: Optional[float], step: int, alpha: float) -> float: + if t_alpha_beta3 is None: + return alpha + return min(step * alpha / t_alpha_beta3, alpha) + + @staticmethod + def schedule_beta3(t_alpha_beta3: Optional[float], step: int, beta1: float, beta3: float) -> float: + if t_alpha_beta3 is None: + return beta3 + + log_beta1, log_beta3 = math.log(beta1), math.log(beta3) + + return min( + math.exp( + log_beta1 * log_beta3 / ((1.0 - step / t_alpha_beta3) * log_beta3 + (step / t_alpha_beta3) * log_beta1) + ), + beta3, + ) + + @torch.no_grad() + def step(self, closure: CLOSURE = None) -> LOSS: + loss: LOSS = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + beta1, beta2, beta3 = group['betas'] + + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) + + alpha_t: float = self.schedule_alpha(group['t_alpha_beta3'], group['step'], group['alpha']) + beta3_t: float = self.schedule_beta3(group['t_alpha_beta3'], group['step'], beta1, beta3) + + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + if grad.is_sparse: + raise NoSparseGradientError(str(self)) + + state = self.state[p] + + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + state['exp_avg_slow'] = torch.zeros_like(p) + + self.apply_weight_decay( + p=p, + grad=grad, + lr=group['lr'], + weight_decay=group['weight_decay'], + weight_decouple=group['weight_decouple'], + fixed_decay=group['fixed_decay'], + ) + + exp_avg, exp_avg_sq, exp_avg_slow = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_slow'] + + exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + exp_avg_slow.mul_(beta3_t).add_(grad, alpha=1.0 - beta3_t) + + de_nom = (exp_avg_sq.sqrt() / bias_correction2_sq).add_(group['eps']) + + step_size = group['lr'] / bias_correction1 + + p.addcdiv_(exp_avg + alpha_t * exp_avg_slow, de_nom, value=-step_size) + + return loss diff --git a/requirements-dev.txt b/requirements-dev.txt index 9d2c0561..d78222be 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,8 +5,8 @@ click==8.1.7 ; python_version >= "3.8" and python_full_version < "4.0.0" colorama==0.4.6 ; python_version >= "3.8" and python_full_version < "4.0.0" and (sys_platform == "win32" or platform_system == "Windows") coverage[toml]==7.6.1 ; python_version >= "3.8" and python_full_version < "4.0.0" exceptiongroup==1.2.2 ; python_version >= "3.8" and python_version < "3.11" -filelock==3.15.4 ; python_version >= "3.8" and python_full_version < "4.0.0" -fsspec==2024.6.1 ; python_version >= "3.8" and python_full_version < "4.0.0" +filelock==3.16.0 ; python_version >= "3.8" and python_full_version < "4.0.0" +fsspec==2024.9.0 ; python_version >= "3.8" and python_full_version < "4.0.0" iniconfig==2.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0" isort==5.13.2 ; python_version >= "3.8" and python_full_version < "4.0.0" jinja2==3.1.4 ; python_version >= "3.8" and python_full_version < "4.0.0" @@ -17,12 +17,12 @@ networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0" numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0" packaging==24.1 ; python_version >= "3.8" and python_full_version < "4.0.0" pathspec==0.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0" -platformdirs==4.2.2 ; python_version >= "3.8" and python_full_version < "4.0.0" +platformdirs==4.3.2 ; python_version >= "3.8" and python_full_version < "4.0.0" pluggy==1.5.0 ; python_version >= "3.8" and python_full_version < "4.0.0" pytest-cov==5.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0" pytest==8.3.2 ; python_version >= "3.8" and python_full_version < "4.0.0" -ruff==0.5.7 ; python_version >= "3.8" and python_full_version < "4.0.0" +ruff==0.6.4 ; python_version >= "3.8" and python_full_version < "4.0.0" sympy==1.13.2 ; python_version >= "3.8" and python_full_version < "4.0.0" tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6" -torch==2.4.0+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" +torch==2.4.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" typing-extensions==4.12.2 ; python_version >= "3.8" and python_full_version < "4.0.0" diff --git a/requirements-docs.txt b/requirements-docs.txt index 8b5c48bb..0be9adb8 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -9,3 +9,4 @@ mkdocstrings-python==1.10.5 markdown-include==0.8.1 mdx_truly_sane_lists==1.3 mkdocs-awesome-pages-plugin==2.9.2 +griffe<1.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index dc544ecd..b530c3cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ --extra-index-url https://download.pytorch.org/whl/cpu -filelock==3.15.4 ; python_version >= "3.8" and python_full_version < "4.0.0" -fsspec==2024.6.1 ; python_version >= "3.8" and python_full_version < "4.0.0" +filelock==3.16.0 ; python_version >= "3.8" and python_full_version < "4.0.0" +fsspec==2024.9.0 ; python_version >= "3.8" and python_full_version < "4.0.0" jinja2==3.1.4 ; python_version >= "3.8" and python_full_version < "4.0.0" markupsafe==2.1.5 ; python_version >= "3.8" and python_full_version < "4.0.0" mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0" networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0" numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0" sympy==1.13.2 ; python_version >= "3.8" and python_full_version < "4.0.0" -torch==2.4.0+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" +torch==2.4.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" typing-extensions==4.12.2 ; python_version >= "3.8" and python_full_version < "4.0.0" diff --git a/tests/constants.py b/tests/constants.py index 7081b2b5..583df16a 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -34,6 +34,7 @@ AdaPNM, AdaShift, AdaSmooth, + AdEMAMix, AggMo, Aida, AliG, @@ -138,6 +139,7 @@ 'stableadamw', 'adammini', 'adamg', + 'ademamix', ] VALID_LR_SCHEDULER_NAMES: List[str] = [ @@ -471,6 +473,8 @@ (Kate, {'lr': 5e-2}, 10), (StableAdamW, {'lr': 1e0}, 5), (AdamG, {'lr': 1e0}, 20), + (AdEMAMix, {'lr': 1e0}, 5), + (AdEMAMix, {'lr': 1e0, 't_alpha_beta3': 5}, 5), ] ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [ (AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10), diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index f1b1d3cd..bf1d1c7a 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names): def test_get_supported_optimizers(): - assert len(get_supported_optimizers()) == 74 + assert len(get_supported_optimizers()) == 75 def test_get_supported_lr_schedulers():