Skip to content

Commit

Permalink
TST: update tests for flowproposal refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mj-will committed Aug 20, 2024
1 parent 3c98b37 commit ef45202
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,23 @@

def test_config_drawsize_none(proposal):
"""Test the popluation configuration with no drawsize given"""
FlowProposal.configure_population(
proposal, 2000, None, True, 10, 1.0, 0.0, "gaussian"
)
proposal.poolsize = 2000
FlowProposal.configure_population(proposal, None, 1.0, 0.0, "gaussian")
assert proposal.drawsize == 2000


def test_config_poolsize_none(proposal):
"""
Test the popluation configuration raises an error if poolsize is None.
"""
with pytest.raises(RuntimeError) as excinfo:
FlowProposal.configure_population(
proposal, None, None, True, 10, 1.0, 0.0, "gaussian"
with pytest.raises(RuntimeError, match=r"Must specify `poolsize`"):
FlowProposal.configure_poolsize(
proposal,
None,
True,
10,
)

assert "poolsize" in str(excinfo.value)


@pytest.mark.parametrize("fixed_radius", [False, 5.0, 1])
def test_config_fixed_radius(proposal, fixed_radius):
Expand Down Expand Up @@ -128,7 +128,8 @@ def test_configure_constant_volume(proposal, latent_prior):
proposal.min_radius = 5.0
proposal.fuzz = 1.5
with patch(
"nessai.proposal.flowproposal.compute_radius", return_value=4.0
"nessai.proposal.flowproposal.flowproposal.compute_radius",
return_value=4.0,
) as mock:
FlowProposal.configure_constant_volume(proposal)
mock.assert_called_once_with(5, 0.95)
Expand All @@ -141,7 +142,9 @@ def test_configure_constant_volume(proposal, latent_prior):
def test_configure_constant_volume_disabled(proposal):
"""Assert nothing happens if constant_volume is False"""
proposal.constant_volume_mode = False
with patch("nessai.proposal.flowproposal.compute_radius") as mock:
with patch(
"nessai.proposal.flowproposal.flowproposal.compute_radius"
) as mock:
FlowProposal.configure_constant_volume(proposal)
mock.assert_not_called()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,11 @@ def test_training(proposal, tmp_path, save, plot, plot_training):
proposal._plot_training_data = MagicMock()

with patch(
"nessai.proposal.flowproposal.live_points_to_array",
"nessai.proposal.flowproposal.base.live_points_to_array",
return_value=data_prime,
), patch("nessai.proposal.flowproposal.save_live_points") as mock_save:
), patch(
"nessai.proposal.flowproposal.base.save_live_points"
) as mock_save:
FlowProposal.train(proposal, x, plot=plot)

