diff --git a/nessai/proposal/flowproposal.py b/nessai/proposal/flowproposal.py index 7e544e24..3b55c03a 100644 --- a/nessai/proposal/flowproposal.py +++ b/nessai/proposal/flowproposal.py @@ -8,10 +8,12 @@ import logging import os import re +from warnings import warn import matplotlib.pyplot as plt import numpy as np import numpy.lib.recfunctions as rfn +from scipy.special import logsumexp import torch from .. import config @@ -1201,15 +1203,11 @@ def x_prime_log_prior(self, x): """ raise RuntimeError("Prime prior is not implemented") - def compute_weights(self, x, log_q): + def compute_weights(self, x, log_q, return_log_prior=False): """ Compute weights for the samples. - Computes the log weights for rejection sampling sampling such that - that the maximum log probability is zero. - - Also sets the fields `logP` and `logL`. Note `logL` is set as the - proposal probability. + Does NOT normalise the weights Parameters ---------- @@ -1217,6 +1215,8 @@ def compute_weights(self, x, log_q): Array of points log_q : array_like Array of log proposal probabilities. + return_log_prior: bool + If true, the log-prior probability is also returned. Returns ------- @@ -1224,14 +1224,15 @@ def compute_weights(self, x, log_q): Log-weights for rejection sampling. """ if self.use_x_prime_prior: - x["logP"] = self.x_prime_log_prior(x) + log_p = self.x_prime_log_prior(x) else: - x["logP"] = self.log_prior(x) + log_p = self.log_prior(x) - x["logL"] = log_q - log_w = x["logP"] - log_q - log_w -= np.max(log_w) - return log_w + log_w = log_p - log_q + if return_log_prior: + return log_w, log_p + else: + return log_w def rejection_sampling(self, z, min_log_q=None): """ @@ -1255,6 +1256,12 @@ def rejection_sampling(self, z, min_log_q=None): array_like Array of accepted samples in the X space. """ + msg = ( + "`FlowProposal.rejection_sampling` is deprecated and will be " + "removed in a future release." + ) + warn(msg, FutureWarning) + x, log_q, z = self.backward_pass( z, rescale=not self.use_x_prime_prior, @@ -1360,6 +1367,7 @@ def populate(self, worst_point, N=10000, plot=True, r=None): plots with samples, these are often a few MB in size so proceed with caution! """ + st = datetime.datetime.now() if not self.initialised: raise RuntimeError( "Proposal has not been initialised. " @@ -1400,49 +1408,52 @@ def populate(self, worst_point, N=10000, plot=True, r=None): "Existing pool of samples is not empty. " "Discarding existing samples." ) - self.x = empty_structured_array(N, dtype=self.population_dtype) self.indices = [] - z_samples = np.empty([N, self.dims]) - - proposed = 0 - accepted = 0 - percent = 0.1 - warn = True + samples = empty_structured_array(0, dtype=self.population_dtype) self.prep_latent_prior() - while accepted < N: - z = self.draw_latent_prior(self.drawsize) - proposed += z.shape[0] + log_n = np.log(N) + log_n_expected = -np.inf + n_proposed = 0 + log_weights = np.empty(0) + log_constant = 0.0 + n_accepted = 0 - z, x = self.rejection_sampling(z, min_log_q=min_log_q) + while n_accepted < N: + z = self.draw_latent_prior(self.drawsize) + n_proposed += z.shape[0] - if not x.size: + x, log_q = self.backward_pass( + z, rescale=not self.use_x_prime_prior + ) + if self.truncate_log_q: + above_min_log_q = log_q > min_log_q + x, log_q = get_subset_arrays(above_min_log_q, x, log_q) + # Handle case where all samples are below min_log_q + if not len(x): continue + log_w = self.compute_weights(x, log_q) - if warn and (x.size / self.drawsize < 0.01): - logger.debug( - "Rejection sampling accepted less than 1 percent of " - f"samples! ({x.size / self.drawsize})" - ) - warn = False + samples = np.concatenate([samples, x]) + log_weights = np.concatenate([log_weights, log_w]) + log_constant = np.nanmax(log_w) + log_n_expected = logsumexp(log_weights - log_constant) - n = min(x.size, N - accepted) - self.x[accepted : (accepted + n)] = x[:n] - z_samples[accepted : (accepted + n), ...] = z[:n] - accepted += n - if accepted > percent * N: - logger.debug( - f"Accepted {accepted} / {N} points, " - f"acceptance: {accepted/proposed:.4}" - ) - percent += 0.1 + # Only try rejection sampling if we expected to accept enough + # points. In the case where we don't, we continue drawing samples + if log_n_expected >= log_n: + log_u = np.log(np.random.rand(len(log_weights))) + accept = (log_weights - log_constant) > log_u + n_accepted = np.sum(accept) + self.x = samples[accept][:N] self.samples = self.convert_to_samples(self.x, plot=plot) if self._plot_pool and plot: - self.plot_pool(z_samples, self.samples) + self.plot_pool(self.samples) + self.population_time += datetime.datetime.now() - st logger.debug("Evaluating log-likelihoods") self.samples["logL"] = self.model.batch_evaluate_log_likelihood( self.samples @@ -1454,13 +1465,13 @@ def populate(self, worst_point, N=10000, plot=True, r=None): logger.debug(f"Current acceptance {self.acceptance[-1]}") self.indices = np.random.permutation(self.samples.size).tolist() - self.population_acceptance = self.x.size / proposed + self.population_acceptance = n_accepted / n_proposed self.populated_count += 1 self.populated = True self._checked_population = False logger.debug(f"Proposal populated with {len(self.indices)} samples") logger.debug( - f"Overall proposal acceptance: {self.x.size / proposed:.4}" + f"Overall proposal acceptance: {self.x.size / n_proposed:.4}" ) def get_alt_distribution(self): @@ -1510,10 +1521,8 @@ def draw(self, worst_point): self.populating = True if self.update_poolsize: self.update_poolsize_scale(self.ns_acceptance) - st = datetime.datetime.now() while not self.populated: self.populate(worst_point, N=self.poolsize) - self.population_time += datetime.datetime.now() - st self.populating = False # new sample is drawn randomly from proposed points # popping from right end is faster @@ -1526,14 +1535,12 @@ def draw(self, worst_point): return new_sample @nessai_style() - def plot_pool(self, z, x): + def plot_pool(self, x): """ Plot the pool of points. Parameters ---------- - z : array_like - Latent samples to plot x : array_like Corresponding samples to plot in the physical space. """ @@ -1555,7 +1562,12 @@ def plot_pool(self, z, x): ), ) - z_tensor = torch.from_numpy(z).to(self.flow.device) + z, log_q = self.forward_pass(x, compute_radius=False) + z_tensor = ( + torch.from_numpy(z) + .type(torch.get_default_dtype()) + .to(self.flow.device) + ) with torch.inference_mode(): if self.alt_dist is not None: log_p = self.alt_dist.log_prob(z_tensor).cpu().numpy() @@ -1568,8 +1580,8 @@ def plot_pool(self, z, x): fig, axs = plt.subplots(3, 1, figsize=(3, 9)) axs = axs.ravel() - axs[0].hist(x["logL"], 20, histtype="step", label="log q") - axs[1].hist(x["logL"] - log_p, 20, histtype="step", label="log J") + axs[0].hist(log_q, 20, histtype="step", label="log q") + axs[1].hist(log_q - log_p, 20, histtype="step", label="log J") axs[2].hist( np.sqrt(np.sum(z**2, axis=1)), 20, diff --git a/nessai/proposal/rejection.py b/nessai/proposal/rejection.py index e9c767ba..67eec4db 100644 --- a/nessai/proposal/rejection.py +++ b/nessai/proposal/rejection.py @@ -60,28 +60,32 @@ def log_proposal(self, x): """ return self.model.new_point_log_prob(x) - def compute_weights(self, x): + def compute_weights(self, x, return_log_prior=False): """ Get weights for the samples. - Computes the log weights for rejection sampling sampling such that - that the maximum log probability is zero. + Computes the log weights for rejection sampling sampling but does not + normalize the weights. Parameters ---------- x : structured_array Array of points + return_log_prior: bool + If true, the log-prior probability is also returned. Returns ------- log_w : :obj:`numpy.ndarray` Array of log-weights rescaled such that the maximum value is zero. """ - x["logP"] = self.model.batch_evaluate_log_prior(x) + log_p = self.model.batch_evaluate_log_prior(x) log_q = self.log_proposal(x) - log_w = x["logP"] - log_q - log_w -= np.nanmax(log_w) - return log_w + log_w = log_p - log_q + if return_log_prior: + return log_w, log_p + else: + return log_w def populate(self, N=None): """ @@ -101,7 +105,8 @@ def populate(self, N=None): if N is None: N = self.poolsize x = self.draw_proposal(N=N) - log_w = self.compute_weights(x) + log_w, x["logP"] = self.compute_weights(x, return_log_prior=True) + log_w -= np.nanmax(log_w) log_u = np.log(np.random.rand(N)) indices = np.where((log_w - log_u) >= 0)[0] self.samples = x[indices] diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_draw.py b/tests/test_proposal/test_flowproposal/test_flowproposal_draw.py index 32cdd19f..afa2c0f1 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_draw.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_draw.py @@ -33,11 +33,8 @@ def test_draw_populated_last_sample(proposal): @pytest.mark.parametrize("update", [False, True]) def test_draw_not_populated(proposal, update, wait): """Test the draw method when the proposal is not populated""" - import datetime - proposal.populated = False proposal.poolsize = 100 - proposal.population_time = datetime.timedelta() proposal.samples = None proposal.indices = [] proposal.update_poolsize = update @@ -56,7 +53,6 @@ def mock_populate(*args, **kwargs): assert out == 2 assert proposal.populated is True - assert proposal.population_time.total_seconds() > 0.0 proposal.populate.assert_called_once_with(1.0, N=100) diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_init_resume.py b/tests/test_proposal/test_flowproposal/test_flowproposal_init_resume.py index af083e04..e0b5be56 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_init_resume.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_init_resume.py @@ -276,6 +276,16 @@ def test_reset_integration(tmpdir, model, latent_prior): modified_proposal.populate(model.new_point()) modified_proposal.reset() + # attributes that should be different + ignore = [ + "population_time", + ] + d1 = proposal.__getstate__() d2 = modified_proposal.__getstate__() + + for key in ignore: + d1.pop(key) + d2.pop(key) + assert d1 == d2 diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_plots.py b/tests/test_proposal/test_flowproposal/test_flowproposal_plots.py index a8987e49..30482229 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_plots.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_plots.py @@ -70,7 +70,7 @@ def test_plot_pool_all(proposal): proposal.populated_count = 0 x = numpy_array_to_live_points(np.random.randn(10, 2), ["x", "y"]) with patch("nessai.proposal.flowproposal.plot_live_points") as plot: - FlowProposal.plot_pool(proposal, None, x) + FlowProposal.plot_pool(proposal, x) plot.assert_called_once_with( x, c="logL", filename=os.path.join("test", "pool_0.png") ) @@ -90,7 +90,7 @@ def test_plot_pool_1d(proposal, tmpdir, alt_dist): z = np.random.randn(10, 2) x = numpy_array_to_live_points(np.random.randn(10, 2), ["x", "y"]) - x["logL"] = np.random.randn(10) + log_q = np.random.randn(10) x["logP"] = np.random.randn(10) training_data = numpy_array_to_live_points( np.random.randn(10, 2), ["x", "y"] @@ -100,6 +100,7 @@ def test_plot_pool_1d(proposal, tmpdir, alt_dist): proposal.training_data = training_data log_p = torch.arange(10) + proposal.forward_pass = MagicMock(return_value=(z, log_q)) proposal.flow = MagicMock() proposal.flow.device = "cpu" if alt_dist: @@ -111,7 +112,7 @@ def test_plot_pool_1d(proposal, tmpdir, alt_dist): ) proposal.alt_dist = None with patch("nessai.proposal.flowproposal.plot_1d_comparison") as plot: - FlowProposal.plot_pool(proposal, z, x) + FlowProposal.plot_pool(proposal, x) plot.assert_called_once_with( training_data, diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_population.py b/tests/test_proposal/test_flowproposal/test_flowproposal_population.py index 77ae388c..85e8e26d 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_population.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_population.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Test methods related to popluation of the proposal after training""" +import datetime from functools import partial import os @@ -85,8 +86,22 @@ def test_compute_weights(proposal, x, log_q): proposal.log_prior.assert_called_once_with(x) out = -1 - log_q - out -= out.max() - assert np.array_equal(log_w, out) + np.testing.assert_array_equal(log_w, out) + + +def test_compute_weights_return_prior(proposal, x, log_q): + """Assert prior is returned""" + proposal.use_x_prime_prior = False + log_p = -np.ones(x.size) + proposal.log_prior = MagicMock(return_value=log_p) + log_w, log_p_out = FlowProposal.compute_weights( + proposal, x, log_q, return_log_prior=True + ) + + proposal.log_prior.assert_called_once_with(x) + expected = -1 - log_q + np.testing.assert_array_equal(log_w, expected) + assert log_p_out is log_p def test_compute_weights_prime_prior(proposal, x, log_q): @@ -99,8 +114,7 @@ def test_compute_weights_prime_prior(proposal, x, log_q): proposal.x_prime_log_prior.assert_called_once_with(x) out = -1 - log_q - out -= out.max() - assert np.array_equal(log_w, out) + np.testing.assert_array_equal(log_w, out) @patch("numpy.random.rand", return_value=np.array([0.1, 0.9])) @@ -349,7 +363,7 @@ def test_draw_latent_prior(proposal): [(0.5, None), (None, 1.5), (None, None)], ) def test_populate( - proposal, check_acceptance, indices, r, min_radius, max_radius + proposal, check_acceptance, indices, r, min_radius, max_radius, wait ): """Test the main populate method""" n_dims = 2 @@ -370,12 +384,24 @@ def test_populate( numpy_array_to_live_points(np.random.randn(drawsize, n_dims), names), numpy_array_to_live_points(np.random.randn(drawsize, n_dims), names), ] + log_q = [ + np.log(np.random.rand(drawsize)), + np.log(np.random.rand(drawsize)), + np.log(np.random.rand(drawsize)), + ] + log_w = [ + np.log(np.concatenate([np.ones(drawsize - 1), np.zeros(1)])), + np.log(np.concatenate([np.ones(drawsize - 1), np.zeros(1)])), + np.log(np.concatenate([np.ones(drawsize - 1), np.zeros(1)])), + ] + # Control rejection sampling using log_w + rand_u = 0.5 * np.ones(3 * drawsize) + log_l = np.random.rand(poolsize) + log_p = np.random.rand(poolsize) r_flow = 1.0 - min_log_q = None - if r is None: r_out = r_flow if min_radius is not None: @@ -385,6 +411,7 @@ def test_populate( else: r_out = r + proposal.population_time = datetime.timedelta() proposal.initialised = True proposal.max_radius = max_radius proposal.dims = n_dims @@ -400,39 +427,48 @@ def test_populate( proposal.check_acceptance = check_acceptance proposal._plot_pool = True proposal.populated_count = 1 - proposal.population_dtype = get_dtype(["x_prime", "y_prime"]) + proposal.population_dtype = get_dtype(names) proposal.truncate_log_q = False + proposal.use_x_prime_prior = False proposal.forward_pass = MagicMock(return_value=(worst_z, np.nan)) + proposal.backward_pass = MagicMock(side_effect=zip(x, log_q)) proposal.radius = MagicMock(return_value=r_flow) proposal.get_alt_distribution = MagicMock(return_value=None) proposal.prep_latent_prior = MagicMock() proposal.draw_latent_prior = MagicMock(side_effect=z) - proposal.rejection_sampling = MagicMock( - side_effect=[(a[:-1], b[:-1]) for a, b in zip(z, x)] - ) + proposal.compute_weights = MagicMock(side_effect=log_w) proposal.compute_acceptance = MagicMock(return_value=0.8) proposal.model = MagicMock() proposal.model.batch_evaluate_log_likelihood = MagicMock( return_value=log_l ) + def convert_to_samples(samples, plot): + samples["logP"] = log_p + # wait for windows + wait() + return samples + proposal.plot_pool = MagicMock() - proposal.convert_to_samples = MagicMock( - side_effect=lambda *args, **kwargs: args[0] - ) + proposal.convert_to_samples = MagicMock(side_effect=convert_to_samples) - x_empty = np.empty(poolsize, dtype=proposal.population_dtype) + x_empty = np.empty(0, dtype=proposal.population_dtype) with patch( "nessai.proposal.flowproposal.empty_structured_array", return_value=x_empty, - ) as mock_empty: - FlowProposal.populate(proposal, worst_point, N=10, plot=True, r=r) + ) as mock_empty, patch( + "numpy.random.rand", return_value=rand_u + ) as mock_rand: + FlowProposal.populate( + proposal, worst_point, N=poolsize, plot=True, r=r + ) mock_empty.assert_called_once_with( - poolsize, + 0, dtype=proposal.population_dtype, ) + mock_rand.assert_called_once_with(3 * drawsize) if r is None: proposal.forward_pass.assert_called_once_with( @@ -451,12 +487,11 @@ def test_populate( draw_calls = 3 * [call(5)] proposal.draw_latent_prior.assert_has_calls(draw_calls) - rejection_calls = [ - call(z[0], min_log_q=min_log_q), - call(z[1], min_log_q=min_log_q), - call(z[2], min_log_q=min_log_q), - ] - proposal.rejection_sampling.assert_has_calls(rejection_calls) + backwards_calls = [call(zz, rescale=True) for zz in z] + proposal.backward_pass.assert_has_calls(backwards_calls) + + compute_weights_calls = [call(xx, lq) for xx, lq in zip(x, log_q)] + proposal.compute_weights.assert_has_calls(compute_weights_calls) proposal.plot_pool.assert_called_once() proposal.convert_to_samples.assert_called_once() @@ -465,7 +500,7 @@ def test_populate( ) assert proposal.convert_to_samples.call_args[1]["plot"] is True - assert proposal.population_acceptance == (10 / 15) + assert proposal.population_acceptance == (12 / 15) assert proposal.populated_count == 2 assert proposal.populated is True assert proposal.x.size == 10 @@ -481,6 +516,8 @@ def test_populate( ) np.testing.assert_array_equal(proposal.samples["logL"], log_l) + assert proposal.population_time.total_seconds() > 0.0 + def test_populate_not_initialised(proposal): """Assert populate fails if the proposal is not initialised""" @@ -493,7 +530,7 @@ def test_populate_not_initialised(proposal): def test_populate_truncate_log_q(proposal): n_dims = 2 nlive = 8 - poolsize = 10 + poolsize = 8 drawsize = 5 names = ["x", "y"] r_flow = 2.0 @@ -510,8 +547,25 @@ def test_populate_truncate_log_q(proposal): numpy_array_to_live_points(np.random.randn(drawsize, n_dims), names), numpy_array_to_live_points(np.random.randn(drawsize, n_dims), names), ] + log_q = [ + np.zeros(drawsize), + np.zeros(drawsize), + np.zeros(drawsize), + ] + # This sample will be discarded because of the logq min check + for i in range(3): + log_q[i][-1] = np.nan_to_num(-np.inf) + log_w = [ + np.log(np.concatenate([np.ones(drawsize - 2), np.zeros(1)])), + np.log(np.concatenate([np.ones(drawsize - 2), np.zeros(1)])), + np.log(np.concatenate([np.ones(drawsize - 2), np.zeros(1)])), + ] + # Control rejection sampling using log_w + rand_u = 0.5 * np.ones(3 * (drawsize - 1)) + log_l = np.random.rand(poolsize) + proposal.population_time = datetime.timedelta() proposal.initialised = True proposal.dims = n_dims proposal.poolsize = poolsize @@ -524,20 +578,23 @@ def test_populate_truncate_log_q(proposal): proposal.compute_radius_with_all = False proposal.check_acceptance = False proposal._plot_pool = False + proposal.use_x_prime_prior = False proposal.populated_count = 1 - proposal.population_dtype = get_dtype(["x_prime", "y_prime"]) + proposal.population_dtype = get_dtype(names) proposal.truncate_log_q = True proposal.training_data = numpy_array_to_live_points( np.random.randn(nlive, n_dims), names=names, ) - log_q_live = np.log(np.random.rand(nlive)) - min_log_q = log_q_live.min() + log_q_live = np.zeros(nlive) + log_q_live[-1] = -1.0 proposal.forward_pass = MagicMock( return_value=(nlive * [None], log_q_live) ) + proposal.backward_pass = MagicMock(side_effect=zip(x, log_q)) + proposal.compute_weights = MagicMock(side_effect=log_w) proposal.radius = MagicMock(return_value=r_flow) proposal.get_alt_distribution = MagicMock(return_value=None) proposal.prep_latent_prior = MagicMock() @@ -555,23 +612,32 @@ def test_populate_truncate_log_q(proposal): side_effect=lambda *args, **kwargs: args[0] ) - x_empty = np.empty(poolsize, dtype=proposal.population_dtype) + x_empty = np.empty(0, dtype=proposal.population_dtype) with patch( "nessai.proposal.flowproposal.empty_structured_array", return_value=x_empty, - ) as mock_empty: - FlowProposal.populate(proposal, worst_point, N=10, plot=False) + ) as mock_empty, patch( + "numpy.random.rand", return_value=rand_u + ) as mock_rand: + FlowProposal.populate(proposal, worst_point, N=poolsize, plot=False) mock_empty.assert_called_once_with( - poolsize, + 0, dtype=proposal.population_dtype, ) + mock_rand.assert_called_once_with(3 * drawsize - 3) + + assert proposal.population_acceptance == (9 / 15) proposal.forward_pass.assert_called_once_with(proposal.training_data) - rejection_calls = [ - call(z[0], min_log_q=min_log_q), - call(z[1], min_log_q=min_log_q), - call(z[2], min_log_q=min_log_q), - ] - proposal.rejection_sampling.assert_has_calls(rejection_calls) + backwards_calls = [call(zz, rescale=True) for zz in z] + proposal.backward_pass.assert_has_calls(backwards_calls) + + compute_weights_calls = [(xx[:-1], lq[:-1]) for xx, lq in zip(x, log_q)] + for actual_call, expected_call in zip( + proposal.compute_weights.call_args_list, + compute_weights_calls, + ): + assert_structured_arrays_equal(actual_call[0][0], expected_call[0]) + np.testing.assert_array_equal(actual_call[0][1], expected_call[1]) diff --git a/tests/test_proposal/test_rejection.py b/tests/test_proposal/test_rejection.py index 0d62cf9e..14e7ac65 100644 --- a/tests/test_proposal/test_rejection.py +++ b/tests/test_proposal/test_rejection.py @@ -53,16 +53,21 @@ def test_log_proposal(proposal): np.testing.assert_array_equal(out, log_prob) -def test_compute_weights(proposal): +@pytest.mark.parametrize("return_log_prior", [False, True]) +def test_compute_weights(proposal, return_log_prior): """Test the compute weights method""" x = numpy_array_to_live_points(np.array([[1], [2], [3]]), "x") proposal.model = Mock() - proposal.model.batch_evaluate_log_prior = Mock( - return_value=np.array([6, 6, 6]) - ) + log_p = np.array([6, 6, 6]) + proposal.model.batch_evaluate_log_prior = Mock(return_value=log_p) proposal.log_proposal = Mock(return_value=np.array([3, 4, np.nan])) - log_w = np.array([0, -1, np.nan]) - out = RejectionProposal.compute_weights(proposal, x) + log_w = np.array([3, 2, np.nan]) + out = RejectionProposal.compute_weights( + proposal, x, return_log_prior=return_log_prior + ) + if return_log_prior: + assert out[1] is log_p + out = out[0] proposal.model.batch_evaluate_log_prior.assert_called_once_with(x) proposal.log_proposal.assert_called_once_with(x) @@ -84,11 +89,12 @@ def test_populate(proposal, N): u[::2] = 1e-10 samples = x[::2] log_l = np.log(np.random.rand(samples.size)) + log_prior = np.zeros(len(x)) samples["logL"] = log_l proposal.poolsize = poolsize proposal.populated = False proposal.draw_proposal = Mock(return_value=x) - proposal.compute_weights = Mock(return_value=log_w) + proposal.compute_weights = Mock(return_value=(log_w, log_prior)) proposal.model = Mock() proposal.model.batch_evaluate_log_likelihood = MagicMock( return_value=log_l diff --git a/tests/test_sampling/test_standard_sampling.py b/tests/test_sampling/test_standard_sampling.py index 1002659c..ff185ac8 100644 --- a/tests/test_sampling/test_standard_sampling.py +++ b/tests/test_sampling/test_standard_sampling.py @@ -459,6 +459,22 @@ def test_constant_volume_mode(integration_model, tmpdir): fs.run(plot=False) +@pytest.mark.slow_integration_test +def test_sampling_with_plotting(integration_model, tmpdir): + """Test sampling with plots enabled""" + output = str(tmpdir.mkdir("test")) + fs = FlowSampler( + integration_model, + output=output, + nlive=100, + plot=True, + proposal_plots=True, + ) + fs.run(plot=True) + assert os.path.exists(os.path.join(output, "proposal", "pool_0.png")) + assert os.path.exists(os.path.join(output, "proposal", "pool_0_log_q.png")) + + @pytest.mark.slow_integration_test def test_truncate_log_q(integration_model, tmpdir): """Test sampling with truncate_log_q"""