Skip to content

Commit

Permalink
Merge pull request #979 from evanatyourservice:sophia_h
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 703157245
  • Loading branch information
OptaxDev committed Dec 5, 2024
2 parents 3f0a64b + 341c6c2 commit 3d8c391
Show file tree
Hide file tree
Showing 4 changed files with 361 additions and 12 deletions.
9 changes: 9 additions & 0 deletions docs/api/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ Experimental features and algorithms that don't meet the
schedule_free_eval_params
schedule_free_sgd
ScheduleFreeState
sophia
SophiaState
split_real_and_imaginary
SplitRealAndImaginaryState

Expand Down Expand Up @@ -99,3 +101,10 @@ Sharpness aware minimization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: sam
.. autoclass:: SAMState

Sophia
~~~~~~
.. autofunction:: hutchinson_estimator_diag_hessian
.. autoclass:: HutchinsonState
.. autofunction:: sophia
.. autoclass:: SophiaState
4 changes: 4 additions & 0 deletions optax/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,7 @@
from optax.contrib._schedule_free import schedule_free_eval_params
from optax.contrib._schedule_free import schedule_free_sgd
from optax.contrib._schedule_free import ScheduleFreeState
from optax.contrib._sophia import hutchinson_estimator_diag_hessian
from optax.contrib._sophia import HutchinsonState
from optax.contrib._sophia import sophia
from optax.contrib._sophia import SophiaState
47 changes: 35 additions & 12 deletions optax/contrib/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
import jax.numpy as jnp
from optax import contrib
from optax._src import alias
from optax._src import base
from optax._src import combine
from optax._src import numerics
from optax._src import update
from optax.schedules import _inject
from optax.transforms import _accumulation
from optax.tree_utils import _state_utils
from optax.tree_utils import _tree_math

# Testing contributions coded as GradientTransformations
_MAIN_OPTIMIZERS_UNDER_TEST = [
Expand All @@ -53,6 +55,10 @@
opt_name='schedule_free_adamw',
opt_kwargs=dict(learning_rate=1e-2, warmup_steps=5000),
),
dict(
opt_name='sophia',
opt_kwargs=dict(learning_rate=1e-2),
),
]
for optimizer in _MAIN_OPTIMIZERS_UNDER_TEST:
optimizer['wrapper_name'] = None
Expand Down Expand Up @@ -144,11 +150,10 @@ def _setup_parabola(dtype):
initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype)
final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype)

@jax.value_and_grad
def get_updates(params):
def obj_fn(params):
return jnp.sum(numerics.abs_sq(params - final_params))

return initial_params, final_params, get_updates
return initial_params, final_params, obj_fn


def _setup_rosenbrock(dtype):
Expand All @@ -159,13 +164,12 @@ def _setup_rosenbrock(dtype):
initial_params = jnp.array([0.0, 0.0], dtype=dtype)
final_params = jnp.array([a, a**2], dtype=dtype)

@jax.value_and_grad
def get_updates(params):
def obj_fn(params):
return numerics.abs_sq(a - params[0]) + b * numerics.abs_sq(
params[1] - params[0] ** 2
)

return initial_params, final_params, get_updates
return initial_params, final_params, obj_fn


class ContribTest(chex.TestCase):
Expand All @@ -188,16 +192,18 @@ def test_optimizers(
opt = _get_opt_factory(opt_name)(**opt_kwargs)
if wrapper_name is not None:
opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs)
initial_params, final_params, get_updates = target(dtype)
initial_params, final_params, obj_fn = target(dtype)

@jax.jit
def step(params, state):
value, updates = get_updates(params)
value, updates = jax.value_and_grad(obj_fn)(params)
if (
opt_name in ['momo', 'momo_adam']
or wrapper_name == 'reduce_on_plateau'
):
update_kwargs = {'value': value}
elif opt_name == 'sophia':
update_kwargs = {'obj_fn': obj_fn}
else:
update_kwargs = {}
updates, state = opt.update(updates, state, params, **update_kwargs)
Expand Down Expand Up @@ -266,14 +272,21 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
update_kwargs = {'value': jnp.array(1.0)}
else:
update_kwargs = {}
if opt_name == 'sophia':
obj_fn = lambda x: _tree_math.tree_l2_norm(x, squared=True)
update_fn = functools.partial(opt.update, obj_fn=obj_fn)
inject_update_fn = functools.partial(opt_inject.update, obj_fn=obj_fn)
else:
update_fn = opt.update
inject_update_fn = opt_inject.update

state = self.variant(opt.init)(params)
updates, new_state = self.variant(opt.update)(
updates, new_state = self.variant(update_fn)(
grads, state, params, **update_kwargs
)

state_inject = self.variant(opt_inject.init)(params)
updates_inject, new_state_inject = self.variant(opt_inject.update)(
updates_inject, new_state_inject = self.variant(inject_update_fn)(
grads, state_inject, params, **update_kwargs
)

Expand Down Expand Up @@ -320,7 +333,11 @@ def test_preserve_dtype(
update_kwargs = {'value': value}
else:
update_kwargs = {}
updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs)
if opt_name == 'sophia':
update_fn = functools.partial(opt.update, obj_fn=fun)
else:
update_fn = opt.update
updates, _ = self.variant(update_fn)(grads, state, params, **update_kwargs)
self.assertEqual(updates.dtype, params.dtype)

@chex.variants(
Expand All @@ -339,10 +356,16 @@ def test_gradient_accumulation(
opt = _get_opt_factory(opt_name)(**opt_kwargs)
if wrapper_name is not None:
opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs)
opt = _accumulation.MultiSteps(opt, every_k_schedule=4)

fun = lambda x: jnp.sum(x**2)

if opt_name == 'sophia':
update_fn = functools.partial(opt.update, obj_fn=fun)
else:
update_fn = opt.update
opt = base.GradientTransformationExtraArgs(opt.init, update_fn)
opt = _accumulation.MultiSteps(opt, every_k_schedule=4)

params = jnp.array([1.0, 2.0], dtype=dtype)
value, grads = jax.value_and_grad(fun)(params)
state = self.variant(opt.init)(params)
Expand Down
Loading

0 comments on commit 3d8c391

Please sign in to comment.