Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework rejection sampling #358

Merged
merged 15 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 63 additions & 51 deletions nessai/proposal/flowproposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import logging
import os
import re
from warnings import warn

import matplotlib.pyplot as plt
import numpy as np
import numpy.lib.recfunctions as rfn
from scipy.special import logsumexp
import torch

from .. import config
Expand Down Expand Up @@ -1201,37 +1203,36 @@ def x_prime_log_prior(self, x):
"""
raise RuntimeError("Prime prior is not implemented")

def compute_weights(self, x, log_q):
def compute_weights(self, x, log_q, return_log_prior=False):
"""
Compute weights for the samples.

Computes the log weights for rejection sampling sampling such that
that the maximum log probability is zero.

Also sets the fields `logP` and `logL`. Note `logL` is set as the
proposal probability.
Does NOT normalise the weights

Parameters
----------
x : structured_arrays
Array of points
log_q : array_like
Array of log proposal probabilities.
return_log_prior: bool
If true, the log-prior probability is also returned.

Returns
-------
array_like
Log-weights for rejection sampling.
"""
if self.use_x_prime_prior:
x["logP"] = self.x_prime_log_prior(x)
log_p = self.x_prime_log_prior(x)
else:
x["logP"] = self.log_prior(x)
log_p = self.log_prior(x)

x["logL"] = log_q
log_w = x["logP"] - log_q
log_w -= np.max(log_w)
return log_w
log_w = log_p - log_q
if return_log_prior:
return log_w, log_p
else:
return log_w

def rejection_sampling(self, z, min_log_q=None):
"""
Expand All @@ -1255,6 +1256,12 @@ def rejection_sampling(self, z, min_log_q=None):
array_like
Array of accepted samples in the X space.
"""
msg = (
"`FlowProposal.rejection_sampling` is deprecated and will be "
"removed in a future release."
)
warn(msg, FutureWarning)

x, log_q, z = self.backward_pass(
z,
rescale=not self.use_x_prime_prior,
Expand Down Expand Up @@ -1360,6 +1367,7 @@ def populate(self, worst_point, N=10000, plot=True, r=None):
plots with samples, these are often a few MB in size so
proceed with caution!
"""
st = datetime.datetime.now()
if not self.initialised:
raise RuntimeError(
"Proposal has not been initialised. "
Expand Down Expand Up @@ -1400,49 +1408,52 @@ def populate(self, worst_point, N=10000, plot=True, r=None):
"Existing pool of samples is not empty. "
"Discarding existing samples."
)
self.x = empty_structured_array(N, dtype=self.population_dtype)
self.indices = []
z_samples = np.empty([N, self.dims])

proposed = 0
accepted = 0
percent = 0.1
warn = True
samples = empty_structured_array(0, dtype=self.population_dtype)

self.prep_latent_prior()

while accepted < N:
z = self.draw_latent_prior(self.drawsize)
proposed += z.shape[0]
log_n = np.log(N)
log_n_expected = -np.inf
n_proposed = 0
log_weights = np.empty(0)
log_constant = 0.0
n_accepted = 0

z, x = self.rejection_sampling(z, min_log_q=min_log_q)
while n_accepted < N:
z = self.draw_latent_prior(self.drawsize)
n_proposed += z.shape[0]

if not x.size:
x, log_q = self.backward_pass(
z, rescale=not self.use_x_prime_prior
)
if self.truncate_log_q:
above_min_log_q = log_q > min_log_q
x, log_q = get_subset_arrays(above_min_log_q, x, log_q)
# Handle case where all samples are below min_log_q
if not len(x):
continue
log_w = self.compute_weights(x, log_q)

if warn and (x.size / self.drawsize < 0.01):
logger.debug(
"Rejection sampling accepted less than 1 percent of "
f"samples! ({x.size / self.drawsize})"
)
warn = False
samples = np.concatenate([samples, x])
log_weights = np.concatenate([log_weights, log_w])
log_constant = np.nanmax(log_w)
log_n_expected = logsumexp(log_weights - log_constant)

