-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #94 from kozistr/refactor/lr_scheduler
[Feature] Implement GSAM optimizer
- Loading branch information
Showing
42 changed files
with
720 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
Implemented LR Schedulers | ||
========================= | ||
LR Schedulers | ||
============= | ||
|
||
.. _get_chebyshev_schedule: | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
[tool.poetry] | ||
name = "pytorch_optimizer" | ||
version = "2.1.1" | ||
description = "Bunch of optimizer implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas." | ||
version = "2.2.0" | ||
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas." | ||
license = "Apache-2.0" | ||
authors = ["kozistr <[email protected]>"] | ||
maintainers = ["kozistr <[email protected]>"] | ||
|
@@ -51,6 +51,11 @@ name = "torch" | |
url = "https://download.pytorch.org/whl/cpu" | ||
secondary = true | ||
|
||
[tool.coverage.run] | ||
omit = [ | ||
"./pytorch_optimizer/optimizer/gsam.py", | ||
] | ||
|
||
[build-system] | ||
requires = ["poetry-core>=1.0.0"] | ||
build-backend = "poetry.core.masonry.api" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError | ||
from pytorch_optimizer.base.types import OPTIMIZER | ||
|
||
|
||
class BaseLinearWarmupScheduler(ABC): | ||
r"""BaseLinearWarmupScheduler class. The LR Scheduler class based on this class has linear warmup strategy. | ||
:param optimizer: Optimizer. OPTIMIZER. It will set learning rate to all trainable parameters in optimizer. | ||
:param t_max: int. total steps to train. | ||
:param max_lr: float. maximum lr. | ||
:param min_lr: float. minimum lr. | ||
:param init_lr: float. initial lr. | ||
:param warmup_steps: int. steps to warm-up. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
optimizer: OPTIMIZER, | ||
t_max: int, | ||
max_lr: float, | ||
min_lr: float = 0.0, | ||
init_lr: float = 0.0, | ||
warmup_steps: int = 0, | ||
): | ||
self.optimizer = optimizer | ||
self.total_steps = t_max | ||
self.max_lr = max_lr | ||
self.min_lr = min_lr | ||
self.init_lr = init_lr | ||
self.warmup_steps = warmup_steps | ||
|
||
self.step_t: int = 0 | ||
self.base_lrs: List[float] = [] | ||
|
||
# record current value in self._last_lr to match API from torch.optim.lr_scheduler | ||
self.last_lr: List[float] = [init_lr] | ||
|
||
self.validate_parameters() | ||
|
||
self._init_lr() | ||
|
||
def validate_parameters(self): | ||
if self.min_lr < 0: | ||
raise NegativeLRError(self.min_lr, 'min_lr') | ||
|
||
if self.max_lr < 0: | ||
raise NegativeLRError(self.max_lr, 'max_lr') | ||
|
||
if self.init_lr < 0: | ||
raise NegativeLRError(self.init_lr, 'init_lr') | ||
|
||
if self.total_steps < 0: | ||
raise NegativeStepError(self.total_steps, 't_max') | ||
|
||
if self.warmup_steps < 0: | ||
raise NegativeStepError(self.warmup_steps, 'warmup_steps') | ||
|
||
def _init_lr(self): | ||
self.base_lrs = [] | ||
for param_group in self.optimizer.param_groups: | ||
param_group['lr'] = self.min_lr | ||
self.base_lrs.append(self.min_lr) | ||
|
||
def step(self): | ||
if self.step_t < self.warmup_steps: | ||
value = self.init_lr + (self.max_lr - self.init_lr) * self.step_t / self.warmup_steps | ||
elif self.step_t == self.warmup_steps: | ||
value = self.max_lr | ||
else: | ||
value = self._step() | ||
|
||
self.step_t += 1 | ||
|
||
# apply the lr to optimizer if it's provided | ||
if self.optimizer is not None: | ||
for param_group in self.optimizer.param_groups: | ||
param_group['lr'] = value | ||
|
||
self.last_lr = [value] | ||
|
||
return value | ||
|
||
@abstractmethod | ||
def _step(self) -> float: | ||
raise NotImplementedError | ||
|
||
def get_lr(self) -> float: | ||
return self.last_lr[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import math | ||
|
||
import numpy as np | ||
|
||
from pytorch_optimizer.base.scheduler import BaseLinearWarmupScheduler | ||
|
||
|
||
class LinearScheduler(BaseLinearWarmupScheduler): | ||
def _step(self) -> float: | ||
return self.max_lr + (self.min_lr - self.max_lr) * (self.step_t - self.warmup_steps) / ( | ||
self.total_steps - self.warmup_steps | ||
) | ||
|
||
|
||
class CosineScheduler(BaseLinearWarmupScheduler): | ||
def _step(self) -> float: | ||
phase: float = (self.step_t - self.warmup_steps) / (self.total_steps - self.warmup_steps) * math.pi | ||
return self.min_lr + (self.max_lr - self.min_lr) * (np.cos(phase) + 1.0) / 2.0 | ||
|
||
|
||
class PolyScheduler(BaseLinearWarmupScheduler): | ||
r"""Poly LR Scheduler | ||
:param: poly_order: float. lr scheduler decreases with steps. | ||
""" | ||
|
||
def __init__(self, poly_order: float = 0.5, **kwargs): | ||
self.poly_order = poly_order | ||
|
||
if poly_order <= 0: | ||
raise ValueError(f'[-] poly_order must be positive. {poly_order}') | ||
|
||
super().__init__(**kwargs) | ||
|
||
def _step(self) -> float: | ||
return self.min_lr + (self.max_lr - self.min_lr) * (self.step_t - self.warmup_steps) ** self.poly_order |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from typing import List | ||
|
||
|
||
class ProportionScheduler: | ||
r"""ProportionScheduler (Rho Scheduler of GSAM) | ||
This scheduler outputs a value that evolves proportional to lr_scheduler. | ||
:param lr_scheduler: learning rate scheduler. | ||
:param max_lr: float. maximum lr. | ||
:param min_lr: float. minimum lr. | ||
:param max_value: float. maximum of rho. | ||
:param min_value: float. minimum of rho. | ||
""" | ||
|
||
def __init__( | ||
self, lr_scheduler, max_lr: float, min_lr: float = 0.0, max_value: float = 2.0, min_value: float = 2.0 | ||
): | ||
self.lr_scheduler = lr_scheduler | ||
self.max_lr = max_lr | ||
self.min_lr = min_lr | ||
self.max_value = max_value | ||
self.min_value = min_value | ||
|
||
self.step_t: int = 0 | ||
self.last_lr: List[float] = [] | ||
|
||
self.step() | ||
|
||
def get_lr(self) -> float: | ||
return self.last_lr[0] | ||
|
||
def step(self) -> float: | ||
self.step_t += 1 | ||
|
||
if hasattr(self.lr_scheduler, 'last_lr'): | ||
lr = self.lr_scheduler.last_lr[0] | ||
else: | ||
lr = self.lr_scheduler.optimizer.param_groups[0]['lr'] | ||
|
||
if self.max_lr > self.min_lr: | ||
value = self.min_value + (self.max_value - self.min_value) * (lr - self.min_lr) / ( | ||
self.max_lr - self.min_lr | ||
) | ||
else: | ||
value = self.max_value | ||
|
||
self.last_lr = [value] | ||
|
||
return value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.