Skip to content

Commit

Permalink
add todo's for code review (very minor)
Browse files Browse the repository at this point in the history
  • Loading branch information
CosmoMatt committed Oct 20, 2023
1 parent 52fb4b2 commit 12f35e2
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 33 deletions.
3 changes: 2 additions & 1 deletion build_harmonic.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ echo -ne 'Building Dependencies... \r'
# Install jax and TFP on jax substrates (on M1 mac)
conda install -q -c conda-forge jax==0.4.1 -y
conda install -q -c conda-forge flax==0.6.1 chex==0.1.6 -y
pip install -Uq tfp-nightly[jax]==0.20.0.dev20230801 > /dev/null
# pip install -Uq tfp-nightly[jax]==0.20.0.dev20230801 > /dev/null

pip install -q -r requirements/requirements-core.txt
echo -ne 'Building Dependencies... ##### (25%)\r'
Expand All @@ -19,6 +19,7 @@ pip install -q -r requirements/requirements-docs.txt
echo -ne 'Building Dependencies... ####################(100%)\r'
echo -ne '\n'

pip install -Uq tfp-nightly[jax]==0.20.0.dev20230801 > /dev/null


# Install specific converter for building tutorial documentation
Expand Down
10 changes: 4 additions & 6 deletions harmonic/flows.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Sequence, Callable, List, Any
import numpy as np
from typing import Sequence, Callable
import flax.linen as nn
import jax
import jax.numpy as jnp
import tensorflow_probability as tfp
from flax.training import train_state
import distrax

tfp = tfp.substrates.jax
Expand Down Expand Up @@ -154,7 +152,6 @@ def log_prob(self, x: jnp.array, var_scale: float = 1.0) -> jnp.array:
Returns:
jnp.ndarray (batch_size,): Predicted log_e posterior value.
"""

get_logprob = jax.jit(jax.vmap(self.__call__, in_axes=[0, None]))
logprob = get_logprob(x, var_scale)

Expand Down Expand Up @@ -206,7 +203,8 @@ class RQSpline(nn.Module):
spline_range (Sequence[float]): Range of the spline.
Adapted from github.com/kazewong/flowMC
Note:
Adapted from github.com/kazewong/flowMC
"""

n_features: int
Expand Down Expand Up @@ -235,7 +233,7 @@ def bijector_fn(params: jnp.ndarray):
)

self.bijector_fn = bijector_fn

