Skip to content

Commit

Permalink
Merge pull request #280 from kozistr/fix/should-grokfast
Browse files Browse the repository at this point in the history
[Fix] `should_grokfast` condition
  • Loading branch information
kozistr authored Oct 12, 2024
2 parents 09a2b83 + 600899f commit 1687e37
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 31 deletions.
4 changes: 4 additions & 0 deletions docs/changelogs/v3.2.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@
* [SOAP: Improving and Stabilizing Shampoo using Adam](https://arxiv.org/abs/2409.11321)
* Support `AdEMAMix` variants. (#276)
* `bnb_ademamix8bit`, `bnb_ademamix32bit`, `bnb_paged_ademamix8bit`, `bnb_paged_ademamix32bit`

### Bug

* Fix `should_grokfast` condition when initialization. (#279, #280)
4 changes: 4 additions & 0 deletions docs/qa.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
## Q1) SophiaH, AdaHessian optimizers give ```RuntimeError: ~ tensors does not require grad and does not have a grad_fn``` in `compute_hutchinson_hessian()`.

`create_graph` must be set `True` when calling `backward()`. here's [an example](https://github.com/kozistr/pytorch_optimizer/issues/194#issuecomment-1723167466).

## Q2) Memory leak happens when using SophiaH, AdaHessian optimizers.

`torch.autograd.grad` with complex gradient flows sometimes leads memory leak issues, and you might encounter OOM issue. [related issue](https://github.com/kozistr/pytorch_optimizer/issues/278)
52 changes: 26 additions & 26 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pytorch_optimizer/optimizer/grokfast.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

should_grokfast: bool = (
group['grokfast'] and group['step'] > group['grokfast_after_step'] and group['grokfast_lamb'] > 0
group['grokfast'] and group['step'] > group['grokfast_after_step'] and group['grokfast_lamb'] > 0.0
)

for p in group['params']:
Expand All @@ -201,7 +201,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
if should_grokfast:
if group['grokfast'] and group['grokfast_lamb'] > 0.0:
state['grok_exp_avg'] = grad.clone()

self.apply_weight_decay(
Expand Down
4 changes: 2 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ platformdirs==4.3.6 ; python_version >= "3.8"
pluggy==1.5.0 ; python_version >= "3.8"
pytest-cov==5.0.0 ; python_version >= "3.8"
pytest==8.3.3 ; python_version >= "3.8"
ruff==0.6.8 ; python_version >= "3.8"
ruff==0.6.9 ; python_version >= "3.8"
sympy==1.13.3 ; python_version >= "3.8"
tomli==2.0.1 ; python_full_version <= "3.11.0a6" and python_version >= "3.8"
tomli==2.0.2 ; python_full_version <= "3.11.0a6" and python_version >= "3.8"
torch==2.4.1+cpu ; python_version >= "3.8"
typing-extensions==4.12.2 ; python_version >= "3.8"
2 changes: 1 addition & 1 deletion tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@
(ScheduleFreeSGD, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(GrokFastAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 10),
(GrokFastAdamW, {'lr': 5e0, 'weight_decay': 1e-3, 'grokfast_after_step': 1}, 5),
(Kate, {'lr': 5e-2}, 10),
(StableAdamW, {'lr': 1e0}, 5),
(AdamG, {'lr': 1e0}, 20),
Expand Down

0 comments on commit 1687e37

Please sign in to comment.