From 05d223398b5c123caf6e3357ea8b1bad53a6edc9 Mon Sep 17 00:00:00 2001 From: mj-will Date: Fri, 5 Jul 2024 17:51:11 +0100 Subject: [PATCH 1/2] ENH: handle one-to-many reparameterisations --- nessai/proposal/flowproposal.py | 5 +++++ nessai/reparameterisations/base.py | 1 + nessai/reparameterisations/combined.py | 4 ++++ 3 files changed, 10 insertions(+) diff --git a/nessai/proposal/flowproposal.py b/nessai/proposal/flowproposal.py index c02d2f44..7c4b6ea6 100644 --- a/nessai/proposal/flowproposal.py +++ b/nessai/proposal/flowproposal.py @@ -801,6 +801,11 @@ def verify_rescaling(self): raise RuntimeError( "Rescaling must be set before it can be verified" ) + if not self._reparameterisation.one_to_one: + logger.warning( + "Could not check if reparameterisation is invertible" + ) + return logger.info("Verifying rescaling functions") x = self.model.new_point(N=1000) for inversion in ["lower", "upper", False, None]: diff --git a/nessai/reparameterisations/base.py b/nessai/reparameterisations/base.py index dab2e480..ba7557b5 100644 --- a/nessai/reparameterisations/base.py +++ b/nessai/reparameterisations/base.py @@ -28,6 +28,7 @@ class Reparameterisation: requires_bounded_prior = False prior_bounds = None prime_prior_bounds = None + one_to_one = True def __init__(self, parameters=None, prior_bounds=None): if not isinstance(parameters, (str, list)): diff --git a/nessai/reparameterisations/combined.py b/nessai/reparameterisations/combined.py index 0ffb1687..6bfb096e 100644 --- a/nessai/reparameterisations/combined.py +++ b/nessai/reparameterisations/combined.py @@ -46,6 +46,10 @@ def requires_prime_prior(self): """Boolean to check if any of the priors require the prime space""" return any(r.requires_prime_prior for r in self.values()) + @property + def one_to_one(self): + return all(r.one_to_one for r in self.values()) + @property def to_prime_order(self): """Order when converting to the prime space""" From fd057d307b4efdc37a6d71c47cf6c84474a6bbd5 Mon Sep 17 00:00:00 2001 From: mj-will Date: Tue, 20 Aug 2024 12:29:17 +0100 Subject: [PATCH 2/2] TST: update tests for one_to_one --- .../test_flowproposal_rescaling.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_proposal/test_flowproposal/test_flowproposal_rescaling.py b/tests/test_proposal/test_flowproposal/test_flowproposal_rescaling.py index 195cbdf6..30dd5c7d 100644 --- a/tests/test_proposal/test_flowproposal/test_flowproposal_rescaling.py +++ b/tests/test_proposal/test_flowproposal/test_flowproposal_rescaling.py @@ -576,6 +576,7 @@ def test_verify_rescaling(proposal, has_inversion): proposal.check_state = MagicMock() proposal.rescaling_set = True proposal._reparameterisation = MagicMock() + proposal._reparameterisation.one_to_one = True FlowProposal.verify_rescaling(proposal) @@ -611,6 +612,8 @@ def test_verify_rescaling_invertible_error(proposal, has_inversion): proposal.rescale = MagicMock(return_value=(x_prime, log_j)) proposal.inverse_rescale = MagicMock(return_value=(x_out, log_j_inv)) proposal.rescaling_set = True + proposal._reparameterisation = MagicMock() + proposal._reparameterisation.one_to_one = True with pytest.raises(RuntimeError) as excinfo: FlowProposal.verify_rescaling(proposal) @@ -640,6 +643,8 @@ def test_verify_rescaling_invertible_error_non_sampling( proposal.rescale = MagicMock(return_value=(x_prime, log_j)) proposal.inverse_rescale = MagicMock(return_value=(x_out, log_j_inv)) proposal.rescaling_set = True + proposal._reparameterisation = MagicMock() + proposal._reparameterisation.one_to_one = True with pytest.raises(RuntimeError) as excinfo: FlowProposal.verify_rescaling(proposal) @@ -665,6 +670,8 @@ def test_verify_rescaling_jacobian_error(proposal, has_inversion): proposal.rescale = MagicMock(return_value=(x_prime, log_j)) proposal.inverse_rescale = MagicMock(return_value=(x_out, log_j_inv)) proposal.rescaling_set = True + proposal._reparameterisation = MagicMock() + proposal._reparameterisation.one_to_one = True with pytest.raises(RuntimeError) as excinfo: FlowProposal.verify_rescaling(proposal) @@ -678,6 +685,18 @@ def test_verify_rescaling_rescaling_not_set(proposal): FlowProposal.verify_rescaling(proposal) +def test_verify_rescaling_not_one_to_one(proposal, caplog): + proposal.rescaling_set = True + proposal._reparameterisation = MagicMock() + proposal._reparameterisation.one_to_one = False + proposal.model.new_point = MagicMock() + FlowProposal.verify_rescaling(proposal) + assert "Could not check if reparameterisation is invertible" in str( + caplog.text + ) + proposal.model.new_point.assert_not_called() + + def test_check_state_update(proposal, map_to_unit_hypercube): """Assert the update method is called""" x = numpy_array_to_live_points(np.random.randn(10, 2), ["x", "y"])