From 68d660e8c6c26a822267bab841f7e08a23de3ebb Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 21 Jul 2024 16:18:11 +0900 Subject: [PATCH 01/14] refactor: debias, debias_adam --- Makefile | 4 ++-- pytorch_optimizer/base/optimizer.py | 21 +++++++++++++++++++++ pytorch_optimizer/optimizer/adabelief.py | 4 ++-- pytorch_optimizer/optimizer/adabound.py | 4 ++-- pytorch_optimizer/optimizer/adahessian.py | 4 ++-- pytorch_optimizer/optimizer/adai.py | 4 ++-- pytorch_optimizer/optimizer/adam_mini.py | 4 ++-- pytorch_optimizer/optimizer/adamax.py | 2 +- pytorch_optimizer/optimizer/adamod.py | 4 ++-- pytorch_optimizer/optimizer/adamp.py | 4 ++-- pytorch_optimizer/optimizer/adams.py | 6 +++--- pytorch_optimizer/optimizer/adamw.py | 5 ++--- pytorch_optimizer/optimizer/adan.py | 6 +++--- pytorch_optimizer/optimizer/adanorm.py | 4 ++-- pytorch_optimizer/optimizer/adapnm.py | 5 +++-- pytorch_optimizer/optimizer/adashift.py | 2 +- pytorch_optimizer/optimizer/aida.py | 4 ++-- pytorch_optimizer/optimizer/amos.py | 2 +- pytorch_optimizer/optimizer/apollo.py | 2 +- pytorch_optimizer/optimizer/avagrad.py | 6 +++--- pytorch_optimizer/optimizer/diffgrad.py | 2 +- pytorch_optimizer/optimizer/fadam.py | 2 +- pytorch_optimizer/optimizer/galore.py | 4 ++-- pytorch_optimizer/optimizer/grokfast.py | 4 ++-- pytorch_optimizer/optimizer/lamb.py | 2 +- pytorch_optimizer/optimizer/novograd.py | 4 ++-- pytorch_optimizer/optimizer/padam.py | 4 ++-- pytorch_optimizer/optimizer/prodigy.py | 4 ++-- pytorch_optimizer/optimizer/radam.py | 2 +- pytorch_optimizer/optimizer/ranger.py | 3 ++- pytorch_optimizer/optimizer/ranger21.py | 2 +- pytorch_optimizer/optimizer/swats.py | 4 ++-- pytorch_optimizer/optimizer/utils.py | 11 ----------- pytorch_optimizer/optimizer/yogi.py | 4 ++-- 34 files changed, 80 insertions(+), 69 deletions(-) diff --git a/Makefile b/Makefile index 4dc0430d0..89b7fcdef 100644 --- a/Makefile +++ b/Makefile @@ -16,8 +16,8 @@ check: ruff check pytorch_optimizer examples tests hubconf.py requirements: - python -m poetry export -f requirements.txt --output requirements.txt --without-hashes - python -m poetry export -f requirements.txt --output requirements-dev.txt --without-hashes --with dev + poetry export -f requirements.txt --output requirements.txt --without-hashes + poetry export -f requirements.txt --output requirements-dev.txt --without-hashes --with dev docs: mkdocs serve diff --git a/pytorch_optimizer/base/optimizer.py b/pytorch_optimizer/base/optimizer.py index 655fba7e4..a2ef6ff7d 100644 --- a/pytorch_optimizer/base/optimizer.py +++ b/pytorch_optimizer/base/optimizer.py @@ -145,6 +145,27 @@ def apply_ams_bound( return de_nom.sqrt_().add_(eps) + @staticmethod + def debias(beta: float, step: int) -> float: + r"""Adam-style debias correction. Returns `1.0 - beta ** step`. + + :param beta: float. beta. + :param step. int. number of step. + """ + return 1.0 - math.pow(beta, step) # fmt: skip + + @staticmethod + def debias_beta(beta: float, step: int) -> float: + r"""Apply the Adam-style debias correction into beta. + + Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)` + + :param beta: float. beta. + :param step: int. number of step. + """ + beta_n: float = math.pow(beta, step) + return (beta_n - beta) / (beta_n - 1.0) # fmt: skip + @staticmethod def apply_adam_debias(adam_debias: bool, step_size: float, bias_correction1: float) -> float: r"""Apply AdamD variant. diff --git a/pytorch_optimizer/optimizer/adabelief.py b/pytorch_optimizer/optimizer/adabelief.py index 41c8a43ab..06d681f39 100644 --- a/pytorch_optimizer/optimizer/adabelief.py +++ b/pytorch_optimizer/optimizer/adabelief.py @@ -101,8 +101,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) step_size, n_sma = self.get_rectify_step_size( is_rectify=group['rectify'], diff --git a/pytorch_optimizer/optimizer/adabound.py b/pytorch_optimizer/optimizer/adabound.py index 629bf24e4..7ae5dbf48 100644 --- a/pytorch_optimizer/optimizer/adabound.py +++ b/pytorch_optimizer/optimizer/adabound.py @@ -90,8 +90,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) final_lr: float = group['final_lr'] * group['lr'] / base_lr lower_bound: float = final_lr * (1 - 1 / (group['gamma'] * group['step'] + 1)) diff --git a/pytorch_optimizer/optimizer/adahessian.py b/pytorch_optimizer/optimizer/adahessian.py index fe152cc6e..aabb4bb10 100644 --- a/pytorch_optimizer/optimizer/adahessian.py +++ b/pytorch_optimizer/optimizer/adahessian.py @@ -104,8 +104,8 @@ def step(self, closure: CLOSURE = None, hessian: Optional[List[torch.Tensor]] = beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2: float = 1.0 - beta2 ** group['step'] + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2: float = self.debias(beta2, group['step']) for p in group['params']: if p.grad is None: diff --git a/pytorch_optimizer/optimizer/adai.py b/pytorch_optimizer/optimizer/adai.py index 793ded4e2..0fcfc5270 100644 --- a/pytorch_optimizer/optimizer/adai.py +++ b/pytorch_optimizer/optimizer/adai.py @@ -105,7 +105,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: if self.use_gc: centralize_gradient(grad, gc_conv_only=False) - bias_correction2: float = 1.0 - beta2 ** state['step'] + bias_correction2: float = self.debias(beta2, state['step']) if not group['stable_weight_decay'] and group['weight_decay'] > 0.0: self.apply_weight_decay( @@ -148,7 +148,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: fixed_decay=group['fixed_decay'], ) - bias_correction2: float = 1.0 - beta2 ** state['step'] + bias_correction2: float = self.debias(beta2, state['step']) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] diff --git a/pytorch_optimizer/optimizer/adam_mini.py b/pytorch_optimizer/optimizer/adam_mini.py index 043ae8439..ac7f65a99 100644 --- a/pytorch_optimizer/optimizer/adam_mini.py +++ b/pytorch_optimizer/optimizer/adam_mini.py @@ -276,8 +276,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2: float = 1.0 - beta2 ** group['step'] + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2: float = self.debias(beta2, group['step']) bias_correction2_sq: float = math.sqrt(bias_correction2) for p in group['params']: diff --git a/pytorch_optimizer/optimizer/adamax.py b/pytorch_optimizer/optimizer/adamax.py index bc24d64dd..d8c25b154 100644 --- a/pytorch_optimizer/optimizer/adamax.py +++ b/pytorch_optimizer/optimizer/adamax.py @@ -84,7 +84,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] + bias_correction1: float = self.debias(beta1, group['step']) for p in group['params']: if p.grad is None: diff --git a/pytorch_optimizer/optimizer/adamod.py b/pytorch_optimizer/optimizer/adamod.py index 156306680..41f76b11f 100644 --- a/pytorch_optimizer/optimizer/adamod.py +++ b/pytorch_optimizer/optimizer/adamod.py @@ -78,8 +78,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2, beta3 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) for p in group['params']: if p.grad is None: diff --git a/pytorch_optimizer/optimizer/adamp.py b/pytorch_optimizer/optimizer/adamp.py index 1b2d9ecc5..9bfbfa18b 100644 --- a/pytorch_optimizer/optimizer/adamp.py +++ b/pytorch_optimizer/optimizer/adamp.py @@ -103,8 +103,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) for p in group['params']: if p.grad is None: diff --git a/pytorch_optimizer/optimizer/adams.py b/pytorch_optimizer/optimizer/adams.py index eb4fa8948..3ee189a7c 100644 --- a/pytorch_optimizer/optimizer/adams.py +++ b/pytorch_optimizer/optimizer/adams.py @@ -111,7 +111,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: state['step'] += 1 - bias_correction2: float = 1.0 - beta2 ** state['step'] + bias_correction2: float = self.debias(beta2, state['step']) s_grad = self.get_adanorm_gradient( grad=grad, @@ -156,8 +156,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: ratio=1.0 / exp_avg_sq_hat_mean, ) - bias_correction1: float = 1.0 - beta1 ** state['step'] - bias_correction2: float = 1.0 - beta2 ** state['step'] + bias_correction1: float = self.debias(beta1, state['step']) + bias_correction2: float = self.debias(beta2, state['step']) exp_avg_sq_hat = state['max_exp_avg_sq'] if group['ams_bound'] else state['exp_avg_sq'] exp_avg_sq_hat.div_(bias_correction2) diff --git a/pytorch_optimizer/optimizer/adamw.py b/pytorch_optimizer/optimizer/adamw.py index 02e4af5ca..158d362a8 100644 --- a/pytorch_optimizer/optimizer/adamw.py +++ b/pytorch_optimizer/optimizer/adamw.py @@ -6,7 +6,6 @@ from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -from pytorch_optimizer.optimizer.utils import debias_beta class StableAdamW(Optimizer, BaseOptimizer): @@ -80,8 +79,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - beta1_comp: float = 1.0 - debias_beta(beta1, group['step']) - beta2_hat: float = debias_beta(beta2, group['step']) + beta1_comp: float = 1.0 - self.debias_beta(beta1, group['step']) + beta2_hat: float = self.debias_beta(beta2, group['step']) eps_p2: float = math.pow(group['eps'], 2) diff --git a/pytorch_optimizer/optimizer/adan.py b/pytorch_optimizer/optimizer/adan.py index ff58262f4..b71ea296f 100644 --- a/pytorch_optimizer/optimizer/adan.py +++ b/pytorch_optimizer/optimizer/adan.py @@ -106,9 +106,9 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2, beta3 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2: float = 1.0 - beta2 ** group['step'] - bias_correction3_sq: float = math.sqrt(1.0 - beta3 ** group['step']) + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2: float = self.debias(beta2, group['step']) + bias_correction3_sq: float = math.sqrt(self.debias(beta3, group['step'])) for p in group['params']: if p.grad is None: diff --git a/pytorch_optimizer/optimizer/adanorm.py b/pytorch_optimizer/optimizer/adanorm.py index 2200f4551..84bb58382 100644 --- a/pytorch_optimizer/optimizer/adanorm.py +++ b/pytorch_optimizer/optimizer/adanorm.py @@ -85,8 +85,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) for p in group['params']: if p.grad is None: diff --git a/pytorch_optimizer/optimizer/adapnm.py b/pytorch_optimizer/optimizer/adapnm.py index c08de1314..9fec329fc 100644 --- a/pytorch_optimizer/optimizer/adapnm.py +++ b/pytorch_optimizer/optimizer/adapnm.py @@ -93,8 +93,9 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2, beta3 = group['betas'] noise_norm: float = math.sqrt((1 + beta3) ** 2 + beta3 ** 2) # fmt: skip - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) for p in group['params']: if p.grad is None: diff --git a/pytorch_optimizer/optimizer/adashift.py b/pytorch_optimizer/optimizer/adashift.py index 1d673e1e2..a577b5797 100644 --- a/pytorch_optimizer/optimizer/adashift.py +++ b/pytorch_optimizer/optimizer/adashift.py @@ -73,7 +73,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: first_grad_weight: float = beta1 ** (group['keep_num'] - 1) / exp_weight_sum last_grad_weight: float = 1.0 / exp_weight_sum - bias_correction: float = 1.0 - beta2 ** (group['step'] - group['keep_num']) + bias_correction: float = self.debias(beta2, group['step'] - group['keep_num']) for p in group['params']: if p.grad is None: diff --git a/pytorch_optimizer/optimizer/aida.py b/pytorch_optimizer/optimizer/aida.py index 4d5417bc2..ffe2532c4 100644 --- a/pytorch_optimizer/optimizer/aida.py +++ b/pytorch_optimizer/optimizer/aida.py @@ -109,8 +109,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) step_size, n_sma = self.get_rectify_step_size( is_rectify=group['rectify'], diff --git a/pytorch_optimizer/optimizer/amos.py b/pytorch_optimizer/optimizer/amos.py index a772da0db..854e9af71 100644 --- a/pytorch_optimizer/optimizer/amos.py +++ b/pytorch_optimizer/optimizer/amos.py @@ -92,7 +92,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: momentum, beta = group['momentum'], group['beta'] lr_sq: float = math.sqrt(group['lr']) - bias_correction: float = 1.0 - beta ** group['step'] + bias_correction: float = self.debias(beta, group['step']) for p in group['params']: if p.grad is None: diff --git a/pytorch_optimizer/optimizer/apollo.py b/pytorch_optimizer/optimizer/apollo.py index ac20ea22f..2dcef87d7 100644 --- a/pytorch_optimizer/optimizer/apollo.py +++ b/pytorch_optimizer/optimizer/apollo.py @@ -92,7 +92,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: weight_decay, eps = group['weight_decay'], group['eps'] - bias_correction: float = 1.0 - group['beta'] ** group['step'] + bias_correction: float = self.debias(group['beta'], group['step']) alpha: float = (1.0 - group['beta']) / bias_correction for p in group['params']: diff --git a/pytorch_optimizer/optimizer/avagrad.py b/pytorch_optimizer/optimizer/avagrad.py index c8c2080b7..483492472 100644 --- a/pytorch_optimizer/optimizer/avagrad.py +++ b/pytorch_optimizer/optimizer/avagrad.py @@ -78,9 +78,9 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) - prev_bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** (group['step'] - 1)) + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) + prev_bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'] - 1)) squared_norm: float = 0.0 num_params: float = 0.0 diff --git a/pytorch_optimizer/optimizer/diffgrad.py b/pytorch_optimizer/optimizer/diffgrad.py index 279dda54b..dbcac9d67 100644 --- a/pytorch_optimizer/optimizer/diffgrad.py +++ b/pytorch_optimizer/optimizer/diffgrad.py @@ -100,7 +100,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] + bias_correction1: float = self.debias(beta1, group['step']) step_size, n_sma = self.get_rectify_step_size( is_rectify=group['rectify'], diff --git a/pytorch_optimizer/optimizer/fadam.py b/pytorch_optimizer/optimizer/fadam.py index 26ee108f5..4d8d5056e 100644 --- a/pytorch_optimizer/optimizer/fadam.py +++ b/pytorch_optimizer/optimizer/fadam.py @@ -81,7 +81,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - curr_beta2: float = beta2 * (1 - beta2 ** (group['step'] - 1)) / (1 - beta2 ** group['step']) + curr_beta2: float = self.debias_beta(beta2, group['step']) for p in group['params']: if p.grad is None: diff --git a/pytorch_optimizer/optimizer/galore.py b/pytorch_optimizer/optimizer/galore.py index ced850c2d..222c332ed 100644 --- a/pytorch_optimizer/optimizer/galore.py +++ b/pytorch_optimizer/optimizer/galore.py @@ -194,8 +194,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) step_size: float = group['lr'] * bias_correction2_sq / bias_correction1 diff --git a/pytorch_optimizer/optimizer/grokfast.py b/pytorch_optimizer/optimizer/grokfast.py index 22b57a9db..47a3dfede 100644 --- a/pytorch_optimizer/optimizer/grokfast.py +++ b/pytorch_optimizer/optimizer/grokfast.py @@ -181,8 +181,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) should_grokfast: bool = ( group['grokfast'] and group['step'] > group['grokfast_after_step'] and group['grokfast_lamb'] > 0 diff --git a/pytorch_optimizer/optimizer/lamb.py b/pytorch_optimizer/optimizer/lamb.py index 8c5f47b70..f5fbabca9 100644 --- a/pytorch_optimizer/optimizer/lamb.py +++ b/pytorch_optimizer/optimizer/lamb.py @@ -129,7 +129,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] beta3: float = 1.0 - beta1 if group['grad_averaging'] else 1.0 - bias_correction1: float = 1.0 - beta1 ** group['step'] + bias_correction1: float = self.debias(beta1, group['step']) step_size, n_sma = self.get_rectify_step_size( is_rectify=group['rectify'], diff --git a/pytorch_optimizer/optimizer/novograd.py b/pytorch_optimizer/optimizer/novograd.py index dddf0463e..957f94180 100644 --- a/pytorch_optimizer/optimizer/novograd.py +++ b/pytorch_optimizer/optimizer/novograd.py @@ -83,8 +83,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) step_size: float = group['lr'] * bias_correction2_sq if not group['adam_debias']: diff --git a/pytorch_optimizer/optimizer/padam.py b/pytorch_optimizer/optimizer/padam.py index d93ccb52d..bcfe0a31e 100644 --- a/pytorch_optimizer/optimizer/padam.py +++ b/pytorch_optimizer/optimizer/padam.py @@ -77,8 +77,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) for p in group['params']: if p.grad is None: diff --git a/pytorch_optimizer/optimizer/prodigy.py b/pytorch_optimizer/optimizer/prodigy.py index 6853e11f4..dd456d129 100644 --- a/pytorch_optimizer/optimizer/prodigy.py +++ b/pytorch_optimizer/optimizer/prodigy.py @@ -102,8 +102,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] beta3 = group['beta3'] if group['beta3'] is not None else math.sqrt(beta2) - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) bias_correction: float = (bias_correction1 / bias_correction2_sq) if group['bias_correction'] else 1.0 d, d0 = group['d'], group['d0'] diff --git a/pytorch_optimizer/optimizer/radam.py b/pytorch_optimizer/optimizer/radam.py index f6741c087..22d1b5180 100644 --- a/pytorch_optimizer/optimizer/radam.py +++ b/pytorch_optimizer/optimizer/radam.py @@ -91,7 +91,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] + bias_correction1: float = self.debias(beta1, group['step']) step_size, n_sma = self.get_rectify_step_size( is_rectify=True, diff --git a/pytorch_optimizer/optimizer/ranger.py b/pytorch_optimizer/optimizer/ranger.py index 6cd13d4ee..45a4277d6 100644 --- a/pytorch_optimizer/optimizer/ranger.py +++ b/pytorch_optimizer/optimizer/ranger.py @@ -106,7 +106,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: group['step'] = 1 beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] + + bias_correction1: float = self.debias(beta1, group['step']) step_size, n_sma = self.get_rectify_step_size( is_rectify=True, diff --git a/pytorch_optimizer/optimizer/ranger21.py b/pytorch_optimizer/optimizer/ranger21.py index 9b112e58b..1992d1da5 100644 --- a/pytorch_optimizer/optimizer/ranger21.py +++ b/pytorch_optimizer/optimizer/ranger21.py @@ -198,7 +198,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction2: float = 1.0 - beta2 ** group['step'] + bias_correction2: float = self.debias(beta2, group['step']) for p in group['params']: if p.grad is None: diff --git a/pytorch_optimizer/optimizer/swats.py b/pytorch_optimizer/optimizer/swats.py index 172454e72..0761479db 100644 --- a/pytorch_optimizer/optimizer/swats.py +++ b/pytorch_optimizer/optimizer/swats.py @@ -96,8 +96,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2: float = 1.0 - beta2 ** group['step'] + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2: float = self.debias(beta2, group['step']) for p in group['params']: if p.grad is None: diff --git a/pytorch_optimizer/optimizer/utils.py b/pytorch_optimizer/optimizer/utils.py index df96bb908..e1b75dbea 100644 --- a/pytorch_optimizer/optimizer/utils.py +++ b/pytorch_optimizer/optimizer/utils.py @@ -36,17 +36,6 @@ def is_deepspeed_zero3_enabled() -> bool: return False -def debias_beta(beta: float, step: int) -> float: - r"""Apply the Adam-style debias correction into beta. - - Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)` - - :param beta: float. beta. - :param step: int. number of step. - """ - return (beta ** step - beta) / (beta ** step - 1.0) # fmt: skip - - def is_valid_parameters(parameters: PARAMETERS) -> bool: r"""Check where the parameters are valid.""" return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], dict) diff --git a/pytorch_optimizer/optimizer/yogi.py b/pytorch_optimizer/optimizer/yogi.py index 9188c89f3..24a0f48b6 100644 --- a/pytorch_optimizer/optimizer/yogi.py +++ b/pytorch_optimizer/optimizer/yogi.py @@ -89,8 +89,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] - bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) for p in group['params']: if p.grad is None: From f5c109f5491e796ecb0aeb46d952511f876fb5a1 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 21 Jul 2024 16:19:05 +0900 Subject: [PATCH 02/14] docs: v3.1.0 changelog --- docs/changelogs/v3.1.0.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelogs/v3.1.0.md b/docs/changelogs/v3.1.0.md index 19e423d47..78ada9684 100644 --- a/docs/changelogs/v3.1.0.md +++ b/docs/changelogs/v3.1.0.md @@ -18,6 +18,7 @@ * Deprecate optional dependency, `bitsandbytes`. (#258) * Move `get_rms`, `approximate_sq_grad` functions to `BaseOptimizer` for reusability. (#258) * Refactor `shampoo_utils.py`. (#259) +* Add `debias`, `debias_adam` methods in `BaseOptimizer`. (#261) ### Bug From 21c16908d2994edbf2a4a56608911ead20db0f41 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 21 Jul 2024 18:48:12 +0900 Subject: [PATCH 03/14] refactor: get_adanorm_gradient --- pytorch_optimizer/base/optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_optimizer/base/optimizer.py b/pytorch_optimizer/base/optimizer.py index a2ef6ff7d..4a97f0168 100644 --- a/pytorch_optimizer/base/optimizer.py +++ b/pytorch_optimizer/base/optimizer.py @@ -226,14 +226,14 @@ def get_adanorm_gradient( :param exp_grad_norm: Optional[torch.Tensor]. exp_grad_norm. :param r: float. Optional[float]. momentum (ratio). """ - if not adanorm: + if not adanorm or exp_grad_norm is None: return grad grad_norm = torch.linalg.norm(grad) exp_grad_norm.mul_(r).add_(grad_norm, alpha=1.0 - r) - return grad * exp_grad_norm / grad_norm if exp_grad_norm > grad_norm else grad + return grad.mul(exp_grad_norm).div_(grad_norm) if exp_grad_norm > grad_norm else grad @staticmethod def get_rms(x: torch.Tensor) -> float: From e8cd0bf89f3e9158417fb61664da631a1dbbd76d Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 21 Jul 2024 19:05:50 +0900 Subject: [PATCH 04/14] refactor: use BaseOptimizer only --- pytorch_optimizer/base/optimizer.py | 25 +++++++++++++-------- pytorch_optimizer/optimizer/a2grad.py | 3 +-- pytorch_optimizer/optimizer/adabelief.py | 3 +-- pytorch_optimizer/optimizer/adabound.py | 3 +-- pytorch_optimizer/optimizer/adadelta.py | 3 +-- pytorch_optimizer/optimizer/adafactor.py | 3 +-- pytorch_optimizer/optimizer/adahessian.py | 3 +-- pytorch_optimizer/optimizer/adai.py | 3 +-- pytorch_optimizer/optimizer/adalite.py | 3 +-- pytorch_optimizer/optimizer/adam_mini.py | 3 +-- pytorch_optimizer/optimizer/adamax.py | 3 +-- pytorch_optimizer/optimizer/adamod.py | 3 +-- pytorch_optimizer/optimizer/adamp.py | 3 +-- pytorch_optimizer/optimizer/adams.py | 3 +-- pytorch_optimizer/optimizer/adamw.py | 3 +-- pytorch_optimizer/optimizer/adan.py | 3 +-- pytorch_optimizer/optimizer/adanorm.py | 3 +-- pytorch_optimizer/optimizer/adapnm.py | 3 +-- pytorch_optimizer/optimizer/adashift.py | 3 +-- pytorch_optimizer/optimizer/adasmooth.py | 3 +-- pytorch_optimizer/optimizer/aggmo.py | 3 +-- pytorch_optimizer/optimizer/aida.py | 3 +-- pytorch_optimizer/optimizer/alig.py | 3 +-- pytorch_optimizer/optimizer/amos.py | 3 +-- pytorch_optimizer/optimizer/apollo.py | 3 +-- pytorch_optimizer/optimizer/avagrad.py | 3 +-- pytorch_optimizer/optimizer/came.py | 3 +-- pytorch_optimizer/optimizer/dadapt.py | 11 +++++---- pytorch_optimizer/optimizer/diffgrad.py | 3 +-- pytorch_optimizer/optimizer/fadam.py | 3 +-- pytorch_optimizer/optimizer/fromage.py | 3 +-- pytorch_optimizer/optimizer/galore.py | 3 +-- pytorch_optimizer/optimizer/gravity.py | 3 +-- pytorch_optimizer/optimizer/grokfast.py | 3 +-- pytorch_optimizer/optimizer/kate.py | 3 +-- pytorch_optimizer/optimizer/lamb.py | 3 +-- pytorch_optimizer/optimizer/lars.py | 3 +-- pytorch_optimizer/optimizer/lion.py | 3 +-- pytorch_optimizer/optimizer/lomo.py | 5 ++--- pytorch_optimizer/optimizer/lookahead.py | 3 +-- pytorch_optimizer/optimizer/madgrad.py | 3 +-- pytorch_optimizer/optimizer/msvag.py | 3 +-- pytorch_optimizer/optimizer/nero.py | 3 +-- pytorch_optimizer/optimizer/novograd.py | 3 +-- pytorch_optimizer/optimizer/padam.py | 3 +-- pytorch_optimizer/optimizer/pid.py | 3 +-- pytorch_optimizer/optimizer/pnm.py | 3 +-- pytorch_optimizer/optimizer/prodigy.py | 3 +-- pytorch_optimizer/optimizer/qhadam.py | 3 +-- pytorch_optimizer/optimizer/qhm.py | 3 +-- pytorch_optimizer/optimizer/radam.py | 3 +-- pytorch_optimizer/optimizer/ranger.py | 3 +-- pytorch_optimizer/optimizer/ranger21.py | 7 +++--- pytorch_optimizer/optimizer/sam.py | 9 ++++---- pytorch_optimizer/optimizer/schedulefree.py | 5 ++--- pytorch_optimizer/optimizer/sgd.py | 9 ++++---- pytorch_optimizer/optimizer/sgdp.py | 3 +-- pytorch_optimizer/optimizer/shampoo.py | 5 ++--- pytorch_optimizer/optimizer/sm3.py | 3 +-- pytorch_optimizer/optimizer/sophia.py | 3 +-- pytorch_optimizer/optimizer/srmm.py | 3 +-- pytorch_optimizer/optimizer/swats.py | 3 +-- pytorch_optimizer/optimizer/tiger.py | 3 +-- pytorch_optimizer/optimizer/yogi.py | 3 +-- 64 files changed, 94 insertions(+), 150 deletions(-) diff --git a/pytorch_optimizer/base/optimizer.py b/pytorch_optimizer/base/optimizer.py index 4a97f0168..33b2c6636 100644 --- a/pytorch_optimizer/base/optimizer.py +++ b/pytorch_optimizer/base/optimizer.py @@ -1,19 +1,23 @@ import math from abc import ABC, abstractmethod -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch +from torch.optim import Optimizer from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError -from pytorch_optimizer.base.types import BETAS, HUTCHINSON_G, PARAMETERS, STATE +from pytorch_optimizer.base.types import BETAS, DEFAULTS, HUTCHINSON_G, PARAMETERS, STATE -class BaseOptimizer(ABC): - r"""Base optimizer class.""" +class BaseOptimizer(ABC, Optimizer): + r"""Base optimizer class. Provides common functionalities for the optimizers.""" + + def __init__(self, params: PARAMETERS, defaults: DEFAULTS) -> None: + super().__init__(params, defaults) @staticmethod @torch.no_grad() - def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tensor]): + def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tensor]) -> None: r"""Set hessian to state from external source. Generally useful when using functorch as a base. Example: @@ -45,7 +49,7 @@ def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tens i += 1 @staticmethod - def zero_hessian(param_groups: PARAMETERS, state: STATE, pre_zero: bool = True): + def zero_hessian(param_groups: PARAMETERS, state: STATE, pre_zero: bool = True) -> None: r"""Zero-out hessian. :param param_groups: PARAMETERS. parameter groups. @@ -68,7 +72,7 @@ def compute_hutchinson_hessian( num_samples: int = 1, alpha: float = 1.0, distribution: HUTCHINSON_G = 'gaussian', - ): + ) -> None: r"""Hutchinson's approximate hessian, added to the state under key `hessian`. :param param_groups: PARAMETERS. parameter groups. @@ -110,7 +114,7 @@ def apply_weight_decay( weight_decouple: bool, fixed_decay: bool, ratio: Optional[float] = None, - ): + ) -> None: r"""Apply weight decay. :param p: torch.Tensor. parameter. @@ -320,5 +324,8 @@ def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None: self.validate_range(nus[1], 'nu2', 0.0, 1.0, range_type='[]') @abstractmethod - def reset(self): # pragma: no cover + def reset(self) -> None: # pragma: no cover + raise NotImplementedError + + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: raise NotImplementedError diff --git a/pytorch_optimizer/optimizer/a2grad.py b/pytorch_optimizer/optimizer/a2grad.py index 1d2b5a758..93d91adc9 100644 --- a/pytorch_optimizer/optimizer/a2grad.py +++ b/pytorch_optimizer/optimizer/a2grad.py @@ -2,14 +2,13 @@ from typing import Optional import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -class A2Grad(Optimizer, BaseOptimizer): +class A2Grad(BaseOptimizer): r"""Optimal Adaptive and Accelerated Stochastic Gradient Descent. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adabelief.py b/pytorch_optimizer/optimizer/adabelief.py index 06d681f39..23a35a59e 100644 --- a/pytorch_optimizer/optimizer/adabelief.py +++ b/pytorch_optimizer/optimizer/adabelief.py @@ -1,14 +1,13 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class AdaBelief(Optimizer, BaseOptimizer): +class AdaBelief(BaseOptimizer): r"""Adapting Step-sizes by the Belief in Observed Gradients. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adabound.py b/pytorch_optimizer/optimizer/adabound.py index 7ae5dbf48..114ef1b14 100644 --- a/pytorch_optimizer/optimizer/adabound.py +++ b/pytorch_optimizer/optimizer/adabound.py @@ -2,14 +2,13 @@ from typing import List import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class AdaBound(Optimizer, BaseOptimizer): +class AdaBound(BaseOptimizer): r"""Adaptive Gradient Methods with Dynamic Bound of Learning Rate. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adadelta.py b/pytorch_optimizer/optimizer/adadelta.py index 3a12b2746..75693a773 100644 --- a/pytorch_optimizer/optimizer/adadelta.py +++ b/pytorch_optimizer/optimizer/adadelta.py @@ -1,12 +1,11 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -class AdaDelta(Optimizer, BaseOptimizer): +class AdaDelta(BaseOptimizer): r"""An Adaptive Learning Rate Method. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adafactor.py b/pytorch_optimizer/optimizer/adafactor.py index 9299cde12..bd9ca02d6 100644 --- a/pytorch_optimizer/optimizer/adafactor.py +++ b/pytorch_optimizer/optimizer/adafactor.py @@ -2,14 +2,13 @@ from typing import Optional, Tuple import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class AdaFactor(Optimizer, BaseOptimizer): +class AdaFactor(BaseOptimizer): r"""Adaptive Learning Rates with Sublinear Memory Cost with some tweaks. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adahessian.py b/pytorch_optimizer/optimizer/adahessian.py index aabb4bb10..284c10e43 100644 --- a/pytorch_optimizer/optimizer/adahessian.py +++ b/pytorch_optimizer/optimizer/adahessian.py @@ -1,14 +1,13 @@ from typing import List, Optional import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, HUTCHINSON_G, LOSS, PARAMETERS -class AdaHessian(Optimizer, BaseOptimizer): +class AdaHessian(BaseOptimizer): r"""An Adaptive Second Order Optimizer for Machine Learning. Requires `loss.backward(create_graph=True)` in order to calculate hessians. diff --git a/pytorch_optimizer/optimizer/adai.py b/pytorch_optimizer/optimizer/adai.py index 0fcfc5270..607c1b4ec 100644 --- a/pytorch_optimizer/optimizer/adai.py +++ b/pytorch_optimizer/optimizer/adai.py @@ -1,7 +1,6 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -9,7 +8,7 @@ from pytorch_optimizer.optimizer.gc import centralize_gradient -class Adai(Optimizer, BaseOptimizer): +class Adai(BaseOptimizer): r"""Disentangling the Effects of Adaptive Learning Rate and Momentum. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adalite.py b/pytorch_optimizer/optimizer/adalite.py index 71a5f7cb4..4047b4c66 100644 --- a/pytorch_optimizer/optimizer/adalite.py +++ b/pytorch_optimizer/optimizer/adalite.py @@ -1,13 +1,12 @@ import torch from torch.nn.functional import softmax -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class Adalite(Optimizer, BaseOptimizer): +class Adalite(BaseOptimizer): r"""Adalite optimizer. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adam_mini.py b/pytorch_optimizer/optimizer/adam_mini.py index ac7f65a99..3f6a308da 100644 --- a/pytorch_optimizer/optimizer/adam_mini.py +++ b/pytorch_optimizer/optimizer/adam_mini.py @@ -4,14 +4,13 @@ import torch from torch import distributed as dist from torch import nn -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS -class AdamMini(Optimizer, BaseOptimizer): # pragma: no cover +class AdamMini(BaseOptimizer): # pragma: no cover r"""Use Fewer Learning Rates To Gain More. :param model: nn.Module. model instance. diff --git a/pytorch_optimizer/optimizer/adamax.py b/pytorch_optimizer/optimizer/adamax.py index d8c25b154..bafae8109 100644 --- a/pytorch_optimizer/optimizer/adamax.py +++ b/pytorch_optimizer/optimizer/adamax.py @@ -1,12 +1,11 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class AdaMax(Optimizer, BaseOptimizer): +class AdaMax(BaseOptimizer): r"""An Adaptive and Momental Bound Method for Stochastic Learning. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adamod.py b/pytorch_optimizer/optimizer/adamod.py index 41f76b11f..6f30433dc 100644 --- a/pytorch_optimizer/optimizer/adamod.py +++ b/pytorch_optimizer/optimizer/adamod.py @@ -1,14 +1,13 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class AdaMod(Optimizer, BaseOptimizer): +class AdaMod(BaseOptimizer): r"""An Adaptive and Momental Bound Method for Stochastic Learning. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adamp.py b/pytorch_optimizer/optimizer/adamp.py index 9bfbfa18b..0adb79f16 100644 --- a/pytorch_optimizer/optimizer/adamp.py +++ b/pytorch_optimizer/optimizer/adamp.py @@ -1,7 +1,6 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -10,7 +9,7 @@ from pytorch_optimizer.optimizer.utils import projection -class AdamP(Optimizer, BaseOptimizer): +class AdamP(BaseOptimizer): r"""Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adams.py b/pytorch_optimizer/optimizer/adams.py index 3ee189a7c..1b1717169 100644 --- a/pytorch_optimizer/optimizer/adams.py +++ b/pytorch_optimizer/optimizer/adams.py @@ -1,14 +1,13 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class AdamS(Optimizer, BaseOptimizer): +class AdamS(BaseOptimizer): r"""Adam with stable weight decay. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adamw.py b/pytorch_optimizer/optimizer/adamw.py index 158d362a8..9de9768f1 100644 --- a/pytorch_optimizer/optimizer/adamw.py +++ b/pytorch_optimizer/optimizer/adamw.py @@ -1,14 +1,13 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class StableAdamW(Optimizer, BaseOptimizer): +class StableAdamW(BaseOptimizer): r"""Stable and low-precision training for large-scale vision-language models. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adan.py b/pytorch_optimizer/optimizer/adan.py index b71ea296f..82a7d3a18 100644 --- a/pytorch_optimizer/optimizer/adan.py +++ b/pytorch_optimizer/optimizer/adan.py @@ -2,7 +2,6 @@ from typing import Union import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -11,7 +10,7 @@ from pytorch_optimizer.optimizer.utils import get_global_gradient_norm -class Adan(Optimizer, BaseOptimizer): +class Adan(BaseOptimizer): r"""Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adanorm.py b/pytorch_optimizer/optimizer/adanorm.py index 84bb58382..9da5fc1a7 100644 --- a/pytorch_optimizer/optimizer/adanorm.py +++ b/pytorch_optimizer/optimizer/adanorm.py @@ -1,14 +1,13 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class AdaNorm(Optimizer, BaseOptimizer): +class AdaNorm(BaseOptimizer): r"""Symbolic Discovery of Optimization Algorithms. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adapnm.py b/pytorch_optimizer/optimizer/adapnm.py index 9fec329fc..22389a6c5 100644 --- a/pytorch_optimizer/optimizer/adapnm.py +++ b/pytorch_optimizer/optimizer/adapnm.py @@ -1,14 +1,13 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class AdaPNM(Optimizer, BaseOptimizer): +class AdaPNM(BaseOptimizer): r"""Adam + Positive-Negative Momentum Optimizers. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adashift.py b/pytorch_optimizer/optimizer/adashift.py index a577b5797..bdb44d841 100644 --- a/pytorch_optimizer/optimizer/adashift.py +++ b/pytorch_optimizer/optimizer/adashift.py @@ -2,14 +2,13 @@ from typing import Callable, Optional import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class AdaShift(Optimizer, BaseOptimizer): +class AdaShift(BaseOptimizer): r"""Decorrelation and Convergence of Adaptive Learning Rate Methods. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/adasmooth.py b/pytorch_optimizer/optimizer/adasmooth.py index 453124bca..8d65b1a3c 100644 --- a/pytorch_optimizer/optimizer/adasmooth.py +++ b/pytorch_optimizer/optimizer/adasmooth.py @@ -1,12 +1,11 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class AdaSmooth(Optimizer, BaseOptimizer): +class AdaSmooth(BaseOptimizer): r"""An Adaptive Learning Rate Method based on Effective Ratio. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/aggmo.py b/pytorch_optimizer/optimizer/aggmo.py index 232565aa1..1739e4f3c 100644 --- a/pytorch_optimizer/optimizer/aggmo.py +++ b/pytorch_optimizer/optimizer/aggmo.py @@ -1,12 +1,11 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class AggMo(Optimizer, BaseOptimizer): +class AggMo(BaseOptimizer): r"""Aggregated Momentum: Stability Through Passive Damping. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/aida.py b/pytorch_optimizer/optimizer/aida.py index ffe2532c4..e04c89ad9 100644 --- a/pytorch_optimizer/optimizer/aida.py +++ b/pytorch_optimizer/optimizer/aida.py @@ -1,14 +1,13 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class Aida(Optimizer, BaseOptimizer): +class Aida(BaseOptimizer): r"""A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/alig.py b/pytorch_optimizer/optimizer/alig.py index 1fb7e2ea2..15666263d 100644 --- a/pytorch_optimizer/optimizer/alig.py +++ b/pytorch_optimizer/optimizer/alig.py @@ -1,7 +1,6 @@ from typing import Callable, Optional import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoClosureError, NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -9,7 +8,7 @@ from pytorch_optimizer.optimizer.utils import get_global_gradient_norm -class AliG(Optimizer, BaseOptimizer): +class AliG(BaseOptimizer): r"""Adaptive Learning Rates for Interpolation with Gradients. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/amos.py b/pytorch_optimizer/optimizer/amos.py index 854e9af71..1b76f1721 100644 --- a/pytorch_optimizer/optimizer/amos.py +++ b/pytorch_optimizer/optimizer/amos.py @@ -1,14 +1,13 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -class Amos(Optimizer, BaseOptimizer): +class Amos(BaseOptimizer): r"""An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/apollo.py b/pytorch_optimizer/optimizer/apollo.py index 2dcef87d7..785eef28f 100644 --- a/pytorch_optimizer/optimizer/apollo.py +++ b/pytorch_optimizer/optimizer/apollo.py @@ -2,14 +2,13 @@ import numpy as np import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -class Apollo(Optimizer, BaseOptimizer): +class Apollo(BaseOptimizer): r"""An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/avagrad.py b/pytorch_optimizer/optimizer/avagrad.py index 483492472..0b9f52dbf 100644 --- a/pytorch_optimizer/optimizer/avagrad.py +++ b/pytorch_optimizer/optimizer/avagrad.py @@ -1,14 +1,13 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class AvaGrad(Optimizer, BaseOptimizer): +class AvaGrad(BaseOptimizer): r"""Domain-independent Dominance of Adaptive Methods. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/came.py b/pytorch_optimizer/optimizer/came.py index 7199bbb87..411fab23e 100644 --- a/pytorch_optimizer/optimizer/came.py +++ b/pytorch_optimizer/optimizer/came.py @@ -2,14 +2,13 @@ from typing import Tuple import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class CAME(Optimizer, BaseOptimizer): +class CAME(BaseOptimizer): r"""Confidence-guided Adaptive Memory Efficient Optimization. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/dadapt.py b/pytorch_optimizer/optimizer/dadapt.py index fa4d298e4..acb83e21b 100644 --- a/pytorch_optimizer/optimizer/dadapt.py +++ b/pytorch_optimizer/optimizer/dadapt.py @@ -7,7 +7,6 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -15,7 +14,7 @@ from pytorch_optimizer.optimizer.utils import get_global_gradient_norm, to_real -class DAdaptAdaGrad(Optimizer, BaseOptimizer): +class DAdaptAdaGrad(BaseOptimizer): r"""AdaGrad with D-Adaptation. Leave LR set to 1 unless you encounter instability. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. @@ -240,7 +239,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: return loss -class DAdaptAdam(Optimizer, BaseOptimizer): +class DAdaptAdam(BaseOptimizer): r"""Adam with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. @@ -401,7 +400,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: return loss -class DAdaptSGD(Optimizer, BaseOptimizer): +class DAdaptSGD(BaseOptimizer): r"""SGD with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. @@ -537,7 +536,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: return loss -class DAdaptAdan(Optimizer, BaseOptimizer): +class DAdaptAdan(BaseOptimizer): r"""Adan with D-Adaptation. Leave LR set to 1 unless you encounter instability. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. @@ -701,7 +700,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: return loss -class DAdaptLion(Optimizer, BaseOptimizer): +class DAdaptLion(BaseOptimizer): r"""Lion with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/diffgrad.py b/pytorch_optimizer/optimizer/diffgrad.py index dbcac9d67..fd94295ef 100644 --- a/pytorch_optimizer/optimizer/diffgrad.py +++ b/pytorch_optimizer/optimizer/diffgrad.py @@ -1,12 +1,11 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class DiffGrad(Optimizer, BaseOptimizer): +class DiffGrad(BaseOptimizer): r"""An Optimization Method for Convolutional Neural Networks. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/fadam.py b/pytorch_optimizer/optimizer/fadam.py index 4d8d5056e..91a89dc3a 100644 --- a/pytorch_optimizer/optimizer/fadam.py +++ b/pytorch_optimizer/optimizer/fadam.py @@ -1,12 +1,11 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class FAdam(Optimizer, BaseOptimizer): +class FAdam(BaseOptimizer): r"""Adam is a natural gradient optimizer using diagonal empirical Fisher information. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/fromage.py b/pytorch_optimizer/optimizer/fromage.py index 73ab0260d..a45677697 100644 --- a/pytorch_optimizer/optimizer/fromage.py +++ b/pytorch_optimizer/optimizer/fromage.py @@ -7,14 +7,13 @@ from typing import Optional import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -class Fromage(Optimizer, BaseOptimizer): +class Fromage(BaseOptimizer): r"""On the distance between two neural networks and the stability of learning. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/galore.py b/pytorch_optimizer/optimizer/galore.py index 222c332ed..c7ce06dc8 100644 --- a/pytorch_optimizer/optimizer/galore.py +++ b/pytorch_optimizer/optimizer/galore.py @@ -2,7 +2,6 @@ from typing import Literal, Optional, Tuple, Union import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -133,7 +132,7 @@ def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor: raise NotImplementedError -class GaLore(Optimizer, BaseOptimizer): +class GaLore(BaseOptimizer): r"""AdamW optimizer with GaLore projector. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/gravity.py b/pytorch_optimizer/optimizer/gravity.py index a0fa65bb7..c7e7ed546 100644 --- a/pytorch_optimizer/optimizer/gravity.py +++ b/pytorch_optimizer/optimizer/gravity.py @@ -1,12 +1,11 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -class Gravity(Optimizer, BaseOptimizer): +class Gravity(BaseOptimizer): r"""a Kinematic Approach on Optimization in Deep Learning. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/grokfast.py b/pytorch_optimizer/optimizer/grokfast.py index 47a3dfede..ad8dc0c43 100644 --- a/pytorch_optimizer/optimizer/grokfast.py +++ b/pytorch_optimizer/optimizer/grokfast.py @@ -4,7 +4,6 @@ import torch from torch import nn -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -99,7 +98,7 @@ def gradfilter_ema( return grads -class GrokFastAdamW(Optimizer, BaseOptimizer): +class GrokFastAdamW(BaseOptimizer): r"""Accelerated Grokking by Amplifying Slow Gradients with AdamW. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/kate.py b/pytorch_optimizer/optimizer/kate.py index 5ea26c0f9..f006d1b4c 100644 --- a/pytorch_optimizer/optimizer/kate.py +++ b/pytorch_optimizer/optimizer/kate.py @@ -1,12 +1,11 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -class Kate(Optimizer, BaseOptimizer): +class Kate(BaseOptimizer): r"""Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/lamb.py b/pytorch_optimizer/optimizer/lamb.py index f5fbabca9..5f9795a2b 100644 --- a/pytorch_optimizer/optimizer/lamb.py +++ b/pytorch_optimizer/optimizer/lamb.py @@ -1,7 +1,6 @@ from typing import Union import torch -from torch.optim import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -9,7 +8,7 @@ from pytorch_optimizer.optimizer.utils import get_global_gradient_norm -class Lamb(Optimizer, BaseOptimizer): +class Lamb(BaseOptimizer): r"""Large Batch Optimization for Deep Learning. This Lamb implementation is based on the paper v3, which does not use de-biasing. diff --git a/pytorch_optimizer/optimizer/lars.py b/pytorch_optimizer/optimizer/lars.py index e45847bfd..32c7f4710 100644 --- a/pytorch_optimizer/optimizer/lars.py +++ b/pytorch_optimizer/optimizer/lars.py @@ -1,12 +1,11 @@ import torch -from torch.optim import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -class LARS(Optimizer, BaseOptimizer): +class LARS(BaseOptimizer): r"""Layer-wise Adaptive Rate Scaling (no rate scaling or weight decay for parameters <= 1D). :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/lion.py b/pytorch_optimizer/optimizer/lion.py index 8b6d38f04..98f2faf0e 100644 --- a/pytorch_optimizer/optimizer/lion.py +++ b/pytorch_optimizer/optimizer/lion.py @@ -1,5 +1,4 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -7,7 +6,7 @@ from pytorch_optimizer.optimizer.gc import centralize_gradient -class Lion(Optimizer, BaseOptimizer): +class Lion(BaseOptimizer): r"""Symbolic Discovery of Optimization Algorithms. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/lomo.py b/pytorch_optimizer/optimizer/lomo.py index 16d7e7a1b..63610c1b2 100644 --- a/pytorch_optimizer/optimizer/lomo.py +++ b/pytorch_optimizer/optimizer/lomo.py @@ -5,7 +5,6 @@ import torch from torch import nn from torch.distributed import ReduceOp, all_reduce -from torch.optim import Optimizer from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import DEFAULTS @@ -13,7 +12,7 @@ from pytorch_optimizer.optimizer.utils import has_overflow, is_deepspeed_zero3_enabled -class LOMO(BaseOptimizer, Optimizer): +class LOMO(BaseOptimizer): r"""Full Parameter Fine-tuning for Large Language Models with Limited Resources. Reference : https://github.com/OpenLMLab/LOMO/blob/main/src/lomo.py @@ -202,7 +201,7 @@ def grad_norm(self, loss): self.gather_norm = False -class AdaLOMO(BaseOptimizer, Optimizer): +class AdaLOMO(BaseOptimizer): r"""Low-memory Optimization with Adaptive Learning Rate. :param model: nn.Module. pytorch model. diff --git a/pytorch_optimizer/optimizer/lookahead.py b/pytorch_optimizer/optimizer/lookahead.py index 6e149f0f4..44a0be2ed 100644 --- a/pytorch_optimizer/optimizer/lookahead.py +++ b/pytorch_optimizer/optimizer/lookahead.py @@ -2,13 +2,12 @@ from typing import Callable, Dict import torch -from torch.optim import Optimizer from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER, STATE -class Lookahead(Optimizer, BaseOptimizer): +class Lookahead(BaseOptimizer): r"""k steps forward, 1 step back. :param optimizer: OPTIMIZER. base optimizer. diff --git a/pytorch_optimizer/optimizer/madgrad.py b/pytorch_optimizer/optimizer/madgrad.py index 060a0a347..52b8e12e9 100644 --- a/pytorch_optimizer/optimizer/madgrad.py +++ b/pytorch_optimizer/optimizer/madgrad.py @@ -6,14 +6,13 @@ import math import torch -from torch.optim import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -class MADGRAD(Optimizer, BaseOptimizer): +class MADGRAD(BaseOptimizer): r"""A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic (slightly modified). :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/msvag.py b/pytorch_optimizer/optimizer/msvag.py index f3260e276..792d87eb8 100644 --- a/pytorch_optimizer/optimizer/msvag.py +++ b/pytorch_optimizer/optimizer/msvag.py @@ -1,12 +1,11 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -class MSVAG(Optimizer, BaseOptimizer): +class MSVAG(BaseOptimizer): r"""Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/nero.py b/pytorch_optimizer/optimizer/nero.py index 33f43e82b..06e499863 100644 --- a/pytorch_optimizer/optimizer/nero.py +++ b/pytorch_optimizer/optimizer/nero.py @@ -1,5 +1,4 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -7,7 +6,7 @@ from pytorch_optimizer.optimizer.utils import neuron_mean, neuron_norm -class Nero(Optimizer, BaseOptimizer): +class Nero(BaseOptimizer): """Learning by Turning: Neural Architecture Aware Optimisation. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/novograd.py b/pytorch_optimizer/optimizer/novograd.py index 957f94180..f9aa937c5 100644 --- a/pytorch_optimizer/optimizer/novograd.py +++ b/pytorch_optimizer/optimizer/novograd.py @@ -1,14 +1,13 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class NovoGrad(Optimizer, BaseOptimizer): +class NovoGrad(BaseOptimizer): r"""Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/padam.py b/pytorch_optimizer/optimizer/padam.py index bcfe0a31e..c15d35f31 100644 --- a/pytorch_optimizer/optimizer/padam.py +++ b/pytorch_optimizer/optimizer/padam.py @@ -1,14 +1,13 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class PAdam(Optimizer, BaseOptimizer): +class PAdam(BaseOptimizer): """Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/pid.py b/pytorch_optimizer/optimizer/pid.py index 7dd73ba20..b148a203c 100644 --- a/pytorch_optimizer/optimizer/pid.py +++ b/pytorch_optimizer/optimizer/pid.py @@ -1,12 +1,11 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -class PID(Optimizer, BaseOptimizer): +class PID(BaseOptimizer): r"""A PID Controller Approach for Stochastic Optimization of Deep Networks. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/pnm.py b/pytorch_optimizer/optimizer/pnm.py index 7d5338eb9..6d16822a8 100644 --- a/pytorch_optimizer/optimizer/pnm.py +++ b/pytorch_optimizer/optimizer/pnm.py @@ -1,14 +1,13 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class PNM(Optimizer, BaseOptimizer): +class PNM(BaseOptimizer): r"""Positive-Negative Momentum Optimizers. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/prodigy.py b/pytorch_optimizer/optimizer/prodigy.py index dd456d129..77a26632a 100644 --- a/pytorch_optimizer/optimizer/prodigy.py +++ b/pytorch_optimizer/optimizer/prodigy.py @@ -2,14 +2,13 @@ from typing import Optional import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class Prodigy(Optimizer, BaseOptimizer): +class Prodigy(BaseOptimizer): r"""An Expeditiously Adaptive Parameter-Free Learner. Leave LR set to 1 unless you encounter instability. diff --git a/pytorch_optimizer/optimizer/qhadam.py b/pytorch_optimizer/optimizer/qhadam.py index 9d95fd05f..88700be5d 100644 --- a/pytorch_optimizer/optimizer/qhadam.py +++ b/pytorch_optimizer/optimizer/qhadam.py @@ -1,14 +1,13 @@ from typing import Tuple import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class QHAdam(Optimizer, BaseOptimizer): +class QHAdam(BaseOptimizer): r"""Quasi-hyperbolic momentum and Adam for deep learning. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/qhm.py b/pytorch_optimizer/optimizer/qhm.py index e9bedf977..986edb3bd 100644 --- a/pytorch_optimizer/optimizer/qhm.py +++ b/pytorch_optimizer/optimizer/qhm.py @@ -1,12 +1,11 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -class QHM(Optimizer, BaseOptimizer): +class QHM(BaseOptimizer): r"""Quasi-hyperbolic momentum (QHM) optimization algorithm. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/radam.py b/pytorch_optimizer/optimizer/radam.py index 22d1b5180..6243f7f95 100644 --- a/pytorch_optimizer/optimizer/radam.py +++ b/pytorch_optimizer/optimizer/radam.py @@ -1,12 +1,11 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class RAdam(Optimizer, BaseOptimizer): +class RAdam(BaseOptimizer): r"""Rectified Adam. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/ranger.py b/pytorch_optimizer/optimizer/ranger.py index 45a4277d6..021166977 100644 --- a/pytorch_optimizer/optimizer/ranger.py +++ b/pytorch_optimizer/optimizer/ranger.py @@ -1,5 +1,4 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -7,7 +6,7 @@ from pytorch_optimizer.optimizer.gc import centralize_gradient -class Ranger(Optimizer, BaseOptimizer): +class Ranger(BaseOptimizer): r"""a synergistic optimizer combining RAdam and LookAhead, and now GC in one optimizer. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/ranger21.py b/pytorch_optimizer/optimizer/ranger21.py index 1992d1da5..a553190cb 100644 --- a/pytorch_optimizer/optimizer/ranger21.py +++ b/pytorch_optimizer/optimizer/ranger21.py @@ -3,7 +3,6 @@ import torch from torch.nn import functional as f -from torch.optim import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -13,7 +12,7 @@ from pytorch_optimizer.optimizer.utils import normalize_gradient, unit_norm -class Ranger21(Optimizer, BaseOptimizer): +class Ranger21(BaseOptimizer): r"""Integrating the latest deep learning components into a single optimizer. Here's the components @@ -240,8 +239,8 @@ def step(self, closure: CLOSURE = None) -> LOSS: for group in self.param_groups: beta1, beta2 = group['betas'] - bias_correction1: float = 1.0 - beta1 ** group['step'] # fmt: skip - bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) # fmt: skip + bias_correction1: float = self.debias(beta1, group['step']) + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) noise_norm: float = math.sqrt((1.0 + beta2) ** 2 + beta2 ** 2) # fmt: skip diff --git a/pytorch_optimizer/optimizer/sam.py b/pytorch_optimizer/optimizer/sam.py index 4c820e5a4..771f76430 100644 --- a/pytorch_optimizer/optimizer/sam.py +++ b/pytorch_optimizer/optimizer/sam.py @@ -6,7 +6,6 @@ from torch.distributed import ReduceOp, all_reduce, get_world_size, is_initialized from torch.nn.parallel import DistributedDataParallel from torch.nn.utils import clip_grad_norm_ -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoClosureError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -14,7 +13,7 @@ from pytorch_optimizer.optimizer.utils import disable_running_stats, enable_running_stats -class SAM(Optimizer, BaseOptimizer): +class SAM(BaseOptimizer): r"""Sharpness-Aware Minimization for Efficiently Improving Generalization. Example: @@ -153,7 +152,7 @@ def load_state_dict(self, state_dict: Dict): self.base_optimizer.param_groups = self.param_groups -class GSAM(Optimizer, BaseOptimizer): # pragma: no cover +class GSAM(BaseOptimizer): # pragma: no cover r"""Surrogate Gap Guided Sharpness-Aware Minimization. Example: @@ -369,7 +368,7 @@ def load_state_dict(self, state_dict: Dict): self.base_optimizer.param_groups = self.param_groups -class WSAM(Optimizer, BaseOptimizer): +class WSAM(BaseOptimizer): r"""Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term. :param model: Union[torch.nn.Module, torch.nn.DataParallel]. the model instance. DDP model is recommended to make @@ -528,7 +527,7 @@ def load_state_dict(self, state_dict: Dict): self.base_optimizer.param_groups = self.param_groups -class BSAM(Optimizer, BaseOptimizer): +class BSAM(BaseOptimizer): r"""SAM as an Optimal Relaxation of Bayes. Example: diff --git a/pytorch_optimizer/optimizer/schedulefree.py b/pytorch_optimizer/optimizer/schedulefree.py index 0898f531d..aada3cfa0 100644 --- a/pytorch_optimizer/optimizer/schedulefree.py +++ b/pytorch_optimizer/optimizer/schedulefree.py @@ -2,14 +2,13 @@ from typing import List import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class ScheduleFreeSGD(Optimizer, BaseOptimizer): +class ScheduleFreeSGD(BaseOptimizer): r"""Schedule-Free SGD. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. @@ -151,7 +150,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: return loss -class ScheduleFreeAdamW(Optimizer, BaseOptimizer): +class ScheduleFreeAdamW(BaseOptimizer): r"""Schedule-Free AdamW. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/sgd.py b/pytorch_optimizer/optimizer/sgd.py index 9a27e8e72..e94c67796 100644 --- a/pytorch_optimizer/optimizer/sgd.py +++ b/pytorch_optimizer/optimizer/sgd.py @@ -2,14 +2,13 @@ from typing import Dict, Tuple import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -class AccSGD(Optimizer, BaseOptimizer): +class AccSGD(BaseOptimizer): r"""Accelerating Stochastic Gradient Descent For Least Squares Regression. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. @@ -104,7 +103,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: return loss -class SGDW(Optimizer, BaseOptimizer): +class SGDW(BaseOptimizer): r"""Decoupled Weight Decay Regularization. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. @@ -198,7 +197,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: return loss -class ASGD(Optimizer, BaseOptimizer): +class ASGD(BaseOptimizer): r"""Adaptive SGD with estimation of the local smoothness (curvature). :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. @@ -323,7 +322,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: return loss -class SignSGD(Optimizer, BaseOptimizer): +class SignSGD(BaseOptimizer): r"""Compressed Optimisation for Non-Convex Problems. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/sgdp.py b/pytorch_optimizer/optimizer/sgdp.py index 4a5c91dfe..63cc54827 100644 --- a/pytorch_optimizer/optimizer/sgdp.py +++ b/pytorch_optimizer/optimizer/sgdp.py @@ -1,5 +1,4 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -7,7 +6,7 @@ from pytorch_optimizer.optimizer.utils import projection -class SGDP(Optimizer, BaseOptimizer): +class SGDP(BaseOptimizer): r"""SGD + Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/shampoo.py b/pytorch_optimizer/optimizer/shampoo.py index 4659dab70..d1c08ba2b 100644 --- a/pytorch_optimizer/optimizer/shampoo.py +++ b/pytorch_optimizer/optimizer/shampoo.py @@ -1,5 +1,4 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -13,7 +12,7 @@ ) -class Shampoo(Optimizer, BaseOptimizer): +class Shampoo(BaseOptimizer): r"""Preconditioned Stochastic Tensor Optimization. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. @@ -137,7 +136,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: return loss -class ScalableShampoo(Optimizer, BaseOptimizer): +class ScalableShampoo(BaseOptimizer): r"""Scalable Preconditioned Stochastic Tensor Optimization. This version of Scalable Shampoo Optimizer aims for a single GPU environment, not for a distributed environment diff --git a/pytorch_optimizer/optimizer/sm3.py b/pytorch_optimizer/optimizer/sm3.py index b97953cf8..35b0c59ad 100644 --- a/pytorch_optimizer/optimizer/sm3.py +++ b/pytorch_optimizer/optimizer/sm3.py @@ -1,12 +1,11 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.utils import reduce_max_except_dim -class SM3(Optimizer, BaseOptimizer): +class SM3(BaseOptimizer): r"""Memory-Efficient Adaptive Optimization. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/sophia.py b/pytorch_optimizer/optimizer/sophia.py index 5c5d94ae8..c46c64bc8 100644 --- a/pytorch_optimizer/optimizer/sophia.py +++ b/pytorch_optimizer/optimizer/sophia.py @@ -1,14 +1,13 @@ from typing import List, Optional import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, HUTCHINSON_G, LOSS, PARAMETERS -class SophiaH(Optimizer, BaseOptimizer): +class SophiaH(BaseOptimizer): r"""Second-order Clipped Stochastic Optimization. Requires `loss.backward(create_graph=True)` in order to calculate hessians. diff --git a/pytorch_optimizer/optimizer/srmm.py b/pytorch_optimizer/optimizer/srmm.py index 16d4af7d5..5c73f9a5c 100644 --- a/pytorch_optimizer/optimizer/srmm.py +++ b/pytorch_optimizer/optimizer/srmm.py @@ -1,14 +1,13 @@ from typing import List, Optional import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -class SRMM(Optimizer, BaseOptimizer): +class SRMM(BaseOptimizer): """Stochastic regularized majorization-minimization with weakly convex and multi-convex surrogates. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/swats.py b/pytorch_optimizer/optimizer/swats.py index 0761479db..0f1705e2f 100644 --- a/pytorch_optimizer/optimizer/swats.py +++ b/pytorch_optimizer/optimizer/swats.py @@ -1,14 +1,13 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class SWATS(Optimizer, BaseOptimizer): +class SWATS(BaseOptimizer): r"""Improving Generalization Performance by Switching from Adam to SGD. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/tiger.py b/pytorch_optimizer/optimizer/tiger.py index 477b6c208..101efa71c 100644 --- a/pytorch_optimizer/optimizer/tiger.py +++ b/pytorch_optimizer/optimizer/tiger.py @@ -1,12 +1,11 @@ import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS -class Tiger(Optimizer, BaseOptimizer): +class Tiger(BaseOptimizer): r"""A Tight-fisted Optimizer, an optimizer that is extremely budget-conscious. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. diff --git a/pytorch_optimizer/optimizer/yogi.py b/pytorch_optimizer/optimizer/yogi.py index 24a0f48b6..7964e6c8b 100644 --- a/pytorch_optimizer/optimizer/yogi.py +++ b/pytorch_optimizer/optimizer/yogi.py @@ -1,14 +1,13 @@ import math import torch -from torch.optim.optimizer import Optimizer from pytorch_optimizer.base.exception import NoSparseGradientError from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS -class Yogi(Optimizer, BaseOptimizer): +class Yogi(BaseOptimizer): r"""Decoupled Weight Decay Regularization. :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. From 046770998550803c28588f675a76514385977c72 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 21 Jul 2024 19:07:14 +0900 Subject: [PATCH 05/14] docs: v3.1.0 changelog --- docs/changelogs/v3.1.0.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelogs/v3.1.0.md b/docs/changelogs/v3.1.0.md index 78ada9684..266b2568d 100644 --- a/docs/changelogs/v3.1.0.md +++ b/docs/changelogs/v3.1.0.md @@ -19,6 +19,7 @@ * Move `get_rms`, `approximate_sq_grad` functions to `BaseOptimizer` for reusability. (#258) * Refactor `shampoo_utils.py`. (#259) * Add `debias`, `debias_adam` methods in `BaseOptimizer`. (#261) +* Refactor to use `BaseOptimizer` only, not inherit multiple classes. (#261) ### Bug From dbef58311c1b7cb457c235b1a9506152976fb935 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 21 Jul 2024 19:23:10 +0900 Subject: [PATCH 06/14] refactor: gc --- pytorch_optimizer/optimizer/gc.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_optimizer/optimizer/gc.py b/pytorch_optimizer/optimizer/gc.py index 782657293..b00ac4303 100644 --- a/pytorch_optimizer/optimizer/gc.py +++ b/pytorch_optimizer/optimizer/gc.py @@ -1,12 +1,12 @@ import torch -def centralize_gradient(x: torch.Tensor, gc_conv_only: bool = False): +def centralize_gradient(grad: torch.Tensor, gc_conv_only: bool = False) -> None: r"""Gradient Centralization (GC). - :param x: torch.Tensor. gradient. + :param grad: torch.Tensor. gradient. :param gc_conv_only: bool. 'False' for both conv & fc layers. """ - size: int = x.dim() + size: int = grad.dim() if (gc_conv_only and size > 3) or (not gc_conv_only and size > 1): - x.add_(-x.mean(dim=tuple(range(1, size)), keepdim=True)) + grad.add_(-grad.mean(dim=tuple(range(1, size)), keepdim=True)) From 7420d322a82dcc58abaa7a0ad6f77e954be87b5b Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 21 Jul 2024 19:28:33 +0900 Subject: [PATCH 07/14] ci: disable coverage --- pytorch_optimizer/base/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_optimizer/base/optimizer.py b/pytorch_optimizer/base/optimizer.py index 33b2c6636..4659712fd 100644 --- a/pytorch_optimizer/base/optimizer.py +++ b/pytorch_optimizer/base/optimizer.py @@ -327,5 +327,5 @@ def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None: def reset(self) -> None: # pragma: no cover raise NotImplementedError - def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # pragma: no cover raise NotImplementedError From 35fb3fcfca4bbe056033109a9dc81cd30c694d0a Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 21 Jul 2024 19:36:00 +0900 Subject: [PATCH 08/14] refactor: type hint --- pytorch_optimizer/base/optimizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_optimizer/base/optimizer.py b/pytorch_optimizer/base/optimizer.py index 4659712fd..e4eb610b4 100644 --- a/pytorch_optimizer/base/optimizer.py +++ b/pytorch_optimizer/base/optimizer.py @@ -1,12 +1,12 @@ import math from abc import ABC, abstractmethod -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch.optim import Optimizer from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError -from pytorch_optimizer.base.types import BETAS, DEFAULTS, HUTCHINSON_G, PARAMETERS, STATE +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, HUTCHINSON_G, LOSS, PARAMETERS, STATE class BaseOptimizer(ABC, Optimizer): @@ -327,5 +327,5 @@ def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None: def reset(self) -> None: # pragma: no cover raise NotImplementedError - def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # pragma: no cover + def step(self, closure: CLOSURE = None) -> LOSS: # pragma: no cover raise NotImplementedError From 024e611778e346bdee8e00bf05f3d812a288d07c Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 21 Jul 2024 20:26:04 +0900 Subject: [PATCH 09/14] update: ranger21 recipe --- tests/test_optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 21cdd86fa..2d25f4650 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -271,7 +271,7 @@ def test_adanorm_optimizer(optimizer_config, environment): optimizer_class, config, num_iterations = optimizer_config if optimizer_class.__name__ == 'Ranger21': - config.update({'num_iterations': num_iterations}) + config.update({'num_iterations': num_iterations, 'disable_lr_scheduler': True}) optimizer = optimizer_class(model.parameters(), **config) From b85e0ad5eaae4c85b73a1307f1f9194badc5a9a5 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 21 Jul 2024 20:26:13 +0900 Subject: [PATCH 10/14] feature: disable_lr_scheduler --- pytorch_optimizer/optimizer/ranger21.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/pytorch_optimizer/optimizer/ranger21.py b/pytorch_optimizer/optimizer/ranger21.py index a553190cb..db22febd3 100644 --- a/pytorch_optimizer/optimizer/ranger21.py +++ b/pytorch_optimizer/optimizer/ranger21.py @@ -2,7 +2,7 @@ from typing import Optional import torch -from torch.nn import functional as f +from torch.nn.functional import softplus from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError from pytorch_optimizer.base.optimizer import BaseOptimizer @@ -38,6 +38,7 @@ class Ranger21(BaseOptimizer): :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. :param use_softplus: bool. use softplus to smooth. :param beta_softplus: float. beta. + :param disable_lr_scheduler: bool. whether to disable learning rate schedule. :param num_warm_up_iterations: Optional[int]. number of warm-up iterations. Ranger21 performs linear learning rate warmup. :param num_warm_down_iterations: Optional[int]. number of warm-down iterations. Ranger21 performs Explore-exploit @@ -65,6 +66,7 @@ def __init__( # pylint: disable=R0913 betas: BETAS = (0.9, 0.999), use_softplus: bool = True, beta_softplus: float = 50.0, + disable_lr_scheduler: bool = False, num_warm_up_iterations: Optional[int] = None, num_warm_down_iterations: Optional[int] = None, warm_down_min_lr: float = 3e-5, @@ -93,6 +95,7 @@ def __init__( # pylint: disable=R0913 self.min_lr = warm_down_min_lr self.use_softplus = use_softplus self.beta_softplus = beta_softplus + self.disable_lr_scheduler = disable_lr_scheduler self.agc_clipping_value = agc_clipping_value self.agc_eps = agc_eps self.centralize_gradients = centralize_gradients @@ -245,8 +248,11 @@ def step(self, closure: CLOSURE = None) -> LOSS: noise_norm: float = math.sqrt((1.0 + beta2) ** 2 + beta2 ** 2) # fmt: skip # warm up & down - lr: float = self.warm_up_dampening(group['lr'], group['step']) - lr = self.warm_down(lr, group['step']) + if self.disable_lr_scheduler: + lr: float = group['lr'] + else: + lr: float = self.warm_up_dampening(group['lr'], group['step']) + lr = self.warm_down(lr, group['step']) for p in group['params']: if p.grad is None: @@ -279,7 +285,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: de_nom = (variance_ma.sqrt() / bias_correction2_sq).add_(group['eps']) if self.use_softplus: - de_nom = f.softplus(de_nom, beta=self.beta_softplus) + de_nom = softplus(de_nom, beta=self.beta_softplus) grad = p.grad centralize_gradient(grad, gc_conv_only=False) @@ -289,7 +295,7 @@ def step(self, closure: CLOSURE = None) -> LOSS: step_size: float = self.apply_adam_debias(group['adam_debias'], lr, bias_correction1) - pn_momentum = grad_ma.mul(1.0 + 1.0).add(neg_grad_ma, alpha=-1.0).mul(1.0 / noise_norm) + pn_momentum = grad_ma.mul(1.0 + 1.0).add_(neg_grad_ma, alpha=-1.0).mul_(1.0 / noise_norm) p.addcdiv_(pn_momentum, de_nom, value=-step_size) self.lookahead_process_step() From 3ece36d852c13166146c8dde855fc1603157c991 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 21 Jul 2024 20:26:22 +0900 Subject: [PATCH 11/14] docs: v3.1.0 changelog --- docs/changelogs/v3.1.0.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelogs/v3.1.0.md b/docs/changelogs/v3.1.0.md index 266b2568d..2bd1b69ff 100644 --- a/docs/changelogs/v3.1.0.md +++ b/docs/changelogs/v3.1.0.md @@ -11,6 +11,7 @@ * `bnb_paged_adam8bit`, `bnb_paged_adamw8bit`, `bnb_*_*32bit`. * Improve `power_iteration()` speed up to 40%. (#259) * Improve `reg_noise()` (E-MCMC) speed up to 120%. (#260) +* Support `disable_lr_scheduler` parameter for `Ranger21` optimizer to disable built-in learning rate scheduler. (#261) ### Refactor From c599541d43d21cd00b7b54c87d5ec36deba31a9d Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 21 Jul 2024 20:35:57 +0900 Subject: [PATCH 12/14] update: ranger21 recipe --- tests/constants.py | 6 +++++- tests/test_optimizers.py | 4 ---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/constants.py b/tests/constants.py index bc42a266e..498fbf46b 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -493,7 +493,11 @@ (Lamb, {'lr': 1e0, 'weight_decay': 1e-3, 'rectify': True, 'adam_debias': True}, 30), (RAdam, {'lr': 1e0, 'weight_decay': 1e-3, 'adam_debias': True}, 25), (Ranger, {'lr': 5e0, 'weight_decay': 1e-3, 'adam_debias': True}, 50), - (Ranger21, {'lr': 1e0, 'weight_decay': 1e-3, 'adam_debias': True, 'num_iterations': 125}, 125), + ( + Ranger21, + {'lr': 5e-1, 'weight_decay': 1e-3, 'adam_debias': True, 'num_iterations': 125, 'disable_lr_scheduler': True}, + 125, + ), (AdaPNM, {'lr': 1e0, 'weight_decay': 1e-3, 'adam_debias': True}, 10), (NovoGrad, {'lr': 1e0, 'weight_decay': 1e-3, 'adam_debias': True}, 5), (AdaNorm, {'lr': 1e0, 'weight_decay': 1e-3, 'adam_debias': True}, 5), diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 2d25f4650..6501de3ee 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -270,8 +270,6 @@ def test_adanorm_optimizer(optimizer_config, environment): (x_data, y_data), model, loss_fn = environment optimizer_class, config, num_iterations = optimizer_config - if optimizer_class.__name__ == 'Ranger21': - config.update({'num_iterations': num_iterations, 'disable_lr_scheduler': True}) optimizer = optimizer_class(model.parameters(), **config) @@ -311,8 +309,6 @@ def test_adamd_optimizers(optimizer_config, environment): (x_data, y_data), model, loss_fn = environment optimizer_class, config, num_iterations = optimizer_config - if optimizer_class.__name__ == 'Ranger21': - config.update({'num_iterations': num_iterations}) optimizer = optimizer_class(model.parameters(), **config) From 137f336d778fa232bcd1f3e6383a11e4fa0a9b2d Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 21 Jul 2024 20:45:46 +0900 Subject: [PATCH 13/14] docs: fix typo --- pytorch_optimizer/base/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_optimizer/base/optimizer.py b/pytorch_optimizer/base/optimizer.py index e4eb610b4..5fe4e8cbb 100644 --- a/pytorch_optimizer/base/optimizer.py +++ b/pytorch_optimizer/base/optimizer.py @@ -154,7 +154,7 @@ def debias(beta: float, step: int) -> float: r"""Adam-style debias correction. Returns `1.0 - beta ** step`. :param beta: float. beta. - :param step. int. number of step. + :param step: int. number of step. """ return 1.0 - math.pow(beta, step) # fmt: skip From 59e3ec306a50bd7e26d7a56f6112a50ed32b04eb Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 21 Jul 2024 20:48:22 +0900 Subject: [PATCH 14/14] build(deps): update docs packages --- requirements-docs.txt | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/requirements-docs.txt b/requirements-docs.txt index deb3aad49..8b5c48bb0 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -1,11 +1,11 @@ --index-url https://pypi.org/simple --extra-index-url https://download.pytorch.org/whl/cpu -numpy -torch==2.1.0 -mkdocs==1.5.2 -mkdocs-material==9.3.1 -pymdown-extensions==10.3 -mkdocstrings-python==1.7.0 +numpy<2.0 +torch==2.3.1 +mkdocs==1.6.0 +mkdocs-material==9.5.29 +pymdown-extensions==10.8.1 +mkdocstrings-python==1.10.5 markdown-include==0.8.1 mdx_truly_sane_lists==1.3 mkdocs-awesome-pages-plugin==2.9.2