Skip to content

Commit

Permalink
Merge pull request #172 from danielward27/update_wrappers
Browse files Browse the repository at this point in the history
Update wrappers
  • Loading branch information
danielward27 authored Sep 2, 2024
2 parents 157f867 + e53c865 commit b8eb028
Show file tree
Hide file tree
Showing 12 changed files with 104 additions and 177 deletions.
30 changes: 15 additions & 15 deletions flowjax/bijections/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from typing import ClassVar

import jax.numpy as jnp
from jax.nn import softplus
from jax.scipy.linalg import solve_triangular
from jaxtyping import Array, ArrayLike, Shaped

from flowjax import wrappers
from flowjax.bijections.bijection import AbstractBijection
from flowjax.bijections.softplus import SoftPlus
from flowjax.utils import arraylike_to_array
from flowjax.utils import arraylike_to_array, inv_softplus
from flowjax.wrappers import AbstractUnwrappable, Parameterize, unwrap


class Affine(AbstractBijection):
Expand All @@ -29,7 +29,7 @@ class Affine(AbstractBijection):
shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
loc: Array
scale: Array | wrappers.AbstractUnwrappable[Array]
scale: Array | AbstractUnwrappable[Array]

def __init__(
self,
Expand All @@ -40,7 +40,7 @@ def __init__(
*(arraylike_to_array(a, dtype=float) for a in (loc, scale)),
)
self.shape = scale.shape
self.scale = wrappers.BijectionReparam(scale, SoftPlus())
self.scale = Parameterize(softplus, inv_softplus(scale))

def transform(self, x, condition=None):
return x * self.scale + self.loc
Expand Down Expand Up @@ -92,15 +92,15 @@ class Scale(AbstractBijection):

shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
scale: Array | wrappers.AbstractUnwrappable[Array]
scale: Array | AbstractUnwrappable[Array]

def __init__(
self,
scale: ArrayLike,
):
scale = arraylike_to_array(scale, "scale", dtype=float)
self.scale = wrappers.BijectionReparam(scale, SoftPlus())
self.shape = jnp.shape(wrappers.unwrap(scale))
self.scale = Parameterize(softplus, inv_softplus(scale))
self.shape = jnp.shape(unwrap(scale))

def transform(self, x, condition=None):
return x * self.scale
Expand All @@ -120,7 +120,7 @@ class TriangularAffine(AbstractBijection):
Transformation has the form :math:`Ax + b`, where :math:`A` is a lower or upper
triangular matrix, and :math:`b` is the bias vector. We assume the diagonal
entries are positive, and constrain the values using SoftPlus. Other
entries are positive, and constrain the values using softplus. Other
parameterizations can be achieved by e.g. replacing ``self.triangular``
after construction.
Expand All @@ -135,7 +135,7 @@ class TriangularAffine(AbstractBijection):
shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
loc: Array
triangular: Array | wrappers.AbstractUnwrappable[Array]
triangular: Array | AbstractUnwrappable[Array]
lower: bool

def __init__(
Expand All @@ -152,12 +152,12 @@ def __init__(
raise ValueError("arr must be a square, 2-dimensional matrix.")
dim = arr.shape[0]

def _to_triangular(diag, arr):
tri = jnp.tril(arr, k=-1) if lower else jnp.triu(arr, k=1)
return jnp.diag(diag) + tri
def _to_triangular(arr):
tri = jnp.tril(arr) if lower else jnp.triu(arr)
return jnp.fill_diagonal(tri, softplus(jnp.diag(tri)), inplace=False)

diag = wrappers.BijectionReparam(jnp.diag(arr), SoftPlus())
self.triangular = wrappers.Lambda(_to_triangular, diag=diag, arr=arr)
arr = jnp.fill_diagonal(arr, inv_softplus(jnp.diag(arr)), inplace=False)
self.triangular = Parameterize(_to_triangular, arr)
self.lower = lower
self.shape = (dim,)
self.loc = jnp.broadcast_to(loc, (dim,))
Expand Down
16 changes: 7 additions & 9 deletions flowjax/bijections/block_autoregressive_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import jax.numpy as jnp
import jax.random as jr
from jax import random
from jax.nn import softplus
from jaxtyping import PRNGKeyArray

from flowjax import masks
from flowjax.bijections.bijection import AbstractBijection
from flowjax.bijections.softplus import SoftPlus
from flowjax.bijections.tanh import LeakyTanh
from flowjax.bisection_search import AutoregressiveBisectionInverter
from flowjax.wrappers import BijectionReparam, WeightNormalization, Where
from flowjax.wrappers import Parameterize, WeightNormalization


class _CallableToBijection(AbstractBijection):
Expand Down Expand Up @@ -219,13 +219,11 @@ def block_autoregressive_linear(
block_diag_mask = masks.block_diag_mask(block_shape, n_blocks)
block_tril_mask = masks.block_tril_mask(block_shape, n_blocks)

weight = Where(block_tril_mask, linear.weight, 0)
weight = Where(
block_diag_mask,
BijectionReparam(weight, SoftPlus(), invert_on_init=False),
weight,
)
weight = WeightNormalization(weight)
def apply_mask(weight):
weight = jnp.where(block_tril_mask, weight, 0)
return jnp.where(block_diag_mask, softplus(weight), weight)

weight = WeightNormalization(Parameterize(apply_mask, linear.weight))
linear = eqx.tree_at(lambda linear: linear.weight, linear, replace=weight)

def linear_to_log_block_diagonal(linear: eqx.nn.Linear):
Expand Down
8 changes: 5 additions & 3 deletions flowjax/bijections/masked_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from flowjax.bijections.jax_transforms import Vmap
from flowjax.masks import rank_based_mask
from flowjax.utils import get_ravelled_pytree_constructor
from flowjax.wrappers import Where
from flowjax.wrappers import Parameterize


class MaskedAutoregressive(AbstractBijection):
Expand Down Expand Up @@ -135,7 +135,7 @@ def masked_autoregressive_mlp(
) -> eqx.nn.MLP:
"""Returns an equinox multilayer perceptron, with autoregressive masks.
The weight matrices are wrapped using :class:`~flowjax.wrappers.Where`, which
The weight matrices are wrapped using :class:`~flowjax.wrappers.Parameterize`, which
will apply the masking when :class:`~flowjax.wrappers.unwrap` is called on the MLP.
For details of how the masks are formed, see https://arxiv.org/pdf/1502.03509.pdf.
Expand All @@ -160,7 +160,9 @@ def masked_autoregressive_mlp(
for i, linear in enumerate(mlp.layers):
mask = rank_based_mask(ranks[i], ranks[i + 1], eq=i != len(mlp.layers) - 1)
masked_linear = eqx.tree_at(
lambda linear: linear.weight, linear, Where(mask, linear.weight, 0)
lambda linear: linear.weight,
linear,
Parameterize(jnp.where, mask, linear.weight, 0),
)
masked_layers.append(masked_linear)

Expand Down
17 changes: 9 additions & 8 deletions flowjax/bijections/rational_quadratic_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import jax.numpy as jnp
from jaxtyping import Array, Float

from flowjax import wrappers
from flowjax.bijections.bijection import AbstractBijection
from flowjax.utils import inv_softplus
from flowjax.wrappers import AbstractUnwrappable, Parameterize


def _real_to_increasing_on_interval(
Expand Down Expand Up @@ -62,9 +63,9 @@ class RationalQuadraticSpline(AbstractBijection):
interval: tuple[int | float, int | float]
softmax_adjust: float | int
min_derivative: float
x_pos: Array | wrappers.AbstractUnwrappable[Array]
y_pos: Array | wrappers.AbstractUnwrappable[Array]
derivatives: Array | wrappers.AbstractUnwrappable[Array]
x_pos: Array | AbstractUnwrappable[Array]
y_pos: Array | AbstractUnwrappable[Array]
derivatives: Array | AbstractUnwrappable[Array]
shape: ClassVar[tuple] = ()
cond_shape: ClassVar[None] = None

Expand All @@ -89,11 +90,11 @@ def __init__(
softmax_adjust=softmax_adjust,
)

self.x_pos = wrappers.Lambda(pos_parameterization, jnp.zeros(knots))
self.y_pos = wrappers.Lambda(pos_parameterization, jnp.zeros(knots))
self.derivatives = wrappers.Lambda(
self.x_pos = Parameterize(pos_parameterization, jnp.zeros(knots))
self.y_pos = Parameterize(pos_parameterization, jnp.zeros(knots))
self.derivatives = Parameterize(
lambda arr: jax.nn.softplus(arr) + self.min_derivative,
jnp.full(knots + 2, jnp.log(jnp.exp(1 - min_derivative) - 1)),
jnp.full(knots + 2, inv_softplus(1 - min_derivative)),
)

def transform(self, x, condition=None):
Expand Down
10 changes: 5 additions & 5 deletions flowjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jax.numpy as jnp
import jax.random as jr
from equinox import AbstractVar
from jax.nn import log_softmax
from jax.nn import log_softmax, softplus
from jax.numpy import linalg
from jax.scipy import stats as jstats
from jax.scipy.special import logsumexp
Expand All @@ -24,15 +24,15 @@
Chain,
Exp,
Scale,
SoftPlus,
TriangularAffine,
)
from flowjax.utils import (
_get_ufunc_signature,
arraylike_to_array,
inv_softplus,
merge_cond_shapes,
)
from flowjax.wrappers import AbstractUnwrappable, BijectionReparam, Lambda, unwrap
from flowjax.wrappers import AbstractUnwrappable, Parameterize, unwrap


class AbstractDistribution(eqx.Module):
Expand Down Expand Up @@ -609,7 +609,7 @@ def __init__(self, df: ArrayLike):
df = arraylike_to_array(df, dtype=float)
df = eqx.error_if(df, df <= 0, "Degrees of freedom values must be positive.")
self.shape = jnp.shape(df)
self.df = BijectionReparam(df, SoftPlus())
self.df = Parameterize(softplus, inv_softplus(df))

def _log_prob(self, x, condition=None):
return jstats.t.logpdf(x, df=self.df).sum()
Expand Down Expand Up @@ -761,7 +761,7 @@ def __init__(
):
weights = eqx.error_if(weights, weights <= 0, "Weights must be positive.")
self.dist = dist
self.log_normalized_weights = Lambda(lambda w: log_softmax(w), jnp.log(weights))
self.log_normalized_weights = Parameterize(log_softmax, jnp.log(weights))
self.shape = dist.shape
self.cond_shape = dist.cond_shape

Expand Down
14 changes: 5 additions & 9 deletions flowjax/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import jax.numpy as jnp
import jax.random as jr
from equinox.nn import Linear
from jax.nn import softplus
from jax.nn.initializers import glorot_uniform
from jaxtyping import PRNGKeyArray

Expand All @@ -27,27 +28,22 @@
Flip,
Invert,
LeakyTanh,
Loc,
MaskedAutoregressive,
Permute,
Planar,
RationalQuadraticSpline,
Scan,
SoftPlus,
TriangularAffine,
Vmap,
)
from flowjax.distributions import AbstractDistribution, Transformed
from flowjax.wrappers import BijectionReparam, WeightNormalization, non_trainable
from flowjax.utils import inv_softplus
from flowjax.wrappers import Parameterize, WeightNormalization


def _affine_with_min_scale(min_scale: float = 1e-2) -> Affine:
scale_reparam = Chain([SoftPlus(), non_trainable(Loc(min_scale))])
return eqx.tree_at(
where=lambda aff: aff.scale,
pytree=Affine(),
replace=BijectionReparam(jnp.array(1), scale_reparam),
)
scale = Parameterize(lambda x: softplus(x) + min_scale, inv_softplus(1 - min_scale))
return eqx.tree_at(where=lambda aff: aff.scale, pytree=Affine(), replace=scale)


def coupling_flow(
Expand Down
2 changes: 1 addition & 1 deletion flowjax/train/variational_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def fit_to_variational_target(
params, opt_state, loss = step(
params,
static,
key,
key=key,
optimizer=optimizer,
opt_state=opt_state,
loss_fn=loss_fn,
Expand Down
12 changes: 12 additions & 0 deletions flowjax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@
import flowjax


def inv_softplus(x: ArrayLike) -> Array:
"""The inverse of the softplus function, checking for positive inputs."""
x = eqx.error_if(
x,
x < 0,
"Expected positive inputs to inv_softplus. If you are trying to use a negative "
"scale parameter, consider constructing with positive scales and modifying the "
"scale attribute post-construction, e.g., using eqx.tree_at.",
)
return jnp.log(-jnp.expm1(-x)) + x


def merge_cond_shapes(shapes: Sequence[tuple[int, ...] | None]):
"""Merges shapes (tuples of ints or None) used in bijections and distributions.
Expand Down
Loading

0 comments on commit b8eb028

Please sign in to comment.