From a7160388449c540ad5e7a14795142a1f5f5fde68 Mon Sep 17 00:00:00 2001 From: alicjapolanska Date: Tue, 17 Oct 2023 15:01:06 +0100 Subject: [PATCH] Convert Gaussian example to jax. --- examples/gaussian_nondiagcov_splines.py | 156 +++++------------------- 1 file changed, 33 insertions(+), 123 deletions(-) diff --git a/examples/gaussian_nondiagcov_splines.py b/examples/gaussian_nondiagcov_splines.py index a4c50ab8..dbb58454 100644 --- a/examples/gaussian_nondiagcov_splines.py +++ b/examples/gaussian_nondiagcov_splines.py @@ -10,7 +10,10 @@ sys.path.append("examples") import utils sys.path.append("harmonic") -import model_nf +from harmonic import model_nf +import jax +import jax.numpy as jnp + def ln_analytic_evidence(ndim, cov): @@ -31,7 +34,7 @@ def ln_analytic_evidence(ndim, cov): ln_norm_lik = 0.5*ndim*np.log(2*np.pi) + 0.5*np.log(np.linalg.det(cov)) return ln_norm_lik - +#@partial(jax.jit, static_argnums=(1,)) def ln_posterior(x, inv_cov): """Compute log_e of posterior. @@ -47,7 +50,7 @@ def ln_posterior(x, inv_cov): """ - return -np.dot(x,np.dot(inv_cov,x))/2.0 + return -jnp.dot(x,jnp.dot(inv_cov,x))/2.0 def init_cov(ndim): @@ -75,7 +78,7 @@ def init_cov(ndim): def run_example(ndim=2, nchains=100, samples_per_chain=1000, - nburn=500, plot_corner=False, plot_surface=False): + plot_corner=False): """Run Gaussian example with non-diagonal covariance matrix. Args: @@ -86,23 +89,26 @@ def run_example(ndim=2, nchains=100, samples_per_chain=1000, samples_per_chain: Number of samples per chain. - nburn: Number of burn in samples for each chain. - plot_corner: Plot marginalised distributions if true. - plot_surface: Plot surface and samples if true. - """ savefigs = True # Initialise covariance matrix. cov = init_cov(ndim) - inv_cov = np.linalg.inv(cov) + inv_cov = jnp.linalg.inv(cov) training_proportion = 0.5 - epochs_num = 5 + epochs_num = 80 var_scale = 0.8 standardize = True + verbose = True + + #Spline params + n_layers = 13 + n_bins = 8 + hidden_size = [64, 64] + spline_range = (-10.0, 10.0) # Start timer. clock = time.process_time() @@ -115,26 +121,23 @@ def run_example(ndim=2, nchains=100, samples_per_chain=1000, if n_realisations > 0: hm.logs.info_log('Realisation = {}/{}' .format(i_realisation+1, n_realisations)) - - # Set up and run sampler. - hm.logs.info_log('Run sampling...') - - pos = np.random.rand(ndim * nchains).reshape((nchains, ndim)) - - sampler = emcee.EnsembleSampler(nchains, ndim, \ - ln_posterior, args=[inv_cov]) - rstate = np.random.get_state() # Set random state to repeatable - # across calls. - (pos, prob, state) = sampler.run_mcmc(pos, samples_per_chain, - rstate0=rstate) - samples = np.ascontiguousarray(sampler.chain[:,nburn:,:]) - lnprob = np.ascontiguousarray(sampler.lnprobability[:,nburn:]) + # Define the number of dimensions and the mean of the Gaussian + num_samples = nchains*samples_per_chain + # Initialize a PRNG key (you can use any valid key) + key = jax.random.PRNGKey(0) + mean = jnp.zeros(ndim) + + # Generate random samples from the 2D Gaussian distribution + samples = jax.random.multivariate_normal(key, mean, cov, shape=(num_samples,)) + lnprob = jax.vmap(ln_posterior, in_axes=(0, None))(samples,jnp.array(inv_cov)) + samples = jnp.reshape(samples, (nchains,-1,ndim)) + lnprob = jnp.reshape(lnprob, (nchains,-1)) # Calculate evidence using harmonic.... # Set up chains. chains = hm.Chains(ndim) - chains.add_chains_3d(samples, lnprob) + chains.add_chains_3d(np.array(samples).astype('double'), np.array(lnprob).astype('double')) chains_train, chains_test = hm.utils.split_data(chains, \ training_proportion=training_proportion) @@ -142,8 +145,8 @@ def run_example(ndim=2, nchains=100, samples_per_chain=1000, # Fit model #======================================================================= hm.logs.info_log('Fit model for {} epochs...'.format(epochs_num)) - model = model_nf.RQSplineFlow(ndim, standardize=standardize, temperature = var_scale) - model.fit(chains_train.samples, epochs=epochs_num) + model = model_nf.RQSplineFlow(ndim, n_layers = n_layers, n_bins = n_bins, hidden_size = hidden_size, spline_range = spline_range, standardize = standardize, temperature=var_scale) + model.fit(jnp.array(chains_train.samples), epochs=epochs_num, verbose = verbose) # Use chains and model to compute inverse evidence. hm.logs.info_log('Compute evidence...') @@ -242,96 +245,6 @@ def run_example(ndim=2, nchains=100, samples_per_chain=1000, bbox_inches='tight', dpi=300) plt.show() - - # ====================================================================== - # In 2D case, plot surface/image and samples. - # ====================================================================== - if plot_surface and ndim == 2 and i_realisation == 0: - - # ================================================================== - # Define plot parameters. - # ================================================================== - nx = 50 - xmin = -3.0 - xmax = 3.0 - - # ================================================================== - # 2D surface plot of posterior. - # ================================================================== - ln_posterior_func = partial(ln_posterior, inv_cov=inv_cov) - ln_posterior_grid, x_grid, y_grid = \ - utils.eval_func_on_grid(ln_posterior_func, - xmin=xmin, xmax=xmax, - ymin=xmin, ymax=xmax, - nx=nx, ny=nx) - i_chain = 0 - ax = utils.plot_surface(np.exp(ln_posterior_grid), x_grid, y_grid, - samples[i_chain,:,:].reshape((-1, ndim)), - np.exp(lnprob[i_chain,:].reshape((-1, 1))), - contour_z_offset=-0.5, alpha=0.3) - - ax.set_zlabel(r'$\mathcal{L}$') - - # Save. - if savefigs: - plt.savefig('examples/plots/splines_gaussian_nondiagcov_posterior_surface.png'\ - , bbox_inches='tight') - - plt.show(block=False) - - # ================================================================== - # Image of posterior samples overlayed with contour plot. - # ================================================================== - # Plot posterior image. - ax = utils.plot_image(np.exp(ln_posterior_grid), x_grid, y_grid, - samples[i_chain].reshape((-1, ndim)), - colorbar_label='$\mathcal{L}$', - plot_contour=True, markersize=1.0) - # Save. - if savefigs: - plt.savefig( - 'examples/plots/splines_gaussian_nondiagcov_posterior_image.png' - , bbox_inches='tight') - - plt.show(block=False) - - # ================================================================== - # Learnt model of the posterior - # ================================================================== - # Evaluate ln_posterior and model over grid. - x = np.linspace(xmin, xmax, nx); y = np.linspace(xmin, xmax, nx) - x, y = np.meshgrid(x, y) - ln_model_grid = np.zeros((nx,nx)) - for i in range(nx): - for j in range(nx): - ln_model_grid[i,j] =model.predict(np.array([x[i,j],y[i,j]])) - - i_chain = 0 - ax = utils.plot_surface(np.exp(ln_model_grid), x_grid, y_grid, - contour_z_offset=-0.075) - ax.set_zlabel(r'$\mathcal{L}$') - - # Save. - if savefigs: - plt.savefig('examples/plots/splines_gaussian_nondiagcov_surface.png' - , bbox_inches='tight') - - plt.show(block=False) - - # ================================================================== - # Projection of posteior onto x1,x2 plane with contours. - # ================================================================== - # Plot posterior image. - ax = utils.plot_image(np.exp(ln_model_grid), x_grid, y_grid, - colorbar_label='$\mathcal{L}$', - plot_contour=True) - # Save. - if savefigs: - plt.savefig('examples/plots/splines_gaussian_nondiagcov_image.png', - bbox_inches='tight') - - plt.show(block=False) - # ================================================================== evidence_inv_summary[i_realisation,0] = ev.evidence_inv evidence_inv_summary[i_realisation,1] = ev.evidence_inv_var @@ -361,10 +274,9 @@ def run_example(ndim=2, nchains=100, samples_per_chain=1000, hm.logs.setup_logging() # Define parameters. - ndim = 10 - nchains = 100 + ndim = 50 + nchains = 200 samples_per_chain = 5000 - nburn = 500 np.random.seed(10) hm.logs.info_log('Non-diagonal Covariance Gaussian example') @@ -374,10 +286,8 @@ def run_example(ndim=2, nchains=100, samples_per_chain=1000, hm.logs.debug_log('Dimensionality = {}'.format(ndim)) hm.logs.debug_log('Number of chains = {}'.format(nchains)) hm.logs.debug_log('Samples per chain = {}'.format(samples_per_chain)) - hm.logs.debug_log('Burn in = {}'.format(nburn)) hm.logs.debug_log('-------------------------') # Run example. - run_example(ndim, nchains, samples_per_chain, nburn, - plot_corner=False, plot_surface=False) + run_example(ndim, nchains, samples_per_chain, plot_corner=False)