Skip to content

Commit

Permalink
Merge pull request #358 from mj-will/rework-rejection-sampling
Browse files Browse the repository at this point in the history
Rework rejection sampling
  • Loading branch information
mj-will authored Dec 13, 2023
2 parents b389742 + 992dd43 commit ee2a238
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 112 deletions.
114 changes: 63 additions & 51 deletions nessai/proposal/flowproposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import logging
import os
import re
from warnings import warn

import matplotlib.pyplot as plt
import numpy as np
import numpy.lib.recfunctions as rfn
from scipy.special import logsumexp
import torch

from .. import config
Expand Down Expand Up @@ -1201,37 +1203,36 @@ def x_prime_log_prior(self, x):
"""
raise RuntimeError("Prime prior is not implemented")

def compute_weights(self, x, log_q):
def compute_weights(self, x, log_q, return_log_prior=False):
"""
Compute weights for the samples.
Computes the log weights for rejection sampling sampling such that
that the maximum log probability is zero.
Also sets the fields `logP` and `logL`. Note `logL` is set as the
proposal probability.
Does NOT normalise the weights
Parameters
----------
x : structured_arrays
Array of points
log_q : array_like
Array of log proposal probabilities.
return_log_prior: bool
If true, the log-prior probability is also returned.
Returns
-------
array_like
Log-weights for rejection sampling.
"""
if self.use_x_prime_prior:
x["logP"] = self.x_prime_log_prior(x)
log_p = self.x_prime_log_prior(x)
else:
x["logP"] = self.log_prior(x)
log_p = self.log_prior(x)

x["logL"] = log_q
log_w = x["logP"] - log_q
log_w -= np.max(log_w)
return log_w
log_w = log_p - log_q
if return_log_prior:
return log_w, log_p
else:
return log_w

def rejection_sampling(self, z, min_log_q=None):
"""
Expand All @@ -1255,6 +1256,12 @@ def rejection_sampling(self, z, min_log_q=None):
array_like
Array of accepted samples in the X space.
"""
msg = (
"`FlowProposal.rejection_sampling` is deprecated and will be "
"removed in a future release."
)
warn(msg, FutureWarning)

x, log_q, z = self.backward_pass(
z,
rescale=not self.use_x_prime_prior,
Expand Down Expand Up @@ -1360,6 +1367,7 @@ def populate(self, worst_point, N=10000, plot=True, r=None):
plots with samples, these are often a few MB in size so
proceed with caution!
"""
st = datetime.datetime.now()
if not self.initialised:
raise RuntimeError(
"Proposal has not been initialised. "
Expand Down Expand Up @@ -1400,49 +1408,52 @@ def populate(self, worst_point, N=10000, plot=True, r=None):
"Existing pool of samples is not empty. "
"Discarding existing samples."
)
self.x = empty_structured_array(N, dtype=self.population_dtype)
self.indices = []
z_samples = np.empty([N, self.dims])

proposed = 0
accepted = 0
percent = 0.1
warn = True
samples = empty_structured_array(0, dtype=self.population_dtype)

self.prep_latent_prior()

while accepted < N:
z = self.draw_latent_prior(self.drawsize)
proposed += z.shape[0]
log_n = np.log(N)
log_n_expected = -np.inf
n_proposed = 0
log_weights = np.empty(0)
log_constant = 0.0
n_accepted = 0

z, x = self.rejection_sampling(z, min_log_q=min_log_q)
while n_accepted < N:
z = self.draw_latent_prior(self.drawsize)
n_proposed += z.shape[0]

if not x.size:
x, log_q = self.backward_pass(
z, rescale=not self.use_x_prime_prior
)
if self.truncate_log_q:
above_min_log_q = log_q > min_log_q
x, log_q = get_subset_arrays(above_min_log_q, x, log_q)
# Handle case where all samples are below min_log_q
if not len(x):
continue
log_w = self.compute_weights(x, log_q)

