Skip to content

Releases: danielward27/flowjax

v17.0.2

19 Dec 09:50
5bcdc93
Compare
Choose a tag to compare

Updates to the workflow packages led to some errors in the release to PyPi, which was fixed with this release. Below are the changes from v16-v17, which are of more interest to users.

Breaking Changes

  • The AbstractBijection class now implements transform and inverse methods by indexing the outputs of transform_and_log_det and inverse_and_log_det. AFAIK, under JIT, dead code elimination prevents computation of log_det where unnecessary, minimizing the benefits of directly implementing these methods. Custom bijections should generally avoid overriding transform and inverse.
  • The flowjax.wrappers module has been removed. Its functionality is now available in a new, separate package: paramax. Most functionality is unchanged when imported from paramax.
  • Partial has been renamed to Indexed, as the former was likely to be confused with functools.partial or jax.tree_util.Partial.
  • The deprecated fit_to_variational_target function has been removed. Use fit_to_key_based_loss instead.
  • Numerical inverse methods are now provided using the NumericalInverse composition. This means:
    • Users directly calling inverse methods on BlockAutoregressiveNetwork will encounter an error. To resolve this, supply an
      inverse method via NumericalInverse (as is done in block_neural_autoregressive_flow).
    • Users accessing the attributes of block_neural_autoregressive_flows might also face a breaking changes; an additional .bijection may be required to extract BlockAutoregressiveNetwork from NumericalInverse.
  • The uniform distribution is now non-trainable by default. Optimization of a uniform distribution seemed most commonly a mistake leading to e.g. violating support assumptions.
  • The root-finding algorithms in flowjax.bisection_search have been moved to flowjax.root_finding. Check the updated function names and documentation in the new module if you directly use these methods.

Apologies for any inconvenience caused by these breaking changes. If you encounter issues or have questions, please feel free to open an issue.


What's Changed

Full Changelog: v16.0.0...v17.0.0

v16.0.0

17 Oct 10:59
d711b18
Compare
Choose a tag to compare

What's Changed

Breaking changes:

  • Calls to get_ravelled_pytree_constructor will now need to explicitly pass the *args and **kwargs for partitioning parameters (usually setting is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable).
  • fit_to_data now returns list of floats, rather than than a list of scalar arrays.

Note, fit_to_variational_target will be deprecated in the next version. This version adds its replacement fit_to_key_based_loss. This was primarily because of some defaults which were "bad", e.g. steps: int = 100, and return_best=True (see #188 for details). It also generalizes the name, as it can be used to fit any pytree, and doesn't have to be used with a variational inference loss function.

Full Changelog: v15.1.0...v16.0.0

v15.1.0

07 Oct 12:17
Compare
Choose a tag to compare

What's Changed

Full Changelog: v15.0.0...v15.1.0

v15.0.0

17 Sep 20:00
73163c9
Compare
Choose a tag to compare

Breaking changes:

  • Users must switch from old style PRNGKey arrays to new style ones (replacing jax.random.PRNGKey with jax.random.key . The old keys will be deprecated in JAX.
  • LogNormal is now a Exp transformed Normal, which is an implementation detail, unless you previously relied on log_normal.base_dist or log_normal.bijection.
  • recursive_unwrap was removed, as alone it didn't provide additional functionality.

What's Changed

Full Changelog: v14.0.0...v15.0.0

v14.0.0

02 Sep 17:40
b8eb028
Compare
Choose a tag to compare

What's Changed

I was not happy with the original implementation of some wrappers in flowjax.wrappers, e.g. BijectionReparam introduced some (avoidable, but untidy) circular dependency issues. We now primarily use Parameterize a new name for what was Lambda. This introduces some breaking changes, primarily the aformentioned renaming, in addition to the removal of BijectionReparam and Where . See the pull request for more information.

Full Changelog: v13.1.1...v14.0.0

v13.1.1

29 Aug 13:24
157f867
Compare
Choose a tag to compare

What's Changed

Full Changelog: v13.1.0...v13.1.1

v13.1.0

29 Aug 10:08
a4bcab8
Compare
Choose a tag to compare

What's Changed

Full Changelog: v13.0.1...v13.1.0

Thanks for the help @weiyaw!

v13.0.1

24 Jul 14:52
91b1c39
Compare
Choose a tag to compare

New release to trigger update to docs.

What's Changed

v13.0.0

24 Jul 14:22
eeb3481
Compare
Choose a tag to compare

Some small breaking changes:

  • The interval attribute of RationalQuadraticSpline is now a two-tuple with type tuple[float | int, float | int], representing the upper and lower bound of the spline. Previously, the interval was a float | int, forcing the interval to be symmetric about 0.
  • Any custom loss functions used with fit_to_data now should accept a key (whether or not it is used), having signature loss(params, static, x, condition, key), for consistency of API.

What's Changed

Full Changelog: v12.4.0...v13.0.0

v12.4.0

21 Jun 17:29
19c7355
Compare
Choose a tag to compare

What's Changed

Full Changelog: v12.3.0...v12.4.0