Skip to content

Commit

Permalink
add an exp2 primitive and lax.exp2
Browse files Browse the repository at this point in the history
part of fixing jax-ml/jax-triton#204
  • Loading branch information
mattjj committed Jul 28, 2023
1 parent 640ee1e commit 296006f
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 3 deletions.
1 change: 1 addition & 0 deletions jax/_src/internal_test_util/lax_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def lax_ops():
),
op_record("is_finite", 1, float_dtypes, test_util.rand_small),
op_record("exp", 1, float_dtypes + complex_dtypes, test_util.rand_small),
op_record("exp2", 1, float_dtypes + complex_dtypes, test_util.rand_small),
# TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32
# precision.
op_record(
Expand Down
15 changes: 13 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,10 @@ def exp(x: ArrayLike) -> Array:
r"""Elementwise exponential: :math:`e^x`."""
return exp_p.bind(x)

def exp2(x: ArrayLike) -> Array:
r"""Elementwise base-2 exponential: :math:`2^x`."""
return exp2_p.bind(x)

def expm1(x: ArrayLike) -> Array:
r"""Elementwise :math:`e^{x} - 1`."""
return expm1_p.bind(x)
Expand Down Expand Up @@ -1757,10 +1761,17 @@ def _round_lower(ctx, x, *, rounding_method):

exp_p = standard_unop(_float | _complex, 'exp')
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
# For exp_p it is more efficient to use the reconstructed output for the vjp
# rule instead of computing it again from the input.
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.ExpOp))

exp2_p = standard_unop(_float | _complex, 'exp2')
ad.defjvp2(exp2_p, lambda g, ans, x: mul(mul(g, ans), log(_const(x, 2))))
def _exp2_lower(ctx, x):
x_aval, = ctx.avals_in
log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype))
log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=())
return hlo.ExpOp(hlo.MulOp(log2, x).result).results
mlir.register_lowering(exp2_p, _exp2_lower)

log_p = standard_unop(_float | _complex, 'log')
ad.defjvp(log_p, lambda g, x: div(g, x))
mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.LogOp))
Expand Down
1 change: 1 addition & 0 deletions jax/_src/lax_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def round(x):
is_finite = np.isfinite

exp = np.exp
exp2 = np.exp2
expm1 = np.expm1
log = np.log
log1p = np.log1p
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def log10(x: ArrayLike, /) -> Array:
@partial(jit, inline=True)
def exp2(x: ArrayLike, /) -> Array:
x, = promote_args_inexact("exp2", x)
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
return lax.exp2(x)


@_wraps(np.signbit, module='numpy')
Expand Down
1 change: 1 addition & 0 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
eq_p as eq_p,
exp as exp,
exp_p as exp_p,
exp2 as exp2,
expand_dims as expand_dims,
expm1 as expm1,
expm1_p as expm1_p,
Expand Down
2 changes: 2 additions & 0 deletions tests/lax_autodiff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):

grad_test_spec(lax.exp, nargs=1, order=2, rng_factory=jtu.rand_small,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.exp2, nargs=1, order=2, rng_factory=jtu.rand_small,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.expm1, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.log, nargs=1, order=2, rng_factory=jtu.rand_positive,
Expand Down

0 comments on commit 296006f

Please sign in to comment.