From 12f35e2b6b8435d9304d37f1de3d77fcd7d43a7b Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Fri, 20 Oct 2023 12:12:56 +0100 Subject: [PATCH] add todo's for code review (very minor) --- build_harmonic.sh | 3 ++- harmonic/flows.py | 10 ++++----- harmonic/model_nf.py | 47 ++++++++++++++++++++-------------------- tests/test_evidence.py | 3 +-- tests/test_flow_model.py | 3 ++- 5 files changed, 33 insertions(+), 33 deletions(-) diff --git a/build_harmonic.sh b/build_harmonic.sh index a146a797..087433a0 100755 --- a/build_harmonic.sh +++ b/build_harmonic.sh @@ -7,7 +7,7 @@ echo -ne 'Building Dependencies... \r' # Install jax and TFP on jax substrates (on M1 mac) conda install -q -c conda-forge jax==0.4.1 -y conda install -q -c conda-forge flax==0.6.1 chex==0.1.6 -y -pip install -Uq tfp-nightly[jax]==0.20.0.dev20230801 > /dev/null +# pip install -Uq tfp-nightly[jax]==0.20.0.dev20230801 > /dev/null pip install -q -r requirements/requirements-core.txt echo -ne 'Building Dependencies... ##### (25%)\r' @@ -19,6 +19,7 @@ pip install -q -r requirements/requirements-docs.txt echo -ne 'Building Dependencies... ####################(100%)\r' echo -ne '\n' +pip install -Uq tfp-nightly[jax]==0.20.0.dev20230801 > /dev/null # Install specific converter for building tutorial documentation diff --git a/harmonic/flows.py b/harmonic/flows.py index df18c964..0853cdd1 100644 --- a/harmonic/flows.py +++ b/harmonic/flows.py @@ -1,10 +1,8 @@ -from typing import Sequence, Callable, List, Any -import numpy as np +from typing import Sequence, Callable import flax.linen as nn import jax import jax.numpy as jnp import tensorflow_probability as tfp -from flax.training import train_state import distrax tfp = tfp.substrates.jax @@ -154,7 +152,6 @@ def log_prob(self, x: jnp.array, var_scale: float = 1.0) -> jnp.array: Returns: jnp.ndarray (batch_size,): Predicted log_e posterior value. """ - get_logprob = jax.jit(jax.vmap(self.__call__, in_axes=[0, None])) logprob = get_logprob(x, var_scale) @@ -206,7 +203,8 @@ class RQSpline(nn.Module): spline_range (Sequence[float]): Range of the spline. - Adapted from github.com/kazewong/flowMC + Note: + Adapted from github.com/kazewong/flowMC """ n_features: int @@ -235,7 +233,7 @@ def bijector_fn(params: jnp.ndarray): ) self.bijector_fn = bijector_fn - + # TODO: change scale to varscale in all functions below. def make_flow(self, scale: float =1.): """ Make distrax distribution containing the rational quadratic spline flow. diff --git a/harmonic/model_nf.py b/harmonic/model_nf.py index 2256b68d..d3f34e25 100644 --- a/harmonic/model_nf.py +++ b/harmonic/model_nf.py @@ -1,17 +1,13 @@ -from typing import Sequence, Callable, List +from typing import Sequence from harmonic import model as md -import numpy as np from harmonic import flows import jax import jax.numpy as jnp import optax -from functools import partial + from tqdm import trange -import cloudpickle -import flax from flax.training import train_state # Useful dataclass to keep train state -import flax.linen as nn def make_training_loop(model): """ @@ -22,6 +18,9 @@ def make_training_loop(model): Returns: train_flow (Callable): wrapper function that trains the model. + + Note: + Adapted from github.com/kazewong/flowMC """ def train_step(batch, state, variables): @@ -68,7 +67,6 @@ def train_flow(rng, state, variables, data, num_epochs, batch_size, verbose: boo rng, input_rng = jax.random.split(rng) # Run an optimization step over a training batch value, state = train_epoch(input_rng, state, variables, data, batch_size) - # print('Train loss: %.3f' % value) loss_values = loss_values.at[epoch].set(value) if loss_values[epoch] < best_loss: best_state = state @@ -111,17 +109,17 @@ def __init__( ndim_in (int): Dimension of the problem to solve. - n_scaled_layers (int): Number of layers with scaler in RealNVP flow. + n_scaled_layers (int, optional): Number of layers with scaler in RealNVP flow. Default = 2. - n_unscaled_layers (int): Number of layers without scaler in RealNVP flow. + n_unscaled_layers (int, optional): Number of layers without scaler in RealNVP flow. Default = 4. - learning_rate (float): Learning rate for adam optimizer used in the fit method. + learning_rate (float, optional): Learning rate for adam optimizer used in the fit method. Default = 0.001. - momentum (float): Learning rate for Adam optimizer used in the fit method. + momentum (float, optional): Learning rate for Adam optimizer used in the fit method. Default = 0.9 - standardize(bool): Indicates if mean and variance should be removed from training data when training the flow. + standardize(bool, optional): Indicates if mean and variance should be removed from training data when training the flow. Default = False - temperature (float): Scale factor by which the base distribution Gaussian is compressed in the prediction step. Should be positive and <=1. + temperature (float, optional): Scale factor by which the base distribution Gaussian is compressed in the prediction step. Should be positive and <=1. Default = 0.8. Raises: @@ -169,26 +167,25 @@ def create_train_state(self, rng): ) - def fit(self, X, batch_size=64, epochs=3, key=jax.random.PRNGKey(1000), verbose = False): + def fit(self, X, batch_size = 64, epochs = 3, key = jax.random.PRNGKey(1000), verbose = False): """Fit the parameters of the model. Args: - X (double ndarray[nsamples, ndim]): Training samples. + X (ndarray[nsamples, ndim]): Training samples. - batch_size (int): Batch size used when training flow. + batch_size (int, optional): Batch size used when training flow. Default = 64. - epochs (int): Number of epochs flow is trained for. + epochs (int, optional): Number of epochs flow is trained for. Default = 3. - key (Union[Array, PRNGKeyArray])): Key used in random number generation process. + key (Union[Array, PRNGKeyArray], optional): Key used in random number generation process. - verbose (bool): Controls if progress bar and current loss are displayed when training. + verbose (bool, optional): Controls if progress bar and current loss are displayed when training. Default = False. Raises: - ValueError: Raised if the second dimension of X is not the same as - ndim. + ValueError: Raised if the second dimension of X is not the same as ndim. """ @@ -225,8 +222,7 @@ def predict(self, x): Args: - x (jnp.ndarray): Sample of shape at which to - predict posterior value. + x (jnp.ndarray): Sample of shape at which to predict posterior value. Returns: @@ -249,6 +245,7 @@ def predict(self, x): if self.standardize: x = (x-self.pre_offset)/self.pre_amp + # TODO: support 1D arrays here, not at the user level. logprob = self.flow.apply( {"params": self.state.params, "variables": self.variables}, x, @@ -269,6 +266,10 @@ def sample(self, n_sample, rng_key=jax.random.PRNGKey(0)): rng_key (Union[Array, PRNGKeyArray])): Key used in random number generation process. + Raises: + + ValueError: If var_scale is negative or greater than 1. + Returns: jnp.array (n_sample, ndim): Samples from fitted distribution. diff --git a/tests/test_evidence.py b/tests/test_evidence.py index a98fa6c2..4c6884e8 100644 --- a/tests/test_evidence.py +++ b/tests/test_evidence.py @@ -4,7 +4,6 @@ import harmonic.chains as ch import harmonic.model as md import harmonic.evidence as cbe -import harmonic.utils as utils import harmonic.model_nf as model_nf @@ -401,7 +400,7 @@ def test_serialization(): # Deserialize evidence ev4 = cbe.Evidence.deserialize(".test.dat") - + # TODO: make sure model works correctly after deserializing. # Test evidence objects the same assert ev3.batch_calculation == ev4.batch_calculation assert ev3.nchains == ev4.nchains diff --git a/tests/test_flow_model.py b/tests/test_flow_model.py index a076ad69..f6b30822 100644 --- a/tests/test_flow_model.py +++ b/tests/test_flow_model.py @@ -4,6 +4,7 @@ import jax import harmonic as hm +# TODO: move this to conftest.py to follow best practices. def standard_nd_gaussian_pdf(x): """ Calculate the probability density function (PDF) of an n-dimensional Gaussian @@ -129,7 +130,7 @@ def test_RQSpline_constructor(): - +# TODO: combine tests into one test with a model variable. def test_RealNVP_gaussian(): # Define the number of dimensions and the mean of the Gaussian