diff --git a/README.md b/README.md index 939258fbf..a5c426e45 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **60 optimizers (+ `bitsandbytes`)**, **10 lr schedulers**, and **13 loss functions** are supported! +Currently, **61 optimizers (+ `bitsandbytes`)**, **10 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -91,71 +91,72 @@ from pytorch_optimizer import get_supported_optimizers supported_optimizers = get_supported_optimizers() ``` -| Optimizer | Description | Official Code | Paper | Citation | -|--------------|---------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------| -| AdaBelief | *Adapting Step-sizes by the Belief in Observed Gradients* | [github](https://github.com/juntang-zhuang/Adabelief-Optimizer) | | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv201007468Z/exportcitation) | -| AdaBound | *Adaptive Gradient Methods with Dynamic Bound of Learning Rate* | [github](https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py) | | [cite](https://github.com/Luolc/AdaBound#citing) | -| AdaHessian | *An Adaptive Second Order Optimizer for Machine Learning* | [github](https://github.com/amirgholami/adahessian) | | [cite](https://github.com/amirgholami/adahessian#citation) | -| AdamD | *Improved bias-correction in Adam* | | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv211010828S/exportcitation) | -| AdamP | *Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights* | [github](https://github.com/clovaai/AdamP) | | [cite](https://github.com/clovaai/AdamP#how-to-cite) | -| diffGrad | *An Optimization Method for Convolutional Neural Networks* | [github](https://github.com/shivram1987/diffGrad) | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv190911015D/exportcitation) | -| MADGRAD | *A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic* | [github](https://github.com/facebookresearch/madgrad) | | [cite](https://github.com/facebookresearch/madgrad#tech-report) | -| RAdam | *On the Variance of the Adaptive Learning Rate and Beyond* | [github](https://github.com/LiyuanLucasLiu/RAdam) | | [cite](https://github.com/LiyuanLucasLiu/RAdam#citation) | -| Ranger | *a synergistic optimizer combining RAdam and LookAhead, and now GC in one optimizer* | [github](https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer) | | [cite](https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer#citing-this-work) | -| Ranger21 | *a synergistic deep learning optimizer* | [github](https://github.com/lessw2020/Ranger21) | | [cite](https://github.com/lessw2020/Ranger21#referencing-this-work) | -| Lamb | *Large Batch Optimization for Deep Learning* | [github](https://github.com/cybertronai/pytorch-lamb) | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv190400962Y/exportcitation) | -| Shampoo | *Preconditioned Stochastic Tensor Optimization* | [github](https://github.com/moskomule/shampoo.pytorch) | | [cite](https://ui.adsabs.harvard.edu/abs/2018arXiv180209568G/exportcitation) | -| Nero | *Learning by Turning: Neural Architecture Aware Optimisation* | [github](https://github.com/jxbz/nero) | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210207227L/exportcitation) | -| Adan | *Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models* | [github](https://github.com/sail-sg/Adan) | | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220806677X/exportcitation) | -| Adai | *Disentangling the Effects of Adaptive Learning Rate and Momentum* | [github](https://github.com/zeke-xie/adaptive-inertia-adai) | | [cite](https://github.com/zeke-xie/adaptive-inertia-adai#citing) | -| SAM | *Sharpness-Aware Minimization* | [github](https://github.com/davda54/sam) | | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv201001412F/exportcitation) | -| ASAM | *Adaptive Sharpness-Aware Minimization* | [github](https://github.com/davda54/sam) | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210211600K/exportcitation) | -| GSAM | *Surrogate Gap Guided Sharpness-Aware Minimization* | [github](https://github.com/juntang-zhuang/GSAM) | | [cite](https://github.com/juntang-zhuang/GSAM#citation) | -| D-Adaptation | *Learning-Rate-Free Learning by D-Adaptation* | [github](https://github.com/facebookresearch/dadaptation) | | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230107733D/exportcitation) | -| AdaFactor | *Adaptive Learning Rates with Sublinear Memory Cost* | [github](https://github.com/DeadAt0m/adafactor-pytorch) | | [cite](https://ui.adsabs.harvard.edu/abs/2018arXiv180404235S/exportcitation) | -| Apollo | *An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization* | [github](https://github.com/XuezheMax/apollo) | | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv200913586M/exportcitation) | -| NovoGrad | *Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks* | [github](https://github.com/lonePatient/NovoGrad-pytorch) | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv190511286G/exportcitation) | -| Lion | *Symbolic Discovery of Optimization Algorithms* | [github](https://github.com/google/automl/tree/master/lion) | | [cite](https://github.com/google/automl/tree/master/lion#citation) | -| Ali-G | *Adaptive Learning Rates for Interpolation with Gradients* | [github](https://github.com/oval-group/ali-g) | | [cite](https://github.com/oval-group/ali-g#adaptive-learning-rates-for-interpolation-with-gradients) | -| SM3 | *Memory-Efficient Adaptive Optimization* | [github](https://github.com/google-research/google-research/tree/master/sm3) | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv190111150A/exportcitation) | -| AdaNorm | *Adaptive Gradient Norm Correction based Optimizer for CNNs* | [github](https://github.com/shivram1987/AdaNorm) | | [cite](https://github.com/shivram1987/AdaNorm/tree/main#citation) | -| RotoGrad | *Gradient Homogenization in Multitask Learning* | [github](https://github.com/adrianjav/rotograd) | | [cite](https://github.com/adrianjav/rotograd#citing) | -| A2Grad | *Optimal Adaptive and Accelerated Stochastic Gradient Descent* | [github](https://github.com/severilov/A2Grad_optimizer) | | [cite](https://ui.adsabs.harvard.edu/abs/2018arXiv181000553D/exportcitation) | -| AccSGD | *Accelerating Stochastic Gradient Descent For Least Squares Regression* | [github](https://github.com/rahulkidambi/AccSGD) | | [cite](https://github.com/rahulkidambi/AccSGD#citation) | -| SGDW | *Decoupled Weight Decay Regularization* | [github](https://github.com/loshchil/AdamW-and-SGDW) | | [cite](https://github.com/loshchil/AdamW-and-SGDW#contact) | -| ASGD | *Adaptive Gradient Descent without Descent* | [github](https://github.com/ymalitsky/adaptive_GD) | | [cite](https://github.com/ymalitsky/adaptive_GD#reference) | -| Yogi | *Adaptive Methods for Nonconvex Optimization* | | [NIPS 2018](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization) | [cite](https://proceedings.neurips.cc/paper_files/paper/2018/hash/90365351ccc7437a1309dc64e4db32a3-Abstract.html) | -| SWATS | *Improving Generalization Performance by Switching from Adam to SGD* | | | [cite](https://ui.adsabs.harvard.edu/abs/2017arXiv171207628S/exportcitation) | -| Fromage | *On the distance between two neural networks and the stability of learning* | [github](https://github.com/jxbz/fromage) | | [cite](https://github.com/jxbz/fromage#citation) | -| MSVAG | *Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients* | [github](https://github.com/lballes/msvag) | | [cite](https://github.com/lballes/msvag#citation) | -| AdaMod | *An Adaptive and Momental Bound Method for Stochastic Learning* | [github](https://github.com/lancopku/AdaMod) | | [cite](https://github.com/lancopku/AdaMod#citation) | -| AggMo | *Aggregated Momentum: Stability Through Passive Damping* | [github](https://github.com/AtheMathmo/AggMo) | | [cite](https://ui.adsabs.harvard.edu/abs/2018arXiv180400325L/exportcitation) | -| QHAdam | *Quasi-hyperbolic momentum and Adam for deep learning* | [github](https://github.com/facebookresearch/qhoptim) | | [cite](https://github.com/facebookresearch/qhoptim#reference) | -| PID | *A PID Controller Approach for Stochastic Optimization of Deep Networks* | [github](https://github.com/tensorboy/PIDOptimizer) | [CVPR 18](http://www4.comp.polyu.edu.hk/~cslzhang/paper/CVPR18_PID.pdf) | [cite](https://github.com/tensorboy/PIDOptimizer#citation) | -| Gravity | *a Kinematic Approach on Optimization in Deep Learning* | [github](https://github.com/dariush-bahrami/gravity.optimizer) | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210109192B/exportcitation) | -| AdaSmooth | *An Adaptive Learning Rate Method based on Effective Ratio* | | | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220400825L/exportcitation) | -| SRMM | *Stochastic regularized majorization-minimization with weakly convex and multi-convex surrogates* | [github](https://github.com/HanbaekLyu/SRMM) | | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220101652L/exportcitation) | -| AvaGrad | *Domain-independent Dominance of Adaptive Methods* | [github](https://github.com/lolemacs/avagrad) | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv191201823S/exportcitation) | -| PCGrad | *Gradient Surgery for Multi-Task Learning* | [github](https://github.com/tianheyu927/PCGrad) | | [cite](https://github.com/tianheyu927/PCGrad#reference) | -| AMSGrad | *On the Convergence of Adam and Beyond* | | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv190409237R/exportcitation) | -| Lookahead | *k steps forward, 1 step back* | [github](https://github.com/pytorch/examples/tree/main/imagenet) | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv190708610Z/exportcitation) | -| PNM | *Manipulating Stochastic Gradient Noise to Improve Generalization* | [github](https://github.com/zeke-xie/Positive-Negative-Momentum) | | [cite](https://github.com/zeke-xie/Positive-Negative-Momentum#citing) | -| GC | *Gradient Centralization* | [github](https://github.com/Yonghongwei/Gradient-Centralization) | | [cite](https://github.com/Yonghongwei/Gradient-Centralization#citation) | -| AGC | *Adaptive Gradient Clipping* | [github](https://github.com/deepmind/deepmind-research/tree/master/nfnets) | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210206171B/exportcitation) | -| Stable WD | *Understanding and Scheduling Weight Decay* | [github](https://github.com/zeke-xie/stable-weight-decay-regularization) | | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv201111152X/exportcitation) | -| Softplus T | *Calibrating the Adaptive Learning Rate to Improve Convergence of ADAM* | | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv190800700T/exportcitation) | -| Un-tuned w/u | *On the adequacy of untuned warmup for adaptive optimization* | | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv191004209M/exportcitation) | -| Norm Loss | *An efficient yet effective regularization method for deep neural networks* | | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210306583G/exportcitation) | -| AdaShift | *Decorrelation and Convergence of Adaptive Learning Rate Methods* | [github](https://github.com/MichaelKonobeev/adashift) | | [cite](https://ui.adsabs.harvard.edu/abs/2018arXiv181000143Z/exportcitation) | -| AdaDelta | *An Adaptive Learning Rate Method* | | | [cite](https://ui.adsabs.harvard.edu/abs/2012arXiv1212.5701Z/exportcitation) | -| Amos | *An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale* | [github](https://github.com/google-research/jestimator) | | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv221011693T/exportcitation) | -| SignSGD | *Compressed Optimisation for Non-Convex Problems* | [github](https://github.com/jxbz/signSGD) | | [cite](https://ui.adsabs.harvard.edu/abs/2018arXiv180204434B/exportcitation) | -| Sophia | *A Scalable Stochastic Second-order Optimizer for Language Model Pre-training* | [github](https://github.com/Liuhong99/Sophia) | | [cite](https://github.com/Liuhong99/Sophia) | -| Prodigy | *An Expeditiously Adaptive Parameter-Free Learner* | [github](https://github.com/konstmish/prodigy) | | [cite](https://github.com/konstmish/prodigy#how-to-cite) | -| PAdam | *Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks* | [github](https://github.com/uclaml/Padam) | | [cite](https://github.com/uclaml/Padam#citation) | -| LOMO | *Full Parameter Fine-tuning for Large Language Models with Limited Resources* | [github](https://github.com/OpenLMLab/LOMO) | | [cite](https://github.com/OpenLMLab/LOMO#citation) | -| Tiger | *A Tight-fisted Optimizer, an optimizer that is extremely budget-conscious* | [github](https://github.com/bojone/tiger) | | [cite](https://github.com/bojone/tiger/blob/main/README_en.md#citation) | -| CAME | *Confidence-guided Adaptive Memory Efficient Optimization* | [github](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/CAME) | | [cite](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/CAME#citation) | +| Optimizer | Description | Official Code | Paper | Citation | +|--------------|---------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------| +| AdaBelief | *Adapting Step-sizes by the Belief in Observed Gradients* | [github](https://github.com/juntang-zhuang/Adabelief-Optimizer) | | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv201007468Z/exportcitation) | +| AdaBound | *Adaptive Gradient Methods with Dynamic Bound of Learning Rate* | [github](https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py) | | [cite](https://github.com/Luolc/AdaBound#citing) | +| AdaHessian | *An Adaptive Second Order Optimizer for Machine Learning* | [github](https://github.com/amirgholami/adahessian) | | [cite](https://github.com/amirgholami/adahessian#citation) | +| AdamD | *Improved bias-correction in Adam* | | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv211010828S/exportcitation) | +| AdamP | *Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights* | [github](https://github.com/clovaai/AdamP) | | [cite](https://github.com/clovaai/AdamP#how-to-cite) | +| diffGrad | *An Optimization Method for Convolutional Neural Networks* | [github](https://github.com/shivram1987/diffGrad) | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv190911015D/exportcitation) | +| MADGRAD | *A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic* | [github](https://github.com/facebookresearch/madgrad) | | [cite](https://github.com/facebookresearch/madgrad#tech-report) | +| RAdam | *On the Variance of the Adaptive Learning Rate and Beyond* | [github](https://github.com/LiyuanLucasLiu/RAdam) | | [cite](https://github.com/LiyuanLucasLiu/RAdam#citation) | +| Ranger | *a synergistic optimizer combining RAdam and LookAhead, and now GC in one optimizer* | [github](https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer) | | [cite](https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer#citing-this-work) | +| Ranger21 | *a synergistic deep learning optimizer* | [github](https://github.com/lessw2020/Ranger21) | | [cite](https://github.com/lessw2020/Ranger21#referencing-this-work) | +| Lamb | *Large Batch Optimization for Deep Learning* | [github](https://github.com/cybertronai/pytorch-lamb) | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv190400962Y/exportcitation) | +| Shampoo | *Preconditioned Stochastic Tensor Optimization* | [github](https://github.com/moskomule/shampoo.pytorch) | | [cite](https://ui.adsabs.harvard.edu/abs/2018arXiv180209568G/exportcitation) | +| Nero | *Learning by Turning: Neural Architecture Aware Optimisation* | [github](https://github.com/jxbz/nero) | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210207227L/exportcitation) | +| Adan | *Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models* | [github](https://github.com/sail-sg/Adan) | | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220806677X/exportcitation) | +| Adai | *Disentangling the Effects of Adaptive Learning Rate and Momentum* | [github](https://github.com/zeke-xie/adaptive-inertia-adai) | | [cite](https://github.com/zeke-xie/adaptive-inertia-adai#citing) | +| SAM | *Sharpness-Aware Minimization* | [github](https://github.com/davda54/sam) | | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv201001412F/exportcitation) | +| ASAM | *Adaptive Sharpness-Aware Minimization* | [github](https://github.com/davda54/sam) | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210211600K/exportcitation) | +| GSAM | *Surrogate Gap Guided Sharpness-Aware Minimization* | [github](https://github.com/juntang-zhuang/GSAM) | | [cite](https://github.com/juntang-zhuang/GSAM#citation) | +| D-Adaptation | *Learning-Rate-Free Learning by D-Adaptation* | [github](https://github.com/facebookresearch/dadaptation) | | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230107733D/exportcitation) | +| AdaFactor | *Adaptive Learning Rates with Sublinear Memory Cost* | [github](https://github.com/DeadAt0m/adafactor-pytorch) | | [cite](https://ui.adsabs.harvard.edu/abs/2018arXiv180404235S/exportcitation) | +| Apollo | *An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization* | [github](https://github.com/XuezheMax/apollo) | | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv200913586M/exportcitation) | +| NovoGrad | *Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks* | [github](https://github.com/lonePatient/NovoGrad-pytorch) | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv190511286G/exportcitation) | +| Lion | *Symbolic Discovery of Optimization Algorithms* | [github](https://github.com/google/automl/tree/master/lion) | | [cite](https://github.com/google/automl/tree/master/lion#citation) | +| Ali-G | *Adaptive Learning Rates for Interpolation with Gradients* | [github](https://github.com/oval-group/ali-g) | | [cite](https://github.com/oval-group/ali-g#adaptive-learning-rates-for-interpolation-with-gradients) | +| SM3 | *Memory-Efficient Adaptive Optimization* | [github](https://github.com/google-research/google-research/tree/master/sm3) | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv190111150A/exportcitation) | +| AdaNorm | *Adaptive Gradient Norm Correction based Optimizer for CNNs* | [github](https://github.com/shivram1987/AdaNorm) | | [cite](https://github.com/shivram1987/AdaNorm/tree/main#citation) | +| RotoGrad | *Gradient Homogenization in Multitask Learning* | [github](https://github.com/adrianjav/rotograd) | | [cite](https://github.com/adrianjav/rotograd#citing) | +| A2Grad | *Optimal Adaptive and Accelerated Stochastic Gradient Descent* | [github](https://github.com/severilov/A2Grad_optimizer) | | [cite](https://ui.adsabs.harvard.edu/abs/2018arXiv181000553D/exportcitation) | +| AccSGD | *Accelerating Stochastic Gradient Descent For Least Squares Regression* | [github](https://github.com/rahulkidambi/AccSGD) | | [cite](https://github.com/rahulkidambi/AccSGD#citation) | +| SGDW | *Decoupled Weight Decay Regularization* | [github](https://github.com/loshchil/AdamW-and-SGDW) | | [cite](https://github.com/loshchil/AdamW-and-SGDW#contact) | +| ASGD | *Adaptive Gradient Descent without Descent* | [github](https://github.com/ymalitsky/adaptive_GD) | | [cite](https://github.com/ymalitsky/adaptive_GD#reference) | +| Yogi | *Adaptive Methods for Nonconvex Optimization* | | [NIPS 2018](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization) | [cite](https://proceedings.neurips.cc/paper_files/paper/2018/hash/90365351ccc7437a1309dc64e4db32a3-Abstract.html) | +| SWATS | *Improving Generalization Performance by Switching from Adam to SGD* | | | [cite](https://ui.adsabs.harvard.edu/abs/2017arXiv171207628S/exportcitation) | +| Fromage | *On the distance between two neural networks and the stability of learning* | [github](https://github.com/jxbz/fromage) | | [cite](https://github.com/jxbz/fromage#citation) | +| MSVAG | *Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients* | [github](https://github.com/lballes/msvag) | | [cite](https://github.com/lballes/msvag#citation) | +| AdaMod | *An Adaptive and Momental Bound Method for Stochastic Learning* | [github](https://github.com/lancopku/AdaMod) | | [cite](https://github.com/lancopku/AdaMod#citation) | +| AggMo | *Aggregated Momentum: Stability Through Passive Damping* | [github](https://github.com/AtheMathmo/AggMo) | | [cite](https://ui.adsabs.harvard.edu/abs/2018arXiv180400325L/exportcitation) | +| QHAdam | *Quasi-hyperbolic momentum and Adam for deep learning* | [github](https://github.com/facebookresearch/qhoptim) | | [cite](https://github.com/facebookresearch/qhoptim#reference) | +| PID | *A PID Controller Approach for Stochastic Optimization of Deep Networks* | [github](https://github.com/tensorboy/PIDOptimizer) | [CVPR 18](http://www4.comp.polyu.edu.hk/~cslzhang/paper/CVPR18_PID.pdf) | [cite](https://github.com/tensorboy/PIDOptimizer#citation) | +| Gravity | *a Kinematic Approach on Optimization in Deep Learning* | [github](https://github.com/dariush-bahrami/gravity.optimizer) | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210109192B/exportcitation) | +| AdaSmooth | *An Adaptive Learning Rate Method based on Effective Ratio* | | | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220400825L/exportcitation) | +| SRMM | *Stochastic regularized majorization-minimization with weakly convex and multi-convex surrogates* | [github](https://github.com/HanbaekLyu/SRMM) | | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220101652L/exportcitation) | +| AvaGrad | *Domain-independent Dominance of Adaptive Methods* | [github](https://github.com/lolemacs/avagrad) | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv191201823S/exportcitation) | +| PCGrad | *Gradient Surgery for Multi-Task Learning* | [github](https://github.com/tianheyu927/PCGrad) | | [cite](https://github.com/tianheyu927/PCGrad#reference) | +| AMSGrad | *On the Convergence of Adam and Beyond* | | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv190409237R/exportcitation) | +| Lookahead | *k steps forward, 1 step back* | [github](https://github.com/pytorch/examples/tree/main/imagenet) | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv190708610Z/exportcitation) | +| PNM | *Manipulating Stochastic Gradient Noise to Improve Generalization* | [github](https://github.com/zeke-xie/Positive-Negative-Momentum) | | [cite](https://github.com/zeke-xie/Positive-Negative-Momentum#citing) | +| GC | *Gradient Centralization* | [github](https://github.com/Yonghongwei/Gradient-Centralization) | | [cite](https://github.com/Yonghongwei/Gradient-Centralization#citation) | +| AGC | *Adaptive Gradient Clipping* | [github](https://github.com/deepmind/deepmind-research/tree/master/nfnets) | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210206171B/exportcitation) | +| Stable WD | *Understanding and Scheduling Weight Decay* | [github](https://github.com/zeke-xie/stable-weight-decay-regularization) | | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv201111152X/exportcitation) | +| Softplus T | *Calibrating the Adaptive Learning Rate to Improve Convergence of ADAM* | | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv190800700T/exportcitation) | +| Un-tuned w/u | *On the adequacy of untuned warmup for adaptive optimization* | | | [cite](https://ui.adsabs.harvard.edu/abs/2019arXiv191004209M/exportcitation) | +| Norm Loss | *An efficient yet effective regularization method for deep neural networks* | | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210306583G/exportcitation) | +| AdaShift | *Decorrelation and Convergence of Adaptive Learning Rate Methods* | [github](https://github.com/MichaelKonobeev/adashift) | | [cite](https://ui.adsabs.harvard.edu/abs/2018arXiv181000143Z/exportcitation) | +| AdaDelta | *An Adaptive Learning Rate Method* | | | [cite](https://ui.adsabs.harvard.edu/abs/2012arXiv1212.5701Z/exportcitation) | +| Amos | *An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale* | [github](https://github.com/google-research/jestimator) | | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv221011693T/exportcitation) | +| SignSGD | *Compressed Optimisation for Non-Convex Problems* | [github](https://github.com/jxbz/signSGD) | | [cite](https://ui.adsabs.harvard.edu/abs/2018arXiv180204434B/exportcitation) | +| Sophia | *A Scalable Stochastic Second-order Optimizer for Language Model Pre-training* | [github](https://github.com/Liuhong99/Sophia) | | [cite](https://github.com/Liuhong99/Sophia) | +| Prodigy | *An Expeditiously Adaptive Parameter-Free Learner* | [github](https://github.com/konstmish/prodigy) | | [cite](https://github.com/konstmish/prodigy#how-to-cite) | +| PAdam | *Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks* | [github](https://github.com/uclaml/Padam) | | [cite](https://github.com/uclaml/Padam#citation) | +| LOMO | *Full Parameter Fine-tuning for Large Language Models with Limited Resources* | [github](https://github.com/OpenLMLab/LOMO) | | [cite](https://github.com/OpenLMLab/LOMO#citation) | +| Tiger | *A Tight-fisted Optimizer, an optimizer that is extremely budget-conscious* | [github](https://github.com/bojone/tiger) | | [cite](https://github.com/bojone/tiger/blob/main/README_en.md#citation) | +| CAME | *Confidence-guided Adaptive Memory Efficient Optimization* | [github](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/CAME) | | [cite](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/CAME#citation) | +| WSAM | *Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term* | [github](https://github.com/intelligent-machine-learning/dlrover/blob/master/atorch/atorch/optimizers/wsam.py) | | [cite](https://github.com/intelligent-machine-learning/dlrover) | ## Supported LR Scheduler diff --git a/docs/optimizer.md b/docs/optimizer.md index 0216eff93..00da8ceb4 100644 --- a/docs/optimizer.md +++ b/docs/optimizer.md @@ -272,6 +272,10 @@ :docstring: :members: +::: pytorch_optimizer.WSAM + :docstring: + :members: + ::: pytorch_optimizer.Yogi :docstring: :members: diff --git a/pyproject.toml b/pyproject.toml index a2b9529ac..b27d0d90f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,8 +16,8 @@ keywords = [ "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "Gravity", "GSAM", "LARS", "Lamb", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH", - "SRMM", "SWATS", "Tiger", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", - "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", + "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", + "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", ] classifiers = [ "License :: OSI Approved :: Apache Software License", @@ -126,7 +126,6 @@ testpaths = "tests" [tool.coverage.run] omit = [ - "./pytorch_optimizer/optimizer/gsam.py", "./pytorch_optimizer/optimizer/rotograd.py", ] diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index 236b55d25..6ad519088 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -55,7 +55,6 @@ from pytorch_optimizer.optimizer.fromage import Fromage from pytorch_optimizer.optimizer.gc import centralize_gradient from pytorch_optimizer.optimizer.gravity import Gravity -from pytorch_optimizer.optimizer.gsam import GSAM from pytorch_optimizer.optimizer.lamb import Lamb from pytorch_optimizer.optimizer.lars import LARS from pytorch_optimizer.optimizer.lion import Lion @@ -76,7 +75,7 @@ from pytorch_optimizer.optimizer.ranger import Ranger from pytorch_optimizer.optimizer.ranger21 import Ranger21 from pytorch_optimizer.optimizer.rotograd import RotoGrad -from pytorch_optimizer.optimizer.sam import SAM +from pytorch_optimizer.optimizer.sam import GSAM, SAM, WSAM from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SignSGD from pytorch_optimizer.optimizer.sgdp import SGDP from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo diff --git a/pytorch_optimizer/optimizer/gsam.py b/pytorch_optimizer/optimizer/gsam.py deleted file mode 100644 index a0d663887..000000000 --- a/pytorch_optimizer/optimizer/gsam.py +++ /dev/null @@ -1,227 +0,0 @@ -from contextlib import ExitStack -from typing import Callable, Dict, Optional, Tuple - -import torch -from torch import nn -from torch.distributed import ReduceOp, all_reduce, get_world_size, is_initialized -from torch.optim.optimizer import Optimizer - -from pytorch_optimizer.base.optimizer import BaseOptimizer -from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, OPTIMIZER, PARAMETERS -from pytorch_optimizer.optimizer.utils import disable_running_stats, enable_running_stats - - -class GSAM(Optimizer, BaseOptimizer): - r"""Surrogate Gap Guided Sharpness-Aware Minimization. - - Example: - ------- - Here's an example:: - - model = YourModel() - base_optimizer = AdamP(model.parameters()) - lr_scheduler = LinearScheduler(base_optimizer, t_max=num_total_steps) - rho_scheduler = ProportionScheduler(lr_scheduler, max_lr=max_lr) - optimizer = GSAM(model.parameters(), base_optimizer, model, rho_scheduler) - - def loss_fn(predictions, targets): - return F.cross_entropy(predictions, targets) - - for inputs, targets in data: - optimizer.set_closure(loss_fn, inputs, targets) - predictions, loss = optimizer.step() - lr_scheduler.step() - optimizer.update_rho_t() - - :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. - :param base_optimizer: Optimizer. base optimizer. - :param model: nn.Module. model. - :param alpha: float. rho alpha. - :param rho_scheduler: rho scheduler. - :param adaptive: bool. element-wise Adaptive SAM. - :param perturb_eps: float. epsilon for perturbation. - :param kwargs: Dict. parameters for optimizer. - """ - - def __init__( - self, - params: PARAMETERS, - base_optimizer: OPTIMIZER, - model: nn.Module, - rho_scheduler, - alpha: float = 0.4, - adaptive: bool = False, - perturb_eps: float = 1e-12, - **kwargs, - ): - self.validate_range(alpha, 'alpha', 0.0, 1.0) - - self.model = model - self.rho_scheduler = rho_scheduler - self.alpha = alpha - self.adaptive = adaptive - self.perturb_eps = perturb_eps - - self.rho_t: float = 0.0 - self.forward_backward_func: Optional[Callable] = None - - if hasattr(ReduceOp, 'AVG'): - self.grad_reduce = ReduceOp.AVG - self.manual_average: bool = False - else: # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes - self.grad_reduce = ReduceOp.SUM - self.manual_average: bool = True - - self.base_optimizer = base_optimizer - self.param_groups = self.base_optimizer.param_groups - - defaults: DEFAULTS = {'adaptive': adaptive} - defaults.update(kwargs) - super().__init__(params, defaults) - - self.update_rho_t() - - def __str__(self) -> str: - return 'GSAM' - - @torch.no_grad() - def reset(self): - pass - - @torch.no_grad() - def update_rho_t(self) -> float: - self.rho_t = self.rho_scheduler.step() - return self.rho_t - - @torch.no_grad() - def perturb_weights(self, rho: float): - grad_norm = self.grad_norm(weight_adaptive=self.adaptive) - for group in self.param_groups: - scale = rho / (grad_norm + self.perturb_eps) - - for p in group['params']: - if p.grad is None: - continue - - self.state[p]['old_g'] = p.grad.clone() - - e_w = (torch.pow(p, 2) if self.adaptive else 1.0) * p.grad * scale.to(p) - p.add_(e_w) # climb to the local maximum "w + e(w)" - - self.state[p]['e_w'] = e_w - - @torch.no_grad() - def un_perturb(self): - for group in self.param_groups: - for p in group['params']: - if 'e_w' in self.state[p]: - p.sub_(self.state[p]['e_w']) - - @torch.no_grad() - def gradient_decompose(self, alpha: float = 0.0): - inner_prod = 0.0 - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - - inner_prod += torch.sum(self.state[p]['old_g'] * p.grad) - - new_grad_norm = self.grad_norm(by=None) - old_grad_norm = self.grad_norm(by='old_g') - - cosine = inner_prod / (new_grad_norm * old_grad_norm + self.perturb_eps) - - # gradient decomposition - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - - vertical = self.state[p]['old_g'] - cosine * old_grad_norm * p.grad / ( - new_grad_norm + self.perturb_eps - ) - p.grad.add_(vertical, alpha=-alpha) - - @torch.no_grad() - def sync_grad(self): - if is_initialized(): - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - - all_reduce(p.grad, op=self.grad_reduce) - if self.manual_average: - p.grad.div_(float(get_world_size())) - - @torch.no_grad() - def grad_norm(self, by: Optional[str] = None, weight_adaptive: bool = False) -> torch.Tensor: - return torch.norm( - torch.stack( - [ - ((torch.abs(p) if weight_adaptive else 1.0) * (p.grad if not by else self.state[p][by])).norm(p=2) - for group in self.param_groups - for p in group['params'] - if p.grad is not None - ] - ), - p=2, - ) - - def maybe_no_sync(self): - return self.model.no_sync() if is_initialized() else ExitStack() - - @torch.no_grad() - def set_closure(self, loss_fn: nn.Module, inputs: torch.Tensor, targets: torch.Tensor, **kwargs): - r"""Set closure. - - Create `self.forward_backward_func`, which is a function such that `self.forward_backward_func()` - automatically performs forward and backward passes. This function does not take any arguments, - and the inputs and targets data should be pre-set in the definition of partial-function. - - :param loss_fn: nn.Module. loss function. - :param inputs: torch.Tensor. inputs. - :param targets: torch.Tensor. targets. - """ - - def get_grad(): - self.base_optimizer.zero_grad() - with torch.enable_grad(): - outputs = self.model(inputs) - loss = loss_fn(outputs, targets, **kwargs) - - loss.backward() - - return outputs, loss.detach() - - self.forward_backward_func = get_grad - - @torch.no_grad() - def step(self, closure: CLOSURE = None) -> Tuple[torch.Tensor, float]: - get_grad = closure if closure else self.forward_backward_func - - with self.maybe_no_sync(): - outputs, loss = get_grad() - - self.perturb_weights(rho=self.rho_t) - - disable_running_stats(self.model) - - get_grad() - - self.gradient_decompose(self.alpha) - - self.un_perturb() - - self.sync_grad() - - self.base_optimizer.step() - - enable_running_stats(self.model) - - return outputs, loss - - def load_state_dict(self, state_dict: Dict): - super().load_state_dict(state_dict) - self.base_optimizer.param_groups = self.param_groups diff --git a/pytorch_optimizer/optimizer/sam.py b/pytorch_optimizer/optimizer/sam.py index 7ae417750..f27145159 100644 --- a/pytorch_optimizer/optimizer/sam.py +++ b/pytorch_optimizer/optimizer/sam.py @@ -1,11 +1,17 @@ -from typing import Dict +from contextlib import ExitStack +from typing import Callable, Dict, Optional, Tuple, Union import torch +from torch import nn +from torch.distributed import ReduceOp, all_reduce, get_world_size, is_initialized +from torch.nn.parallel import DistributedDataParallel +from torch.nn.utils import clip_grad_norm_ from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoClosureError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, OPTIMIZER, PARAMETERS +from pytorch_optimizer.optimizer.utils import disable_running_stats, enable_running_stats class SAM(Optimizer, BaseOptimizer): @@ -145,3 +151,377 @@ def grad_norm(self) -> torch.Tensor: def load_state_dict(self, state_dict: Dict): super().load_state_dict(state_dict) self.base_optimizer.param_groups = self.param_groups + + +class GSAM(Optimizer, BaseOptimizer): # pragma: no cover + r"""Surrogate Gap Guided Sharpness-Aware Minimization. + + Example: + ------- + Here's an example:: + + model = YourModel() + base_optimizer = AdamP(model.parameters()) + lr_scheduler = LinearScheduler(base_optimizer, t_max=num_total_steps) + rho_scheduler = ProportionScheduler(lr_scheduler, max_lr=max_lr) + optimizer = GSAM(model.parameters(), base_optimizer, model, rho_scheduler) + + def loss_fn(predictions, targets): + return F.cross_entropy(predictions, targets) + + for inputs, targets in data: + optimizer.set_closure(loss_fn, inputs, targets) + predictions, loss = optimizer.step() + lr_scheduler.step() + optimizer.update_rho_t() + + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. + :param base_optimizer: Optimizer. base optimizer. + :param model: nn.Module. model. + :param alpha: float. rho alpha. + :param rho_scheduler: rho scheduler. + :param adaptive: bool. element-wise Adaptive SAM. + :param perturb_eps: float. epsilon for perturbation. + :param kwargs: Dict. parameters for optimizer. + """ + + def __init__( + self, + params: PARAMETERS, + base_optimizer: OPTIMIZER, + model: nn.Module, + rho_scheduler, + alpha: float = 0.4, + adaptive: bool = False, + perturb_eps: float = 1e-12, + **kwargs, + ): + self.validate_range(alpha, 'alpha', 0.0, 1.0) + + self.model = model + self.rho_scheduler = rho_scheduler + self.alpha = alpha + self.adaptive = adaptive + self.perturb_eps = perturb_eps + + self.rho_t: float = 0.0 + self.forward_backward_func: Optional[Callable] = None + + if hasattr(ReduceOp, 'AVG'): + self.grad_reduce = ReduceOp.AVG + self.manual_average: bool = False + else: # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes + self.grad_reduce = ReduceOp.SUM + self.manual_average: bool = True + + self.base_optimizer = base_optimizer + self.param_groups = self.base_optimizer.param_groups + + defaults: DEFAULTS = {'adaptive': adaptive} + defaults.update(kwargs) + super().__init__(params, defaults) + + self.update_rho_t() + + def __str__(self) -> str: + return 'GSAM' + + @torch.no_grad() + def reset(self): + pass + + @torch.no_grad() + def update_rho_t(self) -> float: + self.rho_t = self.rho_scheduler.step() + return self.rho_t + + @torch.no_grad() + def perturb_weights(self, rho: float): + grad_norm = self.grad_norm(weight_adaptive=self.adaptive) + for group in self.param_groups: + scale = rho / (grad_norm + self.perturb_eps) + + for p in group['params']: + if p.grad is None: + continue + + self.state[p]['old_g'] = p.grad.clone() + + e_w = (torch.pow(p, 2) if self.adaptive else 1.0) * p.grad * scale.to(p) + p.add_(e_w) # climb to the local maximum "w + e(w)" + + self.state[p]['e_w'] = e_w + + @torch.no_grad() + def un_perturb(self): + for group in self.param_groups: + for p in group['params']: + if 'e_w' in self.state[p]: + p.sub_(self.state[p]['e_w']) + + @torch.no_grad() + def gradient_decompose(self, alpha: float = 0.0): + inner_prod = 0.0 + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + inner_prod += torch.sum(self.state[p]['old_g'] * p.grad) + + new_grad_norm = self.grad_norm(by=None) + old_grad_norm = self.grad_norm(by='old_g') + + cosine = inner_prod / (new_grad_norm * old_grad_norm + self.perturb_eps) + + # gradient decomposition + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + vertical = self.state[p]['old_g'] - cosine * old_grad_norm * p.grad / ( + new_grad_norm + self.perturb_eps + ) + p.grad.add_(vertical, alpha=-alpha) + + @torch.no_grad() + def sync_grad(self): + if is_initialized(): + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + all_reduce(p.grad, op=self.grad_reduce) + if self.manual_average: + p.grad.div_(float(get_world_size())) + + @torch.no_grad() + def grad_norm(self, by: Optional[str] = None, weight_adaptive: bool = False) -> torch.Tensor: + return torch.norm( + torch.stack( + [ + ((torch.abs(p) if weight_adaptive else 1.0) * (p.grad if not by else self.state[p][by])).norm(p=2) + for group in self.param_groups + for p in group['params'] + if p.grad is not None + ] + ), + p=2, + ) + + def maybe_no_sync(self): + return self.model.no_sync() if is_initialized() else ExitStack() + + @torch.no_grad() + def set_closure(self, loss_fn: nn.Module, inputs: torch.Tensor, targets: torch.Tensor, **kwargs): + r"""Set closure. + + Create `self.forward_backward_func`, which is a function such that `self.forward_backward_func()` + automatically performs forward and backward passes. This function does not take any arguments, + and the inputs and targets data should be pre-set in the definition of partial-function. + + :param loss_fn: nn.Module. loss function. + :param inputs: torch.Tensor. inputs. + :param targets: torch.Tensor. targets. + """ + + def get_grad(): + self.base_optimizer.zero_grad() + with torch.enable_grad(): + outputs = self.model(inputs) + loss = loss_fn(outputs, targets, **kwargs) + + loss.backward() + + return outputs, loss.detach() + + self.forward_backward_func = get_grad + + @torch.no_grad() + def step(self, closure: CLOSURE = None) -> Tuple[torch.Tensor, float]: + get_grad = closure if closure else self.forward_backward_func + + with self.maybe_no_sync(): + outputs, loss = get_grad() + + self.perturb_weights(rho=self.rho_t) + + disable_running_stats(self.model) + + get_grad() + + self.gradient_decompose(self.alpha) + + self.un_perturb() + + self.sync_grad() + + self.base_optimizer.step() + + enable_running_stats(self.model) + + return outputs, loss + + def load_state_dict(self, state_dict: Dict): + super().load_state_dict(state_dict) + self.base_optimizer.param_groups = self.param_groups + + +class WSAM(Optimizer, BaseOptimizer): + r"""Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term. + + :param model: Union[torch.nn.Module, torch.nn.DataParallel]. the model instance. DDP model is recommended to make + `model.no_sync` to work. + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. + :param base_optimizer: Optimizer. base optimizer. + :param rho: float. size of the neighborhood for computing the max loss. + :param gamma: float. weighted factor gamma / (1 - gamma) of the sharpness term. 0.8 ~ 0.95 is the optimal. + :param adaptive: bool. element-wise adaptive SAM. + :param decouple: bool. whether to perform a decoupled sharpness regularization. + :param max_norm: Optional[float]. max norm of the gradients. + :param eps: float. term added to the denominator of WSAM to improve numerical stability. + :param kwargs: Dict. parameters for optimizer. + """ + + def __init__( + self, + model: Union[nn.Module, DistributedDataParallel], + params: PARAMETERS, + base_optimizer: OPTIMIZER, + rho: float = 0.05, + gamma: float = 0.9, + adaptive: bool = False, + decouple: bool = True, + max_norm: Optional[float] = None, + eps: float = 1e-12, + **kwargs, + ): + self.validate_non_negative(rho, 'rho') + + self.model = model + self.decouple = decouple + self.max_norm = max_norm + + alpha: float = gamma / (1.0 - gamma) + + defaults: DEFAULTS = {'rho': rho, 'alpha': alpha, 'adaptive': adaptive, 'sam_eps': eps} + defaults.update(kwargs) + super().__init__(params, defaults) + + self.base_optimizer = base_optimizer(self.param_groups, **kwargs) + self.param_groups = self.base_optimizer.param_groups + + def __str__(self) -> str: + return 'WSAM' + + @torch.no_grad() + def reset(self): + pass + + @torch.no_grad() + def first_step(self, zero_grad: bool = False): + grad_norm = self.grad_norm() + for group in self.param_groups: + scale = group['rho'] / (grad_norm + group['sam_eps']) + + for p in group['params']: + if p.grad is None: + continue + + e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * p.grad * scale.to(p) + + # climb to the local maximum "w + e(w)" + p.add_(e_w) + + self.state[p]['e_w'] = e_w + + if is_initialized(): # pragma: no cover + all_reduce(p.grad, op=ReduceOp.AVG) + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + self.state[p]['grad'] = p.grad.clone() + + if zero_grad: + self.zero_grad() + + @torch.no_grad() + def second_step(self, zero_grad: bool = False): + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + if is_initialized(): # pragma: no cover + all_reduce(p.grad, ReduceOp.AVG) + + # get back to "w" from "w + e(w)" + p.add_(self.state[p]['e_w'], alpha=-1.0) + + if self.max_norm is not None: + clip_grad_norm_(self.model.parameters(), self.max_norm) + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + if not self.decouple: + p.grad.mul_(group['alpha']).add_(self.state[p]['grad'], alpha=1.0 - group['alpha']) + else: + self.state[p]['sharpness'] = p.grad.clone() - self.state[p]['grad'] + p.grad.mul_(0.0).add_(self.state[p]['grad'], alpha=1.0) + + # do the actual "sharpness-aware" update + self.base_optimizer.step() + + if self.decouple: + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + p.add_(self.state[p]['sharpness'], alpha=-group['lr'] * group['alpha']) + + if zero_grad: + self.zero_grad() + + @torch.no_grad() + def step(self, closure: CLOSURE = None): + if closure is None: + raise NoClosureError(str(self)) + + closure = torch.enable_grad()(closure) + + enable_running_stats(self.model) + loss = closure() + self.first_step(zero_grad=True) + + disable_running_stats(self.model) + closure() + self.second_step() + + return loss + + def grad_norm(self) -> torch.Tensor: + shared_device = self.param_groups[0]['params'][0].device + return torch.norm( + torch.stack( + [ + ((torch.abs(p) if group['adaptive'] else 1.0) * p.grad).norm(p=2).to(shared_device) + for group in self.param_groups + for p in group['params'] + if p.grad is not None + ] + ), + p=2, + ) + + def load_state_dict(self, state_dict: Dict): + super().load_state_dict(state_dict) + self.base_optimizer.param_groups = self.param_groups diff --git a/tests/constants.py b/tests/constants.py index bfc65edb9..4d53fbe09 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -64,6 +64,7 @@ ) from tests.utils import build_lookahead +DECOUPLE_FLAGS: List[bool] = [True, False] ADAPTIVE_FLAGS: List[bool] = [True, False] PULLBACK_MOMENTUM: List[str] = ['none', 'reset', 'pullback'] @@ -72,6 +73,7 @@ 'asam', 'sam', 'gsam', + 'wsam', 'pcgrad', 'lookahead', ] diff --git a/tests/test_gradients.py b/tests/test_gradients.py index b72551346..7fa8a594d 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -1,7 +1,7 @@ import pytest import torch -from pytorch_optimizer import SAM, AdamP, Lookahead, load_optimizer +from pytorch_optimizer import SAM, WSAM, AdamP, Lookahead, load_optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from tests.constants import NO_SPARSE_OPTIMIZERS, SPARSE_OPTIMIZERS, VALID_OPTIMIZER_NAMES from tests.utils import build_environment, simple_parameter, simple_sparse_parameter, sphere_loss @@ -140,6 +140,22 @@ def test_sam_no_gradient(): optimizer.second_step(zero_grad=True) +def test_wsam_no_gradient(): + (x_data, y_data), model, loss_fn = build_environment() + model.fc1.weight.requires_grad = False + model.fc1.weight.grad = None + + optimizer = WSAM(model, model.parameters(), AdamP) + optimizer.zero_grad() + + loss = loss_fn(y_data, model(x_data)) + loss.backward() + optimizer.first_step(zero_grad=True) + + loss_fn(y_data, model(x_data)).backward() + optimizer.second_step(zero_grad=True) + + @pytest.mark.parametrize('optimizer_name', ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD', 'DAdaptAdan', 'Prodigy']) def test_no_progression(optimizer_name): param = simple_parameter(True) diff --git a/tests/test_optimizer_parameters.py b/tests/test_optimizer_parameters.py index 3817d986e..26d7ff59c 100644 --- a/tests/test_optimizer_parameters.py +++ b/tests/test_optimizer_parameters.py @@ -2,7 +2,7 @@ import torch from torch import nn -from pytorch_optimizer import SAM, Lookahead, PCGrad, Ranger21, SafeFP16Optimizer, load_optimizer +from pytorch_optimizer import SAM, WSAM, Lookahead, PCGrad, Ranger21, SafeFP16Optimizer, load_optimizer from tests.constants import PULLBACK_MOMENTUM from tests.utils import Example, simple_parameter, simple_zero_rank_parameter @@ -54,11 +54,15 @@ def test_pcgrad_parameters(): def test_sam_parameters(): - # test rho with pytest.raises(ValueError): SAM(None, load_optimizer('adamp'), rho=-0.1) +def test_wsam_parameters(): + with pytest.raises(ValueError): + WSAM(None, None, load_optimizer('adamp'), rho=-0.1) + + def test_lookahead_parameters(): optimizer = load_optimizer('adamp')([simple_parameter()]) @@ -91,6 +95,12 @@ def test_sam_methods(): optimizer.load_state_dict(optimizer.state_dict()) +def test_wsam_methods(): + optimizer = WSAM(None, [simple_parameter()], load_optimizer('adamp')) + optimizer.reset() + optimizer.load_state_dict(optimizer.state_dict()) + + def test_safe_fp16_methods(): optimizer = SafeFP16Optimizer(load_optimizer('adamp')([simple_parameter()], lr=5e-1)) optimizer.load_state_dict(optimizer.state_dict()) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index fc57f06a0..b7e7c99d0 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -6,6 +6,7 @@ from pytorch_optimizer import ( GSAM, SAM, + WSAM, CosineScheduler, DynamicLossScaler, Lookahead, @@ -19,6 +20,7 @@ ADAMD_SUPPORTED_OPTIMIZERS, ADANORM_SUPPORTED_OPTIMIZERS, ADAPTIVE_FLAGS, + DECOUPLE_FLAGS, OPTIMIZERS, PULLBACK_MOMENTUM, ) @@ -106,7 +108,7 @@ def test_lookahead(pullback_momentum, environment): @pytest.mark.parametrize('adaptive', ADAPTIVE_FLAGS) -def test_sam_optimizers(adaptive, environment): +def test_sam_optimizer(adaptive, environment): (x_data, y_data), model, loss_fn = environment optimizer = SAM(model.parameters(), load_optimizer('asgd'), lr=5e-1, adaptive=adaptive) @@ -127,7 +129,7 @@ def test_sam_optimizers(adaptive, environment): @pytest.mark.parametrize('adaptive', ADAPTIVE_FLAGS) -def test_sam_optimizers_with_closure(adaptive, environment): +def test_sam_optimizer_with_closure(adaptive, environment): (x_data, y_data), model, loss_fn = environment optimizer = SAM(model.parameters(), load_optimizer('adamp'), lr=5e-1, adaptive=adaptive) @@ -152,7 +154,60 @@ def closure(): @pytest.mark.parametrize('adaptive', ADAPTIVE_FLAGS) -def test_gsam_optimizers(adaptive, environment): +@pytest.mark.parametrize('decouple', DECOUPLE_FLAGS) +def test_wsam_optimizer(adaptive, decouple, environment): + (x_data, y_data), model, loss_fn = environment + + optimizer = WSAM( + model, + model.parameters(), + load_optimizer('adamp'), + lr=5e-2, + adaptive=adaptive, + decouple=decouple, + max_norm=100.0, + ) + + init_loss, loss = np.inf, np.inf + for _ in range(10): + loss = loss_fn(y_data, model(x_data)) + loss.backward() + optimizer.first_step(zero_grad=True) + + loss_fn(y_data, model(x_data)).backward() + optimizer.second_step(zero_grad=True) + + if init_loss == np.inf: + init_loss = loss + + assert tensor_to_numpy(init_loss) > 1.5 * tensor_to_numpy(loss) + + +@pytest.mark.parametrize('adaptive', ADAPTIVE_FLAGS) +def test_wsam_optimizer_with_closure(adaptive, environment): + (x_data, y_data), model, loss_fn = environment + + optimizer = WSAM(model, model.parameters(), load_optimizer('adamp'), lr=5e-2, adaptive=adaptive, max_norm=100.0) + + def closure(): + output = model(x_data) + loss = loss_fn(output, y_data) + loss.backward() + return loss + + init_loss, loss = np.inf, np.inf + for _ in range(10): + loss = optimizer.step(closure) + optimizer.zero_grad() + + if init_loss == np.inf: + init_loss = loss + + assert tensor_to_numpy(init_loss) > 1.5 * tensor_to_numpy(loss) + + +@pytest.mark.parametrize('adaptive', ADAPTIVE_FLAGS) +def test_gsam_optimizer(adaptive, environment): pytest.skip('skip GSAM optimizer') (x_data, y_data), model, loss_fn = environment @@ -182,7 +237,7 @@ def test_gsam_optimizers(adaptive, environment): @pytest.mark.parametrize('optimizer_config', ADANORM_SUPPORTED_OPTIMIZERS, ids=ids) -def test_adanorm_optimizers(optimizer_config, environment): +def test_adanorm_optimizer(optimizer_config, environment): (x_data, y_data), model, loss_fn = environment optimizer_class, config, num_iterations = optimizer_config @@ -306,6 +361,12 @@ def test_no_closure(): with pytest.raises(NoClosureError): optimizer.step() + optimizer = WSAM(None, [param], load_optimizer('adamp')) + optimizer.zero_grad() + + with pytest.raises(NoClosureError): + optimizer.step() + def test_nero_zero_scale(): param = simple_parameter()