diff --git a/README.md b/README.md index c06903f99..2567280d7 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **74 optimizers (+ `bitsandbytes`, `qgalore`)**, **16 lr schedulers**, and **13 loss functions** are supported! +Currently, **75 optimizers (+ `bitsandbytes`, `qgalore`)**, **16 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -172,6 +172,7 @@ supported_optimizers = get_supported_optimizers() | StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) | | AdamMini | *Use Fewer Learning Rates To Gain More* | [github](https://github.com/zyushun/Adam-mini) | | [cite](https://github.com/zyushun/Adam-mini?tab=readme-ov-file#citation) | | TRAC | *Adaptive Parameter-free Optimization* | [github](https://github.com/ComputationalRobotics/TRAC) | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240516642M/exportcitation) | +| AdamG | *Towards Stability of Parameter-free Optimization* | | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240504376P/exportcitation) | ## Supported LR Scheduler diff --git a/docs/changelogs/v3.1.1.md b/docs/changelogs/v3.1.1.md index 07a3115df..a9c630d17 100644 --- a/docs/changelogs/v3.1.1.md +++ b/docs/changelogs/v3.1.1.md @@ -5,6 +5,8 @@ * Implement `TRAC` optimizer. (#263) * [Fast TRAC: A Parameter-Free Optimizer for Lifelong Reinforcement Learning](https://arxiv.org/abs/2405.16642) * Support `AdamW` optimizer via `create_optimizer()`. (#263) +* Implement `AdamG` optimizer. (#264, #265) + * [Towards Stability of Parameter-free Optimization](https://arxiv.org/abs/2405.04376) ### Bug diff --git a/docs/index.md b/docs/index.md index c06903f99..2567280d7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **74 optimizers (+ `bitsandbytes`, `qgalore`)**, **16 lr schedulers**, and **13 loss functions** are supported! +Currently, **75 optimizers (+ `bitsandbytes`, `qgalore`)**, **16 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -172,6 +172,7 @@ supported_optimizers = get_supported_optimizers() | StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) | | AdamMini | *Use Fewer Learning Rates To Gain More* | [github](https://github.com/zyushun/Adam-mini) | | [cite](https://github.com/zyushun/Adam-mini?tab=readme-ov-file#citation) | | TRAC | *Adaptive Parameter-free Optimization* | [github](https://github.com/ComputationalRobotics/TRAC) | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240516642M/exportcitation) | +| AdamG | *Towards Stability of Parameter-free Optimization* | | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240504376P/exportcitation) | ## Supported LR Scheduler diff --git a/docs/optimizer.md b/docs/optimizer.md index 0fb3628fc..e6f5d5b84 100644 --- a/docs/optimizer.md +++ b/docs/optimizer.md @@ -44,6 +44,10 @@ :docstring: :members: +::: pytorch_optimizer.AdamG + :docstring: + :members: + ::: pytorch_optimizer.AdaMod :docstring: :members: diff --git a/poetry.lock b/poetry.lock index d4ae9cb20..7bef65d2b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,14 +1,14 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "bitsandbytes" -version = "0.43.1" +version = "0.43.3" description = "k-bit optimizers and matrix multiplication routines." optional = true python-versions = "*" files = [ - {file = "bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:a81c826d576d6d691c7b4a7491c8fdc0f37f769795d6ca2e54afa605d2c260a3"}, - {file = "bitsandbytes-0.43.1-py3-none-win_amd64.whl", hash = "sha256:52c1c7189a6ca006555a9663e544e75f40520a97a26e075411f9f9aca0771fcd"}, + {file = "bitsandbytes-0.43.3-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:cc99507c352be0715098b2c7577b690dd158972dc4ea10c7495bac104c7c79f0"}, + {file = "bitsandbytes-0.43.3-py3-none-win_amd64.whl", hash = "sha256:257f6552f2144748a84e6c44e1f7a98f3da888f675ed74e18fd7f7eb13c6cafa"}, ] [package.dependencies] @@ -21,33 +21,33 @@ test = ["scipy"] [[package]] name = "black" -version = "24.4.2" +version = "24.8.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" files = [ - {file = "black-24.4.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dd1b5a14e417189db4c7b64a6540f31730713d173f0b63e55fabd52d61d8fdce"}, - {file = "black-24.4.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8e537d281831ad0e71007dcdcbe50a71470b978c453fa41ce77186bbe0ed6021"}, - {file = "black-24.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaea3008c281f1038edb473c1aa8ed8143a5535ff18f978a318f10302b254063"}, - {file = "black-24.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:7768a0dbf16a39aa5e9a3ded568bb545c8c2727396d063bbaf847df05b08cd96"}, - {file = "black-24.4.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:257d724c2c9b1660f353b36c802ccece186a30accc7742c176d29c146df6e474"}, - {file = "black-24.4.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bdde6f877a18f24844e381d45e9947a49e97933573ac9d4345399be37621e26c"}, - {file = "black-24.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e151054aa00bad1f4e1f04919542885f89f5f7d086b8a59e5000e6c616896ffb"}, - {file = "black-24.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:7e122b1c4fb252fd85df3ca93578732b4749d9be076593076ef4d07a0233c3e1"}, - {file = "black-24.4.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:accf49e151c8ed2c0cdc528691838afd217c50412534e876a19270fea1e28e2d"}, - {file = "black-24.4.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:88c57dc656038f1ab9f92b3eb5335ee9b021412feaa46330d5eba4e51fe49b04"}, - {file = "black-24.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be8bef99eb46d5021bf053114442914baeb3649a89dc5f3a555c88737e5e98fc"}, - {file = "black-24.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:415e686e87dbbe6f4cd5ef0fbf764af7b89f9057b97c908742b6008cc554b9c0"}, - {file = "black-24.4.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bf10f7310db693bb62692609b397e8d67257c55f949abde4c67f9cc574492cc7"}, - {file = "black-24.4.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:98e123f1d5cfd42f886624d84464f7756f60ff6eab89ae845210631714f6db94"}, - {file = "black-24.4.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48a85f2cb5e6799a9ef05347b476cce6c182d6c71ee36925a6c194d074336ef8"}, - {file = "black-24.4.2-cp38-cp38-win_amd64.whl", hash = "sha256:b1530ae42e9d6d5b670a34db49a94115a64596bc77710b1d05e9801e62ca0a7c"}, - {file = "black-24.4.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:37aae07b029fa0174d39daf02748b379399b909652a806e5708199bd93899da1"}, - {file = "black-24.4.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:da33a1a5e49c4122ccdfd56cd021ff1ebc4a1ec4e2d01594fef9b6f267a9e741"}, - {file = "black-24.4.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef703f83fc32e131e9bcc0a5094cfe85599e7109f896fe8bc96cc402f3eb4b6e"}, - {file = "black-24.4.2-cp39-cp39-win_amd64.whl", hash = "sha256:b9176b9832e84308818a99a561e90aa479e73c523b3f77afd07913380ae2eab7"}, - {file = "black-24.4.2-py3-none-any.whl", hash = "sha256:d36ed1124bb81b32f8614555b34cc4259c3fbc7eec17870e8ff8ded335b58d8c"}, - {file = "black-24.4.2.tar.gz", hash = "sha256:c872b53057f000085da66a19c55d68f6f8ddcac2642392ad3a355878406fbd4d"}, + {file = "black-24.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:09cdeb74d494ec023ded657f7092ba518e8cf78fa8386155e4a03fdcc44679e6"}, + {file = "black-24.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:81c6742da39f33b08e791da38410f32e27d632260e599df7245cccee2064afeb"}, + {file = "black-24.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:707a1ca89221bc8a1a64fb5e15ef39cd755633daa672a9db7498d1c19de66a42"}, + {file = "black-24.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:d6417535d99c37cee4091a2f24eb2b6d5ec42b144d50f1f2e436d9fe1916fe1a"}, + {file = "black-24.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:fb6e2c0b86bbd43dee042e48059c9ad7830abd5c94b0bc518c0eeec57c3eddc1"}, + {file = "black-24.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:837fd281f1908d0076844bc2b801ad2d369c78c45cf800cad7b61686051041af"}, + {file = "black-24.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:62e8730977f0b77998029da7971fa896ceefa2c4c4933fcd593fa599ecbf97a4"}, + {file = "black-24.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:72901b4913cbac8972ad911dc4098d5753704d1f3c56e44ae8dce99eecb0e3af"}, + {file = "black-24.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:7c046c1d1eeb7aea9335da62472481d3bbf3fd986e093cffd35f4385c94ae368"}, + {file = "black-24.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:649f6d84ccbae73ab767e206772cc2d7a393a001070a4c814a546afd0d423aed"}, + {file = "black-24.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2b59b250fdba5f9a9cd9d0ece6e6d993d91ce877d121d161e4698af3eb9c1018"}, + {file = "black-24.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:6e55d30d44bed36593c3163b9bc63bf58b3b30e4611e4d88a0c3c239930ed5b2"}, + {file = "black-24.8.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:505289f17ceda596658ae81b61ebbe2d9b25aa78067035184ed0a9d855d18afd"}, + {file = "black-24.8.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b19c9ad992c7883ad84c9b22aaa73562a16b819c1d8db7a1a1a49fb7ec13c7d2"}, + {file = "black-24.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f13f7f386f86f8121d76599114bb8c17b69d962137fc70efe56137727c7047e"}, + {file = "black-24.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:f490dbd59680d809ca31efdae20e634f3fae27fba3ce0ba3208333b713bc3920"}, + {file = "black-24.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eab4dd44ce80dea27dc69db40dab62d4ca96112f87996bca68cd75639aeb2e4c"}, + {file = "black-24.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3c4285573d4897a7610054af5a890bde7c65cb466040c5f0c8b732812d7f0e5e"}, + {file = "black-24.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9e84e33b37be070ba135176c123ae52a51f82306def9f7d063ee302ecab2cf47"}, + {file = "black-24.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:73bbf84ed136e45d451a260c6b73ed674652f90a2b3211d6a35e78054563a9bb"}, + {file = "black-24.8.0-py3-none-any.whl", hash = "sha256:972085c618ee94f402da1af548a4f218c754ea7e5dc70acb168bfaca4c2542ed"}, + {file = "black-24.8.0.tar.gz", hash = "sha256:2500945420b6784c38b9ee885af039f5e7471ef284ab03fa35ecdde4688cd83f"}, ] [package.dependencies] @@ -92,63 +92,83 @@ files = [ [[package]] name = "coverage" -version = "7.6.0" +version = "7.6.1" description = "Code coverage measurement for Python" optional = false python-versions = ">=3.8" files = [ - {file = "coverage-7.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dff044f661f59dace805eedb4a7404c573b6ff0cdba4a524141bc63d7be5c7fd"}, - {file = "coverage-7.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a8659fd33ee9e6ca03950cfdcdf271d645cf681609153f218826dd9805ab585c"}, - {file = "coverage-7.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7792f0ab20df8071d669d929c75c97fecfa6bcab82c10ee4adb91c7a54055463"}, - {file = "coverage-7.6.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d4b3cd1ca7cd73d229487fa5caca9e4bc1f0bca96526b922d61053ea751fe791"}, - {file = "coverage-7.6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7e128f85c0b419907d1f38e616c4f1e9f1d1b37a7949f44df9a73d5da5cd53c"}, - {file = "coverage-7.6.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a94925102c89247530ae1dab7dc02c690942566f22e189cbd53579b0693c0783"}, - {file = "coverage-7.6.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:dcd070b5b585b50e6617e8972f3fbbee786afca71b1936ac06257f7e178f00f6"}, - {file = "coverage-7.6.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d50a252b23b9b4dfeefc1f663c568a221092cbaded20a05a11665d0dbec9b8fb"}, - {file = "coverage-7.6.0-cp310-cp310-win32.whl", hash = "sha256:0e7b27d04131c46e6894f23a4ae186a6a2207209a05df5b6ad4caee6d54a222c"}, - {file = "coverage-7.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:54dece71673b3187c86226c3ca793c5f891f9fc3d8aa183f2e3653da18566169"}, - {file = "coverage-7.6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7b525ab52ce18c57ae232ba6f7010297a87ced82a2383b1afd238849c1ff933"}, - {file = "coverage-7.6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4bea27c4269234e06f621f3fac3925f56ff34bc14521484b8f66a580aacc2e7d"}, - {file = "coverage-7.6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed8d1d1821ba5fc88d4a4f45387b65de52382fa3ef1f0115a4f7a20cdfab0e94"}, - {file = "coverage-7.6.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01c322ef2bbe15057bc4bf132b525b7e3f7206f071799eb8aa6ad1940bcf5fb1"}, - {file = "coverage-7.6.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03cafe82c1b32b770a29fd6de923625ccac3185a54a5e66606da26d105f37dac"}, - {file = "coverage-7.6.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0d1b923fc4a40c5832be4f35a5dab0e5ff89cddf83bb4174499e02ea089daf57"}, - {file = "coverage-7.6.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4b03741e70fb811d1a9a1d75355cf391f274ed85847f4b78e35459899f57af4d"}, - {file = "coverage-7.6.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a73d18625f6a8a1cbb11eadc1d03929f9510f4131879288e3f7922097a429f63"}, - {file = "coverage-7.6.0-cp311-cp311-win32.whl", hash = "sha256:65fa405b837060db569a61ec368b74688f429b32fa47a8929a7a2f9b47183713"}, - {file = "coverage-7.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:6379688fb4cfa921ae349c76eb1a9ab26b65f32b03d46bb0eed841fd4cb6afb1"}, - {file = "coverage-7.6.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f7db0b6ae1f96ae41afe626095149ecd1b212b424626175a6633c2999eaad45b"}, - {file = "coverage-7.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bbdf9a72403110a3bdae77948b8011f644571311c2fb35ee15f0f10a8fc082e8"}, - {file = "coverage-7.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cc44bf0315268e253bf563f3560e6c004efe38f76db03a1558274a6e04bf5d5"}, - {file = "coverage-7.6.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:da8549d17489cd52f85a9829d0e1d91059359b3c54a26f28bec2c5d369524807"}, - {file = "coverage-7.6.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0086cd4fc71b7d485ac93ca4239c8f75732c2ae3ba83f6be1c9be59d9e2c6382"}, - {file = "coverage-7.6.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1fad32ee9b27350687035cb5fdf9145bc9cf0a094a9577d43e909948ebcfa27b"}, - {file = "coverage-7.6.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:044a0985a4f25b335882b0966625270a8d9db3d3409ddc49a4eb00b0ef5e8cee"}, - {file = "coverage-7.6.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:76d5f82213aa78098b9b964ea89de4617e70e0d43e97900c2778a50856dac605"}, - {file = "coverage-7.6.0-cp312-cp312-win32.whl", hash = "sha256:3c59105f8d58ce500f348c5b56163a4113a440dad6daa2294b5052a10db866da"}, - {file = "coverage-7.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:ca5d79cfdae420a1d52bf177de4bc2289c321d6c961ae321503b2ca59c17ae67"}, - {file = "coverage-7.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d39bd10f0ae453554798b125d2f39884290c480f56e8a02ba7a6ed552005243b"}, - {file = "coverage-7.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:beb08e8508e53a568811016e59f3234d29c2583f6b6e28572f0954a6b4f7e03d"}, - {file = "coverage-7.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2e16f4cd2bc4d88ba30ca2d3bbf2f21f00f382cf4e1ce3b1ddc96c634bc48ca"}, - {file = "coverage-7.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6616d1c9bf1e3faea78711ee42a8b972367d82ceae233ec0ac61cc7fec09fa6b"}, - {file = "coverage-7.6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad4567d6c334c46046d1c4c20024de2a1c3abc626817ae21ae3da600f5779b44"}, - {file = "coverage-7.6.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d17c6a415d68cfe1091d3296ba5749d3d8696e42c37fca5d4860c5bf7b729f03"}, - {file = "coverage-7.6.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:9146579352d7b5f6412735d0f203bbd8d00113a680b66565e205bc605ef81bc6"}, - {file = "coverage-7.6.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:cdab02a0a941af190df8782aafc591ef3ad08824f97850b015c8c6a8b3877b0b"}, - {file = "coverage-7.6.0-cp38-cp38-win32.whl", hash = "sha256:df423f351b162a702c053d5dddc0fc0ef9a9e27ea3f449781ace5f906b664428"}, - {file = "coverage-7.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:f2501d60d7497fd55e391f423f965bbe9e650e9ffc3c627d5f0ac516026000b8"}, - {file = "coverage-7.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7221f9ac9dad9492cecab6f676b3eaf9185141539d5c9689d13fd6b0d7de840c"}, - {file = "coverage-7.6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ddaaa91bfc4477d2871442bbf30a125e8fe6b05da8a0015507bfbf4718228ab2"}, - {file = "coverage-7.6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4cbe651f3904e28f3a55d6f371203049034b4ddbce65a54527a3f189ca3b390"}, - {file = "coverage-7.6.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:831b476d79408ab6ccfadaaf199906c833f02fdb32c9ab907b1d4aa0713cfa3b"}, - {file = "coverage-7.6.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46c3d091059ad0b9c59d1034de74a7f36dcfa7f6d3bde782c49deb42438f2450"}, - {file = "coverage-7.6.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:4d5fae0a22dc86259dee66f2cc6c1d3e490c4a1214d7daa2a93d07491c5c04b6"}, - {file = "coverage-7.6.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:07ed352205574aad067482e53dd606926afebcb5590653121063fbf4e2175166"}, - {file = "coverage-7.6.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:49c76cdfa13015c4560702574bad67f0e15ca5a2872c6a125f6327ead2b731dd"}, - {file = "coverage-7.6.0-cp39-cp39-win32.whl", hash = "sha256:482855914928c8175735a2a59c8dc5806cf7d8f032e4820d52e845d1f731dca2"}, - {file = "coverage-7.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:543ef9179bc55edfd895154a51792b01c017c87af0ebaae092720152e19e42ca"}, - {file = "coverage-7.6.0-pp38.pp39.pp310-none-any.whl", hash = "sha256:6fe885135c8a479d3e37a7aae61cbd3a0fb2deccb4dda3c25f92a49189f766d6"}, - {file = "coverage-7.6.0.tar.gz", hash = "sha256:289cc803fa1dc901f84701ac10c9ee873619320f2f9aff38794db4a4a0268d51"}, + {file = "coverage-7.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b06079abebbc0e89e6163b8e8f0e16270124c154dc6e4a47b413dd538859af16"}, + {file = "coverage-7.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf4b19715bccd7ee27b6b120e7e9dd56037b9c0681dcc1adc9ba9db3d417fa36"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61c0abb4c85b095a784ef23fdd4aede7a2628478e7baba7c5e3deba61070a02"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd21f6ae3f08b41004dfb433fa895d858f3f5979e7762d052b12aef444e29afc"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f59d57baca39b32db42b83b2a7ba6f47ad9c394ec2076b084c3f029b7afca23"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a1ac0ae2b8bd743b88ed0502544847c3053d7171a3cff9228af618a068ed9c34"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e6a08c0be454c3b3beb105c0596ebdc2371fab6bb90c0c0297f4e58fd7e1012c"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f5796e664fe802da4f57a168c85359a8fbf3eab5e55cd4e4569fbacecc903959"}, + {file = "coverage-7.6.1-cp310-cp310-win32.whl", hash = "sha256:7bb65125fcbef8d989fa1dd0e8a060999497629ca5b0efbca209588a73356232"}, + {file = "coverage-7.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:3115a95daa9bdba70aea750db7b96b37259a81a709223c8448fa97727d546fe0"}, + {file = "coverage-7.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7dea0889685db8550f839fa202744652e87c60015029ce3f60e006f8c4462c93"}, + {file = "coverage-7.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed37bd3c3b063412f7620464a9ac1314d33100329f39799255fb8d3027da50d3"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d85f5e9a5f8b73e2350097c3756ef7e785f55bd71205defa0bfdaf96c31616ff"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bc572be474cafb617672c43fe989d6e48d3c83af02ce8de73fff1c6bb3c198d"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c0420b573964c760df9e9e86d1a9a622d0d27f417e1a949a8a66dd7bcee7bc6"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1f4aa8219db826ce6be7099d559f8ec311549bfc4046f7f9fe9b5cea5c581c56"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:fc5a77d0c516700ebad189b587de289a20a78324bc54baee03dd486f0855d234"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b48f312cca9621272ae49008c7f613337c53fadca647d6384cc129d2996d1133"}, + {file = "coverage-7.6.1-cp311-cp311-win32.whl", hash = "sha256:1125ca0e5fd475cbbba3bb67ae20bd2c23a98fac4e32412883f9bcbaa81c314c"}, + {file = "coverage-7.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:8ae539519c4c040c5ffd0632784e21b2f03fc1340752af711f33e5be83a9d6c6"}, + {file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"}, + {file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"}, + {file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"}, + {file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"}, + {file = "coverage-7.6.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a4acd025ecc06185ba2b801f2de85546e0b8ac787cf9d3b06e7e2a69f925b106"}, + {file = "coverage-7.6.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a6d3adcf24b624a7b778533480e32434a39ad8fa30c315208f6d3e5542aeb6e9"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0c212c49b6c10e6951362f7c6df3329f04c2b1c28499563d4035d964ab8e08c"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e81d7a3e58882450ec4186ca59a3f20a5d4440f25b1cff6f0902ad890e6748a"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78b260de9790fd81e69401c2dc8b17da47c8038176a79092a89cb2b7d945d060"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a78d169acd38300060b28d600344a803628c3fd585c912cacc9ea8790fe96862"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2c09f4ce52cb99dd7505cd0fc8e0e37c77b87f46bc9c1eb03fe3bc9991085388"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6878ef48d4227aace338d88c48738a4258213cd7b74fd9a3d4d7582bb1d8a155"}, + {file = "coverage-7.6.1-cp313-cp313-win32.whl", hash = "sha256:44df346d5215a8c0e360307d46ffaabe0f5d3502c8a1cefd700b34baf31d411a"}, + {file = "coverage-7.6.1-cp313-cp313-win_amd64.whl", hash = "sha256:8284cf8c0dd272a247bc154eb6c95548722dce90d098c17a883ed36e67cdb129"}, + {file = "coverage-7.6.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d3296782ca4eab572a1a4eca686d8bfb00226300dcefdf43faa25b5242ab8a3e"}, + {file = "coverage-7.6.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:502753043567491d3ff6d08629270127e0c31d4184c4c8d98f92c26f65019962"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a89ecca80709d4076b95f89f308544ec8f7b4727e8a547913a35f16717856cb"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a318d68e92e80af8b00fa99609796fdbcdfef3629c77c6283566c6f02c6d6704"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13b0a73a0896988f053e4fbb7de6d93388e6dd292b0d87ee51d106f2c11b465b"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:4421712dbfc5562150f7554f13dde997a2e932a6b5f352edcce948a815efee6f"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:166811d20dfea725e2e4baa71fffd6c968a958577848d2131f39b60043400223"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:225667980479a17db1048cb2bf8bfb39b8e5be8f164b8f6628b64f78a72cf9d3"}, + {file = "coverage-7.6.1-cp313-cp313t-win32.whl", hash = "sha256:170d444ab405852903b7d04ea9ae9b98f98ab6d7e63e1115e82620807519797f"}, + {file = "coverage-7.6.1-cp313-cp313t-win_amd64.whl", hash = "sha256:b9f222de8cded79c49bf184bdbc06630d4c58eec9459b939b4a690c82ed05657"}, + {file = "coverage-7.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6db04803b6c7291985a761004e9060b2bca08da6d04f26a7f2294b8623a0c1a0"}, + {file = "coverage-7.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f1adfc8ac319e1a348af294106bc6a8458a0f1633cc62a1446aebc30c5fa186a"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a95324a9de9650a729239daea117df21f4b9868ce32e63f8b650ebe6cef5595b"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b43c03669dc4618ec25270b06ecd3ee4fa94c7f9b3c14bae6571ca00ef98b0d3"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8929543a7192c13d177b770008bc4e8119f2e1f881d563fc6b6305d2d0ebe9de"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:a09ece4a69cf399510c8ab25e0950d9cf2b42f7b3cb0374f95d2e2ff594478a6"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:9054a0754de38d9dbd01a46621636689124d666bad1936d76c0341f7d71bf569"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0dbde0f4aa9a16fa4d754356a8f2e36296ff4d83994b2c9d8398aa32f222f989"}, + {file = "coverage-7.6.1-cp38-cp38-win32.whl", hash = "sha256:da511e6ad4f7323ee5702e6633085fb76c2f893aaf8ce4c51a0ba4fc07580ea7"}, + {file = "coverage-7.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:3f1156e3e8f2872197af3840d8ad307a9dd18e615dc64d9ee41696f287c57ad8"}, + {file = "coverage-7.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:abd5fd0db5f4dc9289408aaf34908072f805ff7792632250dcb36dc591d24255"}, + {file = "coverage-7.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:547f45fa1a93154bd82050a7f3cddbc1a7a4dd2a9bf5cb7d06f4ae29fe94eaf8"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645786266c8f18a931b65bfcefdbf6952dd0dea98feee39bd188607a9d307ed2"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e0b2df163b8ed01d515807af24f63de04bebcecbd6c3bfeff88385789fdf75a"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:609b06f178fe8e9f89ef676532760ec0b4deea15e9969bf754b37f7c40326dbc"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:702855feff378050ae4f741045e19a32d57d19f3e0676d589df0575008ea5004"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2bdb062ea438f22d99cba0d7829c2ef0af1d768d1e4a4f528087224c90b132cb"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9c56863d44bd1c4fe2abb8a4d6f5371d197f1ac0ebdee542f07f35895fc07f36"}, + {file = "coverage-7.6.1-cp39-cp39-win32.whl", hash = "sha256:6e2cd258d7d927d09493c8df1ce9174ad01b381d4729a9d8d4e38670ca24774c"}, + {file = "coverage-7.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:06a737c882bd26d0d6ee7269b20b12f14a8704807a01056c80bb881a4b2ce6ca"}, + {file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"}, + {file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"}, ] [package.dependencies] @@ -237,20 +257,6 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] -[[package]] -name = "intel-openmp" -version = "2021.4.0" -description = "Intel OpenMP* Runtime Library" -optional = false -python-versions = "*" -files = [ - {file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"}, -] - [[package]] name = "isort" version = "5.13.2" @@ -351,24 +357,6 @@ files = [ {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, ] -[[package]] -name = "mkl" -version = "2021.4.0" -description = "IntelĀ® oneAPI Math Kernel Library" -optional = false -python-versions = "*" -files = [ - {file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"}, - {file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"}, - {file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"}, - {file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"}, - {file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"}, -] - -[package.dependencies] -intel-openmp = "==2021.*" -tbb = "==2021.*" - [[package]] name = "mpmath" version = "1.3.0" @@ -507,13 +495,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pytest" -version = "8.3.1" +version = "8.3.2" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-8.3.1-py3-none-any.whl", hash = "sha256:e9600ccf4f563976e2c99fa02c7624ab938296551f280835ee6516df8bc4ae8c"}, - {file = "pytest-8.3.1.tar.gz", hash = "sha256:7e8e5c5abd6e93cb1cc151f23e57adc31fcf8cfd2a3ff2da63e23f732de35db6"}, + {file = "pytest-8.3.2-py3-none-any.whl", hash = "sha256:4ba08f9ae7dcf84ded419494d229b48d0903ea6407b030eaec46df5e6a73bba5"}, + {file = "pytest-8.3.2.tar.gz", hash = "sha256:c132345d12ce551242c87269de812483f5bcc87cdbb4722e48487ba194f9fdce"}, ] [package.dependencies] @@ -547,40 +535,40 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] [[package]] name = "ruff" -version = "0.5.4" +version = "0.5.7" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.5.4-py3-none-linux_armv6l.whl", hash = "sha256:82acef724fc639699b4d3177ed5cc14c2a5aacd92edd578a9e846d5b5ec18ddf"}, - {file = "ruff-0.5.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:da62e87637c8838b325e65beee485f71eb36202ce8e3cdbc24b9fcb8b99a37be"}, - {file = "ruff-0.5.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e98ad088edfe2f3b85a925ee96da652028f093d6b9b56b76fc242d8abb8e2059"}, - {file = "ruff-0.5.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c55efbecc3152d614cfe6c2247a3054cfe358cefbf794f8c79c8575456efe19"}, - {file = "ruff-0.5.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f9b85eaa1f653abd0a70603b8b7008d9e00c9fa1bbd0bf40dad3f0c0bdd06793"}, - {file = "ruff-0.5.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0cf497a47751be8c883059c4613ba2f50dd06ec672692de2811f039432875278"}, - {file = "ruff-0.5.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:09c14ed6a72af9ccc8d2e313d7acf7037f0faff43cde4b507e66f14e812e37f7"}, - {file = "ruff-0.5.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:628f6b8f97b8bad2490240aa84f3e68f390e13fabc9af5c0d3b96b485921cd60"}, - {file = "ruff-0.5.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3520a00c0563d7a7a7c324ad7e2cde2355733dafa9592c671fb2e9e3cd8194c1"}, - {file = "ruff-0.5.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93789f14ca2244fb91ed481456f6d0bb8af1f75a330e133b67d08f06ad85b516"}, - {file = "ruff-0.5.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:029454e2824eafa25b9df46882f7f7844d36fd8ce51c1b7f6d97e2615a57bbcc"}, - {file = "ruff-0.5.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9492320eed573a13a0bc09a2957f17aa733fff9ce5bf00e66e6d4a88ec33813f"}, - {file = "ruff-0.5.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a6e1f62a92c645e2919b65c02e79d1f61e78a58eddaebca6c23659e7c7cb4ac7"}, - {file = "ruff-0.5.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:768fa9208df2bec4b2ce61dbc7c2ddd6b1be9fb48f1f8d3b78b3332c7d71c1ff"}, - {file = "ruff-0.5.4-py3-none-win32.whl", hash = "sha256:e1e7393e9c56128e870b233c82ceb42164966f25b30f68acbb24ed69ce9c3a4e"}, - {file = "ruff-0.5.4-py3-none-win_amd64.whl", hash = "sha256:58b54459221fd3f661a7329f177f091eb35cf7a603f01d9eb3eb11cc348d38c4"}, - {file = "ruff-0.5.4-py3-none-win_arm64.whl", hash = "sha256:bd53da65f1085fb5b307c38fd3c0829e76acf7b2a912d8d79cadcdb4875c1eb7"}, - {file = "ruff-0.5.4.tar.gz", hash = "sha256:2795726d5f71c4f4e70653273d1c23a8182f07dd8e48c12de5d867bfb7557eed"}, + {file = "ruff-0.5.7-py3-none-linux_armv6l.whl", hash = "sha256:548992d342fc404ee2e15a242cdbea4f8e39a52f2e7752d0e4cbe88d2d2f416a"}, + {file = "ruff-0.5.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:00cc8872331055ee017c4f1071a8a31ca0809ccc0657da1d154a1d2abac5c0be"}, + {file = "ruff-0.5.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eaf3d86a1fdac1aec8a3417a63587d93f906c678bb9ed0b796da7b59c1114a1e"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a01c34400097b06cf8a6e61b35d6d456d5bd1ae6961542de18ec81eaf33b4cb8"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcc8054f1a717e2213500edaddcf1dbb0abad40d98e1bd9d0ad364f75c763eea"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f70284e73f36558ef51602254451e50dd6cc479f8b6f8413a95fcb5db4a55fc"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:a78ad870ae3c460394fc95437d43deb5c04b5c29297815a2a1de028903f19692"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ccd078c66a8e419475174bfe60a69adb36ce04f8d4e91b006f1329d5cd44bcf"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e31c9bad4ebf8fdb77b59cae75814440731060a09a0e0077d559a556453acbb"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d796327eed8e168164346b769dd9a27a70e0298d667b4ecee6877ce8095ec8e"}, + {file = "ruff-0.5.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4a09ea2c3f7778cc635e7f6edf57d566a8ee8f485f3c4454db7771efb692c499"}, + {file = "ruff-0.5.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a36d8dcf55b3a3bc353270d544fb170d75d2dff41eba5df57b4e0b67a95bb64e"}, + {file = "ruff-0.5.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9369c218f789eefbd1b8d82a8cf25017b523ac47d96b2f531eba73770971c9e5"}, + {file = "ruff-0.5.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b88ca3db7eb377eb24fb7c82840546fb7acef75af4a74bd36e9ceb37a890257e"}, + {file = "ruff-0.5.7-py3-none-win32.whl", hash = "sha256:33d61fc0e902198a3e55719f4be6b375b28f860b09c281e4bdbf783c0566576a"}, + {file = "ruff-0.5.7-py3-none-win_amd64.whl", hash = "sha256:083bbcbe6fadb93cd86709037acc510f86eed5a314203079df174c40bbbca6b3"}, + {file = "ruff-0.5.7-py3-none-win_arm64.whl", hash = "sha256:2dca26154ff9571995107221d0aeaad0e75a77b5a682d6236cf89a58c70b76f4"}, + {file = "ruff-0.5.7.tar.gz", hash = "sha256:8dfc0a458797f5d9fb622dd0efc52d796f23f0a1493a9527f4e49a550ae9a7e5"}, ] [[package]] name = "sympy" -version = "1.13.1" +version = "1.13.2" description = "Computer algebra system (CAS) in Python" optional = false python-versions = ">=3.8" files = [ - {file = "sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8"}, - {file = "sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f"}, + {file = "sympy-1.13.2-py3-none-any.whl", hash = "sha256:c51d75517712f1aed280d4ce58506a4a88d635d6b5dd48b39102a7ae1f3fcfe9"}, + {file = "sympy-1.13.2.tar.gz", hash = "sha256:401449d84d07be9d0c7a46a64bd54fe097667d5e7181bfe67ec777be9e01cb13"}, ] [package.dependencies] @@ -589,19 +577,6 @@ mpmath = ">=1.1.0,<1.4" [package.extras] dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] -[[package]] -name = "tbb" -version = "2021.13.0" -description = "IntelĀ® oneAPI Threading Building Blocks (oneTBB)" -optional = false -python-versions = "*" -files = [ - {file = "tbb-2021.13.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:a2567725329639519d46d92a2634cf61e76601dac2f777a05686fea546c4fe4f"}, - {file = "tbb-2021.13.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:aaf667e92849adb012b8874d6393282afc318aca4407fc62f912ee30a22da46a"}, - {file = "tbb-2021.13.0-py3-none-win32.whl", hash = "sha256:6669d26703e9943f6164c6407bd4a237a45007e79b8d3832fe6999576eaaa9ef"}, - {file = "tbb-2021.13.0-py3-none-win_amd64.whl", hash = "sha256:3528a53e4bbe64b07a6112b4c5a00ff3c61924ee46c9c68e004a1ac7ad1f09c3"}, -] - [[package]] name = "tomli" version = "2.0.1" @@ -615,35 +590,34 @@ files = [ [[package]] name = "torch" -version = "2.3.1+cpu" +version = "2.4.0+cpu" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.3.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:d679e21d871982b9234444331a26350902cfd2d5ca44ce6f49896af8b3a3087d"}, - {file = "torch-2.3.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:500bf790afc2fd374a15d06213242e517afccc50a46ea5955d321a9a68003335"}, - {file = "torch-2.3.1+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:a272defe305dbd944aa28a91cc3db0f0149495b3ebec2e39723a7224fa05dc57"}, - {file = "torch-2.3.1+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:d2965eb54d3c8818e2280a54bd53e8246a6bb34e4b10bd19c59f35b611dd9f05"}, - {file = "torch-2.3.1+cpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:2141a6cb7021adf2f92a0fd372cfeac524ba460bd39ce3a641d30a561e41f69a"}, - {file = "torch-2.3.1+cpu-cp312-cp312-win_amd64.whl", hash = "sha256:6acdca2530462611095c44fd95af75ecd5b9646eac813452fe0adf31a9bc310a"}, - {file = "torch-2.3.1+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:cab92d5101e6db686c5525e04d87cedbcf3a556073d71d07fbe7d1ce09630ffb"}, - {file = "torch-2.3.1+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:dbc784569a367fd425158cf4ae82057dd3011185ba5fc68440432ba0562cb5b2"}, - {file = "torch-2.3.1+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:a3cb8e61ba311cee1bb7463cbdcf3ebdfd071e2091e74c5785e3687eb02819f9"}, - {file = "torch-2.3.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:df68668056e62c0332e03f43d9da5d4278b39df1ba58d30ec20d34242070955d"}, + {file = "torch-2.4.0+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:0e59377b27823dda6d26528febb7ca06fc5b77816eaa58b4420cc8785e33d4ce"}, + {file = "torch-2.4.0+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:53c3f75fa4ef0726e494ebef003b17d8a61c3c9fa4630b465610b462bf06c3de"}, + {file = "torch-2.4.0+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:14a7a8b595347dddca594f9e448b93ce68ce4f871acbd32cf04bda7c03664c0c"}, + {file = "torch-2.4.0+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:3b3cb9a6c17b5a4cea42bb37a243bfbad7659cef6d9b4ee29cb793bdf20f482c"}, + {file = "torch-2.4.0+cpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:78dbf5f2789933a7ea2dabeead4daa44679b1e0d8eb35ddb7071c8ab7b181eb3"}, + {file = "torch-2.4.0+cpu-cp312-cp312-win_amd64.whl", hash = "sha256:f59c53a1c3247efb3700f9f78bdd289712177037a85b5519b9ecdef7c77c1fee"}, + {file = "torch-2.4.0+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:08753c3d776ae49dc9ddbae02e26720a513a4dc7997e41d95392bca71623a0cd"}, + {file = "torch-2.4.0+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:9f376f5a14eb04a44974c3a9dfd857a68090acb435b98e62bbf523baeefac85e"}, + {file = "torch-2.4.0+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:040abaee8affa1bb0f3ca14ca693ba81d0d90d88df5b8a839af96933a7fa2d29"}, + {file = "torch-2.4.0+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:441fbf517c46fee6782a4289ffe49f701d0a52e3533ab5397ce395da165d921d"}, ] [package.dependencies] filelock = "*" fsspec = "*" jinja2 = "*" -mkl = {version = ">=2021.1.1,<=2021.4.0", markers = "platform_system == \"Windows\""} networkx = "*" sympy = "*" typing-extensions = ">=4.8.0" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] -optree = ["optree (>=0.9.1)"] +optree = ["optree (>=0.11.0)"] [package.source] type = "legacy" diff --git a/pyproject.toml b/pyproject.toml index cdef27c6f..6fafffb51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,15 +11,15 @@ repository = "https://github.com/kozistr/pytorch_optimizer" documentation = "https://pytorch-optimizers.readthedocs.io/en/latest" keywords = [ "pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound", - "AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "Adalite", - "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM", - "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", "Fromage", - "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", - "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", - "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", - "SM3", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", - "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", - "bitsandbytes", "WSD", "QGaLore", + "AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", + "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", + "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", + "Fromage", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", + "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", + "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", + "SignSGD", "SM3", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", + "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", + "LovaszHinge", "bitsandbytes", "WSD", "QGaLore", ] classifiers = [ "License :: OSI Approved :: Apache Software License", diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 25241f9f4..9b0ba60f3 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -44,6 +44,7 @@ from pytorch_optimizer.optimizer.adalite import Adalite from pytorch_optimizer.optimizer.adam_mini import AdamMini from pytorch_optimizer.optimizer.adamax import AdaMax +from pytorch_optimizer.optimizer.adamg import AdamG from pytorch_optimizer.optimizer.adamod import AdaMod from pytorch_optimizer.optimizer.adamp import AdamP from pytorch_optimizer.optimizer.adams import AdamS @@ -206,6 +207,7 @@ StableAdamW, AdamMini, AdaLOMO, + AdamG, ] OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST} diff --git a/pytorch_optimizer/optimizer/adamg.py b/pytorch_optimizer/optimizer/adamg.py new file mode 100644 index 000000000..168914985 --- /dev/null +++ b/pytorch_optimizer/optimizer/adamg.py @@ -0,0 +1,127 @@ +import math + +import torch + +from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS + + +class AdamG(BaseOptimizer): + r"""Towards Stability of Parameter-free Optimization. + + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. + :param lr: float. learning rate. + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. + :param p: float. p for a numerator function `s(x) = p * x^q`. + :param q: float. q for a numerator function `s(x) = p * x^q`. + :param weight_decay: float. weight decay (L2 penalty). + :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. + :param fixed_decay: bool. fix weight decay. + :param eps: float. term added to the denominator to improve numerical stability. + """ + + def __init__( + self, + params: PARAMETERS, + lr: float = 1e-3, + betas: BETAS = (0.95, 0.999, 0.95), + p: float = 0.5, + q: float = 0.25, + weight_decay: float = 0.0, + weight_decouple: bool = False, + fixed_decay: bool = False, + eps: float = 1e-8, + ): + self.validate_learning_rate(lr) + self.validate_betas(betas) + self.validate_positive(p, 'p') + self.validate_positive(q, 'q') + self.validate_non_negative(weight_decay, 'weight_decay') + self.validate_non_negative(eps, 'eps') + + self.p = p + self.q = q + + defaults: DEFAULTS = { + 'lr': lr, + 'betas': betas, + 'weight_decay': weight_decay, + 'weight_decouple': weight_decouple, + 'fixed_decay': fixed_decay, + 'eps': eps, + } + + super().__init__(params, defaults) + + def __str__(self) -> str: + return 'AdamG' + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + state = self.state[p] + + state['m'] = torch.zeros_like(p) + state['v'] = torch.zeros_like(p) + state['r'] = torch.zeros_like(p) + + def s(self, p: torch.Tensor) -> torch.Tensor: + r"""Numerator function f(x) = p * x^q.""" + return p.pow(self.q).mul_(self.p) + + @torch.no_grad() + def step(self, closure: CLOSURE = None) -> LOSS: + loss: LOSS = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + beta1, beta2, beta3 = group['betas'] + + bias_correction1: float = 1.0 - self.debias(beta1, group['step']) + bias_correction2: float = 1.0 - self.debias(beta2, group['step']) + step_size: float = min(group['lr'], 1.0 / math.sqrt(group['step'])) + + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + if grad.is_sparse: + raise NoSparseGradientError(str(self)) + + state = self.state[p] + + if len(state) == 0: + state['m'] = torch.zeros_like(p) + state['v'] = torch.zeros_like(p) + state['r'] = torch.zeros_like(p) + + self.apply_weight_decay( + p=p, + grad=grad, + lr=group['lr'], + weight_decay=group['weight_decay'], + weight_decouple=group['weight_decouple'], + fixed_decay=group['fixed_decay'], + ) + + m, v, r = state['m'], state['v'], state['r'] + v.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + r.mul_(beta3).add_(self.s(v), alpha=1.0 - beta3) + m.mul_(beta1).addcmul_(r, grad, value=1.0 - beta1) + + update = (m / bias_correction1) / (v / bias_correction2).sqrt_().add_(group['eps']) + + p.add_(update, alpha=-step_size) + + return loss diff --git a/pytorch_optimizer/optimizer/trac.py b/pytorch_optimizer/optimizer/trac.py index 7dbc98e5a..7c13d70ac 100644 --- a/pytorch_optimizer/optimizer/trac.py +++ b/pytorch_optimizer/optimizer/trac.py @@ -140,7 +140,7 @@ def reset(self): self.state['trac'] = { 'betas': torch.tensor(self.betas, device=device), - 's': torch.zeros(len(self.betas), device=device), + 's': torch.zeros(1, device=device), 'variance': torch.zeros(len(self.betas), device=device), 'sigma': torch.full((len(self.betas),), 1e-8, device=device), 'step': 0, @@ -182,6 +182,7 @@ def trac_step(self, updates: Dict, grads: Dict) -> None: device = self.param_groups[0]['params'][0].device + s = self.state['trac']['s'] h = torch.zeros((1,), device=device) for group in self.param_groups: for p in group['params']: @@ -191,7 +192,7 @@ def trac_step(self, updates: Dict, grads: Dict) -> None: theta_ref = self.state['trac'][p] update = updates[p] - deltas[p] = (update - theta_ref) / torch.sum(self.state['trac']['s']).add_(self.eps) + deltas[p] = (update - theta_ref) / s.add(self.eps) update.neg_().add_(p) grad, delta = grads[p], deltas[p] @@ -201,7 +202,8 @@ def trac_step(self, updates: Dict, grads: Dict) -> None: delta.add_(update) - s = self.state['trac']['s'] + p.copy_(theta_ref) + betas = self.state['trac']['betas'] variance = self.state['trac']['variance'] sigma = self.state['trac']['sigma'] @@ -209,21 +211,17 @@ def trac_step(self, updates: Dict, grads: Dict) -> None: variance.mul_(betas.pow(2)).add_(h.pow(2)) sigma.mul_(betas).sub_(h) - s_term = self.erf_imag(sigma / (2.0 * variance).sqrt_().add_(self.eps)) - s_term.mul_(self.f_term) - s.copy_(s_term) + term = self.erf_imag(sigma / (2.0 * variance).sqrt_().add_(self.eps)).mul_(self.f_term) + s.copy_(torch.sum(term)) - scale = max(torch.sum(s), 0.0) + scale = max(s, 0.0) for group in self.param_groups: for p in group['params']: if grads[p] is None: continue - delta = deltas[p] - delta.mul_(scale).add_(self.state['trac'][p]) - - p.copy_(delta) + p.add_(deltas[p] * scale) @torch.no_grad() def step(self, closure: CLOSURE = None) -> LOSS: @@ -238,7 +236,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: self.state['trac'] = { 'betas': torch.tensor(self.betas, device=device), - 's': torch.zeros(len(self.betas), device=device), + 's': torch.zeros(1, device=device), 'variance': torch.zeros(len(self.betas), device=device), 'sigma': torch.full((len(self.betas),), 1e-8, device=device), 'step': 0, diff --git a/requirements-dev.txt b/requirements-dev.txt index e4f7097fc..9d2c05616 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,18 +1,16 @@ --extra-index-url https://download.pytorch.org/whl/cpu -black==24.4.2 ; python_version >= "3.8" and python_full_version < "4.0.0" +black==24.8.0 ; python_version >= "3.8" and python_full_version < "4.0.0" click==8.1.7 ; python_version >= "3.8" and python_full_version < "4.0.0" colorama==0.4.6 ; python_version >= "3.8" and python_full_version < "4.0.0" and (sys_platform == "win32" or platform_system == "Windows") -coverage[toml]==7.6.0 ; python_version >= "3.8" and python_full_version < "4.0.0" +coverage[toml]==7.6.1 ; python_version >= "3.8" and python_full_version < "4.0.0" exceptiongroup==1.2.2 ; python_version >= "3.8" and python_version < "3.11" filelock==3.15.4 ; python_version >= "3.8" and python_full_version < "4.0.0" fsspec==2024.6.1 ; python_version >= "3.8" and python_full_version < "4.0.0" iniconfig==2.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0" -intel-openmp==2021.4.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows" isort==5.13.2 ; python_version >= "3.8" and python_full_version < "4.0.0" jinja2==3.1.4 ; python_version >= "3.8" and python_full_version < "4.0.0" markupsafe==2.1.5 ; python_version >= "3.8" and python_full_version < "4.0.0" -mkl==2021.4.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows" mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0" mypy-extensions==1.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0" networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0" @@ -22,10 +20,9 @@ pathspec==0.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0" platformdirs==4.2.2 ; python_version >= "3.8" and python_full_version < "4.0.0" pluggy==1.5.0 ; python_version >= "3.8" and python_full_version < "4.0.0" pytest-cov==5.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0" -pytest==8.3.1 ; python_version >= "3.8" and python_full_version < "4.0.0" -ruff==0.5.4 ; python_version >= "3.8" and python_full_version < "4.0.0" -sympy==1.13.1 ; python_version >= "3.8" and python_full_version < "4.0.0" -tbb==2021.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows" +pytest==8.3.2 ; python_version >= "3.8" and python_full_version < "4.0.0" +ruff==0.5.7 ; python_version >= "3.8" and python_full_version < "4.0.0" +sympy==1.13.2 ; python_version >= "3.8" and python_full_version < "4.0.0" tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6" -torch==2.3.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" +torch==2.4.0+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" typing-extensions==4.12.2 ; python_version >= "3.8" and python_full_version < "4.0.0" diff --git a/requirements.txt b/requirements.txt index ac61ed57e..dc544ecd7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,14 +2,11 @@ filelock==3.15.4 ; python_version >= "3.8" and python_full_version < "4.0.0" fsspec==2024.6.1 ; python_version >= "3.8" and python_full_version < "4.0.0" -intel-openmp==2021.4.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows" jinja2==3.1.4 ; python_version >= "3.8" and python_full_version < "4.0.0" markupsafe==2.1.5 ; python_version >= "3.8" and python_full_version < "4.0.0" -mkl==2021.4.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows" mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0" networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0" numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0" -sympy==1.13.1 ; python_version >= "3.8" and python_full_version < "4.0.0" -tbb==2021.13.0 ; python_version >= "3.8" and python_full_version < "4.0.0" and platform_system == "Windows" -torch==2.3.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" +sympy==1.13.2 ; python_version >= "3.8" and python_full_version < "4.0.0" +torch==2.4.0+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" typing-extensions==4.12.2 ; python_version >= "3.8" and python_full_version < "4.0.0" diff --git a/tests/constants.py b/tests/constants.py index 072645cee..947667a9f 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -25,6 +25,7 @@ Adai, Adalite, AdaMax, + AdamG, AdaMod, AdamP, AdamS, @@ -136,6 +137,7 @@ 'grokfastadamw', 'stableadamw', 'adammini', + 'adamg', ] VALID_LR_SCHEDULER_NAMES: List[str] = [ @@ -468,6 +470,7 @@ (GrokFastAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 10), (Kate, {'lr': 5e-2}, 10), (StableAdamW, {'lr': 1e0}, 5), + (AdamG, {'lr': 1e0}, 20), ] ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [ (AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10), diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index 9a23043ca..f1b1d3cd5 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names): def test_get_supported_optimizers(): - assert len(get_supported_optimizers()) == 73 + assert len(get_supported_optimizers()) == 74 def test_get_supported_lr_schedulers():