n = min(x.size, N - accepted)
self.x[accepted : (accepted + n)] = x[:n]
z_samples[accepted : (accepted + n), ...] = z[:n]
accepted += n
if accepted > percent * N:
logger.debug(
f"Accepted {accepted} / {N} points, "
f"acceptance: {accepted/proposed:.4}"
)
percent += 0.1
# Only try rejection sampling if we expected to accept enough
# points. In the case where we don't, we continue drawing samples
if log_n_expected >= log_n:
log_u = np.log(np.random.rand(len(log_weights)))
accept = (log_weights - log_constant) > log_u
n_accepted = np.sum(accept)

self.x = samples[accept][:N]
self.samples = self.convert_to_samples(self.x, plot=plot)

if self._plot_pool and plot:
self.plot_pool(z_samples, self.samples)
self.plot_pool(self.samples)

self.population_time += datetime.datetime.now() - st
logger.debug("Evaluating log-likelihoods")
self.samples["logL"] = self.model.batch_evaluate_log_likelihood(
self.samples
Expand All @@ -1454,13 +1465,13 @@ def populate(self, worst_point, N=10000, plot=True, r=None):
logger.debug(f"Current acceptance {self.acceptance[-1]}")

self.indices = np.random.permutation(self.samples.size).tolist()
self.population_acceptance = self.x.size / proposed
self.population_acceptance = n_accepted / n_proposed
self.populated_count += 1
self.populated = True
self._checked_population = False
logger.debug(f"Proposal populated with {len(self.indices)} samples")
logger.debug(
f"Overall proposal acceptance: {self.x.size / proposed:.4}"
f"Overall proposal acceptance: {self.x.size / n_proposed:.4}"
)

def get_alt_distribution(self):
Expand Down Expand Up @@ -1510,10 +1521,8 @@ def draw(self, worst_point):
self.populating = True
if self.update_poolsize:
self.update_poolsize_scale(self.ns_acceptance)
st = datetime.datetime.now()
while not self.populated:
self.populate(worst_point, N=self.poolsize)
self.population_time += datetime.datetime.now() - st
self.populating = False
# new sample is drawn randomly from proposed points
# popping from right end is faster
Expand All @@ -1526,14 +1535,12 @@ def draw(self, worst_point):
return new_sample

@nessai_style()
def plot_pool(self, z, x):
def plot_pool(self, x):
"""
Plot the pool of points.

Parameters
----------
z : array_like
Latent samples to plot
x : array_like
Corresponding samples to plot in the physical space.
"""
Expand All @@ -1555,7 +1562,12 @@ def plot_pool(self, z, x):
),
)

z_tensor = torch.from_numpy(z).to(self.flow.device)
z, log_q = self.forward_pass(x, compute_radius=False)
z_tensor = (
torch.from_numpy(z)
.type(torch.get_default_dtype())
.to(self.flow.device)
)
with torch.inference_mode():
if self.alt_dist is not None:
log_p = self.alt_dist.log_prob(z_tensor).cpu().numpy()
Expand All @@ -1568,8 +1580,8 @@ def plot_pool(self, z, x):

