Skip to content

Commit

Permalink
Merge pull request #47 from kozistr/feature/madgrad
Browse files Browse the repository at this point in the history
[Fix] sparse gradient for MADGRAD
  • Loading branch information
kozistr authored Jan 29, 2022
2 parents afa33ed + d3fc1dc commit 5538b6b
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 18 deletions.
4 changes: 2 additions & 2 deletions pytorch_optimizer/adamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def __init__(
adamd_debias_term: bool = False,
eps: float = 1e-8,
):
"""
"""AdamP optimizer
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate.
:param lr: float. learning rate
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param weight_decay: float. weight decay (L2 penalty)
:param delta: float. threshold that determines whether a set of parameters is scale invariant or not
Expand Down
4 changes: 2 additions & 2 deletions pytorch_optimizer/diffgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def __init__(
weight_decay: float = 0.0,
adamd_debias_term: bool = False,
):
"""
"""DiffGrad
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate.
:param lr: float. learning rate
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param eps: float. term added to the denominator to improve numerical stability
:param weight_decay: float. weight decay (L2 penalty)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_optimizer/diffrgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ def __init__(
):
"""Blend RAdam with DiffGrad
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate.
:param lr: float. learning rate
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
:param weight_decay: float. weight decay (L2 penalty)
:param n_sma_threshold: int. (recommended is 5)
:param degenerated_to_sgd: bool..
:param degenerated_to_sgd: bool. degenerated to SGD
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
:param eps: float. term added to the denominator to improve numerical stability
"""
Expand Down
4 changes: 2 additions & 2 deletions pytorch_optimizer/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def __init__(
alpha: float = 0.5,
pullback_momentum: str = 'none',
):
"""
:param optimizer: Optimizer.
"""Lookahead
:param optimizer: Optimizer. base optimizer
:param k: int. number of lookahead steps
:param alpha: float. linear interpolation factor
:param pullback_momentum: str. change to inner optimizer momentum on interpolation update
Expand Down
4 changes: 2 additions & 2 deletions pytorch_optimizer/sgdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def __init__(
wd_ratio: float = 0.1,
nesterov: bool = False,
):
"""
"""SGDP optimizer
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate.
:param lr: float. learning rate
:param momentum: float. momentum factor
:param dampening: float. dampening for momentum
:param eps: float. term added to the denominator to improve numerical stability
Expand Down
2 changes: 1 addition & 1 deletion pytorch_optimizer/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__VERSION__ = '0.3.3'
__VERSION__ = '0.3.4'
15 changes: 13 additions & 2 deletions tests/test_optimizer_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from pytorch_optimizer import SAM, load_optimizers
from pytorch_optimizer import SAM, Lookahead, load_optimizers

OPTIMIZER_NAMES: List[str] = [
'adamp',
Expand Down Expand Up @@ -67,6 +67,17 @@ def test_betas(optimizer_names):
optimizer(None, betas=(0.1, -0.1))


def test_rho():
def test_sam_parameters():
with pytest.raises(ValueError):
SAM(None, load_optimizers('adamp'), rho=-0.1)


def test_lookahead_parameters():
with pytest.raises(ValueError):
Lookahead(load_optimizers('adamp'), k=0)

with pytest.raises(ValueError):
Lookahead(load_optimizers('adamp'), alpha=0)

with pytest.raises(ValueError):
Lookahead(load_optimizers('adamp'), pullback_momentum='asdf')
5 changes: 3 additions & 2 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,9 @@ def test_f16_optimizers(optimizer_fp16_config):
assert init_loss - 0.01 > loss


@pytest.mark.parametrize('adaptive', (False, True))
@pytest.mark.parametrize('optimizer_sam_config', FP32_OPTIMIZERS, ids=ids)
def test_sam_optimizers(optimizer_sam_config):
def test_sam_optimizers(adaptive, optimizer_sam_config):
torch.manual_seed(42)

x_data, y_data = make_dataset()
Expand All @@ -207,7 +208,7 @@ def test_sam_optimizers(optimizer_sam_config):
loss_fn: nn.Module = nn.BCEWithLogitsLoss()

optimizer_class, config, iterations = optimizer_sam_config
optimizer = SAM(model.parameters(), optimizer_class, **config)
optimizer = SAM(model.parameters(), optimizer_class, **config, adaptive=adaptive)

loss: float = np.inf
init_loss: float = np.inf
Expand Down
10 changes: 7 additions & 3 deletions tests/test_sparse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ def test_sparse_not_supported(no_sparse_optimizer):
grad = torch.randn(1, 1).to_sparse(1)
param.grad = grad

optimizer = load_optimizers(optimizer=no_sparse_optimizer)([param])
optimizer.zero_grad()

with pytest.raises(RuntimeError):
optimizer = load_optimizers(optimizer=no_sparse_optimizer)([param])
optimizer.zero_grad()
optimizer.step()


Expand All @@ -47,3 +46,8 @@ def test_sparse_supported(sparse_optimizer):
optimizer = load_optimizers(optimizer=sparse_optimizer)([param], momentum=0.0)
optimizer.zero_grad()
optimizer.step()

with pytest.raises(RuntimeError):
optimizer = load_optimizers(optimizer=sparse_optimizer)([param], momentum=0.0, weight_decay=1e-3)
optimizer.zero_grad()
optimizer.step()

0 comments on commit 5538b6b

Please sign in to comment.