From 467a2c2ffcf8590b81a37ba092fcce3d61fef67e Mon Sep 17 00:00:00 2001 From: alicjapolanska Date: Wed, 30 Oct 2024 13:15:11 +0000 Subject: [PATCH] Add tests for evidence calculation sample batching. --- tests/test_evidence.py | 79 ++++++++++++++++++++++++++++++++---------- 1 file changed, 61 insertions(+), 18 deletions(-) diff --git a/tests/test_evidence.py b/tests/test_evidence.py index 9b9b3df..68f6924 100644 --- a/tests/test_evidence.py +++ b/tests/test_evidence.py @@ -16,6 +16,16 @@ models_to_test_1 = [sphere_1000D, real_nvp_2D, spline_4D] models_to_test_2 = [sphere_2D, real_nvp_2D, spline_4D] +models_to_test_2D = [ + sphere_2D, + real_nvp_2D, + md.RealNVPModel(2, standardize=True), + md.RQSplineModel(2), + md.RQSplineModel(2, standardize=True), +] + +chain_batching_options = [None, 2, 10] + @pytest.mark.parametrize("model", models_to_test_1) def test_constructor(model): @@ -59,6 +69,29 @@ def test_set_shift(model): assert rho.shift_set == True +@pytest.mark.parametrize("model", models_to_test_2) +def test_add_chains_sample_batching_error(model): + + nchains = 10 + n_samples = 20 + ndim = model.ndim + num_slices = 300 + + X = np.zeros((nchains, n_samples, ndim)) + Y = np.zeros((nchains, n_samples)) + + # Add samples to chains + chain = ch.Chains(ndim) + chain.add_chains_3d(X, Y) + + model.fitted = True + + # Calculate evidence + cal_ev = cbe.Evidence(nchains, model) + with pytest.raises(ValueError): + cal_ev.add_chains(chain, num_slices=num_slices) + + @pytest.mark.parametrize("model", models_to_test_1) def test_process_run_with_shift(model): nchains = 10 @@ -111,7 +144,9 @@ def test_process_run_with_shift(model): assert np.exp(rho.ln_evidence_inv_var_var) == pytest.approx(evidence_inv_var_var) -def test_add_chains(): +@pytest.mark.parametrize("model", models_to_test_2D) +@pytest.mark.parametrize("num_slices", chain_batching_options) +def test_add_chains(model, num_slices): nchains = 200 nsamples = 500 ndim = 2 @@ -125,22 +160,25 @@ def test_add_chains(): chain = ch.Chains(ndim) chain.add_chains_3d(X, Y) - # Fit the Hyper_sphere - domain = [np.array([1e-1, 1e1])] - sphere = mdl.HyperSphere(ndim, domain) - sphere.fit(chain.samples, chain.ln_posterior) + if hasattr(model, "flow"): + model.fit(chain.samples, epochs=5) + else: + model.fit(chain.samples, chain.ln_posterior) # Calculate evidence - cal_ev = cbe.Evidence(nchains, sphere, cbe.Shifting.MEAN_SHIFT) - cal_ev.add_chains(chain) + cal_ev = cbe.Evidence(nchains, model, cbe.Shifting.MEAN_SHIFT) + cal_ev.add_chains(chain, num_slices=num_slices) print("cal_ev.evidence_inv = {}".format(np.exp(cal_ev.ln_evidence_inv))) - assert np.exp(cal_ev.ln_evidence_inv) == pytest.approx(0.159438606) - assert np.exp(cal_ev.ln_evidence_inv_var) == pytest.approx(1.164628268e-07) - assert np.exp(cal_ev.ln_evidence_inv_var_var) ** 0.5 == pytest.approx( - 1.142786462e-08 - ) + if hasattr(model, "flow"): + assert np.exp(cal_ev.ln_evidence_inv) == pytest.approx(0.159438606, rel=0.01) + else: + assert np.exp(cal_ev.ln_evidence_inv) == pytest.approx(0.159438606) + assert np.exp(cal_ev.ln_evidence_inv_var) == pytest.approx(1.164628268e-07) + assert np.exp(cal_ev.ln_evidence_inv_var_var) ** 0.5 == pytest.approx( + 1.142786462e-08 + ) nsamples1 = 300 chains1 = ch.Chains(ndim) @@ -150,14 +188,19 @@ def test_add_chains(): for i_chain in range(nchains): chains2.add_chain(X[i_chain, nsamples1:, :], Y[i_chain, nsamples1:]) - ev = cbe.Evidence(nchains, sphere, cbe.Shifting.MEAN_SHIFT) + ev = cbe.Evidence(nchains, model, cbe.Shifting.MEAN_SHIFT) # Might have small numerical differences if don't use same mean_shift. - ev.add_chains(chains1) - ev.add_chains(chains2) + ev.add_chains(chains1, num_slices=num_slices) + ev.add_chains(chains2, num_slices=num_slices) - assert np.exp(ev.ln_evidence_inv) == pytest.approx(0.159438606) - assert np.exp(ev.ln_evidence_inv_var) == pytest.approx(1.164628268e-07) - assert np.exp(ev.ln_evidence_inv_var_var) ** 0.5 == pytest.approx(1.142786462e-08) + if hasattr(model, "flow"): + assert np.exp(ev.ln_evidence_inv) == pytest.approx(0.159438606, rel=0.01) + else: + assert np.exp(ev.ln_evidence_inv) == pytest.approx(0.159438606) + assert np.exp(ev.ln_evidence_inv_var) == pytest.approx(1.164628268e-07) + assert np.exp(ev.ln_evidence_inv_var_var) ** 0.5 == pytest.approx( + 1.142786462e-08 + ) return