diff --git a/CHANGELOG.md b/CHANGELOG.md index 2fc683cd2861..2a71144340fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.36 +* Breaking Changes + * {func}`jax.experimental.jax2tf.convert` with `native_serialization=False` + or with `enable_xla=False` have been deprecated since July 2024, with + JAX version 0.4.31. Now we removed support for these use cases. `jax2tf` + with native serialization will still be supported. + ## jax 0.4.35 (Oct 22, 2024) * Breaking Changes @@ -21,6 +27,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * `jax.experimental.host_callback` has been deprecated since March 2024, with JAX version 0.4.26. Now we removed it. See {jax-issue}`#20385` for a discussion of alternatives. + * {func}`jax.experimental.jax2tf.convert` with `native_serialization=False` + or with `enable_xla=False` have been deprecated since July 2024, with + JAX version 0.4.31. Now we removed support for these use cases. `jax2tf` + with native serialization is still supported. * Changes: * `jax.lax.FftType` was introduced as a public name for the enum of FFT diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 29a1034e51ed..b12edf2a37ec 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -119,6 +119,10 @@ def _sanitize_scope_name(name): # Line below is different externally and internally. allow_enable_xla_false = lambda: True +# TODO(b/353437398): Deprecate support for `native_serialization=False`. +# Line below is different externally and internally. +allow_native_serialization_false = lambda: True + # A value suitable in a TF tracing context: tf.Tensor, tf.Variable, # or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.) TfVal = Any @@ -294,8 +298,8 @@ def convert(fun_jax: Callable, See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) for more details. - polymorphic_constraints: a sequence of contraints on symbolic dimension expressions, of - the form `e1 >= e2` or `e1 <= e2`. + polymorphic_constraints: a sequence of constraints on symbolic dimension + expressions, of the form `e1 >= e2` or `e1 <= e2`. See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. with_gradient: if set (default), add a tf.custom_gradient to the lowered function, by converting the ``jax.vjp(fun)``. This means that reverse-mode @@ -332,28 +336,38 @@ def convert(fun_jax: Callable, tuple/lists/dicts thereof), and returns TfVals as outputs, and uses only TensorFlow ops and thus can be called from a TensorFlow program. """ - if not enable_xla: - if allow_enable_xla_false(): - warnings.warn("jax2tf.convert with enable_xla=False is deprecated.", - DeprecationWarning, - stacklevel=2) - else: - raise ValueError("jax2tf.convert with enable_xla=False is not supported.") - if native_serialization is DEFAULT_NATIVE_SERIALIZATION: if not enable_xla: native_serialization = False else: native_serialization = config.jax2tf_default_native_serialization.value - if not native_serialization: - warnings.warn( - "jax2tf.convert with native_serialization=False is deprecated.", - DeprecationWarning, - stacklevel=2) - if native_serialization and not enable_xla: - raise ValueError( - "native_serialization is not supported with enable_xla=False") + if not enable_xla: + if allow_enable_xla_false(): + warnings.warn( + "jax2tf.convert with enable_xla=False has been deprecated " + "since July 2024.", + DeprecationWarning, + stacklevel=2) + if native_serialization: + raise ValueError( + "native_serialization is not supported with enable_xla=False") + else: + raise ValueError( + "jax2tf.convert with enable_xla=False has been deprecated " + "since July 2024 and it is not supported anymore.") + + elif not native_serialization: + if allow_native_serialization_false(): + warnings.warn( + "jax2tf.convert with native_serialization=False has been deprecated " + "since July 2024.", + DeprecationWarning, + stacklevel=2) + else: + raise ValueError( + "jax2tf.convert with native_serialization=False has been deprecated " + "since July 2024 and it is not supported anymore.") if not native_serialization and polymorphic_constraints: raise ValueError( @@ -2188,7 +2202,7 @@ def _dot_general(lhs, rhs, *, dimension_numbers, _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" # TODO(b/293247337): we ought to turn on this safety check, but this leads to - # failures. Since we are going to turn on native serializaton soon, wait + # failures. Since we are going to turn on native serialization soon, wait # until then to turn on this check. # lhs_aval, rhs_aval = _in_avals # if lhs_aval.dtype != rhs_aval.dtype: diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index e8d284178691..6d5efb7b1e66 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -90,7 +90,7 @@ def setUp(self): super().setUp() self.warning_ctx = jtu.ignore_warning( message=( - "(jax2tf.convert with native_serialization=False is deprecated" + "(jax2tf.convert with native_serialization=False has been deprecated" "|Calling from_dlpack with a DLPack tensor is deprecated)" ) ) @@ -897,7 +897,7 @@ def setUp(self): super().setUp() self.warning_ctx = jtu.ignore_warning( message=( - "(jax2tf.convert with native_serialization=False is deprecated" + "(jax2tf.convert with native_serialization=False has been deprecated" "|Calling from_dlpack with a DLPack tensor is deprecated)" ) ) @@ -1203,7 +1203,7 @@ def setUp(self): super().setUp() self.warning_ctx = jtu.ignore_warning( message=( - "(jax2tf.convert with native_serialization=False is deprecated" + "(jax2tf.convert with native_serialization=False has been deprecated" "|Calling from_dlpack with a DLPack tensor is deprecated)" ) ) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 8ef9a1a5dd25..27e001fbdb49 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -79,7 +79,7 @@ def setUpClass(cls): def setUp(self): super().setUp() self.warning_ctx = jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False is deprecated" + message="jax2tf.convert with native_serialization=False has been deprecated" ) self.warning_ctx.__enter__() @@ -1722,7 +1722,7 @@ class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase): def setUp(self): super().setUp() self.warning_ctx = jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False is deprecated" + message="jax2tf.convert with native_serialization=False has been deprecated" ) self.warning_ctx.__enter__() @@ -1763,7 +1763,7 @@ def setUp(self): super().setUp() @jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False is deprecated" + message="jax2tf.convert with native_serialization=False has been deprecated" ) def test_simple(self): self.ConvertAndCompare(jnp.sin, 0.7) diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 38af6d9d76d5..786b98e339e9 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -1031,7 +1031,7 @@ def f_jax(x): # A function whose gradient is a constant self.assertAllClose(f_jax(x), restored_f(x)) @jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False is deprecated" + message="jax2tf.convert with native_serialization=False has been deprecated" ) def test_readme_examples(self): """Some of the examples from the README.""" @@ -1124,31 +1124,6 @@ def f2_jax(x): # f32[b, b] # JAX with static shapes sees that x.shape[0] != x.shape[1] self.assertEqual(jnp.sum(x45), f2_jax(x45)) - # In graph serialization eager mode, we catch the broken assumption b >= 1 - with self.assertRaisesRegex( - tf.errors.InvalidArgumentError, - re.escape( - "Found inconsistency between dimension size args[0].shape[1] (= 5) " - "and the specification 'b' (= 4)")): - jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"], - native_serialization=False)(x45) - - # In graph serialization graph mode we also catch it (except on TPU, where - # the behavior is as for jit_compile=1) - - f2_tf = tf.function( - jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"], - native_serialization=False), - autograph=False, - ).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32)) - if jtu.test_device_matches(["tpu"]): - self.assertEqual(1. + jnp.sum(x45), f2_tf(x45)) - else: - with self.assertRaisesRegex( - tf.errors.InvalidArgumentError, - r"Found inconsistency"): - _ = f2_tf(x45) - # We also catch the error with native serialization with self.assertRaisesRegex( tf.errors.InvalidArgumentError,