Skip to content

Commit

Permalink
Convert Gaussian example to jax.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Oct 17, 2023
1 parent 1b33245 commit a716038
Showing 1 changed file with 33 additions and 123 deletions.
156 changes: 33 additions & 123 deletions examples/gaussian_nondiagcov_splines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
sys.path.append("examples")
import utils
sys.path.append("harmonic")
import model_nf
from harmonic import model_nf
import jax
import jax.numpy as jnp



def ln_analytic_evidence(ndim, cov):
Expand All @@ -31,7 +34,7 @@ def ln_analytic_evidence(ndim, cov):
ln_norm_lik = 0.5*ndim*np.log(2*np.pi) + 0.5*np.log(np.linalg.det(cov))
return ln_norm_lik


#@partial(jax.jit, static_argnums=(1,))
def ln_posterior(x, inv_cov):
"""Compute log_e of posterior.
Expand All @@ -47,7 +50,7 @@ def ln_posterior(x, inv_cov):
"""

return -np.dot(x,np.dot(inv_cov,x))/2.0
return -jnp.dot(x,jnp.dot(inv_cov,x))/2.0


def init_cov(ndim):
Expand Down Expand Up @@ -75,7 +78,7 @@ def init_cov(ndim):


def run_example(ndim=2, nchains=100, samples_per_chain=1000,
nburn=500, plot_corner=False, plot_surface=False):
plot_corner=False):
"""Run Gaussian example with non-diagonal covariance matrix.
Args:
Expand All @@ -86,23 +89,26 @@ def run_example(ndim=2, nchains=100, samples_per_chain=1000,
samples_per_chain: Number of samples per chain.
nburn: Number of burn in samples for each chain.
plot_corner: Plot marginalised distributions if true.
plot_surface: Plot surface and samples if true.
"""

savefigs = True

# Initialise covariance matrix.
cov = init_cov(ndim)
inv_cov = np.linalg.inv(cov)
inv_cov = jnp.linalg.inv(cov)
training_proportion = 0.5
epochs_num = 5
epochs_num = 80
var_scale = 0.8
standardize = True
verbose = True

#Spline params
n_layers = 13
n_bins = 8
hidden_size = [64, 64]
spline_range = (-10.0, 10.0)

