Skip to content

Commit

Permalink
Ensure optimizers return updates of same dtype as params.
Browse files Browse the repository at this point in the history
Fix #1038, fix #377, fix #1051

PiperOrigin-RevId: 674026550
  • Loading branch information
vroulet authored and OptaxDev committed Sep 16, 2024
1 parent ee63e45 commit d825e6b
Show file tree
Hide file tree
Showing 23 changed files with 791 additions and 479 deletions.
10 changes: 10 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ Tree
NamedTupleKey
tree_add
tree_add_scalar_mul
tree_cast
tree_div
tree_dtype
tree_get
tree_get_all_with_path
tree_l1_norm
Expand Down Expand Up @@ -121,6 +123,14 @@ Tree add and scalar multiply
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_add_scalar_mul

Tree cast
~~~~~~~~~
.. autofunction:: tree_cast

Tree dtype
~~~~~~~~~~
.. autofunction:: tree_dtype

Tree divide
~~~~~~~~~~~
.. autofunction:: tree_div
Expand Down
6 changes: 5 additions & 1 deletion docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ years, well-cited (100+ citations), and demonstrate broad utility.
if they offer clear advantages over widely used methods.

If your algorithm doesn't meet the main package criteria, the {doc}`api/contrib`
directory is perfect for sharing innovative work.
directory is perfect for sharing innovative work. Please make sure that all
common tests (in `optax/contrib/_common_test.py` or `optax/_src/alias_test.py`)
are passed when you propose a new algorithm. These tests ensure the
interoperability of algorithms with different features of optax (such as
gradient accumulation or varying hyperparameters).


## Design Documents
Expand Down
133 changes: 111 additions & 22 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
from optax._src import update
from optax.losses import _classification
from optax.schedules import _inject
from optax.transforms import _accumulation
import optax.tree_utils as otu


import scipy.optimize as scipy_optimize
from sklearn import datasets
from sklearn import linear_model
Expand Down Expand Up @@ -164,13 +166,16 @@ def step(params, state):

params = initial_params
state = opt.init(params)
# A no-op change, to verify that tree map works.
state = otu.tree_map_params(opt, lambda v: v, state)

for _ in range(10000):
params, state = step(params, state)
with self.subTest('Test that tree_map_params works'):
# A no-op change, to verify that tree map works.
state = otu.tree_map_params(opt, lambda v: v, state)

with self.subTest('Test that optimization works'):
for _ in range(10000):
params, state = step(params, state)

chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2)
chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2)

@chex.all_variants
@parameterized.product(_OPTIMIZERS_UNDER_TEST)
Expand Down Expand Up @@ -211,24 +216,108 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
chex.assert_trees_all_close(
new_state_inject.inner_state, new_state, rtol=1e-4)

@parameterized.named_parameters([
('float32', 'float32'),
('bfloat16', 'bfloat16'),
('complex64', 'complex64'),
('None', None),
])
def test_explicit_dtype(self, dtype):
expected_dtype = jax.dtypes.canonicalize_dtype(dtype) # None -> float32
tx = alias.sgd(0.1, momentum=0.9, accumulator_dtype=dtype)
trace_state, _ = tx.init(jnp.array([0.0, 0.0]))
self.assertEqual(expected_dtype, getattr(trace_state, 'trace').dtype)
tx = alias.adam(0.1, mu_dtype=dtype)
adam_state, _ = tx.init(jnp.array([0.0, 0.0]))
self.assertEqual(expected_dtype, getattr(adam_state, 'mu').dtype)
tx = alias.adamw(0.1, mu_dtype=dtype)
adam_state, _, _ = tx.init(jnp.array([0.0, 0.0]))
self.assertEqual(expected_dtype, getattr(adam_state, 'mu').dtype)
@parameterized.product(
params_dtype=('bfloat16', 'float32', 'complex64', None),
state_dtype=('bfloat16', 'float32', 'complex64', None),
opt_name=('sgd_mom', 'adam', 'adamw'),
)
def test_explicit_dtype(self, params_dtype, state_dtype, opt_name):
if opt_name == 'sgd_mom':
opt = alias.sgd(0.1, momentum=0.9, accumulator_dtype=state_dtype)
attribute_name = 'trace'
elif opt_name in ['adam', 'adamw']:
opt = getattr(alias, opt_name)(0.1, mu_dtype=state_dtype)
attribute_name = 'mu'
else:
raise ValueError(f'Unsupported optimizer: {opt_name}')

