From aeacce8f9feefadd4c5ef6b5bb87abd8f7a8b9e7 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sun, 23 May 2021 06:18:24 -0700 Subject: [PATCH] Add forward & reverse-mode AD tests (should break) --- tests/test_autodiff.py | 170 +++++++++++++++++++++++++++++++++++++++++ tests/utils.py | 49 ++++++++++-- 2 files changed, 214 insertions(+), 5 deletions(-) create mode 100644 tests/test_autodiff.py diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py new file mode 100644 index 0000000..ef0cbf1 --- /dev/null +++ b/tests/test_autodiff.py @@ -0,0 +1,170 @@ +"""Test autodiff. +""" + +from typing import Callable, Dict, Tuple, Type, cast + +import jax +import numpy as onp +from jax import numpy as jnp +from utils import assert_arrays_close, general_group_test, jacnumerical + +import jaxlie + +# Helper methods to test + shared Jacobian helpers +# We cache JITed Jacobian helpers to improve runtime +_jacfwd_jacrev_cache: Dict[Callable, Tuple[Callable, Callable]] = {} + + +def _assert_jacobians_close( + Group: Type[jaxlie.MatrixLieGroup], + f: Callable[ + [Type[jaxlie.MatrixLieGroup], jaxlie.hints.Array], jaxlie.hints.ArrayJax + ], + primal: jaxlie.hints.Array, +) -> None: + + if f not in _jacfwd_jacrev_cache: + _jacfwd_jacrev_cache[f] = ( + jax.jit(jax.jacfwd(f, argnums=1), static_argnums=0), + jax.jit(jax.jacrev(f, argnums=1), static_argnums=0), + ) + + jacfwd, jacrev = _jacfwd_jacrev_cache[f] + jacobian_fwd = jacfwd(Group, primal) + jacobian_rev = jacrev(Group, primal) + jacobian_numerical = jacnumerical( + lambda params: jax.jit(f, static_argnums=0)(Group, params) + )(primal) + + assert_arrays_close(jacobian_fwd, jacobian_rev) + assert_arrays_close(jacobian_fwd, jacobian_numerical, rtol=5e-4, atol=5e-4) + + +# Exp +@jax.partial(jax.jit, static_argnums=0) +def _exp( + Group: Type[jaxlie.MatrixLieGroup], generator: jaxlie.hints.Array +) -> jaxlie.hints.ArrayJax: + return cast(jnp.ndarray, Group.exp(generator).parameters()) + + +@general_group_test +def test_exp_random(Group: Type[jaxlie.MatrixLieGroup]): + """Check that exp Jacobians are consistent, with randomly sampled transforms.""" + generator = onp.random.randn(Group.tangent_dim) + _assert_jacobians_close(Group=Group, f=_exp, primal=generator) + + +@general_group_test +def test_exp_identity(Group: Type[jaxlie.MatrixLieGroup]): + """Check that exp Jacobians are consistent, with transforms close to the identity.""" + generator = onp.random.randn(Group.tangent_dim) * 1e-6 + _assert_jacobians_close(Group=Group, f=_exp, primal=generator) + + +# Log +def _log( + Group: Type[jaxlie.MatrixLieGroup], params: jaxlie.hints.Array +) -> jaxlie.hints.ArrayJax: + return Group.log(Group(params)) + + +@general_group_test +def test_log_random(Group: Type[jaxlie.MatrixLieGroup]): + """Check that log Jacobians are consistent, with randomly sampled transforms.""" + params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters() + _assert_jacobians_close(Group=Group, f=_log, primal=params) + + +@general_group_test +def test_log_identity(Group: Type[jaxlie.MatrixLieGroup]): + """Check that log Jacobians are consistent, with transforms close to the identity.""" + params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters() + _assert_jacobians_close(Group=Group, f=_log, primal=params) + + +# Adjoint +def _adjoint( + Group: Type[jaxlie.MatrixLieGroup], params: jaxlie.hints.Array +) -> jaxlie.hints.ArrayJax: + return cast(jnp.ndarray, Group(params).adjoint().flatten()) + + +@general_group_test +def test_adjoint_random(Group: Type[jaxlie.MatrixLieGroup]): + """Check that adjoint Jacobians are consistent, with randomly sampled transforms.""" + params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters() + _assert_jacobians_close(Group=Group, f=_adjoint, primal=params) + + +@general_group_test +def test_adjoint_identity(Group: Type[jaxlie.MatrixLieGroup]): + """Check that adjoint Jacobians are consistent, with transforms close to the identity.""" + params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters() + _assert_jacobians_close(Group=Group, f=_adjoint, primal=params) + + +# Apply +@jax.partial(jax.jit, static_argnums=0) +def _apply( + Group: Type[jaxlie.MatrixLieGroup], params: jaxlie.hints.Array +) -> jaxlie.hints.ArrayJax: + return Group(params) @ onp.ones(Group.space_dim) + + +@general_group_test +def test_apply_random(Group: Type[jaxlie.MatrixLieGroup]): + """Check that apply Jacobians are consistent, with randomly sampled transforms.""" + params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters() + _assert_jacobians_close(Group=Group, f=_apply, primal=params) + + +@general_group_test +def test_apply_identity(Group: Type[jaxlie.MatrixLieGroup]): + """Check that apply Jacobians are consistent, with transforms close to the identity.""" + params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters() + _assert_jacobians_close(Group=Group, f=_apply, primal=params) + + +# Multiply +@jax.partial(jax.jit, static_argnums=0) +def _multiply( + Group: Type[jaxlie.MatrixLieGroup], params: jaxlie.hints.Array +) -> jaxlie.hints.ArrayJax: + return cast(jnp.ndarray, (Group(params) @ Group(params)).parameters()) + + +@general_group_test +def test_multiply_random(Group: Type[jaxlie.MatrixLieGroup]): + """Check that multiply Jacobians are consistent, with randomly sampled transforms.""" + params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters() + _assert_jacobians_close(Group=Group, f=_multiply, primal=params) + + +@general_group_test +def test_multiply_identity(Group: Type[jaxlie.MatrixLieGroup]): + """Check that multiply Jacobians are consistent, with transforms close to the identity.""" + params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters() + _assert_jacobians_close(Group=Group, f=_multiply, primal=params) + + +# Inverse +@jax.partial(jax.jit, static_argnums=0) +def _inverse( + Group: Type[jaxlie.MatrixLieGroup], params: jaxlie.hints.Array +) -> jaxlie.hints.ArrayJax: + return cast(jnp.ndarray, Group(params).inverse().parameters()) + + +@general_group_test +def test_inverse_random(Group: Type[jaxlie.MatrixLieGroup]): + """Check that inverse Jacobians are consistent, with randomly sampled transforms.""" + params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters() + _assert_jacobians_close(Group=Group, f=_inverse, primal=params) + + +@general_group_test +def test_inverse_identity(Group: Type[jaxlie.MatrixLieGroup]): + """Check that inverse Jacobians are consistent, with transforms close to the identity.""" + params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters() + _assert_jacobians_close(Group=Group, f=_inverse, primal=params) diff --git a/tests/utils.py b/tests/utils.py index a158a5d..e690d83 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,9 +1,10 @@ import random -from typing import Any, Callable, Type, TypeVar, cast +from typing import Any, Callable, List, Type, TypeVar, cast import jax import numpy as onp import pytest +import scipy.optimize from hypothesis import given, settings from hypothesis import strategies as st from jax import numpy as jnp @@ -20,7 +21,19 @@ def sample_transform(Group: Type[T]) -> T: """Sample a random transform from a group.""" seed = random.getrandbits(32) - return cast(T, Group.sample_uniform(key=jax.random.PRNGKey(seed=seed))) + strategy = random.randint(0, 2) + + if strategy == 0: + # Uniform sampling + return cast(T, Group.sample_uniform(key=jax.random.PRNGKey(seed=seed))) + elif strategy == 1: + # Sample from normally-sampled tangent vector + return cast(T, Group.exp(onp.random.randn(Group.tangent_dim))) + elif strategy == 2: + # Sample near identity + return cast(T, Group.exp(onp.random.randn(Group.tangent_dim) * 1e-7)) + else: + assert False def general_group_test( @@ -33,7 +46,7 @@ def f_wrapped(Group: Type[jaxlie.MatrixLieGroup], _random_module) -> None: f(Group) # Disable timing check (first run requires JIT tracing and will be slower) - f_wrapped = settings(deadline=None)(f_wrapped) + f_wrapped = settings(deadline=None, max_examples=max_examples)(f_wrapped) # Add _random_module parameter f_wrapped = given(_random_module=st.random_module())(f_wrapped) @@ -71,9 +84,35 @@ def assert_transforms_close(a: jaxlie.MatrixLieGroup, b: jaxlie.MatrixLieGroup): assert_arrays_close(p1, p2) -def assert_arrays_close(*arrays: jaxlie.hints.Array): +def assert_arrays_close( + *arrays: jaxlie.hints.Array, + rtol: float = 1e-8, + atol: float = 1e-8, +): """Make sure two arrays are close. (and not NaN)""" for array1, array2 in zip(arrays[:-1], arrays[1:]): - onp.testing.assert_allclose(array1, array2, rtol=1e-8, atol=1e-8) + onp.testing.assert_allclose(array1, array2, rtol=rtol, atol=atol) assert not onp.any(onp.isnan(array1)) assert not onp.any(onp.isnan(array2)) + + +def jacnumerical( + f: Callable[[jaxlie.hints.Array], jaxlie.hints.ArrayJax] +) -> Callable[[jaxlie.hints.Array], jaxlie.hints.ArrayJax]: + """Decorator for computing numerical Jacobians of vector->vector functions.""" + + def wrapped(params: jaxlie.hints.Array) -> jaxlie.hints.ArrayJax: + output_dim: int + (output_dim,) = f(params).shape + + jacobian_rows: List[onp.ndarray] = [] + for i in range(output_dim): + jacobian_row: onp.ndarray = scipy.optimize.approx_fprime( + params, lambda p: f(p)[i], epsilon=1e-5 + ) + assert jacobian_row.shape == params.shape + jacobian_rows.append(jacobian_row) + + return jnp.stack(jacobian_rows, axis=0) + + return wrapped