From 52212a779a2852a9d1f7d8cec2854f5959cef262 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 15 Dec 2023 06:02:50 -0800 Subject: [PATCH] [export] Refactor the imports for the public API of jax.experimental.export Previously we used `from jax.experimental.export import export` and `export.export(fun)`. Now we want to add the public API directly to `jax.experimental.export`, for the following desired usage: ``` from jax.experimental import export exp: export.Exported = export.export(fun) ser: bytearray = export.serialize(exp) exp1: export.Exported = export.deserialized(ser) export.call(exp1) ``` This change requires changing the type of `jax.experimental.export.export` from a module to a function. This confuses pytype for the targets with strict type checking, which is why I attempt to make this change atomically throughout the internal code base. In order to support backwards compatibility with OSS packages, this change also includes explicit JAX version checks in several OSS packages, and also adds to the `export` function the attributes that the old export module had. PiperOrigin-RevId: 591228914 --- tf2jax/experimental/mhlo.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tf2jax/experimental/mhlo.py b/tf2jax/experimental/mhlo.py index 4dc7664..ca3895a 100644 --- a/tf2jax/experimental/mhlo.py +++ b/tf2jax/experimental/mhlo.py @@ -127,12 +127,16 @@ def mhlo_apply_abstract_eval( assert has_polymorphic, has_polymorphic if jax.__version_info__ <= (0, 4, 14): from jax.experimental.jax2tf import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error - else: + out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access # pytype: disable=module-attr + elif jax.__version_info__ <= (0, 4, 20): + from jax.experimental.export import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error + out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access # pytype: disable=module-attr + elif jax.__version_info__ <= (0, 4, 23): from jax.experimental.export import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error - if jax.__version_info__ <= (0, 4, 20): - out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access - else: out_shape = shape_poly.symbolic_shape(out_shape, like=res.shape) + else: + from jax.experimental import export # pylint: disable=g-import-not-at-top # pytype: disable=import-error + out_shape = export.symbolic_shape(out_shape, like=res.shape) else: out_shape = res.shape output_specs.append(