Skip to content

Commit

Permalink
Fix reverse-mode AD for log/exp ops, bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed May 23, 2021
1 parent aeacce8 commit 3562997
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 28 deletions.
1 change: 0 additions & 1 deletion jaxlie/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import abc
from typing import ClassVar, Generic, Type, TypeVar, overload

import jax
import numpy as onp
from jax import numpy as jnp
from overrides import EnforceOverrides, final, overrides
Expand Down
26 changes: 22 additions & 4 deletions jaxlie/_se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,25 @@ def exp(tangent: hints.TangentVector) -> "SE2":

theta = tangent[2]
use_taylor = jnp.abs(theta) < get_epsilon(tangent.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD
safe_theta = jnp.where(
use_taylor,
1.0,
theta, # Any non-zero value should do here
)

theta_sq = theta ** 2
sin_over_theta = jnp.where(
use_taylor,
1.0 - theta_sq / 6.0,
jnp.sin(theta) / theta,
jnp.sin(safe_theta) / safe_theta,
)
one_minus_cos_over_theta = jnp.where(
use_taylor,
0.5 * theta - theta * theta_sq / 24.0,
(1.0 - jnp.cos(theta)) / theta,
(1.0 - jnp.cos(safe_theta)) / safe_theta,
)

V = jnp.array(
Expand All @@ -144,12 +153,21 @@ def log(self: "SE2") -> hints.TangentVectorJax:
cos_minus_one = cos - 1.0
half_theta = theta / 2.0
use_taylor = jnp.abs(cos_minus_one) < get_epsilon(theta.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD
safe_cos_minus_one = jnp.where(
use_taylor,
1.0, # Any non-zero value should do here
cos_minus_one,
)

half_theta_over_tan_half_theta = jnp.where(
use_taylor,
# First-order Taylor approximation
# Taylor approximation
1.0 - (theta ** 2) / 12.0,
# Default
-(half_theta * jnp.sin(theta)) / cos_minus_one,
-(half_theta * jnp.sin(theta)) / safe_cos_minus_one,
)

V_inv = jnp.array(
Expand Down
50 changes: 38 additions & 12 deletions jaxlie/_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,27 @@ def exp(tangent: hints.TangentVector) -> "SE3":
rotation = SO3.exp(tangent[3:])

theta_squared = tangent[3:] @ tangent[3:]
theta = jnp.sqrt(theta_squared)
use_taylor = theta_squared < get_epsilon(theta_squared.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD
theta_squared_safe = jnp.where(
use_taylor,
1.0, # Any non-zero value should do here
theta_squared,
)
del theta_squared
theta_safe = jnp.sqrt(theta_squared_safe)

skew_omega = _skew(tangent[3:])
use_small_theta = theta < get_epsilon(theta_squared.dtype)
V = jnp.where(
use_small_theta,
use_taylor,
rotation.as_matrix(),
(
jnp.eye(3)
+ (1.0 - jnp.cos(theta)) / (theta_squared) * skew_omega
+ (theta - jnp.sin(theta))
/ (theta_squared * theta)
+ (1.0 - jnp.cos(theta_safe)) / (theta_squared_safe) * skew_omega
+ (theta_safe - jnp.sin(theta_safe))
/ (theta_squared_safe * theta_safe)
* (skew_omega @ skew_omega)
),
)
Expand All @@ -136,18 +146,34 @@ def log(self: "SE3") -> hints.TangentVectorJax:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223
omega = self.rotation().log()
theta_squared = omega @ omega
use_taylor = theta_squared < get_epsilon(theta_squared.dtype)

skew_omega = _skew(omega)
theta = jnp.sqrt(theta_squared)
half_theta = theta / 2.0
use_small_theta = theta < get_epsilon(theta_squared.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD
theta_squared_safe = jnp.where(
use_taylor,
1.0, # Any non-zero value should do here
theta_squared,
)
del theta_squared
theta_safe = jnp.sqrt(theta_squared_safe)
half_theta_safe = theta_safe / 2.0

V_inv = jnp.where(
use_small_theta,
use_taylor,
jnp.eye(3) - 0.5 * skew_omega + (skew_omega @ skew_omega) / 12.0,
(
jnp.eye(3)
- 0.5 * skew_omega
+ (1.0 - theta * jnp.cos(half_theta) / (2.0 * jnp.sin(half_theta)))
/ theta_squared
+ (
1.0
- theta_safe
* jnp.cos(half_theta_safe)
/ (2.0 * jnp.sin(half_theta_safe))
)
/ theta_squared_safe
* (skew_omega @ skew_omega)
),
)
Expand Down
38 changes: 29 additions & 9 deletions jaxlie/_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,20 +285,30 @@ def exp(tangent: hints.TangentVector) -> "SO3":
assert tangent.shape == (3,)

theta_squared = tangent @ tangent
theta = jnp.sqrt(theta_squared)
half_theta = 0.5 * theta
theta_pow_4 = theta_squared * theta_squared
use_taylor = theta_squared < get_epsilon(tangent.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD
safe_theta = jnp.sqrt(
jnp.where(
use_taylor,
0.0, # Any constant value should do here
theta_squared,
)
)
safe_half_theta = 0.5 * safe_theta

use_taylor = theta < get_epsilon(tangent.dtype)
real_factor = jnp.where(
use_taylor,
1.0 - theta_squared / 8.0 + theta_pow_4 / 384.0,
jnp.cos(half_theta),
jnp.cos(safe_half_theta),
)

imaginary_factor = jnp.where(
use_taylor,
0.5 - theta_squared / 48.0 + theta_pow_4 / 3840.0,
jnp.sin(half_theta) / theta,
jnp.sin(safe_half_theta) / safe_theta,
)

return SO3(
Expand All @@ -317,15 +327,25 @@ def log(self: "SO3") -> hints.TangentVectorJax:

w = self.wxyz[..., 0]
norm_sq = self.wxyz[..., 1:] @ self.wxyz[..., 1:]
norm = jnp.sqrt(norm_sq)
use_taylor = norm < get_epsilon(norm_sq.dtype)
use_taylor = norm_sq < get_epsilon(norm_sq.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD
norm_safe = jnp.sqrt(
jnp.where(
use_taylor,
1.0, # Any non-zero value should do here
norm_sq,
)
)

atan_factor = jnp.where(
use_taylor,
2.0 / w - 2.0 / 3.0 * norm_sq / (w ** 3),
jnp.where(
jnp.abs(w) < get_epsilon(w.dtype),
jnp.where(w > 0, 1.0, -1.0) * jnp.pi / norm,
2.0 * jnp.arctan(norm / w) / norm,
jnp.where(w > 0, 1.0, -1.0) * jnp.pi / norm_safe,
2.0 * jnp.arctan(norm_safe / w) / norm_safe,
),
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="jaxlie",
version="1.2.4",
version="1.2.5",
description="Matrix Lie groups in Jax",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def sample_transform(Group: Type[T]) -> T:


def general_group_test(
f: Callable[[Type[jaxlie.MatrixLieGroup]], None], max_examples: int = 100
f: Callable[[Type[jaxlie.MatrixLieGroup]], None], max_examples: int = 30
) -> Callable[[Type[jaxlie.MatrixLieGroup], Any], None]:
"""Decorator for defining tests that run on all group types."""

Expand Down

0 comments on commit 3562997

Please sign in to comment.