diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index bce5b31..d529b93 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -22,6 +22,7 @@ import jax from jax.experimental import checkify +import jax.extend as jex from jax.lib import xla_client import jax.numpy as jnp import numpy as np @@ -2554,9 +2555,9 @@ def _func(x: jnp.ndarray) -> jnp.ndarray: def _maybe_get_jaxpreqn( - jaxpr: jax.core.ClosedJaxpr) -> Optional[jax.core.JaxprEqn]: + jaxpr: jex.core.ClosedJaxpr) -> Optional[jex.core.JaxprEqn]: def is_all_vars(vs): - return all([isinstance(v, jax.core.Var) for v in vs]) + return all([isinstance(v, jex.core.Var) for v in vs]) if (len(jaxpr.eqns) == 1 and is_all_vars(jaxpr.jaxpr.invars) and is_all_vars(jaxpr.jaxpr.outvars) and diff --git a/tf2jax/experimental/mhlo.py b/tf2jax/experimental/mhlo.py index 793570e..3550144 100644 --- a/tf2jax/experimental/mhlo.py +++ b/tf2jax/experimental/mhlo.py @@ -20,6 +20,7 @@ import jax from jax import core from jax import export +import jax.extend as jex from jax.interpreters import mlir from jax.interpreters import xla from jax.lib import xla_client as xc @@ -31,7 +32,7 @@ safe_zip = jax.util.safe_zip -mhlo_apply_p = core.Primitive("mhlo_apply") +mhlo_apply_p = jex.core.Primitive("mhlo_apply") mhlo_apply_p.multiple_results = True