Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nbouziani committed Sep 10, 2024
1 parent ada00a5 commit 6a59636
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
9 changes: 4 additions & 5 deletions firedrake/ml/jax/fem_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def custom_vjp(_, **kwargs):
raise ImportError("JAX is not installed and is required to use the FiredrakeJaxOperator.")

import collections
import warnings
import numpy as np
from functools import partial

Expand Down Expand Up @@ -118,6 +117,8 @@ def fem_operator(F):
raise ValueError("F must be a ReducedFunctional")

jax_op = FiredrakeJaxOperator(F)
# `jax_op.forward` currently does not work and causes issues related to the function
# signature during JAX compilation. As a workaround, we use `functools.partial` instead.
return partial(FiredrakeJaxOperator.forward, jax_op)


Expand Down Expand Up @@ -208,10 +209,8 @@ def from_jax(x, V=None):

if isinstance(x, jax.core.ShapedArray):
if not isinstance(x, jax.core.ConcreteArray):
warnings.warn("Cannot convert a JAX abstract array to a Firedrake object. Returning a zero function.")
x = np.zeros(x.shape)
else:
x = x.val
raise TypeError("Cannot convert a JAX abstract array to a Firedrake object.")
x = x.val

if not isinstance(x, np.ndarray) and x.device.platform != "cpu":
raise NotImplementedError("Firedrake does not support GPU/TPU tensors")
Expand Down
4 changes: 2 additions & 2 deletions firedrake/ml/jax/ml_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class is achieved using JAX differentiation on the JAX model associated with the
argument_slots=argument_slots, operator_data=operator_data)

# Check that JAX double precision is enabled if Firedrake operates in double precision.
if utils.ScalarType == jnp.float64 and not jax.config.jax_enable_x64:
warnings.warn("JAX is not configured to use 64-bit precision. Consider setting `jax_enable_x64=True`, e.g. `jax.config.update('jax_enable_x64', True)`.", RuntimeWarning)
if utils.ScalarType in (jnp.float64, jnp.complex128) and not jax.config.jax_enable_x64:
warnings.warn("JAX is not configured to use double precision. Consider setting `jax_enable_x64=True`, e.g. `jax.config.update('jax_enable_x64', True)`.", RuntimeWarning)

# --- Callbacks --- #

Expand Down

0 comments on commit 6a59636

Please sign in to comment.