Skip to content

Commit

Permalink
Migrate from jax.core to jax.extend.core for several deprecated symbols
Browse files Browse the repository at this point in the history
A number of symbols from jax.core are deprecated as of recent JAX releases; some of them are newly available in jax.extend.core.

PiperOrigin-RevId: 705932315
  • Loading branch information
Jake VanderPlas authored and TF2JAXDev committed Dec 16, 2024
1 parent 34f26f5 commit 78898b9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
5 changes: 3 additions & 2 deletions tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import jax
from jax.experimental import checkify
import jax.extend as jex
from jax.lib import xla_client
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -2554,9 +2555,9 @@ def _func(x: jnp.ndarray) -> jnp.ndarray:


def _maybe_get_jaxpreqn(
jaxpr: jax.core.ClosedJaxpr) -> Optional[jax.core.JaxprEqn]:
jaxpr: jex.core.ClosedJaxpr) -> Optional[jex.core.JaxprEqn]:
def is_all_vars(vs):
return all([isinstance(v, jax.core.Var) for v in vs])
return all([isinstance(v, jex.core.Var) for v in vs])

if (len(jaxpr.eqns) == 1 and
is_all_vars(jaxpr.jaxpr.invars) and is_all_vars(jaxpr.jaxpr.outvars) and
Expand Down
3 changes: 2 additions & 1 deletion tf2jax/experimental/mhlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jax
from jax import core
from jax import export
import jax.extend as jex
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.lib import xla_client as xc
Expand All @@ -31,7 +32,7 @@

safe_zip = jax.util.safe_zip

mhlo_apply_p = core.Primitive("mhlo_apply")
mhlo_apply_p = jex.core.Primitive("mhlo_apply")
mhlo_apply_p.multiple_results = True


Expand Down

0 comments on commit 78898b9

Please sign in to comment.