Skip to content

Commit

Permalink
Make docstrings more detailed.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Oct 18, 2023
1 parent a716038 commit 62b9da9
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 20 deletions.
165 changes: 149 additions & 16 deletions harmonic/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@

class RealNVP(nn.Module):
"""
Real NVP flow using flax and tfp-jax.
Real-valued non-volume preserving flow using flax and tfp-jax.
Args:
n_features (int): Number of features in the data.
n_scaled_layers (int): Non-zero number of layers in the flow.
n_unscaled_layers (int): Number of unscaled layers in the flow.
"""

n_features: int
Expand All @@ -34,6 +42,20 @@ def setup(self):
]

def make_flow(self, var_scale=1.0):
"""
Make tfp-jax distribution object containing the RealNVP flow.
Args:
var_scale (float): Factor by which base Gaussian unit covariance matrix is scaled.
Should be between 0 and 1 for use in evidence estimation.
Returns:
tfb.Distribution: Base Gaussian transformed by scaled contained in the scaled_layers
attribute, followed by unscaled affine coupling layers contained in the
unscaled_layers attribute.
"""
chain = []
ix = jnp.arange(self.n_features)
permutation = [ix[-1], *ix[:-1]]
Expand Down Expand Up @@ -73,27 +95,71 @@ def make_flow(self, var_scale=1.0):
return nvp

def __call__(self, x, var_scale=1.0) -> jnp.array:
"""
Evaluate the log probability of the flow for non-batched input x.
Args:
x (jnp.ndarray (ndim)): Sample at which to predict posterior value.
var_scale (float): Factor by which base Gaussian unit covariance matrix is scaled.
Should be between 0 and 1 for use in evidence estimation.
Returns:
float: Predicted log_e posterior value.
"""

flow = self.make_flow(var_scale=var_scale)
return flow.log_prob(x)

def sample(
self, rng: jax.random.PRNGKey, num_samples: int, var_scale: float = 1.0
) -> jnp.array:
""" "
""""
Sample from the flow.
Args:
rng (Union[Array, PRNGKeyArray])): Key used in random number generation process.
num_samples (int): Number of samples generated.
scale (float): Factor by which base Gaussian unit covariance matrix is scaled.
Should be between 0 and 1 for use in evidence estimation.
Returns:
jnp.array (num_samples, ndim): Samples from fitted distribution.
"""
nvp = self.make_flow(var_scale=var_scale)
samples = nvp.sample(num_samples, seed=rng)

return samples

def log_prob(self, x: jnp.array, var_scale: float = 1.0) -> jnp.array:
"""
Evaluate the log probability of the flow for a batched input.
Args:
x (jnp.ndarray (batch_size, ndim)): Sample for which to predict posterior values.
var_scale (float): Factor by which base Gaussian unit covariance matrix is scaled.
Should be between 0 and 1 for use in evidence estimation.
Returns:
jnp.ndarray (batch_size,): Predicted log_e posterior value.
"""

get_logprob = jax.jit(jax.vmap(self.__call__, in_axes=[0, None]))
logprob = get_logprob(x, var_scale)

return logprob


class AffineCoupling(nn.Module):
"""
Affine coupling layer used in RealNVP flow class.
Args:
apply_scaling (bool): If true shift is followed by a scaling
"""
apply_scaling: bool = True

@nn.compact
Expand Down Expand Up @@ -122,15 +188,17 @@ class RQSpline(nn.Module):
Rational quadratic spline normalizing flow model using distrax.
Args:
n_features : (int) Number of features in the data.
num_layers : (int) Number of layers in the flow.
num_bins : (int) Number of bins in the spline.
hidden_size : (Sequence[int]) Size of the hidden layers in the conditioner.
spline_range : (Sequence[float]) Range of the spline.
Properties:
base_mean: (ndarray) Mean of Gaussian base distribution
base_cov: (ndarray) Covariance of Gaussian base distribution
n_features (int): Number of features in the data.
num_layers (int): Number of layers in the flow.
num_bins (int): Number of bins in the spline.
hidden_size (Sequence[int]): Size of the hidden layers in the conditioner.
spline_range (Sequence[float]): Range of the spline.
Adapted from github.com/kazewong/flowMC
"""

n_features: int
Expand Down Expand Up @@ -161,6 +229,16 @@ def bijector_fn(params: jnp.ndarray):
self.bijector_fn = bijector_fn