params_dtype = jax.dtypes.canonicalize_dtype(params_dtype)
params = jnp.array([0.0, 0.0], dtype=params_dtype)
state_has_lower_dtype = (
jnp.promote_types(params_dtype, jnp.dtype(state_dtype))
== params_dtype
)
if state_dtype is None or state_has_lower_dtype:
state = opt.init(params)

with self.subTest('Test that attribute dtype is correct'):
if state_dtype is None:
expected_dtype = params_dtype
else:
expected_dtype = jax.dtypes.canonicalize_dtype(state_dtype)
attribute = otu.tree_get(state, attribute_name)
self.assertEqual(expected_dtype, attribute.dtype)

with self.subTest(
'Verifies that the updates keep the same type as params'
):
updates, _ = opt.update(jnp.ones_like(params), state, params)
self.assertEqual(updates.dtype, params.dtype)
else:
with self.subTest(
'Test that we forbid setting dtype s.t. updates dtype get promoted to'
' the state dtype'
):
with self.assertRaises(ValueError):
opt.init(params)

# Not testing with `without_device=True` because without_device set the
# variables to the host which appears to convert then the dtype, so we
# lose control of the dtype and the test fails.
@chex.variants(
with_jit=True, without_jit=True, with_device=True, with_pmap=True
)
@parameterized.product(
_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32')
)
def test_preserve_dtype(self, opt_name, opt_kwargs, dtype):
"""Test that the optimizers return updates of same dtype as params."""
# When debugging this test, note that operations like
# x = 0.5**jnp.asarray(1, dtype=jnp.int32)
# (appearing in e.g. optax.tree_utils.tree_bias_correction)
# are promoted (strictly) to float32 when jitted
# see https://github.com/google/jax/issues/23337
# This may end up letting updates have a dtype different from params.
# The solution is to fix the dtype of the result to the desired dtype
# (just as done in optax.tree_utils.tree_bias_correction).
dtype = jnp.dtype(dtype)
opt_factory = getattr(alias, opt_name)
opt = opt_factory(**opt_kwargs)
fun = lambda x: jnp.sum(x**2)

params = jnp.array([1.0, 2.0], dtype=dtype)
grads = jax.grad(fun)(params)
state = self.variant(opt.init)(params)
if opt_name == 'polyak_sgd':
update_kwargs = {'value': fun(params)}
else:
update_kwargs = {}
updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs)
self.assertEqual(updates.dtype, params.dtype)

@chex.variants(
with_jit=True, without_jit=True, with_device=True, with_pmap=True
)
@parameterized.product(_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32'))
def test_gradient_accumulation(self, opt_name, opt_kwargs, dtype):
"""Test that the optimizers can safely be used with optax.MultiSteps."""
# Checks for issues like https://github.com/google-deepmind/optax/issues/377
dtype = jnp.dtype(dtype)
opt_factory = getattr(alias, opt_name)
base_opt = opt_factory(**opt_kwargs)
opt = _accumulation.MultiSteps(base_opt, every_k_schedule=4)

fun = lambda x: jnp.sum(x**2)

params = jnp.array([1.0, 2.0], dtype=dtype)
grads = jax.grad(fun)(params)
state = self.variant(opt.init)(params)
if opt_name == 'polyak_sgd':
update_kwargs = {'value': fun(params)}
else:
update_kwargs = {}
updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs)
chex.assert_trees_all_equal(updates, jnp.zeros_like(grads))

##########################
# ALGORITHM SPECIFIC TESTS
Expand Down
29 changes: 16 additions & 13 deletions optax/_src/factorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,23 +126,23 @@ def init_fn(params):
"""Initialise the optimiser's state."""

def _init(param):
shape = param.shape
shape, dtype = param.shape, param.dtype
factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor)
if factored_dims is not None:
d1, d0 = factored_dims
vr_shape = np.delete(shape, d0)
vc_shape = np.delete(shape, d1)
return _UpdateResult(
update=jnp.zeros((1,)),
v_row=jnp.zeros(vr_shape),
v_col=jnp.zeros(vc_shape),
v=jnp.zeros((1,)))
update=jnp.zeros((1,), dtype=dtype),
v_row=jnp.zeros(vr_shape, dtype=dtype),
v_col=jnp.zeros(vc_shape, dtype=dtype),
v=jnp.zeros((1,), dtype=dtype))
else:
return _UpdateResult(
update=jnp.zeros((1,)),
v_row=jnp.zeros((1,)),
v_col=jnp.zeros((1,)),
v=jnp.zeros(param.shape))
update=jnp.zeros((1,), dtype=dtype),
v_row=jnp.zeros((1,), dtype=dtype),
v_col=jnp.zeros((1,), dtype=dtype),
v=jnp.zeros(param.shape, dtype=dtype))

