Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[export] Refactor the imports for the public API of jax.experimental.…
…export Previously we used `from jax.experimental.export import export` and `export.export(fun)`. Now we want to add the public API directly to `jax.experimental.export`, for the following desired usage: ``` from jax.experimental import export exp: export.Exported = export.export(fun) ser: bytearray = export.serialize(exp) exp1: export.Exported = export.deserialized(ser) export.call(exp1) ``` This change also includes a workaround to allow users to still do `from jax.experimental.export import export`, for a while. Essentially, with the new structure `jax.experimental.export.export` is the exporting function, while before it was the export module. We add to the exporting function the attributes that the old export module had. This workaround is sufficient for just running the module, but confuses pytype. Therefore, I am including in this change all the uses I can find internally. For OSS packages, the workaround should be sufficient. PiperOrigin-RevId: 591228914
- Loading branch information