fig, axs = plt.subplots(3, 1, figsize=(3, 9))
axs = axs.ravel()
axs[0].hist(x["logL"], 20, histtype="step", label="log q")
axs[1].hist(x["logL"] - log_p, 20, histtype="step", label="log J")
axs[0].hist(log_q, 20, histtype="step", label="log q")
axs[1].hist(log_q - log_p, 20, histtype="step", label="log J")
axs[2].hist(
np.sqrt(np.sum(z**2, axis=1)),
20,
Expand Down
21 changes: 13 additions & 8 deletions nessai/proposal/rejection.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,32 @@ def log_proposal(self, x):
"""
return self.model.new_point_log_prob(x)

def compute_weights(self, x):
def compute_weights(self, x, return_log_prior=False):
"""
Get weights for the samples.

Computes the log weights for rejection sampling sampling such that
that the maximum log probability is zero.
Computes the log weights for rejection sampling sampling but does not
normalize the weights.

Parameters
----------
x : structured_array
Array of points
return_log_prior: bool
If true, the log-prior probability is also returned.

Returns
-------
log_w : :obj:`numpy.ndarray`
Array of log-weights rescaled such that the maximum value is zero.
"""
x["logP"] = self.model.batch_evaluate_log_prior(x)
log_p = self.model.batch_evaluate_log_prior(x)
log_q = self.log_proposal(x)
log_w = x["logP"] - log_q
log_w -= np.nanmax(log_w)
return log_w
log_w = log_p - log_q
if return_log_prior:
return log_w, log_p
else:
return log_w

def populate(self, N=None):
"""
Expand All @@ -101,7 +105,8 @@ def populate(self, N=None):
if N is None:
N = self.poolsize
x = self.draw_proposal(N=N)
log_w = self.compute_weights(x)
log_w, x["logP"] = self.compute_weights(x, return_log_prior=True)
log_w -= np.nanmax(log_w)
log_u = np.log(np.random.rand(N))
indices = np.where((log_w - log_u) >= 0)[0]
self.samples = x[indices]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,8 @@ def test_draw_populated_last_sample(proposal):
@pytest.mark.parametrize("update", [False, True])
def test_draw_not_populated(proposal, update, wait):
"""Test the draw method when the proposal is not populated"""
import datetime

proposal.populated = False
proposal.poolsize = 100
proposal.population_time = datetime.timedelta()
proposal.samples = None
proposal.indices = []
proposal.update_poolsize = update
Expand All @@ -56,7 +53,6 @@ def mock_populate(*args, **kwargs):

assert out == 2
assert proposal.populated is True
assert proposal.population_time.total_seconds() > 0.0

proposal.populate.assert_called_once_with(1.0, N=100)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,16 @@ def test_reset_integration(tmpdir, model, latent_prior):
modified_proposal.populate(model.new_point())
modified_proposal.reset()

# attributes that should be different
ignore = [
"population_time",
]

d1 = proposal.__getstate__()
d2 = modified_proposal.__getstate__()

for key in ignore:
d1.pop(key)
d2.pop(key)

assert d1 == d2
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_plot_pool_all(proposal):
proposal.populated_count = 0
x = numpy_array_to_live_points(np.random.randn(10, 2), ["x", "y"])
with patch("nessai.proposal.flowproposal.plot_live_points") as plot:
FlowProposal.plot_pool(proposal, None, x)
FlowProposal.plot_pool(proposal, x)
plot.assert_called_once_with(
x, c="logL", filename=os.path.join("test", "pool_0.png")
)
Expand All @@ -90,7 +90,7 @@ def test_plot_pool_1d(proposal, tmpdir, alt_dist):

z = np.random.randn(10, 2)
x = numpy_array_to_live_points(np.random.randn(10, 2), ["x", "y"])
x["logL"] = np.random.randn(10)
log_q = np.random.randn(10)
x["logP"] = np.random.randn(10)
training_data = numpy_array_to_live_points(
np.random.randn(10, 2), ["x", "y"]
Expand All @@ -100,6 +100,7 @@ def test_plot_pool_1d(proposal, tmpdir, alt_dist):
proposal.training_data = training_data
log_p = torch.arange(10)

proposal.forward_pass = MagicMock(return_value=(z, log_q))
proposal.flow = MagicMock()
proposal.flow.device = "cpu"
if alt_dist:
Expand All @@ -111,7 +112,7 @@ def test_plot_pool_1d(proposal, tmpdir, alt_dist):
)
proposal.alt_dist = None
with patch("nessai.proposal.flowproposal.plot_1d_comparison") as plot:
FlowProposal.plot_pool(proposal, z, x)
FlowProposal.plot_pool(proposal, x)

plot.assert_called_once_with(
training_data,
Expand Down
Loading