Skip to content

Commit

Permalink
feat: rework truncation based on log_q
Browse files Browse the repository at this point in the history
  • Loading branch information
mj-will committed Oct 25, 2023
1 parent f557f2b commit 17cf217
Showing 1 changed file with 64 additions and 43 deletions.
107 changes: 64 additions & 43 deletions nessai/proposal/flowproposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
save_live_points,
)
from ..utils.sampling import NDimensionalTruncatedGaussian
from ..utils.structures import get_subset_arrays

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,9 +101,8 @@ class FlowProposal(RejectionProposal):
Similar to ``fuzz`` but instead a scaling factor applied to the radius
this specifies a rescaling for volume of the n-ball used to draw
samples. This is translated to a value for ``fuzz``.
truncate : bool, optional
Truncate proposals using probability compute for worst point.
Not recommended.
truncate_log_q : bool, optional
Truncate proposals using minimum log-probability of the training data.
rescale_parameters : list or bool, optional
If True live points are rescaled to `rescale_bounds` before training.
If an instance of `list` then must contain names of parameters to
Expand Down Expand Up @@ -174,7 +174,7 @@ def __init__(
fixed_radius=False,
drawsize=None,
check_acceptance=False,
truncate=False,
truncate_log_q=False,
rescale_bounds=[-1, 1],
expansion_fraction=4.0,
boundary_inversion=False,
Expand Down Expand Up @@ -245,7 +245,7 @@ def __init__(
self.update_bounds = update_bounds
self.check_acceptance = check_acceptance
self.rescale_bounds = rescale_bounds
self.truncate = truncate
self.truncate_log_q = truncate_log_q
self.boundary_inversion = boundary_inversion
self.inversion_type = inversion_type
self.flow_config = flow_config
Expand Down Expand Up @@ -420,6 +420,8 @@ def configure_latent_prior(self):
from ..utils import draw_nsphere

self._draw_latent_prior = draw_nsphere
elif self.latent_prior == "flow":
self._draw_latent_prior = None
else:
raise RuntimeError(
f"Unknown latent prior: {self.latent_prior}, choose from: "
Expand Down Expand Up @@ -1088,7 +1090,9 @@ def forward_pass(self, x, rescale=True, compute_radius=True):

return z, log_prob + log_J

def backward_pass(self, z, rescale=True):
def backward_pass(
self, z, rescale=True, discard_nans=True, return_z=False
):
"""
A backwards pass from the model (latent -> real)
Expand All @@ -1098,14 +1102,21 @@ def backward_pass(self, z, rescale=True):
Structured array of points in the latent space
rescale : bool, optional (True)
Apply inverse rescaling function
discard_nan: bool
If True, samples with NaNs or Infs in log_q are removed.
return_z : bool
If True, return the array of latent samples, this may differ from
the input since samples can be discarded.
Returns
-------
x : array_like
Samples in the latent space
Samples in the data space
log_prob : array_like
Log probabilities corresponding to each sample (including the
Jacobian)
z : array_like
Samples in the latent space, only returned if :code:`return_z=True`
"""
# Compute the log probability
try:
Expand All @@ -1115,8 +1126,9 @@ def backward_pass(self, z, rescale=True):
except AssertionError:
return np.array([]), np.array([])

valid = np.isfinite(log_prob)
x, log_prob = x[valid], log_prob[valid]
if discard_nans:
valid = np.isfinite(log_prob)
x, log_prob = x[valid], log_prob[valid]
x = numpy_array_to_live_points(
x.astype(config.livepoints.default_float_dtype),
self.rescaled_names,
Expand All @@ -1126,10 +1138,13 @@ def backward_pass(self, z, rescale=True):
x, log_J = self.inverse_rescale(x)
# Include Jacobian for the rescaling
log_prob -= log_J
x, log_prob = self.check_prior_bounds(x, log_prob)
return x, log_prob
x, z, log_prob = self.check_prior_bounds(x, z, log_prob)
if return_z:
return x, log_prob, z
else:
return x, log_prob

def radius(self, z, log_q=None):
def radius(self, z, *arrays):
"""
Calculate the radius of a latent point or set of latent points.
If multiple points are parsed the maximum radius is returned.
Expand All @@ -1138,22 +1153,18 @@ def radius(self, z, log_q=None):
----------
z : :obj:`np.ndarray`
Array of points in the latent space
log_q : :obj:`np.ndarray`, optional (None)
Array of corresponding probabilities. If specified
then probability of the maximum radius is also returned.
*arrays :
Additional arrays to return the corresponding value
Returns
-------
tuple of arrays
Tuple of array with the maximum radius and corresponding log_q
if it was a specified input.
Tuple of array with the maximum radius and corresponding values
from any additional arrays that were passed.
"""
if log_q is not None:
r = np.sqrt(np.sum(z**2.0, axis=-1))
i = np.argmax(r)
return r[i], log_q[i]
else:
return np.nanmax(np.sqrt(np.sum(z**2.0, axis=-1)))
r = np.sqrt(np.sum(z**2.0, axis=-1))
i = np.nanargmax(r)
return (r[i],) + (a[i] for a in arrays)

def log_prior(self, x):
"""
Expand Down Expand Up @@ -1219,7 +1230,7 @@ def compute_weights(self, x, log_q):
log_w -= np.max(log_w)
return log_w

def rejection_sampling(self, z, worst_q=None):
def rejection_sampling(self, z, min_log_q=None):
"""
Perform rejection sampling.
Expand All @@ -1230,9 +1241,9 @@ def rejection_sampling(self, z, worst_q=None):
----------
z : ndarray
Samples from the latent space
worst_q : float, optional
min_log_q : float, optional
Lower bound on the log-probability computed using the flow that
is used to truncate new samples. Not recommended.
is used to truncate new samples.
Returns
-------
Expand All @@ -1241,20 +1252,24 @@ def rejection_sampling(self, z, worst_q=None):
array_like
Array of accepted samples in the X space.
"""
x, log_q = self.backward_pass(z, rescale=not self.use_x_prime_prior)
x, log_q, z = self.backward_pass(
z,
rescale=not self.use_x_prime_prior,
discard_nans=False,
return_z=True,
)

if not x.size:
return np.array([]), x

if self.truncate:
if worst_q is None:
raise ValueError(
"`worst_q` is None but truncation is enabled."
)
cut = log_q >= worst_q
x = x[cut]
z = z[cut]
log_q = log_q[cut]
if min_log_q:
above = log_q >= min_log_q
x = x[above]
z = z[above]
log_q = log_q[above]
else:
valid = np.isfinite(log_q)
x, z, log_q = get_subset_arrays(valid, x, z, log_q)

# rescale given priors used initially, need for priors
log_w = self.compute_weights(x, log_q)
Expand Down Expand Up @@ -1311,6 +1326,8 @@ def prep_latent_prior(self):
fuzz=self.fuzz,
)
self._draw_func = self._populate_dist.sample
elif self.latent_prior == "flow":
self._draw_func = lambda N: self.flow.sample_latent_distribution(N)
else:
self._draw_func = partial(
self._draw_latent_prior,
Expand Down Expand Up @@ -1347,25 +1364,29 @@ def populate(self, worst_point, N=10000, plot=True, r=None):
)
if r is not None:
logger.debug(f"Using user inputs for radius {r}")
worst_q = None
elif self.fixed_radius:
r = self.fixed_radius
worst_q = None
else:
logger.debug(f"Populating with worst point: {worst_point}")
if self.compute_radius_with_all:
logger.debug("Using previous live points to compute radius")
worst_point = self.training_data
worst_z, worst_q = self.forward_pass(
worst_z = self.forward_pass(
worst_point, rescale=True, compute_radius=True
)
r, worst_q = self.radius(worst_z, worst_q)
r = self.radius(worst_z)
if self.max_radius and r > self.max_radius:
r = self.max_radius
if self.min_radius and r < self.min_radius:
r = self.min_radius

logger.debug(f"Populating proposal with lantent radius: {r:.5}")
if self.truncate_log_q:
log_q_live_points = self.forward_pass(self.training_data)[1]
min_log_q = log_q_live_points.min()
logger.debug("Truncating with log_q={min_log_q:.3f}")
else:
min_log_q = None

logger.debug(f"Populating proposal with latent radius: {r:.5}")
self.r = r

self.alt_dist = self.get_alt_distribution()
Expand All @@ -1390,7 +1411,7 @@ def populate(self, worst_point, N=10000, plot=True, r=None):
z = self.draw_latent_prior(self.drawsize)
proposed += z.shape[0]

z, x = self.rejection_sampling(z, worst_q)
z, x = self.rejection_sampling(z, min_log_q=min_log_q)

if not x.size:
continue
Expand Down

0 comments on commit 17cf217

Please sign in to comment.