Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] code #261

Merged
merged 14 commits into from
Jul 21, 2024
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ check:
ruff check pytorch_optimizer examples tests hubconf.py

requirements:
python -m poetry export -f requirements.txt --output requirements.txt --without-hashes
python -m poetry export -f requirements.txt --output requirements-dev.txt --without-hashes --with dev
poetry export -f requirements.txt --output requirements.txt --without-hashes
poetry export -f requirements.txt --output requirements-dev.txt --without-hashes --with dev

docs:
mkdocs serve
3 changes: 3 additions & 0 deletions docs/changelogs/v3.1.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
* `bnb_paged_adam8bit`, `bnb_paged_adamw8bit`, `bnb_*_*32bit`.
* Improve `power_iteration()` speed up to 40%. (#259)
* Improve `reg_noise()` (E-MCMC) speed up to 120%. (#260)
* Support `disable_lr_scheduler` parameter for `Ranger21` optimizer to disable built-in learning rate scheduler. (#261)

### Refactor

* Refactor `AdamMini` optimizer. (#258)
* Deprecate optional dependency, `bitsandbytes`. (#258)
* Move `get_rms`, `approximate_sq_grad` functions to `BaseOptimizer` for reusability. (#258)
* Refactor `shampoo_utils.py`. (#259)
* Add `debias`, `debias_adam` methods in `BaseOptimizer`. (#261)
* Refactor to use `BaseOptimizer` only, not inherit multiple classes. (#261)

### Bug

Expand Down
48 changes: 38 additions & 10 deletions pytorch_optimizer/base/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@
from typing import List, Optional, Tuple, Union

import torch
from torch.optim import Optimizer

from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
from pytorch_optimizer.base.types import BETAS, HUTCHINSON_G, PARAMETERS, STATE
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, HUTCHINSON_G, LOSS, PARAMETERS, STATE


class BaseOptimizer(ABC):
r"""Base optimizer class."""
class BaseOptimizer(ABC, Optimizer):
r"""Base optimizer class. Provides common functionalities for the optimizers."""

def __init__(self, params: PARAMETERS, defaults: DEFAULTS) -> None:
super().__init__(params, defaults)

@staticmethod
@torch.no_grad()
def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tensor]):
def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tensor]) -> None:
r"""Set hessian to state from external source. Generally useful when using functorch as a base.

Example:
Expand Down Expand Up @@ -45,7 +49,7 @@ def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tens
i += 1

@staticmethod
def zero_hessian(param_groups: PARAMETERS, state: STATE, pre_zero: bool = True):
def zero_hessian(param_groups: PARAMETERS, state: STATE, pre_zero: bool = True) -> None:
r"""Zero-out hessian.

:param param_groups: PARAMETERS. parameter groups.
Expand All @@ -68,7 +72,7 @@ def compute_hutchinson_hessian(
num_samples: int = 1,
alpha: float = 1.0,
distribution: HUTCHINSON_G = 'gaussian',
):
) -> None:
r"""Hutchinson's approximate hessian, added to the state under key `hessian`.

:param param_groups: PARAMETERS. parameter groups.
Expand Down Expand Up @@ -110,7 +114,7 @@ def apply_weight_decay(
weight_decouple: bool,
fixed_decay: bool,
ratio: Optional[float] = None,
):
) -> None:
r"""Apply weight decay.

:param p: torch.Tensor. parameter.
Expand Down Expand Up @@ -145,6 +149,27 @@ def apply_ams_bound(

return de_nom.sqrt_().add_(eps)

@staticmethod
def debias(beta: float, step: int) -> float:
r"""Adam-style debias correction. Returns `1.0 - beta ** step`.

:param beta: float. beta.
:param step: int. number of step.
"""
return 1.0 - math.pow(beta, step) # fmt: skip

@staticmethod
def debias_beta(beta: float, step: int) -> float:
r"""Apply the Adam-style debias correction into beta.

Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)`

:param beta: float. beta.
:param step: int. number of step.
"""
beta_n: float = math.pow(beta, step)
return (beta_n - beta) / (beta_n - 1.0) # fmt: skip

@staticmethod
def apply_adam_debias(adam_debias: bool, step_size: float, bias_correction1: float) -> float:
r"""Apply AdamD variant.
Expand Down Expand Up @@ -205,14 +230,14 @@ def get_adanorm_gradient(
:param exp_grad_norm: Optional[torch.Tensor]. exp_grad_norm.
:param r: float. Optional[float]. momentum (ratio).
"""
if not adanorm:
if not adanorm or exp_grad_norm is None:
return grad

grad_norm = torch.linalg.norm(grad)

exp_grad_norm.mul_(r).add_(grad_norm, alpha=1.0 - r)

return grad * exp_grad_norm / grad_norm if exp_grad_norm > grad_norm else grad
return grad.mul(exp_grad_norm).div_(grad_norm) if exp_grad_norm > grad_norm else grad

@staticmethod
def get_rms(x: torch.Tensor) -> float:
Expand Down Expand Up @@ -299,5 +324,8 @@ def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None:
self.validate_range(nus[1], 'nu2', 0.0, 1.0, range_type='[]')

@abstractmethod
def reset(self): # pragma: no cover
def reset(self) -> None: # pragma: no cover
raise NotImplementedError

def step(self, closure: CLOSURE = None) -> LOSS: # pragma: no cover
raise NotImplementedError
3 changes: 1 addition & 2 deletions pytorch_optimizer/optimizer/a2grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from typing import Optional

import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS


class A2Grad(Optimizer, BaseOptimizer):
class A2Grad(BaseOptimizer):
r"""Optimal Adaptive and Accelerated Stochastic Gradient Descent.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
Expand Down
7 changes: 3 additions & 4 deletions pytorch_optimizer/optimizer/adabelief.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import math

import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS


class AdaBelief(Optimizer, BaseOptimizer):
class AdaBelief(BaseOptimizer):
r"""Adapting Step-sizes by the Belief in Observed Gradients.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
Expand Down Expand Up @@ -101,8 +100,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:

beta1, beta2 = group['betas']

bias_correction1: float = 1.0 - beta1 ** group['step']
bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])
bias_correction1: float = self.debias(beta1, group['step'])
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

step_size, n_sma = self.get_rectify_step_size(
is_rectify=group['rectify'],
Expand Down
7 changes: 3 additions & 4 deletions pytorch_optimizer/optimizer/adabound.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from typing import List

import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS


class AdaBound(Optimizer, BaseOptimizer):
class AdaBound(BaseOptimizer):
r"""Adaptive Gradient Methods with Dynamic Bound of Learning Rate.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
Expand Down Expand Up @@ -90,8 +89,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:

beta1, beta2 = group['betas']

bias_correction1: float = 1.0 - beta1 ** group['step']
bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])
bias_correction1: float = self.debias(beta1, group['step'])
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

final_lr: float = group['final_lr'] * group['lr'] / base_lr
lower_bound: float = final_lr * (1 - 1 / (group['gamma'] * group['step'] + 1))
Expand Down
3 changes: 1 addition & 2 deletions pytorch_optimizer/optimizer/adadelta.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS


class AdaDelta(Optimizer, BaseOptimizer):
class AdaDelta(BaseOptimizer):
r"""An Adaptive Learning Rate Method.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
Expand Down
3 changes: 1 addition & 2 deletions pytorch_optimizer/optimizer/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from typing import Optional, Tuple

import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS


class AdaFactor(Optimizer, BaseOptimizer):
class AdaFactor(BaseOptimizer):
r"""Adaptive Learning Rates with Sublinear Memory Cost with some tweaks.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
Expand Down
7 changes: 3 additions & 4 deletions pytorch_optimizer/optimizer/adahessian.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import List, Optional

import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, HUTCHINSON_G, LOSS, PARAMETERS


class AdaHessian(Optimizer, BaseOptimizer):
class AdaHessian(BaseOptimizer):
r"""An Adaptive Second Order Optimizer for Machine Learning.

Requires `loss.backward(create_graph=True)` in order to calculate hessians.
Expand Down Expand Up @@ -104,8 +103,8 @@ def step(self, closure: CLOSURE = None, hessian: Optional[List[torch.Tensor]] =

beta1, beta2 = group['betas']

bias_correction1: float = 1.0 - beta1 ** group['step']
bias_correction2: float = 1.0 - beta2 ** group['step']
bias_correction1: float = self.debias(beta1, group['step'])
bias_correction2: float = self.debias(beta2, group['step'])

for p in group['params']:
if p.grad is None:
Expand Down
7 changes: 3 additions & 4 deletions pytorch_optimizer/optimizer/adai.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import math

import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
from pytorch_optimizer.optimizer.gc import centralize_gradient


class Adai(Optimizer, BaseOptimizer):
class Adai(BaseOptimizer):
r"""Disentangling the Effects of Adaptive Learning Rate and Momentum.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
Expand Down Expand Up @@ -105,7 +104,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
if self.use_gc:
centralize_gradient(grad, gc_conv_only=False)

bias_correction2: float = 1.0 - beta2 ** state['step']
bias_correction2: float = self.debias(beta2, state['step'])

if not group['stable_weight_decay'] and group['weight_decay'] > 0.0:
self.apply_weight_decay(
Expand Down Expand Up @@ -148,7 +147,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
fixed_decay=group['fixed_decay'],
)

bias_correction2: float = 1.0 - beta2 ** state['step']
bias_correction2: float = self.debias(beta2, state['step'])

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

Expand Down
3 changes: 1 addition & 2 deletions pytorch_optimizer/optimizer/adalite.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import torch
from torch.nn.functional import softmax
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS


class Adalite(Optimizer, BaseOptimizer):
class Adalite(BaseOptimizer):
r"""Adalite optimizer.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
Expand Down
7 changes: 3 additions & 4 deletions pytorch_optimizer/optimizer/adam_mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
import torch
from torch import distributed as dist
from torch import nn
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS


class AdamMini(Optimizer, BaseOptimizer): # pragma: no cover
class AdamMini(BaseOptimizer): # pragma: no cover
r"""Use Fewer Learning Rates To Gain More.

:param model: nn.Module. model instance.
Expand Down Expand Up @@ -276,8 +275,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:

beta1, beta2 = group['betas']

bias_correction1: float = 1.0 - beta1 ** group['step']
bias_correction2: float = 1.0 - beta2 ** group['step']
bias_correction1: float = self.debias(beta1, group['step'])
bias_correction2: float = self.debias(beta2, group['step'])
bias_correction2_sq: float = math.sqrt(bias_correction2)

for p in group['params']:
Expand Down
5 changes: 2 additions & 3 deletions pytorch_optimizer/optimizer/adamax.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS


class AdaMax(Optimizer, BaseOptimizer):
class AdaMax(BaseOptimizer):
r"""An Adaptive and Momental Bound Method for Stochastic Learning.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
Expand Down Expand Up @@ -84,7 +83,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:

beta1, beta2 = group['betas']

bias_correction1: float = 1.0 - beta1 ** group['step']
bias_correction1: float = self.debias(beta1, group['step'])

for p in group['params']:
if p.grad is None:
Expand Down
7 changes: 3 additions & 4 deletions pytorch_optimizer/optimizer/adamod.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import math

import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS


class AdaMod(Optimizer, BaseOptimizer):
class AdaMod(BaseOptimizer):
r"""An Adaptive and Momental Bound Method for Stochastic Learning.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
Expand Down Expand Up @@ -78,8 +77,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:

beta1, beta2, beta3 = group['betas']

bias_correction1: float = 1.0 - beta1 ** group['step']
bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])
bias_correction1: float = self.debias(beta1, group['step'])
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

for p in group['params']:
if p.grad is None:
Expand Down
Loading