Skip to content

Commit

Permalink
Merge pull request #307 from astro-informatics/20D_gaussian
Browse files Browse the repository at this point in the history
Add batching of evidence estimation inputs and update examples
  • Loading branch information
alicjapolanska authored Nov 5, 2024
2 parents 9726ce2 + 16b72d5 commit 55c970e
Show file tree
Hide file tree
Showing 5 changed files with 393 additions and 177 deletions.
142 changes: 63 additions & 79 deletions examples/gaussian_nondiagcov.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax
import jax.numpy as jnp
import emcee
import logging


def ln_analytic_evidence(ndim, cov):
Expand Down Expand Up @@ -99,29 +100,33 @@ def run_example(
inv_cov = jnp.linalg.inv(cov)
training_proportion = 0.5
if flow_type == "RealNVP":
epochs_num = 5
epochs_num = 10 #5
elif flow_type == "RQSpline":
epochs_num = 3
#epochs_num = 5
epochs_num = 110

# Beginning of path where plots will be saved
save_name_start = "examples/plots/" + flow_type

temperature = 0.8
temperature = 0.9
standardize = True
verbose = True

# Spline params
n_layers = 5
n_bins = 5
n_layers = 3
n_bins = 128
hidden_size = [32, 32]
spline_range = (-10.0, 10.0)

if flow_type == "RQSpline":
save_name_start += "_" + str(n_layers) + "l_" + str(n_bins) + "b_" + str(epochs_num) + "e_" + str(int(training_proportion * 100)) + "perc_" + str(temperature) + "T" + "_emcee"

# Start timer.
clock = time.process_time()

# Run multiple realisations.
n_realisations = 1
evidence_inv_summary = np.zeros((n_realisations, 3))
n_realisations = 100
ln_evidence_inv_summary = np.zeros((n_realisations, 5))
for i_realisation in range(n_realisations):
if n_realisations > 0:
hm.logs.info_log(
Expand All @@ -130,7 +135,7 @@ def run_example(
# Define the number of dimensions and the mean of the Gaussian
num_samples = nchains * samples_per_chain
# Initialize a PRNG key (you can use any valid key)
key = jax.random.PRNGKey(0)
key = jax.random.PRNGKey(i_realisation)
mean = jnp.zeros(ndim)

# Generate random samples from the 2D Gaussian distribution
Expand All @@ -139,7 +144,7 @@ def run_example(
samples = jnp.reshape(samples, (nchains, -1, ndim))
lnprob = jnp.reshape(lnprob, (nchains, -1))

MCMC = False
MCMC = True
if MCMC:
nburn = 500
# Set up and run sampler.
Expand All @@ -151,7 +156,7 @@ def run_example(
rstate = np.random.get_state() # Set random state to repeatable
# across calls.
(pos, prob, state) = sampler.run_mcmc(
pos, samples_per_chain, rstate0=rstate
pos, samples_per_chain, rstate0=rstate, progress=True
)
samples = np.ascontiguousarray(sampler.chain[:, nburn:, :])
lnprob = np.ascontiguousarray(sampler.lnprobability[:, nburn:])
Expand Down Expand Up @@ -191,92 +196,68 @@ def run_example(
ev = hm.Evidence(chains_test.nchains, model)
# ev.set_mean_shift(0.0)
ev.add_chains(chains_test)
ln_evidence, ln_evidence_std = ev.compute_ln_evidence()
err_ln_inv_evidence = ev.compute_ln_inv_evidence_errors()

# Compute analytic evidence.
if i_realisation == 0:
ln_evidence_analytic = ln_analytic_evidence(ndim, cov)

# ======================================================================
# Display evidence computation results.
# ======================================================================
hm.logs.info_log("---------------------------------")
hm.logs.info_log("The inverse evidence in log space is:")
hm.logs.info_log(
"Evidence: analytic = {}, estimated = {}".format(
np.exp(ln_evidence_analytic), np.exp(ln_evidence)
"ln_inv_evidence = {} +/- {}".format(
ev.ln_evidence_inv, err_ln_inv_evidence
)
)
hm.logs.info_log(
"Evidence: std = {}, std / estimate = {}".format(
np.exp(ln_evidence_std), np.exp(ln_evidence_std - ln_evidence)
"ln evidence = {} +/- {} {}".format(
-ev.ln_evidence_inv, -err_ln_inv_evidence[1], -err_ln_inv_evidence[0]
)
)
diff = np.log(np.abs(np.exp(ln_evidence_analytic) - np.exp(ln_evidence)))
hm.logs.info_log("Analytic ln evidence is {}".format(ln_evidence_analytic))
delta = -ln_evidence_analytic - ev.ln_evidence_inv
hm.logs.info_log(
"Evidence: |analytic - estimate| / estimate = {}".format(
np.exp(diff - ln_evidence)
)
)
# ======================================================================
# Display inverse evidence computation results.
# ======================================================================
hm.logs.debug_log("---------------------------------")
hm.logs.debug_log(
"Inv Evidence: analytic = {}, estimate = {}".format(
np.exp(-ln_evidence_analytic), ev.evidence_inv
)
)
hm.logs.debug_log(
"Inv Evidence: std = {}, std / estimate = {}".format(
np.sqrt(ev.evidence_inv_var),
np.sqrt(ev.evidence_inv_var) / ev.evidence_inv,
)
)
hm.logs.debug_log(
"Inv Evidence: kurtosis = {}, sqrt( 2 / ( n_eff - 1 ) ) = {}".format(
ev.kurtosis, np.sqrt(2.0 / (ev.n_eff - 1))
)
)
hm.logs.debug_log(
"Inv Evidence: sqrt( var(var) ) / var = {}".format(
np.sqrt(ev.evidence_inv_var_var) / ev.evidence_inv_var
"Difference between analytic and harmonic is {} +- {} {}".format(
delta, err_ln_inv_evidence[0], err_ln_inv_evidence[1]
)
)

hm.logs.info_log("kurtosis = {}".format(ev.kurtosis))
hm.logs.info_log(" Aim for ~3.")
check = np.exp(0.5 * ev.ln_evidence_inv_var_var - ev.ln_evidence_inv_var)
hm.logs.info_log("sqrt( var(var) ) / var = {}".format(check))
hm.logs.info_log(
"Inv Evidence: |analytic - estimate| / estimate = {}".format(
np.abs(np.exp(-ln_evidence_analytic) - ev.evidence_inv)
/ ev.evidence_inv
)
" Aim for sqrt( 2/(n_eff-1) ) = {}".format(np.sqrt(2.0 / (ev.n_eff - 1)))
)

# ===========================================================================
# Display more technical details
# ===========================================================================
hm.logs.debug_log("---------------------------------")
hm.logs.debug_log("Technical Details")
hm.logs.debug_log("---------------------------------")
hm.logs.debug_log(
hm.logs.info_log("---------------------------------")
hm.logs.info_log("Technical Details")
hm.logs.info_log("---------------------------------")
hm.logs.info_log(
"lnargmax = {}, lnargmin = {}".format(ev.lnargmax, ev.lnargmin)
)
hm.logs.debug_log(
hm.logs.info_log(
"lnprobmax = {}, lnprobmin = {}".format(ev.lnprobmax, ev.lnprobmin)
)
hm.logs.debug_log(
hm.logs.info_log(
"lnpredictmax = {}, lnpredictmin = {}".format(
ev.lnpredictmax, ev.lnpredictmin
)
)
hm.logs.debug_log("---------------------------------")
hm.logs.debug_log(
hm.logs.info_log("---------------------------------")
hm.logs.info_log(
"shift = {}, shift setting = {}".format(ev.shift_value, ev.shift)
)
hm.logs.debug_log("running sum total = {}".format(sum(ev.running_sum)))
hm.logs.debug_log("running sum = \n{}".format(ev.running_sum))
hm.logs.debug_log("nsamples per chain = \n{}".format(ev.nsamples_per_chain))
hm.logs.debug_log(
hm.logs.info_log("running sum total = {}".format(sum(ev.running_sum)))
hm.logs.info_log("running sum = \n{}".format(ev.running_sum))
hm.logs.info_log("nsamples per chain = \n{}".format(ev.nsamples_per_chain))
hm.logs.info_log(
"nsamples eff per chain = \n{}".format(ev.nsamples_eff_per_chain)
)
hm.logs.debug_log("===============================")
hm.logs.info_log("===============================")

# ======================================================================
# Create corner/triangle plot.
Expand Down Expand Up @@ -314,28 +295,31 @@ def run_example(

plt.show()

evidence_inv_summary[i_realisation, 0] = ev.evidence_inv
evidence_inv_summary[i_realisation, 1] = ev.evidence_inv_var
evidence_inv_summary[i_realisation, 2] = ev.evidence_inv_var_var
# Save out realisations for violin plot.
ln_evidence_inv_summary[i_realisation, 0] = ev.ln_evidence_inv
ln_evidence_inv_summary[i_realisation, 1] = err_ln_inv_evidence[0]
ln_evidence_inv_summary[i_realisation, 2] = err_ln_inv_evidence[1]
ln_evidence_inv_summary[i_realisation, 3] = ev.ln_evidence_inv_var
ln_evidence_inv_summary[i_realisation, 4] = ev.ln_evidence_inv_var_var

clock = time.process_time() - clock
hm.logs.info_log("Execution_time = {}s".format(clock))

if n_realisations > 1:
save_name = (
save_name_start
+ "_gaussian_nondiagcov_evidence_inv"
+ "_gaussian_nondiagcov_ln_evidence_inv"
+ "_realisations_{}D.dat".format(ndim)
)
np.savetxt(save_name, evidence_inv_summary)
evidence_inv_analytic_summary = np.zeros(1)
evidence_inv_analytic_summary[0] = np.exp(-ln_evidence_analytic)
np.savetxt(save_name, ln_evidence_inv_summary)
ln_evidence_inv_analytic_summary = np.zeros(1)
ln_evidence_inv_analytic_summary[0] = -ln_evidence_analytic
save_name = (
save_name_start
+ "_gaussian_nondiagcov_evidence_inv"
+ "_gaussian_nondiagcov_ln_evidence_inv"
+ "_analytic_{}D.dat".format(ndim)
)
np.savetxt(save_name, evidence_inv_analytic_summary)
np.savetxt(save_name, ln_evidence_inv_analytic_summary)

created_plots = True
if created_plots:
Expand All @@ -344,14 +328,14 @@ def run_example(

if __name__ == "__main__":
# Setup logging config.
hm.logs.setup_logging()
hm.logs.setup_logging(default_level=logging.DEBUG)

# Define parameters.
ndim = 5
nchains = 100
ndim = 21
nchains = 80
samples_per_chain = 5000
flow_str = "RealNVP"
# flow_str = "RQSpline"
#flow_str = "RealNVP"
flow_str = "RQSpline"
np.random.seed(10) # used for initializing covariance matrix

hm.logs.info_log("Non-diagonal Covariance Gaussian example")
Expand All @@ -365,4 +349,4 @@ def run_example(
hm.logs.debug_log("-------------------------")

# Run example.
run_example(flow_str, ndim, nchains, samples_per_chain, plot_corner=False)
run_example(flow_str, ndim, nchains, samples_per_chain, plot_corner=True)
Loading

0 comments on commit 55c970e

Please sign in to comment.