def make_flow(self, scale: float =1.):
"""
Make distrax distribution containing the rational quadratic spline flow.
Args:
var_scale (float): Factor by which base Gaussian unit covariance matrix is scaled.
Should be between 0 and 1 for use in evidence estimation.
Returns:
Base Gaussian transformed by rational quadratic spline flow.
"""
mask = (jnp.arange(0, self.n_features) % 2).astype(bool)
mask_all = (jnp.zeros(self.n_features)).astype(bool)
layers = []
Expand Down Expand Up @@ -190,21 +268,59 @@ def make_flow(self, scale: float =1.):
return base_dist, flow

def __call__(self, x: jnp.array, scale: float =1.) -> jnp.array:
"""
Evaluate the log probability of the flow for non-batched input x.
Args:
x (jnp.ndarray (ndim)): Sample at which to predict posterior value.
var_scale (float): Factor by which base Gaussian unit covariance matrix is scaled.
Should be between 0 and 1 for use in evidence estimation.
Returns:
jnp.ndarray (float): Predicted log_e posterior value.
"""
base_dist, flow = self.make_flow(scale=scale)

return distrax.Transformed(base_dist, flow).log_prob(x)

def sample(self, rng: jax.random.PRNGKey, num_samples: int, scale: float = 1.) -> jnp.array:
""""
Sample from the flow.
Args:
rng (Union[Array, PRNGKeyArray])): Key used in random number generation process.
num_samples (int): Number of samples generated.
scale (float): Factor by which base Gaussian unit covariance matrix is scaled.
Should be between 0 and 1 for use in evidence estimation.
Returns:
jnp.array (num_samples, ndim): Samples from fitted distribution.
"""

base_dist, flow = self.make_flow(scale=scale)
samples = distrax.Transformed(base_dist, flow).sample(
seed=rng, sample_shape=(num_samples)
)

return samples


def log_prob(self, x:jnp.array, scale:float = 1.) -> jnp.array:
"""
Evaluate the log probability of the flow for a batched input.
Args:
x (jnp.ndarray (batch_size, ndim)): Sample for which to predict posterior values.
scale (float): Factor by which base Gaussian unit covariance matrix is scaled.
Should be between 0 and 1 for use in evidence estimation.
Returns:
jnp.ndarray (batch_size,): Predicted log_e posterior value.
"""

get_logprob = jax.jit(jax.vmap(self.__call__, in_axes=[0, None]))
logprob = get_logprob(x, scale)
Expand All @@ -219,6 +335,11 @@ def __call__(self, x):


class Conditioner(nn.Module):
"""
Conditioner used to construct the bijector function in RQSpline class.
Adapted from github.com/kazewong/flowMC
"""
n_features: int
hidden_size: Sequence[int]
num_bijector_params: int
Expand All @@ -243,6 +364,12 @@ def __call__(self, x):


class Scalar(nn.Module):
"""
Scalar used to construct the spline flow in RQSpline class.
Adapted from github.com/kazewong/flowMC
"""

n_features: int

def setup(self):
Expand All @@ -267,11 +394,17 @@ class MLP(nn.Module):
of `init_weight_scale=1e-4` by default.
Args:
features: (list of int) The number of features in each layer.
activation: (callable) The activation function at each level
use_bias: (bool) Whether to use bias in the layers.
init_weight_scale: (float) The initial weight scale for the layers.
kernel_init: (callable) The kernel initializer for the layers.
features (list of int): The number of features in each layer.
activation (callable): The activation function at each level
use_bias (bool): Whether to use bias in the layers.
init_weight_scale (float): The initial weight scale for the layers.
kernel_init (callable): The kernel initializer for the layers.
Adapted from github.com/kazewong/flowMC
"""

features: Sequence[int]
Expand Down
8 changes: 4 additions & 4 deletions harmonic/model_nf.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,16 +431,16 @@ def fit(self, X, batch_size=64, epochs=3, key=jax.random.PRNGKey(1000), verbose=
return

def predict(self, x):
"""Predict the value of log_e posterior at x.
"""Predict the value of log_e posterior at batched input x.
Args:
x (jnp.ndarray): Sample of shape at which to
predict posterior value.
x (jnp.ndarray (batch_size, ndim)): Batched sample for which to
predict posterior values.
Returns:
jnp.ndarray: Predicted log_e posterior value.
jnp.ndarray (batch_size,): Predicted log_e posterior value.
Raises:
Expand Down

0 comments on commit 62b9da9

Please sign in to comment.