Skip to content

Commit

Permalink
Merge pull request #258 from kozistr/feature/adalomo-optimizer
Browse files Browse the repository at this point in the history
[Feature] Implement `AdaLOMO` optimizer and others
  • Loading branch information
kozistr authored Jul 14, 2024
2 parents acd218b + 2acade3 commit 3d4d440
Show file tree
Hide file tree
Showing 20 changed files with 675 additions and 318 deletions.
158 changes: 80 additions & 78 deletions README.md

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions docs/changelogs/v3.1.0.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
## Change Log

### Feature

* Implement `AdaLomo` optimizer. (#258)
* [Low-memory Optimization with Adaptive Learning Rate](https://arxiv.org/abs/2310.10195)
* Support `Q-GaLore` optimizer. (#258)
* [Q-GaLore: Quantized GaLore with INT4 Projection and Layer-Adaptive Low-Rank Gradients.](https://arxiv.org/abs/2407.08296)
* you can use by `optimizer = load_optimizer('q_galore_adamw8bit')`
* Support more bnb optimizers. (#258)
* `bnb_paged_adam8bit`, `bnb_paged_adamw8bit`, `bnb_*_*32bit`.

### Refactor

* Refactor `AdamMini`. (#258)
* Deprecate optional dependency, `bitsandbytes`. (#258)
* Move `get_rms`, `approximate_sq_grad` functions to `BaseOptimizer` for reusability. (#258)

### Bug

* Fix several bugs in `AdamMini` optimizer. (#257)

## Contributions

thanks to @sdbds
164 changes: 85 additions & 79 deletions docs/index.md

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
:docstring:
:members:

::: pytorch_optimizer.AdaLOMO
:docstring:
:members:

::: pytorch_optimizer.Adai
:docstring:
:members:
Expand Down
2 changes: 1 addition & 1 deletion examples/visualize_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def main():
]

for optimizer_name, optimizer in OPTIMIZERS.items():
if optimizer_name.lower() in {'alig', 'lomo', 'bsam', 'adammini'}:
if optimizer_name.lower() in {'alig', 'lomo', 'adalomo', 'bsam', 'adammini'}:
continue

optimizers.append((optimizer, -6, 0.2))
Expand Down
146 changes: 65 additions & 81 deletions poetry.lock

Large diffs are not rendered by default.

21 changes: 9 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytorch_optimizer"
version = "3.0.2"
version = "3.1.0"
description = "optimizer & lr scheduler & objective function collections in PyTorch"
license = "Apache-2.0"
authors = ["kozistr <[email protected]>"]
Expand All @@ -12,13 +12,14 @@ documentation = "https://pytorch-optimizers.readthedocs.io/en/latest"
keywords = [
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "Adalite",
"AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM", "CAME",
"DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", "Fromage", "GaLore",
"Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero",
"NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad",
"SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3",
"SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine",
"SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
"AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM",
"CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", "Fromage",
"GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG",
"Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21",
"RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD",
"SM3", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
"FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge",
"bitsandbytes", "WSD", "QGaLore",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
Expand Down Expand Up @@ -46,7 +47,6 @@ classifiers = [
python = ">=3.8,<4.0.0"
numpy = { version = "*", python = ">=3.8" }
torch = { version = ">=1.10", python = ">=3.8", source = "torch" }
bitsandbytes = { version = "^0.43", optional = true }

[tool.poetry.dev-dependencies]
isort = { version = "^5", python = ">=3.8" }
Expand All @@ -55,9 +55,6 @@ ruff = "*"
pytest = "*"
pytest-cov = "*"

[tool.poetry.extras]
bitsandbytes = ["bitsandbytes"]

[[tool.poetry.source]]
name = "torch"
url = "https://download.pytorch.org/whl/cpu"
Expand Down
70 changes: 54 additions & 16 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# ruff: noqa
from importlib.util import find_spec
from typing import Dict, List

import torch.cuda
Expand Down Expand Up @@ -72,7 +73,7 @@
from pytorch_optimizer.optimizer.lamb import Lamb
from pytorch_optimizer.optimizer.lars import LARS
from pytorch_optimizer.optimizer.lion import Lion
from pytorch_optimizer.optimizer.lomo import LOMO
from pytorch_optimizer.optimizer.lomo import LOMO, AdaLOMO
from pytorch_optimizer.optimizer.lookahead import Lookahead
from pytorch_optimizer.optimizer.madgrad import MADGRAD
from pytorch_optimizer.optimizer.msvag import MSVAG
Expand Down Expand Up @@ -126,12 +127,8 @@
)
from pytorch_optimizer.optimizer.yogi import Yogi

try:
import bitsandbytes as bnb

HAS_BNB: bool = True # pragma: no cover
except ImportError:
HAS_BNB: bool = False
HAS_BNB: bool = find_spec('bitsandbytes') is not None
HAS_Q_GALORE: bool = find_spec('q-galore-torch') is not None

OPTIMIZER_LIST: List[OPTIMIZER] = [
AdaBelief,
Expand Down Expand Up @@ -205,6 +202,7 @@
Kate,
StableAdamW,
AdamMini,
AdaLOMO,
]
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

Expand Down Expand Up @@ -252,22 +250,58 @@

def load_bnb_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover
r"""load bnb optimizer instance."""
from bitsandbytes import optim

if 'sgd8bit' in optimizer:
return bnb.optim.SGD8bit
return optim.SGD8bit
if 'adam8bit' in optimizer:
return bnb.optim.Adam8bit
return optim.Adam8bit
if 'paged_adam8bit' in optimizer:
return optim.PagedAdam8bit
if 'adamw8bit' in optimizer:
return bnb.optim.AdamW8bit
return optim.AdamW8bit
if 'paged_adamw8bit' in optimizer:
return optim.PagedAdamW8bit
if 'lamb8bit' in optimizer:
return bnb.optim.LAMB8bit
return optim.LAMB8bit
if 'lars8bit' in optimizer:
return bnb.optim.LARS8bit
return optim.LARS8bit
if 'lion8bit' in optimizer:
return bnb.optim.Lion8bit
return optim.Lion8bit
if 'adagrad8bit' in optimizer:
return bnb.optim.Adagrad8bit
return optim.Adagrad8bit
if 'rmsprop8bit' in optimizer:
return bnb.optim.RMSprop8bit
return optim.RMSprop8bit
if 'adagrad32bit' in optimizer:
return optim.Adagrad32bit
if 'adam32bit' in optimizer:
return optim.Adam32bit
if 'paged_adam32bit' in optimizer:
return optim.PagedAdam32bit
if 'adamw32bit' in optimizer:
return optim.AdamW32bit
if 'lamb32bit' in optimizer:
return optim.LAMB32bit
if 'lars32bit' in optimizer:
return optim.LARS32bit
if 'lion32bit' in optimizer:
return optim.Lion32bit
if 'paged_lion32bit' in optimizer:
return optim.PagedLion32bit
if 'rmsprop32bit' in optimizer:
return optim.RMSprop32bit
if 'sgd32bit' in optimizer:
return optim.SGD32bit
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')


def load_q_galore_optimizer(optimizer: str) -> OPTIMIZER: # pragma: no cover
r"""load Q-GaLore optimizer instance."""
import q_galore_torch

if 'adamw8bit' in optimizer:
return q_galore_torch.QGaLoreAdamW8bit

raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')


Expand All @@ -277,7 +311,11 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
if optimizer.startswith('bnb'):
if HAS_BNB and torch.cuda.is_available():
return load_bnb_optimizer(optimizer) # pragma: no cover
raise ImportError(f'[-] bitsandbytes and CUDA required for bnb optimizers : {optimizer}')
raise ImportError(f'[-] bitsandbytes and CUDA required for the optimizer {optimizer}')
if optimizer.startswith('q_galore'):
if HAS_Q_GALORE and torch.cuda.is_available():
return load_q_galore_optimizer(optimizer) # pragma: no cover
raise ImportError(f'[-] bitsandbytes, q-galore-torch, and CUDA required for the optimizer {optimizer}')
if optimizer not in OPTIMIZERS:
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')

Expand Down
16 changes: 16 additions & 0 deletions pytorch_optimizer/base/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,22 @@ def get_adanorm_gradient(

return grad * exp_grad_norm / grad_norm if exp_grad_norm > grad_norm else grad

@staticmethod
def get_rms(x: torch.Tensor) -> float:
r"""Get RMS."""
return x.norm(2) / math.sqrt(x.numel())

@staticmethod
def approximate_sq_grad(
exp_avg_sq_row: torch.Tensor,
exp_avg_sq_col: torch.Tensor,
output: torch.Tensor,
) -> None:
r"""Get approximation of EMA of squared gradient."""
r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
torch.mul(r_factor, c_factor, out=output)

@staticmethod
def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)') -> None:
if range_type == '[)' and not low <= x < high:
Expand Down
16 changes: 0 additions & 16 deletions pytorch_optimizer/optimizer/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,6 @@ def get_options(shape: Tuple[int, ...]) -> bool:
r"""Get `factored`."""
return len(shape) >= 2

@staticmethod
def get_rms(x: torch.Tensor) -> float:
r"""Get RMS."""
return x.norm(2) / math.sqrt(x.numel())

@staticmethod
def approximate_sq_grad(
exp_avg_sq_row: torch.Tensor,
exp_avg_sq_col: torch.Tensor,
output: torch.Tensor,
):
r"""Get approximation of EMA of squared gradient."""
r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
torch.mul(r_factor, c_factor, out=output)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
Expand Down
21 changes: 7 additions & 14 deletions pytorch_optimizer/optimizer/adam_mini.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Optional
from typing import Optional, Set

import torch
from torch import distributed as dist
Expand Down Expand Up @@ -57,6 +57,9 @@ def __init__(
self.num_embeds = num_embeds
self.num_heads = num_heads

self.embed_blocks: Set[str] = {'embed', 'embd', 'wte', 'lm_head.weight', 'output.weight'}
self.qk_blocks: Set[str] = {'k_proj.weight', 'q_proj.weight', 'wq.weight', 'wk.weight'}

groups = self.get_optimizer_groups(weight_decay)

defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'eps': eps}
Expand All @@ -77,12 +80,7 @@ def get_optimizer_groups(self, weight_decay: float):
'weight_decay': 0.0 if ('norm' in name or 'ln_f' in name) else weight_decay,
}

if (
'self_attn.k_proj.weight' in name
or 'self_attn.q_proj.weight' in name
or 'attn.wq.weight' in name
or 'attn.wk.weight' in name
):
if any(block in name for block in self.qk_blocks):
group['parameter_per_head'] = self.num_embeds * self.num_embeds // self.num_heads

if 'attn.attn.weight' in name or 'attn.qkv.weight' in name:
Expand Down Expand Up @@ -303,16 +301,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
fixed_decay=False,
)

if 'embed_tokens' in name or 'wte' in name or 'lm_head' in name:
if any(block in name for block in self.embed_blocks):
self.step_embed(
p, grad, state, group['lr'], beta1, beta2, bias_correction1, bias_correction2_sq, group['eps']
)
elif (
'self_attn.k_proj.weight' in name
or 'self_attn.q_proj.weight' in name
or 'attn.wq.weight' in name
or 'attn.wk.weight' in name
):
elif any(block in name for block in self.qk_blocks):
self.step_attn_proj(
p,
grad,
Expand Down
Loading

0 comments on commit 3d4d440

Please sign in to comment.