diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml new file mode 100644 index 0000000..3bfabfc --- /dev/null +++ b/.github/workflows/python-publish.yml @@ -0,0 +1,36 @@ +# This workflow will upload a Python Package using Twine when a release is created +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Upload Python Package + +on: + release: + types: [published] + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build + - name: Build package + run: python -m build + - name: Publish package + uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/README.md b/README.md index 1f19b12..a37af11 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,41 @@ Implementation of the proposed Adam-a A multi-million dollar paper out of google deepmind basically proposes a small change to Adam (using `atan2`) for greater stability +## Install + +```bash +$ pip install adam-atan2-pytorch +``` + +## Usage + +```python +# toy model + +import torch +from torch import nn + +model = nn.Linear(10, 1) + +# import AdamAtan2 and instantiate with parameters + +from adam_atan2_pytorch import AdamAtan2 + +opt = AdamAtan2(model.parameters(), lr = 1e-4) + +# forward and backwards + +for _ in range(100): + loss = model(torch.randn(10)) + loss.backward() + + # optimizer step + + opt.step() + opt.zero_grad() + +``` + ## Citations ```bibtex diff --git a/adam_atan2_pytorch/__init__.py b/adam_atan2_pytorch/__init__.py new file mode 100644 index 0000000..ccb2223 --- /dev/null +++ b/adam_atan2_pytorch/__init__.py @@ -0,0 +1 @@ +from adam_atan2_pytorch.adam_atan2 import AdamAtan2 diff --git a/adam_atan2_pytorch/adam_atan2.py b/adam_atan2_pytorch/adam_atan2.py new file mode 100644 index 0000000..94b123e --- /dev/null +++ b/adam_atan2_pytorch/adam_atan2.py @@ -0,0 +1,96 @@ +from __future__ import annotations +from typing import Tuple, Callable + +import torch +from torch import atan2, sqrt +from torch.optim.optimizer import Optimizer + +# functions + +def exists(val): + return val is not None + +# class + +class AdamAtan2(Optimizer): + def __init__( + self, + params, + lr = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay = 0., + a = 1., + b = 1. + ): + assert lr > 0. + assert all([0. <= beta <= 1. for beta in betas]) + assert weight_decay >= 0. + + self._init_lr = lr + + defaults = dict( + lr = lr, + betas = betas, + a = a, + b = b, + weight_decay = weight_decay + ) + + super().__init__(params, defaults) + + @torch.no_grad() + def step( + self, + closure: Callable | None = None + ): + + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in filter(lambda p: exists(p.grad), group['params']): + + grad, lr, wd, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr + + # decoupled weight decay + + if wd > 0.: + p.mul_(1. - lr / init_lr * wd) + + # init state if needed + + if len(state) == 0: + state['steps'] = 0 + state['exp_avg'] = torch.zeros_like(grad) + state['exp_avg_sq'] = torch.zeros_like(grad) + + # get some of the states + + exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps'] + + steps += 1 + + # bias corrections + + bias_correct1 = 1. - beta1 ** steps + bias_correct2 = 1. - beta2 ** steps + + # decay running averages + + exp_avg.lerp_(grad, 1. - beta1) + exp_avg_sq.lerp_(grad * grad, 1. - beta2) + + # the following line is the proposed change to the update rule + # using atan2 instead of a division with epsilons - they also suggest hyperparameters `a` and `b` should be explored beyond its default of 1. + + update = a * atan2(exp_avg / bias_correct1, b * sqrt(exp_avg_sq / bias_correct2)) + + p.add_(update, alpha = -lr) + + # increment steps + + state['steps'] = steps + + return loss diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..be0ec94 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,57 @@ +[project] +name = "adam-atan2-pytorch" +version = "0.0.1" +description = "Adam-atan2 for Pytorch" +authors = [ + { name = "Phil Wang", email = "lucidrains@gmail.com" } +] +readme = "README.md" +requires-python = ">= 3.9" +license = { file = "LICENSE" } +keywords = [ + 'artificial intelligence', + 'deep learning', + 'adam', + 'optimizers' +] + +classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3.9', +] + +dependencies = [ + "torch>=2.0", +] + +[project.urls] +Homepage = "https://pypi.org/project/adam_atan2_pytorch/" +Repository = "https://github.com/lucidrains/adam_atan2_pytorch" + +[project.optional-dependencies] +examples = [] +test = [ + "pytest" +] + +[tool.pytest.ini_options] +pythonpath = [ + "." +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.rye] +managed = true +dev-dependencies = [] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["adam_atan2_pytorch"]