diff --git a/README.md b/README.md index 9a3675d49..0633c3746 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **62 optimizers (+ `bitsandbytes`)**, **10 lr schedulers**, and **13 loss functions** are supported! +Currently, **62 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -171,10 +171,11 @@ from pytorch_optimizer import get_supported_lr_schedulers supported_lr_schedulers = get_supported_lr_schedulers() ``` -| LR Scheduler | Description | Official Code | Paper | Citation | -|-----------------|---------------------------------------------------------------------------------|---------------|------------------------------------|------------------------------------------------------------------------------| -| Explore-Exploit | *Wide-minima Density Hypothesis and the Explore-Exploit Learning Rate Schedule* | | | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv200303977I/exportcitation) | -| Chebyshev | *Acceleration via Fractal Learning Rate Schedules* | | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210301338A/exportcitation) | +| LR Scheduler | Description | Official Code | Paper | Citation | +|-----------------|---------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------|------------------------------------|------------------------------------------------------------------------------| +| Explore-Exploit | *Wide-minima Density Hypothesis and the Explore-Exploit Learning Rate Schedule* | | | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv200303977I/exportcitation) | +| Chebyshev | *Acceleration via Fractal Learning Rate Schedules* | | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210301338A/exportcitation) | +| REX | *Revisiting Budgeted Training with an Improved Schedule* | [github](https://github.com/Nerogar/OneTrainer/blob/2c6f34ea0838e5a86774a1cf75093d7e97c70f03/modules/util/lr_scheduler_util.py#L66) | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210704197C/exportcitation) | ## Supported Loss Function diff --git a/docs/changelogs/v3.0.0.md b/docs/changelogs/v3.0.0.md index 646e5eaac..7d4980be1 100644 --- a/docs/changelogs/v3.0.0.md +++ b/docs/changelogs/v3.0.0.md @@ -4,6 +4,8 @@ Major version is updated! (`v2.12.0` -> `v3.0.0`) (#164) ### Feature +* Implement `REX` lr scheduler. (#217, #222) + * [Revisiting Budgeted Training with an Improved Schedule](https://arxiv.org/abs/2107.04197) * Implement `Aida` optimizer. (#220, #221) * [A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range](https://arxiv.org/abs/2203.13273) * Implement `WSAM` optimizer. (#213, #216) diff --git a/docs/lr_scheduler.md b/docs/lr_scheduler.md index 2d75ec1b8..37ff52586 100644 --- a/docs/lr_scheduler.md +++ b/docs/lr_scheduler.md @@ -25,3 +25,7 @@ ::: pytorch_optimizer.ProportionScheduler :docstring: :members: + +::: pytorch_optimizer.REXScheduler + :docstring: + :members: diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index c2b2d3baf..8ccdb7c36 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -26,6 +26,7 @@ from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler +from pytorch_optimizer.lr_scheduler.rex import REXScheduler from pytorch_optimizer.optimizer.a2grad import A2Grad from pytorch_optimizer.optimizer.adabelief import AdaBelief from pytorch_optimizer.optimizer.adabound import AdaBound @@ -195,6 +196,7 @@ PolyScheduler, LinearScheduler, ProportionScheduler, + REXScheduler, ] LR_SCHEDULERS: Dict[str, SCHEDULER] = { str(lr_scheduler.__name__).lower(): lr_scheduler for lr_scheduler in LR_SCHEDULER_LIST diff --git a/pytorch_optimizer/base/scheduler.py b/pytorch_optimizer/base/scheduler.py index 3f28c80d3..4cfd5ed3e 100644 --- a/pytorch_optimizer/base/scheduler.py +++ b/pytorch_optimizer/base/scheduler.py @@ -76,7 +76,6 @@ def step(self): self.step_t += 1 - # apply the lr to optimizer if it's provided if self.optimizer is not None: for param_group in self.optimizer.param_groups: param_group['lr'] = value diff --git a/pytorch_optimizer/lr_scheduler/rex.py b/pytorch_optimizer/lr_scheduler/rex.py new file mode 100644 index 000000000..bebe4200d --- /dev/null +++ b/pytorch_optimizer/lr_scheduler/rex.py @@ -0,0 +1,66 @@ +from typing import List + +from torch.optim.lr_scheduler import _LRScheduler + +from pytorch_optimizer.base.types import OPTIMIZER + + +class REXScheduler(_LRScheduler): + r"""Revisiting Budgeted Training with an Improved Schedule. + + :param optimizer: Optimizer. wrapped optimizer instance. + :param total_steps: int. number of steps to optimize. + :param max_lr: float. max lr. + :param min_lr: float. min lr. + """ + + def __init__( + self, + optimizer: OPTIMIZER, + total_steps: int, + max_lr: float = 1.0, + min_lr: float = 0.0, + ): + self.total_steps = total_steps + self.max_lr = max_lr + self.min_lr = min_lr + + self.step_t: int = 0 + self.base_lrs: List[float] = [] + + # record current value in self._last_lr to match API from torch.optim.lr_scheduler + self.last_lr: List[float] = [self.max_lr] + + super().__init__(optimizer) + + self.init_lr() + + def init_lr(self): + self.base_lrs = [] + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.min_lr + self.base_lrs.append(self.min_lr) + + def get_lr(self) -> float: + return self.last_lr[0] + + def get_linear_lr(self) -> float: + if self.step_t >= self.total_steps: + return self.min_lr + + progress: float = self.step_t / self.total_steps + + return self.min_lr + (self.max_lr - self.min_lr) * ((1.0 - progress) / (1.0 - progress / 2.0)) + + def step(self): + value: float = self.get_linear_lr() + + self.step_t += 1 + + if self.optimizer is not None: + for param_group in self.optimizer.param_groups: + param_group['lr'] = value + + self.last_lr = [value] + + return value diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index a3b97e3f9..57b4fd8d9 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -42,7 +42,7 @@ def test_get_supported_optimizers(): def test_get_supported_lr_schedulers(): - assert len(get_supported_lr_schedulers()) == 10 + assert len(get_supported_lr_schedulers()) == 11 def test_get_supported_loss_functions(): diff --git a/tests/test_lr_schedulers.py b/tests/test_lr_schedulers.py index 87832d961..11720cdf6 100644 --- a/tests/test_lr_schedulers.py +++ b/tests/test_lr_schedulers.py @@ -10,6 +10,7 @@ from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler +from pytorch_optimizer.lr_scheduler.rex import REXScheduler from tests.utils import Example CAWR_RECIPES = [ @@ -263,6 +264,29 @@ def test_proportion_no_last_lr_scheduler(): np.testing.assert_almost_equal(2.0, rho_scheduler.get_lr(), 6) +def test_rex_lr_scheduler(): + lrs = [ + 0.888888, + 0.749999, + 0.571428, + 0.333333, + 0.0, + ] + + base_optimizer = AdamP(Example().parameters()) + + lr_scheduler = REXScheduler( + base_optimizer, + total_steps=5, + max_lr=1.0, + min_lr=0.0, + ) + + for expected_lr in lrs: + _ = lr_scheduler.step() + np.testing.assert_almost_equal(expected_lr, lr_scheduler.get_lr(), 6) + + def test_deberta_v3_large_lr_scheduler(): model = nn.Sequential(*[nn.Linear(1, 1, bias=False) for _ in range(400)]) deberta_v3_large_lr_scheduler(model)