Skip to content

Commit

Permalink
Merge pull request #94 from kozistr/refactor/lr_scheduler
Browse files Browse the repository at this point in the history
[Feature] Implement GSAM optimizer
  • Loading branch information
kozistr authored Jan 24, 2023
2 parents 8a31b1e + a29fd3d commit f6baa63
Show file tree
Hide file tree
Showing 42 changed files with 720 additions and 38 deletions.
4 changes: 4 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ Supported Optimizers
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
| Adai | *Disentangling the Effects of Adaptive Learning Rate and Momentum* | `github <https://github.com/zeke-xie/adaptive-inertia-adai>`__ | `https://arxiv.org/abs/2006.15815 <https://arxiv.org/abs/2006.15815>`__ |
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
| GSAM | *Surrogate Gap Guided Sharpness-Aware Minimization* | `github <https://github.com/juntang-zhuang/GSAM>`__ | `https://openreview.net/pdf?id=edONMAnhLu- <https://openreview.net/pdf?id=edONMAnhLu->`__ |
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+

Useful Resources
----------------
Expand Down Expand Up @@ -303,6 +305,8 @@ Citations

`Adai <https://github.com/zeke-xie/adaptive-inertia-adai#citing>`__

`GSAM <https://github.com/juntang-zhuang/GSAM#citation>`__

Citation
--------

Expand Down
12 changes: 10 additions & 2 deletions docs/optimizer_api.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Implemented Optimizers
====================
Optimizers
==========

.. _AdaBelief:

Expand Down Expand Up @@ -192,3 +192,11 @@ Shampoo

.. autoclass:: pytorch_optimizer.Shampoo
:members:

.. _GSAM:

GSAM
----

.. autoclass:: pytorch_optimizer.GSAM
:members:
4 changes: 2 additions & 2 deletions docs/scheduler_api.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Implemented LR Schedulers
=========================
LR Schedulers
=============

.. _get_chebyshev_schedule:

Expand Down
21 changes: 19 additions & 2 deletions docs/util_api.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Implemented utilizations
========================
Utilizations
============

.. _clip_grad_norm:

Expand Down Expand Up @@ -56,3 +56,20 @@ SafeFP16Optimizer

.. autoclass:: pytorch_optimizer.SafeFP16Optimizer
:members:

.. _enable_running_stats:

enable_running_stats
--------------------

.. autoclass:: pytorch_optimizer.enable_running_stats
:members:


.. _disable_running_stats:

disable_running_stats
---------------------

.. autoclass:: pytorch_optimizer.disable_running_stats
:members:
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "pytorch_optimizer"
version = "2.1.1"
description = "Bunch of optimizer implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
version = "2.2.0"
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
license = "Apache-2.0"
authors = ["kozistr <[email protected]>"]
maintainers = ["kozistr <[email protected]>"]
Expand Down Expand Up @@ -51,6 +51,11 @@ name = "torch"
url = "https://download.pytorch.org/whl/cpu"
secondary = true

