Skip to content

Commit

Permalink
[export] Refactor the imports for the public API of jax.experimental.…
Browse files Browse the repository at this point in the history
…export

Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```

This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.

In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.

PiperOrigin-RevId: 591228914
  • Loading branch information
gnecula authored and TF2JAXDev committed Dec 18, 2023
1 parent ac44167 commit f7abd48
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions tf2jax/experimental/mhlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,16 @@ def mhlo_apply_abstract_eval(
assert has_polymorphic, has_polymorphic
if jax.__version_info__ <= (0, 4, 14):
from jax.experimental.jax2tf import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error
else:
out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access # pytype: disable=module-attr
elif jax.__version_info__ <= (0, 4, 20):
from jax.experimental.export import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error
out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access # pytype: disable=module-attr
elif jax.__version_info__ <= (0, 4, 23):
from jax.experimental.export import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error
if jax.__version_info__ <= (0, 4, 20):
out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access
else:
out_shape = shape_poly.symbolic_shape(out_shape, like=res.shape)
else:
from jax.experimental import export # pylint: disable=g-import-not-at-top # pytype: disable=import-error
out_shape = export.symbolic_shape(out_shape, like=res.shape)
else:
out_shape = res.shape
output_specs.append(
Expand Down

0 comments on commit f7abd48

Please sign in to comment.