return _to_state(
jnp.zeros([], jnp.int32), jax.tree_util.tree_map(_init, params))
Expand All @@ -153,13 +153,13 @@ def update_fn(grads, state, params):
raise ValueError(base.NO_PARAMS_MSG)

def _update(grad, v_row, v_col, v, param, step):
shape = param.shape
shape, dtype = param.shape, param.dtype
decay_rate_t = decay_rate_fn(step - step_offset, decay_rate)

# Scaled by factorized second moment statistics.
new_v_row = jnp.zeros((1,))
new_v_col = jnp.zeros((1,))
new_v = jnp.zeros((1,))
new_v_row = jnp.zeros((1,), dtype=dtype)
new_v_col = jnp.zeros((1,), dtype=dtype)
new_v = jnp.zeros((1,), dtype=dtype)

factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor)
if factored_dims is not None:
Expand All @@ -171,6 +171,8 @@ def _update(grad, v_row, v_col, v, param, step):
new_v_col = (
decay_rate_t * v_col +
(1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d1))
new_v_row = new_v_row.astype(dtype)
new_v_col = new_v_col.astype(dtype)
reduced_d1 = d1-1 if d1 > d0 else d1
row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True)
row_factor = (new_v_row / row_col_mean) ** -0.5
Expand All @@ -182,6 +184,7 @@ def _update(grad, v_row, v_col, v, param, step):
else:
grad_sqr = numerics.abs_sq(grad) + epsilon
new_v = decay_rate_t * v + (1. - decay_rate_t) * grad_sqr
new_v = new_v.astype(dtype)
update = grad * (new_v)**-0.5

return _UpdateResult(update, new_v_row, new_v_col, new_v)
Expand Down
49 changes: 49 additions & 0 deletions optax/_src/factorized_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from absl.testing import parameterized

import chex
import jax
import jax.numpy as jnp

from optax._src import factorized
from optax.transforms import _accumulation


class FactorizedTest(parameterized.TestCase):
Expand All @@ -45,6 +47,53 @@ def test_scale_by_factored_rms(self):
chex.assert_tree_all_finite((params, updates, state))
chex.assert_trees_all_equal_shapes(params, updates)

@chex.variants(with_jit=True, without_jit=True, with_device=True)
@parameterized.product(
factorized_dims=(True, False),
dtype=('bfloat16', 'float32')
)
def test_preserve_dtype(self, factorized_dims: bool, dtype: str):
"""Test that the optimizer returns updates of same dtype as params."""
dtype = jnp.dtype(dtype)
opt = factorized.scale_by_factored_rms()
fun = lambda x: jnp.sum(x**2)

if factorized_dims:
# The updates are factored only for large enough parameters
# default min_dim_size_to_factor is 128 so we use 129 here.
params = jnp.ones((129, 129), dtype=dtype)
else:
params = jnp.array([1.0, 2.0], dtype=dtype)
grads = jax.grad(fun)(params)
state = self.variant(opt.init)(params)
updates, _ = self.variant(opt.update)(grads, state, params)
self.assertEqual(updates.dtype, params.dtype)

@chex.variants(with_jit=True, without_jit=True, with_device=True)
@parameterized.product(
factorized_dims=(True, False),
dtype=('bfloat16', 'float32')
)
def test_gradient_accumulation(self, factorized_dims, dtype):
"""Test that the optimizers can safely be used with optax.MultiSteps."""
# Checks if https://github.com/google-deepmind/optax/issues/377 is fixed.
dtype = jnp.dtype(dtype)
base_opt = factorized.scale_by_factored_rms()
opt = _accumulation.MultiSteps(base_opt, every_k_schedule=4)

fun = lambda x: jnp.sum(x**2)

if factorized_dims:
# The updates are factored only for large enough parameters
# default min_dim_size_to_factor is 128 so we use 129 here.
params = jnp.ones((129, 129), dtype=dtype)
else:
params = jnp.array([1.0, 2.0], dtype=dtype)
grads = jax.grad(fun)(params)
state = self.variant(opt.init)(params)
updates, _ = self.variant(opt.update)(grads, state, params)
chex.assert_trees_all_equal(updates, jnp.zeros_like(grads))


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit d825e6b

Please sign in to comment.