diff --git a/README.rst b/README.rst index 313060d8d..8398757ca 100644 --- a/README.rst +++ b/README.rst @@ -112,6 +112,8 @@ Supported Optimizers +--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+ | Adai | *Disentangling the Effects of Adaptive Learning Rate and Momentum* | `github `__ | `https://arxiv.org/abs/2006.15815 `__ | +--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+ +| GSAM | *Surrogate Gap Guided Sharpness-Aware Minimization* | `github `__ | `https://openreview.net/pdf?id=edONMAnhLu- `__ | ++--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+ Useful Resources ---------------- @@ -303,6 +305,8 @@ Citations `Adai `__ +`GSAM `__ + Citation -------- diff --git a/docs/optimizer_api.rst b/docs/optimizer_api.rst index 61e03fe4a..800b93021 100644 --- a/docs/optimizer_api.rst +++ b/docs/optimizer_api.rst @@ -1,5 +1,5 @@ -Implemented Optimizers -==================== +Optimizers +========== .. _AdaBelief: @@ -192,3 +192,11 @@ Shampoo .. autoclass:: pytorch_optimizer.Shampoo :members: + +.. _GSAM: + +GSAM +---- + +.. autoclass:: pytorch_optimizer.GSAM + :members: diff --git a/docs/scheduler_api.rst b/docs/scheduler_api.rst index da997cdd7..ce08fa9be 100644 --- a/docs/scheduler_api.rst +++ b/docs/scheduler_api.rst @@ -1,5 +1,5 @@ -Implemented LR Schedulers -========================= +LR Schedulers +============= .. _get_chebyshev_schedule: diff --git a/docs/util_api.rst b/docs/util_api.rst index 37a0d620a..86c84e2e8 100644 --- a/docs/util_api.rst +++ b/docs/util_api.rst @@ -1,5 +1,5 @@ -Implemented utilizations -======================== +Utilizations +============ .. _clip_grad_norm: @@ -56,3 +56,20 @@ SafeFP16Optimizer .. autoclass:: pytorch_optimizer.SafeFP16Optimizer :members: + +.. _enable_running_stats: + +enable_running_stats +-------------------- + +.. autoclass:: pytorch_optimizer.enable_running_stats + :members: + + +.. _disable_running_stats: + +disable_running_stats +--------------------- + +.. autoclass:: pytorch_optimizer.disable_running_stats + :members: diff --git a/pyproject.toml b/pyproject.toml index ee7f04914..984793f55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "pytorch_optimizer" -version = "2.1.1" -description = "Bunch of optimizer implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas." +version = "2.2.0" +description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas." license = "Apache-2.0" authors = ["kozistr "] maintainers = ["kozistr "] @@ -51,6 +51,11 @@ name = "torch" url = "https://download.pytorch.org/whl/cpu" secondary = true +[tool.coverage.run] +omit = [ + "./pytorch_optimizer/optimizer/gsam.py", +] + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index cec1f76ba..bbf07a704 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -11,6 +11,8 @@ ) from pytorch_optimizer.lr_scheduler.chebyshev import get_chebyshev_schedule from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts +from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler +from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler from pytorch_optimizer.optimizer.adabelief import AdaBelief from pytorch_optimizer.optimizer.adabound import AdaBound from pytorch_optimizer.optimizer.adai import Adai @@ -22,6 +24,7 @@ from pytorch_optimizer.optimizer.diffrgrad import DiffRGrad from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer from pytorch_optimizer.optimizer.gc import centralize_gradient +from pytorch_optimizer.optimizer.gsam import GSAM from pytorch_optimizer.optimizer.lamb import Lamb from pytorch_optimizer.optimizer.lars import LARS from pytorch_optimizer.optimizer.lookahead import Lookahead @@ -38,6 +41,8 @@ from pytorch_optimizer.optimizer.shampoo import Shampoo from pytorch_optimizer.optimizer.utils import ( clip_grad_norm, + disable_running_stats, + enable_running_stats, get_optimizer_parameters, matrix_power, normalize_gradient, @@ -74,6 +79,10 @@ CosineAnnealingWarmRestarts, CyclicLR, OneCycleLR, + CosineScheduler, + PolyScheduler, + LinearScheduler, + ProportionScheduler, ] LR_SCHEDULERS: Dict[str, SCHEDULER] = { str(lr_scheduler.__name__).lower(): lr_scheduler for lr_scheduler in LR_SCHEDULER_LIST diff --git a/pytorch_optimizer/base/exception.py b/pytorch_optimizer/base/exception.py index 0233f7162..f81c9b0f1 100644 --- a/pytorch_optimizer/base/exception.py +++ b/pytorch_optimizer/base/exception.py @@ -25,3 +25,21 @@ class NoClosureError(Exception): def __init__(self, optimizer_name: str): self.message: str = f'[-] {optimizer_name} requires closure.' super().__init__(self.message) + + +class NegativeLRError(Exception): + """Raised when learning rate is negative""" + + def __init__(self, lr: float, lr_type: str = ''): + self.note: str = 'learning rate' if lr_type == '' else lr_type + self.message: str = f'[-] {self.note} must be positive. ({lr} > 0)' + super().__init__(self.message) + + +class NegativeStepError(Exception): + """Raised when step is negative""" + + def __init__(self, num_steps: int, step_type: str = ''): + self.note: str = 'step' if step_type == '' else step_type + self.message: str = f'[-] {self.note} must be positive. ({num_steps} > 0)' + super().__init__(self.message) diff --git a/pytorch_optimizer/base/base_optimizer.py b/pytorch_optimizer/base/optimizer.py similarity index 96% rename from pytorch_optimizer/base/base_optimizer.py rename to pytorch_optimizer/base/optimizer.py index aa1418702..f7fc2f42c 100644 --- a/pytorch_optimizer/base/base_optimizer.py +++ b/pytorch_optimizer/base/optimizer.py @@ -2,6 +2,7 @@ import torch +from pytorch_optimizer.base.exception import NegativeLRError from pytorch_optimizer.base.types import BETAS @@ -9,7 +10,7 @@ class BaseOptimizer(ABC): @staticmethod def validate_learning_rate(learning_rate: float): if learning_rate < 0.0: - raise ValueError(f'[-] learning rate {learning_rate} must be positive') + raise NegativeLRError(learning_rate) @staticmethod def validate_beta(beta: float): diff --git a/pytorch_optimizer/base/scheduler.py b/pytorch_optimizer/base/scheduler.py new file mode 100644 index 000000000..cced06909 --- /dev/null +++ b/pytorch_optimizer/base/scheduler.py @@ -0,0 +1,91 @@ +from abc import ABC, abstractmethod +from typing import List + +from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError +from pytorch_optimizer.base.types import OPTIMIZER + + +class BaseLinearWarmupScheduler(ABC): + r"""BaseLinearWarmupScheduler class. The LR Scheduler class based on this class has linear warmup strategy. + + :param optimizer: Optimizer. OPTIMIZER. It will set learning rate to all trainable parameters in optimizer. + :param t_max: int. total steps to train. + :param max_lr: float. maximum lr. + :param min_lr: float. minimum lr. + :param init_lr: float. initial lr. + :param warmup_steps: int. steps to warm-up. + """ + + def __init__( + self, + optimizer: OPTIMIZER, + t_max: int, + max_lr: float, + min_lr: float = 0.0, + init_lr: float = 0.0, + warmup_steps: int = 0, + ): + self.optimizer = optimizer + self.total_steps = t_max + self.max_lr = max_lr + self.min_lr = min_lr + self.init_lr = init_lr + self.warmup_steps = warmup_steps + + self.step_t: int = 0 + self.base_lrs: List[float] = [] + + # record current value in self._last_lr to match API from torch.optim.lr_scheduler + self.last_lr: List[float] = [init_lr] + + self.validate_parameters() + + self._init_lr() + + def validate_parameters(self): + if self.min_lr < 0: + raise NegativeLRError(self.min_lr, 'min_lr') + + if self.max_lr < 0: + raise NegativeLRError(self.max_lr, 'max_lr') + + if self.init_lr < 0: + raise NegativeLRError(self.init_lr, 'init_lr') + + if self.total_steps < 0: + raise NegativeStepError(self.total_steps, 't_max') + + if self.warmup_steps < 0: + raise NegativeStepError(self.warmup_steps, 'warmup_steps') + + def _init_lr(self): + self.base_lrs = [] + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.min_lr + self.base_lrs.append(self.min_lr) + + def step(self): + if self.step_t < self.warmup_steps: + value = self.init_lr + (self.max_lr - self.init_lr) * self.step_t / self.warmup_steps + elif self.step_t == self.warmup_steps: + value = self.max_lr + else: + value = self._step() + + self.step_t += 1 + + # apply the lr to optimizer if it's provided + if self.optimizer is not None: + for param_group in self.optimizer.param_groups: + param_group['lr'] = value + + self.last_lr = [value] + + return value + + @abstractmethod + def _step(self) -> float: + raise NotImplementedError + + def get_lr(self) -> float: + return self.last_lr[0] diff --git a/pytorch_optimizer/lr_scheduler/linear_warmup.py b/pytorch_optimizer/lr_scheduler/linear_warmup.py new file mode 100644 index 000000000..a35ebb74e --- /dev/null +++ b/pytorch_optimizer/lr_scheduler/linear_warmup.py @@ -0,0 +1,36 @@ +import math + +import numpy as np + +from pytorch_optimizer.base.scheduler import BaseLinearWarmupScheduler + + +class LinearScheduler(BaseLinearWarmupScheduler): + def _step(self) -> float: + return self.max_lr + (self.min_lr - self.max_lr) * (self.step_t - self.warmup_steps) / ( + self.total_steps - self.warmup_steps + ) + + +class CosineScheduler(BaseLinearWarmupScheduler): + def _step(self) -> float: + phase: float = (self.step_t - self.warmup_steps) / (self.total_steps - self.warmup_steps) * math.pi + return self.min_lr + (self.max_lr - self.min_lr) * (np.cos(phase) + 1.0) / 2.0 + + +class PolyScheduler(BaseLinearWarmupScheduler): + r"""Poly LR Scheduler + + :param: poly_order: float. lr scheduler decreases with steps. + """ + + def __init__(self, poly_order: float = 0.5, **kwargs): + self.poly_order = poly_order + + if poly_order <= 0: + raise ValueError(f'[-] poly_order must be positive. {poly_order}') + + super().__init__(**kwargs) + + def _step(self) -> float: + return self.min_lr + (self.max_lr - self.min_lr) * (self.step_t - self.warmup_steps) ** self.poly_order diff --git a/pytorch_optimizer/lr_scheduler/proportion.py b/pytorch_optimizer/lr_scheduler/proportion.py new file mode 100644 index 000000000..3ceeee2d2 --- /dev/null +++ b/pytorch_optimizer/lr_scheduler/proportion.py @@ -0,0 +1,49 @@ +from typing import List + + +class ProportionScheduler: + r"""ProportionScheduler (Rho Scheduler of GSAM) + This scheduler outputs a value that evolves proportional to lr_scheduler. + + :param lr_scheduler: learning rate scheduler. + :param max_lr: float. maximum lr. + :param min_lr: float. minimum lr. + :param max_value: float. maximum of rho. + :param min_value: float. minimum of rho. + """ + + def __init__( + self, lr_scheduler, max_lr: float, min_lr: float = 0.0, max_value: float = 2.0, min_value: float = 2.0 + ): + self.lr_scheduler = lr_scheduler + self.max_lr = max_lr + self.min_lr = min_lr + self.max_value = max_value + self.min_value = min_value + + self.step_t: int = 0 + self.last_lr: List[float] = [] + + self.step() + + def get_lr(self) -> float: + return self.last_lr[0] + + def step(self) -> float: + self.step_t += 1 + + if hasattr(self.lr_scheduler, 'last_lr'): + lr = self.lr_scheduler.last_lr[0] + else: + lr = self.lr_scheduler.optimizer.param_groups[0]['lr'] + + if self.max_lr > self.min_lr: + value = self.min_value + (self.max_value - self.min_value) * (lr - self.min_lr) / ( + self.max_lr - self.min_lr + ) + else: + value = self.max_value + + self.last_lr = [value] + + return value diff --git a/pytorch_optimizer/optimizer/adabelief.py b/pytorch_optimizer/optimizer/adabelief.py index 0b4b57991..b6ce274c4 100644 --- a/pytorch_optimizer/optimizer/adabelief.py +++ b/pytorch_optimizer/optimizer/adabelief.py @@ -3,8 +3,8 @@ import torch from torch.optim.optimizer import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS diff --git a/pytorch_optimizer/optimizer/adabound.py b/pytorch_optimizer/optimizer/adabound.py index f8eb8f1ec..4bc3ec941 100644 --- a/pytorch_optimizer/optimizer/adabound.py +++ b/pytorch_optimizer/optimizer/adabound.py @@ -4,8 +4,8 @@ import torch from torch.optim.optimizer import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS diff --git a/pytorch_optimizer/optimizer/adai.py b/pytorch_optimizer/optimizer/adai.py index 113c35763..d98b76ed7 100644 --- a/pytorch_optimizer/optimizer/adai.py +++ b/pytorch_optimizer/optimizer/adai.py @@ -3,8 +3,8 @@ import torch from torch.optim.optimizer import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.gc import centralize_gradient diff --git a/pytorch_optimizer/optimizer/adamp.py b/pytorch_optimizer/optimizer/adamp.py index 351864803..0d56966f1 100644 --- a/pytorch_optimizer/optimizer/adamp.py +++ b/pytorch_optimizer/optimizer/adamp.py @@ -3,8 +3,8 @@ import torch from torch.optim.optimizer import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.gc import centralize_gradient from pytorch_optimizer.optimizer.utils import projection diff --git a/pytorch_optimizer/optimizer/adan.py b/pytorch_optimizer/optimizer/adan.py index dc8dcee50..5acc6412d 100644 --- a/pytorch_optimizer/optimizer/adan.py +++ b/pytorch_optimizer/optimizer/adan.py @@ -3,8 +3,8 @@ import torch from torch.optim.optimizer import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.gc import centralize_gradient diff --git a/pytorch_optimizer/optimizer/adapnm.py b/pytorch_optimizer/optimizer/adapnm.py index 06dddeb0e..0ca1173b5 100644 --- a/pytorch_optimizer/optimizer/adapnm.py +++ b/pytorch_optimizer/optimizer/adapnm.py @@ -3,8 +3,8 @@ import torch from torch.optim.optimizer import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS diff --git a/pytorch_optimizer/optimizer/diffgrad.py b/pytorch_optimizer/optimizer/diffgrad.py index 96633e664..1c9dbdbcc 100644 --- a/pytorch_optimizer/optimizer/diffgrad.py +++ b/pytorch_optimizer/optimizer/diffgrad.py @@ -3,8 +3,8 @@ import torch from torch.optim.optimizer import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS diff --git a/pytorch_optimizer/optimizer/diffrgrad.py b/pytorch_optimizer/optimizer/diffrgrad.py index 9ea06d28b..8f9d282b6 100644 --- a/pytorch_optimizer/optimizer/diffrgrad.py +++ b/pytorch_optimizer/optimizer/diffrgrad.py @@ -3,8 +3,8 @@ import torch from torch.optim.optimizer import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS diff --git a/pytorch_optimizer/optimizer/gsam.py b/pytorch_optimizer/optimizer/gsam.py new file mode 100644 index 000000000..4960f9c61 --- /dev/null +++ b/pytorch_optimizer/optimizer/gsam.py @@ -0,0 +1,228 @@ +from contextlib import ExitStack +from typing import Callable, Dict, Optional, Tuple + +import torch +from torch import nn +from torch.distributed import ReduceOp, all_reduce, get_world_size, is_initialized +from torch.optim.optimizer import Optimizer + +from pytorch_optimizer.base.optimizer import BaseOptimizer +from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, OPTIMIZER, PARAMETERS +from pytorch_optimizer.optimizer.utils import disable_running_stats, enable_running_stats + + +class GSAM(Optimizer, BaseOptimizer): + r"""Surrogate Gap Guided Sharpness-Aware Minimization + + Example: + Here's an example:: + + model = YourModel() + base_optimizer = AdamP(model.parameters()) + lr_scheduler = LinearScheduler(base_optimizer, t_max=num_total_steps) + rho_scheduler = ProportionScheduler(lr_scheduler, max_lr=max_lr) + optimizer = GSAM(model.parameters(), base_optimizer, model, rho_scheduler) + + def loss_fn(predictions, targets): + return F.cross_entropy(predictions, targets) + + for inputs, targets in data: + optimizer.set_closure(loss_fn, inputs, targets) + predictions, loss = optimizer.step() + lr_scheduler.step() + optimizer.update_rho_t() + + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups + :param base_optimizer: Optimizer. base optimizer. + :param model: nn.Module. model. + :param alpha: float. rho alpha. + :param rho_scheduler: rho scheduler. + :param adaptive: bool. element-wise Adaptive SAM. + :param perturb_eps: float. epsilon for perturbation. + :param kwargs: Dict. parameters for optimizer. + """ + + def __init__( + self, + params: PARAMETERS, + base_optimizer: OPTIMIZER, + model: nn.Module, + rho_scheduler, + alpha: float = 0.4, + adaptive: bool = False, + perturb_eps: float = 1e-12, + **kwargs, + ): + self.model = model + self.rho_scheduler = rho_scheduler + self.alpha = alpha + self.adaptive = adaptive + self.perturb_eps = perturb_eps + + self.rho_t: float = 0.0 + self.forward_backward_func: Optional[Callable] = None + + if hasattr(ReduceOp, 'AVG'): + self.grad_reduce = ReduceOp.AVG + self.manual_average: bool = False + else: # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes + self.grad_reduce = ReduceOp.SUM + self.manual_average: bool = True + + self.base_optimizer = base_optimizer + self.param_groups = self.base_optimizer.param_groups + + self.validate_parameters() + + defaults: DEFAULTS = dict(adaptive=adaptive, **kwargs) + super().__init__(params, defaults) + + self.update_rho_t() + + def validate_parameters(self): + self.validate_alpha(self.alpha) + + @property + def __name__(self) -> str: + return 'GSAM' + + @torch.no_grad() + def reset(self): + pass + + @torch.no_grad() + def update_rho_t(self) -> float: + self.rho_t = self.rho_scheduler.step() + return self.rho_t + + @torch.no_grad() + def perturb_weights(self, rho: float): + grad_norm = self.grad_norm(weight_adaptive=self.adaptive) + for group in self.param_groups: + scale = rho / (grad_norm + self.perturb_eps) + + for p in group['params']: + if p.grad is None: + continue + + self.state[p]['old_g'] = p.grad.clone() + + e_w = (torch.pow(p, 2) if self.adaptive else 1.0) * p.grad * scale.to(p) + p.add_(e_w) # climb to the local maximum "w + e(w)" + + self.state[p]['e_w'] = e_w + + @torch.no_grad() + def un_perturb(self): + for group in self.param_groups: + for p in group['params']: + if 'e_w' in self.state[p].keys(): + p.sub_(self.state[p]['e_w']) + + @torch.no_grad() + def gradient_decompose(self, alpha: float = 0.0): + inner_prod = 0.0 + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + inner_prod += torch.sum(self.state[p]['old_g'] * p.grad) + + new_grad_norm = self.grad_norm(by=None) + old_grad_norm = self.grad_norm(by='old_g') + + cosine = inner_prod / (new_grad_norm * old_grad_norm + self.perturb_eps) + + # gradient decomposition + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + vertical = self.state[p]['old_g'] - cosine * old_grad_norm * p.grad.data / ( + new_grad_norm + self.perturb_eps + ) + p.grad.add_(vertical, alpha=-alpha) + + @torch.no_grad() + def sync_grad(self): + if is_initialized(): + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + all_reduce(p.grad, op=self.grad_reduce) + if self.manual_average: + p.grad.div_(float(get_world_size())) + + @torch.no_grad() + def grad_norm(self, by: Optional[str] = None, weight_adaptive: bool = False) -> torch.Tensor: + return torch.norm( + torch.stack( + [ + ((torch.abs(p) if weight_adaptive else 1.0) * (p.grad if not by else self.state[p][by])).norm(p=2) + for group in self.param_groups + for p in group['params'] + if p.grad is not None + ] + ), + p=2, + ) + + def maybe_no_sync(self): + return self.model.no_sync() if is_initialized() else ExitStack() + + @torch.no_grad() + def set_closure(self, loss_fn: nn.Module, inputs: torch.Tensor, targets: torch.Tensor, **kwargs): + r"""set closure + create self.forward_backward_func, which is a function such that self.forward_backward_func() automatically + performs forward and backward passes. This function does not take any arguments, and the inputs and + targets data should be pre-set in the definition of partial-function. + + :param loss_fn: nn.Module. loss function. + :param inputs: torch.Tensor. inputs. + :param targets: torch.Tensor. targets. + """ + + def get_grad(): + self.base_optimizer.zero_grad() + with torch.enable_grad(): + outputs = self.model(inputs) + loss = loss_fn(outputs, targets, **kwargs) + + loss.backward() + + return outputs, loss.detach() + + self.forward_backward_func = get_grad + + @torch.no_grad() + def step(self, closure: CLOSURE = None) -> Tuple[torch.Tensor, float]: + get_grad = closure if closure else self.forward_backward_func + + with self.maybe_no_sync(): + outputs, loss = get_grad() + + self.perturb_weights(rho=self.rho_t) + + disable_running_stats(self.model) + + get_grad() + + self.gradient_decompose(self.alpha) + + self.un_perturb() + + self.sync_grad() + + self.base_optimizer.step() + + enable_running_stats(self.model) + + return outputs, loss + + def load_state_dict(self, state_dict: Dict): + super().load_state_dict(state_dict) + self.base_optimizer.param_groups = self.param_groups diff --git a/pytorch_optimizer/optimizer/lamb.py b/pytorch_optimizer/optimizer/lamb.py index 827aeb669..2b0cf6974 100644 --- a/pytorch_optimizer/optimizer/lamb.py +++ b/pytorch_optimizer/optimizer/lamb.py @@ -3,8 +3,8 @@ import torch from torch.optim import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS diff --git a/pytorch_optimizer/optimizer/lars.py b/pytorch_optimizer/optimizer/lars.py index bdbdd6048..30ed2faa9 100644 --- a/pytorch_optimizer/optimizer/lars.py +++ b/pytorch_optimizer/optimizer/lars.py @@ -1,8 +1,8 @@ import torch from torch.optim import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS diff --git a/pytorch_optimizer/optimizer/lookahead.py b/pytorch_optimizer/optimizer/lookahead.py index da4f4f576..726db020b 100644 --- a/pytorch_optimizer/optimizer/lookahead.py +++ b/pytorch_optimizer/optimizer/lookahead.py @@ -4,7 +4,7 @@ import torch from torch.optim import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER, STATE diff --git a/pytorch_optimizer/optimizer/madgrad.py b/pytorch_optimizer/optimizer/madgrad.py index 76bed3b84..70b36d346 100644 --- a/pytorch_optimizer/optimizer/madgrad.py +++ b/pytorch_optimizer/optimizer/madgrad.py @@ -8,8 +8,8 @@ import torch from torch.optim import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS diff --git a/pytorch_optimizer/optimizer/nero.py b/pytorch_optimizer/optimizer/nero.py index 1ac172ba2..748dfc965 100644 --- a/pytorch_optimizer/optimizer/nero.py +++ b/pytorch_optimizer/optimizer/nero.py @@ -1,8 +1,8 @@ import torch from torch.optim.optimizer import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.utils import neuron_mean, neuron_norm diff --git a/pytorch_optimizer/optimizer/pcgrad.py b/pytorch_optimizer/optimizer/pcgrad.py index 80f921bf4..a961f4c82 100644 --- a/pytorch_optimizer/optimizer/pcgrad.py +++ b/pytorch_optimizer/optimizer/pcgrad.py @@ -5,7 +5,7 @@ import torch from torch import nn -from pytorch_optimizer.base.base_optimizer import BaseOptimizer +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import OPTIMIZER from pytorch_optimizer.optimizer.utils import flatten_grad, un_flatten_grad diff --git a/pytorch_optimizer/optimizer/pnm.py b/pytorch_optimizer/optimizer/pnm.py index 92f9d47b0..4deb87d8c 100644 --- a/pytorch_optimizer/optimizer/pnm.py +++ b/pytorch_optimizer/optimizer/pnm.py @@ -3,8 +3,8 @@ import torch from torch.optim.optimizer import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS diff --git a/pytorch_optimizer/optimizer/radam.py b/pytorch_optimizer/optimizer/radam.py index 0a8c492d2..aa80904b1 100644 --- a/pytorch_optimizer/optimizer/radam.py +++ b/pytorch_optimizer/optimizer/radam.py @@ -3,8 +3,8 @@ import torch from torch.optim.optimizer import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS diff --git a/pytorch_optimizer/optimizer/ralamb.py b/pytorch_optimizer/optimizer/ralamb.py index 2db9fac6a..17277ff9e 100644 --- a/pytorch_optimizer/optimizer/ralamb.py +++ b/pytorch_optimizer/optimizer/ralamb.py @@ -3,8 +3,8 @@ import torch from torch.optim import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS diff --git a/pytorch_optimizer/optimizer/ranger.py b/pytorch_optimizer/optimizer/ranger.py index 66d465af1..ecd2b1bb4 100644 --- a/pytorch_optimizer/optimizer/ranger.py +++ b/pytorch_optimizer/optimizer/ranger.py @@ -3,8 +3,8 @@ import torch from torch.optim.optimizer import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.gc import centralize_gradient diff --git a/pytorch_optimizer/optimizer/ranger21.py b/pytorch_optimizer/optimizer/ranger21.py index 53a703e43..bcfe39cc0 100644 --- a/pytorch_optimizer/optimizer/ranger21.py +++ b/pytorch_optimizer/optimizer/ranger21.py @@ -5,8 +5,8 @@ import torch.nn.functional as F from torch.optim import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.agc import agc from pytorch_optimizer.optimizer.gc import centralize_gradient diff --git a/pytorch_optimizer/optimizer/sam.py b/pytorch_optimizer/optimizer/sam.py index f2bb0a755..7ed23a4bb 100644 --- a/pytorch_optimizer/optimizer/sam.py +++ b/pytorch_optimizer/optimizer/sam.py @@ -3,8 +3,8 @@ import torch from torch.optim.optimizer import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoClosureError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, OPTIMIZER, PARAMETERS diff --git a/pytorch_optimizer/optimizer/sgdp.py b/pytorch_optimizer/optimizer/sgdp.py index 3906088af..9fa8a7e62 100644 --- a/pytorch_optimizer/optimizer/sgdp.py +++ b/pytorch_optimizer/optimizer/sgdp.py @@ -1,8 +1,8 @@ import torch from torch.optim.optimizer import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.utils import projection diff --git a/pytorch_optimizer/optimizer/shampoo.py b/pytorch_optimizer/optimizer/shampoo.py index 1a0467c85..701285289 100644 --- a/pytorch_optimizer/optimizer/shampoo.py +++ b/pytorch_optimizer/optimizer/shampoo.py @@ -1,8 +1,8 @@ import torch from torch.optim.optimizer import Optimizer -from pytorch_optimizer.base.base_optimizer import BaseOptimizer from pytorch_optimizer.base.exception import NoSparseGradientError +from pytorch_optimizer.base.optimizer import BaseOptimizer from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS from pytorch_optimizer.optimizer.utils import matrix_power diff --git a/pytorch_optimizer/optimizer/utils.py b/pytorch_optimizer/optimizer/utils.py index e8880b875..150158820 100644 --- a/pytorch_optimizer/optimizer/utils.py +++ b/pytorch_optimizer/optimizer/utils.py @@ -6,6 +6,7 @@ from torch import nn from torch.distributed import all_reduce from torch.nn import functional as F +from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.utils import clip_grad_norm_ from pytorch_optimizer.base.types import PARAMETERS @@ -194,3 +195,24 @@ def neuron_mean(x: torch.Tensor) -> torch.Tensor: x = x.view(x.shape[0], -1) return x.mean(dim=1).view(*view_shape) + + +def disable_running_stats(model): + r"""disable running stats (momentum) of BatchNorm""" + + def _disable(module): + if isinstance(module, _BatchNorm): + module.backup_momentum = module.momentum + module.momentum = 0 + + model.apply(_disable) + + +def enable_running_stats(model): + r"""enable running stats (momentum) of BatchNorm""" + + def _enable(module): + if isinstance(module, _BatchNorm) and hasattr(module, 'backup_momentum'): + module.momentum = module.backup_momentum + + model.apply(_enable) diff --git a/tests/constants.py b/tests/constants.py index dc8b8a3c1..c7f9c873b 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -74,6 +74,7 @@ INVALID_OPTIMIZER_NAMES: List[str] = [ 'asam', 'sam', + 'gsam', 'pcgrad', 'adamd', 'lookahead', diff --git a/tests/test_load_lr_schedulers.py b/tests/test_load_lr_schedulers.py index fa3530e2e..9e35e284e 100644 --- a/tests/test_load_lr_schedulers.py +++ b/tests/test_load_lr_schedulers.py @@ -16,4 +16,4 @@ def test_load_optimizers_invalid(invalid_lr_scheduler_names): def test_get_supported_lr_schedulers(): - assert len(get_supported_lr_schedulers()) == 6 + assert len(get_supported_lr_schedulers()) == 10 diff --git a/tests/test_lr_scheduler_parameters.py b/tests/test_lr_scheduler_parameters.py index 5b0e67d77..c2c66dbb1 100644 --- a/tests/test_lr_scheduler_parameters.py +++ b/tests/test_lr_scheduler_parameters.py @@ -2,13 +2,14 @@ import pytest from pytorch_optimizer import AdamP, get_chebyshev_schedule +from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts +from pytorch_optimizer.lr_scheduler.linear_warmup import PolyScheduler from tests.utils import Example def test_cosine_annealing_warmup_restarts_params(): - model = Example() - optimizer = AdamP(model.parameters()) + optimizer = AdamP(Example().parameters()) with pytest.raises(ValueError): CosineAnnealingWarmupRestarts( @@ -33,6 +34,28 @@ def test_cosine_annealing_warmup_restarts_params(): lr_scheduler.step(epoch=None) +def test_linear_warmup_lr_scheduler_params(): + optimizer = AdamP(Example().parameters()) + + with pytest.raises(ValueError): + PolyScheduler(poly_order=-1, optimizer=optimizer, t_max=1, max_lr=1) + + with pytest.raises(NegativeLRError): + PolyScheduler(optimizer=optimizer, t_max=1, max_lr=-1) + + with pytest.raises(NegativeLRError): + PolyScheduler(optimizer=optimizer, t_max=1, max_lr=1, min_lr=-1) + + with pytest.raises(NegativeLRError): + PolyScheduler(optimizer=optimizer, t_max=1, max_lr=1, min_lr=1, init_lr=-1) + + with pytest.raises(NegativeStepError): + PolyScheduler(optimizer=optimizer, t_max=-1, max_lr=1, min_lr=1, init_lr=1) + + with pytest.raises(NegativeStepError): + PolyScheduler(optimizer=optimizer, t_max=1, max_lr=1, min_lr=1, init_lr=1, warmup_steps=-1) + + def test_chebyshev_params(): with pytest.raises(IndexError): get_chebyshev_schedule(2) diff --git a/tests/test_lr_schedulers.py b/tests/test_lr_schedulers.py index b69c9bb94..92add24fa 100644 --- a/tests/test_lr_schedulers.py +++ b/tests/test_lr_schedulers.py @@ -1,3 +1,5 @@ +from typing import Tuple + import numpy as np import pytest from torch import nn @@ -6,6 +8,8 @@ from pytorch_optimizer.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler from pytorch_optimizer.lr_scheduler.chebyshev import chebyshev_perm from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts +from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler +from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler from tests.utils import Example CAWR_RECIPES = [ @@ -72,6 +76,43 @@ ], ), ] +LWL_RECIPE = [ + 0.001, + 0.0028, + 0.0046, + 0.0064, + 0.0082, + 0.01, + 0.00802, + 0.00604, + 0.00406, + 0.00208, +] +LWC_RECIPE = [ + 0.001, + 0.00280, + 0.00460, + 0.00640, + 0.00820, + 0.01000, + 0.00905, + 0.00658, + 0.00352, + 0.00105, +] +LWP_RECIPE = [ + 0.001, + 0.002800, + 0.004600, + 0.006400, + 0.008200, + 0.010000, + 0.010000, + 0.014101, + 0.017247, + 0.019900, +] +PROPORTION_LEARNING_RATES = [(1e-1, 1e-1, 2.0), (1e-1, 1e-3, 1.090909)] @pytest.mark.parametrize('cosine_annealing_warmup_restart_param', CAWR_RECIPES) @@ -110,11 +151,78 @@ def test_cosine_annealing_warmup_restarts(cosine_annealing_warmup_restart_param) np.testing.assert_almost_equal(expected_lrs[epoch], lr) -def test_get_chebyshev_schedule(): +def test_get_chebyshev_scheduler(): np.testing.assert_almost_equal(get_chebyshev_schedule(3), 1.81818182, decimal=6) np.testing.assert_array_equal(chebyshev_perm(5), np.asarray([0, 7, 3, 4, 1, 6, 2, 5])) +def test_linear_warmup_linear_scheduler(): + optimizer = AdamP(Example().parameters()) + lr_scheduler = LinearScheduler(optimizer, t_max=10, max_lr=1e-2, min_lr=1e-4, init_lr=1e-3, warmup_steps=5) + + for expected_lr in LWL_RECIPE: + lr_scheduler.step() + np.testing.assert_almost_equal(expected_lr, lr_scheduler.get_lr()) + + +def test_linear_warmup_cosine_scheduler(): + optimizer = AdamP(Example().parameters()) + lr_scheduler = CosineScheduler(optimizer, t_max=10, max_lr=1e-2, min_lr=1e-4, init_lr=1e-3, warmup_steps=5) + + for expected_lr in LWC_RECIPE: + lr_scheduler.step() + np.testing.assert_almost_equal(expected_lr, lr_scheduler.get_lr(), 5) + + +def test_linear_warmup_poly_scheduler(): + optimizer = AdamP(Example().parameters()) + lr_scheduler = PolyScheduler(optimizer=optimizer, t_max=10, max_lr=1e-2, min_lr=1e-4, init_lr=1e-3, warmup_steps=5) + + for expected_lr in LWP_RECIPE: + lr_scheduler.step() + np.testing.assert_almost_equal(expected_lr, lr_scheduler.get_lr(), 6) + + +@pytest.mark.parametrize('proportion_learning_rate', PROPORTION_LEARNING_RATES) +def test_proportion_scheduler(proportion_learning_rate: Tuple[float, float, float]): + base_optimizer = AdamP(Example().parameters()) + lr_scheduler = CosineScheduler( + base_optimizer, t_max=10, max_lr=proportion_learning_rate[0], min_lr=proportion_learning_rate[1], init_lr=1e-2 + ) + rho_scheduler = ProportionScheduler( + lr_scheduler, + max_lr=proportion_learning_rate[0], + min_lr=proportion_learning_rate[1], + max_value=2.0, + min_value=1.0, + ) + + for _ in range(10): + _ = rho_scheduler.step() + np.testing.assert_almost_equal(proportion_learning_rate[2], rho_scheduler.get_lr(), 6) + + +def test_proportion_no_last_lr_scheduler(): + base_optimizer = AdamP(Example().parameters()) + lr_scheduler = CosineAnnealingWarmupRestarts( + base_optimizer, + first_cycle_steps=10, + max_lr=1e-2, + min_lr=1e-2, + ) + rho_scheduler = ProportionScheduler( + lr_scheduler, + max_lr=1e-2, + min_lr=1e-2, + max_value=2.0, + min_value=1.0, + ) + + for _ in range(10): + _ = rho_scheduler.step() + np.testing.assert_almost_equal(2.0, rho_scheduler.get_lr(), 6) + + def test_deberta_v3_large_lr_scheduler(): try: from transformers import AutoConfig, AutoModel diff --git a/tests/test_optimizer_parameters.py b/tests/test_optimizer_parameters.py index e02f11885..1f47bdcde 100644 --- a/tests/test_optimizer_parameters.py +++ b/tests/test_optimizer_parameters.py @@ -3,7 +3,7 @@ from torch import nn from pytorch_optimizer import SAM, Lookahead, PCGrad, Ranger21, SafeFP16Optimizer, load_optimizer -from pytorch_optimizer.base.exception import ZeroParameterSizeError +from pytorch_optimizer.base.exception import NegativeLRError, ZeroParameterSizeError from tests.constants import BETA_OPTIMIZER_NAMES, PULLBACK_MOMENTUM, VALID_OPTIMIZER_NAMES from tests.utils import Example, simple_parameter @@ -12,7 +12,7 @@ def test_learning_rate(optimizer_name): optimizer = load_optimizer(optimizer_name) - with pytest.raises(ValueError): + with pytest.raises(NegativeLRError): if optimizer_name == 'ranger21': optimizer(None, num_iterations=100, lr=-1e-2) else: diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index a19bd52f4..2b2f118f0 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -3,7 +3,16 @@ import torch from torch import nn -from pytorch_optimizer import SAM, Lookahead, PCGrad, SafeFP16Optimizer, load_optimizer +from pytorch_optimizer import ( + GSAM, + SAM, + CosineScheduler, + Lookahead, + PCGrad, + ProportionScheduler, + SafeFP16Optimizer, + load_optimizer, +) from pytorch_optimizer.base.exception import NoClosureError, ZeroParameterSizeError from tests.constants import ADAMD_SUPPORTED_OPTIMIZERS, ADAPTIVE_FLAGS, OPTIMIZERS, PULLBACK_MOMENTUM from tests.utils import ( @@ -167,6 +176,41 @@ def closure(): assert tensor_to_numpy(init_loss) > 2.0 * tensor_to_numpy(loss) +@pytest.mark.parametrize('adaptive', ADAPTIVE_FLAGS) +def test_gsam_optimizers(adaptive): + if not torch.cuda.is_available(): + pytest.skip(f'there\'s no cuda. skip test.') + + (x_data, y_data), model, loss_fn = build_environment() + + x_data = x_data.cuda() + y_data = y_data.cuda() + model.cuda() + + lr: float = 5e-1 + num_iterations: int = 50 + + base_optimizer = load_optimizer('adamp')(model.parameters(), lr=lr) + lr_scheduler = CosineScheduler(base_optimizer, t_max=num_iterations, max_lr=lr, min_lr=lr, init_lr=lr) + rho_scheduler = ProportionScheduler(lr_scheduler, max_lr=lr, min_lr=lr) + optimizer = GSAM( + model.parameters(), base_optimizer=base_optimizer, model=model, rho_scheduler=rho_scheduler, adaptive=adaptive + ) + + init_loss, loss = np.inf, np.inf + for _ in range(num_iterations): + optimizer.set_closure(loss_fn, x_data, y_data) + _, loss = optimizer.step() + + if init_loss == np.inf: + init_loss = loss + + lr_scheduler.step() + optimizer.update_rho_t() + + assert tensor_to_numpy(init_loss) > tensor_to_numpy(loss) + + @pytest.mark.parametrize('optimizer_adamd_config', ADAMD_SUPPORTED_OPTIMIZERS, ids=ids) def test_adamd_optimizers(optimizer_adamd_config): (x_data, y_data), model, loss_fn = build_environment() diff --git a/tests/test_utils.py b/tests/test_utils.py index 15d783ac8..08af6e847 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,6 +7,8 @@ from pytorch_optimizer.optimizer.utils import ( clip_grad_norm, + disable_running_stats, + enable_running_stats, get_optimizer_parameters, has_overflow, is_valid_parameters, @@ -98,3 +100,19 @@ def test_is_valid_parameters(): after_parameters = get_optimizer_parameters(model, weight_decay=1e-3, wd_ban_list=wd_ban_list) assert is_valid_parameters(after_parameters) + + +def test_running_stats(): + model = nn.Sequential( + nn.Linear(1, 1), + nn.BatchNorm2d(1), + ) + model[1].momentum = 0.1 + + disable_running_stats(model) + + assert (model[1].momentum == 0) and (model[1].backup_momentum == 0.1) + + enable_running_stats(model) + + assert model[1].momentum == 0.1