Skip to content

Commit

Permalink
TST: add more tests for map to unit hypercube
Browse files Browse the repository at this point in the history
  • Loading branch information
mj-will committed Aug 19, 2024
1 parent 3f9c6d5 commit 7bb1ad8
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,41 @@ def test_prime_log_prior(proposal):
assert "Prime prior is not implemented" in str(excinfo.value)


def test_unit_hypercube_log_prior_wo_reparameterisation(proposal, x):

log_prior = -np.ones(x.size)
proposal._reparameterisation = None
proposal.model = MagicMock()
proposal.model.batch_evaluate_log_prior_unit_hypercube = MagicMock(
return_value=log_prior
)

log_prior_out = FlowProposal.unit_hypercube_log_prior(proposal, x)

assert np.array_equal(log_prior, log_prior_out)
proposal.model.batch_evaluate_log_prior_unit_hypercube.assert_called_once_with( # noqa: E501
x
)


def test_unit_hypercube_log_prior_w_reparameterisation(proposal, x):
log_prior = -np.ones(x.size)
proposal._reparameterisation = MagicMock()
proposal._reparameterisation.log_prior = MagicMock(return_value=log_prior)
proposal.model = MagicMock()
proposal.model.batch_evaluate_log_prior_unit_hypercube = MagicMock(
return_value=log_prior.copy()
)

log_prior_out = FlowProposal.unit_hypercube_log_prior(proposal, x)

assert np.array_equal(log_prior_out, -2 * np.ones(x.size))
proposal._reparameterisation.log_prior.assert_called_once_with(x)
proposal.model.batch_evaluate_log_prior_unit_hypercube.assert_called_once_with( # noqa: E501
x
)


@pytest.mark.parametrize(
"acceptance, scale", [(0.0, 10.0), (0.5, 2.0), (0.01, 10.0), (2.0, 1.0)]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,39 @@ def test_configure_reparameterisations_requires_prime_prior(
assert "One or more reparameterisations require " in str(excinfo.value)


def test_configure_reparameterisations_prime_prior_unit_hypercube(
proposal, dummy_rc, dummy_cmb_rc
):
dummy_rc.return_value = "r"
# Need to add the parameters before hand to prevent a
# NullReparameterisation from being added
dummy_cmb_rc.parameters = ["x", "y"]
dummy_cmb_rc.has_prime_prior = True
dummy_cmb_rc.requires_prime_prior = True
proposal.add_default_reparameterisations = MagicMock()
proposal.get_reparameterisation = MagicMock(
return_value=(
dummy_rc,
{},
)
)
proposal.model.bounds = {"x": [-1, 1], "y": [-1, 1]}
proposal.model.names = ["x", "y"]
proposal.map_to_unit_hypercube = True

with patch(
"nessai.proposal.flowproposal.CombinedReparameterisation",
return_value=dummy_cmb_rc,
), pytest.raises(
RuntimeError,
match="x prime prior does not support map to unit hypercube",
):
FlowProposal.configure_reparameterisations(
proposal,
{"x": {"reparameterisation": "default", "parameters": ["y"]}},
)


@patch("nessai.reparameterisations.CombinedReparameterisation")
def test_configure_reparameterisations_dict_missing(mocked_class, proposal):
"""
Expand Down Expand Up @@ -483,7 +516,10 @@ def test_rescale(proposal, n, map_to_unit_hypercube):


@pytest.mark.parametrize("n", [1, 10])
def test_inverse_rescale(proposal, n, map_to_unit_hypercube):
@pytest.mark.parametrize("return_unit_hypercube", [True, False])
def test_inverse_rescale(
proposal, n, map_to_unit_hypercube, return_unit_hypercube
):
"""Test rescaling when using reparameterisation dict"""
x = numpy_array_to_live_points(np.random.randn(n, 2), ["x", "y"]).squeeze()
x_prime = numpy_array_to_live_points(
Expand All @@ -499,14 +535,16 @@ def test_inverse_rescale(proposal, n, map_to_unit_hypercube):
)
proposal.model.from_unit_hypercube = MagicMock(side_effect=lambda a: a)

x_out, log_j = FlowProposal.inverse_rescale(proposal, x_prime)
x_out, log_j = FlowProposal.inverse_rescale(
proposal, x_prime, return_unit_hypercube=return_unit_hypercube
)

np.testing.assert_array_equal(x[["x", "y"]], x_out[["x", "y"]])
np.testing.assert_array_equal(
x_prime[["logP", "logL"]], x_out[["logP", "logL"]]
)
proposal._reparameterisation.inverse_reparameterise.assert_called_once()
if map_to_unit_hypercube:
if map_to_unit_hypercube and not return_unit_hypercube:
proposal.model.from_unit_hypercube.assert_called_once_with(x)
else:
proposal.model.from_unit_hypercube.assert_not_called()
Expand Down

0 comments on commit 7bb1ad8

Please sign in to comment.