# Start timer.
clock = time.process_time()
Expand All @@ -115,35 +121,32 @@ def run_example(ndim=2, nchains=100, samples_per_chain=1000,
if n_realisations > 0:
hm.logs.info_log('Realisation = {}/{}'
.format(i_realisation+1, n_realisations))

# Set up and run sampler.
hm.logs.info_log('Run sampling...')

pos = np.random.rand(ndim * nchains).reshape((nchains, ndim))

sampler = emcee.EnsembleSampler(nchains, ndim, \
ln_posterior, args=[inv_cov])
rstate = np.random.get_state() # Set random state to repeatable
# across calls.
(pos, prob, state) = sampler.run_mcmc(pos, samples_per_chain,
rstate0=rstate)
samples = np.ascontiguousarray(sampler.chain[:,nburn:,:])
lnprob = np.ascontiguousarray(sampler.lnprobability[:,nburn:])
# Define the number of dimensions and the mean of the Gaussian
num_samples = nchains*samples_per_chain
# Initialize a PRNG key (you can use any valid key)
key = jax.random.PRNGKey(0)
mean = jnp.zeros(ndim)

# Generate random samples from the 2D Gaussian distribution
samples = jax.random.multivariate_normal(key, mean, cov, shape=(num_samples,))
lnprob = jax.vmap(ln_posterior, in_axes=(0, None))(samples,jnp.array(inv_cov))
samples = jnp.reshape(samples, (nchains,-1,ndim))
lnprob = jnp.reshape(lnprob, (nchains,-1))

# Calculate evidence using harmonic....

# Set up chains.
chains = hm.Chains(ndim)
chains.add_chains_3d(samples, lnprob)
chains.add_chains_3d(np.array(samples).astype('double'), np.array(lnprob).astype('double'))
chains_train, chains_test = hm.utils.split_data(chains, \
training_proportion=training_proportion)

#=======================================================================
# Fit model
#=======================================================================
hm.logs.info_log('Fit model for {} epochs...'.format(epochs_num))
model = model_nf.RQSplineFlow(ndim, standardize=standardize, temperature = var_scale)
model.fit(chains_train.samples, epochs=epochs_num)
model = model_nf.RQSplineFlow(ndim, n_layers = n_layers, n_bins = n_bins, hidden_size = hidden_size, spline_range = spline_range, standardize = standardize, temperature=var_scale)
model.fit(jnp.array(chains_train.samples), epochs=epochs_num, verbose = verbose)

# Use chains and model to compute inverse evidence.
hm.logs.info_log('Compute evidence...')
Expand Down Expand Up @@ -242,96 +245,6 @@ def run_example(ndim=2, nchains=100, samples_per_chain=1000,
bbox_inches='tight', dpi=300)

plt.show()

# ======================================================================
# In 2D case, plot surface/image and samples.
# ======================================================================
if plot_surface and ndim == 2 and i_realisation == 0:

# ==================================================================
# Define plot parameters.
# ==================================================================
nx = 50
xmin = -3.0
xmax = 3.0

# ==================================================================
# 2D surface plot of posterior.
# ==================================================================
ln_posterior_func = partial(ln_posterior, inv_cov=inv_cov)
ln_posterior_grid, x_grid, y_grid = \
utils.eval_func_on_grid(ln_posterior_func,
xmin=xmin, xmax=xmax,
ymin=xmin, ymax=xmax,
nx=nx, ny=nx)
i_chain = 0
ax = utils.plot_surface(np.exp(ln_posterior_grid), x_grid, y_grid,
samples[i_chain,:,:].reshape((-1, ndim)),
np.exp(lnprob[i_chain,:].reshape((-1, 1))),
contour_z_offset=-0.5, alpha=0.3)

ax.set_zlabel(r'$\mathcal{L}$')

# Save.
if savefigs:
plt.savefig('examples/plots/splines_gaussian_nondiagcov_posterior_surface.png'\
, bbox_inches='tight')

plt.show(block=False)

# ==================================================================
# Image of posterior samples overlayed with contour plot.
# ==================================================================
# Plot posterior image.
ax = utils.plot_image(np.exp(ln_posterior_grid), x_grid, y_grid,
samples[i_chain].reshape((-1, ndim)),
colorbar_label='$\mathcal{L}$',
plot_contour=True, markersize=1.0)
# Save.
if savefigs:
plt.savefig(
'examples/plots/splines_gaussian_nondiagcov_posterior_image.png'
, bbox_inches='tight')

plt.show(block=False)

# ==================================================================
# Learnt model of the posterior
# ==================================================================
# Evaluate ln_posterior and model over grid.
x = np.linspace(xmin, xmax, nx); y = np.linspace(xmin, xmax, nx)
x, y = np.meshgrid(x, y)
ln_model_grid = np.zeros((nx,nx))
for i in range(nx):
for j in range(nx):
ln_model_grid[i,j] =model.predict(np.array([x[i,j],y[i,j]]))

i_chain = 0
ax = utils.plot_surface(np.exp(ln_model_grid), x_grid, y_grid,
contour_z_offset=-0.075)
ax.set_zlabel(r'$\mathcal{L}$')

# Save.
if savefigs:
plt.savefig('examples/plots/splines_gaussian_nondiagcov_surface.png'
, bbox_inches='tight')

plt.show(block=False)

# ==================================================================
# Projection of posteior onto x1,x2 plane with contours.
# ==================================================================
# Plot posterior image.
ax = utils.plot_image(np.exp(ln_model_grid), x_grid, y_grid,
colorbar_label='$\mathcal{L}$',
plot_contour=True)
# Save.
if savefigs:
plt.savefig('examples/plots/splines_gaussian_nondiagcov_image.png',
bbox_inches='tight')

plt.show(block=False)
# ==================================================================

evidence_inv_summary[i_realisation,0] = ev.evidence_inv
evidence_inv_summary[i_realisation,1] = ev.evidence_inv_var
Expand Down Expand Up @@ -361,10 +274,9 @@ def run_example(ndim=2, nchains=100, samples_per_chain=1000,
hm.logs.setup_logging()

# Define parameters.
ndim = 10
nchains = 100
ndim = 50
nchains = 200
samples_per_chain = 5000
nburn = 500
np.random.seed(10)

hm.logs.info_log('Non-diagonal Covariance Gaussian example')
Expand All @@ -374,10 +286,8 @@ def run_example(ndim=2, nchains=100, samples_per_chain=1000,
hm.logs.debug_log('Dimensionality = {}'.format(ndim))
hm.logs.debug_log('Number of chains = {}'.format(nchains))
hm.logs.debug_log('Samples per chain = {}'.format(samples_per_chain))
hm.logs.debug_log('Burn in = {}'.format(nburn))

hm.logs.debug_log('-------------------------')

# Run example.
run_example(ndim, nchains, samples_per_chain, nburn,
plot_corner=False, plot_surface=False)
run_example(ndim, nchains, samples_per_chain, plot_corner=False)

0 comments on commit a716038

Please sign in to comment.