Skip to content

Commit

Permalink
[export] Refactor the imports for the public API of jax.experimental.…
Browse files Browse the repository at this point in the history
…export

Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```

This change also includes a workaround to allow users to still
do `from jax.experimental.export import export`, for a while.
Essentially, with the new structure `jax.experimental.export.export`
is the exporting function, while before it was the export module.
We add to the exporting function the attributes that the
old export module had.

This workaround is sufficient for just running the module, but
confuses pytype. Therefore, I am including in this change all
the uses I can find internally. For OSS packages, the workaround
should be sufficient.

PiperOrigin-RevId: 591228914
  • Loading branch information
gnecula authored and TF2JAXDev committed Dec 15, 2023
1 parent ac44167 commit 66355ae
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tf2jax/experimental/mhlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,13 @@ def mhlo_apply_abstract_eval(
assert has_polymorphic, has_polymorphic
if jax.__version_info__ <= (0, 4, 14):
from jax.experimental.jax2tf import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error
else:
out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access
elif jax.__version_info__ <= (0, 4, 20):
from jax.experimental.export import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error
if jax.__version_info__ <= (0, 4, 20):
out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access
else:
out_shape = shape_poly.symbolic_shape(out_shape, like=res.shape)
from jax.experimental import export # pylint: disable=g-import-not-at-top # pytype: disable=import-error
out_shape = export.symbolic_shape(out_shape, like=res.shape)
else:
out_shape = res.shape
output_specs.append(
Expand Down

0 comments on commit 66355ae

Please sign in to comment.