Skip to content

Commit

Permalink
release 0.0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 30, 2024
1 parent 28f3b2f commit 4700a8c
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 0 deletions.
36 changes: 36 additions & 0 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.

name: Upload Python Package

on:
release:
types: [published]

jobs:
deploy:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,41 @@ Implementation of the proposed <a href="https://arxiv.org/abs/2407.05872">Adam-a

A multi-million dollar paper out of google deepmind basically proposes a small change to Adam (using `atan2`) for greater stability

## Install

```bash
$ pip install adam-atan2-pytorch
```

## Usage

```python
# toy model

import torch
from torch import nn

model = nn.Linear(10, 1)

# import AdamAtan2 and instantiate with parameters

from adam_atan2_pytorch import AdamAtan2

opt = AdamAtan2(model.parameters(), lr = 1e-4)

# forward and backwards

for _ in range(100):
loss = model(torch.randn(10))
loss.backward()

# optimizer step

opt.step()
opt.zero_grad()

```

## Citations

```bibtex
Expand Down
1 change: 1 addition & 0 deletions adam_atan2_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from adam_atan2_pytorch.adam_atan2 import AdamAtan2
96 changes: 96 additions & 0 deletions adam_atan2_pytorch/adam_atan2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations
from typing import Tuple, Callable

import torch
from torch import atan2, sqrt
from torch.optim.optimizer import Optimizer

# functions

def exists(val):
return val is not None

# class

class AdamAtan2(Optimizer):
def __init__(
self,
params,
lr = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay = 0.,
a = 1.,
b = 1.
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])
assert weight_decay >= 0.

self._init_lr = lr

defaults = dict(
lr = lr,
betas = betas,
a = a,
b = b,
weight_decay = weight_decay
)

super().__init__(params, defaults)

@torch.no_grad()
def step(
self,
closure: Callable | None = None
):

loss = None
if exists(closure):
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

grad, lr, wd, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr

# decoupled weight decay

if wd > 0.:
p.mul_(1. - lr / init_lr * wd)

# init state if needed

if len(state) == 0:
state['steps'] = 0
state['exp_avg'] = torch.zeros_like(grad)
state['exp_avg_sq'] = torch.zeros_like(grad)

# get some of the states

exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps']

steps += 1

# bias corrections

bias_correct1 = 1. - beta1 ** steps
bias_correct2 = 1. - beta2 ** steps

# decay running averages

exp_avg.lerp_(grad, 1. - beta1)
exp_avg_sq.lerp_(grad * grad, 1. - beta2)

# the following line is the proposed change to the update rule
# using atan2 instead of a division with epsilons - they also suggest hyperparameters `a` and `b` should be explored beyond its default of 1.

update = a * atan2(exp_avg / bias_correct1, b * sqrt(exp_avg_sq / bias_correct2))

p.add_(update, alpha = -lr)

# increment steps

state['steps'] = steps

return loss
57 changes: 57 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
[project]
name = "adam-atan2-pytorch"
version = "0.0.1"
description = "Adam-atan2 for Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
]
readme = "README.md"
requires-python = ">= 3.9"
license = { file = "LICENSE" }
keywords = [
'artificial intelligence',
'deep learning',
'adam',
'optimizers'
]

classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.9',
]

dependencies = [
"torch>=2.0",
]

[project.urls]
Homepage = "https://pypi.org/project/adam_atan2_pytorch/"
Repository = "https://github.com/lucidrains/adam_atan2_pytorch"

[project.optional-dependencies]
examples = []
test = [
"pytest"
]

[tool.pytest.ini_options]
pythonpath = [
"."
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.rye]
managed = true
dev-dependencies = []

[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.build.targets.wheel]
packages = ["adam_atan2_pytorch"]

0 comments on commit 4700a8c

Please sign in to comment.