diff --git a/tests/test_flowsampler.py b/tests/test_flowsampler.py index 223f268b..89d691fb 100644 --- a/tests/test_flowsampler.py +++ b/tests/test_flowsampler.py @@ -172,7 +172,7 @@ def test_init_no_resume_file(flow_sampler, tmp_path, resume, use_ins): def test_resume_from_resume_data(flow_sampler, model, tmp_path): """Test for resume from data""" - output = tmp_path / "test" + output = str(tmp_path / "test") data = object() flow_sampler.check_resume = MagicMock(return_value=True) flow_sampler._resume_from_data = MagicMock() @@ -183,6 +183,7 @@ def test_resume_from_resume_data(flow_sampler, model, tmp_path): NestedSampler, resume_data=data, model=model, + output=os.path.join(output, ""), weights_path=None, flow_config=None, checkpoint_callback=None, @@ -191,7 +192,7 @@ def test_resume_from_resume_data(flow_sampler, model, tmp_path): def test_resume_from_resume_file(flow_sampler, model, tmp_path): """Test for resume from data""" - output = tmp_path / "test" + output = str(tmp_path / "test") resume_file = "resume.pkl" flow_sampler.check_resume = MagicMock(return_value=True) flow_sampler._resume_from_file = MagicMock() @@ -207,6 +208,7 @@ def test_resume_from_resume_file(flow_sampler, model, tmp_path): NestedSampler, resume_file=resume_file, model=model, + output=os.path.join(output, ""), weights_path=None, flow_config=None, checkpoint_callback=None, @@ -381,6 +383,7 @@ def test_init_resume(tmp_path, test_old, error): mock_resume.assert_called_with( expected_rf, integration_model, + output=os.path.join(output, ""), flow_config=flow_config, weights_path=weights_file, checkpoint_callback=None, diff --git a/tests/test_proposal/test_base_proposal.py b/tests/test_proposal/test_base_proposal.py index 9c593043..82545e28 100644 --- a/tests/test_proposal/test_base_proposal.py +++ b/tests/test_proposal/test_base_proposal.py @@ -4,6 +4,7 @@ """ import logging +import os import pickle from unittest.mock import MagicMock, Mock, create_autospec @@ -62,6 +63,19 @@ def test_initialise(proposal): assert proposal.initialised is True +def test_update_output(proposal, tmp_path): + tmp_path = tmp_path / "test" + proposal.output = tmp_path / "orig" + Proposal.update_output(proposal, tmp_path) + assert proposal.output == tmp_path + assert os.path.exists(tmp_path) + + +def test_update_output_no_output(proposal): + Proposal.update_output(proposal, "test") + assert not hasattr(proposal, "output") + + def test_evaluate_likelihoods(proposal): """Assert the correct method is called""" samples = numpy_array_to_live_points(np.array([[1], [2]]), ["x"]) diff --git a/tests/test_samplers/test_base_sampler.py b/tests/test_samplers/test_base_sampler.py index 6ea4f196..e6267d49 100644 --- a/tests/test_samplers/test_base_sampler.py +++ b/tests/test_samplers/test_base_sampler.py @@ -428,22 +428,31 @@ def test_close_pool(sampler): sampler.model.close_pool.assert_called_once_with(code=2) -def test_resume_from_pickled_sampler(model): +@pytest.mark.parametrize("output", [None, "orig", "new"]) +def test_resume_from_pickled_sampler(model, output): """Test the resume from pickled sampler method""" obj = MagicMock() obj.model = None + obj.output = "orig" obj._previous_likelihood_evaluations = 3 obj._previous_likelihood_evaluation_time = 4.0 model.likelihood_evaluations = 1 model.likelihood_evaluation_time = datetime.timedelta(seconds=2) - out = BaseNestedSampler.resume_from_pickled_sampler(obj, model) + out = BaseNestedSampler.resume_from_pickled_sampler( + obj, model, output=output + ) assert out.model == model assert out.model.likelihood_evaluations == 4 assert out.model.likelihood_evaluation_time.total_seconds() == 6 + if output == "new": + obj.update_output.assert_called_once_with("new") + else: + obj.update_output.assert_not_called() + def test_resume(model): """Test the resume method""" @@ -458,11 +467,11 @@ def test_resume(model): return_value=pickle_out, ) as mock_resume, ): - out = BaseNestedSampler.resume("test.pkl", model) + out = BaseNestedSampler.resume("test.pkl", model, output="test") assert out is pickle_out mock_pickle.assert_called_once() - mock_resume.assert_called_once_with(obj, model) + mock_resume.assert_called_once_with(obj, model, output="test") def test_get_result_dictionary(sampler): @@ -516,3 +525,12 @@ def test_update_history(sampler): BaseNestedSampler.update_history(sampler) assert sampler.history["likelihood_evaluations"] == [10, 20] assert sampler.history["sampling_time"] == [1, 2] + + +def test_update_output(sampler, tmp_path): + sampler.output = tmp_path / "orig" + sampler.resume_file = sampler.output / "resume.pkl" + new_output = tmp_path / "new" + BaseNestedSampler.update_output(sampler, new_output) + assert sampler.output == new_output + assert sampler.resume_file == str(new_output / "resume.pkl") diff --git a/tests/test_samplers/test_importance_nested_sampler/test_config.py b/tests/test_samplers/test_importance_nested_sampler/test_config.py index 3b4968df..be1b903d 100644 --- a/tests/test_samplers/test_importance_nested_sampler/test_config.py +++ b/tests/test_samplers/test_importance_nested_sampler/test_config.py @@ -1,6 +1,7 @@ """Test configuration of INS""" -from unittest.mock import MagicMock +import os +from unittest.mock import MagicMock, patch import numpy as np import pytest @@ -81,3 +82,21 @@ def check_configuration_okay(ins): ins.nlive = 100 ins.min_remove = 1 assert INS.check_configuration(ins) is True + + +@pytest.mark.parametrize("has_proposal", [False, True]) +def test_update_output(ins, tmp_path, has_proposal): + output = tmp_path / "new" + if has_proposal: + ins.proposal = MagicMock() + ins.proposal.output = tmp_path / "orig" / "levels" + else: + ins.proposal = None + with patch("nessai.samplers.base.BaseNestedSampler.update_output") as mock: + INS.update_output(ins, output) + + mock.assert_called_once_with(output) + if has_proposal: + ins.proposal.update_output.assert_called_once_with( + os.path.join(output, "levels", "") + ) diff --git a/tests/test_samplers/test_nested_sampler/test_general_config.py b/tests/test_samplers/test_nested_sampler/test_general_config.py index 490e809c..b67712b8 100644 --- a/tests/test_samplers/test_nested_sampler/test_general_config.py +++ b/tests/test_samplers/test_nested_sampler/test_general_config.py @@ -87,6 +87,29 @@ def test_setup_output(sampler, tmpdir, plot): assert os.path.exists(os.path.join(path, "diagnostics")) +@pytest.mark.parametrize("plot", [False, True]) +@pytest.mark.parametrize("has_proposal", [False, True]) +def test_update_output(sampler, tmp_path, plot, has_proposal): + output = tmp_path / "new" + sampler.plot = plot + if has_proposal: + sampler._flow_proposal = MagicMock() + sampler._flow_proposal.output = tmp_path / "orig" / "proposal" + else: + sampler._flow_proposal = None + with patch("nessai.samplers.base.BaseNestedSampler.update_output") as mock: + NestedSampler.update_output(sampler, output) + + mock.assert_called_once_with(output) + if has_proposal: + sampler._flow_proposal.update_output.assert_called_once_with( + os.path.join(output, "proposal", "") + ) + + if plot: + assert os.path.exists(os.path.join(output, "diagnostics")) + + def test_configure_max_iteration(sampler): """Test to make sure the maximum iteration is set correctly""" NestedSampler.configure_max_iteration(sampler, 10)