diff --git a/README.md b/README.md index 0633c3746..dbfc1a1d2 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **62 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported! +Currently, **63 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -160,6 +160,7 @@ supported_optimizers = get_supported_optimizers() | CAME | *Confidence-guided Adaptive Memory Efficient Optimization* | [github](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/CAME) | | [cite](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/CAME#citation) | | WSAM | *Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term* | [github](https://github.com/intelligent-machine-learning/dlrover/blob/master/atorch/atorch/optimizers/wsam.py) | | [cite](https://github.com/intelligent-machine-learning/dlrover) | | Aida | *A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range* | [github](https://github.com/guoqiang-zhang-x/Aida-Optimizer) | | [cite](https://github.com/guoqiang-zhang-x/Aida-Optimizer?tab=readme-ov-file#1-brief-description-of-aida) | +| GaLore | *Memory-Efficient LLM Training by Gradient Low-Rank Projection* | [github](https://github.com/jiaweizzhao/GaLore) | | [cite](https://github.com/jiaweizzhao/GaLore/tree/master?tab=readme-ov-file#citation) | ## Supported LR Scheduler diff --git a/docs/changelogs/v3.0.0.md b/docs/changelogs/v3.0.0.md index 7d4980be1..0de82f3c9 100644 --- a/docs/changelogs/v3.0.0.md +++ b/docs/changelogs/v3.0.0.md @@ -10,21 +10,28 @@ Major version is updated! (`v2.12.0` -> `v3.0.0`) (#164) * [A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range](https://arxiv.org/abs/2203.13273) * Implement `WSAM` optimizer. (#213, #216) * [Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term](https://arxiv.org/abs/2305.15817) +* Implement `GaLore` optimizer. (#224, #228) + * [Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507) -## Dependency +### Fix + +* Fix SRMM to allow operation beyond memory_length. (#227) + +### Dependency * Drop `Python 3.7` support officially. (#221) * Please check the [README](https://github.com/kozistr/pytorch_optimizer?tab=readme-ov-file#getting-started). +* Update `bitsandbytes` to `0.43.0`. (#228) -## Docs +### Docs * Add missing parameters in `Ranger21 optimizer` document. (#214, #215) * Fix `WSAM` optimizer paper link. (#219) -### Contributions +## Contributions -thanks to @sdbds +thanks to @sdbds, @i404788 -### Diff +## Diff [2.12.0...3.0.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.12.0...v3.0.0) diff --git a/docs/optimizer.md b/docs/optimizer.md index 776ea8b70..14d326b4e 100644 --- a/docs/optimizer.md +++ b/docs/optimizer.md @@ -132,6 +132,10 @@ :docstring: :members: +::: pytorch_optimizer.GaLoreProjector + :docstring: + :members: + ::: pytorch_optimizer.centralize_gradient :docstring: :members: diff --git a/poetry.lock b/poetry.lock index f40e104c6..f8cc69312 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,47 +2,52 @@ [[package]] name = "bitsandbytes" -version = "0.42.0" +version = "0.43.0" description = "k-bit optimizers and matrix multiplication routines." optional = true python-versions = "*" files = [ - {file = "bitsandbytes-0.42.0-py3-none-any.whl", hash = "sha256:63798680912cc63bb77b535a2d0860af024e290a52e157f777ad2a52e2585967"}, - {file = "bitsandbytes-0.42.0.tar.gz", hash = "sha256:fc1505f184f0d275766f2a6c663f1a43b734c1409b5c5a406f3a6073d9f329fd"}, + {file = "bitsandbytes-0.43.0-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:b2626ada0ae447ae0cf3dd0be8f5b0abad7abdec7056c7fb738aa13a5a862007"}, + {file = "bitsandbytes-0.43.0-py3-none-win_amd64.whl", hash = "sha256:6fa7f3255fe9f3e549fb110bc60794079761a4e608b5fb86ebe7b4047467dd99"}, ] [package.dependencies] -scipy = "*" +numpy = "*" +torch = "*" + +[package.extras] +benchmark = ["matplotlib", "pandas"] +test = ["scipy"] [[package]] name = "black" -version = "24.2.0" +version = "24.3.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" files = [ - {file = "black-24.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6981eae48b3b33399c8757036c7f5d48a535b962a7c2310d19361edeef64ce29"}, - {file = "black-24.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d533d5e3259720fdbc1b37444491b024003e012c5173f7d06825a77508085430"}, - {file = "black-24.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61a0391772490ddfb8a693c067df1ef5227257e72b0e4108482b8d41b5aee13f"}, - {file = "black-24.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:992e451b04667116680cb88f63449267c13e1ad134f30087dec8527242e9862a"}, - {file = "black-24.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:163baf4ef40e6897a2a9b83890e59141cc8c2a98f2dda5080dc15c00ee1e62cd"}, - {file = "black-24.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e37c99f89929af50ffaf912454b3e3b47fd64109659026b678c091a4cd450fb2"}, - {file = "black-24.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9de21bafcba9683853f6c96c2d515e364aee631b178eaa5145fc1c61a3cc92"}, - {file = "black-24.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:9db528bccb9e8e20c08e716b3b09c6bdd64da0dd129b11e160bf082d4642ac23"}, - {file = "black-24.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d84f29eb3ee44859052073b7636533ec995bd0f64e2fb43aeceefc70090e752b"}, - {file = "black-24.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e08fb9a15c914b81dd734ddd7fb10513016e5ce7e6704bdd5e1251ceee51ac9"}, - {file = "black-24.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:810d445ae6069ce64030c78ff6127cd9cd178a9ac3361435708b907d8a04c693"}, - {file = "black-24.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:ba15742a13de85e9b8f3239c8f807723991fbfae24bad92d34a2b12e81904982"}, - {file = "black-24.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7e53a8c630f71db01b28cd9602a1ada68c937cbf2c333e6ed041390d6968faf4"}, - {file = "black-24.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:93601c2deb321b4bad8f95df408e3fb3943d85012dddb6121336b8e24a0d1218"}, - {file = "black-24.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0057f800de6acc4407fe75bb147b0c2b5cbb7c3ed110d3e5999cd01184d53b0"}, - {file = "black-24.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:faf2ee02e6612577ba0181f4347bcbcf591eb122f7841ae5ba233d12c39dcb4d"}, - {file = "black-24.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:057c3dc602eaa6fdc451069bd027a1b2635028b575a6c3acfd63193ced20d9c8"}, - {file = "black-24.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:08654d0797e65f2423f850fc8e16a0ce50925f9337fb4a4a176a7aa4026e63f8"}, - {file = "black-24.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca610d29415ee1a30a3f30fab7a8f4144e9d34c89a235d81292a1edb2b55f540"}, - {file = "black-24.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:4dd76e9468d5536abd40ffbc7a247f83b2324f0c050556d9c371c2b9a9a95e31"}, - {file = "black-24.2.0-py3-none-any.whl", hash = "sha256:e8a6ae970537e67830776488bca52000eaa37fa63b9988e8c487458d9cd5ace6"}, - {file = "black-24.2.0.tar.gz", hash = "sha256:bce4f25c27c3435e4dace4815bcb2008b87e167e3bf4ee47ccdc5ce906eb4894"}, + {file = "black-24.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7d5e026f8da0322b5662fa7a8e752b3fa2dac1c1cbc213c3d7ff9bdd0ab12395"}, + {file = "black-24.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9f50ea1132e2189d8dff0115ab75b65590a3e97de1e143795adb4ce317934995"}, + {file = "black-24.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2af80566f43c85f5797365077fb64a393861a3730bd110971ab7a0c94e873e7"}, + {file = "black-24.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:4be5bb28e090456adfc1255e03967fb67ca846a03be7aadf6249096100ee32d0"}, + {file = "black-24.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4f1373a7808a8f135b774039f61d59e4be7eb56b2513d3d2f02a8b9365b8a8a9"}, + {file = "black-24.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:aadf7a02d947936ee418777e0247ea114f78aff0d0959461057cae8a04f20597"}, + {file = "black-24.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c02e4ea2ae09d16314d30912a58ada9a5c4fdfedf9512d23326128ac08ac3d"}, + {file = "black-24.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:bf21b7b230718a5f08bd32d5e4f1db7fc8788345c8aea1d155fc17852b3410f5"}, + {file = "black-24.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:2818cf72dfd5d289e48f37ccfa08b460bf469e67fb7c4abb07edc2e9f16fb63f"}, + {file = "black-24.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4acf672def7eb1725f41f38bf6bf425c8237248bb0804faa3965c036f7672d11"}, + {file = "black-24.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7ed6668cbbfcd231fa0dc1b137d3e40c04c7f786e626b405c62bcd5db5857e4"}, + {file = "black-24.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:56f52cfbd3dabe2798d76dbdd299faa046a901041faf2cf33288bc4e6dae57b5"}, + {file = "black-24.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:79dcf34b33e38ed1b17434693763301d7ccbd1c5860674a8f871bd15139e7837"}, + {file = "black-24.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e19cb1c6365fd6dc38a6eae2dcb691d7d83935c10215aef8e6c38edee3f77abd"}, + {file = "black-24.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65b76c275e4c1c5ce6e9870911384bff5ca31ab63d19c76811cb1fb162678213"}, + {file = "black-24.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:b5991d523eee14756f3c8d5df5231550ae8993e2286b8014e2fdea7156ed0959"}, + {file = "black-24.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c45f8dff244b3c431b36e3224b6be4a127c6aca780853574c00faf99258041eb"}, + {file = "black-24.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6905238a754ceb7788a73f02b45637d820b2f5478b20fec82ea865e4f5d4d9f7"}, + {file = "black-24.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7de8d330763c66663661a1ffd432274a2f92f07feeddd89ffd085b5744f85e7"}, + {file = "black-24.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:7bb041dca0d784697af4646d3b62ba4a6b028276ae878e53f6b4f74ddd6db99f"}, + {file = "black-24.3.0-py3-none-any.whl", hash = "sha256:41622020d7120e01d377f74249e677039d20e6344ff5851de8a10f11f513bf93"}, + {file = "black-24.3.0.tar.gz", hash = "sha256:a0c9c4a0771afc6919578cec71ce82a3e31e054904e7197deacbc9382671c41f"}, ] [package.dependencies] @@ -87,63 +92,63 @@ files = [ [[package]] name = "coverage" -version = "7.4.3" +version = "7.4.4" description = "Code coverage measurement for Python" optional = false python-versions = ">=3.8" files = [ - {file = "coverage-7.4.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8580b827d4746d47294c0e0b92854c85a92c2227927433998f0d3320ae8a71b6"}, - {file = "coverage-7.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:718187eeb9849fc6cc23e0d9b092bc2348821c5e1a901c9f8975df0bc785bfd4"}, - {file = "coverage-7.4.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:767b35c3a246bcb55b8044fd3a43b8cd553dd1f9f2c1eeb87a302b1f8daa0524"}, - {file = "coverage-7.4.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae7f19afe0cce50039e2c782bff379c7e347cba335429678450b8fe81c4ef96d"}, - {file = "coverage-7.4.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba3a8aaed13770e970b3df46980cb068d1c24af1a1968b7818b69af8c4347efb"}, - {file = "coverage-7.4.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ee866acc0861caebb4f2ab79f0b94dbfbdbfadc19f82e6e9c93930f74e11d7a0"}, - {file = "coverage-7.4.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:506edb1dd49e13a2d4cac6a5173317b82a23c9d6e8df63efb4f0380de0fbccbc"}, - {file = "coverage-7.4.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd6545d97c98a192c5ac995d21c894b581f1fd14cf389be90724d21808b657e2"}, - {file = "coverage-7.4.3-cp310-cp310-win32.whl", hash = "sha256:f6a09b360d67e589236a44f0c39218a8efba2593b6abdccc300a8862cffc2f94"}, - {file = "coverage-7.4.3-cp310-cp310-win_amd64.whl", hash = "sha256:18d90523ce7553dd0b7e23cbb28865db23cddfd683a38fb224115f7826de78d0"}, - {file = "coverage-7.4.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cbbe5e739d45a52f3200a771c6d2c7acf89eb2524890a4a3aa1a7fa0695d2a47"}, - {file = "coverage-7.4.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:489763b2d037b164846ebac0cbd368b8a4ca56385c4090807ff9fad817de4113"}, - {file = "coverage-7.4.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:451f433ad901b3bb00184d83fd83d135fb682d780b38af7944c9faeecb1e0bfe"}, - {file = "coverage-7.4.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fcc66e222cf4c719fe7722a403888b1f5e1682d1679bd780e2b26c18bb648cdc"}, - {file = "coverage-7.4.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3ec74cfef2d985e145baae90d9b1b32f85e1741b04cd967aaf9cfa84c1334f3"}, - {file = "coverage-7.4.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:abbbd8093c5229c72d4c2926afaee0e6e3140de69d5dcd918b2921f2f0c8baba"}, - {file = "coverage-7.4.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:35eb581efdacf7b7422af677b92170da4ef34500467381e805944a3201df2079"}, - {file = "coverage-7.4.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8249b1c7334be8f8c3abcaaa996e1e4927b0e5a23b65f5bf6cfe3180d8ca7840"}, - {file = "coverage-7.4.3-cp311-cp311-win32.whl", hash = "sha256:cf30900aa1ba595312ae41978b95e256e419d8a823af79ce670835409fc02ad3"}, - {file = "coverage-7.4.3-cp311-cp311-win_amd64.whl", hash = "sha256:18c7320695c949de11a351742ee001849912fd57e62a706d83dfc1581897fa2e"}, - {file = "coverage-7.4.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b51bfc348925e92a9bd9b2e48dad13431b57011fd1038f08316e6bf1df107d10"}, - {file = "coverage-7.4.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d6cdecaedea1ea9e033d8adf6a0ab11107b49571bbb9737175444cea6eb72328"}, - {file = "coverage-7.4.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b2eccb883368f9e972e216c7b4c7c06cabda925b5f06dde0650281cb7666a30"}, - {file = "coverage-7.4.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6c00cdc8fa4e50e1cc1f941a7f2e3e0f26cb2a1233c9696f26963ff58445bac7"}, - {file = "coverage-7.4.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b9a4a8dd3dcf4cbd3165737358e4d7dfbd9d59902ad11e3b15eebb6393b0446e"}, - {file = "coverage-7.4.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:062b0a75d9261e2f9c6d071753f7eef0fc9caf3a2c82d36d76667ba7b6470003"}, - {file = "coverage-7.4.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:ebe7c9e67a2d15fa97b77ea6571ce5e1e1f6b0db71d1d5e96f8d2bf134303c1d"}, - {file = "coverage-7.4.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c0a120238dd71c68484f02562f6d446d736adcc6ca0993712289b102705a9a3a"}, - {file = "coverage-7.4.3-cp312-cp312-win32.whl", hash = "sha256:37389611ba54fd6d278fde86eb2c013c8e50232e38f5c68235d09d0a3f8aa352"}, - {file = "coverage-7.4.3-cp312-cp312-win_amd64.whl", hash = "sha256:d25b937a5d9ffa857d41be042b4238dd61db888533b53bc76dc082cb5a15e914"}, - {file = "coverage-7.4.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:28ca2098939eabab044ad68850aac8f8db6bf0b29bc7f2887d05889b17346454"}, - {file = "coverage-7.4.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:280459f0a03cecbe8800786cdc23067a8fc64c0bd51dc614008d9c36e1659d7e"}, - {file = "coverage-7.4.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c0cdedd3500e0511eac1517bf560149764b7d8e65cb800d8bf1c63ebf39edd2"}, - {file = "coverage-7.4.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9a9babb9466fe1da12417a4aed923e90124a534736de6201794a3aea9d98484e"}, - {file = "coverage-7.4.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dec9de46a33cf2dd87a5254af095a409ea3bf952d85ad339751e7de6d962cde6"}, - {file = "coverage-7.4.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:16bae383a9cc5abab9bb05c10a3e5a52e0a788325dc9ba8499e821885928968c"}, - {file = "coverage-7.4.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2c854ce44e1ee31bda4e318af1dbcfc929026d12c5ed030095ad98197eeeaed0"}, - {file = "coverage-7.4.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ce8c50520f57ec57aa21a63ea4f325c7b657386b3f02ccaedeccf9ebe27686e1"}, - {file = "coverage-7.4.3-cp38-cp38-win32.whl", hash = "sha256:708a3369dcf055c00ddeeaa2b20f0dd1ce664eeabde6623e516c5228b753654f"}, - {file = "coverage-7.4.3-cp38-cp38-win_amd64.whl", hash = "sha256:1bf25fbca0c8d121a3e92a2a0555c7e5bc981aee5c3fdaf4bb7809f410f696b9"}, - {file = "coverage-7.4.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b253094dbe1b431d3a4ac2f053b6d7ede2664ac559705a704f621742e034f1f"}, - {file = "coverage-7.4.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77fbfc5720cceac9c200054b9fab50cb2a7d79660609200ab83f5db96162d20c"}, - {file = "coverage-7.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6679060424faa9c11808598504c3ab472de4531c571ab2befa32f4971835788e"}, - {file = "coverage-7.4.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4af154d617c875b52651dd8dd17a31270c495082f3d55f6128e7629658d63765"}, - {file = "coverage-7.4.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8640f1fde5e1b8e3439fe482cdc2b0bb6c329f4bb161927c28d2e8879c6029ee"}, - {file = "coverage-7.4.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:69b9f6f66c0af29642e73a520b6fed25ff9fd69a25975ebe6acb297234eda501"}, - {file = "coverage-7.4.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:0842571634f39016a6c03e9d4aba502be652a6e4455fadb73cd3a3a49173e38f"}, - {file = "coverage-7.4.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a78ed23b08e8ab524551f52953a8a05d61c3a760781762aac49f8de6eede8c45"}, - {file = "coverage-7.4.3-cp39-cp39-win32.whl", hash = "sha256:c0524de3ff096e15fcbfe8f056fdb4ea0bf497d584454f344d59fce069d3e6e9"}, - {file = "coverage-7.4.3-cp39-cp39-win_amd64.whl", hash = "sha256:0209a6369ccce576b43bb227dc8322d8ef9e323d089c6f3f26a597b09cb4d2aa"}, - {file = "coverage-7.4.3-pp38.pp39.pp310-none-any.whl", hash = "sha256:7cbde573904625509a3f37b6fecea974e363460b556a627c60dc2f47e2fffa51"}, - {file = "coverage-7.4.3.tar.gz", hash = "sha256:276f6077a5c61447a48d133ed13e759c09e62aff0dc84274a68dc18660104d52"}, + {file = "coverage-7.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0be5efd5127542ef31f165de269f77560d6cdef525fffa446de6f7e9186cfb2"}, + {file = "coverage-7.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ccd341521be3d1b3daeb41960ae94a5e87abe2f46f17224ba5d6f2b8398016cf"}, + {file = "coverage-7.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fa497a8ab37784fbb20ab699c246053ac294d13fc7eb40ec007a5043ec91f8"}, + {file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1a93009cb80730c9bca5d6d4665494b725b6e8e157c1cb7f2db5b4b122ea562"}, + {file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:690db6517f09336559dc0b5f55342df62370a48f5469fabf502db2c6d1cffcd2"}, + {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:09c3255458533cb76ef55da8cc49ffab9e33f083739c8bd4f58e79fecfe288f7"}, + {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8ce1415194b4a6bd0cdcc3a1dfbf58b63f910dcb7330fe15bdff542c56949f87"}, + {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b91cbc4b195444e7e258ba27ac33769c41b94967919f10037e6355e998af255c"}, + {file = "coverage-7.4.4-cp310-cp310-win32.whl", hash = "sha256:598825b51b81c808cb6f078dcb972f96af96b078faa47af7dfcdf282835baa8d"}, + {file = "coverage-7.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:09ef9199ed6653989ebbcaacc9b62b514bb63ea2f90256e71fea3ed74bd8ff6f"}, + {file = "coverage-7.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f9f50e7ef2a71e2fae92774c99170eb8304e3fdf9c8c3c7ae9bab3e7229c5cf"}, + {file = "coverage-7.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:623512f8ba53c422fcfb2ce68362c97945095b864cda94a92edbaf5994201083"}, + {file = "coverage-7.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0513b9508b93da4e1716744ef6ebc507aff016ba115ffe8ecff744d1322a7b63"}, + {file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40209e141059b9370a2657c9b15607815359ab3ef9918f0196b6fccce8d3230f"}, + {file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a2b2b78c78293782fd3767d53e6474582f62443d0504b1554370bde86cc8227"}, + {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:73bfb9c09951125d06ee473bed216e2c3742f530fc5acc1383883125de76d9cd"}, + {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f384c3cc76aeedce208643697fb3e8437604b512255de6d18dae3f27655a384"}, + {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:54eb8d1bf7cacfbf2a3186019bcf01d11c666bd495ed18717162f7eb1e9dd00b"}, + {file = "coverage-7.4.4-cp311-cp311-win32.whl", hash = "sha256:cac99918c7bba15302a2d81f0312c08054a3359eaa1929c7e4b26ebe41e9b286"}, + {file = "coverage-7.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:b14706df8b2de49869ae03a5ccbc211f4041750cd4a66f698df89d44f4bd30ec"}, + {file = "coverage-7.4.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:201bef2eea65e0e9c56343115ba3814e896afe6d36ffd37bab783261db430f76"}, + {file = "coverage-7.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41c9c5f3de16b903b610d09650e5e27adbfa7f500302718c9ffd1c12cf9d6818"}, + {file = "coverage-7.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d898fe162d26929b5960e4e138651f7427048e72c853607f2b200909794ed978"}, + {file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ea79bb50e805cd6ac058dfa3b5c8f6c040cb87fe83de10845857f5535d1db70"}, + {file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce4b94265ca988c3f8e479e741693d143026632672e3ff924f25fab50518dd51"}, + {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00838a35b882694afda09f85e469c96367daa3f3f2b097d846a7216993d37f4c"}, + {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fdfafb32984684eb03c2d83e1e51f64f0906b11e64482df3c5db936ce3839d48"}, + {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:69eb372f7e2ece89f14751fbcbe470295d73ed41ecd37ca36ed2eb47512a6ab9"}, + {file = "coverage-7.4.4-cp312-cp312-win32.whl", hash = "sha256:137eb07173141545e07403cca94ab625cc1cc6bc4c1e97b6e3846270e7e1fea0"}, + {file = "coverage-7.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:d71eec7d83298f1af3326ce0ff1d0ea83c7cb98f72b577097f9083b20bdaf05e"}, + {file = "coverage-7.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ae728ff3b5401cc320d792866987e7e7e880e6ebd24433b70a33b643bb0384"}, + {file = "coverage-7.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc4f1358cb0c78edef3ed237ef2c86056206bb8d9140e73b6b89fbcfcbdd40e1"}, + {file = "coverage-7.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8130a2aa2acb8788e0b56938786c33c7c98562697bf9f4c7d6e8e5e3a0501e4a"}, + {file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf271892d13e43bc2b51e6908ec9a6a5094a4df1d8af0bfc360088ee6c684409"}, + {file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4cdc86d54b5da0df6d3d3a2f0b710949286094c3a6700c21e9015932b81447e"}, + {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ae71e7ddb7a413dd60052e90528f2f65270aad4b509563af6d03d53e979feafd"}, + {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:38dd60d7bf242c4ed5b38e094baf6401faa114fc09e9e6632374388a404f98e7"}, + {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa5b1c1bfc28384f1f53b69a023d789f72b2e0ab1b3787aae16992a7ca21056c"}, + {file = "coverage-7.4.4-cp38-cp38-win32.whl", hash = "sha256:dfa8fe35a0bb90382837b238fff375de15f0dcdb9ae68ff85f7a63649c98527e"}, + {file = "coverage-7.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:b2991665420a803495e0b90a79233c1433d6ed77ef282e8e152a324bbbc5e0c8"}, + {file = "coverage-7.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b799445b9f7ee8bf299cfaed6f5b226c0037b74886a4e11515e569b36fe310d"}, + {file = "coverage-7.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b4d33f418f46362995f1e9d4f3a35a1b6322cb959c31d88ae56b0298e1c22357"}, + {file = "coverage-7.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aadacf9a2f407a4688d700e4ebab33a7e2e408f2ca04dbf4aef17585389eff3e"}, + {file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c95949560050d04d46b919301826525597f07b33beba6187d04fa64d47ac82e"}, + {file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff7687ca3d7028d8a5f0ebae95a6e4827c5616b31a4ee1192bdfde697db110d4"}, + {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5fc1de20b2d4a061b3df27ab9b7c7111e9a710f10dc2b84d33a4ab25065994ec"}, + {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c74880fc64d4958159fbd537a091d2a585448a8f8508bf248d72112723974cbd"}, + {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:742a76a12aa45b44d236815d282b03cfb1de3b4323f3e4ec933acfae08e54ade"}, + {file = "coverage-7.4.4-cp39-cp39-win32.whl", hash = "sha256:d89d7b2974cae412400e88f35d86af72208e1ede1a541954af5d944a8ba46c57"}, + {file = "coverage-7.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:9ca28a302acb19b6af89e90f33ee3e1906961f94b54ea37de6737b7ca9d8827c"}, + {file = "coverage-7.4.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:b2c5edc4ac10a7ef6605a966c58929ec6c1bd0917fb8c15cb3363f65aa40e677"}, + {file = "coverage-7.4.4.tar.gz", hash = "sha256:c901df83d097649e257e803be22592aedfd5182f07b3cc87d640bbb9afd50f49"}, ] [package.dependencies] @@ -168,29 +173,29 @@ test = ["pytest (>=6)"] [[package]] name = "filelock" -version = "3.13.1" +version = "3.13.3" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, - {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"}, + {file = "filelock-3.13.3-py3-none-any.whl", hash = "sha256:5ffa845303983e7a0b7ae17636509bc97997d58afeafa72fb141a17b152284cb"}, + {file = "filelock-3.13.3.tar.gz", hash = "sha256:a79895a25bbefdf55d1a2a0a80968f7dbb28edcd6d4234a0afb3f37ecde4b546"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] name = "fsspec" -version = "2024.2.0" +version = "2024.3.1" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.2.0-py3-none-any.whl", hash = "sha256:817f969556fa5916bc682e02ca2045f96ff7f586d45110fcb76022063ad2c7d8"}, - {file = "fsspec-2024.2.0.tar.gz", hash = "sha256:b6ad1a679f760dda52b1168c859d01b7b80648ea6f7f7c7f5a8a91dc3f3ecb84"}, + {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"}, + {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"}, ] [package.extras] @@ -413,13 +418,13 @@ files = [ [[package]] name = "packaging" -version = "23.2" +version = "24.0" description = "Core utilities for Python packages" optional = false python-versions = ">=3.7" files = [ - {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, - {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, + {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, + {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, ] [[package]] @@ -465,13 +470,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pytest" -version = "8.0.2" +version = "8.1.1" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-8.0.2-py3-none-any.whl", hash = "sha256:edfaaef32ce5172d5466b5127b42e0d6d35ebbe4453f0e3505d96afd93f6b096"}, - {file = "pytest-8.0.2.tar.gz", hash = "sha256:d4051d623a2e0b7e51960ba963193b09ce6daeb9759a451844a21e4ddedfc1bd"}, + {file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"}, + {file = "pytest-8.1.1.tar.gz", hash = "sha256:ac978141a75948948817d360297b7aae0fcb9d6ff6bc9ec6d514b85d5a65c044"}, ] [package.dependencies] @@ -479,21 +484,21 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""} exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" -pluggy = ">=1.3.0,<2.0" -tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} +pluggy = ">=1.4,<2.0" +tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] -testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] [[package]] name = "pytest-cov" -version = "4.1.0" +version = "5.0.0" description = "Pytest plugin for measuring coverage." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, - {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, + {file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"}, + {file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"}, ] [package.dependencies] @@ -501,72 +506,34 @@ coverage = {version = ">=5.2.1", extras = ["toml"]} pytest = ">=4.6" [package.extras] -testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] [[package]] name = "ruff" -version = "0.3.0" +version = "0.3.5" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.3.0-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:7deb528029bacf845bdbb3dbb2927d8ef9b4356a5e731b10eef171e3f0a85944"}, - {file = "ruff-0.3.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e1e0d4381ca88fb2b73ea0766008e703f33f460295de658f5467f6f229658c19"}, - {file = "ruff-0.3.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f7dbba46e2827dfcb0f0cc55fba8e96ba7c8700e0a866eb8cef7d1d66c25dcb"}, - {file = "ruff-0.3.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:23dbb808e2f1d68eeadd5f655485e235c102ac6f12ad31505804edced2a5ae77"}, - {file = "ruff-0.3.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ef655c51f41d5fa879f98e40c90072b567c666a7114fa2d9fe004dffba00932"}, - {file = "ruff-0.3.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d0d3d7ef3d4f06433d592e5f7d813314a34601e6c5be8481cccb7fa760aa243e"}, - {file = "ruff-0.3.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b08b356d06a792e49a12074b62222f9d4ea2a11dca9da9f68163b28c71bf1dd4"}, - {file = "ruff-0.3.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9343690f95710f8cf251bee1013bf43030072b9f8d012fbed6ad702ef70d360a"}, - {file = "ruff-0.3.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1f3ed501a42f60f4dedb7805fa8d4534e78b4e196f536bac926f805f0743d49"}, - {file = "ruff-0.3.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:cc30a9053ff2f1ffb505a585797c23434d5f6c838bacfe206c0e6cf38c921a1e"}, - {file = "ruff-0.3.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:5da894a29ec018a8293d3d17c797e73b374773943e8369cfc50495573d396933"}, - {file = "ruff-0.3.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:755c22536d7f1889be25f2baf6fedd019d0c51d079e8417d4441159f3bcd30c2"}, - {file = "ruff-0.3.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:dd73fe7f4c28d317855da6a7bc4aa29a1500320818dd8f27df95f70a01b8171f"}, - {file = "ruff-0.3.0-py3-none-win32.whl", hash = "sha256:19eacceb4c9406f6c41af806418a26fdb23120dfe53583df76d1401c92b7c14b"}, - {file = "ruff-0.3.0-py3-none-win_amd64.whl", hash = "sha256:128265876c1d703e5f5e5a4543bd8be47c73a9ba223fd3989d4aa87dd06f312f"}, - {file = "ruff-0.3.0-py3-none-win_arm64.whl", hash = "sha256:e3a4a6d46aef0a84b74fcd201a4401ea9a6cd85614f6a9435f2d33dd8cefbf83"}, - {file = "ruff-0.3.0.tar.gz", hash = "sha256:0886184ba2618d815067cf43e005388967b67ab9c80df52b32ec1152ab49f53a"}, -] - -[[package]] -name = "scipy" -version = "1.9.3" -description = "Fundamental algorithms for scientific computing in Python" -optional = true -python-versions = ">=3.8" -files = [ - {file = "scipy-1.9.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1884b66a54887e21addf9c16fb588720a8309a57b2e258ae1c7986d4444d3bc0"}, - {file = "scipy-1.9.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:83b89e9586c62e787f5012e8475fbb12185bafb996a03257e9675cd73d3736dd"}, - {file = "scipy-1.9.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a72d885fa44247f92743fc20732ae55564ff2a519e8302fb7e18717c5355a8b"}, - {file = "scipy-1.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d01e1dd7b15bd2449c8bfc6b7cc67d630700ed655654f0dfcf121600bad205c9"}, - {file = "scipy-1.9.3-cp310-cp310-win_amd64.whl", hash = "sha256:68239b6aa6f9c593da8be1509a05cb7f9efe98b80f43a5861cd24c7557e98523"}, - {file = "scipy-1.9.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b41bc822679ad1c9a5f023bc93f6d0543129ca0f37c1ce294dd9d386f0a21096"}, - {file = "scipy-1.9.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:90453d2b93ea82a9f434e4e1cba043e779ff67b92f7a0e85d05d286a3625df3c"}, - {file = "scipy-1.9.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83c06e62a390a9167da60bedd4575a14c1f58ca9dfde59830fc42e5197283dab"}, - {file = "scipy-1.9.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abaf921531b5aeaafced90157db505e10345e45038c39e5d9b6c7922d68085cb"}, - {file = "scipy-1.9.3-cp311-cp311-win_amd64.whl", hash = "sha256:06d2e1b4c491dc7d8eacea139a1b0b295f74e1a1a0f704c375028f8320d16e31"}, - {file = "scipy-1.9.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5a04cd7d0d3eff6ea4719371cbc44df31411862b9646db617c99718ff68d4840"}, - {file = "scipy-1.9.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:545c83ffb518094d8c9d83cce216c0c32f8c04aaf28b92cc8283eda0685162d5"}, - {file = "scipy-1.9.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d54222d7a3ba6022fdf5773931b5d7c56efe41ede7f7128c7b1637700409108"}, - {file = "scipy-1.9.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cff3a5295234037e39500d35316a4c5794739433528310e117b8a9a0c76d20fc"}, - {file = "scipy-1.9.3-cp38-cp38-win_amd64.whl", hash = "sha256:2318bef588acc7a574f5bfdff9c172d0b1bf2c8143d9582e05f878e580a3781e"}, - {file = "scipy-1.9.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d644a64e174c16cb4b2e41dfea6af722053e83d066da7343f333a54dae9bc31c"}, - {file = "scipy-1.9.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:da8245491d73ed0a994ed9c2e380fd058ce2fa8a18da204681f2fe1f57f98f95"}, - {file = "scipy-1.9.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4db5b30849606a95dcf519763dd3ab6fe9bd91df49eba517359e450a7d80ce2e"}, - {file = "scipy-1.9.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c68db6b290cbd4049012990d7fe71a2abd9ffbe82c0056ebe0f01df8be5436b0"}, - {file = "scipy-1.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:5b88e6d91ad9d59478fafe92a7c757d00c59e3bdc3331be8ada76a4f8d683f58"}, - {file = "scipy-1.9.3.tar.gz", hash = "sha256:fbc5c05c85c1a02be77b1ff591087c83bc44579c6d2bd9fb798bb64ea5e1a027"}, + {file = "ruff-0.3.5-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:aef5bd3b89e657007e1be6b16553c8813b221ff6d92c7526b7e0227450981eac"}, + {file = "ruff-0.3.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:89b1e92b3bd9fca249153a97d23f29bed3992cff414b222fcd361d763fc53f12"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e55771559c89272c3ebab23326dc23e7f813e492052391fe7950c1a5a139d89"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dabc62195bf54b8a7876add6e789caae0268f34582333cda340497c886111c39"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a05f3793ba25f194f395578579c546ca5d83e0195f992edc32e5907d142bfa3"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:dfd3504e881082959b4160ab02f7a205f0fadc0a9619cc481982b6837b2fd4c0"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87258e0d4b04046cf1d6cc1c56fadbf7a880cc3de1f7294938e923234cf9e498"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:712e71283fc7d9f95047ed5f793bc019b0b0a29849b14664a60fd66c23b96da1"}, + {file = "ruff-0.3.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a532a90b4a18d3f722c124c513ffb5e5eaff0cc4f6d3aa4bda38e691b8600c9f"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:122de171a147c76ada00f76df533b54676f6e321e61bd8656ae54be326c10296"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d80a6b18a6c3b6ed25b71b05eba183f37d9bc8b16ace9e3d700997f00b74660b"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a7b6e63194c68bca8e71f81de30cfa6f58ff70393cf45aab4c20f158227d5936"}, + {file = "ruff-0.3.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a759d33a20c72f2dfa54dae6e85e1225b8e302e8ac655773aff22e542a300985"}, + {file = "ruff-0.3.5-py3-none-win32.whl", hash = "sha256:9d8605aa990045517c911726d21293ef4baa64f87265896e491a05461cae078d"}, + {file = "ruff-0.3.5-py3-none-win_amd64.whl", hash = "sha256:dc56bb16a63c1303bd47563c60482a1512721053d93231cf7e9e1c6954395a0e"}, + {file = "ruff-0.3.5-py3-none-win_arm64.whl", hash = "sha256:faeeae9905446b975dcf6d4499dc93439b131f1443ee264055c5716dd947af55"}, + {file = "ruff-0.3.5.tar.gz", hash = "sha256:a067daaeb1dc2baf9b82a32dae67d154d95212080c80435eb052d95da647763d"}, ] -[package.dependencies] -numpy = ">=1.18.5,<1.26.0" - -[package.extras] -dev = ["flake8", "mypy", "pycodestyle", "typing_extensions"] -doc = ["matplotlib (>2)", "numpydoc", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-panels (>=0.5.2)", "sphinx-tabs"] -test = ["asv", "gmpy2", "mpmath", "pytest", "pytest-cov", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] - [[package]] name = "sympy" version = "1.12" @@ -594,21 +561,21 @@ files = [ [[package]] name = "torch" -version = "2.2.1+cpu" +version = "2.2.2+cpu" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.2.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:5d82422cf04797f1b2a8574b64a916070ec83eef58ad4900615ee0218d7b8b8e"}, - {file = "torch-2.2.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:f8914dd0f5f0e5c66fdecd9559403eea9feac82d1ea639b672fde0073c6addbd"}, - {file = "torch-2.2.1+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:6bc973d5632374b92b4b293817b4d2ff8c8ce1c784c748b471dba1fffcd9c333"}, - {file = "torch-2.2.1+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:abdec34b0ade8fca0520055e72c3094425ae0ef210718e9c0278121cd3608c32"}, - {file = "torch-2.2.1+cpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:d7339580135da4105c1244a8621faa076990409afeab5a7b642c3c1ee70a5622"}, - {file = "torch-2.2.1+cpu-cp312-cp312-win_amd64.whl", hash = "sha256:039128fcb5548122465b15f679b8831c47d14f0d6c28c1f1b631f8019c104720"}, - {file = "torch-2.2.1+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:2b447f7bb50b393b4544b4036d587e39ab524d4353e77c197f6a2727f22b0d47"}, - {file = "torch-2.2.1+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:2ccdf3e5f71e6426ea9e34d21c3cc333b29d4f48299b981d28aeb5112b5495e1"}, - {file = "torch-2.2.1+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:2fb340b289760040a16a77a6d70b8a48961abba1822e6f58705c97c80befa03e"}, - {file = "torch-2.2.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:e03dc4654ecceeb5b03f0a6f60b342c0e0d267b3ebc61e4f672cace1df8cd930"}, + {file = "torch-2.2.2+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:02c4fac3c964e73f5f49003e0060c697f73b67c10cc23f51c592facb29e1bd53"}, + {file = "torch-2.2.2+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:fc29dda2795dd7220d769c5926b1c50ddac9b4827897e30a10467063691cdf54"}, + {file = "torch-2.2.2+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:90089cae572672fb449c8ff1dc1b29daaffa117bf97ede7463dcd2fd1b991e4c"}, + {file = "torch-2.2.2+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:88e63c916e3275fa30a220ee736423a95573b96072ded85e5c0171fd8f37a755"}, + {file = "torch-2.2.2+cpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:431a747b5a880cf8e1fb6d58db6bfafa6768cbec76517d046854537c03323edf"}, + {file = "torch-2.2.2+cpu-cp312-cp312-win_amd64.whl", hash = "sha256:2b0cf041f878607a361116945f82ce2dba4b7a747151da7619a63cb5fccb72df"}, + {file = "torch-2.2.2+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:8914ce932168e572a09b4a7e5b0806d279f771dfe58d7e1d8de2291fac4ce69b"}, + {file = "torch-2.2.2+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:4ef2911ffde6d86f643c23aa99f25f1a1df8bee93bf8d0c69cf1b9ba0ca521dc"}, + {file = "torch-2.2.2+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:6e3d323a21df22415770e88d39e13591079b9356dabb8b394d1ee29ac6c92481"}, + {file = "torch-2.2.2+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:c2c9e7d5e3c7d58e4b78d6aebfa8002af7cda16cde08d0e3ed00300dc21a8efc"}, ] [package.dependencies] @@ -630,13 +597,13 @@ reference = "torch" [[package]] name = "typing-extensions" -version = "4.10.0" +version = "4.11.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"}, - {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, + {file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"}, + {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"}, ] [extras] @@ -645,4 +612,4 @@ bitsandbytes = ["bitsandbytes"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0.0" -content-hash = "d7d8dc0d66f37e2f167551d9be0775c778f251db9faeef427a5d66c2a2396d38" +content-hash = "d51586f8352db14a18dd407b19285c9649564b029e6e6aae52a0d566515e5c81" diff --git a/pyproject.toml b/pyproject.toml index 57eadffc3..646597102 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,11 +13,11 @@ keywords = [ "pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound", "AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "CAME", "DAdaptAdaGrad", "DAdaptAdam", - "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "Gravity", "GSAM", "LARS", "Lamb", "Lion", "LOMO", - "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", - "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", - "SopihaH", "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", - "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", + "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "GaLore", "Gravity", "GSAM", "LARS", "Lamb", "Lion", + "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", + "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", + "SM3", "SopihaH", "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", + "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", ] classifiers = [ "License :: OSI Approved :: Apache Software License", @@ -45,14 +45,14 @@ classifiers = [ python = ">=3.8,<4.0.0" numpy = { version = "*", python = ">=3.8" } torch = { version = ">=1.10", python = ">=3.8", source = "torch" } -bitsandbytes = { version = "^0.42", optional = true } +bitsandbytes = { version = "^0.43", optional = true } [tool.poetry.dev-dependencies] isort = { version = "^5", python = ">=3.8" } black = { version = "^24", python = ">=3.8"} -ruff = "^0.3" -pytest = "^8" -pytest-cov = "^4" +ruff = "*" +pytest = "*" +pytest-cov = "*" [tool.poetry.extras] bitsandbytes = ["bitsandbytes"] diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 8ccdb7c36..f0f1c24ef 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -55,6 +55,7 @@ from pytorch_optimizer.optimizer.diffgrad import DiffGrad from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer from pytorch_optimizer.optimizer.fromage import Fromage +from pytorch_optimizer.optimizer.galore import GaLore, GaLoreProjector from pytorch_optimizer.optimizer.gc import centralize_gradient from pytorch_optimizer.optimizer.gravity import Gravity from pytorch_optimizer.optimizer.lamb import Lamb @@ -182,6 +183,7 @@ CAME, DAdaptLion, Aida, + GaLore, ] OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST} diff --git a/pytorch_optimizer/optimizer/galore.py b/pytorch_optimizer/optimizer/galore.py new file mode 100644 index 000000000..ced850c2d --- /dev/null +++ b/pytorch_optimizer/optimizer/galore.py @@ -0,0 +1,249 @@ +import math +from typing import Literal, Optional, Tuple, Union + +import torch +from torch.optim.optimizer import Optimizer + +from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS + +PROJECTION_TYPE = Literal['std', 'reverse_std', 'right', 'left', 'full'] + + +class GaLoreProjector: + r"""Memory-Efficient LLM Training by Gradient Low-Rank Projection. + + :param rank: int. low rank to project. + :param update_proj_gap: int. num steps to update the projection. + :param scale: float. scale factor. + :param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' are + supported. + """ + + def __init__( + self, rank: int = 128, update_proj_gap: int = 50, scale: float = 1.0, projection_type: PROJECTION_TYPE = 'std' + ): + self.rank = rank + self.update_proj_gap = update_proj_gap + self.scale = scale + self.projection_type = projection_type + + self.ortho_matrix: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None + + @staticmethod + def get_orthogonal_matrix( + weights: torch.Tensor, rank: int, projection_type: str + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if projection_type not in {'right', 'left', 'full'}: + raise ValueError('projection_type should be one of left, right or full') + + original_type = weights.data.dtype + original_device = weights.data.device + is_float: bool = original_type == torch.float + + u, s, vh = torch.linalg.svd(weights if is_float else weights.float(), full_matrices=False) + + if projection_type == 'right': + b = vh[:rank, :] + return b if is_float else b.to(original_device).type(original_type) + if projection_type == 'left': + a = u[:, :rank] + return a if is_float else a.to(original_device).type(original_type) + + a = u[:, :rank] + b = vh[:rank, :] + + return ( + (a, b) + if is_float + else (a.to(original_device).type(original_type), b.to(original_device).type(original_type)) + ) + + def get_low_rank_grad_std(self, grad: torch.Tensor, steps: int) -> torch.Tensor: + if grad.shape[0] >= grad.shape[1]: + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right') + return torch.matmul(grad, self.ortho_matrix.t()) + + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left') + + return torch.matmul(self.ortho_matrix.t(), grad) + + def get_low_rank_grad_reverse_std(self, grad: torch.Tensor, steps: int) -> torch.Tensor: + if grad.shape[0] >= grad.shape[1]: + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left') + return torch.matmul(self.ortho_matrix.t(), grad) + + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right') + + return torch.matmul(grad, self.ortho_matrix.t()) + + def get_low_rank_grad_right(self, grad: torch.Tensor, steps: int) -> torch.Tensor: + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right') + return torch.matmul(grad, self.ortho_matrix.t()) + + def get_low_rank_grad_left(self, grad: torch.Tensor, steps: int) -> torch.Tensor: + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left') + return torch.matmul(self.ortho_matrix.t(), grad) + + def get_low_rank_grad_full(self, grad: torch.Tensor, steps: int) -> torch.Tensor: + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='full') + return torch.matmul(self.ortho_matrix[0].t(), grad) @ self.ortho_matrix[1].t() + + def project(self, full_rank_grad: torch.Tensor, steps: int) -> torch.Tensor: + if self.projection_type == 'std': + return self.get_low_rank_grad_std(full_rank_grad, steps) + if self.projection_type == 'reverse_std': + return self.get_low_rank_grad_reverse_std(full_rank_grad, steps) + if self.projection_type == 'right': + return self.get_low_rank_grad_right(full_rank_grad, steps) + if self.projection_type == 'left': + return self.get_low_rank_grad_left(full_rank_grad, steps) + if self.projection_type == 'full': + return self.get_low_rank_grad_full(full_rank_grad, steps) + raise NotImplementedError + + def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor: + if self.projection_type == 'std': + return ( + torch.matmul(low_rank_grad, self.ortho_matrix) + if low_rank_grad.shape[0] >= low_rank_grad.shape[1] + else torch.matmul(self.ortho_matrix, low_rank_grad) + ) * self.scale + if self.projection_type == 'reverse_std': + return ( + torch.matmul(self.ortho_matrix, low_rank_grad.t()) + if low_rank_grad.shape[0] <= low_rank_grad.shape[1] + else torch.matmul(low_rank_grad, self.ortho_matrix.t()) + ) * self.scale + if self.projection_type == 'right': + return torch.matmul(low_rank_grad, self.ortho_matrix.t()) * self.scale + if self.projection_type == 'left': + return torch.matmul(self.ortho_matrix, low_rank_grad) * self.scale + if self.projection_type == 'full': + return torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1].t() * self.scale + + raise NotImplementedError + + +class GaLore(Optimizer, BaseOptimizer): + r"""AdamW optimizer with GaLore projector. + + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. + :param lr: float. learning rate. + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. + :param weight_decay: float. weight decay (L2 penalty). + :param eps: float. term added to the denominator to improve numerical stability. + """ + + def __init__( + self, + params: PARAMETERS, + lr: float = 1e-3, + betas: BETAS = (0.9, 0.999), + weight_decay: float = 0.0, + eps: float = 1e-6, + **kwargs, + ): + self.validate_learning_rate(lr) + self.validate_betas(betas) + self.validate_non_negative(weight_decay, 'weight_decay') + self.validate_non_negative(eps, 'eps') + + defaults: DEFAULTS = { + 'lr': lr, + 'betas': betas, + 'weight_decay': weight_decay, + 'eps': eps, + **kwargs, + } + + super().__init__(params, defaults) + + def __str__(self) -> str: + return 'GaLore' + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure: CLOSURE = None) -> LOSS: + loss: LOSS = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + beta1, beta2 = group['betas'] + + bias_correction1: float = 1.0 - beta1 ** group['step'] + bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + + step_size: float = group['lr'] * bias_correction2_sq / bias_correction1 + + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + if grad.is_sparse: + raise NoSparseGradientError(str(self)) + + state = self.state[p] + + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + + if 'rank' in group and p.dim() > 1: + if 'projector' not in state: + state['projector'] = GaLoreProjector( + rank=group['rank'], + update_proj_gap=group['update_proj_gap'], + scale=group['scale'], + projection_type=group['projection_type'], + ) + + grad = state['projector'].project(grad, group['step']) + + self.apply_weight_decay( + p=p, + grad=None, + lr=group['lr'], + weight_decay=group['weight_decay'], + weight_decouple=True, + fixed_decay=False, + ) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + de_nom = exp_avg_sq.sqrt().add_(group['eps']) + + norm_grad = exp_avg / de_nom + + if 'rank' in group and p.dim() > 1: + norm_grad = state['projector'].project_back(norm_grad) + + p.add_(norm_grad, alpha=-step_size) + + return loss diff --git a/requirements-dev.txt b/requirements-dev.txt index c6157e8fc..6ed086a44 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,12 +1,12 @@ --extra-index-url https://download.pytorch.org/whl/cpu -black==24.2.0 ; python_version >= "3.8" and python_full_version < "4.0.0" +black==24.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0" click==8.1.7 ; python_version >= "3.8" and python_full_version < "4.0.0" colorama==0.4.6 ; python_version >= "3.8" and python_full_version < "4.0.0" and (sys_platform == "win32" or platform_system == "Windows") -coverage[toml]==7.4.3 ; python_version >= "3.8" and python_full_version < "4.0.0" +coverage[toml]==7.4.4 ; python_version >= "3.8" and python_full_version < "4.0.0" exceptiongroup==1.2.0 ; python_version >= "3.8" and python_version < "3.11" -filelock==3.13.1 ; python_version >= "3.8" and python_full_version < "4.0.0" -fsspec==2024.2.0 ; python_version >= "3.8" and python_full_version < "4.0.0" +filelock==3.13.3 ; python_version >= "3.8" and python_full_version < "4.0.0" +fsspec==2024.3.1 ; python_version >= "3.8" and python_full_version < "4.0.0" iniconfig==2.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0" isort==5.13.2 ; python_version >= "3.8" and python_full_version < "4.0.0" jinja2==3.1.3 ; python_version >= "3.8" and python_full_version < "4.0.0" @@ -15,14 +15,14 @@ mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0" mypy-extensions==1.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0" networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0" numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0" -packaging==23.2 ; python_version >= "3.8" and python_full_version < "4.0.0" +packaging==24.0 ; python_version >= "3.8" and python_full_version < "4.0.0" pathspec==0.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0" platformdirs==4.2.0 ; python_version >= "3.8" and python_full_version < "4.0.0" pluggy==1.4.0 ; python_version >= "3.8" and python_full_version < "4.0.0" -pytest-cov==4.1.0 ; python_version >= "3.8" and python_full_version < "4.0.0" -pytest==8.0.2 ; python_version >= "3.8" and python_full_version < "4.0.0" -ruff==0.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0" +pytest-cov==5.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0" +pytest==8.1.1 ; python_version >= "3.8" and python_full_version < "4.0.0" +ruff==0.3.5 ; python_version >= "3.8" and python_full_version < "4.0.0" sympy==1.12 ; python_version >= "3.8" and python_full_version < "4.0.0" tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6" -torch==2.2.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" -typing-extensions==4.10.0 ; python_version >= "3.8" and python_full_version < "4.0.0" +torch==2.2.2+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" +typing-extensions==4.11.0 ; python_version >= "3.8" and python_full_version < "4.0.0" diff --git a/requirements.txt b/requirements.txt index b6e6a1df3..fe27e8f9a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ --extra-index-url https://download.pytorch.org/whl/cpu -filelock==3.13.1 ; python_version >= "3.8" and python_full_version < "4.0.0" -fsspec==2024.2.0 ; python_version >= "3.8" and python_full_version < "4.0.0" +filelock==3.13.3 ; python_version >= "3.8" and python_full_version < "4.0.0" +fsspec==2024.3.1 ; python_version >= "3.8" and python_full_version < "4.0.0" jinja2==3.1.3 ; python_version >= "3.8" and python_full_version < "4.0.0" markupsafe==2.1.5 ; python_version >= "3.8" and python_full_version < "4.0.0" mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0" networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0" numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0" sympy==1.12 ; python_version >= "3.8" and python_full_version < "4.0.0" -torch==2.2.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" -typing-extensions==4.10.0 ; python_version >= "3.8" and python_full_version < "4.0.0" +torch==2.2.2+cpu ; python_version >= "3.8" and python_full_version < "4.0.0" +typing-extensions==4.11.0 ; python_version >= "3.8" and python_full_version < "4.0.0" diff --git a/tests/constants.py b/tests/constants.py index 926dfb9b0..f7ffc4b46 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -45,6 +45,7 @@ DAdaptSGD, DiffGrad, Fromage, + GaLore, Gravity, Lamb, Lion, @@ -401,6 +402,38 @@ (CAME, {'lr': 7.5e-1, 'weight_decay': 1e-3}, 75), (CAME, {'lr': 7.5e-1, 'weight_decay': 1e-3, 'ams_bound': True}, 75), (Aida, {'lr': 1e0, 'weight_decay': 1e-3, 'ams_bound': True}, 5), + ( + GaLore, + {'lr': 1e0, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 1, 'projection_type': 'std'}, + 5, + ), + ( + GaLore, + { + 'lr': 1e0, + 'weight_decay': 1e-3, + 'rank': 2, + 'scale': 1.0, + 'update_proj_gap': 1, + 'projection_type': 'reverse_std', + }, + 5, + ), + ( + GaLore, + {'lr': 5e-1, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 2, 'projection_type': 'left'}, + 5, + ), + ( + GaLore, + {'lr': 1e0, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 1, 'projection_type': 'right'}, + 5, + ), + ( + GaLore, + {'lr': 5e-1, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 2, 'projection_type': 'full'}, + 5, + ), ] ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [ (AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10), diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index 57b4fd8d9..b7812180c 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names): def test_get_supported_optimizers(): - assert len(get_supported_optimizers()) == 61 + assert len(get_supported_optimizers()) == 62 def test_get_supported_lr_schedulers(): diff --git a/tests/test_optimizer_parameters.py b/tests/test_optimizer_parameters.py index 26d7ff59c..433b3d988 100644 --- a/tests/test_optimizer_parameters.py +++ b/tests/test_optimizer_parameters.py @@ -2,7 +2,16 @@ import torch from torch import nn -from pytorch_optimizer import SAM, WSAM, Lookahead, PCGrad, Ranger21, SafeFP16Optimizer, load_optimizer +from pytorch_optimizer import ( + SAM, + WSAM, + GaLoreProjector, + Lookahead, + PCGrad, + Ranger21, + SafeFP16Optimizer, + load_optimizer, +) from tests.constants import PULLBACK_MOMENTUM from tests.utils import Example, simple_parameter, simple_zero_rank_parameter @@ -254,3 +263,16 @@ def test_ranger_parameters(): # test lookahead step `k` with pytest.raises(ValueError): opt(None, k=-1) + + +def test_galore_projection_type(): + p = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + + with pytest.raises(NotImplementedError): + GaLoreProjector(projection_type='invalid').project(p, 1) + + with pytest.raises(NotImplementedError): + GaLoreProjector(projection_type='invalid').project_back(p) + + with pytest.raises(ValueError): + GaLoreProjector.get_orthogonal_matrix(p, 1, projection_type='std')