diff --git a/README.md b/README.md
index 0633c3746..dbfc1a1d2 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,7 @@
**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
-Currently, **62 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported!
+Currently, **63 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported!
Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).
@@ -160,6 +160,7 @@ supported_optimizers = get_supported_optimizers()
| CAME | *Confidence-guided Adaptive Memory Efficient Optimization* | [github](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/CAME) | | [cite](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/CAME#citation) |
| WSAM | *Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term* | [github](https://github.com/intelligent-machine-learning/dlrover/blob/master/atorch/atorch/optimizers/wsam.py) | | [cite](https://github.com/intelligent-machine-learning/dlrover) |
| Aida | *A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range* | [github](https://github.com/guoqiang-zhang-x/Aida-Optimizer) | | [cite](https://github.com/guoqiang-zhang-x/Aida-Optimizer?tab=readme-ov-file#1-brief-description-of-aida) |
+| GaLore | *Memory-Efficient LLM Training by Gradient Low-Rank Projection* | [github](https://github.com/jiaweizzhao/GaLore) | | [cite](https://github.com/jiaweizzhao/GaLore/tree/master?tab=readme-ov-file#citation) |
## Supported LR Scheduler
diff --git a/docs/changelogs/v3.0.0.md b/docs/changelogs/v3.0.0.md
index 7d4980be1..0de82f3c9 100644
--- a/docs/changelogs/v3.0.0.md
+++ b/docs/changelogs/v3.0.0.md
@@ -10,21 +10,28 @@ Major version is updated! (`v2.12.0` -> `v3.0.0`) (#164)
* [A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range](https://arxiv.org/abs/2203.13273)
* Implement `WSAM` optimizer. (#213, #216)
* [Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term](https://arxiv.org/abs/2305.15817)
+* Implement `GaLore` optimizer. (#224, #228)
+ * [Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507)
-## Dependency
+### Fix
+
+* Fix SRMM to allow operation beyond memory_length. (#227)
+
+### Dependency
* Drop `Python 3.7` support officially. (#221)
* Please check the [README](https://github.com/kozistr/pytorch_optimizer?tab=readme-ov-file#getting-started).
+* Update `bitsandbytes` to `0.43.0`. (#228)
-## Docs
+### Docs
* Add missing parameters in `Ranger21 optimizer` document. (#214, #215)
* Fix `WSAM` optimizer paper link. (#219)
-### Contributions
+## Contributions
-thanks to @sdbds
+thanks to @sdbds, @i404788
-### Diff
+## Diff
[2.12.0...3.0.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.12.0...v3.0.0)
diff --git a/docs/optimizer.md b/docs/optimizer.md
index 776ea8b70..14d326b4e 100644
--- a/docs/optimizer.md
+++ b/docs/optimizer.md
@@ -132,6 +132,10 @@
:docstring:
:members:
+::: pytorch_optimizer.GaLoreProjector
+ :docstring:
+ :members:
+
::: pytorch_optimizer.centralize_gradient
:docstring:
:members:
diff --git a/poetry.lock b/poetry.lock
index f40e104c6..f8cc69312 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -2,47 +2,52 @@
[[package]]
name = "bitsandbytes"
-version = "0.42.0"
+version = "0.43.0"
description = "k-bit optimizers and matrix multiplication routines."
optional = true
python-versions = "*"
files = [
- {file = "bitsandbytes-0.42.0-py3-none-any.whl", hash = "sha256:63798680912cc63bb77b535a2d0860af024e290a52e157f777ad2a52e2585967"},
- {file = "bitsandbytes-0.42.0.tar.gz", hash = "sha256:fc1505f184f0d275766f2a6c663f1a43b734c1409b5c5a406f3a6073d9f329fd"},
+ {file = "bitsandbytes-0.43.0-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:b2626ada0ae447ae0cf3dd0be8f5b0abad7abdec7056c7fb738aa13a5a862007"},
+ {file = "bitsandbytes-0.43.0-py3-none-win_amd64.whl", hash = "sha256:6fa7f3255fe9f3e549fb110bc60794079761a4e608b5fb86ebe7b4047467dd99"},
]
[package.dependencies]
-scipy = "*"
+numpy = "*"
+torch = "*"
+
+[package.extras]
+benchmark = ["matplotlib", "pandas"]
+test = ["scipy"]
[[package]]
name = "black"
-version = "24.2.0"
+version = "24.3.0"
description = "The uncompromising code formatter."
optional = false
python-versions = ">=3.8"
files = [
- {file = "black-24.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6981eae48b3b33399c8757036c7f5d48a535b962a7c2310d19361edeef64ce29"},
- {file = "black-24.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d533d5e3259720fdbc1b37444491b024003e012c5173f7d06825a77508085430"},
- {file = "black-24.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61a0391772490ddfb8a693c067df1ef5227257e72b0e4108482b8d41b5aee13f"},
- {file = "black-24.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:992e451b04667116680cb88f63449267c13e1ad134f30087dec8527242e9862a"},
- {file = "black-24.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:163baf4ef40e6897a2a9b83890e59141cc8c2a98f2dda5080dc15c00ee1e62cd"},
- {file = "black-24.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e37c99f89929af50ffaf912454b3e3b47fd64109659026b678c091a4cd450fb2"},
- {file = "black-24.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9de21bafcba9683853f6c96c2d515e364aee631b178eaa5145fc1c61a3cc92"},
- {file = "black-24.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:9db528bccb9e8e20c08e716b3b09c6bdd64da0dd129b11e160bf082d4642ac23"},
- {file = "black-24.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d84f29eb3ee44859052073b7636533ec995bd0f64e2fb43aeceefc70090e752b"},
- {file = "black-24.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e08fb9a15c914b81dd734ddd7fb10513016e5ce7e6704bdd5e1251ceee51ac9"},
- {file = "black-24.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:810d445ae6069ce64030c78ff6127cd9cd178a9ac3361435708b907d8a04c693"},
- {file = "black-24.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:ba15742a13de85e9b8f3239c8f807723991fbfae24bad92d34a2b12e81904982"},
- {file = "black-24.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7e53a8c630f71db01b28cd9602a1ada68c937cbf2c333e6ed041390d6968faf4"},
- {file = "black-24.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:93601c2deb321b4bad8f95df408e3fb3943d85012dddb6121336b8e24a0d1218"},
- {file = "black-24.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0057f800de6acc4407fe75bb147b0c2b5cbb7c3ed110d3e5999cd01184d53b0"},
- {file = "black-24.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:faf2ee02e6612577ba0181f4347bcbcf591eb122f7841ae5ba233d12c39dcb4d"},
- {file = "black-24.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:057c3dc602eaa6fdc451069bd027a1b2635028b575a6c3acfd63193ced20d9c8"},
- {file = "black-24.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:08654d0797e65f2423f850fc8e16a0ce50925f9337fb4a4a176a7aa4026e63f8"},
- {file = "black-24.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca610d29415ee1a30a3f30fab7a8f4144e9d34c89a235d81292a1edb2b55f540"},
- {file = "black-24.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:4dd76e9468d5536abd40ffbc7a247f83b2324f0c050556d9c371c2b9a9a95e31"},
- {file = "black-24.2.0-py3-none-any.whl", hash = "sha256:e8a6ae970537e67830776488bca52000eaa37fa63b9988e8c487458d9cd5ace6"},
- {file = "black-24.2.0.tar.gz", hash = "sha256:bce4f25c27c3435e4dace4815bcb2008b87e167e3bf4ee47ccdc5ce906eb4894"},
+ {file = "black-24.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7d5e026f8da0322b5662fa7a8e752b3fa2dac1c1cbc213c3d7ff9bdd0ab12395"},
+ {file = "black-24.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9f50ea1132e2189d8dff0115ab75b65590a3e97de1e143795adb4ce317934995"},
+ {file = "black-24.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2af80566f43c85f5797365077fb64a393861a3730bd110971ab7a0c94e873e7"},
+ {file = "black-24.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:4be5bb28e090456adfc1255e03967fb67ca846a03be7aadf6249096100ee32d0"},
+ {file = "black-24.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4f1373a7808a8f135b774039f61d59e4be7eb56b2513d3d2f02a8b9365b8a8a9"},
+ {file = "black-24.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:aadf7a02d947936ee418777e0247ea114f78aff0d0959461057cae8a04f20597"},
+ {file = "black-24.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c02e4ea2ae09d16314d30912a58ada9a5c4fdfedf9512d23326128ac08ac3d"},
+ {file = "black-24.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:bf21b7b230718a5f08bd32d5e4f1db7fc8788345c8aea1d155fc17852b3410f5"},
+ {file = "black-24.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:2818cf72dfd5d289e48f37ccfa08b460bf469e67fb7c4abb07edc2e9f16fb63f"},
+ {file = "black-24.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4acf672def7eb1725f41f38bf6bf425c8237248bb0804faa3965c036f7672d11"},
+ {file = "black-24.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7ed6668cbbfcd231fa0dc1b137d3e40c04c7f786e626b405c62bcd5db5857e4"},
+ {file = "black-24.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:56f52cfbd3dabe2798d76dbdd299faa046a901041faf2cf33288bc4e6dae57b5"},
+ {file = "black-24.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:79dcf34b33e38ed1b17434693763301d7ccbd1c5860674a8f871bd15139e7837"},
+ {file = "black-24.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e19cb1c6365fd6dc38a6eae2dcb691d7d83935c10215aef8e6c38edee3f77abd"},
+ {file = "black-24.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65b76c275e4c1c5ce6e9870911384bff5ca31ab63d19c76811cb1fb162678213"},
+ {file = "black-24.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:b5991d523eee14756f3c8d5df5231550ae8993e2286b8014e2fdea7156ed0959"},
+ {file = "black-24.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c45f8dff244b3c431b36e3224b6be4a127c6aca780853574c00faf99258041eb"},
+ {file = "black-24.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6905238a754ceb7788a73f02b45637d820b2f5478b20fec82ea865e4f5d4d9f7"},
+ {file = "black-24.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7de8d330763c66663661a1ffd432274a2f92f07feeddd89ffd085b5744f85e7"},
+ {file = "black-24.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:7bb041dca0d784697af4646d3b62ba4a6b028276ae878e53f6b4f74ddd6db99f"},
+ {file = "black-24.3.0-py3-none-any.whl", hash = "sha256:41622020d7120e01d377f74249e677039d20e6344ff5851de8a10f11f513bf93"},
+ {file = "black-24.3.0.tar.gz", hash = "sha256:a0c9c4a0771afc6919578cec71ce82a3e31e054904e7197deacbc9382671c41f"},
]
[package.dependencies]
@@ -87,63 +92,63 @@ files = [
[[package]]
name = "coverage"
-version = "7.4.3"
+version = "7.4.4"
description = "Code coverage measurement for Python"
optional = false
python-versions = ">=3.8"
files = [
- {file = "coverage-7.4.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8580b827d4746d47294c0e0b92854c85a92c2227927433998f0d3320ae8a71b6"},
- {file = "coverage-7.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:718187eeb9849fc6cc23e0d9b092bc2348821c5e1a901c9f8975df0bc785bfd4"},
- {file = "coverage-7.4.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:767b35c3a246bcb55b8044fd3a43b8cd553dd1f9f2c1eeb87a302b1f8daa0524"},
- {file = "coverage-7.4.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae7f19afe0cce50039e2c782bff379c7e347cba335429678450b8fe81c4ef96d"},
- {file = "coverage-7.4.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba3a8aaed13770e970b3df46980cb068d1c24af1a1968b7818b69af8c4347efb"},
- {file = "coverage-7.4.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ee866acc0861caebb4f2ab79f0b94dbfbdbfadc19f82e6e9c93930f74e11d7a0"},
- {file = "coverage-7.4.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:506edb1dd49e13a2d4cac6a5173317b82a23c9d6e8df63efb4f0380de0fbccbc"},
- {file = "coverage-7.4.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd6545d97c98a192c5ac995d21c894b581f1fd14cf389be90724d21808b657e2"},
- {file = "coverage-7.4.3-cp310-cp310-win32.whl", hash = "sha256:f6a09b360d67e589236a44f0c39218a8efba2593b6abdccc300a8862cffc2f94"},
- {file = "coverage-7.4.3-cp310-cp310-win_amd64.whl", hash = "sha256:18d90523ce7553dd0b7e23cbb28865db23cddfd683a38fb224115f7826de78d0"},
- {file = "coverage-7.4.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cbbe5e739d45a52f3200a771c6d2c7acf89eb2524890a4a3aa1a7fa0695d2a47"},
- {file = "coverage-7.4.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:489763b2d037b164846ebac0cbd368b8a4ca56385c4090807ff9fad817de4113"},
- {file = "coverage-7.4.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:451f433ad901b3bb00184d83fd83d135fb682d780b38af7944c9faeecb1e0bfe"},
- {file = "coverage-7.4.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fcc66e222cf4c719fe7722a403888b1f5e1682d1679bd780e2b26c18bb648cdc"},
- {file = "coverage-7.4.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3ec74cfef2d985e145baae90d9b1b32f85e1741b04cd967aaf9cfa84c1334f3"},
- {file = "coverage-7.4.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:abbbd8093c5229c72d4c2926afaee0e6e3140de69d5dcd918b2921f2f0c8baba"},
- {file = "coverage-7.4.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:35eb581efdacf7b7422af677b92170da4ef34500467381e805944a3201df2079"},
- {file = "coverage-7.4.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8249b1c7334be8f8c3abcaaa996e1e4927b0e5a23b65f5bf6cfe3180d8ca7840"},
- {file = "coverage-7.4.3-cp311-cp311-win32.whl", hash = "sha256:cf30900aa1ba595312ae41978b95e256e419d8a823af79ce670835409fc02ad3"},
- {file = "coverage-7.4.3-cp311-cp311-win_amd64.whl", hash = "sha256:18c7320695c949de11a351742ee001849912fd57e62a706d83dfc1581897fa2e"},
- {file = "coverage-7.4.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b51bfc348925e92a9bd9b2e48dad13431b57011fd1038f08316e6bf1df107d10"},
- {file = "coverage-7.4.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d6cdecaedea1ea9e033d8adf6a0ab11107b49571bbb9737175444cea6eb72328"},
- {file = "coverage-7.4.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b2eccb883368f9e972e216c7b4c7c06cabda925b5f06dde0650281cb7666a30"},
- {file = "coverage-7.4.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6c00cdc8fa4e50e1cc1f941a7f2e3e0f26cb2a1233c9696f26963ff58445bac7"},
- {file = "coverage-7.4.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b9a4a8dd3dcf4cbd3165737358e4d7dfbd9d59902ad11e3b15eebb6393b0446e"},
- {file = "coverage-7.4.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:062b0a75d9261e2f9c6d071753f7eef0fc9caf3a2c82d36d76667ba7b6470003"},
- {file = "coverage-7.4.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:ebe7c9e67a2d15fa97b77ea6571ce5e1e1f6b0db71d1d5e96f8d2bf134303c1d"},
- {file = "coverage-7.4.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c0a120238dd71c68484f02562f6d446d736adcc6ca0993712289b102705a9a3a"},
- {file = "coverage-7.4.3-cp312-cp312-win32.whl", hash = "sha256:37389611ba54fd6d278fde86eb2c013c8e50232e38f5c68235d09d0a3f8aa352"},
- {file = "coverage-7.4.3-cp312-cp312-win_amd64.whl", hash = "sha256:d25b937a5d9ffa857d41be042b4238dd61db888533b53bc76dc082cb5a15e914"},
- {file = "coverage-7.4.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:28ca2098939eabab044ad68850aac8f8db6bf0b29bc7f2887d05889b17346454"},
- {file = "coverage-7.4.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:280459f0a03cecbe8800786cdc23067a8fc64c0bd51dc614008d9c36e1659d7e"},
- {file = "coverage-7.4.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c0cdedd3500e0511eac1517bf560149764b7d8e65cb800d8bf1c63ebf39edd2"},
- {file = "coverage-7.4.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9a9babb9466fe1da12417a4aed923e90124a534736de6201794a3aea9d98484e"},
- {file = "coverage-7.4.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dec9de46a33cf2dd87a5254af095a409ea3bf952d85ad339751e7de6d962cde6"},
- {file = "coverage-7.4.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:16bae383a9cc5abab9bb05c10a3e5a52e0a788325dc9ba8499e821885928968c"},
- {file = "coverage-7.4.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2c854ce44e1ee31bda4e318af1dbcfc929026d12c5ed030095ad98197eeeaed0"},
- {file = "coverage-7.4.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ce8c50520f57ec57aa21a63ea4f325c7b657386b3f02ccaedeccf9ebe27686e1"},
- {file = "coverage-7.4.3-cp38-cp38-win32.whl", hash = "sha256:708a3369dcf055c00ddeeaa2b20f0dd1ce664eeabde6623e516c5228b753654f"},
- {file = "coverage-7.4.3-cp38-cp38-win_amd64.whl", hash = "sha256:1bf25fbca0c8d121a3e92a2a0555c7e5bc981aee5c3fdaf4bb7809f410f696b9"},
- {file = "coverage-7.4.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b253094dbe1b431d3a4ac2f053b6d7ede2664ac559705a704f621742e034f1f"},
- {file = "coverage-7.4.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77fbfc5720cceac9c200054b9fab50cb2a7d79660609200ab83f5db96162d20c"},
- {file = "coverage-7.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6679060424faa9c11808598504c3ab472de4531c571ab2befa32f4971835788e"},
- {file = "coverage-7.4.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4af154d617c875b52651dd8dd17a31270c495082f3d55f6128e7629658d63765"},
- {file = "coverage-7.4.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8640f1fde5e1b8e3439fe482cdc2b0bb6c329f4bb161927c28d2e8879c6029ee"},
- {file = "coverage-7.4.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:69b9f6f66c0af29642e73a520b6fed25ff9fd69a25975ebe6acb297234eda501"},
- {file = "coverage-7.4.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:0842571634f39016a6c03e9d4aba502be652a6e4455fadb73cd3a3a49173e38f"},
- {file = "coverage-7.4.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a78ed23b08e8ab524551f52953a8a05d61c3a760781762aac49f8de6eede8c45"},
- {file = "coverage-7.4.3-cp39-cp39-win32.whl", hash = "sha256:c0524de3ff096e15fcbfe8f056fdb4ea0bf497d584454f344d59fce069d3e6e9"},
- {file = "coverage-7.4.3-cp39-cp39-win_amd64.whl", hash = "sha256:0209a6369ccce576b43bb227dc8322d8ef9e323d089c6f3f26a597b09cb4d2aa"},
- {file = "coverage-7.4.3-pp38.pp39.pp310-none-any.whl", hash = "sha256:7cbde573904625509a3f37b6fecea974e363460b556a627c60dc2f47e2fffa51"},
- {file = "coverage-7.4.3.tar.gz", hash = "sha256:276f6077a5c61447a48d133ed13e759c09e62aff0dc84274a68dc18660104d52"},
+ {file = "coverage-7.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0be5efd5127542ef31f165de269f77560d6cdef525fffa446de6f7e9186cfb2"},
+ {file = "coverage-7.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ccd341521be3d1b3daeb41960ae94a5e87abe2f46f17224ba5d6f2b8398016cf"},
+ {file = "coverage-7.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fa497a8ab37784fbb20ab699c246053ac294d13fc7eb40ec007a5043ec91f8"},
+ {file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1a93009cb80730c9bca5d6d4665494b725b6e8e157c1cb7f2db5b4b122ea562"},
+ {file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:690db6517f09336559dc0b5f55342df62370a48f5469fabf502db2c6d1cffcd2"},
+ {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:09c3255458533cb76ef55da8cc49ffab9e33f083739c8bd4f58e79fecfe288f7"},
+ {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8ce1415194b4a6bd0cdcc3a1dfbf58b63f910dcb7330fe15bdff542c56949f87"},
+ {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b91cbc4b195444e7e258ba27ac33769c41b94967919f10037e6355e998af255c"},
+ {file = "coverage-7.4.4-cp310-cp310-win32.whl", hash = "sha256:598825b51b81c808cb6f078dcb972f96af96b078faa47af7dfcdf282835baa8d"},
+ {file = "coverage-7.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:09ef9199ed6653989ebbcaacc9b62b514bb63ea2f90256e71fea3ed74bd8ff6f"},
+ {file = "coverage-7.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f9f50e7ef2a71e2fae92774c99170eb8304e3fdf9c8c3c7ae9bab3e7229c5cf"},
+ {file = "coverage-7.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:623512f8ba53c422fcfb2ce68362c97945095b864cda94a92edbaf5994201083"},
+ {file = "coverage-7.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0513b9508b93da4e1716744ef6ebc507aff016ba115ffe8ecff744d1322a7b63"},
+ {file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40209e141059b9370a2657c9b15607815359ab3ef9918f0196b6fccce8d3230f"},
+ {file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a2b2b78c78293782fd3767d53e6474582f62443d0504b1554370bde86cc8227"},
+ {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:73bfb9c09951125d06ee473bed216e2c3742f530fc5acc1383883125de76d9cd"},
+ {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f384c3cc76aeedce208643697fb3e8437604b512255de6d18dae3f27655a384"},
+ {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:54eb8d1bf7cacfbf2a3186019bcf01d11c666bd495ed18717162f7eb1e9dd00b"},
+ {file = "coverage-7.4.4-cp311-cp311-win32.whl", hash = "sha256:cac99918c7bba15302a2d81f0312c08054a3359eaa1929c7e4b26ebe41e9b286"},
+ {file = "coverage-7.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:b14706df8b2de49869ae03a5ccbc211f4041750cd4a66f698df89d44f4bd30ec"},
+ {file = "coverage-7.4.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:201bef2eea65e0e9c56343115ba3814e896afe6d36ffd37bab783261db430f76"},
+ {file = "coverage-7.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41c9c5f3de16b903b610d09650e5e27adbfa7f500302718c9ffd1c12cf9d6818"},
+ {file = "coverage-7.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d898fe162d26929b5960e4e138651f7427048e72c853607f2b200909794ed978"},
+ {file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ea79bb50e805cd6ac058dfa3b5c8f6c040cb87fe83de10845857f5535d1db70"},
+ {file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce4b94265ca988c3f8e479e741693d143026632672e3ff924f25fab50518dd51"},
+ {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00838a35b882694afda09f85e469c96367daa3f3f2b097d846a7216993d37f4c"},
+ {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fdfafb32984684eb03c2d83e1e51f64f0906b11e64482df3c5db936ce3839d48"},
+ {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:69eb372f7e2ece89f14751fbcbe470295d73ed41ecd37ca36ed2eb47512a6ab9"},
+ {file = "coverage-7.4.4-cp312-cp312-win32.whl", hash = "sha256:137eb07173141545e07403cca94ab625cc1cc6bc4c1e97b6e3846270e7e1fea0"},
+ {file = "coverage-7.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:d71eec7d83298f1af3326ce0ff1d0ea83c7cb98f72b577097f9083b20bdaf05e"},
+ {file = "coverage-7.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ae728ff3b5401cc320d792866987e7e7e880e6ebd24433b70a33b643bb0384"},
+ {file = "coverage-7.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc4f1358cb0c78edef3ed237ef2c86056206bb8d9140e73b6b89fbcfcbdd40e1"},
+ {file = "coverage-7.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8130a2aa2acb8788e0b56938786c33c7c98562697bf9f4c7d6e8e5e3a0501e4a"},
+ {file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf271892d13e43bc2b51e6908ec9a6a5094a4df1d8af0bfc360088ee6c684409"},
+ {file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4cdc86d54b5da0df6d3d3a2f0b710949286094c3a6700c21e9015932b81447e"},
+ {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ae71e7ddb7a413dd60052e90528f2f65270aad4b509563af6d03d53e979feafd"},
+ {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:38dd60d7bf242c4ed5b38e094baf6401faa114fc09e9e6632374388a404f98e7"},
+ {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa5b1c1bfc28384f1f53b69a023d789f72b2e0ab1b3787aae16992a7ca21056c"},
+ {file = "coverage-7.4.4-cp38-cp38-win32.whl", hash = "sha256:dfa8fe35a0bb90382837b238fff375de15f0dcdb9ae68ff85f7a63649c98527e"},
+ {file = "coverage-7.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:b2991665420a803495e0b90a79233c1433d6ed77ef282e8e152a324bbbc5e0c8"},
+ {file = "coverage-7.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b799445b9f7ee8bf299cfaed6f5b226c0037b74886a4e11515e569b36fe310d"},
+ {file = "coverage-7.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b4d33f418f46362995f1e9d4f3a35a1b6322cb959c31d88ae56b0298e1c22357"},
+ {file = "coverage-7.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aadacf9a2f407a4688d700e4ebab33a7e2e408f2ca04dbf4aef17585389eff3e"},
+ {file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c95949560050d04d46b919301826525597f07b33beba6187d04fa64d47ac82e"},
+ {file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff7687ca3d7028d8a5f0ebae95a6e4827c5616b31a4ee1192bdfde697db110d4"},
+ {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5fc1de20b2d4a061b3df27ab9b7c7111e9a710f10dc2b84d33a4ab25065994ec"},
+ {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c74880fc64d4958159fbd537a091d2a585448a8f8508bf248d72112723974cbd"},
+ {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:742a76a12aa45b44d236815d282b03cfb1de3b4323f3e4ec933acfae08e54ade"},
+ {file = "coverage-7.4.4-cp39-cp39-win32.whl", hash = "sha256:d89d7b2974cae412400e88f35d86af72208e1ede1a541954af5d944a8ba46c57"},
+ {file = "coverage-7.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:9ca28a302acb19b6af89e90f33ee3e1906961f94b54ea37de6737b7ca9d8827c"},
+ {file = "coverage-7.4.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:b2c5edc4ac10a7ef6605a966c58929ec6c1bd0917fb8c15cb3363f65aa40e677"},
+ {file = "coverage-7.4.4.tar.gz", hash = "sha256:c901df83d097649e257e803be22592aedfd5182f07b3cc87d640bbb9afd50f49"},
]
[package.dependencies]
@@ -168,29 +173,29 @@ test = ["pytest (>=6)"]
[[package]]
name = "filelock"
-version = "3.13.1"
+version = "3.13.3"
description = "A platform independent file lock."
optional = false
python-versions = ">=3.8"
files = [
- {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"},
- {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"},
+ {file = "filelock-3.13.3-py3-none-any.whl", hash = "sha256:5ffa845303983e7a0b7ae17636509bc97997d58afeafa72fb141a17b152284cb"},
+ {file = "filelock-3.13.3.tar.gz", hash = "sha256:a79895a25bbefdf55d1a2a0a80968f7dbb28edcd6d4234a0afb3f37ecde4b546"},
]
[package.extras]
-docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"]
-testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"]
+docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"]
+testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"]
typing = ["typing-extensions (>=4.8)"]
[[package]]
name = "fsspec"
-version = "2024.2.0"
+version = "2024.3.1"
description = "File-system specification"
optional = false
python-versions = ">=3.8"
files = [
- {file = "fsspec-2024.2.0-py3-none-any.whl", hash = "sha256:817f969556fa5916bc682e02ca2045f96ff7f586d45110fcb76022063ad2c7d8"},
- {file = "fsspec-2024.2.0.tar.gz", hash = "sha256:b6ad1a679f760dda52b1168c859d01b7b80648ea6f7f7c7f5a8a91dc3f3ecb84"},
+ {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"},
+ {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"},
]
[package.extras]
@@ -413,13 +418,13 @@ files = [
[[package]]
name = "packaging"
-version = "23.2"
+version = "24.0"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.7"
files = [
- {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"},
- {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"},
+ {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"},
+ {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"},
]
[[package]]
@@ -465,13 +470,13 @@ testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "pytest"
-version = "8.0.2"
+version = "8.1.1"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.8"
files = [
- {file = "pytest-8.0.2-py3-none-any.whl", hash = "sha256:edfaaef32ce5172d5466b5127b42e0d6d35ebbe4453f0e3505d96afd93f6b096"},
- {file = "pytest-8.0.2.tar.gz", hash = "sha256:d4051d623a2e0b7e51960ba963193b09ce6daeb9759a451844a21e4ddedfc1bd"},
+ {file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"},
+ {file = "pytest-8.1.1.tar.gz", hash = "sha256:ac978141a75948948817d360297b7aae0fcb9d6ff6bc9ec6d514b85d5a65c044"},
]
[package.dependencies]
@@ -479,21 +484,21 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""}
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
packaging = "*"
-pluggy = ">=1.3.0,<2.0"
-tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
+pluggy = ">=1.4,<2.0"
+tomli = {version = ">=1", markers = "python_version < \"3.11\""}
[package.extras]
-testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
+testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-cov"
-version = "4.1.0"
+version = "5.0.0"
description = "Pytest plugin for measuring coverage."
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"},
- {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"},
+ {file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
+ {file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
]
[package.dependencies]
@@ -501,72 +506,34 @@ coverage = {version = ">=5.2.1", extras = ["toml"]}
pytest = ">=4.6"
[package.extras]
-testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
+testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
[[package]]
name = "ruff"
-version = "0.3.0"
+version = "0.3.5"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
files = [
- {file = "ruff-0.3.0-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:7deb528029bacf845bdbb3dbb2927d8ef9b4356a5e731b10eef171e3f0a85944"},
- {file = "ruff-0.3.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e1e0d4381ca88fb2b73ea0766008e703f33f460295de658f5467f6f229658c19"},
- {file = "ruff-0.3.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f7dbba46e2827dfcb0f0cc55fba8e96ba7c8700e0a866eb8cef7d1d66c25dcb"},
- {file = "ruff-0.3.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:23dbb808e2f1d68eeadd5f655485e235c102ac6f12ad31505804edced2a5ae77"},
- {file = "ruff-0.3.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ef655c51f41d5fa879f98e40c90072b567c666a7114fa2d9fe004dffba00932"},
- {file = "ruff-0.3.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d0d3d7ef3d4f06433d592e5f7d813314a34601e6c5be8481cccb7fa760aa243e"},
- {file = "ruff-0.3.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b08b356d06a792e49a12074b62222f9d4ea2a11dca9da9f68163b28c71bf1dd4"},
- {file = "ruff-0.3.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9343690f95710f8cf251bee1013bf43030072b9f8d012fbed6ad702ef70d360a"},
- {file = "ruff-0.3.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1f3ed501a42f60f4dedb7805fa8d4534e78b4e196f536bac926f805f0743d49"},
- {file = "ruff-0.3.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:cc30a9053ff2f1ffb505a585797c23434d5f6c838bacfe206c0e6cf38c921a1e"},
- {file = "ruff-0.3.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:5da894a29ec018a8293d3d17c797e73b374773943e8369cfc50495573d396933"},
- {file = "ruff-0.3.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:755c22536d7f1889be25f2baf6fedd019d0c51d079e8417d4441159f3bcd30c2"},
- {file = "ruff-0.3.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:dd73fe7f4c28d317855da6a7bc4aa29a1500320818dd8f27df95f70a01b8171f"},
- {file = "ruff-0.3.0-py3-none-win32.whl", hash = "sha256:19eacceb4c9406f6c41af806418a26fdb23120dfe53583df76d1401c92b7c14b"},
- {file = "ruff-0.3.0-py3-none-win_amd64.whl", hash = "sha256:128265876c1d703e5f5e5a4543bd8be47c73a9ba223fd3989d4aa87dd06f312f"},
- {file = "ruff-0.3.0-py3-none-win_arm64.whl", hash = "sha256:e3a4a6d46aef0a84b74fcd201a4401ea9a6cd85614f6a9435f2d33dd8cefbf83"},
- {file = "ruff-0.3.0.tar.gz", hash = "sha256:0886184ba2618d815067cf43e005388967b67ab9c80df52b32ec1152ab49f53a"},
-]
-
-[[package]]
-name = "scipy"
-version = "1.9.3"
-description = "Fundamental algorithms for scientific computing in Python"
-optional = true
-python-versions = ">=3.8"
-files = [
- {file = "scipy-1.9.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1884b66a54887e21addf9c16fb588720a8309a57b2e258ae1c7986d4444d3bc0"},
- {file = "scipy-1.9.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:83b89e9586c62e787f5012e8475fbb12185bafb996a03257e9675cd73d3736dd"},
- {file = "scipy-1.9.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a72d885fa44247f92743fc20732ae55564ff2a519e8302fb7e18717c5355a8b"},
- {file = "scipy-1.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d01e1dd7b15bd2449c8bfc6b7cc67d630700ed655654f0dfcf121600bad205c9"},
- {file = "scipy-1.9.3-cp310-cp310-win_amd64.whl", hash = "sha256:68239b6aa6f9c593da8be1509a05cb7f9efe98b80f43a5861cd24c7557e98523"},
- {file = "scipy-1.9.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b41bc822679ad1c9a5f023bc93f6d0543129ca0f37c1ce294dd9d386f0a21096"},
- {file = "scipy-1.9.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:90453d2b93ea82a9f434e4e1cba043e779ff67b92f7a0e85d05d286a3625df3c"},
- {file = "scipy-1.9.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83c06e62a390a9167da60bedd4575a14c1f58ca9dfde59830fc42e5197283dab"},
- {file = "scipy-1.9.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abaf921531b5aeaafced90157db505e10345e45038c39e5d9b6c7922d68085cb"},
- {file = "scipy-1.9.3-cp311-cp311-win_amd64.whl", hash = "sha256:06d2e1b4c491dc7d8eacea139a1b0b295f74e1a1a0f704c375028f8320d16e31"},
- {file = "scipy-1.9.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5a04cd7d0d3eff6ea4719371cbc44df31411862b9646db617c99718ff68d4840"},
- {file = "scipy-1.9.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:545c83ffb518094d8c9d83cce216c0c32f8c04aaf28b92cc8283eda0685162d5"},
- {file = "scipy-1.9.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d54222d7a3ba6022fdf5773931b5d7c56efe41ede7f7128c7b1637700409108"},
- {file = "scipy-1.9.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cff3a5295234037e39500d35316a4c5794739433528310e117b8a9a0c76d20fc"},
- {file = "scipy-1.9.3-cp38-cp38-win_amd64.whl", hash = "sha256:2318bef588acc7a574f5bfdff9c172d0b1bf2c8143d9582e05f878e580a3781e"},
- {file = "scipy-1.9.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d644a64e174c16cb4b2e41dfea6af722053e83d066da7343f333a54dae9bc31c"},
- {file = "scipy-1.9.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:da8245491d73ed0a994ed9c2e380fd058ce2fa8a18da204681f2fe1f57f98f95"},
- {file = "scipy-1.9.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4db5b30849606a95dcf519763dd3ab6fe9bd91df49eba517359e450a7d80ce2e"},
- {file = "scipy-1.9.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c68db6b290cbd4049012990d7fe71a2abd9ffbe82c0056ebe0f01df8be5436b0"},
- {file = "scipy-1.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:5b88e6d91ad9d59478fafe92a7c757d00c59e3bdc3331be8ada76a4f8d683f58"},
- {file = "scipy-1.9.3.tar.gz", hash = "sha256:fbc5c05c85c1a02be77b1ff591087c83bc44579c6d2bd9fb798bb64ea5e1a027"},
+ {file = "ruff-0.3.5-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:aef5bd3b89e657007e1be6b16553c8813b221ff6d92c7526b7e0227450981eac"},
+ {file = "ruff-0.3.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:89b1e92b3bd9fca249153a97d23f29bed3992cff414b222fcd361d763fc53f12"},
+ {file = "ruff-0.3.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e55771559c89272c3ebab23326dc23e7f813e492052391fe7950c1a5a139d89"},
+ {file = "ruff-0.3.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dabc62195bf54b8a7876add6e789caae0268f34582333cda340497c886111c39"},
+ {file = "ruff-0.3.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a05f3793ba25f194f395578579c546ca5d83e0195f992edc32e5907d142bfa3"},
+ {file = "ruff-0.3.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:dfd3504e881082959b4160ab02f7a205f0fadc0a9619cc481982b6837b2fd4c0"},
+ {file = "ruff-0.3.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87258e0d4b04046cf1d6cc1c56fadbf7a880cc3de1f7294938e923234cf9e498"},
+ {file = "ruff-0.3.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:712e71283fc7d9f95047ed5f793bc019b0b0a29849b14664a60fd66c23b96da1"},
+ {file = "ruff-0.3.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a532a90b4a18d3f722c124c513ffb5e5eaff0cc4f6d3aa4bda38e691b8600c9f"},
+ {file = "ruff-0.3.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:122de171a147c76ada00f76df533b54676f6e321e61bd8656ae54be326c10296"},
+ {file = "ruff-0.3.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d80a6b18a6c3b6ed25b71b05eba183f37d9bc8b16ace9e3d700997f00b74660b"},
+ {file = "ruff-0.3.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a7b6e63194c68bca8e71f81de30cfa6f58ff70393cf45aab4c20f158227d5936"},
+ {file = "ruff-0.3.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a759d33a20c72f2dfa54dae6e85e1225b8e302e8ac655773aff22e542a300985"},
+ {file = "ruff-0.3.5-py3-none-win32.whl", hash = "sha256:9d8605aa990045517c911726d21293ef4baa64f87265896e491a05461cae078d"},
+ {file = "ruff-0.3.5-py3-none-win_amd64.whl", hash = "sha256:dc56bb16a63c1303bd47563c60482a1512721053d93231cf7e9e1c6954395a0e"},
+ {file = "ruff-0.3.5-py3-none-win_arm64.whl", hash = "sha256:faeeae9905446b975dcf6d4499dc93439b131f1443ee264055c5716dd947af55"},
+ {file = "ruff-0.3.5.tar.gz", hash = "sha256:a067daaeb1dc2baf9b82a32dae67d154d95212080c80435eb052d95da647763d"},
]
-[package.dependencies]
-numpy = ">=1.18.5,<1.26.0"
-
-[package.extras]
-dev = ["flake8", "mypy", "pycodestyle", "typing_extensions"]
-doc = ["matplotlib (>2)", "numpydoc", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-panels (>=0.5.2)", "sphinx-tabs"]
-test = ["asv", "gmpy2", "mpmath", "pytest", "pytest-cov", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
-
[[package]]
name = "sympy"
version = "1.12"
@@ -594,21 +561,21 @@ files = [
[[package]]
name = "torch"
-version = "2.2.1+cpu"
+version = "2.2.2+cpu"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "torch-2.2.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:5d82422cf04797f1b2a8574b64a916070ec83eef58ad4900615ee0218d7b8b8e"},
- {file = "torch-2.2.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:f8914dd0f5f0e5c66fdecd9559403eea9feac82d1ea639b672fde0073c6addbd"},
- {file = "torch-2.2.1+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:6bc973d5632374b92b4b293817b4d2ff8c8ce1c784c748b471dba1fffcd9c333"},
- {file = "torch-2.2.1+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:abdec34b0ade8fca0520055e72c3094425ae0ef210718e9c0278121cd3608c32"},
- {file = "torch-2.2.1+cpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:d7339580135da4105c1244a8621faa076990409afeab5a7b642c3c1ee70a5622"},
- {file = "torch-2.2.1+cpu-cp312-cp312-win_amd64.whl", hash = "sha256:039128fcb5548122465b15f679b8831c47d14f0d6c28c1f1b631f8019c104720"},
- {file = "torch-2.2.1+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:2b447f7bb50b393b4544b4036d587e39ab524d4353e77c197f6a2727f22b0d47"},
- {file = "torch-2.2.1+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:2ccdf3e5f71e6426ea9e34d21c3cc333b29d4f48299b981d28aeb5112b5495e1"},
- {file = "torch-2.2.1+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:2fb340b289760040a16a77a6d70b8a48961abba1822e6f58705c97c80befa03e"},
- {file = "torch-2.2.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:e03dc4654ecceeb5b03f0a6f60b342c0e0d267b3ebc61e4f672cace1df8cd930"},
+ {file = "torch-2.2.2+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:02c4fac3c964e73f5f49003e0060c697f73b67c10cc23f51c592facb29e1bd53"},
+ {file = "torch-2.2.2+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:fc29dda2795dd7220d769c5926b1c50ddac9b4827897e30a10467063691cdf54"},
+ {file = "torch-2.2.2+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:90089cae572672fb449c8ff1dc1b29daaffa117bf97ede7463dcd2fd1b991e4c"},
+ {file = "torch-2.2.2+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:88e63c916e3275fa30a220ee736423a95573b96072ded85e5c0171fd8f37a755"},
+ {file = "torch-2.2.2+cpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:431a747b5a880cf8e1fb6d58db6bfafa6768cbec76517d046854537c03323edf"},
+ {file = "torch-2.2.2+cpu-cp312-cp312-win_amd64.whl", hash = "sha256:2b0cf041f878607a361116945f82ce2dba4b7a747151da7619a63cb5fccb72df"},
+ {file = "torch-2.2.2+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:8914ce932168e572a09b4a7e5b0806d279f771dfe58d7e1d8de2291fac4ce69b"},
+ {file = "torch-2.2.2+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:4ef2911ffde6d86f643c23aa99f25f1a1df8bee93bf8d0c69cf1b9ba0ca521dc"},
+ {file = "torch-2.2.2+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:6e3d323a21df22415770e88d39e13591079b9356dabb8b394d1ee29ac6c92481"},
+ {file = "torch-2.2.2+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:c2c9e7d5e3c7d58e4b78d6aebfa8002af7cda16cde08d0e3ed00300dc21a8efc"},
]
[package.dependencies]
@@ -630,13 +597,13 @@ reference = "torch"
[[package]]
name = "typing-extensions"
-version = "4.10.0"
+version = "4.11.0"
description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
python-versions = ">=3.8"
files = [
- {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"},
- {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"},
+ {file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
+ {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
]
[extras]
@@ -645,4 +612,4 @@ bitsandbytes = ["bitsandbytes"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8,<4.0.0"
-content-hash = "d7d8dc0d66f37e2f167551d9be0775c778f251db9faeef427a5d66c2a2396d38"
+content-hash = "d51586f8352db14a18dd407b19285c9649564b029e6e6aae52a0d566515e5c81"
diff --git a/pyproject.toml b/pyproject.toml
index 57eadffc3..646597102 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -13,11 +13,11 @@ keywords = [
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "AdamP",
"AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "CAME", "DAdaptAdaGrad", "DAdaptAdam",
- "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "Gravity", "GSAM", "LARS", "Lamb", "Lion", "LOMO",
- "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM",
- "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3",
- "SopihaH", "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice",
- "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes",
+ "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "GaLore", "Gravity", "GSAM", "LARS", "Lamb", "Lion",
+ "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam",
+ "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD",
+ "SM3", "SopihaH", "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
+ "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
@@ -45,14 +45,14 @@ classifiers = [
python = ">=3.8,<4.0.0"
numpy = { version = "*", python = ">=3.8" }
torch = { version = ">=1.10", python = ">=3.8", source = "torch" }
-bitsandbytes = { version = "^0.42", optional = true }
+bitsandbytes = { version = "^0.43", optional = true }
[tool.poetry.dev-dependencies]
isort = { version = "^5", python = ">=3.8" }
black = { version = "^24", python = ">=3.8"}
-ruff = "^0.3"
-pytest = "^8"
-pytest-cov = "^4"
+ruff = "*"
+pytest = "*"
+pytest-cov = "*"
[tool.poetry.extras]
bitsandbytes = ["bitsandbytes"]
diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py
index 8ccdb7c36..f0f1c24ef 100644
--- a/pytorch_optimizer/__init__.py
+++ b/pytorch_optimizer/__init__.py
@@ -55,6 +55,7 @@
from pytorch_optimizer.optimizer.diffgrad import DiffGrad
from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
from pytorch_optimizer.optimizer.fromage import Fromage
+from pytorch_optimizer.optimizer.galore import GaLore, GaLoreProjector
from pytorch_optimizer.optimizer.gc import centralize_gradient
from pytorch_optimizer.optimizer.gravity import Gravity
from pytorch_optimizer.optimizer.lamb import Lamb
@@ -182,6 +183,7 @@
CAME,
DAdaptLion,
Aida,
+ GaLore,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
diff --git a/pytorch_optimizer/optimizer/galore.py b/pytorch_optimizer/optimizer/galore.py
new file mode 100644
index 000000000..ced850c2d
--- /dev/null
+++ b/pytorch_optimizer/optimizer/galore.py
@@ -0,0 +1,249 @@
+import math
+from typing import Literal, Optional, Tuple, Union
+
+import torch
+from torch.optim.optimizer import Optimizer
+
+from pytorch_optimizer.base.exception import NoSparseGradientError
+from pytorch_optimizer.base.optimizer import BaseOptimizer
+from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
+
+PROJECTION_TYPE = Literal['std', 'reverse_std', 'right', 'left', 'full']
+
+
+class GaLoreProjector:
+ r"""Memory-Efficient LLM Training by Gradient Low-Rank Projection.
+
+ :param rank: int. low rank to project.
+ :param update_proj_gap: int. num steps to update the projection.
+ :param scale: float. scale factor.
+ :param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' are
+ supported.
+ """
+
+ def __init__(
+ self, rank: int = 128, update_proj_gap: int = 50, scale: float = 1.0, projection_type: PROJECTION_TYPE = 'std'
+ ):
+ self.rank = rank
+ self.update_proj_gap = update_proj_gap
+ self.scale = scale
+ self.projection_type = projection_type
+
+ self.ortho_matrix: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None
+
+ @staticmethod
+ def get_orthogonal_matrix(
+ weights: torch.Tensor, rank: int, projection_type: str
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ if projection_type not in {'right', 'left', 'full'}:
+ raise ValueError('projection_type should be one of left, right or full')
+
+ original_type = weights.data.dtype
+ original_device = weights.data.device
+ is_float: bool = original_type == torch.float
+
+ u, s, vh = torch.linalg.svd(weights if is_float else weights.float(), full_matrices=False)
+
+ if projection_type == 'right':
+ b = vh[:rank, :]
+ return b if is_float else b.to(original_device).type(original_type)
+ if projection_type == 'left':
+ a = u[:, :rank]
+ return a if is_float else a.to(original_device).type(original_type)
+
+ a = u[:, :rank]
+ b = vh[:rank, :]
+
+ return (
+ (a, b)
+ if is_float
+ else (a.to(original_device).type(original_type), b.to(original_device).type(original_type))
+ )
+
+ def get_low_rank_grad_std(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
+ if grad.shape[0] >= grad.shape[1]:
+ if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
+ self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right')
+ return torch.matmul(grad, self.ortho_matrix.t())
+
+ if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
+ self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left')
+
+ return torch.matmul(self.ortho_matrix.t(), grad)
+
+ def get_low_rank_grad_reverse_std(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
+ if grad.shape[0] >= grad.shape[1]:
+ if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
+ self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left')
+ return torch.matmul(self.ortho_matrix.t(), grad)
+
+ if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
+ self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right')
+
+ return torch.matmul(grad, self.ortho_matrix.t())
+
+ def get_low_rank_grad_right(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
+ if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
+ self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right')
+ return torch.matmul(grad, self.ortho_matrix.t())
+
+ def get_low_rank_grad_left(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
+ if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
+ self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left')
+ return torch.matmul(self.ortho_matrix.t(), grad)
+
+ def get_low_rank_grad_full(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
+ if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
+ self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='full')
+ return torch.matmul(self.ortho_matrix[0].t(), grad) @ self.ortho_matrix[1].t()
+
+ def project(self, full_rank_grad: torch.Tensor, steps: int) -> torch.Tensor:
+ if self.projection_type == 'std':
+ return self.get_low_rank_grad_std(full_rank_grad, steps)
+ if self.projection_type == 'reverse_std':
+ return self.get_low_rank_grad_reverse_std(full_rank_grad, steps)
+ if self.projection_type == 'right':
+ return self.get_low_rank_grad_right(full_rank_grad, steps)
+ if self.projection_type == 'left':
+ return self.get_low_rank_grad_left(full_rank_grad, steps)
+ if self.projection_type == 'full':
+ return self.get_low_rank_grad_full(full_rank_grad, steps)
+ raise NotImplementedError
+
+ def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor:
+ if self.projection_type == 'std':
+ return (
+ torch.matmul(low_rank_grad, self.ortho_matrix)
+ if low_rank_grad.shape[0] >= low_rank_grad.shape[1]
+ else torch.matmul(self.ortho_matrix, low_rank_grad)
+ ) * self.scale
+ if self.projection_type == 'reverse_std':
+ return (
+ torch.matmul(self.ortho_matrix, low_rank_grad.t())
+ if low_rank_grad.shape[0] <= low_rank_grad.shape[1]
+ else torch.matmul(low_rank_grad, self.ortho_matrix.t())
+ ) * self.scale
+ if self.projection_type == 'right':
+ return torch.matmul(low_rank_grad, self.ortho_matrix.t()) * self.scale
+ if self.projection_type == 'left':
+ return torch.matmul(self.ortho_matrix, low_rank_grad) * self.scale
+ if self.projection_type == 'full':
+ return torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1].t() * self.scale
+
+ raise NotImplementedError
+
+
+class GaLore(Optimizer, BaseOptimizer):
+ r"""AdamW optimizer with GaLore projector.
+
+ :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
+ :param lr: float. learning rate.
+ :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
+ :param weight_decay: float. weight decay (L2 penalty).
+ :param eps: float. term added to the denominator to improve numerical stability.
+ """
+
+ def __init__(
+ self,
+ params: PARAMETERS,
+ lr: float = 1e-3,
+ betas: BETAS = (0.9, 0.999),
+ weight_decay: float = 0.0,
+ eps: float = 1e-6,
+ **kwargs,
+ ):
+ self.validate_learning_rate(lr)
+ self.validate_betas(betas)
+ self.validate_non_negative(weight_decay, 'weight_decay')
+ self.validate_non_negative(eps, 'eps')
+
+ defaults: DEFAULTS = {
+ 'lr': lr,
+ 'betas': betas,
+ 'weight_decay': weight_decay,
+ 'eps': eps,
+ **kwargs,
+ }
+
+ super().__init__(params, defaults)
+
+ def __str__(self) -> str:
+ return 'GaLore'
+
+ @torch.no_grad()
+ def reset(self):
+ for group in self.param_groups:
+ for p in group['params']:
+ state = self.state[p]
+
+ state['exp_avg'] = torch.zeros_like(p)
+ state['exp_avg_sq'] = torch.zeros_like(p)
+
+ @torch.no_grad()
+ def step(self, closure: CLOSURE = None) -> LOSS:
+ loss: LOSS = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ if 'step' in group:
+ group['step'] += 1
+ else:
+ group['step'] = 1
+
+ beta1, beta2 = group['betas']
+
+ bias_correction1: float = 1.0 - beta1 ** group['step']
+ bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])
+
+ step_size: float = group['lr'] * bias_correction2_sq / bias_correction1
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+
+ grad = p.grad
+ if grad.is_sparse:
+ raise NoSparseGradientError(str(self))
+
+ state = self.state[p]
+
+ if len(state) == 0:
+ state['exp_avg'] = torch.zeros_like(p)
+ state['exp_avg_sq'] = torch.zeros_like(p)
+
+ if 'rank' in group and p.dim() > 1:
+ if 'projector' not in state:
+ state['projector'] = GaLoreProjector(
+ rank=group['rank'],
+ update_proj_gap=group['update_proj_gap'],
+ scale=group['scale'],
+ projection_type=group['projection_type'],
+ )
+
+ grad = state['projector'].project(grad, group['step'])
+
+ self.apply_weight_decay(
+ p=p,
+ grad=None,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ weight_decouple=True,
+ fixed_decay=False,
+ )
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
+
+ de_nom = exp_avg_sq.sqrt().add_(group['eps'])
+
+ norm_grad = exp_avg / de_nom
+
+ if 'rank' in group and p.dim() > 1:
+ norm_grad = state['projector'].project_back(norm_grad)
+
+ p.add_(norm_grad, alpha=-step_size)
+
+ return loss
diff --git a/requirements-dev.txt b/requirements-dev.txt
index c6157e8fc..6ed086a44 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,12 +1,12 @@
--extra-index-url https://download.pytorch.org/whl/cpu
-black==24.2.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
+black==24.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
click==8.1.7 ; python_version >= "3.8" and python_full_version < "4.0.0"
colorama==0.4.6 ; python_version >= "3.8" and python_full_version < "4.0.0" and (sys_platform == "win32" or platform_system == "Windows")
-coverage[toml]==7.4.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
+coverage[toml]==7.4.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
exceptiongroup==1.2.0 ; python_version >= "3.8" and python_version < "3.11"
-filelock==3.13.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
-fsspec==2024.2.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
+filelock==3.13.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
+fsspec==2024.3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
iniconfig==2.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
isort==5.13.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
jinja2==3.1.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
@@ -15,14 +15,14 @@ mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
mypy-extensions==1.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
-packaging==23.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
+packaging==24.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
pathspec==0.12.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
platformdirs==4.2.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
pluggy==1.4.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
-pytest-cov==4.1.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
-pytest==8.0.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
-ruff==0.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
+pytest-cov==5.0.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
+pytest==8.1.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
+ruff==0.3.5 ; python_version >= "3.8" and python_full_version < "4.0.0"
sympy==1.12 ; python_version >= "3.8" and python_full_version < "4.0.0"
tomli==2.0.1 ; python_version >= "3.8" and python_full_version <= "3.11.0a6"
-torch==2.2.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
-typing-extensions==4.10.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
+torch==2.2.2+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
+typing-extensions==4.11.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
diff --git a/requirements.txt b/requirements.txt
index b6e6a1df3..fe27e8f9a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,12 +1,12 @@
--extra-index-url https://download.pytorch.org/whl/cpu
-filelock==3.13.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
-fsspec==2024.2.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
+filelock==3.13.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
+fsspec==2024.3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
jinja2==3.1.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
markupsafe==2.1.5 ; python_version >= "3.8" and python_full_version < "4.0.0"
mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
networkx==3.1 ; python_version >= "3.8" and python_full_version < "4.0.0"
numpy==1.24.4 ; python_version >= "3.8" and python_full_version < "4.0.0"
sympy==1.12 ; python_version >= "3.8" and python_full_version < "4.0.0"
-torch==2.2.1+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
-typing-extensions==4.10.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
+torch==2.2.2+cpu ; python_version >= "3.8" and python_full_version < "4.0.0"
+typing-extensions==4.11.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
diff --git a/tests/constants.py b/tests/constants.py
index 926dfb9b0..f7ffc4b46 100644
--- a/tests/constants.py
+++ b/tests/constants.py
@@ -45,6 +45,7 @@
DAdaptSGD,
DiffGrad,
Fromage,
+ GaLore,
Gravity,
Lamb,
Lion,
@@ -401,6 +402,38 @@
(CAME, {'lr': 7.5e-1, 'weight_decay': 1e-3}, 75),
(CAME, {'lr': 7.5e-1, 'weight_decay': 1e-3, 'ams_bound': True}, 75),
(Aida, {'lr': 1e0, 'weight_decay': 1e-3, 'ams_bound': True}, 5),
+ (
+ GaLore,
+ {'lr': 1e0, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 1, 'projection_type': 'std'},
+ 5,
+ ),
+ (
+ GaLore,
+ {
+ 'lr': 1e0,
+ 'weight_decay': 1e-3,
+ 'rank': 2,
+ 'scale': 1.0,
+ 'update_proj_gap': 1,
+ 'projection_type': 'reverse_std',
+ },
+ 5,
+ ),
+ (
+ GaLore,
+ {'lr': 5e-1, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 2, 'projection_type': 'left'},
+ 5,
+ ),
+ (
+ GaLore,
+ {'lr': 1e0, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 1, 'projection_type': 'right'},
+ 5,
+ ),
+ (
+ GaLore,
+ {'lr': 5e-1, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 2, 'projection_type': 'full'},
+ 5,
+ ),
]
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),
diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py
index 57b4fd8d9..b7812180c 100644
--- a/tests/test_load_modules.py
+++ b/tests/test_load_modules.py
@@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):
def test_get_supported_optimizers():
- assert len(get_supported_optimizers()) == 61
+ assert len(get_supported_optimizers()) == 62
def test_get_supported_lr_schedulers():
diff --git a/tests/test_optimizer_parameters.py b/tests/test_optimizer_parameters.py
index 26d7ff59c..433b3d988 100644
--- a/tests/test_optimizer_parameters.py
+++ b/tests/test_optimizer_parameters.py
@@ -2,7 +2,16 @@
import torch
from torch import nn
-from pytorch_optimizer import SAM, WSAM, Lookahead, PCGrad, Ranger21, SafeFP16Optimizer, load_optimizer
+from pytorch_optimizer import (
+ SAM,
+ WSAM,
+ GaLoreProjector,
+ Lookahead,
+ PCGrad,
+ Ranger21,
+ SafeFP16Optimizer,
+ load_optimizer,
+)
from tests.constants import PULLBACK_MOMENTUM
from tests.utils import Example, simple_parameter, simple_zero_rank_parameter
@@ -254,3 +263,16 @@ def test_ranger_parameters():
# test lookahead step `k`
with pytest.raises(ValueError):
opt(None, k=-1)
+
+
+def test_galore_projection_type():
+ p = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)
+
+ with pytest.raises(NotImplementedError):
+ GaLoreProjector(projection_type='invalid').project(p, 1)
+
+ with pytest.raises(NotImplementedError):
+ GaLoreProjector(projection_type='invalid').project_back(p)
+
+ with pytest.raises(ValueError):
+ GaLoreProjector.get_orthogonal_matrix(p, 1, projection_type='std')