Skip to content

Commit

Permalink
Merge pull request #160 from kozistr/feature/adadelta-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement AdaDelta optimizer
  • Loading branch information
kozistr authored May 6, 2023
2 parents 2e97e5f + 4cc9bd7 commit 4dbfc23
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 21 deletions.
9 changes: 2 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,9 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }}
restore-keys: ${{ runner.os }}-pip-
cache: 'pip'
- name: Install dependencies
run: pip install -r requirements-dev.txt
run: pip --disable-pip-version-check install --no-compile -r requirements-dev.txt
- name: Check lint
run: make check
- name: Check test
Expand Down
11 changes: 3 additions & 8 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,11 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Cache pip
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }}
restore-keys: ${{ runner.os }}-pip-
cache: 'pip'
- name: Install dependencies
run: |
python3 -m pip install poetry
python3 -m pip install -r requirements.txt
pip install poetry
pip install -r requirements.txt
- name: Publish package to PyPI
env:
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
Expand Down
4 changes: 3 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pytorch-optimizer

| **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
| I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
| Currently, 49 optimizers, 6 lr schedulers are supported!
| Currently, 50 optimizers, 6 lr schedulers are supported!
|
| Highly inspired by `pytorch-optimizer <https://github.com/jettify/pytorch-optimizer>`__.
Expand Down Expand Up @@ -211,6 +211,8 @@ You can check the supported optimizers & lr schedulers.
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
| AdaShift | *Decorrelation and Convergence of Adaptive Learning Rate Methods* | `github <https://github.com/MichaelKonobeev/adashift>`__ | `https://arxiv.org/abs/1810.00143v4 <https://arxiv.org/abs/1810.00143v4>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2018arXiv181000143Z/exportcitation>`__ |
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
| AdaDelta | *An Adaptive Learning Rate Method* | | `https://arxiv.org/abs/1212.5701v1 <https://arxiv.org/abs/1212.5701v1>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2012arXiv1212.5701Z/exportcitation>`__ |
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+

Useful Resources
----------------
Expand Down
8 changes: 8 additions & 0 deletions docs/optimizer_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,11 @@ AdaShift

.. autoclass:: pytorch_optimizer.AdaShift
:members:

.. _AdaDelta:

AdaDelta
--------