if warn and (x.size / self.drawsize < 0.01):
logger.debug(
"Rejection sampling accepted less than 1 percent of "
f"samples! ({x.size / self.drawsize})"
)
warn = False
samples = np.concatenate([samples, x])
log_weights = np.concatenate([log_weights, log_w])
log_constant = np.nanmax(log_w)
log_n_expected = logsumexp(log_weights - log_constant)

n = min(x.size, N - accepted)
self.x[accepted : (accepted + n)] = x[:n]
z_samples[accepted : (accepted + n), ...] = z[:n]
accepted += n
if accepted > percent * N:
logger.debug(
f"Accepted {accepted} / {N} points, "
f"acceptance: {accepted/proposed:.4}"
)
percent += 0.1
# Only try rejection sampling if we expected to accept enough
# points. In the case where we don't, we continue drawing samples
if log_n_expected >= log_n:
log_u = np.log(np.random.rand(len(log_weights)))
accept = (log_weights - log_constant) > log_u
n_accepted = np.sum(accept)

self.x = samples[accept][:N]
self.samples = self.convert_to_samples(self.x, plot=plot)

if self._plot_pool and plot:
self.plot_pool(z_samples, self.samples)
self.plot_pool(self.samples)

self.population_time += datetime.datetime.now() - st
logger.debug("Evaluating log-likelihoods")
self.samples["logL"] = self.model.batch_evaluate_log_likelihood(
self.samples
Expand All @@ -1454,13 +1465,13 @@ def populate(self, worst_point, N=10000, plot=True, r=None):
logger.debug(f"Current acceptance {self.acceptance[-1]}")

self.indices = np.random.permutation(self.samples.size).tolist()
self.population_acceptance = self.x.size / proposed
self.population_acceptance = n_accepted / n_proposed
self.populated_count += 1
self.populated = True
self._checked_population = False
logger.debug(f"Proposal populated with {len(self.indices)} samples")
logger.debug(
f"Overall proposal acceptance: {self.x.size / proposed:.4}"
f"Overall proposal acceptance: {self.x.size / n_proposed:.4}"
)

def get_alt_distribution(self):
Expand Down Expand Up @@ -1510,10 +1521,8 @@ def draw(self, worst_point):
self.populating = True
if self.update_poolsize:
self.update_poolsize_scale(self.ns_acceptance)
st = datetime.datetime.now()
while not self.populated:
self.populate(worst_point, N=self.poolsize)
self.population_time += datetime.datetime.now() - st
self.populating = False
# new sample is drawn randomly from proposed points
# popping from right end is faster
Expand All @@ -1526,14 +1535,12 @@ def draw(self, worst_point):
return new_sample

@nessai_style()
def plot_pool(self, z, x):
def plot_pool(self, x):
"""
Plot the pool of points.
Parameters
----------
z : array_like
Latent samples to plot
x : array_like
Corresponding samples to plot in the physical space.
"""
Expand All @@ -1555,7 +1562,12 @@ def plot_pool(self, z, x):
),
)

z_tensor = torch.from_numpy(z).to(self.flow.device)
z, log_q = self.forward_pass(x, compute_radius=False)
z_tensor = (
torch.from_numpy(z)
.type(torch.get_default_dtype())
.to(self.flow.device)
)
with torch.inference_mode():
if self.alt_dist is not None:
log_p = self.alt_dist.log_prob(z_tensor).cpu().numpy()
Expand All @@ -1568,8 +1580,8 @@ def plot_pool(self, z, x):

