Skip to content

Commit

Permalink
Merge pull request #398 from mj-will/support-unit-hypercube
Browse files Browse the repository at this point in the history
Support unit hypercube
  • Loading branch information
mj-will authored Aug 19, 2024
2 parents 71f6e0f + 7bb1ad8 commit 7836002
Show file tree
Hide file tree
Showing 19 changed files with 808 additions and 211 deletions.
3 changes: 2 additions & 1 deletion nessai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,8 @@ def log_prior_unit_hypercube(self, x) -> np.ndarray:
hypercube.
"""
x = self.unstructured_view(x)
return np.log(~np.any((x < 0) | (x >= 1), axis=-1))
with np.errstate(divide="ignore"):
return np.log(~np.any((x < 0) | (x >= 1), axis=-1))

def from_unit_hypercube(self, x):
"""Map from the unit hypercube to the priors.
Expand Down
87 changes: 76 additions & 11 deletions nessai/proposal/flowproposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(
fallback_reparameterisation="zscore",
use_default_reparameterisations=None,
reverse_reparameterisations=False,
map_to_unit_hypercube=False,
):

super(FlowProposal, self).__init__(model)
Expand All @@ -174,6 +175,7 @@ def __init__(
self._x_prime_dtype = None
self._draw_func = None
self._populate_dist = None
self._prior_bounds = None

self.flow = None
self._flow_config = None
Expand All @@ -194,6 +196,7 @@ def __init__(
self.rescaling_set = False
self.use_x_prime_prior = False
self.should_update_reparameterisations = False
self.map_to_unit_hypercube = map_to_unit_hypercube
self.accumulate_weights = accumulate_weights

self.reparameterisations = reparameterisations
Expand Down Expand Up @@ -305,6 +308,23 @@ def population_dtype(self):
else:
return self.x_dtype

@property
def prior_bounds(self):
"""The priors bounds used when computing the priors.
If :code:`map_to_unit_hypercube` is true, these will be [0, 1]
"""
if self._prior_bounds is None:
if self.map_to_unit_hypercube:
logger.debug("Setting prior bounds to the unit-hypercube")
self._prior_bounds = {
n: np.array([0.0, 1.0]) for n in self.model.names
}
else:
logger.debug("Setting prior bounds to the model prior bounds")
self._prior_bounds = self.model.bounds
return self._prior_bounds

def configure_population(
self,
poolsize,
Expand Down Expand Up @@ -665,12 +685,12 @@ def configure_reparameterisations(self, reparameterisations):

if isinstance(default_config["parameters"], list):
prior_bounds = {
p: self.model.bounds[p]
p: self.prior_bounds[p]
for p in default_config["parameters"]
}
else:
prior_bounds = {
default_config["parameters"]: self.model.bounds[
default_config["parameters"]: self.prior_bounds[
default_config["parameters"]
]
}
Expand All @@ -693,7 +713,7 @@ def configure_reparameterisations(self, reparameterisations):
self.fallback_reparameterisation
)
fallback_kwargs["prior_bounds"] = {
p: self.model.bounds[p] for p in other_params
p: self.prior_bounds[p] for p in other_params
}
logger.info(
f"Assuming fallback reparameterisation "
Expand All @@ -712,6 +732,10 @@ def configure_reparameterisations(self, reparameterisations):
self.use_x_prime_prior = True
self.x_prime_log_prior = self._reparameterisation.x_prime_log_prior
logger.debug("Using x prime prior")
if self.map_to_unit_hypercube:
raise RuntimeError(
"x prime prior does not support map to unit hypercube"
)
else:
logger.debug("Prime prior is disabled")
if self._reparameterisation.requires_prime_prior:
Expand Down Expand Up @@ -839,6 +863,9 @@ def rescale(self, x, compute_radius=False, **kwargs):
if x.size == 1:
x = np.array([x], dtype=x.dtype)

if self.map_to_unit_hypercube:
x = self.model.to_unit_hypercube(x)

x, x_prime, log_J = self._reparameterisation.reparameterise(
x, x_prime, log_J, compute_radius=compute_radius, **kwargs
)
Expand All @@ -847,7 +874,7 @@ def rescale(self, x, compute_radius=False, **kwargs):
x_prime[p] = x[p]
return x_prime, log_J

def inverse_rescale(self, x_prime, **kwargs):
def inverse_rescale(self, x_prime, return_unit_hypercube=False, **kwargs):
"""
Rescale from the primed physical space to the original physical
space.
Expand All @@ -872,6 +899,10 @@ def inverse_rescale(self, x_prime, **kwargs):

for p in config.livepoints.non_sampling_parameters:
x[p] = x_prime[p]

if self.map_to_unit_hypercube and not return_unit_hypercube:
x = self.model.from_unit_hypercube(x)

return x, log_J

def check_state(self, x):
Expand All @@ -884,6 +915,8 @@ def check_state(self, x):
x: array_like
Array of training live points which can be used to set parameters
"""
if self.map_to_unit_hypercube:
x = self.model.to_unit_hypercube(x)
self._reparameterisation.update(x)

@nessai_style()
Expand All @@ -892,7 +925,7 @@ def _plot_training_data(self, output):
z_training_data, _ = self.forward_pass(
self.training_data, rescale=True
)
z_gen = np.random.randn(self.training_data.size, self.dims)
z_gen = self.flow.sample_latent_distribution(self.training_data.size)

fig = plt.figure()
plt.hist(np.sqrt(np.sum(z_training_data**2, axis=1)), "auto")
Expand Down Expand Up @@ -1083,7 +1116,12 @@ def forward_pass(self, x, rescale=True, compute_radius=True):
return z, log_prob + log_J

def backward_pass(
self, z, rescale=True, discard_nans=True, return_z=False
self,
z,
rescale=True,
discard_nans=True,
return_z=False,
return_unit_hypercube=False,
):
"""
A backwards pass from the model (latent -> real)
Expand Down Expand Up @@ -1127,10 +1165,13 @@ def backward_pass(
)
# Apply rescaling in rescale=True
if rescale:
x, log_J = self.inverse_rescale(x)
x, log_J = self.inverse_rescale(
x, return_unit_hypercube=return_unit_hypercube
)
# Include Jacobian for the rescaling
log_prob -= log_J
x, z, log_prob = self.check_prior_bounds(x, z, log_prob)
if not return_unit_hypercube:
x, z, log_prob = self.check_prior_bounds(x, z, log_prob)
if return_z:
return x, log_prob, z
else:
Expand Down Expand Up @@ -1193,6 +1234,22 @@ def x_prime_log_prior(self, x):
"""
raise RuntimeError("Prime prior is not implemented")

def unit_hypercube_log_prior(self, x):
"""
Compute the prior in the unit hypercube space.
Parameters
----------
x : array
Samples in the unit hypercube.
"""
if self._reparameterisation:
return self.model.batch_evaluate_log_prior_unit_hypercube(
x
) + self._reparameterisation.log_prior(x)
else:
return self.model.batch_evaluate_log_prior_unit_hypercube(x)

def compute_weights(self, x, log_q, return_log_prior=False):
"""
Compute weights for the samples.
Expand All @@ -1215,6 +1272,8 @@ def compute_weights(self, x, log_q, return_log_prior=False):
"""
if self.use_x_prime_prior:
log_p = self.x_prime_log_prior(x)
elif self.map_to_unit_hypercube:
log_p = self.unit_hypercube_log_prior(x)
else:
log_p = self.log_prior(x)

Expand Down Expand Up @@ -1257,6 +1316,7 @@ def rejection_sampling(self, z, min_log_q=None):
rescale=not self.use_x_prime_prior,
discard_nans=False,
return_z=True,
return_unit_hypercube=self.map_to_unit_hypercube,
)

if not x.size:
Expand Down Expand Up @@ -1312,10 +1372,13 @@ def convert_to_samples(self, x, plot=True):
)

x, _ = self.inverse_rescale(x)
x["logP"] = self.model.batch_evaluate_log_prior(x)
return rfn.repack_fields(
elif self.map_to_unit_hypercube:
x = self.model.from_unit_hypercube(x)
x = rfn.repack_fields(
x[self.model.names + config.livepoints.non_sampling_parameters]
)
x["logP"] = self.model.batch_evaluate_log_prior(x)
return x

def prep_latent_prior(self):
"""Prepare the latent prior."""
Expand Down Expand Up @@ -1422,7 +1485,9 @@ def populate(
n_proposed += z.shape[0]

x, log_q = self.backward_pass(
z, rescale=not self.use_x_prime_prior
z,
rescale=not self.use_x_prime_prior,
return_unit_hypercube=self.map_to_unit_hypercube,
)
if self.truncate_log_q:
above_min_log_q = log_q > min_log_q
Expand Down
48 changes: 48 additions & 0 deletions nessai/reparameterisations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,54 @@
ScaleAndShift,
{"estimate_scale": True, "estimate_shift": True},
),
"zscore-gaussian-cdf": (
ScaleAndShift,
{
"estimate_scale": True,
"estimate_shift": True,
"post_rescaling": "gaussian_cdf",
},
),
"z-score-gaussian-cdf": (
ScaleAndShift,
{
"estimate_scale": True,
"estimate_shift": True,
"post_rescaling": "gaussian_cdf",
},
),
"z-score-logit": (
ScaleAndShift,
{
"estimate_scale": True,
"estimate_shift": True,
"pre_rescaling": "logit",
},
),
"zscore-logit": (
ScaleAndShift,
{
"estimate_scale": True,
"estimate_shift": True,
"pre_rescaling": "logit",
},
),
"z-score-inv-gaussian-cdf": (
ScaleAndShift,
{
"estimate_scale": True,
"estimate_shift": True,
"pre_rescaling": "inv_gaussian_cdf",
},
),
"zscore-inv-gaussian-cdf": (
ScaleAndShift,
{
"estimate_scale": True,
"estimate_shift": True,
"pre_rescaling": "inv_gaussian_cdf",
},
),
"angle": (Angle, {}),
"angle-pi": (Angle, {"scale": 2.0, "prior": "uniform"}),
"angle-2pi": (Angle, {"scale": 1.0, "prior": "uniform"}),
Expand Down
Loading

0 comments on commit 7836002

Please sign in to comment.