Skip to content

Commit

Permalink
Merge pull request #228 from kozistr/feature/galore-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement GaLore optimizer
  • Loading branch information
kozistr authored Apr 7, 2024
2 parents 523f140 + f2d6f14 commit b1b5ed4
Show file tree
Hide file tree
Showing 12 changed files with 489 additions and 204 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **62 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported!
Currently, **63 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported!

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down Expand Up @@ -160,6 +160,7 @@ supported_optimizers = get_supported_optimizers()
| CAME | *Confidence-guided Adaptive Memory Efficient Optimization* | [github](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/CAME) | <https://aclanthology.org/2023.acl-long.243/> | [cite](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/CAME#citation) |
| WSAM | *Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term* | [github](https://github.com/intelligent-machine-learning/dlrover/blob/master/atorch/atorch/optimizers/wsam.py) | <https://arxiv.org/abs/2305.15817> | [cite](https://github.com/intelligent-machine-learning/dlrover) |
| Aida | *A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range* | [github](https://github.com/guoqiang-zhang-x/Aida-Optimizer) | <https://arxiv.org/abs/2203.13273> | [cite](https://github.com/guoqiang-zhang-x/Aida-Optimizer?tab=readme-ov-file#1-brief-description-of-aida) |
| GaLore | *Memory-Efficient LLM Training by Gradient Low-Rank Projection* | [github](https://github.com/jiaweizzhao/GaLore) | <https://arxiv.org/abs/2403.03507> | [cite](https://github.com/jiaweizzhao/GaLore/tree/master?tab=readme-ov-file#citation) |

## Supported LR Scheduler

Expand Down
17 changes: 12 additions & 5 deletions docs/changelogs/v3.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,28 @@ Major version is updated! (`v2.12.0` -> `v3.0.0`) (#164)
* [A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range](https://arxiv.org/abs/2203.13273)
* Implement `WSAM` optimizer. (#213, #216)
* [Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term](https://arxiv.org/abs/2305.15817)
* Implement `GaLore` optimizer. (#224, #228)
* [Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507)

## Dependency
### Fix

* Fix SRMM to allow operation beyond memory_length. (#227)

### Dependency

* Drop `Python 3.7` support officially. (#221)
* Please check the [README](https://github.com/kozistr/pytorch_optimizer?tab=readme-ov-file#getting-started).
* Update `bitsandbytes` to `0.43.0`. (#228)

## Docs
### Docs

* Add missing parameters in `Ranger21 optimizer` document. (#214, #215)
* Fix `WSAM` optimizer paper link. (#219)

### Contributions
## Contributions

thanks to @sdbds
thanks to @sdbds, @i404788

### Diff
## Diff

[2.12.0...3.0.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.12.0...v3.0.0)
4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@
:docstring:
:members:

::: pytorch_optimizer.GaLoreProjector
:docstring:
:members:

::: pytorch_optimizer.centralize_gradient
:docstring:
:members:
Expand Down
313 changes: 140 additions & 173 deletions poetry.lock

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ keywords = [
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "AdamP",
"AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "CAME", "DAdaptAdaGrad", "DAdaptAdam",
"DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "Gravity", "GSAM", "LARS", "Lamb", "Lion", "LOMO",
"Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM",
"RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3",
"SopihaH", "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice",
"LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes",
"DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "GaLore", "Gravity", "GSAM", "LARS", "Lamb", "Lion",
"LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam",
"QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD",
"SM3", "SopihaH", "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
"Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
Expand Down Expand Up @@ -45,14 +45,14 @@ classifiers = [
python = ">=3.8,<4.0.0"
numpy = { version = "*", python = ">=3.8" }
torch = { version = ">=1.10", python = ">=3.8", source = "torch" }
bitsandbytes = { version = "^0.42", optional = true }
bitsandbytes = { version = "^0.43", optional = true }

[tool.poetry.dev-dependencies]
isort = { version = "^5", python = ">=3.8" }
black = { version = "^24", python = ">=3.8"}
ruff = "^0.3"
pytest = "^8"
pytest-cov = "^4"
ruff = "*"
pytest = "*"
pytest-cov = "*"

[tool.poetry.extras]
bitsandbytes = ["bitsandbytes"]
Expand Down
2 changes: 2 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from pytorch_optimizer.optimizer.diffgrad import DiffGrad
from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
from pytorch_optimizer.optimizer.fromage import Fromage
from pytorch_optimizer.optimizer.galore import GaLore, GaLoreProjector
from pytorch_optimizer.optimizer.gc import centralize_gradient
from pytorch_optimizer.optimizer.gravity import Gravity
from pytorch_optimizer.optimizer.lamb import Lamb
Expand Down Expand Up @@ -182,6 +183,7 @@
CAME,
DAdaptLion,
Aida,
GaLore,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down
249 changes: 249 additions & 0 deletions pytorch_optimizer/optimizer/galore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
import math
from typing import Literal, Optional, Tuple, Union

import torch
from torch.optim.optimizer import Optimizer

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS

PROJECTION_TYPE = Literal['std', 'reverse_std', 'right', 'left', 'full']


class GaLoreProjector:
r"""Memory-Efficient LLM Training by Gradient Low-Rank Projection.
:param rank: int. low rank to project.
:param update_proj_gap: int. num steps to update the projection.
:param scale: float. scale factor.
:param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' are
supported.
"""

def __init__(
self, rank: int = 128, update_proj_gap: int = 50, scale: float = 1.0, projection_type: PROJECTION_TYPE = 'std'
):
self.rank = rank
self.update_proj_gap = update_proj_gap
self.scale = scale
self.projection_type = projection_type

self.ortho_matrix: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None

@staticmethod
def get_orthogonal_matrix(
weights: torch.Tensor, rank: int, projection_type: str
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if projection_type not in {'right', 'left', 'full'}:
raise ValueError('projection_type should be one of left, right or full')

original_type = weights.data.dtype
original_device = weights.data.device
is_float: bool = original_type == torch.float

u, s, vh = torch.linalg.svd(weights if is_float else weights.float(), full_matrices=False)

if projection_type == 'right':
b = vh[:rank, :]
return b if is_float else b.to(original_device).type(original_type)
if projection_type == 'left':
a = u[:, :rank]
return a if is_float else a.to(original_device).type(original_type)

a = u[:, :rank]
b = vh[:rank, :]

return (
(a, b)
if is_float
else (a.to(original_device).type(original_type), b.to(original_device).type(original_type))
)

def get_low_rank_grad_std(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
if grad.shape[0] >= grad.shape[1]:
if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right')
return torch.matmul(grad, self.ortho_matrix.t())

if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left')

return torch.matmul(self.ortho_matrix.t(), grad)

def get_low_rank_grad_reverse_std(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
if grad.shape[0] >= grad.shape[1]:
if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left')
return torch.matmul(self.ortho_matrix.t(), grad)

if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right')

return torch.matmul(grad, self.ortho_matrix.t())

def get_low_rank_grad_right(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right')
return torch.matmul(grad, self.ortho_matrix.t())

def get_low_rank_grad_left(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left')
return torch.matmul(self.ortho_matrix.t(), grad)

def get_low_rank_grad_full(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='full')
return torch.matmul(self.ortho_matrix[0].t(), grad) @ self.ortho_matrix[1].t()

def project(self, full_rank_grad: torch.Tensor, steps: int) -> torch.Tensor:
if self.projection_type == 'std':
return self.get_low_rank_grad_std(full_rank_grad, steps)
if self.projection_type == 'reverse_std':
return self.get_low_rank_grad_reverse_std(full_rank_grad, steps)
if self.projection_type == 'right':
return self.get_low_rank_grad_right(full_rank_grad, steps)
if self.projection_type == 'left':
return self.get_low_rank_grad_left(full_rank_grad, steps)
if self.projection_type == 'full':
return self.get_low_rank_grad_full(full_rank_grad, steps)
raise NotImplementedError

def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor:
if self.projection_type == 'std':
return (
torch.matmul(low_rank_grad, self.ortho_matrix)
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]
else torch.matmul(self.ortho_matrix, low_rank_grad)
) * self.scale
if self.projection_type == 'reverse_std':
return (
torch.matmul(self.ortho_matrix, low_rank_grad.t())
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]
else torch.matmul(low_rank_grad, self.ortho_matrix.t())
) * self.scale
if self.projection_type == 'right':
return torch.matmul(low_rank_grad, self.ortho_matrix.t()) * self.scale
if self.projection_type == 'left':
return torch.matmul(self.ortho_matrix, low_rank_grad) * self.scale
if self.projection_type == 'full':
return torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1].t() * self.scale

raise NotImplementedError


class GaLore(Optimizer, BaseOptimizer):
r"""AdamW optimizer with GaLore projector.
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
:param weight_decay: float. weight decay (L2 penalty).
:param eps: float. term added to the denominator to improve numerical stability.
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-3,
betas: BETAS = (0.9, 0.999),
weight_decay: float = 0.0,
eps: float = 1e-6,
**kwargs,
):
self.validate_learning_rate(lr)
self.validate_betas(betas)
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
'weight_decay': weight_decay,
'eps': eps,
**kwargs,
}

super().__init__(params, defaults)

def __str__(self) -> str:
return 'GaLore'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]

state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1

beta1, beta2 = group['betas']

bias_correction1: float = 1.0 - beta1 ** group['step']
bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])

step_size: float = group['lr'] * bias_correction2_sq / bias_correction1

for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(str(self))

state = self.state[p]

if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

if 'rank' in group and p.dim() > 1:
if 'projector' not in state:
state['projector'] = GaLoreProjector(
rank=group['rank'],
update_proj_gap=group['update_proj_gap'],
scale=group['scale'],
projection_type=group['projection_type'],
)

grad = state['projector'].project(grad, group['step'])

self.apply_weight_decay(
p=p,
grad=None,
lr=group['lr'],
weight_decay=group['weight_decay'],
weight_decouple=True,
fixed_decay=False,
)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

de_nom = exp_avg_sq.sqrt().add_(group['eps'])

norm_grad = exp_avg / de_nom

if 'rank' in group and p.dim() > 1:
norm_grad = state['projector'].project_back(norm_grad)

p.add_(norm_grad, alpha=-step_size)

return loss
Loading

0 comments on commit b1b5ed4

Please sign in to comment.