diff --git a/docs/api/contrib.rst b/docs/api/contrib.rst index fdee8656f..44c9bac8d 100644 --- a/docs/api/contrib.rst +++ b/docs/api/contrib.rst @@ -34,6 +34,8 @@ Experimental features and algorithms that don't meet the schedule_free_eval_params schedule_free_sgd ScheduleFreeState + sophia + SophiaState split_real_and_imaginary SplitRealAndImaginaryState @@ -99,3 +101,10 @@ Sharpness aware minimization ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: sam .. autoclass:: SAMState + +Sophia +~~~~~~ +.. autofunction:: hutchinson_estimator_diag_hessian +.. autoclass:: HutchinsonState +.. autofunction:: sophia +.. autoclass:: SophiaState diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index a310cc23b..8d559092b 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -51,3 +51,7 @@ from optax.contrib._schedule_free import schedule_free_eval_params from optax.contrib._schedule_free import schedule_free_sgd from optax.contrib._schedule_free import ScheduleFreeState +from optax.contrib._sophia import hutchinson_estimator_diag_hessian +from optax.contrib._sophia import HutchinsonState +from optax.contrib._sophia import sophia +from optax.contrib._sophia import SophiaState diff --git a/optax/contrib/_common_test.py b/optax/contrib/_common_test.py index b20118483..32954dd91 100644 --- a/optax/contrib/_common_test.py +++ b/optax/contrib/_common_test.py @@ -27,12 +27,14 @@ import jax.numpy as jnp from optax import contrib from optax._src import alias +from optax._src import base from optax._src import combine from optax._src import numerics from optax._src import update from optax.schedules import _inject from optax.transforms import _accumulation from optax.tree_utils import _state_utils +from optax.tree_utils import _tree_math # Testing contributions coded as GradientTransformations _MAIN_OPTIMIZERS_UNDER_TEST = [ @@ -53,6 +55,10 @@ opt_name='schedule_free_adamw', opt_kwargs=dict(learning_rate=1e-2, warmup_steps=5000), ), + dict( + opt_name='sophia', + opt_kwargs=dict(learning_rate=1e-2), + ), ] for optimizer in _MAIN_OPTIMIZERS_UNDER_TEST: optimizer['wrapper_name'] = None @@ -144,11 +150,10 @@ def _setup_parabola(dtype): initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype) final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype) - @jax.value_and_grad - def get_updates(params): + def obj_fn(params): return jnp.sum(numerics.abs_sq(params - final_params)) - return initial_params, final_params, get_updates + return initial_params, final_params, obj_fn def _setup_rosenbrock(dtype): @@ -159,13 +164,12 @@ def _setup_rosenbrock(dtype): initial_params = jnp.array([0.0, 0.0], dtype=dtype) final_params = jnp.array([a, a**2], dtype=dtype) - @jax.value_and_grad - def get_updates(params): + def obj_fn(params): return numerics.abs_sq(a - params[0]) + b * numerics.abs_sq( params[1] - params[0] ** 2 ) - return initial_params, final_params, get_updates + return initial_params, final_params, obj_fn class ContribTest(chex.TestCase): @@ -188,16 +192,18 @@ def test_optimizers( opt = _get_opt_factory(opt_name)(**opt_kwargs) if wrapper_name is not None: opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) - initial_params, final_params, get_updates = target(dtype) + initial_params, final_params, obj_fn = target(dtype) @jax.jit def step(params, state): - value, updates = get_updates(params) + value, updates = jax.value_and_grad(obj_fn)(params) if ( opt_name in ['momo', 'momo_adam'] or wrapper_name == 'reduce_on_plateau' ): update_kwargs = {'value': value} + elif opt_name == 'sophia': + update_kwargs = {'obj_fn': obj_fn} else: update_kwargs = {} updates, state = opt.update(updates, state, params, **update_kwargs) @@ -266,14 +272,21 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams( update_kwargs = {'value': jnp.array(1.0)} else: update_kwargs = {} + if opt_name == 'sophia': + obj_fn = lambda x: _tree_math.tree_l2_norm(x, squared=True) + update_fn = functools.partial(opt.update, obj_fn=obj_fn) + inject_update_fn = functools.partial(opt_inject.update, obj_fn=obj_fn) + else: + update_fn = opt.update + inject_update_fn = opt_inject.update state = self.variant(opt.init)(params) - updates, new_state = self.variant(opt.update)( + updates, new_state = self.variant(update_fn)( grads, state, params, **update_kwargs ) state_inject = self.variant(opt_inject.init)(params) - updates_inject, new_state_inject = self.variant(opt_inject.update)( + updates_inject, new_state_inject = self.variant(inject_update_fn)( grads, state_inject, params, **update_kwargs ) @@ -320,7 +333,11 @@ def test_preserve_dtype( update_kwargs = {'value': value} else: update_kwargs = {} - updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs) + if opt_name == 'sophia': + update_fn = functools.partial(opt.update, obj_fn=fun) + else: + update_fn = opt.update + updates, _ = self.variant(update_fn)(grads, state, params, **update_kwargs) self.assertEqual(updates.dtype, params.dtype) @chex.variants( @@ -339,10 +356,16 @@ def test_gradient_accumulation( opt = _get_opt_factory(opt_name)(**opt_kwargs) if wrapper_name is not None: opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) - opt = _accumulation.MultiSteps(opt, every_k_schedule=4) fun = lambda x: jnp.sum(x**2) + if opt_name == 'sophia': + update_fn = functools.partial(opt.update, obj_fn=fun) + else: + update_fn = opt.update + opt = base.GradientTransformationExtraArgs(opt.init, update_fn) + opt = _accumulation.MultiSteps(opt, every_k_schedule=4) + params = jnp.array([1.0, 2.0], dtype=dtype) value, grads = jax.value_and_grad(fun)(params) state = self.variant(opt.init)(params) diff --git a/optax/contrib/_sophia.py b/optax/contrib/_sophia.py new file mode 100644 index 000000000..07565ec9b --- /dev/null +++ b/optax/contrib/_sophia.py @@ -0,0 +1,313 @@ +# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Sophia optimizer. + +A contributed implementation of the Sophia optimizer from "Sophia: A Scalable +Stochastic Second-order Optimizer for Language Model Pre-training" +(https://arxiv.org/abs/2305.14342) by Hong Liu, Zhiyuan Li, David Hall, +Percy Liang, and Tengyu Ma. + +This contribution is heavily based on the implementation of Sophia by levanter +(https://github.com/stanford-crfm/levanter) with some changes. +""" +from typing import Any, Callable, NamedTuple, Optional, Union + +import jax +import jax.numpy as jnp +from optax import tree_utils as otu +from optax._src import base +from optax._src import combine +from optax._src import numerics +from optax._src import transform +from optax._src import utils +from optax.transforms import _adding + + +class HutchinsonState(NamedTuple): + key: jax.Array + + +def hutchinson_estimator_diag_hessian(random_seed: Optional[jax.Array] = None): + """Returns a GradientTransformation that computes the diagonal of the Hessian. + + The Hessian diagonal is estimated using Hutchinson's estimator, which is + unbiased but has high variance. + + Args: + random_seed: key used to generate random vectors. + + Returns: + GradientTransformationExtraArgs + """ + + def init_fn(params): + del params + key = random_seed if random_seed is not None else jax.random.PRNGKey(0) + return HutchinsonState(key=key) + + def update_fn(updates, state, params=None, obj_fn=None): + if params is None: + raise ValueError("params must be provided to hutchinson update function.") + if obj_fn is None: + raise ValueError("obj_fn must be provided to hutchinson update function.") + del updates + key, subkey = jax.random.split(state.key) + random_signs = otu.tree_random_like( + subkey, + params, + jax.random.rademacher, + dtype=jnp.float32, + ) + random_signs = otu.tree_cast(random_signs, otu.tree_dtype(params, "lowest")) + hvp = jax.jvp(jax.grad(obj_fn), (params,), (random_signs,))[1] + product = jax.tree.map(lambda h, r: h * r, hvp, random_signs) + return product, HutchinsonState(key=key) + + return base.GradientTransformationExtraArgs(init_fn, update_fn) + + +class SophiaState(NamedTuple): + """State for Sophia Optimizer.""" + + count: jax.Array # shape=(), dtype=jnp.int32 + mu: base.Updates # momentum + nu: base.Updates # EMA of hessian diagonal + hessian_fn_state: Any + + +def scale_by_sophia( + b1: float = 0.965, + b2: float = 0.99, + eps: float = 1e-8, + gamma: float = 0.01, + clip_threshold: Optional[float] = 1.0, + update_interval: int = 10, + hessian_diagonal_fn: Union[ + base.GradientTransformation, + base.GradientTransformationExtraArgs, + ] = hutchinson_estimator_diag_hessian(), + mu_dtype: Optional[Any] = None, + verbose: bool = False, + print_win_rate_every_n_steps: int = 0, +) -> base.GradientTransformationExtraArgs: + """Sophia optimizer. + + See :func:`optax.contrib.sophia` for more details. + + Args: + b1: Exponential decay rate for the first moment estimates. + b2: Exponential decay rate for the hessian diagonal estimates. Keep in mind + effective `b2` is `1 - (1 - b2) / update_interval`, e.g. default `b2` of + 0.99 is effectively 0.999 because default `update_interval` is every 10. + eps: Small constant to avoid division by zero. + gamma: Normalizing constant for the hessian diagonal. + clip_threshold: Threshold for clipping updates. + update_interval: Interval for updating the hessian diagonal. + hessian_diagonal_fn: GradientTransformation that computes the diagonal of + the Hessian. Default is Hutchinson's estimator (sophia-h). If using more + than one device, be sure this function properly averages the hessian + diagonal across devices. + mu_dtype: dtype of the first moment estimates. + verbose: If True, print win rate every n steps. + print_win_rate_every_n_steps: Print sophia win rate every n steps for + diagnostic purposes. Authors state this value should stay between 0.1 and + 0.5 during training. If win rate is too low, try increasing `gamma`. 0 to + turn off. + + Returns: + optax.GradientTransformationExtraArgs + """ + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + return SophiaState( + count=jnp.zeros([], jnp.int32), + mu=otu.tree_zeros_like(params, dtype=mu_dtype), + nu=otu.tree_zeros_like(params), + hessian_fn_state=hessian_diagonal_fn.init(params), + ) + + def update_fn(updates, state: SophiaState, params=None, **hess_fn_kwargs): + if params is None: + raise ValueError("params must be provided to sophia's update function.") + count_inc = numerics.safe_int32_increment(state.count) + + grads = updates + + # Sophia update + mu = otu.tree_update_moment(updates, state.mu, b1, 1) + mu_hat = otu.tree_bias_correction(mu, b1, count_inc) + updates = jax.tree.map( + lambda m, h: m / jnp.maximum(gamma * h, eps), mu_hat, state.nu + ) + if clip_threshold is not None: + sum_not_clipped = jax.tree.reduce( + lambda x, y: x + y, + jax.tree.map(lambda u: jnp.sum(jnp.abs(u) < clip_threshold), updates), + ) + total_tree_size = sum(x.size for x in jax.tree.leaves(updates)) + if verbose: + win_rate = sum_not_clipped / total_tree_size + jax.lax.cond( + count_inc % print_win_rate_every_n_steps == 0, + lambda: jax.debug.print("Sophia optimizer win rate: {}", win_rate), + lambda: None, + ) + + updates = jax.tree.map( + lambda u: jnp.clip(u, -clip_threshold, clip_threshold), updates + ) + + # Hessian diagonal update + def update_hessian_diag(hess_fn_state, nu): + hessian_diag, hess_fn_state = hessian_diagonal_fn.update( + grads, hess_fn_state, params=params, **hess_fn_kwargs + ) + + # ema of hessian diagonal + nu = otu.tree_update_moment(hessian_diag, nu, b2, 1) + + return hess_fn_state, nu + + hessian_fn_state, nu = jax.lax.cond( + jnp.equal(state.count % update_interval, 0), + update_hessian_diag, + lambda h, n: (h, n), + state.hessian_fn_state, + state.nu, + ) + + # Cast momentum back to mu_dtype + mu = otu.tree_cast(mu, mu_dtype) + + state = SophiaState( + count=count_inc, + mu=mu, + nu=nu, + hessian_fn_state=hessian_fn_state, + ) + return updates, state + + return base.GradientTransformationExtraArgs(init_fn, update_fn) + + +def sophia( + learning_rate: base.ScalarOrSchedule, + b1: float = 0.965, + b2: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 1e-4, + weight_decay_mask: Optional[ + Union[Any, Callable[[base.Params], Any]] + ] = None, + gamma: float = 0.01, + clip_threshold: Optional[float] = 1.0, + update_interval: int = 10, + hessian_diagonal_fn: Union[ + base.GradientTransformation, + base.GradientTransformationExtraArgs, + ] = hutchinson_estimator_diag_hessian(), + mu_dtype: Optional[Any] = None, + verbose: bool = False, + print_win_rate_every_n_steps: int = 0, +) -> base.GradientTransformationExtraArgs: + """Sophia optimizer. + + A separate GradientTransformation is required through the argument + `hessian_diagonal_fn` to compute the diagonal of the Hessian. Any extra + arguments required by the hessian_diagonal_fn's update function can be + passed through sophia's update function as trailing keyword arguments + (**kwargs). The default hessian_diagonal_fn is Hutchinson's estimator + and needs the objective function as an extra argument, `obj_fn`. + obj_fn must accept `params` as its only argument and return only a + scalar (the loss). + + For example, assuming your experiment's loss function is + `loss_fn(params, batch) -> loss, aux` that takes multiple arguments and + returns multiple outputs, we must modify it to `loss_fn(params) -> loss`: + + `obj_fn = lambda params: loss_fn(params, batch)[0]` + + where `batch` is the current step's batch. + + Then it can be passed to sophia's update function (which will pass it to the + hessian_diagonal_fn's update function): + + `updates, state = sophia.update(updates, state, params, obj_fn=sophia_obj_fn)` + + Optionally, you can write your own GradientTransformation to compute the + hessian diagonal. Use this file's hutchinson_estimator_diag_hessian function + as an example. If you are using more than one device, be sure the hessian + diagonal function properly averages the hessian diagonal across devices. + The default hessian_diagonal_fn does not do this, and would cause params to + diverge from each other across devices if using pmap for example. + + Args: + learning_rate: A global scaling factor, either fixed or evolving along + iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. + b1: Exponential decay rate for the first moment estimates. + b2: Exponential decay rate for the hessian diagonal estimates. Keep in mind + effective `b2` is `1 - (1 - b2) / update_interval`, e.g. default `b2` of + 0.99 is effectively 0.999 because default `update_interval` is every 10. + eps: Small constant to avoid division by zero. + weight_decay: Rate at which to decay weights. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the transformation to, and `False` for those you want to skip. + gamma: Normalizing constant for the hessian diagonal. + clip_threshold: Threshold for clipping updates. + update_interval: Interval for updating the hessian diagonal. + hessian_diagonal_fn: GradientTransformation that computes the diagonal of + the Hessian. Default is Hutchinson's estimator (sophia-h). If using more + than one device, be sure this function properly averages the hessian + diagonal across devices. + mu_dtype: dtype of the first moment estimates. + verbose: If True, print win rate every n steps. + print_win_rate_every_n_steps: Print sophia win rate every n steps for + diagnostic purposes. Authors state this value should stay between 0.1 and + 0.5 during training. If win rate is too low, try increasing `gamma`. 0 to + turn off. + + Returns: + optax.GradientTransformationExtraArgs + + References: + Liu et al., `Sophia: A Scalable Stochastic Second-order Optimizer for + Language Model Pre-training `_, 2023 + + `Levanter `_ + + .. note:: + We use a rademacher vector to estimate the diagonal of the Hessian, contrary + to the original implementation which uses a normal random vector. + """ + tx = [ + scale_by_sophia( + b1=b1, + b2=b2, + eps=eps, + gamma=gamma, + clip_threshold=clip_threshold, + update_interval=update_interval, + hessian_diagonal_fn=hessian_diagonal_fn, + mu_dtype=mu_dtype, + verbose=verbose, + print_win_rate_every_n_steps=print_win_rate_every_n_steps, + ), + _adding.add_decayed_weights(weight_decay, mask=weight_decay_mask), + transform.scale_by_learning_rate(learning_rate), + ] + return combine.chain(*tx)