Skip to content

Commit

Permalink
Update Gaussian example to use log space and MCMC.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Oct 28, 2024
1 parent 38d7993 commit dd6e08e
Showing 1 changed file with 40 additions and 34 deletions.
74 changes: 40 additions & 34 deletions examples/gaussian_nondiagcov.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax
import jax.numpy as jnp
import emcee
import logging


def ln_analytic_evidence(ndim, cov):
Expand Down Expand Up @@ -99,30 +100,33 @@ def run_example(
inv_cov = jnp.linalg.inv(cov)
training_proportion = 0.5
if flow_type == "RealNVP":
epochs_num = 5
epochs_num = 10 #5
elif flow_type == "RQSpline":
# epochs_num = 3
epochs_num = 10
#epochs_num = 5
epochs_num = 110

# Beginning of path where plots will be saved
save_name_start = "examples/plots/" + flow_type

temperature = 0.8
temperature = 0.9
standardize = True
verbose = True

# Spline params
n_layers = 5
n_bins = 16
n_layers = 3
n_bins = 128
hidden_size = [32, 32]
spline_range = (-10.0, 10.0)

if flow_type == "RQSpline":
save_name_start += "_" + str(n_layers) + "l_" + str(n_bins) + "b_" + str(epochs_num) + "e_" + str(int(training_proportion * 100)) + "perc_" + str(temperature) + "T" + "_emcee"

# Start timer.
clock = time.process_time()

# Run multiple realisations.
n_realisations = 1
ln_evidence_inv_summary = np.zeros((n_realisations, 3))
n_realisations = 100
ln_evidence_inv_summary = np.zeros((n_realisations, 5))
for i_realisation in range(n_realisations):
if n_realisations > 0:
hm.logs.info_log(
Expand All @@ -140,7 +144,7 @@ def run_example(
samples = jnp.reshape(samples, (nchains, -1, ndim))
lnprob = jnp.reshape(lnprob, (nchains, -1))

MCMC = False
MCMC = True
if MCMC:
nburn = 500
# Set up and run sampler.
Expand All @@ -152,7 +156,7 @@ def run_example(
rstate = np.random.get_state() # Set random state to repeatable
# across calls.
(pos, prob, state) = sampler.run_mcmc(
pos, samples_per_chain, rstate0=rstate
pos, samples_per_chain, rstate0=rstate, progress=True
)
samples = np.ascontiguousarray(sampler.chain[:, nburn:, :])
lnprob = np.ascontiguousarray(sampler.lnprobability[:, nburn:])
Expand Down Expand Up @@ -218,42 +222,42 @@ def run_example(
)
)

hm.logs.debug_log("kurtosis = {}".format(ev.kurtosis))
hm.logs.debug_log(" Aim for ~3.")
hm.logs.info_log("kurtosis = {}".format(ev.kurtosis))
hm.logs.info_log(" Aim for ~3.")
check = np.exp(0.5 * ev.ln_evidence_inv_var_var - ev.ln_evidence_inv_var)
hm.logs.debug_log("sqrt( var(var) ) / var = {}".format(check))
hm.logs.debug_log(
hm.logs.info_log("sqrt( var(var) ) / var = {}".format(check))
hm.logs.info_log(
" Aim for sqrt( 2/(n_eff-1) ) = {}".format(np.sqrt(2.0 / (ev.n_eff - 1)))
)

# ===========================================================================
# Display more technical details
# ===========================================================================
hm.logs.debug_log("---------------------------------")
hm.logs.debug_log("Technical Details")
hm.logs.debug_log("---------------------------------")
hm.logs.debug_log(
hm.logs.info_log("---------------------------------")
hm.logs.info_log("Technical Details")
hm.logs.info_log("---------------------------------")
hm.logs.info_log(
"lnargmax = {}, lnargmin = {}".format(ev.lnargmax, ev.lnargmin)
)
hm.logs.debug_log(
hm.logs.info_log(
"lnprobmax = {}, lnprobmin = {}".format(ev.lnprobmax, ev.lnprobmin)
)
hm.logs.debug_log(
hm.logs.info_log(
"lnpredictmax = {}, lnpredictmin = {}".format(
ev.lnpredictmax, ev.lnpredictmin
)
)
hm.logs.debug_log("---------------------------------")
hm.logs.debug_log(
hm.logs.info_log("---------------------------------")
hm.logs.info_log(
"shift = {}, shift setting = {}".format(ev.shift_value, ev.shift)
)
hm.logs.debug_log("running sum total = {}".format(sum(ev.running_sum)))
hm.logs.debug_log("running sum = \n{}".format(ev.running_sum))
hm.logs.debug_log("nsamples per chain = \n{}".format(ev.nsamples_per_chain))
hm.logs.debug_log(
hm.logs.info_log("running sum total = {}".format(sum(ev.running_sum)))
hm.logs.info_log("running sum = \n{}".format(ev.running_sum))
hm.logs.info_log("nsamples per chain = \n{}".format(ev.nsamples_per_chain))
hm.logs.info_log(
"nsamples eff per chain = \n{}".format(ev.nsamples_eff_per_chain)
)
hm.logs.debug_log("===============================")
hm.logs.info_log("===============================")

# ======================================================================
# Create corner/triangle plot.
Expand Down Expand Up @@ -293,8 +297,10 @@ def run_example(

# Save out realisations for violin plot.
ln_evidence_inv_summary[i_realisation, 0] = ev.ln_evidence_inv
ln_evidence_inv_summary[i_realisation, 1] = ev.ln_evidence_inv_var
ln_evidence_inv_summary[i_realisation, 2] = ev.ln_evidence_inv_var_var
ln_evidence_inv_summary[i_realisation, 1] = err_ln_inv_evidence[0]
ln_evidence_inv_summary[i_realisation, 2] = err_ln_inv_evidence[1]
ln_evidence_inv_summary[i_realisation, 3] = ev.ln_evidence_inv_var
ln_evidence_inv_summary[i_realisation, 4] = ev.ln_evidence_inv_var_var

clock = time.process_time() - clock
hm.logs.info_log("Execution_time = {}s".format(clock))
Expand Down Expand Up @@ -322,13 +328,13 @@ def run_example(

if __name__ == "__main__":
# Setup logging config.
hm.logs.setup_logging()
hm.logs.setup_logging(default_level=logging.DEBUG)

# Define parameters.
ndim = 21
nchains = 100
nchains = 80
samples_per_chain = 5000
# flow_str = "RealNVP"
#flow_str = "RealNVP"
flow_str = "RQSpline"
np.random.seed(10) # used for initializing covariance matrix

Expand All @@ -343,4 +349,4 @@ def run_example(
hm.logs.debug_log("-------------------------")

# Run example.
run_example(flow_str, ndim, nchains, samples_per_chain, plot_corner=False)
run_example(flow_str, ndim, nchains, samples_per_chain, plot_corner=True)

0 comments on commit dd6e08e

Please sign in to comment.