diff --git a/docs/changelogs/v3.2.0.md b/docs/changelogs/v3.2.0.md index 812dcc41..830bf53d 100644 --- a/docs/changelogs/v3.2.0.md +++ b/docs/changelogs/v3.2.0.md @@ -6,3 +6,7 @@ * [SOAP: Improving and Stabilizing Shampoo using Adam](https://arxiv.org/abs/2409.11321) * Support `AdEMAMix` variants. (#276) * `bnb_ademamix8bit`, `bnb_ademamix32bit`, `bnb_paged_ademamix8bit`, `bnb_paged_ademamix32bit` + +### Bug + +* Fix `should_grokfast` condition when initialization. (#279, #280) diff --git a/docs/qa.md b/docs/qa.md index a6bd431c..d0dccc4b 100644 --- a/docs/qa.md +++ b/docs/qa.md @@ -3,3 +3,7 @@ ## Q1) SophiaH, AdaHessian optimizers give ```RuntimeError: ~ tensors does not require grad and does not have a grad_fn``` in `compute_hutchinson_hessian()`. `create_graph` must be set `True` when calling `backward()`. here's [an example](https://github.com/kozistr/pytorch_optimizer/issues/194#issuecomment-1723167466). + +## Q2) Memory leak happens when using SophiaH, AdaHessian optimizers. + +`torch.autograd.grad` with complex gradient flows sometimes leads memory leak issues, and you might encounter OOM issue. [related issue](https://github.com/kozistr/pytorch_optimizer/issues/278) diff --git a/poetry.lock b/poetry.lock index c13cf1a2..21c9237e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,13 +2,13 @@ [[package]] name = "bitsandbytes" -version = "0.44.0" +version = "0.44.1" description = "k-bit optimizers and matrix multiplication routines." optional = true python-versions = "*" files = [ - {file = "bitsandbytes-0.44.0-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:f31b32ace5d2da0fc7f55b8ed205364298769daaa34d61a45e1f7f2bfd1b3622"}, - {file = "bitsandbytes-0.44.0-py3-none-win_amd64.whl", hash = "sha256:fb3dae427e2c07ecc2bd847e4bb49941093b88480d85ba207d5ac4db8d3ff42f"}, + {file = "bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:b2f24c6cbf11fc8c5d69b3dcecee9f7011451ec59d6ac833e873c9f105259668"}, + {file = "bitsandbytes-0.44.1-py3-none-win_amd64.whl", hash = "sha256:8e68e12aa25d2cf9a1730ad72890a5d1a19daa23f459a6a4679331f353d58cb4"}, ] [package.dependencies] @@ -589,29 +589,29 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] [[package]] name = "ruff" -version = "0.6.8" +version = "0.6.9" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.6.8-py3-none-linux_armv6l.whl", hash = "sha256:77944bca110ff0a43b768f05a529fecd0706aac7bcce36d7f1eeb4cbfca5f0f2"}, - {file = "ruff-0.6.8-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:27b87e1801e786cd6ede4ada3faa5e254ce774de835e6723fd94551464c56b8c"}, - {file = "ruff-0.6.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:cd48f945da2a6334f1793d7f701725a76ba93bf3d73c36f6b21fb04d5338dcf5"}, - {file = "ruff-0.6.8-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:677e03c00f37c66cea033274295a983c7c546edea5043d0c798833adf4cf4c6f"}, - {file = "ruff-0.6.8-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9f1476236b3eacfacfc0f66aa9e6cd39f2a624cb73ea99189556015f27c0bdeb"}, - {file = "ruff-0.6.8-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f5a2f17c7d32991169195d52a04c95b256378bbf0de8cb98478351eb70d526f"}, - {file = "ruff-0.6.8-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5fd0d4b7b1457c49e435ee1e437900ced9b35cb8dc5178921dfb7d98d65a08d0"}, - {file = "ruff-0.6.8-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8034b19b993e9601f2ddf2c517451e17a6ab5cdb1c13fdff50c1442a7171d87"}, - {file = "ruff-0.6.8-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6cfb227b932ba8ef6e56c9f875d987973cd5e35bc5d05f5abf045af78ad8e098"}, - {file = "ruff-0.6.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ef0411eccfc3909269fed47c61ffebdcb84a04504bafa6b6df9b85c27e813b0"}, - {file = "ruff-0.6.8-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:007dee844738c3d2e6c24ab5bc7d43c99ba3e1943bd2d95d598582e9c1b27750"}, - {file = "ruff-0.6.8-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ce60058d3cdd8490e5e5471ef086b3f1e90ab872b548814e35930e21d848c9ce"}, - {file = "ruff-0.6.8-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1085c455d1b3fdb8021ad534379c60353b81ba079712bce7a900e834859182fa"}, - {file = "ruff-0.6.8-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:70edf6a93b19481affd287d696d9e311388d808671bc209fb8907b46a8c3af44"}, - {file = "ruff-0.6.8-py3-none-win32.whl", hash = "sha256:792213f7be25316f9b46b854df80a77e0da87ec66691e8f012f887b4a671ab5a"}, - {file = "ruff-0.6.8-py3-none-win_amd64.whl", hash = "sha256:ec0517dc0f37cad14a5319ba7bba6e7e339d03fbf967a6d69b0907d61be7a263"}, - {file = "ruff-0.6.8-py3-none-win_arm64.whl", hash = "sha256:8d3bb2e3fbb9875172119021a13eed38849e762499e3cfde9588e4b4d70968dc"}, - {file = "ruff-0.6.8.tar.gz", hash = "sha256:a5bf44b1aa0adaf6d9d20f86162b34f7c593bfedabc51239953e446aefc8ce18"}, + {file = "ruff-0.6.9-py3-none-linux_armv6l.whl", hash = "sha256:064df58d84ccc0ac0fcd63bc3090b251d90e2a372558c0f057c3f75ed73e1ccd"}, + {file = "ruff-0.6.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:140d4b5c9f5fc7a7b074908a78ab8d384dd7f6510402267bc76c37195c02a7ec"}, + {file = "ruff-0.6.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:53fd8ca5e82bdee8da7f506d7b03a261f24cd43d090ea9db9a1dc59d9313914c"}, + {file = "ruff-0.6.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645d7d8761f915e48a00d4ecc3686969761df69fb561dd914a773c1a8266e14e"}, + {file = "ruff-0.6.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eae02b700763e3847595b9d2891488989cac00214da7f845f4bcf2989007d577"}, + {file = "ruff-0.6.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d5ccc9e58112441de8ad4b29dcb7a86dc25c5f770e3c06a9d57e0e5eba48829"}, + {file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:417b81aa1c9b60b2f8edc463c58363075412866ae4e2b9ab0f690dc1e87ac1b5"}, + {file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c866b631f5fbce896a74a6e4383407ba7507b815ccc52bcedabb6810fdb3ef7"}, + {file = "ruff-0.6.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b118afbb3202f5911486ad52da86d1d52305b59e7ef2031cea3425142b97d6f"}, + {file = "ruff-0.6.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a67267654edc23c97335586774790cde402fb6bbdb3c2314f1fc087dee320bfa"}, + {file = "ruff-0.6.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3ef0cc774b00fec123f635ce5c547dac263f6ee9fb9cc83437c5904183b55ceb"}, + {file = "ruff-0.6.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:12edd2af0c60fa61ff31cefb90aef4288ac4d372b4962c2864aeea3a1a2460c0"}, + {file = "ruff-0.6.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:55bb01caeaf3a60b2b2bba07308a02fca6ab56233302406ed5245180a05c5625"}, + {file = "ruff-0.6.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:925d26471fa24b0ce5a6cdfab1bb526fb4159952385f386bdcc643813d472039"}, + {file = "ruff-0.6.9-py3-none-win32.whl", hash = "sha256:eb61ec9bdb2506cffd492e05ac40e5bc6284873aceb605503d8494180d6fc84d"}, + {file = "ruff-0.6.9-py3-none-win_amd64.whl", hash = "sha256:785d31851c1ae91f45b3d8fe23b8ae4b5170089021fbb42402d811135f0b7117"}, + {file = "ruff-0.6.9-py3-none-win_arm64.whl", hash = "sha256:a9641e31476d601f83cd602608739a0840e348bda93fec9f1ee816f8b6798b93"}, + {file = "ruff-0.6.9.tar.gz", hash = "sha256:b076ef717a8e5bc819514ee1d602bbdca5b4420ae13a9cf61a0c0a4f53a2baa2"}, ] [[package]] @@ -633,13 +633,13 @@ dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] [[package]] name = "tomli" -version = "2.0.1" +version = "2.0.2" description = "A lil' TOML parser" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, - {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, + {file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"}, + {file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"}, ] [[package]] diff --git a/pytorch_optimizer/optimizer/grokfast.py b/pytorch_optimizer/optimizer/grokfast.py index ef733a0b..df2ed50e 100644 --- a/pytorch_optimizer/optimizer/grokfast.py +++ b/pytorch_optimizer/optimizer/grokfast.py @@ -185,7 +185,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) should_grokfast: bool = ( - group['grokfast'] and group['step'] > group['grokfast_after_step'] and group['grokfast_lamb'] > 0 + group['grokfast'] and group['step'] > group['grokfast_after_step'] and group['grokfast_lamb'] > 0.0 ) for p in group['params']: @@ -201,7 +201,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: if len(state) == 0: state['exp_avg'] = torch.zeros_like(p) state['exp_avg_sq'] = torch.zeros_like(p) - if should_grokfast: + if group['grokfast'] and group['grokfast_lamb'] > 0.0: state['grok_exp_avg'] = grad.clone() self.apply_weight_decay( diff --git a/requirements-dev.txt b/requirements-dev.txt index 4b11462a..74595c3f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -22,8 +22,8 @@ platformdirs==4.3.6 ; python_version >= "3.8" pluggy==1.5.0 ; python_version >= "3.8" pytest-cov==5.0.0 ; python_version >= "3.8" pytest==8.3.3 ; python_version >= "3.8" -ruff==0.6.8 ; python_version >= "3.8" +ruff==0.6.9 ; python_version >= "3.8" sympy==1.13.3 ; python_version >= "3.8" -tomli==2.0.1 ; python_full_version <= "3.11.0a6" and python_version >= "3.8" +tomli==2.0.2 ; python_full_version <= "3.11.0a6" and python_version >= "3.8" torch==2.4.1+cpu ; python_version >= "3.8" typing-extensions==4.12.2 ; python_version >= "3.8" diff --git a/tests/constants.py b/tests/constants.py index 15cd66ab..1d97fbca 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -471,7 +471,7 @@ (ScheduleFreeSGD, {'lr': 1e0, 'weight_decay': 1e-3}, 5), (ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5), (FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5), - (GrokFastAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 10), + (GrokFastAdamW, {'lr': 5e0, 'weight_decay': 1e-3, 'grokfast_after_step': 1}, 5), (Kate, {'lr': 5e-2}, 10), (StableAdamW, {'lr': 1e0}, 5), (AdamG, {'lr': 1e0}, 20),