Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: lucidrains/adam-atan2-pytorch
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 0.1.4
Choose a base ref
...
head repository: lucidrains/adam-atan2-pytorch
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: main
Choose a head ref
  • 10 commits
  • 6 files changed
  • 1 contributor

Commits on Nov 21, 2024

  1. Copy the full SHA
    f6ab117 View commit details

Commits on Nov 22, 2024

  1. Copy the full SHA
    57ffaf3 View commit details
  2. facepalm

    lucidrains committed Nov 22, 2024
    Copy the full SHA
    0177579 View commit details
  3. Copy the full SHA
    661aed2 View commit details
  4. Copy the full SHA
    4ba86bc View commit details
  5. Copy the full SHA
    7137f5d View commit details
  6. Copy the full SHA
    ef23c81 View commit details
  7. Copy the full SHA
    e3a0aff View commit details

Commits on Nov 27, 2024

  1. Copy the full SHA
    5a48ed9 View commit details
  2. add the proposed cautious optimizer from https://arxiv.org/abs/2411.1…

    …6085, but allow for attenuating unaligned updates with some factor instead of zeroing completely
    lucidrains committed Nov 27, 2024
    Copy the full SHA
    8f14cf5 View commit details
Showing with 172 additions and 14 deletions.
  1. +9 −0 README.md
  2. +2 −1 adam_atan2_pytorch/__init__.py
  3. +11 −1 adam_atan2_pytorch/adam_atan2.py
  4. +13 −11 adam_atan2_pytorch/adopt.py
  5. +136 −0 adam_atan2_pytorch/adopt_atan2.py
  6. +1 −1 pyproject.toml
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -81,3 +81,12 @@ for _ in range(100):
url = {https://api.semanticscholar.org/CorpusID:273822148}
}
```

```bibtex
@inproceedings{Liang2024CautiousOI,
title = {Cautious Optimizers: Improving Training with One Line of Code},
author = {Kaizhao Liang and Lizhang Chen and Bo Liu and Qiang Liu},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:274234738}
}
```
3 changes: 2 additions & 1 deletion adam_atan2_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from adam_atan2_pytorch.adam_atan2 import AdamAtan2
from adam_atan2_pytorch.adopt import Adopt
from adam_atan2_pytorch.adopt_atan2 import AdoptAtan2

Adam = AdamAtan2
Adopt = AdoptAtan2
12 changes: 11 additions & 1 deletion adam_atan2_pytorch/adam_atan2.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@ def __init__(
weight_decay = 0.,
regen_reg_rate = 0.,
decoupled_wd = False,
cautious_factor = 1., # set to 0. for zeroing out any updates not in same direction as gradient as in https://arxiv.org/abs/2411.16085
a = 1.27,
b = 1.
):
@@ -29,6 +30,7 @@ def __init__(
assert weight_decay >= 0.
assert regen_reg_rate >= 0.
assert not (weight_decay > 0. and regen_reg_rate > 0.)
assert 0. <= cautious_factor <= 1.

self._init_lr = lr
self.decoupled_wd = decoupled_wd
@@ -40,6 +42,7 @@ def __init__(
b = b,
weight_decay = weight_decay,
regen_reg_rate = regen_reg_rate,
cautious_factor = cautious_factor
)

super().__init__(params, defaults)
@@ -58,7 +61,7 @@ def step(
for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
grad, lr, wd, regen_rate, cautious_factor, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['cautious_factor'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr

# maybe decoupled weight decay

@@ -109,6 +112,13 @@ def step(
den = exp_avg_sq.mul(b * b / bias_correct2).sqrt_()
update = exp_avg.mul(1. / bias_correct1).atan2_(den)

# maybe cautious update - algorithm 2 in https://arxiv.org/abs/2411.16085

if cautious_factor < 1.:
align_mask = (update * grad) > 0
scale = torch.where(align_mask, torch.ones_like(grad), cautious_factor)
update *= (scale / scale.mean().clamp(min = 1e-5))

# update parameters

p.add_(update, alpha = -lr * a)
24 changes: 13 additions & 11 deletions adam_atan2_pytorch/adopt.py
Original file line number Diff line number Diff line change
@@ -16,14 +16,14 @@ class Adopt(Optimizer):
"""
the proposed Adam substitute from University of Tokyo
Algorithm 2 in https://arxiv.org/abs/2411.02853
Algorithm 3 in https://arxiv.org/abs/2411.02853
"""

def __init__(
self,
params,
lr = 1e-4,
betas: tuple[float, float] = (0.9, 0.9999),
betas: tuple[float, float] = (0.9, 0.99),
eps = 1e-6,
weight_decay = 0.,
decoupled_wd = True
@@ -74,7 +74,7 @@ def step(

if len(state) == 0:
state['steps'] = 0
state['m'] = torch.empty_like(grad)
state['m'] = torch.zeros_like(grad)
state['v'] = grad * grad

# get some of the states
@@ -87,18 +87,20 @@ def step(
state['steps'] += 1
continue

# logic

steps += 1

# calculate m

grad_sq = grad * grad

next_m = grad.div(v.sqrt().clamp(min = eps)) # they claim that a max(value, eps) performs better than adding the epsilon
update = grad.div(v.sqrt().clamp(min = eps)) # they claim that a max(value, eps) performs better than adding the epsilon

# clip with t ^ 0.25 as in Algorithm 3

clip_value = steps ** 0.25
update.clamp_(min = -clip_value, max = clip_value)

# update m

if steps > 1:
m.lerp_(next_m, 1. - beta1)
m.lerp_(update, 1. - beta1)

# then update parameters

@@ -110,6 +112,6 @@ def step(

# increment steps

state['steps'] = steps
state['steps'] += 1

return loss
136 changes: 136 additions & 0 deletions adam_atan2_pytorch/adopt_atan2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import annotations
from typing import Callable

import torch
from torch import atan2, sqrt
from torch.optim.optimizer import Optimizer

# functions

def exists(val):
return val is not None

# class

class AdoptAtan2(Optimizer):
"""
the proposed Adam substitute from University of Tokyo
combined with the proposed atan2 method for ridding of the eps from Google
Algorithm 3 in https://arxiv.org/abs/2411.02853
"""

def __init__(
self,
params,
lr = 1e-4,
betas: tuple[float, float] = (0.9, 0.99),
weight_decay = 0.,
regen_reg_rate = 0.,
decoupled_wd = True,
cautious_factor = 1., # set to 0. for zeroing out any updates not in same direction as gradient as in https://arxiv.org/abs/2411.16085
a = 1.27,
b = 1.
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])
assert weight_decay >= 0.
assert not (weight_decay > 0. and regen_reg_rate > 0.)

self._init_lr = lr
self.decoupled_wd = decoupled_wd

defaults = dict(
lr = lr,
betas = betas,
a = a,
b = b,
weight_decay = weight_decay,
regen_reg_rate = regen_reg_rate,
cautious_factor = cautious_factor
)

super().__init__(params, defaults)

@torch.no_grad()
def step(
self,
closure: Callable | None = None
):

loss = None
if exists(closure):
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

grad, lr, wd, regen_rate, cautious_factor, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['cautious_factor'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr

# maybe decoupled weight decay

if self.decoupled_wd:
wd /= init_lr

# regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958

if regen_rate > 0. and 'param_init' in state:
param_init = state['param_init']
p.lerp_(param_init, lr / init_lr * regen_rate)

# weight decay

if wd > 0.:
p.mul_(1. - lr * wd)

# init state if needed

if len(state) == 0:
state['steps'] = 0
state['m'] = torch.zeros_like(grad)
state['v'] = grad * grad

if regen_rate > 0.:
state['param_init'] = p.clone()

# get some of the states

m, v, steps = state['m'], state['v'], state['steps']

# for the first step do nothing

if steps == 0:
state['steps'] += 1
continue

# calculate m

grad_sq = grad * grad

update = grad.atan2(b * v.sqrt())

m.lerp_(update, 1. - beta1)

# maybe cautious update - algorithm 2 in https://arxiv.org/abs/2411.16085

scale = 1.

if cautious_factor < 1.:
align_mask = (update * grad) > 0
scale = torch.where(align_mask, torch.ones_like(grad), cautious_factor)
scale /= scale.mean().clamp(min = 1e-5)

# then update parameters

p.add_(m * scale, alpha = -lr * a)

# update exp grad sq (v)

v.lerp_(grad_sq, 1. - beta2)

# increment steps

state['steps'] += 1

return loss
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "adam-atan2-pytorch"
version = "0.1.4"
version = "0.1.18"
description = "Adam-atan2 for Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }