diff --git a/README.md b/README.md index 6f550c2f..b49ec067 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ ## The reasons why you use `pytorch-optimizer`. -* Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! +* Wide range of supported optimizers. Currently, **84 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! * Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion` * Easy to use, clean, and tested codes * Active maintenance @@ -192,6 +192,7 @@ get_supported_optimizers(['adam*', 'ranger*']) | MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) | | Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | | [cite](https://github.com/KellerJordan/Muon) | | LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) | +| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) | ## Supported LR Scheduler diff --git a/docs/changelogs/v3.3.1.md b/docs/changelogs/v3.3.1.md index 9b6eed73..cc4b5017 100644 --- a/docs/changelogs/v3.3.1.md +++ b/docs/changelogs/v3.3.1.md @@ -4,6 +4,9 @@ * Support `Cautious` variant to `AdaShift` optimizer. (#310) * Save the state of the `Lookahead` optimizer too. (#310) +* Implement `APOLLO` optimizer. (#311, #312) + * [SGD-like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270) +* Rename the `Apollo` (`An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization`) optimizer name to `ApolloDQN` not to overlap with the new optimizer name `APOLLO`. (#312) ### Bug diff --git a/docs/index.md b/docs/index.md index 6f550c2f..b49ec067 100644 --- a/docs/index.md +++ b/docs/index.md @@ -10,7 +10,7 @@ ## The reasons why you use `pytorch-optimizer`. -* Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! +* Wide range of supported optimizers. Currently, **84 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! * Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion` * Easy to use, clean, and tested codes * Active maintenance @@ -192,6 +192,7 @@ get_supported_optimizers(['adam*', 'ranger*']) | MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) | | Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | | [cite](https://github.com/KellerJordan/Muon) | | LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) | +| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) | ## Supported LR Scheduler diff --git a/docs/optimizer.md b/docs/optimizer.md index 887f5e48..0cf30f27 100644 --- a/docs/optimizer.md +++ b/docs/optimizer.md @@ -116,7 +116,11 @@ :docstring: :members: -::: pytorch_optimizer.Apollo +::: pytorch_optimizer.APOLLO + :docstring: + :members: + +::: pytorch_optimizer.ApolloDQN :docstring: :members: diff --git a/pyproject.toml b/pyproject.toml index 7e3eea6d..8f76709a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,14 +13,14 @@ keywords = [ "pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound", "AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix", "ADOPT", "AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", - "Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", - "DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LaProp", - "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", - "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", - "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", - "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", - "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", - "bitsandbytes", "WSD", "QGaLore", + "Apollo", "APOLLO", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", + "DAdaptLion", "DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", + "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam", + "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", + "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", + "SignSGD", "SM3", "SOAP", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", + "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", + "LovaszHinge", "bitsandbytes", "WSD", "QGaLore", ] classifiers = [ "License :: OSI Approved :: Apache Software License", diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 8af7e027..dfd91c2b 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -42,6 +42,7 @@ ) from pytorch_optimizer.optimizer import ( ADOPT, + APOLLO, ASGD, BSAM, CAME, @@ -90,7 +91,7 @@ Aida, AliG, Amos, - Apollo, + ApolloDQN, AvaGrad, DAdaptAdaGrad, DAdaptAdam, diff --git a/pytorch_optimizer/optimizer/__init__.py b/pytorch_optimizer/optimizer/__init__.py index 23397307..b08209e4 100644 --- a/pytorch_optimizer/optimizer/__init__.py +++ b/pytorch_optimizer/optimizer/__init__.py @@ -34,7 +34,7 @@ from pytorch_optimizer.optimizer.aida import Aida from pytorch_optimizer.optimizer.alig import AliG from pytorch_optimizer.optimizer.amos import Amos -from pytorch_optimizer.optimizer.apollo import Apollo +from pytorch_optimizer.optimizer.apollo import APOLLO, ApolloDQN from pytorch_optimizer.optimizer.avagrad import AvaGrad from pytorch_optimizer.optimizer.came import CAME from pytorch_optimizer.optimizer.dadapt import DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptLion, DAdaptSGD @@ -228,7 +228,8 @@ def load_optimizer(optimizer: str) -> OPTIMIZER: DAdaptAdan, AdamS, AdaFactor, - Apollo, + ApolloDQN, + APOLLO, SWATS, NovoGrad, Lion, diff --git a/pytorch_optimizer/optimizer/apollo.py b/pytorch_optimizer/optimizer/apollo.py index 04ca47cf..ae24af84 100644 --- a/pytorch_optimizer/optimizer/apollo.py +++ b/pytorch_optimizer/optimizer/apollo.py @@ -1,14 +1,18 @@ -from typing import Optional +import math +from typing import Literal, Optional import numpy as np import torch from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS +from pytorch_optimizer.optimizer.galore_utils import GaLoreProjector +SCALE_TYPE = Literal['channel', 'tensor'] -class Apollo(BaseOptimizer): + +class ApolloDQN(BaseOptimizer): r"""An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. @@ -25,8 +29,8 @@ class Apollo(BaseOptimizer): def __init__( self, params: PARAMETERS, - lr: float = 1e-3, - init_lr: Optional[float] = None, + lr: float = 1e-2, + init_lr: Optional[float] = 1e-5, beta: float = 0.9, rebound: str = 'constant', weight_decay: float = 0.0, @@ -58,7 +62,7 @@ def __init__( super().__init__(params, defaults) def __str__(self) -> str: - return 'Apollo' + return 'ApolloDQN' @torch.no_grad() def reset(self): @@ -146,3 +150,155 @@ def step(self, closure: CLOSURE = None) -> LOSS: p.add_(d_p, alpha=-current_lr) return loss + + +class APOLLO(BaseOptimizer): + r"""SGD-like Memory, AdamW-level Performance. + + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. + :param lr: float. learning rate. + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. + :param weight_decay: float. weight decay (L2 penalty). + :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. + :param fixed_decay: bool. fix weight decay. + :param correct_bias: bool. Whether to correct bias in Adam. + :param eps: float. term added to the denominator to improve numerical stability. + """ + + def __init__( + self, + params: PARAMETERS, + lr: float = 1e-2, + betas: BETAS = (0.9, 0.999), + scale_type: SCALE_TYPE = 'tensor', + weight_decay: float = 0.0, + weight_decouple: bool = True, + fixed_decay: bool = False, + correct_bias: bool = True, + eps: float = 1e-6, + **kwargs, + ): + self.validate_learning_rate(lr) + self.validate_betas(betas) + self.validate_non_negative(weight_decay, 'weight_decay') + self.validate_non_negative(eps, 'eps') + + defaults: DEFAULTS = { + 'lr': lr, + 'betas': betas, + 'scale_type': scale_type, + 'weight_decay': weight_decay, + 'weight_decouple': weight_decouple, + 'fixed_decay': fixed_decay, + 'correct_bias': correct_bias, + 'eps': eps, + **kwargs, + } + super().__init__(params, defaults) + + def __str__(self) -> str: + return 'APOLLO' + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + state = self.state[p] + + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure: CLOSURE = None) -> LOSS: + loss: LOSS = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + beta1, beta2 = group['betas'] + + step_size: float = group['lr'] + if group['correct_bias']: + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) + step_size *= bias_correction2_sq / bias_correction1 + + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + if grad.is_sparse: + raise NoSparseGradientError(str(self)) + + state = self.state[p] + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + + if 'rank' in group and p.dim() > 1: + if 'projector' not in state: + state['projector'] = GaLoreProjector( + rank=group['rank'], + update_proj_gap=group['update_proj_gap'], + scale=group['scale'], + projection_type=group['projection_type'], + ) + + grad = state['projector'].project(grad, group['step'], from_random_matrix=True) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + de_nom = exp_avg_sq.sqrt().add_(group['eps']) + + norm_grad = exp_avg / de_nom + if 'rank' in group and p.dim() > 1: + if group['scale_type'] == 'channel': + norm_dim: int = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1 + scaling_factor = torch.norm(norm_grad, dim=norm_dim) / (torch.norm(grad, dim=norm_dim) + 1e-8) + if norm_dim == 1: + scaling_factor = scaling_factor.unsqueeze(1) + else: + scaling_factor = torch.norm(norm_grad) / (torch.norm(grad) + 1e-8) + + scaling_grad = grad * scaling_factor + + scaling_grad_norm = torch.norm(scaling_grad) + if 'scaling_grad' in state: + limiter = ( + max( + scaling_grad_norm / (state['scaling_grad'] + 1e-8), + 1.01, + ) + / 1.01 + ) + + scaling_grad.div_(limiter) + scaling_grad_norm.div_(limiter) + + state['scaling_grad'] = scaling_grad_norm + + norm_grad = scaling_grad * np.sqrt(group['scale']) + norm_grad = state['projector'].project_back(norm_grad) + + p.add_(norm_grad, alpha=-step_size) + + self.apply_weight_decay( + p, + grad, + lr=step_size, + weight_decay=group['weight_decay'], + weight_decouple=group['weight_decouple'], + fixed_decay=group['fixed_decay'], + ) + + return loss diff --git a/pytorch_optimizer/optimizer/galore.py b/pytorch_optimizer/optimizer/galore.py index 44d21f14..2dfaadc4 100644 --- a/pytorch_optimizer/optimizer/galore.py +++ b/pytorch_optimizer/optimizer/galore.py @@ -1,156 +1,11 @@ import math -from typing import Literal, Optional, Tuple, Union import torch from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS - -PROJECTION_TYPE = Literal['std', 'reverse_std', 'right', 'left', 'full', 'random'] - - -class GaLoreProjector: - r"""Memory-Efficient LLM Training by Gradient Low-Rank Projection. - - :param rank: int. low rank to project. - :param update_proj_gap: int. num steps to update the projection. - :param scale: float. scale factor. - :param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' and - 'random' are supported. - """ - - def __init__( - self, - rank: int = 128, - update_proj_gap: int = 50, - scale: float = 1.0, - projection_type: PROJECTION_TYPE = 'std', - **kwargs, - ): - self.rank = rank - self.update_proj_gap = update_proj_gap - self.scale = scale - self.projection_type = projection_type - - self.ortho_matrix: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None - - @staticmethod - def get_orthogonal_matrix( - weights: torch.Tensor, rank: int, projection_type: str - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if projection_type not in {'right', 'left', 'full'}: - raise ValueError('projection_type should be one of left, right or full') - - original_type = weights.data.dtype - original_device = weights.data.device - is_float: bool = original_type == torch.float - - u, s, vh = torch.linalg.svd(weights if is_float else weights.float(), full_matrices=False) - - if projection_type == 'right': - b = vh[:rank, :] - return b if is_float else b.to(original_device).type(original_type) - if projection_type == 'left': - a = u[:, :rank] - return a if is_float else a.to(original_device).type(original_type) - - a = u[:, :rank] - b = vh[:rank, :] - - return ( - (a, b) - if is_float - else (a.to(original_device).type(original_type), b.to(original_device).type(original_type)) - ) - - def get_low_rank_grad_std(self, grad: torch.Tensor, steps: int) -> torch.Tensor: - if grad.shape[0] >= grad.shape[1]: - if self.ortho_matrix is None or steps % self.update_proj_gap == 0: - self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right') - return torch.matmul(grad, self.ortho_matrix.t()) - - if self.ortho_matrix is None or steps % self.update_proj_gap == 0: - self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left') - - return torch.matmul(self.ortho_matrix.t(), grad) - - def get_low_rank_grad_reverse_std(self, grad: torch.Tensor, steps: int) -> torch.Tensor: - if grad.shape[0] >= grad.shape[1]: - if self.ortho_matrix is None or steps % self.update_proj_gap == 0: - self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left') - return torch.matmul(self.ortho_matrix.t(), grad) - - if self.ortho_matrix is None or steps % self.update_proj_gap == 0: - self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right') - - return torch.matmul(grad, self.ortho_matrix.t()) - - def get_low_rank_grad_right(self, grad: torch.Tensor, steps: int) -> torch.Tensor: - if self.ortho_matrix is None or steps % self.update_proj_gap == 0: - self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right') - return torch.matmul(grad, self.ortho_matrix.t()) - - def get_low_rank_grad_left(self, grad: torch.Tensor, steps: int) -> torch.Tensor: - if self.ortho_matrix is None or steps % self.update_proj_gap == 0: - self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left') - return torch.matmul(self.ortho_matrix.t(), grad) - - def get_low_rank_grad_full(self, grad: torch.Tensor, steps: int) -> torch.Tensor: - if self.ortho_matrix is None or steps % self.update_proj_gap == 0: - self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='full') - return torch.matmul(self.ortho_matrix[0].t(), grad) @ self.ortho_matrix[1].t() - - def get_low_rank_grad_random(self, grad: torch.Tensor, steps: int) -> torch.Tensor: - is_right: bool = grad.size(0) >= grad.size(1) - if self.ortho_matrix is None or steps % self.update_proj_gap == 0: - self.ortho_matrix = self.get_orthogonal_matrix( - grad, self.rank, projection_type='right' if is_right else 'left' - ) - return torch.matmul(grad, self.ortho_matrix.t()) if is_right else torch.matmul(self.ortho_matrix.t(), grad) - - def project(self, full_rank_grad: torch.Tensor, steps: int) -> torch.Tensor: - if self.projection_type == 'std': - return self.get_low_rank_grad_std(full_rank_grad, steps) - if self.projection_type == 'reverse_std': - return self.get_low_rank_grad_reverse_std(full_rank_grad, steps) - if self.projection_type == 'right': - return self.get_low_rank_grad_right(full_rank_grad, steps) - if self.projection_type == 'left': - return self.get_low_rank_grad_left(full_rank_grad, steps) - if self.projection_type == 'full': - return self.get_low_rank_grad_full(full_rank_grad, steps) - if self.projection_type == 'random': - return self.get_low_rank_grad_random(full_rank_grad, steps) - raise NotImplementedError - - def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor: - if self.projection_type == 'std': - return ( - torch.matmul(low_rank_grad, self.ortho_matrix) - if low_rank_grad.shape[0] >= low_rank_grad.shape[1] - else torch.matmul(self.ortho_matrix, low_rank_grad) - ) * self.scale - if self.projection_type == 'reverse_std': - return ( - torch.matmul(self.ortho_matrix, low_rank_grad.t()) - if low_rank_grad.shape[0] <= low_rank_grad.shape[1] - else torch.matmul(low_rank_grad, self.ortho_matrix.t()) - ) * self.scale - if self.projection_type == 'right': - return torch.matmul(low_rank_grad, self.ortho_matrix.t()) * self.scale - if self.projection_type == 'left': - return torch.matmul(self.ortho_matrix, low_rank_grad) * self.scale - if self.projection_type == 'full': - return torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1].t() * self.scale - if self.projection_type == 'random': - return ( - torch.matmul(low_rank_grad, self.ortho_matrix.t()) - if low_rank_grad.shape[0] >= low_rank_grad.shape[1] - else torch.matmul(self.ortho_matrix, low_rank_grad) - ) * self.scale - - raise NotImplementedError +from pytorch_optimizer.optimizer.galore_utils import GaLoreProjector class GaLore(BaseOptimizer): diff --git a/pytorch_optimizer/optimizer/galore_utils.py b/pytorch_optimizer/optimizer/galore_utils.py new file mode 100644 index 00000000..5041c5c4 --- /dev/null +++ b/pytorch_optimizer/optimizer/galore_utils.py @@ -0,0 +1,170 @@ +import math +from typing import Literal, Optional, Tuple, Union + +import torch + +PROJECTION_TYPE = Literal['std', 'reverse_std', 'right', 'left', 'full', 'random'] + + +class GaLoreProjector: + r"""Memory-Efficient LLM Training by Gradient Low-Rank Projection. + + :param rank: int. low rank to project. + :param update_proj_gap: int. num steps to update the projection. + :param scale: float. scale factor. + :param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' and + 'random' are supported. + """ + + def __init__( + self, + rank: int = 128, + update_proj_gap: int = 50, + scale: float = 1.0, + projection_type: PROJECTION_TYPE = 'std', + **kwargs, + ): + self.rank = rank + self.update_proj_gap = update_proj_gap + self.scale = scale + self.projection_type = projection_type + + self.ortho_matrix: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None + + @staticmethod + def get_orthogonal_matrix( + weights: torch.Tensor, rank: int, projection_type: str, from_random_matrix: bool = False + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if projection_type not in {'right', 'left', 'full'}: + raise ValueError('projection_type should be one of left, right or full') + + original_type = weights.data.dtype + original_device = weights.data.device + is_float: bool = original_type == torch.float + + if not from_random_matrix: + u, _, vh = torch.linalg.svd(weights if is_float else weights.float(), full_matrices=False) + else: + u = torch.randn((weights.size(0), rank), device=original_device, dtype=original_type) / math.sqrt(rank) + vh = torch.randn((rank, weights.size(1)), device=original_device, dtype=original_type) / math.sqrt(rank) + + if projection_type == 'right': + b = vh[:rank, :] + return b if is_float else b.to(original_device).type(original_type) + if projection_type == 'left': + a = u[:, :rank] + return a if is_float else a.to(original_device).type(original_type) + + a = u[:, :rank] + b = vh[:rank, :] + + return ( + (a, b) + if is_float + else (a.to(original_device).type(original_type), b.to(original_device).type(original_type)) + ) + + def get_low_rank_grad_std(self, grad: torch.Tensor, steps: int, from_random_matrix: bool) -> torch.Tensor: + if grad.shape[0] >= grad.shape[1]: + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + grad, self.rank, projection_type='right', from_random_matrix=from_random_matrix + ) + return torch.matmul(grad, self.ortho_matrix.t()) + + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + grad, self.rank, projection_type='left', from_random_matrix=from_random_matrix + ) + + return torch.matmul(self.ortho_matrix.t(), grad) + + def get_low_rank_grad_reverse_std(self, grad: torch.Tensor, steps: int, from_random_matrix: bool) -> torch.Tensor: + if grad.shape[0] >= grad.shape[1]: + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + grad, self.rank, projection_type='left', from_random_matrix=from_random_matrix + ) + return torch.matmul(self.ortho_matrix.t(), grad) + + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + grad, self.rank, projection_type='right', from_random_matrix=from_random_matrix + ) + + return torch.matmul(grad, self.ortho_matrix.t()) + + def get_low_rank_grad_right(self, grad: torch.Tensor, steps: int, from_random_matrix: bool) -> torch.Tensor: + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + grad, self.rank, projection_type='right', from_random_matrix=from_random_matrix + ) + return torch.matmul(grad, self.ortho_matrix.t()) + + def get_low_rank_grad_left(self, grad: torch.Tensor, steps: int, from_random_matrix: bool) -> torch.Tensor: + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + grad, self.rank, projection_type='left', from_random_matrix=from_random_matrix + ) + return torch.matmul(self.ortho_matrix.t(), grad) + + def get_low_rank_grad_full(self, grad: torch.Tensor, steps: int, from_random_matrix: bool) -> torch.Tensor: + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + grad, self.rank, projection_type='full', from_random_matrix=from_random_matrix + ) + return torch.matmul(self.ortho_matrix[0].t(), grad) @ self.ortho_matrix[1].t() + + def get_low_rank_grad_random(self, grad: torch.Tensor, steps: int, from_random_matrix: bool) -> torch.Tensor: + is_right: bool = grad.size(0) >= grad.size(1) + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + grad, + self.rank, + projection_type='right' if is_right else 'left', + from_random_matrix=from_random_matrix, + ) + return torch.matmul(grad, self.ortho_matrix.t()) if is_right else torch.matmul(self.ortho_matrix.t(), grad) + + def project(self, full_rank_grad: torch.Tensor, steps: int, from_random_matrix: bool = False) -> torch.Tensor: + if self.projection_type == 'std': + return self.get_low_rank_grad_std(full_rank_grad, steps, from_random_matrix) + if self.projection_type == 'reverse_std': + return self.get_low_rank_grad_reverse_std(full_rank_grad, steps, from_random_matrix) + if self.projection_type == 'right': + return self.get_low_rank_grad_right(full_rank_grad, steps, from_random_matrix) + if self.projection_type == 'left': + return self.get_low_rank_grad_left(full_rank_grad, steps, from_random_matrix) + if self.projection_type == 'full': + return self.get_low_rank_grad_full(full_rank_grad, steps, from_random_matrix) + if self.projection_type == 'random': + return self.get_low_rank_grad_random(full_rank_grad, steps, from_random_matrix) + raise NotImplementedError + + def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor: + if self.projection_type == 'std': + return ( + torch.matmul(low_rank_grad, self.ortho_matrix) + if low_rank_grad.shape[0] >= low_rank_grad.shape[1] + else torch.matmul(self.ortho_matrix, low_rank_grad) + ) * self.scale + if self.projection_type == 'reverse_std': + return ( + torch.matmul(self.ortho_matrix, low_rank_grad.t()) + if low_rank_grad.shape[0] <= low_rank_grad.shape[1] + else torch.matmul(low_rank_grad, self.ortho_matrix.t()) + ) * self.scale + if self.projection_type == 'right': + return torch.matmul(low_rank_grad, self.ortho_matrix.t()) * self.scale + if self.projection_type == 'left': + return torch.matmul(self.ortho_matrix, low_rank_grad) * self.scale + if self.projection_type == 'full': + return torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1].t() * self.scale + if self.projection_type == 'random': + return ( + torch.matmul(low_rank_grad, self.ortho_matrix.t()) + if low_rank_grad.shape[0] >= low_rank_grad.shape[1] + else torch.matmul(self.ortho_matrix, low_rank_grad) + ) * self.scale + + raise NotImplementedError diff --git a/tests/constants.py b/tests/constants.py index 21b4e0bf..2cce0c9b 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -2,6 +2,7 @@ from pytorch_optimizer.optimizer import ( ADOPT, + APOLLO, ASGD, CAME, FTRL, @@ -42,7 +43,7 @@ Aida, AliG, Amos, - Apollo, + ApolloDQN, AvaGrad, DAdaptAdaGrad, DAdaptAdam, @@ -149,6 +150,7 @@ 'soap', 'muon', 'laprop', + 'apollo', ] VALID_LR_SCHEDULER_NAMES: List[str] = [ @@ -381,9 +383,9 @@ (AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100), (AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'ams_bound': True}, 120), (AdaFactor, {'lr': 1e1, 'betas': (None, 0.999), 'weight_decay': 1e-3}, 40), - (Apollo, {'lr': 5e-1, 'weight_decay': 1e-3}, 10), - (Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'rebound': 'belief'}, 10), - (Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decay_type': 'stable', 'warmup_steps': 0}, 50), + (ApolloDQN, {'lr': 5e-1, 'weight_decay': 1e-3}, 10), + (ApolloDQN, {'lr': 5e-1, 'weight_decay': 1e-3, 'rebound': 'belief'}, 10), + (ApolloDQN, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decay_type': 'stable', 'warmup_steps': 0}, 50), (NovoGrad, {'lr': 5e-1, 'weight_decay': 1e-3, 'grad_averaging': True}, 5), (Lion, {'lr': 5e-1, 'weight_decay': 1e-3}, 5), (Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': False}, 5), @@ -504,6 +506,32 @@ (LaProp, {'lr': 1e0, 'weight_decay': 1e-3}, 5), (LaProp, {'lr': 1e0, 'centered': True, 'weight_decay': 1e-3}, 11), (LaProp, {'lr': 1e0, 'ams_bound': True, 'weight_decay': 1e-3}, 5), + ( + APOLLO, + { + 'lr': 1e-1, + 'weight_decay': 1e-3, + 'rank': 2, + 'update_proj_gap': 1, + 'scale': 1.0, + 'scale_type': 'tensor', + 'projection_type': 'right', + }, + 15, + ), + ( + APOLLO, + { + 'lr': 1e-1, + 'weight_decay': 1e-3, + 'rank': 2, + 'update_proj_gap': 1, + 'scale': 1.0, + 'scale_type': 'channel', + 'projection_type': 'right', + }, + 15, + ), ] ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [ (AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10), diff --git a/tests/test_general_optimizer_parameters.py b/tests/test_general_optimizer_parameters.py index 31f4dae2..7b6b35d5 100644 --- a/tests/test_general_optimizer_parameters.py +++ b/tests/test_general_optimizer_parameters.py @@ -112,7 +112,7 @@ def test_momentum(optimizer_name): optimizer(None, momentum=-1e-3) -@pytest.mark.parametrize('optimizer_name', ['nero', 'apollo', 'sm3', 'msvag', 'ranger21']) +@pytest.mark.parametrize('optimizer_name', ['nero', 'apollodqn', 'sm3', 'msvag', 'ranger21']) def test_beta(optimizer_name): optimizer = load_optimizer(optimizer_name) diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index ccca7ef5..62238ccd 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -34,7 +34,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names): def test_get_supported_optimizers(): - assert len(get_supported_optimizers()) == 82 + assert len(get_supported_optimizers()) == 83 assert len(get_supported_optimizers('adam*')) == 7 assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 9 diff --git a/tests/test_optimizer_parameters.py b/tests/test_optimizer_parameters.py index 90149013..1e2d6259 100644 --- a/tests/test_optimizer_parameters.py +++ b/tests/test_optimizer_parameters.py @@ -3,7 +3,7 @@ from torch import nn from pytorch_optimizer.optimizer import SAM, WSAM, Lookahead, PCGrad, Ranger21, SafeFP16Optimizer, load_optimizer -from pytorch_optimizer.optimizer.galore import GaLoreProjector +from pytorch_optimizer.optimizer.galore_utils import GaLoreProjector from tests.constants import PULLBACK_MOMENTUM from tests.utils import Example, simple_parameter, simple_zero_rank_parameter @@ -231,7 +231,7 @@ def test_lars_parameters(): def test_apollo_parameters(): - opt = load_optimizer('apollo') + opt = load_optimizer('apollodqn') # test rebound type with pytest.raises(ValueError): @@ -257,6 +257,8 @@ def test_ranger_parameters(): def test_galore_projection_type(): p = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + _ = GaLoreProjector.get_orthogonal_matrix(p, 1, projection_type='left', from_random_matrix=True) + with pytest.raises(NotImplementedError): GaLoreProjector(projection_type='invalid').project(p, 1)