diff --git a/README.md b/README.md index a284ee3b..c33bd391 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ ## The reasons why you use `pytorch-optimizer`. -* Wide range of supported optimizers. Currently, **85 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! +* Wide range of supported optimizers. Currently, **86 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! * Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion` * Easy to use, clean, and tested codes * Active maintenance @@ -194,6 +194,7 @@ get_supported_optimizers(['adam*', 'ranger*']) | LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) | | APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) | | MARS | *Unleashing the Power of Variance Reduction for Training Large Models* | [github](https://github.com/AGI-Arena/MARS) | | [cite](https://github.com/AGI-Arena/MARS/tree/main?tab=readme-ov-file#citation) | +| SGDSaI | *No More Adam: Learning Rate Scaling at Initialization is All You Need* | [github](https://github.com/AnonymousAlethiometer/SGD_SaI) | | [cite](https://github.com/AnonymousAlethiometer/SGD_SaI?tab=readme-ov-file#citation) | ## Supported LR Scheduler diff --git a/docs/changelogs/v3.3.2.md b/docs/changelogs/v3.3.2.md new file mode 100644 index 00000000..3893ad82 --- /dev/null +++ b/docs/changelogs/v3.3.2.md @@ -0,0 +1,10 @@ +### Change Log + +### Feature + +* Implement `SGDSaI` optimizer. (#315, #316) + * [No More Adam: Learning Rate Scaling at Initialization is All You Need](https://arxiv.org/abs/2412.11768) + +### Bug + +* Clone `exp_avg` before calling `apply_cautious` not to mask `exp_avg`. (#316) diff --git a/docs/index.md b/docs/index.md index a284ee3b..c33bd391 100644 --- a/docs/index.md +++ b/docs/index.md @@ -10,7 +10,7 @@ ## The reasons why you use `pytorch-optimizer`. -* Wide range of supported optimizers. Currently, **85 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! +* Wide range of supported optimizers. Currently, **86 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported! * Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion` * Easy to use, clean, and tested codes * Active maintenance @@ -194,6 +194,7 @@ get_supported_optimizers(['adam*', 'ranger*']) | LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) | | APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) | | MARS | *Unleashing the Power of Variance Reduction for Training Large Models* | [github](https://github.com/AGI-Arena/MARS) | | [cite](https://github.com/AGI-Arena/MARS/tree/main?tab=readme-ov-file#citation) | +| SGDSaI | *No More Adam: Learning Rate Scaling at Initialization is All You Need* | [github](https://github.com/AnonymousAlethiometer/SGD_SaI) | | [cite](https://github.com/AnonymousAlethiometer/SGD_SaI?tab=readme-ov-file#citation) | ## Supported LR Scheduler diff --git a/docs/optimizer.md b/docs/optimizer.md index 66a2e5aa..07e671e5 100644 --- a/docs/optimizer.md +++ b/docs/optimizer.md @@ -332,6 +332,10 @@ :docstring: :members: +::: pytorch_optimizer.SGDSaI + :docstring: + :members: + ::: pytorch_optimizer.SGDP :docstring: :members: diff --git a/docs/visualization.md b/docs/visualization.md index 1ab5a506..acb1587d 100644 --- a/docs/visualization.md +++ b/docs/visualization.md @@ -274,6 +274,10 @@ ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SGDP.png) +### SGDSaI + +![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SGDSaI.png) + ### SGDW ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SGDW.png) @@ -592,6 +596,10 @@ ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SGDP.png) +### SGDSaI + +![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SGDSaI.png) + ### SGDW ![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SGDW.png) diff --git a/docs/visualizations/rastrigin_SGDSaI.png b/docs/visualizations/rastrigin_SGDSaI.png new file mode 100644 index 00000000..7aac4ba1 Binary files /dev/null and b/docs/visualizations/rastrigin_SGDSaI.png differ diff --git a/docs/visualizations/rosenbrock_SGDSaI.png b/docs/visualizations/rosenbrock_SGDSaI.png new file mode 100644 index 00000000..38136d02 Binary files /dev/null and b/docs/visualizations/rosenbrock_SGDSaI.png differ diff --git a/pyproject.toml b/pyproject.toml index 3aac1a48..9df77de0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pytorch_optimizer" -version = "3.3.1" +version = "3.3.2" description = "optimizer & lr scheduler & objective function collections in PyTorch" license = "Apache-2.0" authors = ["kozistr "] diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 39a200d1..75585a89 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -128,6 +128,7 @@ ScheduleFreeAdamW, ScheduleFreeRAdam, ScheduleFreeSGD, + SGDSaI, Shampoo, SignSGD, SophiaH, diff --git a/pytorch_optimizer/optimizer/__init__.py b/pytorch_optimizer/optimizer/__init__.py index fdbf7e6d..cfd2eb9d 100644 --- a/pytorch_optimizer/optimizer/__init__.py +++ b/pytorch_optimizer/optimizer/__init__.py @@ -74,7 +74,7 @@ from pytorch_optimizer.optimizer.rotograd import RotoGrad from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeRAdam, ScheduleFreeSGD -from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SignSGD +from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SGDSaI, SignSGD from pytorch_optimizer.optimizer.sgdp import SGDP from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo from pytorch_optimizer.optimizer.sm3 import SM3 @@ -281,6 +281,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER: ScheduleFreeRAdam, LaProp, MARS, + SGDSaI, ] OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST} diff --git a/pytorch_optimizer/optimizer/adashift.py b/pytorch_optimizer/optimizer/adashift.py index 64f78b7e..d61f383f 100644 --- a/pytorch_optimizer/optimizer/adashift.py +++ b/pytorch_optimizer/optimizer/adashift.py @@ -110,10 +110,11 @@ def step(self, closure: CLOSURE = None) -> LOSS: exp_avg_sq.mul_(beta2).add_(reduced_grad_sq, alpha=1.0 - beta2) update = exp_avg.clone() - update.div_(exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps'])) if self.cautious: self.apply_cautious(update, grad) + update.div_(exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps'])) + p.add_(update, alpha=-group['lr']) return loss diff --git a/pytorch_optimizer/optimizer/ademamix.py b/pytorch_optimizer/optimizer/ademamix.py index 7a3431fb..52a36a18 100644 --- a/pytorch_optimizer/optimizer/ademamix.py +++ b/pytorch_optimizer/optimizer/ademamix.py @@ -146,10 +146,11 @@ def step(self, closure: CLOSURE = None) -> LOSS: de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps']) + update = exp_avg.clone() if self.cautious: - self.apply_cautious(exp_avg, grad) + self.apply_cautious(update, grad) - update = (exp_avg + alpha_t * exp_avg_slow).div_(de_nom) + update.add_(exp_avg_slow, alpha=alpha_t).div_(de_nom) p.add_(update, alpha=-step_size) diff --git a/pytorch_optimizer/optimizer/mars.py b/pytorch_optimizer/optimizer/mars.py index 10d8dc50..ac6f4f31 100644 --- a/pytorch_optimizer/optimizer/mars.py +++ b/pytorch_optimizer/optimizer/mars.py @@ -121,8 +121,9 @@ def optimize_mixed( exp_avg.mul_(beta1).add_(c_t, alpha=1.0 - beta1) + update = exp_avg.clone() if cautious: - self.apply_cautious(exp_avg, grad) + self.apply_cautious(update, grad) if mars_type == 'adamw' or (mars_type == 'shampoo' and not is_grad_2d): exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1.0 - beta2) @@ -130,17 +131,17 @@ def optimize_mixed( bias_correction1: float = self.debias(beta1, step) bias_correction2_sq: float = math.sqrt(self.debias(beta2, step)) - update = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps) - update.div_(bias_correction2_sq).mul_(bias_correction1) + de_nom = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps) + de_nom.div_(bias_correction2_sq).mul_(bias_correction1) - return exp_avg.div(update) + return update.div_(de_nom) if mars_type == 'lion': - return exp_avg.sign() + return update.sign_() - factor: float = max(1.0, grad.size(0) / grad.size(1)) ** 0.5 + factor: float = math.sqrt(max(1.0, grad.size(0) / grad.size(1))) - return zero_power_via_newton_schulz_5(exp_avg.mul(1.0 / (1.0 - beta1)), eps=eps).mul_(factor) + return zero_power_via_newton_schulz_5(update.mul_(1.0 / (1.0 - beta1)), eps=eps).mul_(factor) def optimize_1d( self, @@ -162,13 +163,15 @@ def optimize_1d( exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) - update = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps) - update.div_(bias_correction2_sq).mul_(bias_correction1) + update = exp_avg.clone() if cautious: - self.apply_cautious(exp_avg, grad) + self.apply_cautious(update, grad) - return exp_avg.div(update) + de_nom = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps) + de_nom.div_(bias_correction2_sq).mul_(bias_correction1) + + return update.div_(de_nom) @torch.no_grad() def step(self, closure: CLOSURE = None) -> LOSS: diff --git a/pytorch_optimizer/optimizer/sgd.py b/pytorch_optimizer/optimizer/sgd.py index c3e668a8..df996be2 100644 --- a/pytorch_optimizer/optimizer/sgd.py +++ b/pytorch_optimizer/optimizer/sgd.py @@ -356,6 +356,9 @@ def __init__( } super().__init__(params, defaults) + def __str__(self) -> str: + return 'SignSGD' + @torch.no_grad() def reset(self): for group in self.param_groups: @@ -396,3 +399,123 @@ def step(self, closure: CLOSURE = None) -> LOSS: p.add_(torch.sign(buf), alpha=-group['lr']) return loss + + +class SGDSaI(BaseOptimizer): + r"""No More Adam: Learning Rate Scaling at Initialization is All You Need. + + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. + :param lr: float. learning rate. + :param momentum: float. coefficients used for computing running averages of gradient. + :param weight_decay: float. weight decay (L2 penalty). + :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. + :param eps: float. term added to the denominator to improve numerical stability. + """ + + def __init__( + self, + params: PARAMETERS, + lr: float = 1e-2, + momentum: float = 0.9, + weight_decay: float = 1e-2, + weight_decouple: bool = True, + eps: float = 1e-8, + **kwargs, + ): + self.validate_learning_rate(lr) + self.validate_range(momentum, 'beta', 0.0, 1.0) + self.validate_non_negative(weight_decay, 'weight_decay') + self.validate_non_negative(eps, 'eps') + + self.has_warmup: bool = False + + defaults: DEFAULTS = { + 'lr': lr, + 'momentum': momentum, + 'weight_decay': weight_decay, + 'weight_decouple': weight_decouple, + 'eps': eps, + } + super().__init__(params, defaults) + + def __str__(self) -> str: + return 'SGDSaI' + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + state = self.state[p] + + if group['momentum'] > 0.0: + state['momentum_buffer'] = torch.zeros_like(p) + + @torch.no_grad() + def warmup_step(self, closure: CLOSURE = None) -> LOSS: + loss: LOSS = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + if grad.is_sparse: + raise NoSparseGradientError(str(self)) + + sigma = grad.std().nan_to_num_() + grad_norm = grad.norm() + + g_snr = grad_norm.div_(sigma.add_(group['eps'])) if sigma != 0.0 else grad_norm + + self.state[p]['gsnr'] = g_snr + + self.has_warmup = True + + return loss + + @torch.no_grad() + def step(self, closure: CLOSURE = None) -> LOSS: + if not self.has_warmup: + self.warmup_step(closure) + + loss: LOSS = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + momentum: float = group['momentum'] + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad + + state = self.state[p] + + if momentum > 0.0: + if 'momentum_buffer' not in state: + state['momentum_buffer'] = grad.clone() + + buf = state['momentum_buffer'] + buf.mul_(momentum).add_(grad, alpha=1.0 - momentum) + else: + buf = grad + + self.apply_weight_decay( + p, + grad, + group['lr'], + group['weight_decay'], + group['weight_decouple'], + False, + ) + + p.add_(buf, alpha=-group['lr'] * state['gsnr']) + + return loss diff --git a/tests/constants.py b/tests/constants.py index 07a34bcd..f0969331 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -74,6 +74,7 @@ ScheduleFreeAdamW, ScheduleFreeRAdam, ScheduleFreeSGD, + SGDSaI, Shampoo, SignSGD, SophiaH, @@ -538,6 +539,8 @@ (MARS, {'lr': 1e-1, 'weight_decay': 1e-3, 'mars_type': 'lion', 'optimize_1d': True}, 5), (MARS, {'lr': 5e-1, 'lr_1d': 5e-1, 'weight_decay': 1e-3, 'mars_type': 'shampoo'}, 5), (MARS, {'lr': 5e-1, 'lr_1d': 5e-1, 'weight_decay': 1e-3, 'mars_type': 'adamw', 'ams_bound': True}, 5), + (SGDSaI, {'lr': 1e0}, 15), + (SGDSaI, {'lr': 1e0, 'momentum': 0.0}, 15), ] ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [ (AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10), diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index aaa70195..7f64b54e 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -34,7 +34,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names): def test_get_supported_optimizers(): - assert len(get_supported_optimizers()) == 84 + assert len(get_supported_optimizers()) == 85 assert len(get_supported_optimizers('adam*')) == 7 assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 9