Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Cautious optimizer, improve the stability of ADOPT optimizer, a new projector type random for GaLore optimizer #294

Merged
merged 13 commits into from
Nov 27, 2024
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| SOAP | *Improving and Stabilizing Shampoo using Adam* | [github](https://github.com/nikhilvyas/SOAP) | <https://arxiv.org/abs/2409.11321> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240911321V/exportcitation) |
| ADOPT | *Modified Adam Can Converge with Any β2 with the Optimal Rate* | [github](https://github.com/iShohei220/adopt) | <https://arxiv.org/abs/2411.02853> | [cite](https://github.com/iShohei220/adopt?tab=readme-ov-file#citation) |
| FTRL | *Follow The Regularized Leader* | | <https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf> | |
| Cautious | *Improving Training with One Line of Code* | [github](https://github.com/kyleliang919/C-Optim) | <https://arxiv.org/pdf/2411.16085v1> | [cite](https://github.com/kyleliang919/C-Optim?tab=readme-ov-file#citation) |

## Supported LR Scheduler

Expand Down
6 changes: 6 additions & 0 deletions docs/changelogs/v3.3.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
* [Modified Adam Can Converge with Any β2 with the Optimal Rate](https://arxiv.org/abs/2411.02853)
* Implement `FTRL` optimizer. (#291)
* [Follow The Regularized Leader](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf)
* Implement `Cautious optimizer` feature. (#294)
* [Improving Training with One Line of Code](https://arxiv.org/pdf/2411.16085v1)
* you can use it by setting `cautious=True` for `Lion`, `AdaFactor` and `AdEMAMix` optimizers.
* Improve the stability of `ADOPT` optimizer. (#294)
* [Note](https://github.com/iShohei220/adopt?tab=readme-ov-file#update-on-nov-22-2024)
* Support a new projection type `random` for `GaLoreProjector`. (#294)

### Refactor

Expand Down
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| SOAP | *Improving and Stabilizing Shampoo using Adam* | [github](https://github.com/nikhilvyas/SOAP) | <https://arxiv.org/abs/2409.11321> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240911321V/exportcitation) |
| ADOPT | *Modified Adam Can Converge with Any β2 with the Optimal Rate* | [github](https://github.com/iShohei220/adopt) | <https://arxiv.org/abs/2411.02853> | [cite](https://github.com/iShohei220/adopt?tab=readme-ov-file#citation) |
| FTRL | *Follow The Regularized Leader* | | <https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf> | |
| Cautious | *Improving Training with One Line of Code* | [github](https://github.com/kyleliang919/C-Optim) | <https://arxiv.org/pdf/2411.16085v1> | [cite](https://github.com/kyleliang919/C-Optim?tab=readme-ov-file#citation) |

## Supported LR Scheduler

Expand Down
11 changes: 11 additions & 0 deletions pytorch_optimizer/base/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,17 @@ def approximate_sq_grad(
c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
torch.mul(r_factor, c_factor, out=output)

@staticmethod
def apply_cautious(update: torch.Tensor, grad: torch.Tensor) -> None:
r"""Apply the Cautious Optimizer feature.

:param update: torch.Tensor. update. it'll be masked in in-place manner.
:param grad: torch.Tensor. gradient.
"""
mask = (update * grad > 0).to(grad.dtype)
mask.mul_(mask.numel() / (mask.sum() + 1))
update.mul_(mask)

@staticmethod
def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)') -> None:
if range_type == '[)' and not low <= x < high:
Expand Down
7 changes: 6 additions & 1 deletion pytorch_optimizer/optimizer/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class AdaFactor(BaseOptimizer):
:param momentum_dtype: torch.dtype. type of momentum variable. In VIT paper observed that storing momentum in
half-precision (bfloat16 type) does not affect training dynamics and has no effect on the outcome while
reducing optimize overhead from 2-fold to 1.5-fold.
:param cautious: bool. whether to use the Cautious variant.
"""

def __init__(
Expand All @@ -49,6 +50,7 @@ def __init__(
eps1: float = 1e-30,
eps2: float = 1e-3,
momentum_dtype: torch.dtype = torch.bfloat16,
cautious: bool = False,
**kwargs,
):
self.validate_learning_rate(lr)
Expand All @@ -62,6 +64,7 @@ def __init__(
self.eps1 = eps1
self.eps2 = eps2
self.momentum_dtype = momentum_dtype
self.cautious = cautious

defaults: DEFAULTS = {
'lr': lr,
Expand Down Expand Up @@ -214,7 +217,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg = state['exp_avg']
exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1)

update = exp_avg
update = exp_avg.clone()
if self.cautious:
self.apply_cautious(update, grad)

self.apply_weight_decay(
p=p,
Expand Down
18 changes: 12 additions & 6 deletions pytorch_optimizer/optimizer/ademamix.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class AdEMAMix(BaseOptimizer):
:param fixed_decay: bool. fix weight decay.
:param alpha: float. usually between 4 and 10 would work well.
:param t_alpha_beta3: Optional[float]. total number of iterations is preferred when needed.
:param cautious: bool. whether to use cautious feature.
:param eps: float. term added to the denominator to improve numerical stability.
"""

Expand All @@ -32,6 +33,7 @@ def __init__(
fixed_decay: bool = False,
alpha: float = 5.0,
t_alpha_beta3: Optional[float] = None,
cautious: bool = False,
eps: float = 1e-8,
**kwargs,
):
Expand All @@ -42,6 +44,8 @@ def __init__(
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

self.cautious = cautious

defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
Expand Down Expand Up @@ -71,9 +75,7 @@ def reset(self):

@staticmethod
def schedule_alpha(t_alpha_beta3: Optional[float], step: int, alpha: float) -> float:
if t_alpha_beta3 is None:
return alpha
return min(step * alpha / t_alpha_beta3, alpha)
return alpha if t_alpha_beta3 is None else min(step * alpha / t_alpha_beta3, alpha)

@staticmethod
def schedule_beta3(t_alpha_beta3: Optional[float], step: int, beta1: float, beta3: float) -> float:
Expand Down Expand Up @@ -107,6 +109,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
bias_correction1: float = self.debias(beta1, group['step'])
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

step_size: float = group['lr'] / bias_correction1

alpha_t: float = self.schedule_alpha(group['t_alpha_beta3'], group['step'], group['alpha'])
beta3_t: float = self.schedule_beta3(group['t_alpha_beta3'], group['step'], beta1, beta3)

Expand Down Expand Up @@ -140,10 +144,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
exp_avg_slow.mul_(beta3_t).add_(grad, alpha=1.0 - beta3_t)

de_nom = (exp_avg_sq.sqrt() / bias_correction2_sq).add_(group['eps'])
de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])

step_size = group['lr'] / bias_correction1
update = (exp_avg + alpha_t * exp_avg_slow).div_(de_nom)
if self.cautious:
self.apply_cautious(update, grad)

p.addcdiv_(exp_avg + alpha_t * exp_avg_slow, de_nom, value=-step_size)
p.add_(update, alpha=-step_size)

return loss
17 changes: 13 additions & 4 deletions pytorch_optimizer/optimizer/adopt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import math
from typing import Callable, Optional

import torch

from pytorch_optimizer.base.exception import NoSparseGradientError
Expand All @@ -22,6 +25,7 @@ def __init__(
params: PARAMETERS,
lr: float = 1e-3,
betas: BETAS = (0.9, 0.9999),
clip_lambda: Optional[Callable[[float], float]] = lambda step: math.pow(step, 0.25),
weight_decay: float = 0.0,
weight_decouple: bool = False,
fixed_decay: bool = False,
Expand All @@ -33,6 +37,8 @@ def __init__(
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

self.clip_lambda = clip_lambda

defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
Expand Down Expand Up @@ -104,10 +110,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1.0 - beta2)

de_nom = exp_avg_sq.sqrt().clamp_(min=group['eps'])
if group['step'] == 2:
exp_avg.addcdiv_(grad, de_nom)
else:
exp_avg.mul_(beta1).addcdiv_(grad, de_nom, value=1.0 - beta1)

normed_grad = grad.div(de_nom)
if self.clip_lambda is not None:
clip = self.clip_lambda(group['step'])
normed_grad.clamp_(-clip, clip)

exp_avg.lerp_(normed_grad, weight=1.0 - beta1)

p.add_(exp_avg, alpha=-group['lr'])

Expand Down
22 changes: 19 additions & 3 deletions pytorch_optimizer/optimizer/galore.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS

PROJECTION_TYPE = Literal['std', 'reverse_std', 'right', 'left', 'full']
PROJECTION_TYPE = Literal['std', 'reverse_std', 'right', 'left', 'full', 'random']


class GaLoreProjector:
Expand All @@ -16,8 +16,8 @@ class GaLoreProjector:
:param rank: int. low rank to project.
:param update_proj_gap: int. num steps to update the projection.
:param scale: float. scale factor.
:param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' are
supported.
:param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' and
'random' are supported.
"""

def __init__(
Expand Down Expand Up @@ -101,6 +101,14 @@ def get_low_rank_grad_full(self, grad: torch.Tensor, steps: int) -> torch.Tensor
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='full')
return torch.matmul(self.ortho_matrix[0].t(), grad) @ self.ortho_matrix[1].t()

def get_low_rank_grad_random(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
is_right: bool = grad.size(0) >= grad.size(1)
if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
grad, self.rank, projection_type='right' if is_right else 'left'
)
return torch.matmul(grad, self.ortho_matrix.t()) if is_right else torch.matmul(self.ortho_matrix.t(), grad)

def project(self, full_rank_grad: torch.Tensor, steps: int) -> torch.Tensor:
if self.projection_type == 'std':
return self.get_low_rank_grad_std(full_rank_grad, steps)
Expand All @@ -112,6 +120,8 @@ def project(self, full_rank_grad: torch.Tensor, steps: int) -> torch.Tensor:
return self.get_low_rank_grad_left(full_rank_grad, steps)
if self.projection_type == 'full':
return self.get_low_rank_grad_full(full_rank_grad, steps)
if self.projection_type == 'random':
return self.get_low_rank_grad_random(full_rank_grad, steps)
raise NotImplementedError

def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor:
Expand All @@ -133,6 +143,12 @@ def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor:
return torch.matmul(self.ortho_matrix, low_rank_grad) * self.scale
if self.projection_type == 'full':
return torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1].t() * self.scale
if self.projection_type == 'random':
return (
torch.matmul(low_rank_grad, self.ortho_matrix.t())
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]
else torch.matmul(self.ortho_matrix, low_rank_grad)
) * self.scale

raise NotImplementedError

Expand Down
6 changes: 6 additions & 0 deletions pytorch_optimizer/optimizer/lion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Lion(BaseOptimizer):
:param use_gc: bool. use gradient centralization.
:param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
:param adanorm: bool. whether to use the AdaNorm variant.
:param cautious: bool. whether to use the Cautious variant.
"""

def __init__(
Expand All @@ -31,13 +32,15 @@ def __init__(
use_gc: bool = False,
r: float = 0.95,
adanorm: bool = False,
cautious: bool = False,
**kwargs,
):
self.validate_learning_rate(lr)
self.validate_betas(betas)
self.validate_non_negative(weight_decay, 'weight_decay')

self.use_gc = use_gc
self.cautious = cautious

defaults: DEFAULTS = {
'lr': lr,
Expand Down Expand Up @@ -114,6 +117,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
update.mul_(beta1).add_(grad, alpha=1.0 - beta1).sign_()
exp_avg.mul_(beta2).add_(s_grad, alpha=1.0 - beta2)

if self.cautious:
self.apply_cautious(update, grad)

p.add_(update, alpha=-group['lr'])

return loss
12 changes: 10 additions & 2 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@
(AdamS, {'lr': 1e0, 'weight_decay': 1e-3, 'ams_bound': True}, 20),
(AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100),
(AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'ams_bound': True}, 120),
(AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'cautious': True}, 70),
(AdaFactor, {'lr': 1e1, 'betas': (None, 0.999), 'weight_decay': 1e-3}, 40),
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3}, 10),
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'rebound': 'belief'}, 10),
Expand All @@ -383,6 +384,7 @@
(Lion, {'lr': 5e-1, 'weight_decay': 1e-3}, 5),
(Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': False}, 5),
(Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'use_gc': True}, 10),
(Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'cautious': True}, 5),
(AliG, {'max_lr': 5e-1, 'momentum': 0.9}, 5),
(AliG, {'max_lr': 5e-1, 'momentum': 0.9, 'adjusted_momentum': True}, 5),
(SM3, {'lr': 5e-1, 'momentum': 0.9, 'beta': 0.9}, 5),
Expand Down Expand Up @@ -469,6 +471,11 @@
{'lr': 5e-1, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 2, 'projection_type': 'full'},
5,
),
(
GaLore,
{'lr': 1e0, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 1, 'projection_type': 'random'},
5,
),
(Adalite, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(ScheduleFreeSGD, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
Expand All @@ -478,8 +485,9 @@
(Kate, {'lr': 5e-2}, 10),
(StableAdamW, {'lr': 1e0}, 5),
(AdamG, {'lr': 1e0}, 20),
(AdEMAMix, {'lr': 1e0}, 5),
(AdEMAMix, {'lr': 1e0, 't_alpha_beta3': 5}, 5),
(AdEMAMix, {'lr': 1e0}, 3),
(AdEMAMix, {'lr': 1e0, 't_alpha_beta3': 5}, 3),
(AdEMAMix, {'lr': 1e0, 'cautious': True}, 2),
(
SOAP,
{'lr': 1e0, 'shampoo_beta': 0.95, 'precondition_frequency': 1, 'merge_dims': False, 'precondition_1d': True},
Expand Down