Skip to content

Commit

Permalink
Add forward & reverse-mode AD tests (should break)
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed May 23, 2021
1 parent 94a9387 commit aeacce8
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 5 deletions.
170 changes: 170 additions & 0 deletions tests/test_autodiff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Test autodiff.
"""

from typing import Callable, Dict, Tuple, Type, cast

import jax
import numpy as onp
from jax import numpy as jnp
from utils import assert_arrays_close, general_group_test, jacnumerical

import jaxlie

# Helper methods to test + shared Jacobian helpers
# We cache JITed Jacobian helpers to improve runtime
_jacfwd_jacrev_cache: Dict[Callable, Tuple[Callable, Callable]] = {}


def _assert_jacobians_close(
Group: Type[jaxlie.MatrixLieGroup],
f: Callable[
[Type[jaxlie.MatrixLieGroup], jaxlie.hints.Array], jaxlie.hints.ArrayJax
],
primal: jaxlie.hints.Array,
) -> None:

if f not in _jacfwd_jacrev_cache:
_jacfwd_jacrev_cache[f] = (
jax.jit(jax.jacfwd(f, argnums=1), static_argnums=0),
jax.jit(jax.jacrev(f, argnums=1), static_argnums=0),
)

jacfwd, jacrev = _jacfwd_jacrev_cache[f]
jacobian_fwd = jacfwd(Group, primal)
jacobian_rev = jacrev(Group, primal)
jacobian_numerical = jacnumerical(
lambda params: jax.jit(f, static_argnums=0)(Group, params)
)(primal)

assert_arrays_close(jacobian_fwd, jacobian_rev)
assert_arrays_close(jacobian_fwd, jacobian_numerical, rtol=5e-4, atol=5e-4)


# Exp
@jax.partial(jax.jit, static_argnums=0)
def _exp(
Group: Type[jaxlie.MatrixLieGroup], generator: jaxlie.hints.Array
) -> jaxlie.hints.ArrayJax:
return cast(jnp.ndarray, Group.exp(generator).parameters())


@general_group_test
def test_exp_random(Group: Type[jaxlie.MatrixLieGroup]):
"""Check that exp Jacobians are consistent, with randomly sampled transforms."""
generator = onp.random.randn(Group.tangent_dim)
_assert_jacobians_close(Group=Group, f=_exp, primal=generator)


@general_group_test
def test_exp_identity(Group: Type[jaxlie.MatrixLieGroup]):
"""Check that exp Jacobians are consistent, with transforms close to the identity."""
generator = onp.random.randn(Group.tangent_dim) * 1e-6
_assert_jacobians_close(Group=Group, f=_exp, primal=generator)


# Log
def _log(
Group: Type[jaxlie.MatrixLieGroup], params: jaxlie.hints.Array
) -> jaxlie.hints.ArrayJax:
return Group.log(Group(params))


@general_group_test
def test_log_random(Group: Type[jaxlie.MatrixLieGroup]):
"""Check that log Jacobians are consistent, with randomly sampled transforms."""
params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters()
_assert_jacobians_close(Group=Group, f=_log, primal=params)


@general_group_test
def test_log_identity(Group: Type[jaxlie.MatrixLieGroup]):
"""Check that log Jacobians are consistent, with transforms close to the identity."""
params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters()
_assert_jacobians_close(Group=Group, f=_log, primal=params)


# Adjoint
def _adjoint(
Group: Type[jaxlie.MatrixLieGroup], params: jaxlie.hints.Array
) -> jaxlie.hints.ArrayJax:
return cast(jnp.ndarray, Group(params).adjoint().flatten())


@general_group_test
def test_adjoint_random(Group: Type[jaxlie.MatrixLieGroup]):
"""Check that adjoint Jacobians are consistent, with randomly sampled transforms."""
params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters()
_assert_jacobians_close(Group=Group, f=_adjoint, primal=params)


@general_group_test
def test_adjoint_identity(Group: Type[jaxlie.MatrixLieGroup]):
"""Check that adjoint Jacobians are consistent, with transforms close to the identity."""
params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters()
_assert_jacobians_close(Group=Group, f=_adjoint, primal=params)


# Apply
@jax.partial(jax.jit, static_argnums=0)
def _apply(
Group: Type[jaxlie.MatrixLieGroup], params: jaxlie.hints.Array
) -> jaxlie.hints.ArrayJax:
return Group(params) @ onp.ones(Group.space_dim)


@general_group_test
def test_apply_random(Group: Type[jaxlie.MatrixLieGroup]):
"""Check that apply Jacobians are consistent, with randomly sampled transforms."""
params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters()
_assert_jacobians_close(Group=Group, f=_apply, primal=params)


@general_group_test
def test_apply_identity(Group: Type[jaxlie.MatrixLieGroup]):
"""Check that apply Jacobians are consistent, with transforms close to the identity."""
params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters()
_assert_jacobians_close(Group=Group, f=_apply, primal=params)


# Multiply
@jax.partial(jax.jit, static_argnums=0)
def _multiply(
Group: Type[jaxlie.MatrixLieGroup], params: jaxlie.hints.Array
) -> jaxlie.hints.ArrayJax:
return cast(jnp.ndarray, (Group(params) @ Group(params)).parameters())


@general_group_test
def test_multiply_random(Group: Type[jaxlie.MatrixLieGroup]):
"""Check that multiply Jacobians are consistent, with randomly sampled transforms."""
params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters()
_assert_jacobians_close(Group=Group, f=_multiply, primal=params)


@general_group_test
def test_multiply_identity(Group: Type[jaxlie.MatrixLieGroup]):
"""Check that multiply Jacobians are consistent, with transforms close to the identity."""
params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters()
_assert_jacobians_close(Group=Group, f=_multiply, primal=params)


# Inverse
@jax.partial(jax.jit, static_argnums=0)
def _inverse(
Group: Type[jaxlie.MatrixLieGroup], params: jaxlie.hints.Array
) -> jaxlie.hints.ArrayJax:
return cast(jnp.ndarray, Group(params).inverse().parameters())


@general_group_test
def test_inverse_random(Group: Type[jaxlie.MatrixLieGroup]):
"""Check that inverse Jacobians are consistent, with randomly sampled transforms."""
params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters()
_assert_jacobians_close(Group=Group, f=_inverse, primal=params)


@general_group_test
def test_inverse_identity(Group: Type[jaxlie.MatrixLieGroup]):
"""Check that inverse Jacobians are consistent, with transforms close to the identity."""
params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters()
_assert_jacobians_close(Group=Group, f=_inverse, primal=params)
49 changes: 44 additions & 5 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import random
from typing import Any, Callable, Type, TypeVar, cast
from typing import Any, Callable, List, Type, TypeVar, cast

import jax
import numpy as onp
import pytest
import scipy.optimize
from hypothesis import given, settings
from hypothesis import strategies as st
from jax import numpy as jnp
Expand All @@ -20,7 +21,19 @@
def sample_transform(Group: Type[T]) -> T:
"""Sample a random transform from a group."""
seed = random.getrandbits(32)
return cast(T, Group.sample_uniform(key=jax.random.PRNGKey(seed=seed)))
strategy = random.randint(0, 2)

if strategy == 0:
# Uniform sampling
return cast(T, Group.sample_uniform(key=jax.random.PRNGKey(seed=seed)))
elif strategy == 1:
# Sample from normally-sampled tangent vector
return cast(T, Group.exp(onp.random.randn(Group.tangent_dim)))
elif strategy == 2:
# Sample near identity
return cast(T, Group.exp(onp.random.randn(Group.tangent_dim) * 1e-7))
else:
assert False


def general_group_test(
Expand All @@ -33,7 +46,7 @@ def f_wrapped(Group: Type[jaxlie.MatrixLieGroup], _random_module) -> None:
f(Group)

# Disable timing check (first run requires JIT tracing and will be slower)
f_wrapped = settings(deadline=None)(f_wrapped)
f_wrapped = settings(deadline=None, max_examples=max_examples)(f_wrapped)

# Add _random_module parameter
f_wrapped = given(_random_module=st.random_module())(f_wrapped)
Expand Down Expand Up @@ -71,9 +84,35 @@ def assert_transforms_close(a: jaxlie.MatrixLieGroup, b: jaxlie.MatrixLieGroup):
assert_arrays_close(p1, p2)


def assert_arrays_close(*arrays: jaxlie.hints.Array):
def assert_arrays_close(
*arrays: jaxlie.hints.Array,
rtol: float = 1e-8,
atol: float = 1e-8,
):
"""Make sure two arrays are close. (and not NaN)"""
for array1, array2 in zip(arrays[:-1], arrays[1:]):
onp.testing.assert_allclose(array1, array2, rtol=1e-8, atol=1e-8)
onp.testing.assert_allclose(array1, array2, rtol=rtol, atol=atol)
assert not onp.any(onp.isnan(array1))
assert not onp.any(onp.isnan(array2))


def jacnumerical(
f: Callable[[jaxlie.hints.Array], jaxlie.hints.ArrayJax]
) -> Callable[[jaxlie.hints.Array], jaxlie.hints.ArrayJax]:
"""Decorator for computing numerical Jacobians of vector->vector functions."""

def wrapped(params: jaxlie.hints.Array) -> jaxlie.hints.ArrayJax:
output_dim: int
(output_dim,) = f(params).shape

jacobian_rows: List[onp.ndarray] = []
for i in range(output_dim):
jacobian_row: onp.ndarray = scipy.optimize.approx_fprime(
params, lambda p: f(p)[i], epsilon=1e-5
)
assert jacobian_row.shape == params.shape
jacobian_rows.append(jacobian_row)

return jnp.stack(jacobian_rows, axis=0)

return wrapped

0 comments on commit aeacce8

Please sign in to comment.