Releases: danielward27/flowjax
v17.0.2
Updates to the workflow packages led to some errors in the release to PyPi, which was fixed with this release. Below are the changes from v16-v17, which are of more interest to users.
Breaking Changes
- The
AbstractBijection
class now implementstransform
andinverse
methods by indexing the outputs oftransform_and_log_det
andinverse_and_log_det
. AFAIK, under JIT, dead code elimination prevents computation oflog_det
where unnecessary, minimizing the benefits of directly implementing these methods. Custom bijections should generally avoid overridingtransform
andinverse
. - The
flowjax.wrappers
module has been removed. Its functionality is now available in a new, separate package:paramax
. Most functionality is unchanged when imported fromparamax
. Partial
has been renamed toIndexed
, as the former was likely to be confused withfunctools.partial
orjax.tree_util.Partial
.- The deprecated
fit_to_variational_target
function has been removed. Usefit_to_key_based_loss
instead. - Numerical inverse methods are now provided using the
NumericalInverse
composition. This means:- Users directly calling inverse methods on
BlockAutoregressiveNetwork
will encounter an error. To resolve this, supply an
inverse method viaNumericalInverse
(as is done inblock_neural_autoregressive_flow
). - Users accessing the attributes of
block_neural_autoregressive_flow
s might also face a breaking changes; an additional.bijection
may be required to extractBlockAutoregressiveNetwork
fromNumericalInverse
.
- Users directly calling inverse methods on
- The uniform distribution is now non-trainable by default. Optimization of a uniform distribution seemed most commonly a mistake leading to e.g. violating support assumptions.
- The root-finding algorithms in
flowjax.bisection_search
have been moved toflowjax.root_finding
. Check the updated function names and documentation in the new module if you directly use these methods.
Apologies for any inconvenience caused by these breaking changes. If you encounter issues or have questions, please feel free to open an issue.
What's Changed
- Update docs by @danielward27 in #190
- Paramax by @danielward27 in #192
- Provide transform and inverse method implementations on AbstractBijection by @danielward27 in #193
- Root finding by @danielward27 in #195
- Rename partial and rm fit to variational by @danielward27 in #196
- Root finding by @danielward27 in #198
- Make uniform non-trainable by @danielward27 in #199
Full Changelog: v16.0.0...v17.0.0
v16.0.0
What's Changed
- Return dict of list of floats or list of floats in fitting functions by @danielward27 in #185
- No need for static by @danielward27 in #186
- Support old style keys in numpyro wrappers by @danielward27 in #187
- key_based_loss update by @danielward27 in #188
- Reduce coupling by @danielward27 in #189
Breaking changes:
- Calls to
get_ravelled_pytree_constructor
will now need to explicitly pass the *args and **kwargs for partitioning parameters (usually settingis_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable)
. fit_to_data
now returns list of floats, rather than than a list of scalar arrays.
Note, fit_to_variational_target
will be deprecated in the next version. This version adds its replacement fit_to_key_based_loss
. This was primarily because of some defaults which were "bad", e.g. steps: int = 100
, and return_best=True
(see #188 for details). It also generalizes the name, as it can be used to fit any pytree, and doesn't have to be used with a variational inference loss function.
Full Changelog: v15.1.0...v16.0.0
v15.1.0
What's Changed
- Add gamma and power transform by @danielward27 in #182
- Update docs by @danielward27 in #183
- Add Beta distribution and Sigmoid transform by @danielward27 in #184
Full Changelog: v15.0.0...v15.1.0
v15.0.0
Breaking changes:
- Users must switch from old style PRNGKey arrays to new style ones (replacing
jax.random.PRNGKey
withjax.random.key
. The old keys will be deprecated in JAX. LogNormal
is now aExp
transformedNormal
, which is an implementation detail, unless you previously relied onlog_normal.base_dist
orlog_normal.bijection
.recursive_unwrap
was removed, as alone it didn't provide additional functionality.
What's Changed
- Avoid array in AutoregressiveBisectionInverter by @danielward27 in #174
- Simplify wrappers by @danielward27 in #175
- Docs and switch to jr.key by @danielward27 in #177
- Documentation example fixes by @danielward27 in #178
- Update README.md by @danielward27 in #179
Full Changelog: v14.0.0...v15.0.0
v14.0.0
What's Changed
I was not happy with the original implementation of some wrappers in flowjax.wrappers
, e.g. BijectionReparam
introduced some (avoidable, but untidy) circular dependency issues. We now primarily use Parameterize
a new name for what was Lambda
. This introduces some breaking changes, primarily the aformentioned renaming, in addition to the removal of BijectionReparam
and Where
. See the pull request for more information.
- Update wrappers by @danielward27 in #172
Full Changelog: v13.1.1...v14.0.0
v13.1.1
What's Changed
- Test for updated equinox errors by @danielward27 in #171
Full Changelog: v13.1.0...v13.1.1
v13.1.0
What's Changed
- Planar inverse with leaky relu by @danielward27 in #170.
Full Changelog: v13.0.1...v13.1.0
Thanks for the help @weiyaw!
v13.0.1
v13.0.0
Some small breaking changes:
- The interval attribute of
RationalQuadraticSpline
is now a two-tuple with typetuple[float | int, float | int]
, representing the upper and lower bound of the spline. Previously, the interval was afloat | int
, forcing the interval to be symmetric about 0. - Any custom loss functions used with
fit_to_data
now should accept a key (whether or not it is used), having signatureloss(params, static, x, condition, key)
, for consistency of API.
What's Changed
- Allow uneven interval in spline by @danielward27 in #166
- Sample contrastive by @danielward27 in #167
Full Changelog: v12.4.0...v13.0.0