diff --git a/examples/gaussian_nondiagcov_nvp.py b/examples/gaussian_nondiagcov_nvp.py index 58ebac58..d272e42a 100644 --- a/examples/gaussian_nondiagcov_nvp.py +++ b/examples/gaussian_nondiagcov_nvp.py @@ -102,8 +102,8 @@ def run_example( cov = init_cov(ndim) inv_cov = np.linalg.inv(cov) training_proportion = 0.5 - epochs_num = 50 - var_scale = 0.9 + epochs_num = 5 + temperature = 0.9 standardize = True verbose = True @@ -145,7 +145,7 @@ def run_example( # ======================================================================= hm.logs.info_log("Fit model for {} epochs...".format(epochs_num)) model = model_nf.RealNVPModel( - ndim, standardize=standardize, temperature=var_scale + ndim, standardize=standardize, temperature=temperature ) model.fit(chains_train.samples, epochs=epochs_num, verbose=verbose) diff --git a/examples/gaussian_nondiagcov_splines.py b/examples/gaussian_nondiagcov_splines.py index fbaeefe1..3bdc63f7 100644 --- a/examples/gaussian_nondiagcov_splines.py +++ b/examples/gaussian_nondiagcov_splines.py @@ -102,7 +102,7 @@ def run_example(ndim=2, nchains=100, samples_per_chain=1000, plot_corner=False): inv_cov = jnp.linalg.inv(cov) training_proportion = 0.5 epochs_num = 80 - var_scale = 0.8 + temperature = 0.8 standardize = True verbose = True @@ -157,7 +157,7 @@ def run_example(ndim=2, nchains=100, samples_per_chain=1000, plot_corner=False): hidden_size=hidden_size, spline_range=spline_range, standardize=standardize, - temperature=var_scale, + temperature=temperature, ) model.fit(jnp.array(chains_train.samples), epochs=epochs_num, verbose=verbose) diff --git a/examples/normal_gamma_nvp.py b/examples/normal_gamma_nvp.py index bd8c5e15..40277a53 100644 --- a/examples/normal_gamma_nvp.py +++ b/examples/normal_gamma_nvp.py @@ -192,12 +192,12 @@ def run_example( created_plots = False training_proportion = 0.5 - var_scale = 0.9 + temperature = 0.9 epochs_num = 100 standardize = False plot_comparison_2var = True - var_scale_2 = 0.95 + temperature_2 = 0.95 # =========================================================================== # Simulate data @@ -269,7 +269,7 @@ def run_example( # ======================================================================= hm.logs.info_log("Fit model for {} epochs...".format(epochs_num)) model = model_nf.RealNVPModel( - ndim, standardize=standardize, temperature=var_scale + ndim, standardize=standardize, temperature=temperature ) model.fit(chains_train.samples, epochs=epochs_num) @@ -299,7 +299,7 @@ def run_example( if plot_comparison_2var: model2 = model - model2.temperature = var_scale_2 + model2.temperature = temperature_2 ev_2 = hm.Evidence(chains_test.nchains, model) ev_2.add_chains(chains_test) ln_evidence_2, ln_evidence_std_2 = ev_2.compute_ln_evidence() @@ -411,7 +411,7 @@ def run_example( if savefigs: plt.savefig( "examples/plots/nvp_normalgamma_corner_all_" - + str(var_scale) + + str(temperature) + "tau" + str(tau_prior) + ".png", @@ -448,7 +448,7 @@ def run_example( ) if savefigs: plt.savefig( - "examples/plots/nvp_normalgamma_comparison" + str(var_scale) + ".pdf", + "examples/plots/nvp_normalgamma_comparison" + str(temperature) + ".pdf", bbox_inches="tight", ) plt.show(block=False) @@ -472,7 +472,7 @@ def run_example( capsize=4, capthick=2, elinewidth=2, - label="T=" + str(var_scale), + label="T=" + str(temperature), ) ax.errorbar( np.array(tau_array) * 1.13, @@ -482,15 +482,15 @@ def run_example( capsize=4, capthick=2, elinewidth=2, - label="T=" + str(var_scale_2), + label="T=" + str(temperature_2), ) ax.legend(loc="lower right") if savefigs: plt.savefig( "examples/plots/nvp_normalgamma_comparison_" - + str(var_scale) + + str(temperature) + "_" - + str(var_scale_2) + + str(temperature_2) + ".pdf", bbox_inches="tight", dpi=3000, diff --git a/examples/normal_gamma_splines.py b/examples/normal_gamma_splines.py index 0c47a71f..f5c5427f 100644 --- a/examples/normal_gamma_splines.py +++ b/examples/normal_gamma_splines.py @@ -192,12 +192,12 @@ def run_example( created_plots = False training_proportion = 0.5 - var_scale = 0.8 + temperature = 0.8 epochs_num = 10 standardize = False plot_comparison_2var = False - var_scale_2 = 0.95 + temperature_2 = 0.95 # =========================================================================== # Simulate data @@ -269,7 +269,7 @@ def run_example( # ======================================================================= hm.logs.info_log("Fit model for {} epochs...".format(epochs_num)) model = model_nf.RQSplineModel( - ndim, standardize=standardize, temperature=var_scale + ndim, standardize=standardize, temperature=temperature ) model.fit(chains_train.samples, epochs=epochs_num) @@ -299,7 +299,7 @@ def run_example( if plot_comparison_2var: model2 = model - model2.temperature = var_scale_2 # double check this doesn't modify model + model2.temperature = temperature_2 # double check this doesn't modify model ev_2 = hm.Evidence(chains_test.nchains, model) ev_2.add_chains(chains_test) ln_evidence_2, ln_evidence_std_2 = ev_2.compute_ln_evidence() @@ -411,7 +411,7 @@ def run_example( if savefigs: plt.savefig( "examples/plots/splines_normalgamma_corner_all_" - + str(var_scale) + + str(temperature) + "tau" + str(tau_prior) + ".png", @@ -449,7 +449,7 @@ def run_example( if savefigs: plt.savefig( "examples/plots/splines_normalgamma_comparison" - + str(var_scale) + + str(temperature) + ".pdf", bbox_inches="tight", ) @@ -474,7 +474,7 @@ def run_example( capsize=4, capthick=2, elinewidth=2, - label="T=" + str(var_scale), + label="T=" + str(temperature), ) ax.errorbar( np.array(tau_array) * 1.13, @@ -484,15 +484,15 @@ def run_example( capsize=4, capthick=2, elinewidth=2, - label="T=" + str(var_scale_2), + label="T=" + str(temperature_2), ) ax.legend(loc="lower right") if savefigs: plt.savefig( "examples/plots/splines_normalgamma_comparison_" - + str(var_scale) + + str(temperature) + "_" - + str(var_scale_2) + + str(temperature_2) + ".pdf", bbox_inches="tight", dpi=3000, diff --git a/examples/pima_indian_nvp.py b/examples/pima_indian_nvp.py index 07423571..8050c702 100644 --- a/examples/pima_indian_nvp.py +++ b/examples/pima_indian_nvp.py @@ -188,7 +188,7 @@ def run_example( """ training_proportion = 0.5 - var_scale = 0.9 + temperature = 0.9 epochs_num = 50 n_scaled = 6 n_unscaled = 2 @@ -259,7 +259,7 @@ def run_example( ndim, n_scaled_layers=n_scaled, n_unscaled_layers=n_unscaled, - temperature=var_scale, + temperature=temperature, ) model.fit(chains_train.samples, epochs=epochs_num) @@ -268,7 +268,7 @@ def run_example( # ======================================================================= num_samp = chains_train.samples.shape[0] - # samps = np.array(model.sample(num_samp, var_scale=1.)) + # samps = np.array(model.sample(num_samp, temperature=1.)) samps_compressed = np.array(model.sample(num_samp)) labels = ["Bias", "NP", "PGC", "BMI", "DP", "AGE"] @@ -286,7 +286,7 @@ def run_example( if savefigs: plt.savefig( "examples/plots/nvp_pima_indian_corner_all_{}_T{}_tau{}_".format( - n_scaled + n_unscaled, var_scale, tau + n_scaled + n_unscaled, temperature, tau ) + model_lab + ".png", diff --git a/examples/radiata_pine_nvp.py b/examples/radiata_pine_nvp.py index 87f0394b..36dc194a 100644 --- a/examples/radiata_pine_nvp.py +++ b/examples/radiata_pine_nvp.py @@ -325,7 +325,7 @@ def run_example( savefigs = True training_proportion = 0.5 - var_scale = 0.9 + temperature = 0.9 epochs_num = 50 n_scaled = 3 n_unscaled = 3 @@ -442,7 +442,7 @@ def run_example( n_unscaled_layers=n_unscaled, learning_rate=learning_rate, standardize=standardize, - temperature=var_scale, + temperature=temperature, ) # model = model_nf.RQSplineFlow(ndim) model.fit(chains_train.samples, epochs=epochs_num) @@ -537,8 +537,8 @@ def run_example( # ======================================================================= num_samp = chains_train.samples.shape[0] - # samps = np.array(model.sample(num_samp, var_scale=1.)) - samps_compressed = np.array(model.sample(num_samp, var_scale=var_scale)) + # samps = np.array(model.sample(num_samp, temperature=1.)) + samps_compressed = np.array(model.sample(num_samp, temperature=temperature)) utils.plot_getdist_compare(chains_train.samples, samps_compressed) if savefigs: diff --git a/examples/radiata_pine_splines.py b/examples/radiata_pine_splines.py index 20551853..70c6eff3 100644 --- a/examples/radiata_pine_splines.py +++ b/examples/radiata_pine_splines.py @@ -325,7 +325,7 @@ def run_example( savefigs = True training_proportion = 0.5 - var_scale = 0.8 + temperature = 0.8 epochs_num = 30 standardize = True @@ -433,7 +433,9 @@ def run_example( Fit model by selecing the configuration of hyper-parameters which minimises the validation variances. """ - model = model_nf.RQSplineModel(ndim, standardize=standardize, temperature=var_scale) + model = model_nf.RQSplineModel( + ndim, standardize=standardize, temperature=temperature + ) model.fit(chains_train.samples, epochs=epochs_num) # =========================================================================== @@ -526,7 +528,7 @@ def run_example( # ======================================================================= num_samp = chains_train.samples.shape[0] - # samps = np.array(model.sample(num_samp, var_scale=1.)) + # samps = np.array(model.sample(num_samp, temperature=1.)) samps_compressed = np.array(model.sample(num_samp)) utils.plot_getdist_compare(chains_train.samples, samps_compressed) diff --git a/examples/rastrigin_splines.py b/examples/rastrigin_splines.py index ac1dd353..c12462ca 100644 --- a/examples/rastrigin_splines.py +++ b/examples/rastrigin_splines.py @@ -122,7 +122,7 @@ def run_example( hyper_parameters = [[10 ** (R)] for R in range(-nhyper + step, step)] hm.logs.debug_log("Hyper-parameters = {}".format(hyper_parameters)) - var_scale = 0.8 + temperature = 0.8 epochs_num = 30 # Spline params @@ -209,7 +209,7 @@ def run_example( standardize=standardize, learning_rate=learning_rate, momentum=momentum, - temperature=var_scale, + temperature=temperature, ) model.fit(chains_train.samples, epochs=epochs_num) diff --git a/examples/rosenbrock_nvp.py b/examples/rosenbrock_nvp.py index b4961aa6..71749c8a 100644 --- a/examples/rosenbrock_nvp.py +++ b/examples/rosenbrock_nvp.py @@ -147,7 +147,7 @@ def run_example( a = 1.0 b = 100.0 epochs_num = 8 - var_scale = 0.9 + temperature = 0.9 training_proportion = 0.5 standardize = False """ @@ -221,7 +221,7 @@ def run_example( # ======================================================================= hm.logs.info_log("Fit model for {} epochs...".format(epochs_num)) model = model_nf.RealNVPModel( - ndim, standardize=standardize, temperature=var_scale + ndim, standardize=standardize, temperature=temperature ) model.fit(chains_train.samples, epochs=epochs_num) @@ -364,7 +364,7 @@ def run_example( if savefigs: plt.savefig( "examples/plots/nvp_rosenbrock_corner_all_T" - + str(var_scale) + + str(temperature) + ".png", bbox_inches="tight", dpi=300, @@ -388,7 +388,7 @@ def run_example( if n_realisations > 1: np.savetxt( "examples/data/nvp_rosenbrock_evidence_inv_T" - + str(var_scale) + + str(temperature) + "_realisations.dat", evidence_inv_summary, ) diff --git a/examples/rosenbrock_splines.py b/examples/rosenbrock_splines.py index 35a0a64b..2ca873f8 100644 --- a/examples/rosenbrock_splines.py +++ b/examples/rosenbrock_splines.py @@ -147,7 +147,7 @@ def run_example( a = 1.0 b = 100.0 epochs_num = 5 - var_scale = 0.8 + temperature = 0.8 training_proportion = 0.5 standardize = True """ @@ -221,7 +221,7 @@ def run_example( # ======================================================================= hm.logs.info_log("Fit model for {} epochs...".format(epochs_num)) model = model_nf.RQSplineModel( - ndim, standardize=standardize, temperature=var_scale + ndim, standardize=standardize, temperature=temperature ) model.fit(chains_train.samples, epochs=epochs_num) @@ -356,7 +356,7 @@ def run_example( if savefigs: plt.savefig( "examples/plots/spline_rosenbrock_corner_all_T" - + str(var_scale) + + str(temperature) + ".png", bbox_inches="tight", dpi=300, @@ -378,7 +378,7 @@ def run_example( if n_realisations > 1: np.savetxt( "examples/data/spline_rosenbrock_evidence_inv_T" - + str(var_scale) + + str(temperature) + "_realisations.dat", evidence_inv_summary, ) diff --git a/examples/temperature_diagram.py b/examples/temperature_diagram.py index 0bde50c0..9a081335 100644 --- a/examples/temperature_diagram.py +++ b/examples/temperature_diagram.py @@ -109,9 +109,9 @@ def get_batch(seed): num_samp = batch_size * 5 -# samps = np.array(model.sample(num_samp, var_scale=1.)) -samps2 = np.array(model.sample(num_samp, var_scale=0.7)) -samps3 = np.array(model.sample(num_samp, var_scale=0.4)) +# samps = np.array(model.sample(num_samp, temperature=1.)) +samps2 = np.array(model.sample(num_samp, temperature=0.7)) +samps3 = np.array(model.sample(num_samp, temperature=0.4)) # Get the getdist MCSamples objects for the samples, specifying same parameter # names and labels; if not specified weights are assumed to all be unity diff --git a/harmonic/flows.py b/harmonic/flows.py index f7150cd0..0de23917 100644 --- a/harmonic/flows.py +++ b/harmonic/flows.py @@ -38,12 +38,12 @@ def setup(self): AffineCoupling(apply_scaling=False) for i in range(self.n_unscaled_layers) ] - def make_flow(self, var_scale: float = 1.0): + def make_flow(self, temperature: float = 1.0): """ Make tfp-jax distribution object containing the RealNVP flow. Args: - var_scale (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. + temperature (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1. @@ -91,32 +91,32 @@ def make_flow(self, var_scale: float = 1.0): nvp = tfd.TransformedDistribution( distribution=tfd.MultivariateNormalDiag( loc=jnp.zeros(self.n_features), - scale_diag=jnp.full(self.n_features, var_scale), + scale_diag=jnp.full(self.n_features, temperature), ), bijector=chain, ) return nvp - def __call__(self, x: jnp.ndarray, var_scale: int = 1.0) -> jnp.array: + def __call__(self, x: jnp.ndarray, temperature: int = 1.0) -> jnp.array: """ Evaluate the log probability of the flow for non-batched input x. Args: x (jnp.ndarray (ndim)): Sample at which to predict posterior value. - var_scale (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. + temperature (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1. Returns: float: Predicted log_e posterior value. """ - flow = self.make_flow(var_scale=var_scale) + flow = self.make_flow(temperature=temperature) return flow.log_prob(x) def sample( - self, rng: jax.random.PRNGKey, num_samples: int, var_scale: float = 1.0 + self, rng: jax.random.PRNGKey, num_samples: int, temperature: float = 1.0 ) -> jnp.array: """ " Sample from the flow. @@ -126,32 +126,32 @@ def sample( num_samples (int): Number of samples generated. - var_scale (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. + temperature (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1. Returns: jnp.array (num_samples, ndim): Samples from fitted distribution. """ - nvp = self.make_flow(var_scale=var_scale) + nvp = self.make_flow(temperature=temperature) samples = nvp.sample(num_samples, seed=rng) return samples - def log_prob(self, x: jnp.array, var_scale: float = 1.0) -> jnp.array: + def log_prob(self, x: jnp.array, temperature: float = 1.0) -> jnp.array: """ Evaluate the log probability of the flow for a batched input. Args: x (jnp.ndarray (batch_size, ndim)): Sample for which to predict posterior values. - var_scale (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. + temperature (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1. Returns: jnp.ndarray (batch_size,): Predicted log_e posterior value. """ get_logprob = jax.jit(jax.vmap(self.__call__, in_axes=[0, None])) - logprob = get_logprob(x, var_scale) + logprob = get_logprob(x, temperature) return logprob @@ -231,12 +231,12 @@ def bijector_fn(params: jnp.ndarray): self.bijector_fn = bijector_fn - def make_flow(self, var_scale: float = 1.0): + def make_flow(self, temperature: float = 1.0): """ Make distrax distribution containing the rational quadratic spline flow. Args: - var_scale (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. + temperature (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1. Returns: @@ -264,31 +264,31 @@ def make_flow(self, var_scale: float = 1.0): base_dist = distrax.Independent( distrax.MultivariateNormalFullCovariance( loc=jnp.zeros(self.n_features), - covariance_matrix=jnp.eye(self.n_features) * var_scale, + covariance_matrix=jnp.eye(self.n_features) * temperature, ) ) return base_dist, flow - def __call__(self, x: jnp.array, var_scale: float = 1.0) -> jnp.array: + def __call__(self, x: jnp.array, temperature: float = 1.0) -> jnp.array: """ Evaluate the log probability of the flow for non-batched input x. Args: x (jnp.ndarray (ndim)): Sample at which to predict posterior value. - var_scale (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. + temperature (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1. Returns: jnp.ndarray (float): Predicted log_e posterior value. """ - base_dist, flow = self.make_flow(var_scale=var_scale) + base_dist, flow = self.make_flow(temperature=temperature) return distrax.Transformed(base_dist, flow).log_prob(x) def sample( - self, rng: jax.random.PRNGKey, num_samples: int, var_scale: float = 1.0 + self, rng: jax.random.PRNGKey, num_samples: int, temperature: float = 1.0 ) -> jnp.array: """ " Sample from the flow. @@ -298,28 +298,28 @@ def sample( num_samples (int): Number of samples generated. - var_scale (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. + temperature (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1. Returns: jnp.array (num_samples, ndim): Samples from fitted distribution. """ - base_dist, flow = self.make_flow(var_scale=var_scale) + base_dist, flow = self.make_flow(temperature=temperature) samples = distrax.Transformed(base_dist, flow).sample( seed=rng, sample_shape=(num_samples) ) return samples - def log_prob(self, x: jnp.array, var_scale: float = 1.0) -> jnp.array: + def log_prob(self, x: jnp.array, temperature: float = 1.0) -> jnp.array: """ Evaluate the log probability of the flow for a batched input. Args: x (jnp.ndarray (batch_size, ndim)): Sample for which to predict posterior values. - var_scale (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. + temperature (float, optional): Factor by which base Gaussian unit covariance matrix is scaled. Should be between 0 and 1 for use in evidence estimation. Defaults to 1. Returns: @@ -327,7 +327,7 @@ def log_prob(self, x: jnp.array, var_scale: float = 1.0) -> jnp.array: """ get_logprob = jax.jit(jax.vmap(self.__call__, in_axes=[0, None])) - logprob = get_logprob(x, var_scale) + logprob = get_logprob(x, temperature) return logprob diff --git a/harmonic/model_nf.py b/harmonic/model_nf.py index c95d94be..6e13ca63 100644 --- a/harmonic/model_nf.py +++ b/harmonic/model_nf.py @@ -27,7 +27,7 @@ def loss(params): log_det = model.apply( {"params": params, "variables": variables}, batch, - var_scale=1.0, + temperature=1.0, method=model.log_prob, ) return -jnp.mean(log_det) @@ -204,16 +204,16 @@ def predict(self, x: jnp.ndarray) -> jnp.ndarray: Raises: - ValueError: If var_scale is negative or greater than 1. + ValueError: If temperature is negative or greater than 1. """ - var_scale = self.temperature + temperature = self.temperature - if var_scale > 1: + if temperature > 1: raise ValueError("Scaling must not be greater than 1.") - if var_scale <= 0: + if temperature <= 0: raise ValueError("Scaling must be positive.") if self.standardize: @@ -223,7 +223,7 @@ def predict(self, x: jnp.ndarray) -> jnp.ndarray: logprob = self.flow.apply( {"params": self.state.params, "variables": self.variables}, x, - var_scale, + temperature, method=self.flow.log_prob, ) @@ -242,26 +242,26 @@ def sample(self, n_sample: int, rng_key=jax.random.PRNGKey(0)) -> jnp.ndarray: Raises: - ValueError: If var_scale is negative or greater than 1. + ValueError: If temperature is negative or greater than 1. Returns: jnp.array (n_sample, ndim): Samples from fitted distribution. """ - var_scale = self.temperature + temperature = self.temperature - if var_scale > 1: + if temperature > 1: raise ValueError("Scaling must not be greater than 1.") - if var_scale <= 0: + if temperature <= 0: raise ValueError("Scaling must be positive.") samples = self.flow.apply( {"params": self.state.params, "variables": self.variables}, rng_key, n_sample, - var_scale, + temperature, method=self.flow.sample, ) diff --git a/tests/test_flow_model.py b/tests/test_flow_model.py index 1c6496cf..0bca7f61 100644 --- a/tests/test_flow_model.py +++ b/tests/test_flow_model.py @@ -256,7 +256,7 @@ def test_model_serialization(): learning_rate = 0.01 momentum = 0.8 standardize = True - var_scale = 0.6 + temperature = 0.6 model_NVP = model_nf.RealNVPModel( ndim, @@ -265,7 +265,7 @@ def test_model_serialization(): learning_rate=learning_rate, momentum=momentum, standardize=standardize, - temperature=var_scale, + temperature=temperature, ) model_NVP.fit(samples, epochs=epochs_NVP) @@ -302,7 +302,7 @@ def test_model_serialization(): standardize=standardize, learning_rate=learning_rate, momentum=momentum, - temperature=var_scale, + temperature=temperature, ) model_spline.fit(samples, epochs=epochs_spline) # Serialize model