From 02e022c8aa903d6522e153c1eaa24aba6f630a4a Mon Sep 17 00:00:00 2001 From: mj-will Date: Wed, 22 Nov 2023 15:50:28 +0000 Subject: [PATCH 01/15] feat: rework rejection sampling in flowproposal --- nessai/proposal/flowproposal.py | 82 +++++++++++++++------------------ 1 file changed, 37 insertions(+), 45 deletions(-) diff --git a/nessai/proposal/flowproposal.py b/nessai/proposal/flowproposal.py index 7e544e24..80539812 100644 --- a/nessai/proposal/flowproposal.py +++ b/nessai/proposal/flowproposal.py @@ -12,6 +12,7 @@ import matplotlib.pyplot as plt import numpy as np import numpy.lib.recfunctions as rfn +from scipy.special import logsumexp import torch from .. import config @@ -1201,15 +1202,11 @@ def x_prime_log_prior(self, x): """ raise RuntimeError("Prime prior is not implemented") - def compute_weights(self, x, log_q): + def compute_weights(self, x, log_q, return_log_prior=False): """ Compute weights for the samples. - Computes the log weights for rejection sampling sampling such that - that the maximum log probability is zero. - - Also sets the fields `logP` and `logL`. Note `logL` is set as the - proposal probability. + Does NOT normalise the weights Parameters ---------- @@ -1217,6 +1214,8 @@ def compute_weights(self, x, log_q): Array of points log_q : array_like Array of log proposal probabilities. + return_log_prior: bool + If true, the log-prior probability is also returned. Returns ------- @@ -1224,14 +1223,15 @@ def compute_weights(self, x, log_q): Log-weights for rejection sampling. """ if self.use_x_prime_prior: - x["logP"] = self.x_prime_log_prior(x) + log_p = self.x_prime_log_prior(x) else: - x["logP"] = self.log_prior(x) + log_p = self.log_prior(x) - x["logL"] = log_q - log_w = x["logP"] - log_q - log_w -= np.max(log_w) - return log_w + log_w = log_p - log_q + if return_log_prior: + return log_w, log_p + else: + return log_w def rejection_sampling(self, z, min_log_q=None): """ @@ -1400,48 +1400,40 @@ def populate(self, worst_point, N=10000, plot=True, r=None): "Existing pool of samples is not empty. " "Discarding existing samples." ) - self.x = empty_structured_array(N, dtype=self.population_dtype) self.indices = [] - z_samples = np.empty([N, self.dims]) - - proposed = 0 - accepted = 0 - percent = 0.1 - warn = True + samples = empty_structured_array(0, dtype=self.population_dtype) self.prep_latent_prior() - while accepted < N: - z = self.draw_latent_prior(self.drawsize) - proposed += z.shape[0] - - z, x = self.rejection_sampling(z, min_log_q=min_log_q) + log_n = np.log(N) + log_n_expected = -np.inf + n_proposed = 0 + log_weights = np.empty(0) + log_constant = 0.0 - if not x.size: - continue + while log_n_expected < log_n: + z = self.draw_latent_prior(self.drawsize) + n_proposed += z.shape[0] - if warn and (x.size / self.drawsize < 0.01): - logger.debug( - "Rejection sampling accepted less than 1 percent of " - f"samples! ({x.size / self.drawsize})" - ) - warn = False + x, log_q = self.backward_pass( + z, rescale=not self.use_x_prime_prior + ) + log_w, x["logP"] = self.compute_weights( + x, log_q, return_log_prior=True + ) - n = min(x.size, N - accepted) - self.x[accepted : (accepted + n)] = x[:n] - z_samples[accepted : (accepted + n), ...] = z[:n] - accepted += n - if accepted > percent * N: - logger.debug( - f"Accepted {accepted} / {N} points, " - f"acceptance: {accepted/proposed:.4}" - ) - percent += 0.1 + samples = np.concatenate([samples, x]) + log_weights = np.concatenate([log_weights, log_w]) + log_constant = np.nanmax(log_w) + log_n_expected = logsumexp(log_weights - log_constant) + log_u = np.log(np.random.rand(len(log_weights))) + accept = (log_weights - log_constant) > log_u + self.x = samples[accept] self.samples = self.convert_to_samples(self.x, plot=plot) if self._plot_pool and plot: - self.plot_pool(z_samples, self.samples) + self.plot_pool(None, self.samples) logger.debug("Evaluating log-likelihoods") self.samples["logL"] = self.model.batch_evaluate_log_likelihood( @@ -1454,13 +1446,13 @@ def populate(self, worst_point, N=10000, plot=True, r=None): logger.debug(f"Current acceptance {self.acceptance[-1]}") self.indices = np.random.permutation(self.samples.size).tolist() - self.population_acceptance = self.x.size / proposed + self.population_acceptance = self.x.size / n_proposed self.populated_count += 1 self.populated = True self._checked_population = False logger.debug(f"Proposal populated with {len(self.indices)} samples") logger.debug( - f"Overall proposal acceptance: {self.x.size / proposed:.4}" + f"Overall proposal acceptance: {self.x.size / n_proposed:.4}" ) def get_alt_distribution(self): From 26d179b606e8c22aa23ee2b56cbec9df4b30f8b6 Mon Sep 17 00:00:00 2001 From: mj-will Date: Fri, 24 Nov 2023 16:28:40 +0000 Subject: [PATCH 02/15] refactor: change rejection sampling to be consistent with flowproposal --- nessai/proposal/rejection.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/nessai/proposal/rejection.py b/nessai/proposal/rejection.py index e9c767ba..dd00777e 100644 --- a/nessai/proposal/rejection.py +++ b/nessai/proposal/rejection.py @@ -60,28 +60,32 @@ def log_proposal(self, x): """ return self.model.new_point_log_prob(x) - def compute_weights(self, x): + def compute_weights(self, x, return_log_prior=False): """ Get weights for the samples. - Computes the log weights for rejection sampling sampling such that - that the maximum log probability is zero. + Computes the log weights for rejection sampling sampling but does not + normalize the weights. Parameters ---------- x : structured_array Array of points + return_log_prior: bool + If true, the log-prior probability is also returned. Returns ------- log_w : :obj:`numpy.ndarray` Array of log-weights rescaled such that the maximum value is zero. """ - x["logP"] = self.model.batch_evaluate_log_prior(x) + log_p = self.model.batch_evaluate_log_prior(x) log_q = self.log_proposal(x) - log_w = x["logP"] - log_q - log_w -= np.nanmax(log_w) - return log_w + log_w = log_p - log_q + if return_log_prior: + return log_w, log_p + else: + return log_w def populate(self, N=None): """ @@ -101,7 +105,9 @@ def populate(self, N=None): if N is None: N = self.poolsize x = self.draw_proposal(N=N) - log_w = self.compute_weights(x) + log_w, log_p = self.compute_weights(x, return_log_prior=True) + log_w -= np.nanmax(log_w) + x["logP"] = log_p log_u = np.log(np.random.rand(N)) indices = np.where((log_w - log_u) >= 0)[0] self.samples = x[indices] From ffd4bc733b5e6759733b747ade3a8c0069dbdf0c Mon Sep 17 00:00:00 2001 From: mj-will Date: Wed, 29 Nov 2023 16:23:20 +0000 Subject: [PATCH 03/15] fix: ensure at most n samples are added --- nessai/proposal/flowproposal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nessai/proposal/flowproposal.py b/nessai/proposal/flowproposal.py index 80539812..278ad067 100644 --- a/nessai/proposal/flowproposal.py +++ b/nessai/proposal/flowproposal.py @@ -1429,7 +1429,7 @@ def populate(self, worst_point, N=10000, plot=True, r=None): log_u = np.log(np.random.rand(len(log_weights))) accept = (log_weights - log_constant) > log_u - self.x = samples[accept] + self.x = samples[accept][:N] self.samples = self.convert_to_samples(self.x, plot=plot) if self._plot_pool and plot: From 72de28234fd5e1ce42c968dd78c715cf73f99ec7 Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 11 Dec 2023 14:44:00 +0000 Subject: [PATCH 04/15] refactor!: change plot pool to not require z --- nessai/proposal/flowproposal.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/nessai/proposal/flowproposal.py b/nessai/proposal/flowproposal.py index 278ad067..3e3b8b5e 100644 --- a/nessai/proposal/flowproposal.py +++ b/nessai/proposal/flowproposal.py @@ -1433,7 +1433,7 @@ def populate(self, worst_point, N=10000, plot=True, r=None): self.samples = self.convert_to_samples(self.x, plot=plot) if self._plot_pool and plot: - self.plot_pool(None, self.samples) + self.plot_pool(self.samples) logger.debug("Evaluating log-likelihoods") self.samples["logL"] = self.model.batch_evaluate_log_likelihood( @@ -1518,14 +1518,12 @@ def draw(self, worst_point): return new_sample @nessai_style() - def plot_pool(self, z, x): + def plot_pool(self, x): """ Plot the pool of points. Parameters ---------- - z : array_like - Latent samples to plot x : array_like Corresponding samples to plot in the physical space. """ @@ -1547,7 +1545,12 @@ def plot_pool(self, z, x): ), ) - z_tensor = torch.from_numpy(z).to(self.flow.device) + z, log_q = self.forward_pass(x, compute_radius=False) + z_tensor = ( + torch.from_numpy(z) + .type(torch.get_default_dtype()) + .to(self.flow.device) + ) with torch.inference_mode(): if self.alt_dist is not None: log_p = self.alt_dist.log_prob(z_tensor).cpu().numpy() @@ -1560,8 +1563,8 @@ def plot_pool(self, z, x): fig, axs = plt.subplots(3, 1, figsize=(3, 9)) axs = axs.ravel() - axs[0].hist(x["logL"], 20, histtype="step", label="log q") - axs[1].hist(x["logL"] - log_p, 20, histtype="step", label="log J") + axs[0].hist(log_q, 20, histtype="step", label="log q") + axs[1].hist(log_q - log_p, 20, histtype="step", label="log J") axs[2].hist( np.sqrt(np.sum(z**2, axis=1)), 20, From 1e086e8f5135d6c61b7d0aa006c2fbd59934e1d0 Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 11 Dec 2023 14:44:51 +0000 Subject: [PATCH 05/15] test: add integration test that includes plotting --- tests/test_sampling/test_standard_sampling.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_sampling/test_standard_sampling.py b/tests/test_sampling/test_standard_sampling.py index 1002659c..83474f17 100644 --- a/tests/test_sampling/test_standard_sampling.py +++ b/tests/test_sampling/test_standard_sampling.py @@ -459,6 +459,21 @@ def test_constant_volume_mode(integration_model, tmpdir): fs.run(plot=False) +@pytest.mark.slow_integration_test +def test_sampling_with_plotting(integration_model, tmpdir): + """Test sampling with plots enabled""" + output = str(tmpdir.mkdir("test")) + fs = FlowSampler( + integration_model, + output=output, + nlive=100, + plot=True, + proposal_plots=True, + stopping=1.0, + ) + fs.run(plot=True) + + @pytest.mark.slow_integration_test def test_truncate_log_q(integration_model, tmpdir): """Test sampling with truncate_log_q""" From 476fc8e5e1c2fe89d4f9b8937ec25ba8bbb9a20d Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 11 Dec 2023 15:49:55 +0000 Subject: [PATCH 06/15] fix: support truncate log q --- nessai/proposal/flowproposal.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nessai/proposal/flowproposal.py b/nessai/proposal/flowproposal.py index 3e3b8b5e..6d8d0622 100644 --- a/nessai/proposal/flowproposal.py +++ b/nessai/proposal/flowproposal.py @@ -1418,9 +1418,10 @@ def populate(self, worst_point, N=10000, plot=True, r=None): x, log_q = self.backward_pass( z, rescale=not self.use_x_prime_prior ) - log_w, x["logP"] = self.compute_weights( - x, log_q, return_log_prior=True - ) + if self.truncate_log_q: + above_min_log_q = log_q > min_log_q + x, log_q = get_subset_arrays(above_min_log_q, x, log_q) + log_w = self.compute_weights(x, log_q) samples = np.concatenate([samples, x]) log_weights = np.concatenate([log_weights, log_w]) From b3ad400859d58c65ef0bc85beef894306ff427a8 Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 11 Dec 2023 16:01:20 +0000 Subject: [PATCH 07/15] fix: address rejection sampling edge cases - Handle case where no samples are above log_q_min - Handle case where rejection sampling yields too few samples --- nessai/proposal/flowproposal.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/nessai/proposal/flowproposal.py b/nessai/proposal/flowproposal.py index 6d8d0622..84c317e3 100644 --- a/nessai/proposal/flowproposal.py +++ b/nessai/proposal/flowproposal.py @@ -8,6 +8,7 @@ import logging import os import re +from warnings import warn import matplotlib.pyplot as plt import numpy as np @@ -1255,6 +1256,12 @@ def rejection_sampling(self, z, min_log_q=None): array_like Array of accepted samples in the X space. """ + msg = ( + "`FlowProposal.rejection_sampling` is deprecated and will be " + "removed in a future release." + ) + warn(msg, FutureWarning) + x, log_q, z = self.backward_pass( z, rescale=not self.use_x_prime_prior, @@ -1410,8 +1417,9 @@ def populate(self, worst_point, N=10000, plot=True, r=None): n_proposed = 0 log_weights = np.empty(0) log_constant = 0.0 + n_accepted = 0 - while log_n_expected < log_n: + while n_accepted < N: z = self.draw_latent_prior(self.drawsize) n_proposed += z.shape[0] @@ -1421,6 +1429,9 @@ def populate(self, worst_point, N=10000, plot=True, r=None): if self.truncate_log_q: above_min_log_q = log_q > min_log_q x, log_q = get_subset_arrays(above_min_log_q, x, log_q) + # Handle case where all samples are below min_log_q + if not len(x): + continue log_w = self.compute_weights(x, log_q) samples = np.concatenate([samples, x]) @@ -1428,8 +1439,13 @@ def populate(self, worst_point, N=10000, plot=True, r=None): log_constant = np.nanmax(log_w) log_n_expected = logsumexp(log_weights - log_constant) - log_u = np.log(np.random.rand(len(log_weights))) - accept = (log_weights - log_constant) > log_u + # Only try rejection sampling if we expected to accept enough + # points. In the case where we don't, we continue drawing samples + if log_n_expected >= log_n: + log_u = np.log(np.random.rand(len(log_weights))) + accept = (log_weights - log_constant) > log_u + n_accepted = np.sum(accept) + self.x = samples[accept][:N] self.samples = self.convert_to_samples(self.x, plot=plot) @@ -1447,7 +1463,7 @@ def populate(self, worst_point, N=10000, plot=True, r=None): logger.debug(f"Current acceptance {self.acceptance[-1]}") self.indices = np.random.permutation(self.samples.size).tolist() - self.population_acceptance = self.x.size / n_proposed + self.population_acceptance = n_accepted / n_proposed self.populated_count += 1 self.populated = True self._checked_population = False From c277ddd3dd7f417d8c13a54b273d92fb04166737 Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 11 Dec 2023 16:07:36 +0000 Subject: [PATCH 08/15] refactor: assign value directly --- nessai/proposal/rejection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nessai/proposal/rejection.py b/nessai/proposal/rejection.py index dd00777e..67eec4db 100644 --- a/nessai/proposal/rejection.py +++ b/nessai/proposal/rejection.py @@ -105,9 +105,8 @@ def populate(self, N=None): if N is None: N = self.poolsize x = self.draw_proposal(N=N) - log_w, log_p = self.compute_weights(x, return_log_prior=True) + log_w, x["logP"] = self.compute_weights(x, return_log_prior=True) log_w -= np.nanmax(log_w) - x["logP"] = log_p log_u = np.log(np.random.rand(N)) indices = np.where((log_w - log_u) >= 0)[0] self.samples = x[indices] From e69b97edaa4b88d0c5d58c7b24c326ec8dd54340 Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 11 Dec 2023 16:08:03 +0000 Subject: [PATCH 09/15] test: refactor test for compute weights --- tests/test_proposal/test_rejection.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/test_proposal/test_rejection.py b/tests/test_proposal/test_rejection.py index 0d62cf9e..14e7ac65 100644 --- a/tests/test_proposal/test_rejection.py +++ b/tests/test_proposal/test_rejection.py @@ -53,16 +53,21 @@ def test_log_proposal(proposal): np.testing.assert_array_equal(out, log_prob) -def test_compute_weights(proposal): +@pytest.mark.parametrize("return_log_prior", [False, True]) +def test_compute_weights(proposal, return_log_prior): """Test the compute weights method""" x = numpy_array_to_live_points(np.array([[1], [2], [3]]), "x") proposal.model = Mock() - proposal.model.batch_evaluate_log_prior = Mock( - return_value=np.array([6, 6, 6]) - ) + log_p = np.array([6, 6, 6]) + proposal.model.batch_evaluate_log_prior = Mock(return_value=log_p) proposal.log_proposal = Mock(return_value=np.array([3, 4, np.nan])) - log_w = np.array([0, -1, np.nan]) - out = RejectionProposal.compute_weights(proposal, x) + log_w = np.array([3, 2, np.nan]) + out = RejectionProposal.compute_weights( + proposal, x, return_log_prior=return_log_prior + ) + if return_log_prior: + assert out[1] is log_p + out = out[0] proposal.model.batch_evaluate_log_prior.assert_called_once_with(x) proposal.log_proposal.assert_called_once_with(x) @@ -84,11 +89,12 @@ def test_populate(proposal, N): u[::2] = 1e-10 samples = x[::2] log_l = np.log(np.random.rand(samples.size)) + log_prior = np.zeros(len(x)) samples["logL"] = log_l proposal.poolsize = poolsize proposal.populated = False proposal.draw_proposal = Mock(return_value=x) - proposal.compute_weights = Mock(return_value=log_w) + proposal.compute_weights = Mock(return_value=(log_w, log_prior)) proposal.model = Mock() proposal.model.batch_evaluate_log_likelihood = MagicMock( return_value=log_l From c29a4f639b28c6fb75abf8ca314343553d8d8a45 Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 11 Dec 2023 16:09:35 +0000 Subject: [PATCH 10/15] test: refactor for changes to rejection sampling --- .../test_flowproposal_plots.py | 7 +- .../test_flowproposal_population.py | 111 ++++++++++++------ 2 files changed, 80 insertions(+), 38 deletions(-) diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_plots.py b/tests/test_proposal/test_flowproposal/test_flowproposal_plots.py index a8987e49..30482229 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_plots.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_plots.py @@ -70,7 +70,7 @@ def test_plot_pool_all(proposal): proposal.populated_count = 0 x = numpy_array_to_live_points(np.random.randn(10, 2), ["x", "y"]) with patch("nessai.proposal.flowproposal.plot_live_points") as plot: - FlowProposal.plot_pool(proposal, None, x) + FlowProposal.plot_pool(proposal, x) plot.assert_called_once_with( x, c="logL", filename=os.path.join("test", "pool_0.png") ) @@ -90,7 +90,7 @@ def test_plot_pool_1d(proposal, tmpdir, alt_dist): z = np.random.randn(10, 2) x = numpy_array_to_live_points(np.random.randn(10, 2), ["x", "y"]) - x["logL"] = np.random.randn(10) + log_q = np.random.randn(10) x["logP"] = np.random.randn(10) training_data = numpy_array_to_live_points( np.random.randn(10, 2), ["x", "y"] @@ -100,6 +100,7 @@ def test_plot_pool_1d(proposal, tmpdir, alt_dist): proposal.training_data = training_data log_p = torch.arange(10) + proposal.forward_pass = MagicMock(return_value=(z, log_q)) proposal.flow = MagicMock() proposal.flow.device = "cpu" if alt_dist: @@ -111,7 +112,7 @@ def test_plot_pool_1d(proposal, tmpdir, alt_dist): ) proposal.alt_dist = None with patch("nessai.proposal.flowproposal.plot_1d_comparison") as plot: - FlowProposal.plot_pool(proposal, z, x) + FlowProposal.plot_pool(proposal, x) plot.assert_called_once_with( training_data, diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_population.py b/tests/test_proposal/test_flowproposal/test_flowproposal_population.py index 77ae388c..5ec345e0 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_population.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_population.py @@ -85,8 +85,7 @@ def test_compute_weights(proposal, x, log_q): proposal.log_prior.assert_called_once_with(x) out = -1 - log_q - out -= out.max() - assert np.array_equal(log_w, out) + np.testing.assert_array_equal(log_w, out) def test_compute_weights_prime_prior(proposal, x, log_q): @@ -99,8 +98,7 @@ def test_compute_weights_prime_prior(proposal, x, log_q): proposal.x_prime_log_prior.assert_called_once_with(x) out = -1 - log_q - out -= out.max() - assert np.array_equal(log_w, out) + np.testing.assert_array_equal(log_w, out) @patch("numpy.random.rand", return_value=np.array([0.1, 0.9])) @@ -370,12 +368,23 @@ def test_populate( numpy_array_to_live_points(np.random.randn(drawsize, n_dims), names), numpy_array_to_live_points(np.random.randn(drawsize, n_dims), names), ] + log_q = [ + np.log(np.random.rand(drawsize)), + np.log(np.random.rand(drawsize)), + np.log(np.random.rand(drawsize)), + ] + log_w = [ + np.log(np.concatenate([np.ones(drawsize - 1), np.zeros(1)])), + np.log(np.concatenate([np.ones(drawsize - 1), np.zeros(1)])), + np.log(np.concatenate([np.ones(drawsize - 1), np.zeros(1)])), + ] + # Control rejection sampling using log_w + rand_u = 0.5 * np.ones(3 * drawsize) + log_l = np.random.rand(poolsize) r_flow = 1.0 - min_log_q = None - if r is None: r_out = r_flow if min_radius is not None: @@ -400,17 +409,17 @@ def test_populate( proposal.check_acceptance = check_acceptance proposal._plot_pool = True proposal.populated_count = 1 - proposal.population_dtype = get_dtype(["x_prime", "y_prime"]) + proposal.population_dtype = get_dtype(names) proposal.truncate_log_q = False + proposal.use_x_prime_prior = False proposal.forward_pass = MagicMock(return_value=(worst_z, np.nan)) + proposal.backward_pass = MagicMock(side_effect=zip(x, log_q)) proposal.radius = MagicMock(return_value=r_flow) proposal.get_alt_distribution = MagicMock(return_value=None) proposal.prep_latent_prior = MagicMock() proposal.draw_latent_prior = MagicMock(side_effect=z) - proposal.rejection_sampling = MagicMock( - side_effect=[(a[:-1], b[:-1]) for a, b in zip(z, x)] - ) + proposal.compute_weights = MagicMock(side_effect=log_w) proposal.compute_acceptance = MagicMock(return_value=0.8) proposal.model = MagicMock() proposal.model.batch_evaluate_log_likelihood = MagicMock( @@ -422,17 +431,22 @@ def test_populate( side_effect=lambda *args, **kwargs: args[0] ) - x_empty = np.empty(poolsize, dtype=proposal.population_dtype) + x_empty = np.empty(0, dtype=proposal.population_dtype) with patch( "nessai.proposal.flowproposal.empty_structured_array", return_value=x_empty, - ) as mock_empty: - FlowProposal.populate(proposal, worst_point, N=10, plot=True, r=r) + ) as mock_empty, patch( + "numpy.random.rand", return_value=rand_u + ) as mock_rand: + FlowProposal.populate( + proposal, worst_point, N=poolsize, plot=True, r=r + ) mock_empty.assert_called_once_with( - poolsize, + 0, dtype=proposal.population_dtype, ) + mock_rand.assert_called_once_with(3 * drawsize) if r is None: proposal.forward_pass.assert_called_once_with( @@ -451,12 +465,11 @@ def test_populate( draw_calls = 3 * [call(5)] proposal.draw_latent_prior.assert_has_calls(draw_calls) - rejection_calls = [ - call(z[0], min_log_q=min_log_q), - call(z[1], min_log_q=min_log_q), - call(z[2], min_log_q=min_log_q), - ] - proposal.rejection_sampling.assert_has_calls(rejection_calls) + backwards_calls = [call(zz, rescale=True) for zz in z] + proposal.backward_pass.assert_has_calls(backwards_calls) + + compute_weights_calls = [call(xx, lq) for xx, lq in zip(x, log_q)] + proposal.compute_weights.assert_has_calls(compute_weights_calls) proposal.plot_pool.assert_called_once() proposal.convert_to_samples.assert_called_once() @@ -465,7 +478,7 @@ def test_populate( ) assert proposal.convert_to_samples.call_args[1]["plot"] is True - assert proposal.population_acceptance == (10 / 15) + assert proposal.population_acceptance == (12 / 15) assert proposal.populated_count == 2 assert proposal.populated is True assert proposal.x.size == 10 @@ -493,7 +506,7 @@ def test_populate_not_initialised(proposal): def test_populate_truncate_log_q(proposal): n_dims = 2 nlive = 8 - poolsize = 10 + poolsize = 8 drawsize = 5 names = ["x", "y"] r_flow = 2.0 @@ -510,6 +523,22 @@ def test_populate_truncate_log_q(proposal): numpy_array_to_live_points(np.random.randn(drawsize, n_dims), names), numpy_array_to_live_points(np.random.randn(drawsize, n_dims), names), ] + log_q = [ + np.zeros(drawsize), + np.zeros(drawsize), + np.zeros(drawsize), + ] + # This sample will be discarded because of the logq min check + for i in range(3): + log_q[i][-1] = np.nan_to_num(-np.inf) + log_w = [ + np.log(np.concatenate([np.ones(drawsize - 2), np.zeros(1)])), + np.log(np.concatenate([np.ones(drawsize - 2), np.zeros(1)])), + np.log(np.concatenate([np.ones(drawsize - 2), np.zeros(1)])), + ] + # Control rejection sampling using log_w + rand_u = 0.5 * np.ones(3 * (drawsize - 1)) + log_l = np.random.rand(poolsize) proposal.initialised = True @@ -524,20 +553,23 @@ def test_populate_truncate_log_q(proposal): proposal.compute_radius_with_all = False proposal.check_acceptance = False proposal._plot_pool = False + proposal.use_x_prime_prior = False proposal.populated_count = 1 - proposal.population_dtype = get_dtype(["x_prime", "y_prime"]) + proposal.population_dtype = get_dtype(names) proposal.truncate_log_q = True proposal.training_data = numpy_array_to_live_points( np.random.randn(nlive, n_dims), names=names, ) - log_q_live = np.log(np.random.rand(nlive)) - min_log_q = log_q_live.min() + log_q_live = np.zeros(nlive) + log_q_live[-1] = -1.0 proposal.forward_pass = MagicMock( return_value=(nlive * [None], log_q_live) ) + proposal.backward_pass = MagicMock(side_effect=zip(x, log_q)) + proposal.compute_weights = MagicMock(side_effect=log_w) proposal.radius = MagicMock(return_value=r_flow) proposal.get_alt_distribution = MagicMock(return_value=None) proposal.prep_latent_prior = MagicMock() @@ -555,23 +587,32 @@ def test_populate_truncate_log_q(proposal): side_effect=lambda *args, **kwargs: args[0] ) - x_empty = np.empty(poolsize, dtype=proposal.population_dtype) + x_empty = np.empty(0, dtype=proposal.population_dtype) with patch( "nessai.proposal.flowproposal.empty_structured_array", return_value=x_empty, - ) as mock_empty: - FlowProposal.populate(proposal, worst_point, N=10, plot=False) + ) as mock_empty, patch( + "numpy.random.rand", return_value=rand_u + ) as mock_rand: + FlowProposal.populate(proposal, worst_point, N=poolsize, plot=False) mock_empty.assert_called_once_with( - poolsize, + 0, dtype=proposal.population_dtype, ) + mock_rand.assert_called_once_with(3 * drawsize - 3) + + assert proposal.population_acceptance == (9 / 15) proposal.forward_pass.assert_called_once_with(proposal.training_data) - rejection_calls = [ - call(z[0], min_log_q=min_log_q), - call(z[1], min_log_q=min_log_q), - call(z[2], min_log_q=min_log_q), - ] - proposal.rejection_sampling.assert_has_calls(rejection_calls) + backwards_calls = [call(zz, rescale=True) for zz in z] + proposal.backward_pass.assert_has_calls(backwards_calls) + + compute_weights_calls = [(xx[:-1], lq[:-1]) for xx, lq in zip(x, log_q)] + for actual_call, expected_call in zip( + proposal.compute_weights.call_args_list, + compute_weights_calls, + ): + assert_structured_arrays_equal(actual_call[0][0], expected_call[0]) + np.testing.assert_array_equal(actual_call[0][1], expected_call[1]) From 3fb86deea53e6ddf7432495d5ed9139ad06a1881 Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 11 Dec 2023 16:16:55 +0000 Subject: [PATCH 11/15] refactor: exclude likelihood from population time --- nessai/proposal/flowproposal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nessai/proposal/flowproposal.py b/nessai/proposal/flowproposal.py index 84c317e3..3b55c03a 100644 --- a/nessai/proposal/flowproposal.py +++ b/nessai/proposal/flowproposal.py @@ -1367,6 +1367,7 @@ def populate(self, worst_point, N=10000, plot=True, r=None): plots with samples, these are often a few MB in size so proceed with caution! """ + st = datetime.datetime.now() if not self.initialised: raise RuntimeError( "Proposal has not been initialised. " @@ -1452,6 +1453,7 @@ def populate(self, worst_point, N=10000, plot=True, r=None): if self._plot_pool and plot: self.plot_pool(self.samples) + self.population_time += datetime.datetime.now() - st logger.debug("Evaluating log-likelihoods") self.samples["logL"] = self.model.batch_evaluate_log_likelihood( self.samples @@ -1519,10 +1521,8 @@ def draw(self, worst_point): self.populating = True if self.update_poolsize: self.update_poolsize_scale(self.ns_acceptance) - st = datetime.datetime.now() while not self.populated: self.populate(worst_point, N=self.poolsize) - self.population_time += datetime.datetime.now() - st self.populating = False # new sample is drawn randomly from proposed points # popping from right end is faster From 173c3685631616b673a0223d2c202e74ed9443cb Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 11 Dec 2023 16:46:44 +0000 Subject: [PATCH 12/15] test: update for change to population time --- .../test_flowproposal/test_flowproposal_draw.py | 4 ---- .../test_flowproposal/test_flowproposal_init_resume.py | 10 ++++++++++ .../test_flowproposal/test_flowproposal_population.py | 5 +++++ 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_draw.py b/tests/test_proposal/test_flowproposal/test_flowproposal_draw.py index 32cdd19f..afa2c0f1 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_draw.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_draw.py @@ -33,11 +33,8 @@ def test_draw_populated_last_sample(proposal): @pytest.mark.parametrize("update", [False, True]) def test_draw_not_populated(proposal, update, wait): """Test the draw method when the proposal is not populated""" - import datetime - proposal.populated = False proposal.poolsize = 100 - proposal.population_time = datetime.timedelta() proposal.samples = None proposal.indices = [] proposal.update_poolsize = update @@ -56,7 +53,6 @@ def mock_populate(*args, **kwargs): assert out == 2 assert proposal.populated is True - assert proposal.population_time.total_seconds() > 0.0 proposal.populate.assert_called_once_with(1.0, N=100) diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_init_resume.py b/tests/test_proposal/test_flowproposal/test_flowproposal_init_resume.py index af083e04..e0b5be56 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_init_resume.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_init_resume.py @@ -276,6 +276,16 @@ def test_reset_integration(tmpdir, model, latent_prior): modified_proposal.populate(model.new_point()) modified_proposal.reset() + # attributes that should be different + ignore = [ + "population_time", + ] + d1 = proposal.__getstate__() d2 = modified_proposal.__getstate__() + + for key in ignore: + d1.pop(key) + d2.pop(key) + assert d1 == d2 diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_population.py b/tests/test_proposal/test_flowproposal/test_flowproposal_population.py index 5ec345e0..acf91225 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_population.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_population.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Test methods related to popluation of the proposal after training""" +import datetime from functools import partial import os @@ -394,6 +395,7 @@ def test_populate( else: r_out = r + proposal.population_time = datetime.timedelta() proposal.initialised = True proposal.max_radius = max_radius proposal.dims = n_dims @@ -494,6 +496,8 @@ def test_populate( ) np.testing.assert_array_equal(proposal.samples["logL"], log_l) + assert proposal.population_time.total_seconds() > 0.0 + def test_populate_not_initialised(proposal): """Assert populate fails if the proposal is not initialised""" @@ -541,6 +545,7 @@ def test_populate_truncate_log_q(proposal): log_l = np.random.rand(poolsize) + proposal.population_time = datetime.timedelta() proposal.initialised = True proposal.dims = n_dims proposal.poolsize = poolsize From e15d11677823b38fd8f31896be87a2e2b1b76060 Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 11 Dec 2023 17:36:14 +0000 Subject: [PATCH 13/15] test: fix timing issue with windows --- .../test_flowproposal_population.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_population.py b/tests/test_proposal/test_flowproposal/test_flowproposal_population.py index acf91225..b952802c 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_population.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_population.py @@ -348,7 +348,7 @@ def test_draw_latent_prior(proposal): [(0.5, None), (None, 1.5), (None, None)], ) def test_populate( - proposal, check_acceptance, indices, r, min_radius, max_radius + proposal, check_acceptance, indices, r, min_radius, max_radius, wait ): """Test the main populate method""" n_dims = 2 @@ -383,6 +383,7 @@ def test_populate( rand_u = 0.5 * np.ones(3 * drawsize) log_l = np.random.rand(poolsize) + log_p = np.random.rand(poolsize) r_flow = 1.0 @@ -428,10 +429,14 @@ def test_populate( return_value=log_l ) + def convert_to_samples(samples, plot): + samples["logP"] = log_p + # wait for windows + wait() + return samples + proposal.plot_pool = MagicMock() - proposal.convert_to_samples = MagicMock( - side_effect=lambda *args, **kwargs: args[0] - ) + proposal.convert_to_samples = MagicMock(side_effect=convert_to_samples) x_empty = np.empty(0, dtype=proposal.population_dtype) with patch( From 1fe42ca0b34bdbd6c8b2ebd4a437dc4d5323d80f Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 12 Dec 2023 11:36:16 +0000 Subject: [PATCH 14/15] test: check return_log_prior in compute_weights --- .../test_flowproposal_population.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_population.py b/tests/test_proposal/test_flowproposal/test_flowproposal_population.py index b952802c..85e8e26d 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_population.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_population.py @@ -89,6 +89,21 @@ def test_compute_weights(proposal, x, log_q): np.testing.assert_array_equal(log_w, out) +def test_compute_weights_return_prior(proposal, x, log_q): + """Assert prior is returned""" + proposal.use_x_prime_prior = False + log_p = -np.ones(x.size) + proposal.log_prior = MagicMock(return_value=log_p) + log_w, log_p_out = FlowProposal.compute_weights( + proposal, x, log_q, return_log_prior=True + ) + + proposal.log_prior.assert_called_once_with(x) + expected = -1 - log_q + np.testing.assert_array_equal(log_w, expected) + assert log_p_out is log_p + + def test_compute_weights_prime_prior(proposal, x, log_q): """Test method for computing rejection sampling weights with the prime prior. From 992dd43b083d9864c225f7ad1df6d0c264ddcdc2 Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 12 Dec 2023 11:38:32 +0000 Subject: [PATCH 15/15] test: check plots exist --- tests/test_sampling/test_standard_sampling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_sampling/test_standard_sampling.py b/tests/test_sampling/test_standard_sampling.py index 83474f17..ff185ac8 100644 --- a/tests/test_sampling/test_standard_sampling.py +++ b/tests/test_sampling/test_standard_sampling.py @@ -469,9 +469,10 @@ def test_sampling_with_plotting(integration_model, tmpdir): nlive=100, plot=True, proposal_plots=True, - stopping=1.0, ) fs.run(plot=True) + assert os.path.exists(os.path.join(output, "proposal", "pool_0.png")) + assert os.path.exists(os.path.join(output, "proposal", "pool_0_log_q.png")) @pytest.mark.slow_integration_test