Skip to content

Commit

Permalink
[jax2tf] Disable jax2tf with non-native serialization.
Browse files Browse the repository at this point in the history
jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024.

This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error.

PiperOrigin-RevId: 689708392
  • Loading branch information
gnecula authored and Google-ML-Automation committed Oct 25, 2024
1 parent 0bc70bb commit 9088add
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 51 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.

## jax 0.4.36

* Breaking Changes
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or with `enable_xla=False` have been deprecated since July 2024, with
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
with native serialization will still be supported.

## jax 0.4.35 (Oct 22, 2024)

* Breaking Changes
Expand All @@ -21,6 +27,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* `jax.experimental.host_callback` has been deprecated since March 2024, with
JAX version 0.4.26. Now we removed it.
See {jax-issue}`#20385` for a discussion of alternatives.
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
or with `enable_xla=False` have been deprecated since July 2024, with
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
with native serialization is still supported.

* Changes:
* `jax.lax.FftType` was introduced as a public name for the enum of FFT
Expand Down
52 changes: 33 additions & 19 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def _sanitize_scope_name(name):
# Line below is different externally and internally.
allow_enable_xla_false = lambda: True

# TODO(b/353437398): Deprecate support for `native_serialization=False`.
# Line below is different externally and internally.
allow_native_serialization_false = lambda: True

# A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
TfVal = Any
Expand Down Expand Up @@ -294,8 +298,8 @@ def convert(fun_jax: Callable,
See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
polymorphic_constraints: a sequence of contraints on symbolic dimension expressions, of
the form `e1 >= e2` or `e1 <= e2`.
polymorphic_constraints: a sequence of constraints on symbolic dimension
expressions, of the form `e1 >= e2` or `e1 <= e2`.
See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
with_gradient: if set (default), add a tf.custom_gradient to the lowered
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
Expand Down Expand Up @@ -332,28 +336,38 @@ def convert(fun_jax: Callable,
tuple/lists/dicts thereof), and returns TfVals as outputs, and uses
only TensorFlow ops and thus can be called from a TensorFlow program.
"""
if not enable_xla:
if allow_enable_xla_false():
warnings.warn("jax2tf.convert with enable_xla=False is deprecated.",
DeprecationWarning,
stacklevel=2)
else:
raise ValueError("jax2tf.convert with enable_xla=False is not supported.")

if native_serialization is DEFAULT_NATIVE_SERIALIZATION:
if not enable_xla:
native_serialization = False
else:
native_serialization = config.jax2tf_default_native_serialization.value

if not native_serialization:
warnings.warn(
"jax2tf.convert with native_serialization=False is deprecated.",
DeprecationWarning,
stacklevel=2)
if native_serialization and not enable_xla:
raise ValueError(
"native_serialization is not supported with enable_xla=False")
if not enable_xla:
if allow_enable_xla_false():
warnings.warn(
"jax2tf.convert with enable_xla=False has been deprecated "
"since July 2024.",
DeprecationWarning,
stacklevel=2)
if native_serialization:
raise ValueError(
"native_serialization is not supported with enable_xla=False")
else:
raise ValueError(
"jax2tf.convert with enable_xla=False has been deprecated "
"since July 2024 and it is not supported anymore.")

elif not native_serialization:
if allow_native_serialization_false():
warnings.warn(
"jax2tf.convert with native_serialization=False has been deprecated "
"since July 2024.",
DeprecationWarning,
stacklevel=2)
else:
raise ValueError(
"jax2tf.convert with native_serialization=False has been deprecated "
"since July 2024 and it is not supported anymore.")

if not native_serialization and polymorphic_constraints:
raise ValueError(
Expand Down Expand Up @@ -2188,7 +2202,7 @@ def _dot_general(lhs, rhs, *, dimension_numbers,
_out_aval: core.ShapedArray):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
# TODO(b/293247337): we ought to turn on this safety check, but this leads to
# failures. Since we are going to turn on native serializaton soon, wait
# failures. Since we are going to turn on native serialization soon, wait
# until then to turn on this check.
# lhs_aval, rhs_aval = _in_avals
# if lhs_aval.dtype != rhs_aval.dtype:
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/jax2tf/tests/call_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def setUp(self):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message=(
"(jax2tf.convert with native_serialization=False is deprecated"
"(jax2tf.convert with native_serialization=False has been deprecated"
"|Calling from_dlpack with a DLPack tensor is deprecated)"
)
)
Expand Down Expand Up @@ -897,7 +897,7 @@ def setUp(self):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message=(
"(jax2tf.convert with native_serialization=False is deprecated"
"(jax2tf.convert with native_serialization=False has been deprecated"
"|Calling from_dlpack with a DLPack tensor is deprecated)"
)
)
Expand Down Expand Up @@ -1203,7 +1203,7 @@ def setUp(self):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message=(
"(jax2tf.convert with native_serialization=False is deprecated"
"(jax2tf.convert with native_serialization=False has been deprecated"
"|Calling from_dlpack with a DLPack tensor is deprecated)"
)
)
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/jax2tf/tests/jax2tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def setUpClass(cls):
def setUp(self):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
message="jax2tf.convert with native_serialization=False has been deprecated"
)
self.warning_ctx.__enter__()

Expand Down Expand Up @@ -1722,7 +1722,7 @@ class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
def setUp(self):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
message="jax2tf.convert with native_serialization=False has been deprecated"
)
self.warning_ctx.__enter__()

Expand Down Expand Up @@ -1763,7 +1763,7 @@ def setUp(self):
super().setUp()

@jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
message="jax2tf.convert with native_serialization=False has been deprecated"
)
def test_simple(self):
self.ConvertAndCompare(jnp.sin, 0.7)
Expand Down
27 changes: 1 addition & 26 deletions jax/experimental/jax2tf/tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ def f_jax(x): # A function whose gradient is a constant
self.assertAllClose(f_jax(x), restored_f(x))

@jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
message="jax2tf.convert with native_serialization=False has been deprecated"
)
def test_readme_examples(self):
"""Some of the examples from the README."""
Expand Down Expand Up @@ -1124,31 +1124,6 @@ def f2_jax(x): # f32[b, b]
# JAX with static shapes sees that x.shape[0] != x.shape[1]
self.assertEqual(jnp.sum(x45), f2_jax(x45))

# In graph serialization eager mode, we catch the broken assumption b >= 1
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
re.escape(
"Found inconsistency between dimension size args[0].shape[1] (= 5) "
"and the specification 'b' (= 4)")):
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=False)(x45)

# In graph serialization graph mode we also catch it (except on TPU, where
# the behavior is as for jit_compile=1)

f2_tf = tf.function(
jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"],
native_serialization=False),
autograph=False,
).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32))
if jtu.test_device_matches(["tpu"]):
self.assertEqual(1. + jnp.sum(x45), f2_tf(x45))
else:
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
r"Found inconsistency"):
_ = f2_tf(x45)

# We also catch the error with native serialization
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
Expand Down

0 comments on commit 9088add

Please sign in to comment.