Skip to content

Commit

Permalink
Merge pull request #222 from kozistr/feature/rex-lr-scheduler
Browse files Browse the repository at this point in the history
[Feature] Implement REX lr scheduler
  • Loading branch information
kozistr authored Mar 2, 2024
2 parents fd717fc + a3b1bfb commit da65344
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 7 deletions.
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
Currently, **62 optimizers (+ `bitsandbytes`)**, **10 lr schedulers**, and **13 loss functions** are supported!
Currently, **62 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported!

Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).

Expand Down Expand Up @@ -171,10 +171,11 @@ from pytorch_optimizer import get_supported_lr_schedulers
supported_lr_schedulers = get_supported_lr_schedulers()
```

| LR Scheduler | Description | Official Code | Paper | Citation |
|-----------------|---------------------------------------------------------------------------------|---------------|------------------------------------|------------------------------------------------------------------------------|
| Explore-Exploit | *Wide-minima Density Hypothesis and the Explore-Exploit Learning Rate Schedule* | | <https://arxiv.org/abs/2003.03977> | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv200303977I/exportcitation) |
| Chebyshev | *Acceleration via Fractal Learning Rate Schedules* | | <https://arxiv.org/abs/2103.01338> | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210301338A/exportcitation) |
| LR Scheduler | Description | Official Code | Paper | Citation |
|-----------------|---------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------|------------------------------------|------------------------------------------------------------------------------|
| Explore-Exploit | *Wide-minima Density Hypothesis and the Explore-Exploit Learning Rate Schedule* | | <https://arxiv.org/abs/2003.03977> | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv200303977I/exportcitation) |
| Chebyshev | *Acceleration via Fractal Learning Rate Schedules* | | <https://arxiv.org/abs/2103.01338> | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210301338A/exportcitation) |
| REX | *Revisiting Budgeted Training with an Improved Schedule* | [github](https://github.com/Nerogar/OneTrainer/blob/2c6f34ea0838e5a86774a1cf75093d7e97c70f03/modules/util/lr_scheduler_util.py#L66) | <https://arxiv.org/abs/2107.04197> | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210704197C/exportcitation) |

## Supported Loss Function

Expand Down
2 changes: 2 additions & 0 deletions docs/changelogs/v3.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Major version is updated! (`v2.12.0` -> `v3.0.0`) (#164)

### Feature

* Implement `REX` lr scheduler. (#217, #222)
* [Revisiting Budgeted Training with an Improved Schedule](https://arxiv.org/abs/2107.04197)
* Implement `Aida` optimizer. (#220, #221)
* [A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range](https://arxiv.org/abs/2203.13273)
* Implement `WSAM` optimizer. (#213, #216)
Expand Down
4 changes: 4 additions & 0 deletions docs/lr_scheduler.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@
::: pytorch_optimizer.ProportionScheduler
:docstring:
:members:

::: pytorch_optimizer.REXScheduler
:docstring:
:members:
2 changes: 2 additions & 0 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler
from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler
from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler
from pytorch_optimizer.lr_scheduler.rex import REXScheduler
from pytorch_optimizer.optimizer.a2grad import A2Grad
from pytorch_optimizer.optimizer.adabelief import AdaBelief
from pytorch_optimizer.optimizer.adabound import AdaBound
Expand Down Expand Up @@ -195,6 +196,7 @@
PolyScheduler,
LinearScheduler,
ProportionScheduler,
REXScheduler,
]
LR_SCHEDULERS: Dict[str, SCHEDULER] = {
str(lr_scheduler.__name__).lower(): lr_scheduler for lr_scheduler in LR_SCHEDULER_LIST
Expand Down
1 change: 0 additions & 1 deletion pytorch_optimizer/base/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def step(self):

self.step_t += 1

# apply the lr to optimizer if it's provided
if self.optimizer is not None:
for param_group in self.optimizer.param_groups:
param_group['lr'] = value
Expand Down
66 changes: 66 additions & 0 deletions pytorch_optimizer/lr_scheduler/rex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import List

from torch.optim.lr_scheduler import _LRScheduler

from pytorch_optimizer.base.types import OPTIMIZER


class REXScheduler(_LRScheduler):
r"""Revisiting Budgeted Training with an Improved Schedule.
:param optimizer: Optimizer. wrapped optimizer instance.
:param total_steps: int. number of steps to optimize.
:param max_lr: float. max lr.
:param min_lr: float. min lr.
"""

def __init__(
self,
optimizer: OPTIMIZER,
total_steps: int,
max_lr: float = 1.0,
min_lr: float = 0.0,
):
self.total_steps = total_steps
self.max_lr = max_lr
self.min_lr = min_lr

self.step_t: int = 0
self.base_lrs: List[float] = []

# record current value in self._last_lr to match API from torch.optim.lr_scheduler
self.last_lr: List[float] = [self.max_lr]

super().__init__(optimizer)

self.init_lr()

def init_lr(self):
self.base_lrs = []
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.min_lr
self.base_lrs.append(self.min_lr)

def get_lr(self) -> float:
return self.last_lr[0]

def get_linear_lr(self) -> float:
if self.step_t >= self.total_steps:
return self.min_lr

progress: float = self.step_t / self.total_steps

return self.min_lr + (self.max_lr - self.min_lr) * ((1.0 - progress) / (1.0 - progress / 2.0))

def step(self):
value: float = self.get_linear_lr()

self.step_t += 1

if self.optimizer is not None:
for param_group in self.optimizer.param_groups:
param_group['lr'] = value

self.last_lr = [value]

return value
2 changes: 1 addition & 1 deletion tests/test_load_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_get_supported_optimizers():


def test_get_supported_lr_schedulers():
assert len(get_supported_lr_schedulers()) == 10
assert len(get_supported_lr_schedulers()) == 11


def test_get_supported_loss_functions():
Expand Down
24 changes: 24 additions & 0 deletions tests/test_lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler
from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler
from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler
from pytorch_optimizer.lr_scheduler.rex import REXScheduler
from tests.utils import Example

CAWR_RECIPES = [
Expand Down Expand Up @@ -263,6 +264,29 @@ def test_proportion_no_last_lr_scheduler():
np.testing.assert_almost_equal(2.0, rho_scheduler.get_lr(), 6)


def test_rex_lr_scheduler():
lrs = [
0.888888,
0.749999,
0.571428,
0.333333,
0.0,
]

base_optimizer = AdamP(Example().parameters())

lr_scheduler = REXScheduler(
base_optimizer,
total_steps=5,
max_lr=1.0,
min_lr=0.0,
)

for expected_lr in lrs:
_ = lr_scheduler.step()
np.testing.assert_almost_equal(expected_lr, lr_scheduler.get_lr(), 6)


def test_deberta_v3_large_lr_scheduler():
model = nn.Sequential(*[nn.Linear(1, 1, bias=False) for _ in range(400)])
deberta_v3_large_lr_scheduler(model)

0 comments on commit da65344

Please sign in to comment.