Skip to content

Commit

Permalink
Add tests for evidence calculation sample batching.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Oct 30, 2024
1 parent 89325d5 commit 467a2c2
Showing 1 changed file with 61 additions and 18 deletions.
79 changes: 61 additions & 18 deletions tests/test_evidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
models_to_test_1 = [sphere_1000D, real_nvp_2D, spline_4D]
models_to_test_2 = [sphere_2D, real_nvp_2D, spline_4D]

models_to_test_2D = [
sphere_2D,
real_nvp_2D,
md.RealNVPModel(2, standardize=True),
md.RQSplineModel(2),
md.RQSplineModel(2, standardize=True),
]

chain_batching_options = [None, 2, 10]


@pytest.mark.parametrize("model", models_to_test_1)
def test_constructor(model):
Expand Down Expand Up @@ -59,6 +69,29 @@ def test_set_shift(model):
assert rho.shift_set == True


@pytest.mark.parametrize("model", models_to_test_2)
def test_add_chains_sample_batching_error(model):

nchains = 10
n_samples = 20
ndim = model.ndim
num_slices = 300

X = np.zeros((nchains, n_samples, ndim))
Y = np.zeros((nchains, n_samples))

# Add samples to chains
chain = ch.Chains(ndim)
chain.add_chains_3d(X, Y)

model.fitted = True

# Calculate evidence
cal_ev = cbe.Evidence(nchains, model)
with pytest.raises(ValueError):
cal_ev.add_chains(chain, num_slices=num_slices)


@pytest.mark.parametrize("model", models_to_test_1)
def test_process_run_with_shift(model):
nchains = 10
Expand Down Expand Up @@ -111,7 +144,9 @@ def test_process_run_with_shift(model):
assert np.exp(rho.ln_evidence_inv_var_var) == pytest.approx(evidence_inv_var_var)


def test_add_chains():
@pytest.mark.parametrize("model", models_to_test_2D)
@pytest.mark.parametrize("num_slices", chain_batching_options)
def test_add_chains(model, num_slices):
nchains = 200
nsamples = 500
ndim = 2
Expand All @@ -125,22 +160,25 @@ def test_add_chains():
chain = ch.Chains(ndim)
chain.add_chains_3d(X, Y)

# Fit the Hyper_sphere
domain = [np.array([1e-1, 1e1])]
sphere = mdl.HyperSphere(ndim, domain)
sphere.fit(chain.samples, chain.ln_posterior)
if hasattr(model, "flow"):
model.fit(chain.samples, epochs=5)
else:
model.fit(chain.samples, chain.ln_posterior)

# Calculate evidence
cal_ev = cbe.Evidence(nchains, sphere, cbe.Shifting.MEAN_SHIFT)
cal_ev.add_chains(chain)
cal_ev = cbe.Evidence(nchains, model, cbe.Shifting.MEAN_SHIFT)
cal_ev.add_chains(chain, num_slices=num_slices)

print("cal_ev.evidence_inv = {}".format(np.exp(cal_ev.ln_evidence_inv)))

assert np.exp(cal_ev.ln_evidence_inv) == pytest.approx(0.159438606)
assert np.exp(cal_ev.ln_evidence_inv_var) == pytest.approx(1.164628268e-07)
assert np.exp(cal_ev.ln_evidence_inv_var_var) ** 0.5 == pytest.approx(
1.142786462e-08
)
if hasattr(model, "flow"):
assert np.exp(cal_ev.ln_evidence_inv) == pytest.approx(0.159438606, rel=0.01)
else:
assert np.exp(cal_ev.ln_evidence_inv) == pytest.approx(0.159438606)
assert np.exp(cal_ev.ln_evidence_inv_var) == pytest.approx(1.164628268e-07)
assert np.exp(cal_ev.ln_evidence_inv_var_var) ** 0.5 == pytest.approx(
1.142786462e-08
)

nsamples1 = 300
chains1 = ch.Chains(ndim)
Expand All @@ -150,14 +188,19 @@ def test_add_chains():
for i_chain in range(nchains):
chains2.add_chain(X[i_chain, nsamples1:, :], Y[i_chain, nsamples1:])

ev = cbe.Evidence(nchains, sphere, cbe.Shifting.MEAN_SHIFT)
ev = cbe.Evidence(nchains, model, cbe.Shifting.MEAN_SHIFT)
# Might have small numerical differences if don't use same mean_shift.
ev.add_chains(chains1)
ev.add_chains(chains2)
ev.add_chains(chains1, num_slices=num_slices)
ev.add_chains(chains2, num_slices=num_slices)

assert np.exp(ev.ln_evidence_inv) == pytest.approx(0.159438606)
assert np.exp(ev.ln_evidence_inv_var) == pytest.approx(1.164628268e-07)
assert np.exp(ev.ln_evidence_inv_var_var) ** 0.5 == pytest.approx(1.142786462e-08)
if hasattr(model, "flow"):
assert np.exp(ev.ln_evidence_inv) == pytest.approx(0.159438606, rel=0.01)
else:
assert np.exp(ev.ln_evidence_inv) == pytest.approx(0.159438606)
assert np.exp(ev.ln_evidence_inv_var) == pytest.approx(1.164628268e-07)
assert np.exp(ev.ln_evidence_inv_var_var) ** 0.5 == pytest.approx(
1.142786462e-08
)

return

Expand Down

0 comments on commit 467a2c2

Please sign in to comment.