# TODO: change scale to varscale in all functions below.
def make_flow(self, scale: float =1.):
"""
Make distrax distribution containing the rational quadratic spline flow.
Expand Down
47 changes: 24 additions & 23 deletions harmonic/model_nf.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from typing import Sequence, Callable, List
from typing import Sequence
from harmonic import model as md
import numpy as np
from harmonic import flows
import jax
import jax.numpy as jnp
import optax
from functools import partial

from tqdm import trange
import cloudpickle

import flax
from flax.training import train_state # Useful dataclass to keep train state
import flax.linen as nn

def make_training_loop(model):
"""
Expand All @@ -22,6 +18,9 @@ def make_training_loop(model):
Returns:
train_flow (Callable): wrapper function that trains the model.
Note:
Adapted from github.com/kazewong/flowMC
"""

def train_step(batch, state, variables):
Expand Down Expand Up @@ -68,7 +67,6 @@ def train_flow(rng, state, variables, data, num_epochs, batch_size, verbose: boo
rng, input_rng = jax.random.split(rng)
# Run an optimization step over a training batch
value, state = train_epoch(input_rng, state, variables, data, batch_size)
# print('Train loss: %.3f' % value)
loss_values = loss_values.at[epoch].set(value)
if loss_values[epoch] < best_loss:
best_state = state
Expand Down Expand Up @@ -111,17 +109,17 @@ def __init__(
ndim_in (int): Dimension of the problem to solve.
n_scaled_layers (int): Number of layers with scaler in RealNVP flow.
n_scaled_layers (int, optional): Number of layers with scaler in RealNVP flow. Default = 2.
n_unscaled_layers (int): Number of layers without scaler in RealNVP flow.
n_unscaled_layers (int, optional): Number of layers without scaler in RealNVP flow. Default = 4.
learning_rate (float): Learning rate for adam optimizer used in the fit method.
learning_rate (float, optional): Learning rate for adam optimizer used in the fit method. Default = 0.001.
momentum (float): Learning rate for Adam optimizer used in the fit method.
momentum (float, optional): Learning rate for Adam optimizer used in the fit method. Default = 0.9
standardize(bool): Indicates if mean and variance should be removed from training data when training the flow.
standardize(bool, optional): Indicates if mean and variance should be removed from training data when training the flow. Default = False
temperature (float): Scale factor by which the base distribution Gaussian is compressed in the prediction step. Should be positive and <=1.
temperature (float, optional): Scale factor by which the base distribution Gaussian is compressed in the prediction step. Should be positive and <=1. Default = 0.8.
Raises:
Expand Down Expand Up @@ -169,26 +167,25 @@ def create_train_state(self, rng):
)


def fit(self, X, batch_size=64, epochs=3, key=jax.random.PRNGKey(1000), verbose = False):
def fit(self, X, batch_size = 64, epochs = 3, key = jax.random.PRNGKey(1000), verbose = False):
"""Fit the parameters of the model.
Args:
X (double ndarray[nsamples, ndim]): Training samples.
X (ndarray[nsamples, ndim]): Training samples.
batch_size (int): Batch size used when training flow.
batch_size (int, optional): Batch size used when training flow. Default = 64.
epochs (int): Number of epochs flow is trained for.
epochs (int, optional): Number of epochs flow is trained for. Default = 3.
key (Union[Array, PRNGKeyArray])): Key used in random number generation process.
key (Union[Array, PRNGKeyArray], optional): Key used in random number generation process.
verbose (bool): Controls if progress bar and current loss are displayed when training.
verbose (bool, optional): Controls if progress bar and current loss are displayed when training. Default = False.
Raises:
ValueError: Raised if the second dimension of X is not the same as
ndim.
ValueError: Raised if the second dimension of X is not the same as ndim.
"""

Expand Down Expand Up @@ -225,8 +222,7 @@ def predict(self, x):
Args:
x (jnp.ndarray): Sample of shape at which to
predict posterior value.
x (jnp.ndarray): Sample of shape at which to predict posterior value.
Returns:
Expand All @@ -249,6 +245,7 @@ def predict(self, x):
if self.standardize:
x = (x-self.pre_offset)/self.pre_amp

# TODO: support 1D arrays here, not at the user level.
logprob = self.flow.apply(
{"params": self.state.params, "variables": self.variables},
x,
Expand All @@ -269,6 +266,10 @@ def sample(self, n_sample, rng_key=jax.random.PRNGKey(0)):
rng_key (Union[Array, PRNGKeyArray])): Key used in random number generation process.
Raises:
ValueError: If var_scale is negative or greater than 1.
Returns:
jnp.array (n_sample, ndim): Samples from fitted distribution.
Expand Down
3 changes: 1 addition & 2 deletions tests/test_evidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import harmonic.chains as ch
import harmonic.model as md
import harmonic.evidence as cbe
import harmonic.utils as utils
import harmonic.model_nf as model_nf


Expand Down Expand Up @@ -401,7 +400,7 @@ def test_serialization():

# Deserialize evidence
ev4 = cbe.Evidence.deserialize(".test.dat")

# TODO: make sure model works correctly after deserializing.
# Test evidence objects the same
assert ev3.batch_calculation == ev4.batch_calculation
assert ev3.nchains == ev4.nchains
Expand Down
3 changes: 2 additions & 1 deletion tests/test_flow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax
import harmonic as hm

# TODO: move this to conftest.py to follow best practices.
def standard_nd_gaussian_pdf(x):
"""
Calculate the probability density function (PDF) of an n-dimensional Gaussian
Expand Down Expand Up @@ -129,7 +130,7 @@ def test_RQSpline_constructor():




# TODO: combine tests into one test with a model variable.
def test_RealNVP_gaussian():

# Define the number of dimensions and the mean of the Gaussian
Expand Down

0 comments on commit 12f35e2

Please sign in to comment.