From 0586be54ac98e0342fe01b3d65cfe39d6a41a9c2 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 2 Mar 2024 18:02:50 +0900 Subject: [PATCH 1/7] docs: REX lr scheduler --- README.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 9a3675d49..0633c3746 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch. I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas. -Currently, **62 optimizers (+ `bitsandbytes`)**, **10 lr schedulers**, and **13 loss functions** are supported! +Currently, **62 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported! Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer). @@ -171,10 +171,11 @@ from pytorch_optimizer import get_supported_lr_schedulers supported_lr_schedulers = get_supported_lr_schedulers() ``` -| LR Scheduler | Description | Official Code | Paper | Citation | -|-----------------|---------------------------------------------------------------------------------|---------------|------------------------------------|------------------------------------------------------------------------------| -| Explore-Exploit | *Wide-minima Density Hypothesis and the Explore-Exploit Learning Rate Schedule* | | | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv200303977I/exportcitation) | -| Chebyshev | *Acceleration via Fractal Learning Rate Schedules* | | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210301338A/exportcitation) | +| LR Scheduler | Description | Official Code | Paper | Citation | +|-----------------|---------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------|------------------------------------|------------------------------------------------------------------------------| +| Explore-Exploit | *Wide-minima Density Hypothesis and the Explore-Exploit Learning Rate Schedule* | | | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv200303977I/exportcitation) | +| Chebyshev | *Acceleration via Fractal Learning Rate Schedules* | | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210301338A/exportcitation) | +| REX | *Revisiting Budgeted Training with an Improved Schedule* | [github](https://github.com/Nerogar/OneTrainer/blob/2c6f34ea0838e5a86774a1cf75093d7e97c70f03/modules/util/lr_scheduler_util.py#L66) | | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210704197C/exportcitation) | ## Supported Loss Function From 448db3d8894b5fe13e0b2a9c39086a32e52644fd Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 2 Mar 2024 18:25:41 +0900 Subject: [PATCH 2/7] docs: REX lr scheduler --- docs/lr_scheduler.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/lr_scheduler.md b/docs/lr_scheduler.md index 2d75ec1b8..37ff52586 100644 --- a/docs/lr_scheduler.md +++ b/docs/lr_scheduler.md @@ -25,3 +25,7 @@ ::: pytorch_optimizer.ProportionScheduler :docstring: :members: + +::: pytorch_optimizer.REXScheduler + :docstring: + :members: From 4af1880a2256df000c2f223be9bec5538383f39a Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 2 Mar 2024 18:25:51 +0900 Subject: [PATCH 3/7] feature: implement REX lr scheduler --- pytorch_optimizer/lr_scheduler/rex.py | 66 +++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 pytorch_optimizer/lr_scheduler/rex.py diff --git a/pytorch_optimizer/lr_scheduler/rex.py b/pytorch_optimizer/lr_scheduler/rex.py new file mode 100644 index 000000000..bebe4200d --- /dev/null +++ b/pytorch_optimizer/lr_scheduler/rex.py @@ -0,0 +1,66 @@ +from typing import List + +from torch.optim.lr_scheduler import _LRScheduler + +from pytorch_optimizer.base.types import OPTIMIZER + + +class REXScheduler(_LRScheduler): + r"""Revisiting Budgeted Training with an Improved Schedule. + + :param optimizer: Optimizer. wrapped optimizer instance. + :param total_steps: int. number of steps to optimize. + :param max_lr: float. max lr. + :param min_lr: float. min lr. + """ + + def __init__( + self, + optimizer: OPTIMIZER, + total_steps: int, + max_lr: float = 1.0, + min_lr: float = 0.0, + ): + self.total_steps = total_steps + self.max_lr = max_lr + self.min_lr = min_lr + + self.step_t: int = 0 + self.base_lrs: List[float] = [] + + # record current value in self._last_lr to match API from torch.optim.lr_scheduler + self.last_lr: List[float] = [self.max_lr] + + super().__init__(optimizer) + + self.init_lr() + + def init_lr(self): + self.base_lrs = [] + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.min_lr + self.base_lrs.append(self.min_lr) + + def get_lr(self) -> float: + return self.last_lr[0] + + def get_linear_lr(self) -> float: + if self.step_t >= self.total_steps: + return self.min_lr + + progress: float = self.step_t / self.total_steps + + return self.min_lr + (self.max_lr - self.min_lr) * ((1.0 - progress) / (1.0 - progress / 2.0)) + + def step(self): + value: float = self.get_linear_lr() + + self.step_t += 1 + + if self.optimizer is not None: + for param_group in self.optimizer.param_groups: + param_group['lr'] = value + + self.last_lr = [value] + + return value From 457120de7bd5315ba34fbef98ac75fcef3666b18 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 2 Mar 2024 18:25:59 +0900 Subject: [PATCH 4/7] update: REX lr scheduler --- pytorch_optimizer/__init__.py | 2 ++ pytorch_optimizer/base/scheduler.py | 1 - tests/test_load_modules.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index c2b2d3baf..8ccdb7c36 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -26,6 +26,7 @@ from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler +from pytorch_optimizer.lr_scheduler.rex import REXScheduler from pytorch_optimizer.optimizer.a2grad import A2Grad from pytorch_optimizer.optimizer.adabelief import AdaBelief from pytorch_optimizer.optimizer.adabound import AdaBound @@ -195,6 +196,7 @@ PolyScheduler, LinearScheduler, ProportionScheduler, + REXScheduler, ] LR_SCHEDULERS: Dict[str, SCHEDULER] = { str(lr_scheduler.__name__).lower(): lr_scheduler for lr_scheduler in LR_SCHEDULER_LIST diff --git a/pytorch_optimizer/base/scheduler.py b/pytorch_optimizer/base/scheduler.py index 3f28c80d3..4cfd5ed3e 100644 --- a/pytorch_optimizer/base/scheduler.py +++ b/pytorch_optimizer/base/scheduler.py @@ -76,7 +76,6 @@ def step(self): self.step_t += 1 - # apply the lr to optimizer if it's provided if self.optimizer is not None: for param_group in self.optimizer.param_groups: param_group['lr'] = value diff --git a/tests/test_load_modules.py b/tests/test_load_modules.py index a3b97e3f9..57b4fd8d9 100644 --- a/tests/test_load_modules.py +++ b/tests/test_load_modules.py @@ -42,7 +42,7 @@ def test_get_supported_optimizers(): def test_get_supported_lr_schedulers(): - assert len(get_supported_lr_schedulers()) == 10 + assert len(get_supported_lr_schedulers()) == 11 def test_get_supported_loss_functions(): From 721e1dbbe624e2540bceb1ff1e494311c1ab3e93 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 2 Mar 2024 18:29:29 +0900 Subject: [PATCH 5/7] update: test_rex_lr_scheduler --- tests/test_lr_schedulers.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_lr_schedulers.py b/tests/test_lr_schedulers.py index 87832d961..d4ac4ba79 100644 --- a/tests/test_lr_schedulers.py +++ b/tests/test_lr_schedulers.py @@ -10,6 +10,7 @@ from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler +from pytorch_optimizer.lr_scheduler.rex import REXScheduler from tests.utils import Example CAWR_RECIPES = [ @@ -263,6 +264,29 @@ def test_proportion_no_last_lr_scheduler(): np.testing.assert_almost_equal(2.0, rho_scheduler.get_lr(), 6) +def test_rex_lr_scheduler(): + lrs = [ + 0.888888, + 0.749999, + 0.571428, + 0.333333, + 0.0, + ] + + base_optimizer = AdamP(Example().parameters()) + + lr_scheduler = REXScheduler( + base_optimizer, + total_steps=5, + max_lr=1.0, + min_lr=0.0, + ) + + for expected_lr in lrs: + lr: float = lr_scheduler.step() + np.testing.assert_almost_equal(expected_lr, lr, 6) + + def test_deberta_v3_large_lr_scheduler(): model = nn.Sequential(*[nn.Linear(1, 1, bias=False) for _ in range(400)]) deberta_v3_large_lr_scheduler(model) From bc3df01bcdfa93e905fa5d081436362ff742b378 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 2 Mar 2024 18:30:47 +0900 Subject: [PATCH 6/7] docs: v3.0.0 changelog --- docs/changelogs/v3.0.0.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changelogs/v3.0.0.md b/docs/changelogs/v3.0.0.md index 646e5eaac..7d4980be1 100644 --- a/docs/changelogs/v3.0.0.md +++ b/docs/changelogs/v3.0.0.md @@ -4,6 +4,8 @@ Major version is updated! (`v2.12.0` -> `v3.0.0`) (#164) ### Feature +* Implement `REX` lr scheduler. (#217, #222) + * [Revisiting Budgeted Training with an Improved Schedule](https://arxiv.org/abs/2107.04197) * Implement `Aida` optimizer. (#220, #221) * [A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range](https://arxiv.org/abs/2203.13273) * Implement `WSAM` optimizer. (#213, #216) From a3b1bfb1c7a115ee67aed9249261bcef782afebd Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 2 Mar 2024 18:34:39 +0900 Subject: [PATCH 7/7] update: test_rex_lr_scheduler --- tests/test_lr_schedulers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_lr_schedulers.py b/tests/test_lr_schedulers.py index d4ac4ba79..11720cdf6 100644 --- a/tests/test_lr_schedulers.py +++ b/tests/test_lr_schedulers.py @@ -283,8 +283,8 @@ def test_rex_lr_scheduler(): ) for expected_lr in lrs: - lr: float = lr_scheduler.step() - np.testing.assert_almost_equal(expected_lr, lr, 6) + _ = lr_scheduler.step() + np.testing.assert_almost_equal(expected_lr, lr_scheduler.get_lr(), 6) def test_deberta_v3_large_lr_scheduler():