From b02d98932f95fe0500c28698b38acb175e92e980 Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Fri, 20 Oct 2023 03:20:22 +0100 Subject: [PATCH] Lion Optimizer (#1062) * initial commit * test set, fixed readme and docstring * Refactor Lion implementation --------- Co-authored-by: kamathis4 --- configs/neox_arguments.md | 4 +- megatron/neox_arguments/neox_args.py | 2 +- megatron/optimizers.py | 84 ++++++++++++++++++++++++- megatron/training.py | 8 +++ tests/model/test_model_instantiation.py | 1 + tests/model/test_model_train.py | 1 + 6 files changed, 96 insertions(+), 4 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index badc95e46..1f3511456 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -601,11 +601,11 @@ Optimizer Arguments -- **optimizer_type**: typing.Literal['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd'] +- **optimizer_type**: typing.Literal['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd', 'lion'] Default = adam - Type of optimizer to use. Choose from ['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd'] + Type of optimizer to use. Choose from ['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd', 'lion'] NOTE: sgd will use MuSGD from Mup. Mup must be enabled for this optimizer. diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index e427b2551..e1a58b6d9 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -387,7 +387,7 @@ class NeoXArgsOptimizer(NeoXArgsTemplate): """ optimizer_type: Literal[ - "adam", "onebitadam", "cpu_adam", "cpu_torch_adam", "sm3", "madgrad_wd", "sgd" + "adam", "onebitadam", "cpu_adam", "cpu_torch_adam", "sm3", "madgrad_wd", "sgd", "lion" ] = "adam" """ Type of optimizer to use. Choose from ['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd'] diff --git a/megatron/optimizers.py b/megatron/optimizers.py index 8dc1d3264..fcf8a44c7 100644 --- a/megatron/optimizers.py +++ b/megatron/optimizers.py @@ -227,7 +227,7 @@ def _max_reduce_except_dim(tensor, dim): # closure is checked if callable or not since some code passes loss directly, rather than in closure param import math -from typing import Collection, TYPE_CHECKING, Any, Callable, Optional +from typing import Collection, TYPE_CHECKING, Any, Callable, Optional, Tuple import torch import torch.optim @@ -413,3 +413,85 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] self.state["k"] += 1 return loss + + +class Lion(Optimizer): + """ + Implements the Lion Algorithm + + .. / _Lion: https://arxiv.org/abs/2302.06675 + + Compared to AdamW and various adaptive optimizers that need to save both first and second moments, + Lion only needs the momentum, halving the additional memory footprint. This is beneficial when training large models + and / or with a large batch size. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate (default: 1e-2). + beta (float): + coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99)) + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + + """ + + def __init__( + self, + params, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + ): + if lr <= 0: + raise ValueError(f"Learning rate {lr} must be positive") + if weight_decay < 0: + raise ValueError(f"Weight decay {weight_decay} must be non-negative") + if not (0 <= betas[0] <= 1 and 0 <= betas[1] <= 1): + raise ValueError(f"Betas {betas} must be in range [0, 1)") + + defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) + super().__init__(params, defaults) + + def update(self, p, grad, exp_avg, lr, wd, beta1, beta2): + """https://arxiv.org/pdf/2302.06675.pdf#appendix.A""" + + # update model parameters + p.mul_(1 - lr * wd) + sign = exp_avg.clone().mul_(beta1).add(grad, alpha=1 - beta1).sign_() + p.add_(sign, alpha=-lr) + + # update EMA + exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) + + @torch.no_grad() + def step(self, closure: Optional[Callable] = None): + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + state = self.state[p] + + # init state - exponential moving average of gradient values + if len(state) == 0: + state["exp_avg"] = torch.zeros_like(p.data).detach() + + self.update( + p, + p.grad, + state["exp_avg"], + group["lr"], + group["weight_decay"], + group["betas"][0], + group["betas"][1], + ) + + return loss diff --git a/megatron/training.py b/megatron/training.py index 548f81cb0..ed9c0bcd0 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -526,6 +526,14 @@ def get_optimizer(model, neox_args): weight_decay=neox_args.weight_decay, **neox_args.optimizer["params"], ) + elif neox_args.optimizer_type.lower() == "lion": + from .optimizers import Lion + + optimizer = Lion( + param_groups, + weight_decay=neox_args.weight_decay, + **neox_args.optimizer["params"] + ) elif neox_args.optimizer_type.lower() == "adam": # Use Adam if neox_args.use_mup: diff --git a/tests/model/test_model_instantiation.py b/tests/model/test_model_instantiation.py index 60412ad9a..37654c34c 100644 --- a/tests/model/test_model_instantiation.py +++ b/tests/model/test_model_instantiation.py @@ -85,6 +85,7 @@ def wrapper(): {"type": "cpu_adam", "params": {"lr": 0.0006}}, {"type": "cpu_torch_adam", "params": {"lr": 0.0006}}, {"type": "sm3", "params": {"lr": 0.0006}}, + {"type": "lion", "params": {"lr": 0.0006}}, {"type": "madgrad_wd", "params": {"lr": 0.0006}}, ] } diff --git a/tests/model/test_model_train.py b/tests/model/test_model_train.py index b7dda1efd..be5d8fccc 100644 --- a/tests/model/test_model_train.py +++ b/tests/model/test_model_train.py @@ -119,6 +119,7 @@ def wrapper(): {"type": "cpu_adam", "params": {"lr": 0.0006}}, {"type": "cpu_torch_adam", "params": {"lr": 0.0006}}, {"type": "sm3", "params": {"lr": 0.0006}}, + {"type": "lion", "params": {"lr": 0.0006}}, {"type": "madgrad_wd", "params": {"lr": 0.0006}}, ] }