From 66355ae5146261a2ecaeacefc88a70518fc9ac8f Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 15 Dec 2023 06:02:50 -0800 Subject: [PATCH] [export] Refactor the imports for the public API of jax.experimental.export Previously we used `from jax.experimental.export import export` and `export.export(fun)`. Now we want to add the public API directly to `jax.experimental.export`, for the following desired usage: ``` from jax.experimental import export exp: export.Exported = export.export(fun) ser: bytearray = export.serialize(exp) exp1: export.Exported = export.deserialized(ser) export.call(exp1) ``` This change also includes a workaround to allow users to still do `from jax.experimental.export import export`, for a while. Essentially, with the new structure `jax.experimental.export.export` is the exporting function, while before it was the export module. We add to the exporting function the attributes that the old export module had. This workaround is sufficient for just running the module, but confuses pytype. Therefore, I am including in this change all the uses I can find internally. For OSS packages, the workaround should be sufficient. PiperOrigin-RevId: 591228914 --- tf2jax/experimental/mhlo.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tf2jax/experimental/mhlo.py b/tf2jax/experimental/mhlo.py index 4dc7664..aedd0bd 100644 --- a/tf2jax/experimental/mhlo.py +++ b/tf2jax/experimental/mhlo.py @@ -127,12 +127,13 @@ def mhlo_apply_abstract_eval( assert has_polymorphic, has_polymorphic if jax.__version_info__ <= (0, 4, 14): from jax.experimental.jax2tf import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error - else: + out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access + elif jax.__version_info__ <= (0, 4, 20): from jax.experimental.export import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error - if jax.__version_info__ <= (0, 4, 20): out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access else: - out_shape = shape_poly.symbolic_shape(out_shape, like=res.shape) + from jax.experimental import export # pylint: disable=g-import-not-at-top # pytype: disable=import-error + out_shape = export.symbolic_shape(out_shape, like=res.shape) else: out_shape = res.shape output_specs.append(