.. autoclass:: pytorch_optimizer.AdaDelta
:members:
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytorch_optimizer"
version = "2.8.0"
version = "2.9.0"
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
license = "Apache-2.0"
authors = ["kozistr <[email protected]>"]
Expand All @@ -9,7 +9,7 @@ readme = "README.rst"
homepage = "https://github.com/kozistr/pytorch_optimizer"
repository = "https://github.com/kozistr/pytorch_optimizer"
documentation = "https://pytorch-optimizers.readthedocs.io/en/latest"
keywords = ["pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "Adai", "AdamP", "AdamS", "Adan", "AggMo", "AliG", "Apollo", "AvaGrad", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DiffGrad", "Fromage", "Gravity", "LARS", "Lamb", "Lion", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PID", "PNM", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "SGDP", "SGDW", "SM3", "SRMM", "SWATS", "ScalableShampoo", "Shampoo", "Yogi", "SAM", "GSAM", "PCGrad", "RotoGrad"]
keywords = ["pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound", "AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "Adai", "AdamP", "AdamS", "Adan", "AggMo", "AliG", "Apollo", "AvaGrad", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DiffGrad", "Fromage", "Gravity", "LARS", "Lamb", "Lion", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PID", "PNM", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "SGDP", "SGDW", "SM3", "SRMM", "SWATS", "ScalableShampoo", "Shampoo", "Yogi", "SAM", "GSAM", "PCGrad", "RotoGrad"]
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Development Status :: 5 - Production/Stable",
Expand Down
2 changes: 2 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pytorch_optimizer.optimizer.a2grad import A2Grad
from pytorch_optimizer.optimizer.adabelief import AdaBelief
from pytorch_optimizer.optimizer.adabound import AdaBound
from pytorch_optimizer.optimizer.adadelta import AdaDelta
from pytorch_optimizer.optimizer.adafactor import AdaFactor
from pytorch_optimizer.optimizer.adai import Adai
from pytorch_optimizer.optimizer.adamax import AdaMax
Expand Down Expand Up @@ -143,6 +144,7 @@
SRMM,
AvaGrad,
AdaShift,
AdaDelta,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
106 changes: 106 additions & 0 deletions pytorch_optimizer/optimizer/adadelta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS


class AdaDelta(Optimizer, BaseOptimizer):
r"""An Adaptive Learning Rate Method.
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param rho: float. coefficient used for computing a running average of squared gradients.
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
:param fixed_decay: bool. fix weight decay.
:param eps: float. term added to the denominator to improve numerical stability.
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1.0,
rho: float = 0.9,
weight_decay: float = 0.0,
weight_decouple: bool = False,
fixed_decay: bool = False,
eps: float = 1e-6,
):
self.validate_learning_rate(lr)
self.validate_range(rho, 'rho', 0.0, 1.0)
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

defaults: DEFAULTS = {
'lr': lr,
'rho': rho,
'weight_decay': weight_decay,
'weight_decouple': weight_decouple,
'fixed_decay': fixed_decay,
'eps': eps,
}
super().__init__(params, defaults)

def __str__(self) -> str:
return 'AdaDelta'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
state = self.state[p]

state['square_avg'] = torch.zeros_like(p)
state['acc_delta'] = torch.zeros_like(p)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1

rho: float = group['rho']

for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(str(self))

state = self.state[p]

if len(state) == 0:
state['square_avg'] = torch.zeros_like(p)
state['acc_delta'] = torch.zeros_like(p)

self.apply_weight_decay(
p=p,
grad=grad,
lr=group['lr'],
weight_decay=group['weight_decay'],
weight_decouple=group['weight_decouple'],
fixed_decay=group['fixed_decay'],
)

square_avg, acc_delta = state['square_avg'], state['acc_delta']
square_avg.mul_(rho).addcmul_(grad, grad, value=1.0 - rho)

std = square_avg.add(group['eps']).sqrt_()
delta = acc_delta.add(group['eps']).sqrt_().div_(std).mul_(grad)

acc_delta.mul_(rho).addcmul_(delta, delta, value=1.0 - rho)
p.add_(delta, alpha=-group['lr'])

return loss
2 changes: 1 addition & 1 deletion pytorch_optimizer/optimizer/dadapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:

# it's not Adam Debias
d_lr: float = self.apply_adam_debias(
group['bias_correction'], step_size=d * lr, bias_correction1=bias_correction
not group['bias_correction'], step_size=d * lr, bias_correction1=bias_correction
)

sk_l1 = torch.tensor([0.0], device=device)
Expand Down
4 changes: 3 additions & 1 deletion tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AccSGD,
AdaBelief,
AdaBound,
AdaDelta,
AdaFactor,
Adai,
AdaMax,
Expand Down Expand Up @@ -312,7 +313,7 @@
(Adan, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': True}, 5),
(DAdaptAdaGrad, {'lr': 3e0, 'weight_decay': 1e-3}, 30),
(DAdaptAdaGrad, {'lr': 5e0, 'weight_decay': 1e-3, 'momentum': 0.1}, 20),
(DAdaptAdam, {'lr': 5e4, 'weight_decay': 1e-1}, 10),
(DAdaptAdam, {'lr': 5e4, 'weight_decay': 1e-3}, 5),
(DAdaptSGD, {'lr': 2e0, 'weight_decay': 1e-3}, 25),
(DAdaptAdan, {'lr': 2e0, 'weight_decay': 1e-3}, 20),
(DAdaptAdan, {'lr': 2e0, 'weight_decay': 1e-3, 'weight_decouple': True}, 20),
Expand Down Expand Up @@ -363,6 +364,7 @@
(SRMM, {'lr': 5e-1}, 5),
(AvaGrad, {'lr': 1e1}, 5),
(AdaShift, {'lr': 1e0, 'keep_num': 1}, 5),
(AdaDelta, {'lr': 5e1}, 5),
]
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_load_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ def test_load_optimizers_invalid(invalid_optimizer_names):


def test_get_supported_optimizers():
assert len(get_supported_optimizers()) == 49
assert len(get_supported_optimizers()) == 50

0 comments on commit 4dbfc23

Please sign in to comment.