diff --git a/nessai/proposal/flowproposal.py b/nessai/proposal/flowproposal.py index 1ef1127a..b2192d98 100644 --- a/nessai/proposal/flowproposal.py +++ b/nessai/proposal/flowproposal.py @@ -34,6 +34,7 @@ save_live_points, ) from ..utils.sampling import NDimensionalTruncatedGaussian +from ..utils.structures import get_subset_arrays logger = logging.getLogger(__name__) @@ -100,9 +101,8 @@ class FlowProposal(RejectionProposal): Similar to ``fuzz`` but instead a scaling factor applied to the radius this specifies a rescaling for volume of the n-ball used to draw samples. This is translated to a value for ``fuzz``. - truncate : bool, optional - Truncate proposals using probability compute for worst point. - Not recommended. + truncate_log_q : bool, optional + Truncate proposals using minimum log-probability of the training data. rescale_parameters : list or bool, optional If True live points are rescaled to `rescale_bounds` before training. If an instance of `list` then must contain names of parameters to @@ -174,7 +174,7 @@ def __init__( fixed_radius=False, drawsize=None, check_acceptance=False, - truncate=False, + truncate_log_q=False, rescale_bounds=[-1, 1], expansion_fraction=4.0, boundary_inversion=False, @@ -245,7 +245,7 @@ def __init__( self.update_bounds = update_bounds self.check_acceptance = check_acceptance self.rescale_bounds = rescale_bounds - self.truncate = truncate + self.truncate_log_q = truncate_log_q self.boundary_inversion = boundary_inversion self.inversion_type = inversion_type self.flow_config = flow_config @@ -420,6 +420,8 @@ def configure_latent_prior(self): from ..utils import draw_nsphere self._draw_latent_prior = draw_nsphere + elif self.latent_prior == "flow": + self._draw_latent_prior = None else: raise RuntimeError( f"Unknown latent prior: {self.latent_prior}, choose from: " @@ -1088,7 +1090,9 @@ def forward_pass(self, x, rescale=True, compute_radius=True): return z, log_prob + log_J - def backward_pass(self, z, rescale=True): + def backward_pass( + self, z, rescale=True, discard_nans=True, return_z=False + ): """ A backwards pass from the model (latent -> real) @@ -1098,14 +1102,21 @@ def backward_pass(self, z, rescale=True): Structured array of points in the latent space rescale : bool, optional (True) Apply inverse rescaling function + discard_nan: bool + If True, samples with NaNs or Infs in log_q are removed. + return_z : bool + If True, return the array of latent samples, this may differ from + the input since samples can be discarded. Returns ------- x : array_like - Samples in the latent space + Samples in the data space log_prob : array_like Log probabilities corresponding to each sample (including the Jacobian) + z : array_like + Samples in the latent space, only returned if :code:`return_z=True` """ # Compute the log probability try: @@ -1115,8 +1126,9 @@ def backward_pass(self, z, rescale=True): except AssertionError: return np.array([]), np.array([]) - valid = np.isfinite(log_prob) - x, log_prob = x[valid], log_prob[valid] + if discard_nans: + valid = np.isfinite(log_prob) + x, log_prob = x[valid], log_prob[valid] x = numpy_array_to_live_points( x.astype(config.livepoints.default_float_dtype), self.rescaled_names, @@ -1126,10 +1138,13 @@ def backward_pass(self, z, rescale=True): x, log_J = self.inverse_rescale(x) # Include Jacobian for the rescaling log_prob -= log_J - x, log_prob = self.check_prior_bounds(x, log_prob) - return x, log_prob + x, z, log_prob = self.check_prior_bounds(x, z, log_prob) + if return_z: + return x, log_prob, z + else: + return x, log_prob - def radius(self, z, log_q=None): + def radius(self, z, *arrays): """ Calculate the radius of a latent point or set of latent points. If multiple points are parsed the maximum radius is returned. @@ -1138,22 +1153,18 @@ def radius(self, z, log_q=None): ---------- z : :obj:`np.ndarray` Array of points in the latent space - log_q : :obj:`np.ndarray`, optional (None) - Array of corresponding probabilities. If specified - then probability of the maximum radius is also returned. + *arrays : + Additional arrays to return the corresponding value Returns ------- tuple of arrays - Tuple of array with the maximum radius and corresponding log_q - if it was a specified input. + Tuple of array with the maximum radius and corresponding values + from any additional arrays that were passed. """ - if log_q is not None: - r = np.sqrt(np.sum(z**2.0, axis=-1)) - i = np.argmax(r) - return r[i], log_q[i] - else: - return np.nanmax(np.sqrt(np.sum(z**2.0, axis=-1))) + r = np.sqrt(np.sum(z**2.0, axis=-1)) + i = np.nanargmax(r) + return (r[i],) + (a[i] for a in arrays) def log_prior(self, x): """ @@ -1219,7 +1230,7 @@ def compute_weights(self, x, log_q): log_w -= np.max(log_w) return log_w - def rejection_sampling(self, z, worst_q=None): + def rejection_sampling(self, z, min_log_q=None): """ Perform rejection sampling. @@ -1230,9 +1241,9 @@ def rejection_sampling(self, z, worst_q=None): ---------- z : ndarray Samples from the latent space - worst_q : float, optional + min_log_q : float, optional Lower bound on the log-probability computed using the flow that - is used to truncate new samples. Not recommended. + is used to truncate new samples. Returns ------- @@ -1241,20 +1252,24 @@ def rejection_sampling(self, z, worst_q=None): array_like Array of accepted samples in the X space. """ - x, log_q = self.backward_pass(z, rescale=not self.use_x_prime_prior) + x, log_q, z = self.backward_pass( + z, + rescale=not self.use_x_prime_prior, + discard_nans=False, + return_z=True, + ) if not x.size: return np.array([]), x - if self.truncate: - if worst_q is None: - raise ValueError( - "`worst_q` is None but truncation is enabled." - ) - cut = log_q >= worst_q - x = x[cut] - z = z[cut] - log_q = log_q[cut] + if min_log_q: + above = log_q >= min_log_q + x = x[above] + z = z[above] + log_q = log_q[above] + else: + valid = np.isfinite(log_q) + x, z, log_q = get_subset_arrays(valid, x, z, log_q) # rescale given priors used initially, need for priors log_w = self.compute_weights(x, log_q) @@ -1311,6 +1326,8 @@ def prep_latent_prior(self): fuzz=self.fuzz, ) self._draw_func = self._populate_dist.sample + elif self.latent_prior == "flow": + self._draw_func = lambda N: self.flow.sample_latent_distribution(N) else: self._draw_func = partial( self._draw_latent_prior, @@ -1347,25 +1364,29 @@ def populate(self, worst_point, N=10000, plot=True, r=None): ) if r is not None: logger.debug(f"Using user inputs for radius {r}") - worst_q = None elif self.fixed_radius: r = self.fixed_radius - worst_q = None else: logger.debug(f"Populating with worst point: {worst_point}") if self.compute_radius_with_all: logger.debug("Using previous live points to compute radius") - worst_point = self.training_data - worst_z, worst_q = self.forward_pass( + worst_z = self.forward_pass( worst_point, rescale=True, compute_radius=True ) - r, worst_q = self.radius(worst_z, worst_q) + r = self.radius(worst_z) if self.max_radius and r > self.max_radius: r = self.max_radius if self.min_radius and r < self.min_radius: r = self.min_radius - logger.debug(f"Populating proposal with lantent radius: {r:.5}") + if self.truncate_log_q: + log_q_live_points = self.forward_pass(self.training_data)[1] + min_log_q = log_q_live_points.min() + logger.debug("Truncating with log_q={min_log_q:.3f}") + else: + min_log_q = None + + logger.debug(f"Populating proposal with latent radius: {r:.5}") self.r = r self.alt_dist = self.get_alt_distribution() @@ -1390,7 +1411,7 @@ def populate(self, worst_point, N=10000, plot=True, r=None): z = self.draw_latent_prior(self.drawsize) proposed += z.shape[0] - z, x = self.rejection_sampling(z, worst_q) + z, x = self.rejection_sampling(z, min_log_q=min_log_q) if not x.size: continue