From 2b9c73d10d08d7415337ab71cb7718022e89c408 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 31 Oct 2024 15:40:54 -0700 Subject: [PATCH] Remove a number of expired deprecations. These APIs were all removed 3 or more months ago, and the registrations here cause them to raise informative AttributeErrors. Enough time has passed now that we can remove these. --- jax/__init__.py | 7 ------- jax/core.py | 22 ---------------------- jax/interpreters/ad.py | 11 ----------- jax/interpreters/xla.py | 36 ------------------------------------ jax/lax/__init__.py | 13 ------------- jax/nn/__init__.py | 14 -------------- jax/numpy/__init__.py | 5 ----- jax/random.py | 17 ----------------- 8 files changed, 125 deletions(-) diff --git a/jax/__init__.py b/jax/__init__.py index 4f5c256b0c9d..7916ef0e3962 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -223,13 +223,6 @@ "jax.clear_backends is deprecated.", _deprecated_clear_backends ), - # Remove after jax 0.4.35 release. - "xla_computation": ( - "jax.xla_computation is deleted. Please use the AOT APIs; see " - "https://jax.readthedocs.io/en/latest/aot.html. For example, replace " - "xla_computation(f)(*xs) with jit(f).lower(*xs).compiler_ir('hlo'). See " - "CHANGELOG.md for 0.4.30 for more examples.", None - ), } import typing as _typing diff --git a/jax/core.py b/jax/core.py index 2880e42c681b..fb08763fd3a1 100644 --- a/jax/core.py +++ b/jax/core.py @@ -147,28 +147,6 @@ "pp_kv_pairs": ("jax.core.pp_kv_pairs was removed in JAX v0.4.34.", None), "pp_var": ("jax.core.pp_var was removed in JAX v0.4.34.", None), "pp_vars": ("jax.core.pp_vars was removed in JAX v0.4.34.", None), - # Finalized 2024-05-13; remove after 2024-08-13 - "DimSize": ( - "jax.core.DimSize is deprecated. Use DimSize = int | Any.", - None, - ), - "Shape": ( - "jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].", - None, - ), - # Finalized 2024-06-24; remove after 2024-09-24 - "canonicalize_shape": ( - "jax.core.canonicalize_shape is deprecated.", None, - ), - "dimension_as_value": ( - "jax.core.dimension_as_value is deprecated. Use jnp.array.", None, - ), - "definitely_equal": ( - "jax.core.definitely_equal is deprecated. Use ==.", None, - ), - "symbolic_equal_dim": ( - "jax.core.symbolic_equal_dim is deprecated. Use ==.", None, - ), # Added Jan 8, 2024 "non_negative_dim": ( "jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _src_core.non_negative_dim, diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 160a96fae368..4ded4a803ae0 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -68,17 +68,6 @@ zeros_like_p as zeros_like_p, ) -_deprecations = { - # Finalized Mar 18, 2024; remove after June 18, 2024 - "config": ( - "jax.interpreters.ad.config is deprecated. Use jax.config directly.", - None, - ), - "source_info_util": ( - "jax.interpreters.ad.source_info_util is deprecated. Use jax.extend.source_info_util.", - None, - ), -} def backward_pass(jaxpr, reduce_axes, transform_stack, consts, primals_in, cotangents_in): diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 2711bcfb80d5..b3a470f5e049 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -42,42 +42,6 @@ ("jax.interpreters.xla.xe was removed in JAX v0.4.36. " "Use jax.lib.xla_extension instead."), None ), - # Finalized 2024-05-13; remove after 2024-08-13 - "backend_specific_translations": ( - "jax.interpreters.xla.backend_specific_translations is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "translations": ( - "jax.interpreters.xla.translations is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "register_translation": ( - "jax.interpreters.xla.register_translation is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "xla_destructure": ( - "jax.interpreters.xla.xla_destructure is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "TranslationRule": ( - "jax.interpreters.xla.TranslationRule is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "TranslationContext": ( - "jax.interpreters.xla.TranslationContext is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "XlaOp": ( - "jax.interpreters.xla.XlaOp is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), } from jax._src.deprecations import deprecation_getattr as _deprecation_getattr diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 5f3bfa057912..d2fb6a9bae3c 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -377,16 +377,3 @@ from jax._src.pjit import with_sharding_constraint as with_sharding_constraint from jax._src.pjit import sharding_constraint_p as sharding_constraint_p from jax._src.dispatch import device_put_p as device_put_p - - -_deprecations = { - # Finalized 2024-05-13; remove after 2024-08-13 - "tie_in": ( - "jax.lax.tie_in is deprecated: it has been a no-op since JAX v0.2.0. " - "Replace z = tie_in(x, y) with z = y.", None, - ), -} - -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 496d03261384..ebe725c448ee 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -49,17 +49,3 @@ squareplus as squareplus, mish as mish, ) - -# Deprecations - -_deprecations = { - # Finalized 2024-05-13; remove after 2024-08-13 - "normalize": ( - "jax.nn.normalize is deprecated. Use jax.nn.standardize instead.", - None, - ), -} - -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 93405cc03ef7..9be73e96adcf 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -472,11 +472,6 @@ "jnp.round_ is deprecated; use jnp.round instead.", round ), - # Deprecated 18 Sept 2023 and removed 06 Feb 2024 - "trapz": ( - "jnp.trapz is deprecated; use jnp.trapezoid instead.", - None - ), } import typing diff --git a/jax/random.py b/jax/random.py index 29a625389811..b99cd531f18c 100644 --- a/jax/random.py +++ b/jax/random.py @@ -251,20 +251,3 @@ weibull_min as weibull_min, wrap_key_data as wrap_key_data, ) - -_deprecations = { - # Finalized Jul 26 2024; remove after Nov 2024. - "shuffle": ( - "jax.random.shuffle is deprecated. Use jax.random.permutation with independent=True.", - None, - ) -} - -import typing -if typing.TYPE_CHECKING: - pass -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing