diff --git a/flowjax/bijections/affine.py b/flowjax/bijections/affine.py index b676c025..cb46bdbd 100644 --- a/flowjax/bijections/affine.py +++ b/flowjax/bijections/affine.py @@ -4,13 +4,13 @@ from typing import ClassVar import jax.numpy as jnp +from jax.nn import softplus from jax.scipy.linalg import solve_triangular from jaxtyping import Array, ArrayLike, Shaped -from flowjax import wrappers from flowjax.bijections.bijection import AbstractBijection -from flowjax.bijections.softplus import SoftPlus -from flowjax.utils import arraylike_to_array +from flowjax.utils import arraylike_to_array, inv_softplus +from flowjax.wrappers import AbstractUnwrappable, Parameterize, unwrap class Affine(AbstractBijection): @@ -29,7 +29,7 @@ class Affine(AbstractBijection): shape: tuple[int, ...] cond_shape: ClassVar[None] = None loc: Array - scale: Array | wrappers.AbstractUnwrappable[Array] + scale: Array | AbstractUnwrappable[Array] def __init__( self, @@ -40,7 +40,7 @@ def __init__( *(arraylike_to_array(a, dtype=float) for a in (loc, scale)), ) self.shape = scale.shape - self.scale = wrappers.BijectionReparam(scale, SoftPlus()) + self.scale = Parameterize(softplus, inv_softplus(scale)) def transform(self, x, condition=None): return x * self.scale + self.loc @@ -92,15 +92,15 @@ class Scale(AbstractBijection): shape: tuple[int, ...] cond_shape: ClassVar[None] = None - scale: Array | wrappers.AbstractUnwrappable[Array] + scale: Array | AbstractUnwrappable[Array] def __init__( self, scale: ArrayLike, ): scale = arraylike_to_array(scale, "scale", dtype=float) - self.scale = wrappers.BijectionReparam(scale, SoftPlus()) - self.shape = jnp.shape(wrappers.unwrap(scale)) + self.scale = Parameterize(softplus, inv_softplus(scale)) + self.shape = jnp.shape(unwrap(scale)) def transform(self, x, condition=None): return x * self.scale @@ -120,7 +120,7 @@ class TriangularAffine(AbstractBijection): Transformation has the form :math:`Ax + b`, where :math:`A` is a lower or upper triangular matrix, and :math:`b` is the bias vector. We assume the diagonal - entries are positive, and constrain the values using SoftPlus. Other + entries are positive, and constrain the values using softplus. Other parameterizations can be achieved by e.g. replacing ``self.triangular`` after construction. @@ -135,7 +135,7 @@ class TriangularAffine(AbstractBijection): shape: tuple[int, ...] cond_shape: ClassVar[None] = None loc: Array - triangular: Array | wrappers.AbstractUnwrappable[Array] + triangular: Array | AbstractUnwrappable[Array] lower: bool def __init__( @@ -152,12 +152,12 @@ def __init__( raise ValueError("arr must be a square, 2-dimensional matrix.") dim = arr.shape[0] - def _to_triangular(diag, arr): - tri = jnp.tril(arr, k=-1) if lower else jnp.triu(arr, k=1) - return jnp.diag(diag) + tri + def _to_triangular(arr): + tri = jnp.tril(arr) if lower else jnp.triu(arr) + return jnp.fill_diagonal(tri, softplus(jnp.diag(tri)), inplace=False) - diag = wrappers.BijectionReparam(jnp.diag(arr), SoftPlus()) - self.triangular = wrappers.Lambda(_to_triangular, diag=diag, arr=arr) + arr = jnp.fill_diagonal(arr, inv_softplus(jnp.diag(arr)), inplace=False) + self.triangular = Parameterize(_to_triangular, arr) self.lower = lower self.shape = (dim,) self.loc = jnp.broadcast_to(loc, (dim,)) diff --git a/flowjax/bijections/block_autoregressive_network.py b/flowjax/bijections/block_autoregressive_network.py index 9cb6ad14..ab0a8f4a 100644 --- a/flowjax/bijections/block_autoregressive_network.py +++ b/flowjax/bijections/block_autoregressive_network.py @@ -9,14 +9,14 @@ import jax.numpy as jnp import jax.random as jr from jax import random +from jax.nn import softplus from jaxtyping import PRNGKeyArray from flowjax import masks from flowjax.bijections.bijection import AbstractBijection -from flowjax.bijections.softplus import SoftPlus from flowjax.bijections.tanh import LeakyTanh from flowjax.bisection_search import AutoregressiveBisectionInverter -from flowjax.wrappers import BijectionReparam, WeightNormalization, Where +from flowjax.wrappers import Parameterize, WeightNormalization class _CallableToBijection(AbstractBijection): @@ -219,13 +219,11 @@ def block_autoregressive_linear( block_diag_mask = masks.block_diag_mask(block_shape, n_blocks) block_tril_mask = masks.block_tril_mask(block_shape, n_blocks) - weight = Where(block_tril_mask, linear.weight, 0) - weight = Where( - block_diag_mask, - BijectionReparam(weight, SoftPlus(), invert_on_init=False), - weight, - ) - weight = WeightNormalization(weight) + def apply_mask(weight): + weight = jnp.where(block_tril_mask, weight, 0) + return jnp.where(block_diag_mask, softplus(weight), weight) + + weight = WeightNormalization(Parameterize(apply_mask, linear.weight)) linear = eqx.tree_at(lambda linear: linear.weight, linear, replace=weight) def linear_to_log_block_diagonal(linear: eqx.nn.Linear): diff --git a/flowjax/bijections/masked_autoregressive.py b/flowjax/bijections/masked_autoregressive.py index d3ca06eb..2f1fca82 100644 --- a/flowjax/bijections/masked_autoregressive.py +++ b/flowjax/bijections/masked_autoregressive.py @@ -13,7 +13,7 @@ from flowjax.bijections.jax_transforms import Vmap from flowjax.masks import rank_based_mask from flowjax.utils import get_ravelled_pytree_constructor -from flowjax.wrappers import Where +from flowjax.wrappers import Parameterize class MaskedAutoregressive(AbstractBijection): @@ -135,7 +135,7 @@ def masked_autoregressive_mlp( ) -> eqx.nn.MLP: """Returns an equinox multilayer perceptron, with autoregressive masks. - The weight matrices are wrapped using :class:`~flowjax.wrappers.Where`, which + The weight matrices are wrapped using :class:`~flowjax.wrappers.Parameterize`, which will apply the masking when :class:`~flowjax.wrappers.unwrap` is called on the MLP. For details of how the masks are formed, see https://arxiv.org/pdf/1502.03509.pdf. @@ -160,7 +160,9 @@ def masked_autoregressive_mlp( for i, linear in enumerate(mlp.layers): mask = rank_based_mask(ranks[i], ranks[i + 1], eq=i != len(mlp.layers) - 1) masked_linear = eqx.tree_at( - lambda linear: linear.weight, linear, Where(mask, linear.weight, 0) + lambda linear: linear.weight, + linear, + Parameterize(jnp.where, mask, linear.weight, 0), ) masked_layers.append(masked_linear) diff --git a/flowjax/bijections/rational_quadratic_spline.py b/flowjax/bijections/rational_quadratic_spline.py index f7d71423..59d55494 100644 --- a/flowjax/bijections/rational_quadratic_spline.py +++ b/flowjax/bijections/rational_quadratic_spline.py @@ -7,8 +7,9 @@ import jax.numpy as jnp from jaxtyping import Array, Float -from flowjax import wrappers from flowjax.bijections.bijection import AbstractBijection +from flowjax.utils import inv_softplus +from flowjax.wrappers import AbstractUnwrappable, Parameterize def _real_to_increasing_on_interval( @@ -62,9 +63,9 @@ class RationalQuadraticSpline(AbstractBijection): interval: tuple[int | float, int | float] softmax_adjust: float | int min_derivative: float - x_pos: Array | wrappers.AbstractUnwrappable[Array] - y_pos: Array | wrappers.AbstractUnwrappable[Array] - derivatives: Array | wrappers.AbstractUnwrappable[Array] + x_pos: Array | AbstractUnwrappable[Array] + y_pos: Array | AbstractUnwrappable[Array] + derivatives: Array | AbstractUnwrappable[Array] shape: ClassVar[tuple] = () cond_shape: ClassVar[None] = None @@ -89,11 +90,11 @@ def __init__( softmax_adjust=softmax_adjust, ) - self.x_pos = wrappers.Lambda(pos_parameterization, jnp.zeros(knots)) - self.y_pos = wrappers.Lambda(pos_parameterization, jnp.zeros(knots)) - self.derivatives = wrappers.Lambda( + self.x_pos = Parameterize(pos_parameterization, jnp.zeros(knots)) + self.y_pos = Parameterize(pos_parameterization, jnp.zeros(knots)) + self.derivatives = Parameterize( lambda arr: jax.nn.softplus(arr) + self.min_derivative, - jnp.full(knots + 2, jnp.log(jnp.exp(1 - min_derivative) - 1)), + jnp.full(knots + 2, inv_softplus(1 - min_derivative)), ) def transform(self, x, condition=None): diff --git a/flowjax/distributions.py b/flowjax/distributions.py index b7c24405..f3750e43 100644 --- a/flowjax/distributions.py +++ b/flowjax/distributions.py @@ -11,7 +11,7 @@ import jax.numpy as jnp import jax.random as jr from equinox import AbstractVar -from jax.nn import log_softmax +from jax.nn import log_softmax, softplus from jax.numpy import linalg from jax.scipy import stats as jstats from jax.scipy.special import logsumexp @@ -24,15 +24,15 @@ Chain, Exp, Scale, - SoftPlus, TriangularAffine, ) from flowjax.utils import ( _get_ufunc_signature, arraylike_to_array, + inv_softplus, merge_cond_shapes, ) -from flowjax.wrappers import AbstractUnwrappable, BijectionReparam, Lambda, unwrap +from flowjax.wrappers import AbstractUnwrappable, Parameterize, unwrap class AbstractDistribution(eqx.Module): @@ -609,7 +609,7 @@ def __init__(self, df: ArrayLike): df = arraylike_to_array(df, dtype=float) df = eqx.error_if(df, df <= 0, "Degrees of freedom values must be positive.") self.shape = jnp.shape(df) - self.df = BijectionReparam(df, SoftPlus()) + self.df = Parameterize(softplus, inv_softplus(df)) def _log_prob(self, x, condition=None): return jstats.t.logpdf(x, df=self.df).sum() @@ -761,7 +761,7 @@ def __init__( ): weights = eqx.error_if(weights, weights <= 0, "Weights must be positive.") self.dist = dist - self.log_normalized_weights = Lambda(lambda w: log_softmax(w), jnp.log(weights)) + self.log_normalized_weights = Parameterize(log_softmax, jnp.log(weights)) self.shape = dist.shape self.cond_shape = dist.cond_shape diff --git a/flowjax/flows.py b/flowjax/flows.py index 2c8a968a..6423e617 100644 --- a/flowjax/flows.py +++ b/flowjax/flows.py @@ -14,6 +14,7 @@ import jax.numpy as jnp import jax.random as jr from equinox.nn import Linear +from jax.nn import softplus from jax.nn.initializers import glorot_uniform from jaxtyping import PRNGKeyArray @@ -27,27 +28,22 @@ Flip, Invert, LeakyTanh, - Loc, MaskedAutoregressive, Permute, Planar, RationalQuadraticSpline, Scan, - SoftPlus, TriangularAffine, Vmap, ) from flowjax.distributions import AbstractDistribution, Transformed -from flowjax.wrappers import BijectionReparam, WeightNormalization, non_trainable +from flowjax.utils import inv_softplus +from flowjax.wrappers import Parameterize, WeightNormalization def _affine_with_min_scale(min_scale: float = 1e-2) -> Affine: - scale_reparam = Chain([SoftPlus(), non_trainable(Loc(min_scale))]) - return eqx.tree_at( - where=lambda aff: aff.scale, - pytree=Affine(), - replace=BijectionReparam(jnp.array(1), scale_reparam), - ) + scale = Parameterize(lambda x: softplus(x) + min_scale, inv_softplus(1 - min_scale)) + return eqx.tree_at(where=lambda aff: aff.scale, pytree=Affine(), replace=scale) def coupling_flow( diff --git a/flowjax/train/variational_fit.py b/flowjax/train/variational_fit.py index adc02333..cd22f440 100644 --- a/flowjax/train/variational_fit.py +++ b/flowjax/train/variational_fit.py @@ -61,7 +61,7 @@ def fit_to_variational_target( params, opt_state, loss = step( params, static, - key, + key=key, optimizer=optimizer, opt_state=opt_state, loss_fn=loss_fn, diff --git a/flowjax/utils.py b/flowjax/utils.py index 716218e8..2cf48290 100644 --- a/flowjax/utils.py +++ b/flowjax/utils.py @@ -10,6 +10,18 @@ import flowjax +def inv_softplus(x: ArrayLike) -> Array: + """The inverse of the softplus function, checking for positive inputs.""" + x = eqx.error_if( + x, + x < 0, + "Expected positive inputs to inv_softplus. If you are trying to use a negative " + "scale parameter, consider constructing with positive scales and modifying the " + "scale attribute post-construction, e.g., using eqx.tree_at.", + ) + return jnp.log(-jnp.expm1(-x)) + x + + def merge_cond_shapes(shapes: Sequence[tuple[int, ...] | None]): """Merges shapes (tuples of ints or None) used in bijections and distributions. diff --git a/flowjax/wrappers.py b/flowjax/wrappers.py index 809b6cd7..818f26cc 100644 --- a/flowjax/wrappers.py +++ b/flowjax/wrappers.py @@ -25,20 +25,16 @@ from abc import abstractmethod from collections.abc import Callable, Iterable -from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar +from typing import Any, ClassVar, Generic, TypeVar import equinox as eqx -import jax.numpy as jnp -from jaxtyping import Int, Scalar - -from flowjax.utils import arraylike_to_array - -if TYPE_CHECKING: - from flowjax.bijections import AbstractBijection - import jax +import jax.numpy as jnp from jax import lax -from jaxtyping import Array, ArrayLike, PyTree +from jax.nn import softplus +from jaxtyping import Array, Int, PyTree, Scalar + +from flowjax.utils import inv_softplus T = TypeVar("T") @@ -134,68 +130,33 @@ def _map_fn(leaf): ) -def _apply_inverse_and_check_valid(bijection, arr): - param_inv = bijection._vectorize.inverse(arr) - return eqx.error_if( - param_inv, - jnp.logical_and(jnp.isfinite(arr), ~jnp.isfinite(param_inv)), - "Non-finite value(s) introduced when reparameterizing. This suggests " - "the parameter vector passed to BijectionReparam was incompatible with " - f"the bijection used for reparameterizing ({type(bijection).__name__}).", - ) - - -class BijectionReparam(AbstractUnwrappable[Array]): - """Reparameterize a parameter using a bijection. +class Parameterize(AbstractUnwrappable[T]): + """Unwrap an object by calling fn with args and kwargs. - When applying unwrap, ``bijection.transform`` is applied. By default, the inverse - of the bijection is applied when setting the parameter values. + All of fn, args and kwargs may contain trainable parameters. If the Parameterize is + created within ``eqx.filter_vmap``, unwrapping is automatically vectorized + correctly, as long as the vmapped constructor adds leading batch + dimensions to all arrays (the default for ``eqx.filter_vmap``). Args: - arr: The parameter to reparameterize. If invert_on_init is False, then this can - be a ``AbstractUnwrappable[Array]``. - bijection: A bijection whose shape is broadcastable to ``jnp.shape(arr)``. - invert_on_init: Whether to apply the inverse transformation when initializing. - Defaults to True. + fn: Callable to call with args, and kwargs. + *args: Positional arguments to pass to fn. + **kwargs: Keyword arguments to pass to fn. """ - arr: Array | AbstractUnwrappable[Array] - bijection: "AbstractBijection" + fn: Callable[..., T] + args: Iterable + kwargs: dict[str, Any] _dummy: Int[Scalar, ""] - def __init__( - self, - arr: Array | AbstractUnwrappable[Array], - bijection: "AbstractBijection", - *, - invert_on_init: bool = True, - ): - if invert_on_init: - self.arr = _apply_inverse_and_check_valid(bijection, arr) - else: - if not isinstance(arr, AbstractUnwrappable): - arr = arraylike_to_array(arr) - self.arr = arr - self.bijection = bijection + def __init__(self, fn: Callable, *args, **kwargs): + self.fn = fn + self.args = args + self.kwargs = kwargs self._dummy = jnp.empty((), int) - def unwrap(self) -> Array: - return self.bijection._vectorize.transform(self.arr) - - -class Where(AbstractUnwrappable[Array]): - """Applies jnp.where upon unwrapping. - - This can be used to construct masks by setting ``cond=mask`` and ``if_false=0``. - """ - - cond: ArrayLike - if_true: ArrayLike | AbstractUnwrappable[Array] - if_false: ArrayLike | AbstractUnwrappable[Array] - _dummy: ClassVar[None] = None - - def unwrap(self): - return jnp.where(self.cond, self.if_true, self.if_false) + def unwrap(self) -> T: + return self.fn(*self.args, **self.kwargs) class WeightNormalization(AbstractUnwrappable[Array]): @@ -210,40 +171,10 @@ class WeightNormalization(AbstractUnwrappable[Array]): _dummy: ClassVar[None] = None def __init__(self, weight: Array | AbstractUnwrappable[Array]): - from flowjax.bijections import SoftPlus # Delayed to avoid circular import... - self.weight = weight scale_init = 1 / jnp.linalg.norm(unwrap(weight), axis=-1, keepdims=True) - self.scale = BijectionReparam(scale_init, SoftPlus()) + self.scale = Parameterize(softplus, inv_softplus(scale_init)) def unwrap(self) -> Array: weight_norms = jnp.linalg.norm(self.weight, axis=-1, keepdims=True) return self.scale * self.weight / weight_norms - - -class Lambda(AbstractUnwrappable[T]): - """Unwrap an object by calling fn with (possibly trainable) args and kwargs. - - If the Lambda is created within ``eqx.filter_vmap``, unwrapping is automatically - vectorized correctly, as long as the vmapped constructor adds leading batch - dimensions to all arrays in Lambda (the default for ``eqx.filter_vmap``). - - Args: - fn: Function to call with args, and kwargs. - *args: Positional arguments to pass to fn. - **kwargs: Keyword arguments to pass to fn. - """ - - fn: Callable[..., T] - args: Iterable - kwargs: dict - _dummy: Int[Scalar, ""] - - def __init__(self, fn, *args, **kwargs): - self.fn = fn - self.args = args - self.kwargs = kwargs - self._dummy = jnp.empty((), int) - - def unwrap(self) -> T: - return self.fn(*self.args, **self.kwargs) diff --git a/pyproject.toml b/pyproject.toml index 45162fd7..31391a8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ license = { file = "LICENSE" } name = "flowjax" readme = "README.md" requires-python = ">=3.10" -version = "13.1.1" +version = "14.0.0" [project.urls] repository = "https://github.com/danielward27/flowjax" diff --git a/tests/test_bijections/test_masked_autoregressive.py b/tests/test_bijections/test_masked_autoregressive.py index 989e9da9..0412d838 100644 --- a/tests/test_bijections/test_masked_autoregressive.py +++ b/tests/test_bijections/test_masked_autoregressive.py @@ -15,9 +15,8 @@ def test_masked_autoregressive_mlp(): # Extract masks before unwrapping mlp = masked_autoregressive_mlp(in_ranks, hidden_ranks, out_ranks, depth=3, key=key) - masks = [layer.weight.cond for layer in mlp.layers] - mlp = unwrap(mlp) + masks = [layer.weight != 0 for layer in mlp.layers] x = jnp.ones(in_size) y = mlp(x) assert y.shape == out_ranks.shape diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 8416030e..8420bdd2 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -4,59 +4,47 @@ import jax.random as jr import pytest -from flowjax.bijections import Exp, Scale from flowjax.distributions import Normal from flowjax.wrappers import ( - BijectionReparam, - Lambda, NonTrainable, + Parameterize, WeightNormalization, non_trainable, unwrap, ) -def test_BijectionReparam(): - - with pytest.raises(eqx.EquinoxRuntimeError, match="Exp"): - BijectionReparam(-jnp.ones(3), Exp()) - - param = jnp.array([jnp.inf, 1, 2]) - wrapped = BijectionReparam(param, Exp()) - assert pytest.approx(unwrap(wrapped)) == param - assert pytest.approx(wrapped.arr) == jnp.log(param) - - # Test with vmapped constructor - - def _get_param(arr): - return BijectionReparam(arr, Scale(jnp.full(3, fill_value=2))) - - init_param = jnp.ones((1, 2, 3)) - param = eqx.filter_vmap(eqx.filter_vmap(_get_param))(init_param) - assert pytest.approx(init_param) == unwrap(param) - - -def test_Lambda(): - diag = Lambda(jnp.diag, jnp.ones(3)) +def test_Parameterize(): + diag = Parameterize(jnp.diag, jnp.ones(3)) assert pytest.approx(jnp.eye(3)) == unwrap(diag) # Test works when vmapped (note diag does not follow standard vectorization rules) - v_diag = eqx.filter_vmap(Lambda)(jnp.diag, jnp.ones((4, 3))) + v_diag = eqx.filter_vmap(Parameterize)(jnp.diag, jnp.ones((4, 3))) expected = eqx.filter_vmap(jnp.eye, axis_size=4)(3) assert pytest.approx(expected) == unwrap(v_diag) # Test works when double vmapped - v_diag = eqx.filter_vmap(eqx.filter_vmap(Lambda))(jnp.diag, jnp.ones((5, 4, 3))) + v_diag = eqx.filter_vmap(eqx.filter_vmap(Parameterize))( + jnp.diag, jnp.ones((5, 4, 3)) + ) expected = eqx.filter_vmap(eqx.filter_vmap(jnp.eye, axis_size=4), axis_size=5)(3) assert pytest.approx(expected) == unwrap(v_diag) # Test works when no arrays present (in which case axis_size is relied on) - unwrappable = eqx.filter_vmap(eqx.filter_vmap(Lambda, axis_size=2), axis_size=3)( - lambda: jnp.zeros(()) - ) + unwrappable = eqx.filter_vmap( + eqx.filter_vmap(Parameterize, axis_size=2), axis_size=3 + )(lambda: jnp.zeros(())) assert pytest.approx(unwrap(unwrappable)) == jnp.zeros((3, 2)) +def test_nested_Parameterized(): + param = Parameterize( + jnp.square, + Parameterize(jnp.square, Parameterize(jnp.square, 2)), + ) + assert unwrap(param) == jnp.square(jnp.square(jnp.square(2))) + + def test_NonTrainable_and_non_trainable(): dist1 = eqx.tree_at(lambda dist: dist.bijection, Normal(), replace_fn=NonTrainable) dist2 = non_trainable(Normal())