fig, axs = plt.subplots(3, 1, figsize=(3, 9))
axs = axs.ravel()
axs[0].hist(x["logL"], 20, histtype="step", label="log q")
axs[1].hist(x["logL"] - log_p, 20, histtype="step", label="log J")
axs[0].hist(log_q, 20, histtype="step", label="log q")
axs[1].hist(log_q - log_p, 20, histtype="step", label="log J")
axs[2].hist(
np.sqrt(np.sum(z**2, axis=1)),
20,
Expand Down
21 changes: 13 additions & 8 deletions nessai/proposal/rejection.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,32 @@ def log_proposal(self, x):
"""
return self.model.new_point_log_prob(x)

def compute_weights(self, x):
def compute_weights(self, x, return_log_prior=False):
"""
Get weights for the samples.
Computes the log weights for rejection sampling sampling such that
that the maximum log probability is zero.
Computes the log weights for rejection sampling sampling but does not
normalize the weights.
Parameters
----------
x : structured_array
Array of points
return_log_prior: bool
If true, the log-prior probability is also returned.
Returns
-------
log_w : :obj:`numpy.ndarray`
Array of log-weights rescaled such that the maximum value is zero.
"""
x["logP"] = self.model.batch_evaluate_log_prior(x)
log_p = self.model.batch_evaluate_log_prior(x)
log_q = self.log_proposal(x)
log_w = x["logP"] - log_q
log_w -= np.nanmax(log_w)
return log_w
log_w = log_p - log_q
if return_log_prior:
return log_w, log_p
else:
return log_w

def populate(self, N=None):
"""
Expand All @@ -101,7 +105,8 @@ def populate(self, N=None):
if N is None:
N = self.poolsize
x = self.draw_proposal(N=N)
log_w = self.compute_weights(x)
log_w, x["logP"] = self.compute_weights(x, return_log_prior=True)
log_w -= np.nanmax(log_w)
log_u = np.log(np.random.rand(N))
indices = np.where((log_w - log_u) >= 0)[0]
self.samples = x[indices]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,8 @@ def test_draw_populated_last_sample(proposal):
@pytest.mark.parametrize("update", [False, True])
def test_draw_not_populated(proposal, update, wait):
"""Test the draw method when the proposal is not populated"""
import datetime

proposal.populated = False
proposal.poolsize = 100
proposal.population_time = datetime.timedelta()
proposal.samples = None
proposal.indices = []
proposal.update_poolsize = update
Expand All @@ -56,7 +53,6 @@ def mock_populate(*args, **kwargs):

assert out == 2
assert proposal.populated is True
assert proposal.population_time.total_seconds() > 0.0

proposal.populate.assert_called_once_with(1.0, N=100)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,16 @@ def test_reset_integration(tmpdir, model, latent_prior):
modified_proposal.populate(model.new_point())
modified_proposal.reset()

# attributes that should be different
ignore = [
"population_time",
]

d1 = proposal.__getstate__()
d2 = modified_proposal.__getstate__()

for key in ignore:
d1.pop(key)
d2.pop(key)

assert d1 == d2
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_plot_pool_all(proposal):
proposal.populated_count = 0
x = numpy_array_to_live_points(np.random.randn(10, 2), ["x", "y"])
with patch("nessai.proposal.flowproposal.plot_live_points") as plot:
FlowProposal.plot_pool(proposal, None, x)
FlowProposal.plot_pool(proposal, x)
plot.assert_called_once_with(
x, c="logL", filename=os.path.join("test", "pool_0.png")
)
Expand All @@ -90,7 +90,7 @@ def test_plot_pool_1d(proposal, tmpdir, alt_dist):

z = np.random.randn(10, 2)
x = numpy_array_to_live_points(np.random.randn(10, 2), ["x", "y"])
x["logL"] = np.random.randn(10)
log_q = np.random.randn(10)
x["logP"] = np.random.randn(10)
training_data = numpy_array_to_live_points(
np.random.randn(10, 2), ["x", "y"]
Expand All @@ -100,6 +100,7 @@ def test_plot_pool_1d(proposal, tmpdir, alt_dist):
proposal.training_data = training_data
log_p = torch.arange(10)

proposal.forward_pass = MagicMock(return_value=(z, log_q))
proposal.flow = MagicMock()
proposal.flow.device = "cpu"
if alt_dist:
Expand All @@ -111,7 +112,7 @@ def test_plot_pool_1d(proposal, tmpdir, alt_dist):
)
proposal.alt_dist = None
with patch("nessai.proposal.flowproposal.plot_1d_comparison") as plot:
FlowProposal.plot_pool(proposal, z, x)
FlowProposal.plot_pool(proposal, x)

plot.assert_called_once_with(
training_data,
Expand Down
Loading

0 comments on commit ee2a238

Please sign in to comment.