[tool.coverage.run]
omit = [
"./pytorch_optimizer/optimizer/gsam.py",
]

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
9 changes: 9 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
)
from pytorch_optimizer.lr_scheduler.chebyshev import get_chebyshev_schedule
from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts
from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler
from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler
from pytorch_optimizer.optimizer.adabelief import AdaBelief
from pytorch_optimizer.optimizer.adabound import AdaBound
from pytorch_optimizer.optimizer.adai import Adai
Expand All @@ -22,6 +24,7 @@
from pytorch_optimizer.optimizer.diffrgrad import DiffRGrad
from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
from pytorch_optimizer.optimizer.gc import centralize_gradient
from pytorch_optimizer.optimizer.gsam import GSAM
from pytorch_optimizer.optimizer.lamb import Lamb
from pytorch_optimizer.optimizer.lars import LARS
from pytorch_optimizer.optimizer.lookahead import Lookahead
Expand All @@ -38,6 +41,8 @@
from pytorch_optimizer.optimizer.shampoo import Shampoo
from pytorch_optimizer.optimizer.utils import (
clip_grad_norm,
disable_running_stats,
enable_running_stats,
get_optimizer_parameters,
matrix_power,
normalize_gradient,
Expand Down Expand Up @@ -74,6 +79,10 @@
CosineAnnealingWarmRestarts,
CyclicLR,
OneCycleLR,
CosineScheduler,
PolyScheduler,
LinearScheduler,
ProportionScheduler,
]
LR_SCHEDULERS: Dict[str, SCHEDULER] = {
str(lr_scheduler.__name__).lower(): lr_scheduler for lr_scheduler in LR_SCHEDULER_LIST
Expand Down
18 changes: 18 additions & 0 deletions pytorch_optimizer/base/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,21 @@ class NoClosureError(Exception):
def __init__(self, optimizer_name: str):
self.message: str = f'[-] {optimizer_name} requires closure.'
super().__init__(self.message)


class NegativeLRError(Exception):
"""Raised when learning rate is negative"""

def __init__(self, lr: float, lr_type: str = ''):
self.note: str = 'learning rate' if lr_type == '' else lr_type
self.message: str = f'[-] {self.note} must be positive. ({lr} > 0)'
super().__init__(self.message)


class NegativeStepError(Exception):
"""Raised when step is negative"""

def __init__(self, num_steps: int, step_type: str = ''):
self.note: str = 'step' if step_type == '' else step_type
self.message: str = f'[-] {self.note} must be positive. ({num_steps} > 0)'
super().__init__(self.message)
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import torch

from pytorch_optimizer.base.exception import NegativeLRError
from pytorch_optimizer.base.types import BETAS


class BaseOptimizer(ABC):
@staticmethod
def validate_learning_rate(learning_rate: float):
if learning_rate < 0.0:
raise ValueError(f'[-] learning rate {learning_rate} must be positive')
raise NegativeLRError(learning_rate)

@staticmethod
def validate_beta(beta: float):
Expand Down
91 changes: 91 additions & 0 deletions pytorch_optimizer/base/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from abc import ABC, abstractmethod
from typing import List

from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
from pytorch_optimizer.base.types import OPTIMIZER


class BaseLinearWarmupScheduler(ABC):
r"""BaseLinearWarmupScheduler class. The LR Scheduler class based on this class has linear warmup strategy.
:param optimizer: Optimizer. OPTIMIZER. It will set learning rate to all trainable parameters in optimizer.
:param t_max: int. total steps to train.
:param max_lr: float. maximum lr.
:param min_lr: float. minimum lr.
:param init_lr: float. initial lr.
:param warmup_steps: int. steps to warm-up.
"""

def __init__(
self,
optimizer: OPTIMIZER,
t_max: int,
max_lr: float,
min_lr: float = 0.0,
init_lr: float = 0.0,
warmup_steps: int = 0,
):
self.optimizer = optimizer
self.total_steps = t_max
self.max_lr = max_lr
self.min_lr = min_lr
self.init_lr = init_lr
self.warmup_steps = warmup_steps

self.step_t: int = 0
self.base_lrs: List[float] = []

# record current value in self._last_lr to match API from torch.optim.lr_scheduler
self.last_lr: List[float] = [init_lr]

self.validate_parameters()

self._init_lr()

def validate_parameters(self):
if self.min_lr < 0:
raise NegativeLRError(self.min_lr, 'min_lr')

if self.max_lr < 0:
raise NegativeLRError(self.max_lr, 'max_lr')

if self.init_lr < 0:
raise NegativeLRError(self.init_lr, 'init_lr')

if self.total_steps < 0:
raise NegativeStepError(self.total_steps, 't_max')

if self.warmup_steps < 0:
raise NegativeStepError(self.warmup_steps, 'warmup_steps')

def _init_lr(self):
self.base_lrs = []
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.min_lr
self.base_lrs.append(self.min_lr)

def step(self):
if self.step_t < self.warmup_steps:
value = self.init_lr + (self.max_lr - self.init_lr) * self.step_t / self.warmup_steps
elif self.step_t == self.warmup_steps:
value = self.max_lr
else:
value = self._step()

self.step_t += 1

# apply the lr to optimizer if it's provided
if self.optimizer is not None:
for param_group in self.optimizer.param_groups:
param_group['lr'] = value

self.last_lr = [value]

return value

@abstractmethod
def _step(self) -> float:
raise NotImplementedError

def get_lr(self) -> float:
return self.last_lr[0]
36 changes: 36 additions & 0 deletions pytorch_optimizer/lr_scheduler/linear_warmup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import math

import numpy as np

from pytorch_optimizer.base.scheduler import BaseLinearWarmupScheduler


class LinearScheduler(BaseLinearWarmupScheduler):
def _step(self) -> float:
return self.max_lr + (self.min_lr - self.max_lr) * (self.step_t - self.warmup_steps) / (
self.total_steps - self.warmup_steps
)


class CosineScheduler(BaseLinearWarmupScheduler):
def _step(self) -> float:
phase: float = (self.step_t - self.warmup_steps) / (self.total_steps - self.warmup_steps) * math.pi
return self.min_lr + (self.max_lr - self.min_lr) * (np.cos(phase) + 1.0) / 2.0


class PolyScheduler(BaseLinearWarmupScheduler):
r"""Poly LR Scheduler
:param: poly_order: float. lr scheduler decreases with steps.
"""

def __init__(self, poly_order: float = 0.5, **kwargs):
self.poly_order = poly_order

if poly_order <= 0:
raise ValueError(f'[-] poly_order must be positive. {poly_order}')

super().__init__(**kwargs)

def _step(self) -> float:
return self.min_lr + (self.max_lr - self.min_lr) * (self.step_t - self.warmup_steps) ** self.poly_order
49 changes: 49 additions & 0 deletions pytorch_optimizer/lr_scheduler/proportion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import List


class ProportionScheduler:
r"""ProportionScheduler (Rho Scheduler of GSAM)
This scheduler outputs a value that evolves proportional to lr_scheduler.
:param lr_scheduler: learning rate scheduler.
:param max_lr: float. maximum lr.
:param min_lr: float. minimum lr.
:param max_value: float. maximum of rho.
:param min_value: float. minimum of rho.
"""

def __init__(
self, lr_scheduler, max_lr: float, min_lr: float = 0.0, max_value: float = 2.0, min_value: float = 2.0
):
self.lr_scheduler = lr_scheduler
self.max_lr = max_lr
self.min_lr = min_lr
self.max_value = max_value
self.min_value = min_value

self.step_t: int = 0
self.last_lr: List[float] = []

self.step()

def get_lr(self) -> float:
return self.last_lr[0]

def step(self) -> float:
self.step_t += 1

if hasattr(self.lr_scheduler, 'last_lr'):
lr = self.lr_scheduler.last_lr[0]
else:
lr = self.lr_scheduler.optimizer.param_groups[0]['lr']

if self.max_lr > self.min_lr:
value = self.min_value + (self.max_value - self.min_value) * (lr - self.min_lr) / (
self.max_lr - self.min_lr
)
else:
value = self.max_value

self.last_lr = [value]

return value
2 changes: 1 addition & 1 deletion pytorch_optimizer/optimizer/adabelief.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.base_optimizer import BaseOptimizer
from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS


Expand Down
2 changes: 1 addition & 1 deletion pytorch_optimizer/optimizer/adabound.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.base_optimizer import BaseOptimizer
from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS


Expand Down
2 changes: 1 addition & 1 deletion pytorch_optimizer/optimizer/adai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.base_optimizer import BaseOptimizer
from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
from pytorch_optimizer.optimizer.gc import centralize_gradient

Expand Down
2 changes: 1 addition & 1 deletion pytorch_optimizer/optimizer/adamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.base_optimizer import BaseOptimizer
from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
from pytorch_optimizer.optimizer.gc import centralize_gradient
from pytorch_optimizer.optimizer.utils import projection
Expand Down
Loading

0 comments on commit f6baa63

Please sign in to comment.