assert_structured_arrays_equal(x, proposal.training_data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_plot_pool_all(proposal):
proposal._plot_pool = "all"
proposal.populated_count = 0
x = numpy_array_to_live_points(np.random.randn(10, 2), ["x", "y"])
with patch("nessai.proposal.flowproposal.plot_live_points") as plot:
with patch("nessai.proposal.flowproposal.base.plot_live_points") as plot:
FlowProposal.plot_pool(proposal, x)
plot.assert_called_once_with(
x, c="logL", filename=os.path.join("test", "pool_0.png")
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_plot_pool_1d(proposal, tmpdir, alt_dist):
return_value=log_p
)
proposal.alt_dist = None
with patch("nessai.proposal.flowproposal.plot_1d_comparison") as plot:
with patch("nessai.proposal.flowproposal.base.plot_1d_comparison") as plot:
FlowProposal.plot_pool(proposal, x)

plot.assert_called_once_with(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,78 +165,6 @@ def test_compute_weights_unit_hypercube(proposal, x, log_q):
np.testing.assert_array_equal(log_w, out)


@patch("numpy.random.rand", return_value=np.array([0.1, 0.9]))
def test_rejection_sampling(proposal, z, x, log_q):
"""Test rejection sampling method."""
proposal.use_x_prime_prior = False
proposal.truncate = False
proposal.backward_pass = MagicMock(return_value=(x, log_q, z))
log_w = np.log(np.array([0.5, 0.5]))
proposal.compute_weights = MagicMock(return_value=log_w)

z_out, x_out = FlowProposal.rejection_sampling(proposal, z)

proposal.backward_pass.assert_called_once_with(
z,
rescale=True,
return_z=True,
discard_nans=False,
return_unit_hypercube=proposal.map_to_unit_hypercube,
)
proposal.compute_weights.assert_called_once()
assert x_out.size == 1
assert z_out.shape == (1, 2)
assert_structured_arrays_equal(x_out[0], x[0])
assert np.array_equal(z_out[0], z[0])


def test_rejection_sampling_empty(proposal, z):
"""Test rejection sampling method if no valid points are produced by
`backwards_pass`
"""
proposal.use_x_prime_prior = False
proposal.truncate = False
proposal.backward_pass = MagicMock(
return_value=(np.array([]), np.array([]), np.array([]))
)

z_out, x_out = FlowProposal.rejection_sampling(proposal, z)

assert x_out.size == 0
assert z_out.size == 0


@patch("numpy.random.rand", return_value=np.array([0.1]))
def test_rejection_sampling_truncate(proposal, z, x):
"""Test rejection sampling method with truncation"""
proposal.use_x_prime_prior = False
proposal.truncate = True
log_q = np.array([0.0, 1.0])
proposal.backward_pass = MagicMock(return_value=(x, log_q, z))
min_log_q = 0.5
log_w = np.log(np.array([0.5]))
proposal.compute_weights = MagicMock(return_value=log_w)

z_out, x_out = FlowProposal.rejection_sampling(
proposal,
z,
min_log_q=min_log_q,
)

proposal.backward_pass.assert_called_once_with(
z,
rescale=True,
return_z=True,
discard_nans=False,
return_unit_hypercube=proposal.map_to_unit_hypercube,
)
proposal.compute_weights.assert_called_once()
assert x_out.size == 1
assert z_out.shape == (1, 2)
assert_structured_arrays_equal(x_out[0], x[1])
assert np.array_equal(z_out[0], z[1])


def test_compute_acceptance(proposal):
"""Test the compute_acceptance method"""
proposal.samples = np.arange(1, 11, dtype=float).view([("logL", "f8")])
Expand All @@ -261,7 +189,7 @@ def test_convert_to_samples(proposal):
)


@patch("nessai.proposal.flowproposal.plot_1d_comparison")
@patch("nessai.proposal.flowproposal.base.plot_1d_comparison")
def test_convert_to_samples_with_prime(mock_plot, proposal):
"""Test convert to sample with the prime prior"""
samples = numpy_array_to_live_points(np.random.randn(10, 2), ["x", "y"])
Expand Down Expand Up @@ -337,7 +265,7 @@ def test_get_alt_distribution_uniform(proposal, prior):
proposal.flow = Mock()
proposal.flow.device = "cpu"
with patch(
"nessai.proposal.flowproposal.get_uniform_distribution"
"nessai.proposal.flowproposal.flowproposal.get_uniform_distribution"
) as mock:
dist = FlowProposal.get_alt_distribution(proposal)

Expand Down Expand Up @@ -388,7 +316,7 @@ def test_prep_latent_prior_truncated(proposal):
dist.sample = MagicMock()

with patch(
"nessai.proposal.flowproposal.NDimensionalTruncatedGaussian",
"nessai.proposal.flowproposal.flowproposal.NDimensionalTruncatedGaussian", # noqa: E501
return_value=dist,
) as mock_dist:
FlowProposal.prep_latent_prior(proposal)
Expand All @@ -412,7 +340,8 @@ def draw(dims, N=None, r=None, fuzz=None):
proposal._draw_latent_prior = draw

with patch(
"nessai.proposal.flowproposal.partial", side_effect=partial
"nessai.proposal.flowproposal.flowproposal.partial",
side_effect=partial,
) as mock_partial:
FlowProposal.prep_latent_prior(proposal)

Expand Down Expand Up @@ -538,7 +467,7 @@ def convert_to_samples(samples, plot):

x_empty = np.empty(0, dtype=proposal.population_dtype)
with patch(
"nessai.proposal.flowproposal.empty_structured_array",
"nessai.proposal.flowproposal.flowproposal.empty_structured_array",
return_value=x_empty,
) as mock_empty, patch(
"numpy.random.rand", return_value=rand_u
Expand Down Expand Up @@ -714,7 +643,7 @@ def convert_to_samples(samples, plot):

x_empty = np.empty(poolsize, dtype=proposal.population_dtype)
with patch(
"nessai.proposal.flowproposal.empty_structured_array",
"nessai.proposal.flowproposal.flowproposal.empty_structured_array",
return_value=x_empty,
) as mock_empty, patch(
"numpy.random.rand", side_effect=rand_u
Expand Down Expand Up @@ -881,7 +810,7 @@ def test_populate_truncate_log_q(proposal):

x_empty = np.empty(0, dtype=proposal.population_dtype)
with patch(
"nessai.proposal.flowproposal.empty_structured_array",
"nessai.proposal.flowproposal.flowproposal.empty_structured_array",
return_value=x_empty,
) as mock_empty, patch(
"numpy.random.rand", return_value=rand_u
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from nessai.livepoint import get_dtype, numpy_array_to_live_points
from nessai.model import Model
from nessai.proposal import FlowProposal
from nessai.proposal.flowproposal import FlowProposal
from nessai.reparameterisations import (
NullReparameterisation,
RescaleToBounds,
Expand Down Expand Up @@ -45,7 +45,7 @@ def test_default_reparameterisation(proposal):
FlowProposal.add_default_reparameterisations(proposal)


@patch("nessai.proposal.flowproposal.get_reparameterisation")
@patch("nessai.proposal.flowproposal.base.get_reparameterisation")
def test_get_reparamaterisation(mocked_fn, proposal):
"""Make sure the underlying function is called"""
FlowProposal.get_reparameterisation(proposal, "angle")
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_configure_reparameterisations_dict(
proposal.prior_bounds = proposal.model.bounds

with patch(
"nessai.proposal.flowproposal.CombinedReparameterisation",
"nessai.proposal.flowproposal.base.CombinedReparameterisation",
return_value=dummy_cmb_rc,
) as mocked_class:
FlowProposal.configure_reparameterisations(
Expand All @@ -111,7 +111,7 @@ def test_configure_reparameterisations_dict(
assert proposal.parameters == ["x"]


@patch("nessai.proposal.flowproposal.CombinedReparameterisation")
@patch("nessai.proposal.flowproposal.base.CombinedReparameterisation")
def test_configure_reparameterisations_dict_w_params(
mocked_class, proposal, dummy_rc, dummy_cmb_rc
):
Expand Down Expand Up @@ -140,7 +140,7 @@ def test_configure_reparameterisations_dict_w_params(
proposal.prior_bounds = proposal.model.bounds

with patch(
"nessai.proposal.flowproposal.CombinedReparameterisation",
"nessai.proposal.flowproposal.base.CombinedReparameterisation",
return_value=dummy_cmb_rc,
) as mocked_class:
FlowProposal.configure_reparameterisations(
Expand Down Expand Up @@ -188,7 +188,7 @@ def test_configure_reparameterisations_requires_prime_prior(
proposal.map_to_unit_hypercube = False

with patch(
"nessai.proposal.flowproposal.CombinedReparameterisation",
"nessai.proposal.flowproposal.base.CombinedReparameterisation",
return_value=dummy_cmb_rc,
), pytest.raises(RuntimeError) as excinfo:
FlowProposal.configure_reparameterisations(
Expand Down Expand Up @@ -220,7 +220,7 @@ def test_configure_reparameterisations_prime_prior_unit_hypercube(
proposal.map_to_unit_hypercube = True

with patch(
"nessai.proposal.flowproposal.CombinedReparameterisation",
"nessai.proposal.flowproposal.base.CombinedReparameterisation",
return_value=dummy_cmb_rc,
), pytest.raises(
RuntimeError,
Expand Down Expand Up @@ -440,6 +440,7 @@ def test_set_rescaling_with_model(proposal, model):
"""
proposal.model = model
proposal.model.reparameterisations = {"x": "default"}
proposal.expansion_fraction = None

def update(self):
proposal.parameters = model.names
Expand All @@ -455,6 +456,7 @@ def update(self):
)
assert proposal.reparameterisations == {"x": "default"}
assert proposal.prime_parameters == ["x_prime"]
proposal.configure_constant_volume.assert_called_once()


def test_set_rescaling_with_reparameterisations(proposal, model):
Expand All @@ -464,6 +466,7 @@ def test_set_rescaling_with_reparameterisations(proposal, model):
proposal.model = model
proposal.model.reparameterisations = None
proposal.reparameterisations = {"x": "default"}
proposal.expansion_fraction = None

def update(self):
proposal.parameters = model.names
Expand All @@ -479,6 +482,7 @@ def update(self):
)
assert proposal.reparameterisations == {"x": "default"}
assert proposal.prime_parameters == ["x_prime"]
proposal.configure_constant_volume.assert_called_once()


@pytest.mark.parametrize("n", [1, 10])
Expand Down

0 comments on commit ef45202

Please sign in to comment.