From 4c71b9dc8373e916480fca3c72dd548c0c607720 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 27 Nov 2024 22:38:12 +0900 Subject: [PATCH 01/13] docs: c-opt --- README.md | 1 + docs/index.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index c72a59b6..afd924fc 100644 --- a/README.md +++ b/README.md @@ -183,6 +183,7 @@ get_supported_optimizers(['adam*', 'ranger*']) | SOAP | *Improving and Stabilizing Shampoo using Adam* | [github](https://github.com/nikhilvyas/SOAP) | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240911321V/exportcitation) | | ADOPT | *Modified Adam Can Converge with Any β2 with the Optimal Rate* | [github](https://github.com/iShohei220/adopt) | | [cite](https://github.com/iShohei220/adopt?tab=readme-ov-file#citation) | | FTRL | *Follow The Regularized Leader* | | | | +| Cautious | *Improving Training with One Line of Code* | [github](https://github.com/kyleliang919/C-Optim) | | [cite](https://github.com/kyleliang919/C-Optim?tab=readme-ov-file#citation) | ## Supported LR Scheduler diff --git a/docs/index.md b/docs/index.md index c72a59b6..afd924fc 100644 --- a/docs/index.md +++ b/docs/index.md @@ -183,6 +183,7 @@ get_supported_optimizers(['adam*', 'ranger*']) | SOAP | *Improving and Stabilizing Shampoo using Adam* | [github](https://github.com/nikhilvyas/SOAP) | | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240911321V/exportcitation) | | ADOPT | *Modified Adam Can Converge with Any β2 with the Optimal Rate* | [github](https://github.com/iShohei220/adopt) | | [cite](https://github.com/iShohei220/adopt?tab=readme-ov-file#citation) | | FTRL | *Follow The Regularized Leader* | | | | +| Cautious | *Improving Training with One Line of Code* | [github](https://github.com/kyleliang919/C-Optim) | | [cite](https://github.com/kyleliang919/C-Optim?tab=readme-ov-file#citation) | ## Supported LR Scheduler From 8cbd63b2385b028a7fa5ffee6de379285aefe60d Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 27 Nov 2024 22:38:20 +0900 Subject: [PATCH 02/13] docs: v3.3.0 changelog --- docs/changelogs/v3.3.0.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/changelogs/v3.3.0.md b/docs/changelogs/v3.3.0.md index 7f0eb296..c93987eb 100644 --- a/docs/changelogs/v3.3.0.md +++ b/docs/changelogs/v3.3.0.md @@ -8,6 +8,9 @@ * [Modified Adam Can Converge with Any β2 with the Optimal Rate](https://arxiv.org/abs/2411.02853) * Implement `FTRL` optimizer. (#291) * [Follow The Regularized Leader](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf) +* Implement `Cautious optimizer` feature. (#292) + * [Improving Training with One Line of Code](https://arxiv.org/pdf/2411.16085v1) + * you can use it by setting `cautious=True` for `Lion` and `AdaFactor` optimizers. ### Refactor From 3adcc08f057c9174881e702beef1c4d0ae110ce9 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 27 Nov 2024 22:38:38 +0900 Subject: [PATCH 03/13] feature: implement cautious-opt --- pytorch_optimizer/base/optimizer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pytorch_optimizer/base/optimizer.py b/pytorch_optimizer/base/optimizer.py index f496af7b..b3b1ff8c 100644 --- a/pytorch_optimizer/base/optimizer.py +++ b/pytorch_optimizer/base/optimizer.py @@ -255,6 +255,17 @@ def approximate_sq_grad( c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt() torch.mul(r_factor, c_factor, out=output) + @staticmethod + def apply_cautious(update: torch.Tensor, grad: torch.Tensor) -> None: + r"""Apply the Cautious Optimizer feature. + + :param update: torch.Tensor. update. it'll be masked in in-place manner. + :param grad: torch.Tensor. gradient. + """ + mask = (update * grad > 0).to(grad.dtype) + mask.mul_(mask.numel() / (mask.sum() + 1)) + update.mul_(mask) + @staticmethod def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)') -> None: if range_type == '[)' and not low <= x < high: From dcea26a5e40077381ae0af5bb0863a914b0d13df Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 27 Nov 2024 22:38:46 +0900 Subject: [PATCH 04/13] feature: cautious-opt --- pytorch_optimizer/optimizer/adafactor.py | 7 ++++++- pytorch_optimizer/optimizer/lion.py | 6 ++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pytorch_optimizer/optimizer/adafactor.py b/pytorch_optimizer/optimizer/adafactor.py index fa2a2696..10698b9f 100644 --- a/pytorch_optimizer/optimizer/adafactor.py +++ b/pytorch_optimizer/optimizer/adafactor.py @@ -30,6 +30,7 @@ class AdaFactor(BaseOptimizer): :param momentum_dtype: torch.dtype. type of momentum variable. In VIT paper observed that storing momentum in half-precision (bfloat16 type) does not affect training dynamics and has no effect on the outcome while reducing optimize overhead from 2-fold to 1.5-fold. + :param cautious: bool. whether to use the Cautious variant. """ def __init__( @@ -49,6 +50,7 @@ def __init__( eps1: float = 1e-30, eps2: float = 1e-3, momentum_dtype: torch.dtype = torch.bfloat16, + cautious: bool = False, **kwargs, ): self.validate_learning_rate(lr) @@ -62,6 +64,7 @@ def __init__( self.eps1 = eps1 self.eps2 = eps2 self.momentum_dtype = momentum_dtype + self.cautious = cautious defaults: DEFAULTS = { 'lr': lr, @@ -214,7 +217,9 @@ def step(self, closure: CLOSURE = None) -> LOSS: exp_avg = state['exp_avg'] exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1) - update = exp_avg + update = exp_avg.clone() + if self.cautious: + self.apply_cautious(update, grad) self.apply_weight_decay( p=p, diff --git a/pytorch_optimizer/optimizer/lion.py b/pytorch_optimizer/optimizer/lion.py index 5665ad31..511d1473 100644 --- a/pytorch_optimizer/optimizer/lion.py +++ b/pytorch_optimizer/optimizer/lion.py @@ -18,6 +18,7 @@ class Lion(BaseOptimizer): :param use_gc: bool. use gradient centralization. :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred. :param adanorm: bool. whether to use the AdaNorm variant. + :param cautious: bool. whether to use the Cautious variant. """ def __init__( @@ -31,6 +32,7 @@ def __init__( use_gc: bool = False, r: float = 0.95, adanorm: bool = False, + cautious: bool = False, **kwargs, ): self.validate_learning_rate(lr) @@ -38,6 +40,7 @@ def __init__( self.validate_non_negative(weight_decay, 'weight_decay') self.use_gc = use_gc + self.cautious = cautious defaults: DEFAULTS = { 'lr': lr, @@ -114,6 +117,9 @@ def step(self, closure: CLOSURE = None) -> LOSS: update.mul_(beta1).add_(grad, alpha=1.0 - beta1).sign_() exp_avg.mul_(beta2).add_(s_grad, alpha=1.0 - beta2) + if self.cautious: + self.apply_cautious(update, grad) + p.add_(update, alpha=-group['lr']) return loss From 56cfedfcce5f26826988274e77a925a1423e9310 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 27 Nov 2024 22:38:55 +0900 Subject: [PATCH 05/13] update: test cases --- tests/constants.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/constants.py b/tests/constants.py index 8f2abe12..d9badc6c 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -375,6 +375,7 @@ (AdamS, {'lr': 1e0, 'weight_decay': 1e-3, 'ams_bound': True}, 20), (AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100), (AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'ams_bound': True}, 120), + (AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'cautious': True}, 70), (AdaFactor, {'lr': 1e1, 'betas': (None, 0.999), 'weight_decay': 1e-3}, 40), (Apollo, {'lr': 5e-1, 'weight_decay': 1e-3}, 10), (Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'rebound': 'belief'}, 10), @@ -383,6 +384,7 @@ (Lion, {'lr': 5e-1, 'weight_decay': 1e-3}, 5), (Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': False}, 5), (Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'use_gc': True}, 10), + (Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'cautious': True}, 5), (AliG, {'max_lr': 5e-1, 'momentum': 0.9}, 5), (AliG, {'max_lr': 5e-1, 'momentum': 0.9, 'adjusted_momentum': True}, 5), (SM3, {'lr': 5e-1, 'momentum': 0.9, 'beta': 0.9}, 5), From ba8957d5495dde5c1503ba43728f1f92ef7a5b15 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 27 Nov 2024 22:52:20 +0900 Subject: [PATCH 06/13] update: support cautious-opt --- pytorch_optimizer/optimizer/ademamix.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pytorch_optimizer/optimizer/ademamix.py b/pytorch_optimizer/optimizer/ademamix.py index 1d2cdac2..a0c57e1e 100644 --- a/pytorch_optimizer/optimizer/ademamix.py +++ b/pytorch_optimizer/optimizer/ademamix.py @@ -19,6 +19,7 @@ class AdEMAMix(BaseOptimizer): :param fixed_decay: bool. fix weight decay. :param alpha: float. usually between 4 and 10 would work well. :param t_alpha_beta3: Optional[float]. total number of iterations is preferred when needed. + :param cautious: bool. whether to use cautious feature. :param eps: float. term added to the denominator to improve numerical stability. """ @@ -32,6 +33,7 @@ def __init__( fixed_decay: bool = False, alpha: float = 5.0, t_alpha_beta3: Optional[float] = None, + cautious: bool = False, eps: float = 1e-8, **kwargs, ): @@ -42,6 +44,8 @@ def __init__( self.validate_non_negative(weight_decay, 'weight_decay') self.validate_non_negative(eps, 'eps') + self.cautious = cautious + defaults: DEFAULTS = { 'lr': lr, 'betas': betas, @@ -71,9 +75,7 @@ def reset(self): @staticmethod def schedule_alpha(t_alpha_beta3: Optional[float], step: int, alpha: float) -> float: - if t_alpha_beta3 is None: - return alpha - return min(step * alpha / t_alpha_beta3, alpha) + return alpha if t_alpha_beta3 is None else min(step * alpha / t_alpha_beta3, alpha) @staticmethod def schedule_beta3(t_alpha_beta3: Optional[float], step: int, beta1: float, beta3: float) -> float: @@ -107,6 +109,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: bias_correction1: float = self.debias(beta1, group['step']) bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) + step_size: float = group['lr'] / bias_correction1 + alpha_t: float = self.schedule_alpha(group['t_alpha_beta3'], group['step'], group['alpha']) beta3_t: float = self.schedule_beta3(group['t_alpha_beta3'], group['step'], beta1, beta3) @@ -140,10 +144,12 @@ def step(self, closure: CLOSURE = None) -> LOSS: exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) exp_avg_slow.mul_(beta3_t).add_(grad, alpha=1.0 - beta3_t) - de_nom = (exp_avg_sq.sqrt() / bias_correction2_sq).add_(group['eps']) + de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps']) - step_size = group['lr'] / bias_correction1 + update = (exp_avg + alpha_t * exp_avg_slow).div_(de_nom) + if self.cautious: + self.apply_cautious(update, grad) - p.addcdiv_(exp_avg + alpha_t * exp_avg_slow, de_nom, value=-step_size) + p.add_(update, alpha=-step_size) return loss From b705a8bccc29ebf34f894663ebb3942f509983e1 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 27 Nov 2024 22:52:26 +0900 Subject: [PATCH 07/13] update: test case --- tests/constants.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/constants.py b/tests/constants.py index d9badc6c..0ddea933 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -480,8 +480,9 @@ (Kate, {'lr': 5e-2}, 10), (StableAdamW, {'lr': 1e0}, 5), (AdamG, {'lr': 1e0}, 20), - (AdEMAMix, {'lr': 1e0}, 5), - (AdEMAMix, {'lr': 1e0, 't_alpha_beta3': 5}, 5), + (AdEMAMix, {'lr': 1e0}, 3), + (AdEMAMix, {'lr': 1e0, 't_alpha_beta3': 5}, 3), + (AdEMAMix, {'lr': 1e0, 'cautious': True}, 2), ( SOAP, {'lr': 1e0, 'shampoo_beta': 0.95, 'precondition_frequency': 1, 'merge_dims': False, 'precondition_1d': True}, From 6c004a7c33096152e7f3238fe0ba4aa0f8aa54f9 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 27 Nov 2024 22:53:22 +0900 Subject: [PATCH 08/13] docs: v3.3.0 changelog --- docs/changelogs/v3.3.0.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/changelogs/v3.3.0.md b/docs/changelogs/v3.3.0.md index c93987eb..327785a0 100644 --- a/docs/changelogs/v3.3.0.md +++ b/docs/changelogs/v3.3.0.md @@ -8,9 +8,9 @@ * [Modified Adam Can Converge with Any β2 with the Optimal Rate](https://arxiv.org/abs/2411.02853) * Implement `FTRL` optimizer. (#291) * [Follow The Regularized Leader](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf) -* Implement `Cautious optimizer` feature. (#292) +* Implement `Cautious optimizer` feature. (#294) * [Improving Training with One Line of Code](https://arxiv.org/pdf/2411.16085v1) - * you can use it by setting `cautious=True` for `Lion` and `AdaFactor` optimizers. + * you can use it by setting `cautious=True` for `Lion`, `AdaFactor` and `AdEMAMix` optimizers. ### Refactor From deb84f5425fc89737226d77d3027fbbe7c7418a8 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 27 Nov 2024 23:23:26 +0900 Subject: [PATCH 09/13] docs: v3.3.0 changelog --- docs/changelogs/v3.3.0.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changelogs/v3.3.0.md b/docs/changelogs/v3.3.0.md index 327785a0..5a0816c2 100644 --- a/docs/changelogs/v3.3.0.md +++ b/docs/changelogs/v3.3.0.md @@ -11,6 +11,8 @@ * Implement `Cautious optimizer` feature. (#294) * [Improving Training with One Line of Code](https://arxiv.org/pdf/2411.16085v1) * you can use it by setting `cautious=True` for `Lion`, `AdaFactor` and `AdEMAMix` optimizers. +* Improve the stability of `ADOPT` optimizer. (#294) + * [Note](https://github.com/iShohei220/adopt?tab=readme-ov-file#update-on-nov-22-2024) ### Refactor From d73034f7120ffd6661057b10dc2dd22449d802df Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 27 Nov 2024 23:23:34 +0900 Subject: [PATCH 10/13] update: improve the stability --- pytorch_optimizer/optimizer/adopt.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/pytorch_optimizer/optimizer/adopt.py b/pytorch_optimizer/optimizer/adopt.py index 8b7026db..9e95bb22 100644 --- a/pytorch_optimizer/optimizer/adopt.py +++ b/pytorch_optimizer/optimizer/adopt.py @@ -1,3 +1,6 @@ +import math +from typing import Callable, Optional + import torch from pytorch_optimizer.base.exception import NoSparseGradientError @@ -22,6 +25,7 @@ def __init__( params: PARAMETERS, lr: float = 1e-3, betas: BETAS = (0.9, 0.9999), + clip_lambda: Optional[Callable[[float], float]] = lambda step: math.pow(step, 0.25), weight_decay: float = 0.0, weight_decouple: bool = False, fixed_decay: bool = False, @@ -33,6 +37,8 @@ def __init__( self.validate_non_negative(weight_decay, 'weight_decay') self.validate_non_negative(eps, 'eps') + self.clip_lambda = clip_lambda + defaults: DEFAULTS = { 'lr': lr, 'betas': betas, @@ -104,10 +110,13 @@ def step(self, closure: CLOSURE = None) -> LOSS: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1.0 - beta2) de_nom = exp_avg_sq.sqrt().clamp_(min=group['eps']) - if group['step'] == 2: - exp_avg.addcdiv_(grad, de_nom) - else: - exp_avg.mul_(beta1).addcdiv_(grad, de_nom, value=1.0 - beta1) + + normed_grad = grad.div(de_nom) + if self.clip_lambda is not None: + clip = self.clip_lambda(group['step']) + normed_grad.clamp_(-clip, clip) + + exp_avg.lerp_(normed_grad, weight=1.0 - beta1) p.add_(exp_avg, alpha=-group['lr']) From 82d81f2503bd9856017e6f87408d3ed02c91b009 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 27 Nov 2024 23:48:38 +0900 Subject: [PATCH 11/13] update: test case --- tests/constants.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/constants.py b/tests/constants.py index 0ddea933..1397f670 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -471,6 +471,11 @@ {'lr': 5e-1, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 2, 'projection_type': 'full'}, 5, ), + ( + GaLore, + {'lr': 1e0, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 1, 'projection_type': 'random'}, + 5, + ), (Adalite, {'lr': 1e0, 'weight_decay': 1e-3}, 5), (ScheduleFreeSGD, {'lr': 1e0, 'weight_decay': 1e-3}, 5), (ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5), From 406c59ef49566d8d98389b333002e9355179901e Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 27 Nov 2024 23:48:46 +0900 Subject: [PATCH 12/13] docs: v3.3.0 changelog --- docs/changelogs/v3.3.0.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelogs/v3.3.0.md b/docs/changelogs/v3.3.0.md index 5a0816c2..bb1fde17 100644 --- a/docs/changelogs/v3.3.0.md +++ b/docs/changelogs/v3.3.0.md @@ -13,6 +13,7 @@ * you can use it by setting `cautious=True` for `Lion`, `AdaFactor` and `AdEMAMix` optimizers. * Improve the stability of `ADOPT` optimizer. (#294) * [Note](https://github.com/iShohei220/adopt?tab=readme-ov-file#update-on-nov-22-2024) +* Support a new projection type `random` for `GaLoreProjector`. (#294) ### Refactor From db82a58d1a42b5d222e996cb48be23939a85c4b7 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 27 Nov 2024 23:49:13 +0900 Subject: [PATCH 13/13] feature: support random projection_type --- pytorch_optimizer/optimizer/galore.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/pytorch_optimizer/optimizer/galore.py b/pytorch_optimizer/optimizer/galore.py index e0592119..44d21f14 100644 --- a/pytorch_optimizer/optimizer/galore.py +++ b/pytorch_optimizer/optimizer/galore.py @@ -7,7 +7,7 @@ from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -PROJECTION_TYPE = Literal['std', 'reverse_std', 'right', 'left', 'full'] +PROJECTION_TYPE = Literal['std', 'reverse_std', 'right', 'left', 'full', 'random'] class GaLoreProjector: @@ -16,8 +16,8 @@ class GaLoreProjector: :param rank: int. low rank to project. :param update_proj_gap: int. num steps to update the projection. :param scale: float. scale factor. - :param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' are - supported. + :param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' and + 'random' are supported. """ def __init__( @@ -101,6 +101,14 @@ def get_low_rank_grad_full(self, grad: torch.Tensor, steps: int) -> torch.Tensor self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='full') return torch.matmul(self.ortho_matrix[0].t(), grad) @ self.ortho_matrix[1].t() + def get_low_rank_grad_random(self, grad: torch.Tensor, steps: int) -> torch.Tensor: + is_right: bool = grad.size(0) >= grad.size(1) + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + grad, self.rank, projection_type='right' if is_right else 'left' + ) + return torch.matmul(grad, self.ortho_matrix.t()) if is_right else torch.matmul(self.ortho_matrix.t(), grad) + def project(self, full_rank_grad: torch.Tensor, steps: int) -> torch.Tensor: if self.projection_type == 'std': return self.get_low_rank_grad_std(full_rank_grad, steps) @@ -112,6 +120,8 @@ def project(self, full_rank_grad: torch.Tensor, steps: int) -> torch.Tensor: return self.get_low_rank_grad_left(full_rank_grad, steps) if self.projection_type == 'full': return self.get_low_rank_grad_full(full_rank_grad, steps) + if self.projection_type == 'random': + return self.get_low_rank_grad_random(full_rank_grad, steps) raise NotImplementedError def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor: @@ -133,6 +143,12 @@ def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor: return torch.matmul(self.ortho_matrix, low_rank_grad) * self.scale if self.projection_type == 'full': return torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1].t() * self.scale + if self.projection_type == 'random': + return ( + torch.matmul(low_rank_grad, self.ortho_matrix.t()) + if low_rank_grad.shape[0] >= low_rank_grad.shape[1] + else torch.matmul(self.ortho_matrix, low_rank_grad) + ) * self.scale raise NotImplementedError