-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #222 from kozistr/feature/rex-lr-scheduler
[Feature] Implement REX lr scheduler
- Loading branch information
Showing
8 changed files
with
105 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from typing import List | ||
|
||
from torch.optim.lr_scheduler import _LRScheduler | ||
|
||
from pytorch_optimizer.base.types import OPTIMIZER | ||
|
||
|
||
class REXScheduler(_LRScheduler): | ||
r"""Revisiting Budgeted Training with an Improved Schedule. | ||
:param optimizer: Optimizer. wrapped optimizer instance. | ||
:param total_steps: int. number of steps to optimize. | ||
:param max_lr: float. max lr. | ||
:param min_lr: float. min lr. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
optimizer: OPTIMIZER, | ||
total_steps: int, | ||
max_lr: float = 1.0, | ||
min_lr: float = 0.0, | ||
): | ||
self.total_steps = total_steps | ||
self.max_lr = max_lr | ||
self.min_lr = min_lr | ||
|
||
self.step_t: int = 0 | ||
self.base_lrs: List[float] = [] | ||
|
||
# record current value in self._last_lr to match API from torch.optim.lr_scheduler | ||
self.last_lr: List[float] = [self.max_lr] | ||
|
||
super().__init__(optimizer) | ||
|
||
self.init_lr() | ||
|
||
def init_lr(self): | ||
self.base_lrs = [] | ||
for param_group in self.optimizer.param_groups: | ||
param_group['lr'] = self.min_lr | ||
self.base_lrs.append(self.min_lr) | ||
|
||
def get_lr(self) -> float: | ||
return self.last_lr[0] | ||
|
||
def get_linear_lr(self) -> float: | ||
if self.step_t >= self.total_steps: | ||
return self.min_lr | ||
|
||
progress: float = self.step_t / self.total_steps | ||
|
||
return self.min_lr + (self.max_lr - self.min_lr) * ((1.0 - progress) / (1.0 - progress / 2.0)) | ||
|
||
def step(self): | ||
value: float = self.get_linear_lr() | ||
|
||
self.step_t += 1 | ||
|
||
if self.optimizer is not None: | ||
for param_group in self.optimizer.param_groups: | ||
param_group['lr'] = value | ||
|
||
self.last_lr = [value] | ||
|
||
return value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters