diff --git a/docs/further-details.rst b/docs/further-details.rst index ce56b395..353c30d2 100644 --- a/docs/further-details.rst +++ b/docs/further-details.rst @@ -112,6 +112,12 @@ priority over the resume file. This will be passed to the :python:`resume_from_pickled_sampler` of the corresponding sampler class. +.. note:: + + If the output directory has been moved, make sure to change the + :code`output` argument when calling :code:`FlowSampler`. The sampler + will then automatically update the relevant paths. + Checkpoint callbacks -------------------- diff --git a/nessai/flowsampler.py b/nessai/flowsampler.py index a6446f2d..fa934c52 100644 --- a/nessai/flowsampler.py +++ b/nessai/flowsampler.py @@ -85,7 +85,7 @@ class FlowSampler: def __init__( self, model, - output=os.getcwd(), + output=None, importance_nested_sampler=False, resume=True, resume_file="nested_sampler_resume.pkl", @@ -145,6 +145,8 @@ def __init__( logger.debug("Overriding `parallelise_prior` in the model") model.parallelise_prior = parallelise_prior + if output is None: + output = os.getcwd() self.output = os.path.join(output, "") os.makedirs(self.output, exist_ok=True) self.save_kwargs(kwargs) @@ -160,6 +162,7 @@ def __init__( SamplerClass, resume_data=resume_data, model=model, + output=self.output, weights_path=weights_path, flow_config=kwargs.get("flow_config"), checkpoint_callback=kwargs.get("checkpoint_callback"), @@ -168,6 +171,7 @@ def __init__( self.ns = self._resume_from_file( SamplerClass, model=model, + output=self.output, resume_file=resume_file, weights_path=weights_path, flow_config=kwargs.get("flow_config"), diff --git a/nessai/proposal/base.py b/nessai/proposal/base.py index 687c08e2..3027c6f9 100644 --- a/nessai/proposal/base.py +++ b/nessai/proposal/base.py @@ -5,6 +5,7 @@ import datetime import logging +import os from abc import ABC, abstractmethod import numpy as np @@ -53,6 +54,24 @@ def initialise(self): """ self.initialised = True + def update_output(self, output: str) -> None: + """ + Update the output directory. + + Only updates the output if the proposal has an output attribute. + + Parameters + ---------- + output: str + Path to the output directory + """ + if hasattr(self, "output"): + logger.debug(f"Updating output directory to {output}") + self.output = output + os.makedirs(self.output, exist_ok=True) + else: + logger.debug("No output directory to update") + def evaluate_likelihoods(self): """Evaluate the likelihoods for the pool of live points.""" self.samples["logL"] = self.model.batch_evaluate_log_likelihood( diff --git a/nessai/samplers/base.py b/nessai/samplers/base.py index 8e51156b..4d9ab96d 100644 --- a/nessai/samplers/base.py +++ b/nessai/samplers/base.py @@ -168,6 +168,18 @@ def configure_output( self.output = output self.resume_file = resume_file + def update_output(self, output: str) -> None: + """Update the output directory and resume file. + + Parameters + ---------- + output: str + Path to the output directory + """ + self.output = output + resume_file = os.path.split(self.resume_file)[1] + self.resume_file = os.path.join(output, resume_file) + def configure_random_seed(self, seed: Optional[int]): """Initialise the random seed. @@ -316,6 +328,7 @@ def resume_from_pickled_sampler( cls, sampler: Any, model: Model, + output: Optional[str] = None, checkpoint_callback: Optional[Callable] = None, ): """Resume from pickle data. @@ -326,6 +339,9 @@ def resume_from_pickled_sampler( Pickle data model : :obj:`nessai.model.Model` User-defined model + output : Optional[str] + New output directory. If not specified, the output directory will + be the same as the previous run. checkpoint_callback : Optional[Callable] Checkpoint callback function. If not specified, the default method will be used. @@ -341,13 +357,25 @@ def resume_from_pickled_sampler( model.likelihood_evaluation_time += datetime.timedelta( seconds=sampler._previous_likelihood_evaluation_time ) + if output is not None and output != sampler.output: + logger.info( + f"Overwriting output from {sampler.output} to {output}" + ) + os.makedirs(output, exist_ok=True) + sampler.update_output(output) sampler.model = model sampler.resumed = True sampler.checkpoint_callback = checkpoint_callback return sampler @classmethod - def resume(cls, filename: str, model: Model, **kwargs): + def resume( + cls, + filename: str, + model: Model, + output: Optional[str] = None, + **kwargs, + ): """Resumes the interrupted state from a checkpoint pickle file. Parameters @@ -356,6 +384,9 @@ def resume(cls, filename: str, model: Model, **kwargs): Pickle file to resume from model : :obj:`nessai.model.Model` User-defined model + output : Optional[str] + New output directory. If not specified, the output directory will + be the same as the previous run. Returns ------- @@ -365,7 +396,9 @@ def resume(cls, filename: str, model: Model, **kwargs): logger.info(f"Resuming {cls.__name__} from {filename}") with open(filename, "rb") as f: sampler = pickle.load(f) - return cls.resume_from_pickled_sampler(sampler, model, **kwargs) + return cls.resume_from_pickled_sampler( + sampler, model, output=output, **kwargs + ) @abstractmethod def nested_sampling_loop(self): diff --git a/nessai/samplers/importancesampler.py b/nessai/samplers/importancesampler.py index fce4e52a..9a6b1e77 100644 --- a/nessai/samplers/importancesampler.py +++ b/nessai/samplers/importancesampler.py @@ -718,6 +718,13 @@ def get_proposal(self, subdir: str = "levels", **kwargs): proposal = ImportanceFlowProposal(self.model, output, **kwargs) return proposal + def update_output(self, output: str) -> None: + super().update_output(output) + if self.proposal is not None: + # ImportanceFlowProposal uses a subdirectory + subdir = os.path.basename(os.path.normpath(self.proposal.output)) + self.proposal.update_output(os.path.join(output, subdir, "")) + def configure_iterations( self, min_iteration: Optional[int] = None, diff --git a/nessai/samplers/nestedsampler.py b/nessai/samplers/nestedsampler.py index 916be42c..d5476e5c 100644 --- a/nessai/samplers/nestedsampler.py +++ b/nessai/samplers/nestedsampler.py @@ -499,6 +499,26 @@ def configure_output( if self.plot: os.makedirs(os.path.join(output, "diagnostics"), exist_ok=True) + def update_output(self, output: str) -> None: + """Update the output directory. + + Also creates a "diagnostics" directory for plotting. + + Parameters + ---------- + output : str + Path to the output directory. + """ + super().update_output(output) + if self.plot: + os.makedirs(os.path.join(output, "diagnostics"), exist_ok=True) + if self._flow_proposal is not None: + # FlowProposal uses a subdirectory + subdir = os.path.basename( + os.path.normpath(self._flow_proposal.output) + ) + self._flow_proposal.update_output(os.path.join(output, subdir, "")) + def configure_flow_reset( self, reset_weights, reset_permutations, reset_flow ): diff --git a/tests/test_flowsampler.py b/tests/test_flowsampler.py index 223f268b..72513b07 100644 --- a/tests/test_flowsampler.py +++ b/tests/test_flowsampler.py @@ -113,7 +113,10 @@ def test_check_resume_files_do_not_exist(flow_sampler, tmp_path): @pytest.mark.parametrize("resume", [False, True]) @pytest.mark.parametrize("use_ins", [False, True]) -def test_init_no_resume_file(flow_sampler, tmp_path, resume, use_ins): +@pytest.mark.parametrize("specify_output", [False, True]) +def test_init_no_resume_file( + flow_sampler, tmp_path, resume, use_ins, specify_output +): """Test the init method when there is no run to resume from""" integration_model = MagicMock() @@ -139,11 +142,12 @@ def test_init_no_resume_file(flow_sampler, tmp_path, resume, use_ins): f"nessai.flowsampler.{sampler_class}", return_value="ns" ) as mock, patch("nessai.flowsampler.configure_threads") as mock_threads, + patch("os.getcwd", return_value=output) as mock_getcwd, ): FlowSampler.__init__( flow_sampler, integration_model, - output=output, + output=output if specify_output else None, resume=resume, exit_code=exit_code, pytorch_threads=pytorch_threads, @@ -169,10 +173,13 @@ def test_init_no_resume_file(flow_sampler, tmp_path, resume, use_ins): flow_sampler.save_kwargs.assert_called_once_with(kwargs) + if not specify_output: + mock_getcwd.assert_called_once() + def test_resume_from_resume_data(flow_sampler, model, tmp_path): """Test for resume from data""" - output = tmp_path / "test" + output = str(tmp_path / "test") data = object() flow_sampler.check_resume = MagicMock(return_value=True) flow_sampler._resume_from_data = MagicMock() @@ -183,6 +190,7 @@ def test_resume_from_resume_data(flow_sampler, model, tmp_path): NestedSampler, resume_data=data, model=model, + output=os.path.join(output, ""), weights_path=None, flow_config=None, checkpoint_callback=None, @@ -191,7 +199,7 @@ def test_resume_from_resume_data(flow_sampler, model, tmp_path): def test_resume_from_resume_file(flow_sampler, model, tmp_path): """Test for resume from data""" - output = tmp_path / "test" + output = str(tmp_path / "test") resume_file = "resume.pkl" flow_sampler.check_resume = MagicMock(return_value=True) flow_sampler._resume_from_file = MagicMock() @@ -207,6 +215,7 @@ def test_resume_from_resume_file(flow_sampler, model, tmp_path): NestedSampler, resume_file=resume_file, model=model, + output=os.path.join(output, ""), weights_path=None, flow_config=None, checkpoint_callback=None, @@ -381,6 +390,7 @@ def test_init_resume(tmp_path, test_old, error): mock_resume.assert_called_with( expected_rf, integration_model, + output=os.path.join(output, ""), flow_config=flow_config, weights_path=weights_file, checkpoint_callback=None, diff --git a/tests/test_proposal/test_base_proposal.py b/tests/test_proposal/test_base_proposal.py index 9c593043..82545e28 100644 --- a/tests/test_proposal/test_base_proposal.py +++ b/tests/test_proposal/test_base_proposal.py @@ -4,6 +4,7 @@ """ import logging +import os import pickle from unittest.mock import MagicMock, Mock, create_autospec @@ -62,6 +63,19 @@ def test_initialise(proposal): assert proposal.initialised is True +def test_update_output(proposal, tmp_path): + tmp_path = tmp_path / "test" + proposal.output = tmp_path / "orig" + Proposal.update_output(proposal, tmp_path) + assert proposal.output == tmp_path + assert os.path.exists(tmp_path) + + +def test_update_output_no_output(proposal): + Proposal.update_output(proposal, "test") + assert not hasattr(proposal, "output") + + def test_evaluate_likelihoods(proposal): """Assert the correct method is called""" samples = numpy_array_to_live_points(np.array([[1], [2]]), ["x"]) diff --git a/tests/test_samplers/test_base_sampler.py b/tests/test_samplers/test_base_sampler.py index 6ea4f196..e6267d49 100644 --- a/tests/test_samplers/test_base_sampler.py +++ b/tests/test_samplers/test_base_sampler.py @@ -428,22 +428,31 @@ def test_close_pool(sampler): sampler.model.close_pool.assert_called_once_with(code=2) -def test_resume_from_pickled_sampler(model): +@pytest.mark.parametrize("output", [None, "orig", "new"]) +def test_resume_from_pickled_sampler(model, output): """Test the resume from pickled sampler method""" obj = MagicMock() obj.model = None + obj.output = "orig" obj._previous_likelihood_evaluations = 3 obj._previous_likelihood_evaluation_time = 4.0 model.likelihood_evaluations = 1 model.likelihood_evaluation_time = datetime.timedelta(seconds=2) - out = BaseNestedSampler.resume_from_pickled_sampler(obj, model) + out = BaseNestedSampler.resume_from_pickled_sampler( + obj, model, output=output + ) assert out.model == model assert out.model.likelihood_evaluations == 4 assert out.model.likelihood_evaluation_time.total_seconds() == 6 + if output == "new": + obj.update_output.assert_called_once_with("new") + else: + obj.update_output.assert_not_called() + def test_resume(model): """Test the resume method""" @@ -458,11 +467,11 @@ def test_resume(model): return_value=pickle_out, ) as mock_resume, ): - out = BaseNestedSampler.resume("test.pkl", model) + out = BaseNestedSampler.resume("test.pkl", model, output="test") assert out is pickle_out mock_pickle.assert_called_once() - mock_resume.assert_called_once_with(obj, model) + mock_resume.assert_called_once_with(obj, model, output="test") def test_get_result_dictionary(sampler): @@ -516,3 +525,12 @@ def test_update_history(sampler): BaseNestedSampler.update_history(sampler) assert sampler.history["likelihood_evaluations"] == [10, 20] assert sampler.history["sampling_time"] == [1, 2] + + +def test_update_output(sampler, tmp_path): + sampler.output = tmp_path / "orig" + sampler.resume_file = sampler.output / "resume.pkl" + new_output = tmp_path / "new" + BaseNestedSampler.update_output(sampler, new_output) + assert sampler.output == new_output + assert sampler.resume_file == str(new_output / "resume.pkl") diff --git a/tests/test_samplers/test_importance_nested_sampler/test_config.py b/tests/test_samplers/test_importance_nested_sampler/test_config.py index 3b4968df..be1b903d 100644 --- a/tests/test_samplers/test_importance_nested_sampler/test_config.py +++ b/tests/test_samplers/test_importance_nested_sampler/test_config.py @@ -1,6 +1,7 @@ """Test configuration of INS""" -from unittest.mock import MagicMock +import os +from unittest.mock import MagicMock, patch import numpy as np import pytest @@ -81,3 +82,21 @@ def check_configuration_okay(ins): ins.nlive = 100 ins.min_remove = 1 assert INS.check_configuration(ins) is True + + +@pytest.mark.parametrize("has_proposal", [False, True]) +def test_update_output(ins, tmp_path, has_proposal): + output = tmp_path / "new" + if has_proposal: + ins.proposal = MagicMock() + ins.proposal.output = tmp_path / "orig" / "levels" + else: + ins.proposal = None + with patch("nessai.samplers.base.BaseNestedSampler.update_output") as mock: + INS.update_output(ins, output) + + mock.assert_called_once_with(output) + if has_proposal: + ins.proposal.update_output.assert_called_once_with( + os.path.join(output, "levels", "") + ) diff --git a/tests/test_samplers/test_nested_sampler/test_general_config.py b/tests/test_samplers/test_nested_sampler/test_general_config.py index 490e809c..b67712b8 100644 --- a/tests/test_samplers/test_nested_sampler/test_general_config.py +++ b/tests/test_samplers/test_nested_sampler/test_general_config.py @@ -87,6 +87,29 @@ def test_setup_output(sampler, tmpdir, plot): assert os.path.exists(os.path.join(path, "diagnostics")) +@pytest.mark.parametrize("plot", [False, True]) +@pytest.mark.parametrize("has_proposal", [False, True]) +def test_update_output(sampler, tmp_path, plot, has_proposal): + output = tmp_path / "new" + sampler.plot = plot + if has_proposal: + sampler._flow_proposal = MagicMock() + sampler._flow_proposal.output = tmp_path / "orig" / "proposal" + else: + sampler._flow_proposal = None + with patch("nessai.samplers.base.BaseNestedSampler.update_output") as mock: + NestedSampler.update_output(sampler, output) + + mock.assert_called_once_with(output) + if has_proposal: + sampler._flow_proposal.update_output.assert_called_once_with( + os.path.join(output, "proposal", "") + ) + + if plot: + assert os.path.exists(os.path.join(output, "diagnostics")) + + def test_configure_max_iteration(sampler): """Test to make sure the maximum iteration is set correctly""" NestedSampler.configure_max_iteration(sampler, 10) diff --git a/tests/test_sampling/test_standard_sampling.py b/tests/test_sampling/test_standard_sampling.py index 67dc792c..ad3bb474 100644 --- a/tests/test_sampling/test_standard_sampling.py +++ b/tests/test_sampling/test_standard_sampling.py @@ -5,6 +5,7 @@ import logging import os +import shutil from unittest.mock import patch import numpy as np @@ -441,6 +442,50 @@ def test_resume_reparameterisation_values(tmpdir, model, flow_config): ) +@pytest.mark.slow_integration_test +def test_sampling_resume_move_files(model, flow_config, tmp_path): + """ + Test resuming the sampler after moving the resume files. + """ + output = tmp_path / "resume" + fp = FlowSampler( + model, + output=output, + resume=True, + nlive=100, + plot=False, + flow_config=flow_config, + training_frequency=10, + maximum_uninformed=9, + checkpoint_on_iteration=True, + checkpoint_interval=5, + seed=1234, + max_iteration=11, + poolsize=10, + ) + fp.run() + + assert os.path.exists(os.path.join(output, "nested_sampler_resume.pkl")) + new_output = tmp_path / "new_resume" + + shutil.move(output, new_output) + assert not os.path.exists(output) + + fp = FlowSampler( + model, + output=new_output, + resume=True, + flow_config=flow_config, + ) + assert fp.ns.iteration == 11 + fp.ns.max_iteration = 21 + fp.run() + assert fp.ns.iteration == 21 + assert os.path.exists( + os.path.join(new_output, "nested_sampler_resume.pkl.old") + ) + + @pytest.mark.slow_integration_test def test_sampling_with_infinite_prior_bounds(tmpdir): """