From 3b3d41f22316d5f1d3e1199938882a47c441a2f6 Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 16 Sep 2024 11:19:24 +0100 Subject: [PATCH 1/9] ENH: add `update_output` method to proposal classes --- nessai/proposal/base.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/nessai/proposal/base.py b/nessai/proposal/base.py index 687c08e2..3027c6f9 100644 --- a/nessai/proposal/base.py +++ b/nessai/proposal/base.py @@ -5,6 +5,7 @@ import datetime import logging +import os from abc import ABC, abstractmethod import numpy as np @@ -53,6 +54,24 @@ def initialise(self): """ self.initialised = True + def update_output(self, output: str) -> None: + """ + Update the output directory. + + Only updates the output if the proposal has an output attribute. + + Parameters + ---------- + output: str + Path to the output directory + """ + if hasattr(self, "output"): + logger.debug(f"Updating output directory to {output}") + self.output = output + os.makedirs(self.output, exist_ok=True) + else: + logger.debug("No output directory to update") + def evaluate_likelihoods(self): """Evaluate the likelihoods for the pool of live points.""" self.samples["logL"] = self.model.batch_evaluate_log_likelihood( From 1b6d992c05e58710b823e864be0fee22eb5b2349 Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 16 Sep 2024 11:20:17 +0100 Subject: [PATCH 2/9] ENH: add output directory to resume functions --- nessai/samplers/base.py | 38 ++++++++++++++++++++++++++-- nessai/samplers/importancesampler.py | 7 +++++ nessai/samplers/nestedsampler.py | 20 +++++++++++++++ 3 files changed, 63 insertions(+), 2 deletions(-) diff --git a/nessai/samplers/base.py b/nessai/samplers/base.py index 8e51156b..591385c0 100644 --- a/nessai/samplers/base.py +++ b/nessai/samplers/base.py @@ -168,6 +168,18 @@ def configure_output( self.output = output self.resume_file = resume_file + def update_output(self, output: str) -> None: + """Update the output directory and resume file. + + Parameters + ---------- + output: str + Path to the output directory + """ + self.output = output + resume_file = os.path.split(self.resume_file)[1] + self.resume_file = os.path.join(output, resume_file) + def configure_random_seed(self, seed: Optional[int]): """Initialise the random seed. @@ -316,6 +328,7 @@ def resume_from_pickled_sampler( cls, sampler: Any, model: Model, + output: Optional[str] = None, checkpoint_callback: Optional[Callable] = None, ): """Resume from pickle data. @@ -326,6 +339,9 @@ def resume_from_pickled_sampler( Pickle data model : :obj:`nessai.model.Model` User-defined model + output : Optional[str] + New output directory. If not specified, the output directory will + be the same as the previous run. checkpoint_callback : Optional[Callable] Checkpoint callback function. If not specified, the default method will be used. @@ -341,13 +357,26 @@ def resume_from_pickled_sampler( model.likelihood_evaluation_time += datetime.timedelta( seconds=sampler._previous_likelihood_evaluation_time ) + if output is not None: + if output != sampler.output: + logger.info( + f"Overwriting output from {sampler.output} to {output}" + ) + os.makedirs(output, exist_ok=True) + sampler.update_output(output) sampler.model = model sampler.resumed = True sampler.checkpoint_callback = checkpoint_callback return sampler @classmethod - def resume(cls, filename: str, model: Model, **kwargs): + def resume( + cls, + filename: str, + model: Model, + output: Optional[str] = None, + **kwargs, + ): """Resumes the interrupted state from a checkpoint pickle file. Parameters @@ -356,6 +385,9 @@ def resume(cls, filename: str, model: Model, **kwargs): Pickle file to resume from model : :obj:`nessai.model.Model` User-defined model + output : Optional[str] + New output directory. If not specified, the output directory will + be the same as the previous run. Returns ------- @@ -365,7 +397,9 @@ def resume(cls, filename: str, model: Model, **kwargs): logger.info(f"Resuming {cls.__name__} from {filename}") with open(filename, "rb") as f: sampler = pickle.load(f) - return cls.resume_from_pickled_sampler(sampler, model, **kwargs) + return cls.resume_from_pickled_sampler( + sampler, model, output=output, **kwargs + ) @abstractmethod def nested_sampling_loop(self): diff --git a/nessai/samplers/importancesampler.py b/nessai/samplers/importancesampler.py index fce4e52a..10e337da 100644 --- a/nessai/samplers/importancesampler.py +++ b/nessai/samplers/importancesampler.py @@ -718,6 +718,13 @@ def get_proposal(self, subdir: str = "levels", **kwargs): proposal = ImportanceFlowProposal(self.model, output, **kwargs) return proposal + def update_output(self, output: str) -> None: + super().update_output(output) + if self.proposal is not None: + # ImportanceFlowProposal uses a subdirectory + subdir = os.path.basename(os.path.normpath(self.proposal.output)) + self.proposal.update_output(os.path.join(self.output, subdir, "")) + def configure_iterations( self, min_iteration: Optional[int] = None, diff --git a/nessai/samplers/nestedsampler.py b/nessai/samplers/nestedsampler.py index 916be42c..d5476e5c 100644 --- a/nessai/samplers/nestedsampler.py +++ b/nessai/samplers/nestedsampler.py @@ -499,6 +499,26 @@ def configure_output( if self.plot: os.makedirs(os.path.join(output, "diagnostics"), exist_ok=True) + def update_output(self, output: str) -> None: + """Update the output directory. + + Also creates a "diagnostics" directory for plotting. + + Parameters + ---------- + output : str + Path to the output directory. + """ + super().update_output(output) + if self.plot: + os.makedirs(os.path.join(output, "diagnostics"), exist_ok=True) + if self._flow_proposal is not None: + # FlowProposal uses a subdirectory + subdir = os.path.basename( + os.path.normpath(self._flow_proposal.output) + ) + self._flow_proposal.update_output(os.path.join(output, subdir, "")) + def configure_flow_reset( self, reset_weights, reset_permutations, reset_flow ): From a7cd529e5937a27a3ebdcc2075109c8896405d9f Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 16 Sep 2024 11:20:36 +0100 Subject: [PATCH 3/9] ENH: specify output directory when resuming --- nessai/flowsampler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nessai/flowsampler.py b/nessai/flowsampler.py index a6446f2d..fa934c52 100644 --- a/nessai/flowsampler.py +++ b/nessai/flowsampler.py @@ -85,7 +85,7 @@ class FlowSampler: def __init__( self, model, - output=os.getcwd(), + output=None, importance_nested_sampler=False, resume=True, resume_file="nested_sampler_resume.pkl", @@ -145,6 +145,8 @@ def __init__( logger.debug("Overriding `parallelise_prior` in the model") model.parallelise_prior = parallelise_prior + if output is None: + output = os.getcwd() self.output = os.path.join(output, "") os.makedirs(self.output, exist_ok=True) self.save_kwargs(kwargs) @@ -160,6 +162,7 @@ def __init__( SamplerClass, resume_data=resume_data, model=model, + output=self.output, weights_path=weights_path, flow_config=kwargs.get("flow_config"), checkpoint_callback=kwargs.get("checkpoint_callback"), @@ -168,6 +171,7 @@ def __init__( self.ns = self._resume_from_file( SamplerClass, model=model, + output=self.output, resume_file=resume_file, weights_path=weights_path, flow_config=kwargs.get("flow_config"), From a4b3af0c36d0104013ce2e28f164a640b5cb902e Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 16 Sep 2024 11:24:55 +0100 Subject: [PATCH 4/9] DOC: add note about resuming from a different directory --- docs/further-details.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/further-details.rst b/docs/further-details.rst index ce56b395..353c30d2 100644 --- a/docs/further-details.rst +++ b/docs/further-details.rst @@ -112,6 +112,12 @@ priority over the resume file. This will be passed to the :python:`resume_from_pickled_sampler` of the corresponding sampler class. +.. note:: + + If the output directory has been moved, make sure to change the + :code`output` argument when calling :code:`FlowSampler`. The sampler + will then automatically update the relevant paths. + Checkpoint callbacks -------------------- From 65da4fe1df59ac64190c35e4a0fc57b99caac21a Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 16 Sep 2024 12:28:10 +0100 Subject: [PATCH 5/9] STY: simplify if statement --- nessai/samplers/base.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/nessai/samplers/base.py b/nessai/samplers/base.py index 591385c0..4d9ab96d 100644 --- a/nessai/samplers/base.py +++ b/nessai/samplers/base.py @@ -357,13 +357,12 @@ def resume_from_pickled_sampler( model.likelihood_evaluation_time += datetime.timedelta( seconds=sampler._previous_likelihood_evaluation_time ) - if output is not None: - if output != sampler.output: - logger.info( - f"Overwriting output from {sampler.output} to {output}" - ) - os.makedirs(output, exist_ok=True) - sampler.update_output(output) + if output is not None and output != sampler.output: + logger.info( + f"Overwriting output from {sampler.output} to {output}" + ) + os.makedirs(output, exist_ok=True) + sampler.update_output(output) sampler.model = model sampler.resumed = True sampler.checkpoint_callback = checkpoint_callback From 3cf460aea7e56e321ab2d17aa51557f5b5932529 Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 16 Sep 2024 12:28:23 +0100 Subject: [PATCH 6/9] MAINT: use local variable --- nessai/samplers/importancesampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nessai/samplers/importancesampler.py b/nessai/samplers/importancesampler.py index 10e337da..9a6b1e77 100644 --- a/nessai/samplers/importancesampler.py +++ b/nessai/samplers/importancesampler.py @@ -723,7 +723,7 @@ def update_output(self, output: str) -> None: if self.proposal is not None: # ImportanceFlowProposal uses a subdirectory subdir = os.path.basename(os.path.normpath(self.proposal.output)) - self.proposal.update_output(os.path.join(self.output, subdir, "")) + self.proposal.update_output(os.path.join(output, subdir, "")) def configure_iterations( self, From 7d66feec0e8f18297d242fc1a66ea6b5c0de0421 Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 16 Sep 2024 12:28:52 +0100 Subject: [PATCH 7/9] TST: add tests for `update_output` and resume --- tests/test_flowsampler.py | 7 +++-- tests/test_proposal/test_base_proposal.py | 14 ++++++++++ tests/test_samplers/test_base_sampler.py | 26 ++++++++++++++++--- .../test_config.py | 21 ++++++++++++++- .../test_general_config.py | 23 ++++++++++++++++ 5 files changed, 84 insertions(+), 7 deletions(-) diff --git a/tests/test_flowsampler.py b/tests/test_flowsampler.py index 223f268b..89d691fb 100644 --- a/tests/test_flowsampler.py +++ b/tests/test_flowsampler.py @@ -172,7 +172,7 @@ def test_init_no_resume_file(flow_sampler, tmp_path, resume, use_ins): def test_resume_from_resume_data(flow_sampler, model, tmp_path): """Test for resume from data""" - output = tmp_path / "test" + output = str(tmp_path / "test") data = object() flow_sampler.check_resume = MagicMock(return_value=True) flow_sampler._resume_from_data = MagicMock() @@ -183,6 +183,7 @@ def test_resume_from_resume_data(flow_sampler, model, tmp_path): NestedSampler, resume_data=data, model=model, + output=os.path.join(output, ""), weights_path=None, flow_config=None, checkpoint_callback=None, @@ -191,7 +192,7 @@ def test_resume_from_resume_data(flow_sampler, model, tmp_path): def test_resume_from_resume_file(flow_sampler, model, tmp_path): """Test for resume from data""" - output = tmp_path / "test" + output = str(tmp_path / "test") resume_file = "resume.pkl" flow_sampler.check_resume = MagicMock(return_value=True) flow_sampler._resume_from_file = MagicMock() @@ -207,6 +208,7 @@ def test_resume_from_resume_file(flow_sampler, model, tmp_path): NestedSampler, resume_file=resume_file, model=model, + output=os.path.join(output, ""), weights_path=None, flow_config=None, checkpoint_callback=None, @@ -381,6 +383,7 @@ def test_init_resume(tmp_path, test_old, error): mock_resume.assert_called_with( expected_rf, integration_model, + output=os.path.join(output, ""), flow_config=flow_config, weights_path=weights_file, checkpoint_callback=None, diff --git a/tests/test_proposal/test_base_proposal.py b/tests/test_proposal/test_base_proposal.py index 9c593043..82545e28 100644 --- a/tests/test_proposal/test_base_proposal.py +++ b/tests/test_proposal/test_base_proposal.py @@ -4,6 +4,7 @@ """ import logging +import os import pickle from unittest.mock import MagicMock, Mock, create_autospec @@ -62,6 +63,19 @@ def test_initialise(proposal): assert proposal.initialised is True +def test_update_output(proposal, tmp_path): + tmp_path = tmp_path / "test" + proposal.output = tmp_path / "orig" + Proposal.update_output(proposal, tmp_path) + assert proposal.output == tmp_path + assert os.path.exists(tmp_path) + + +def test_update_output_no_output(proposal): + Proposal.update_output(proposal, "test") + assert not hasattr(proposal, "output") + + def test_evaluate_likelihoods(proposal): """Assert the correct method is called""" samples = numpy_array_to_live_points(np.array([[1], [2]]), ["x"]) diff --git a/tests/test_samplers/test_base_sampler.py b/tests/test_samplers/test_base_sampler.py index 6ea4f196..e6267d49 100644 --- a/tests/test_samplers/test_base_sampler.py +++ b/tests/test_samplers/test_base_sampler.py @@ -428,22 +428,31 @@ def test_close_pool(sampler): sampler.model.close_pool.assert_called_once_with(code=2) -def test_resume_from_pickled_sampler(model): +@pytest.mark.parametrize("output", [None, "orig", "new"]) +def test_resume_from_pickled_sampler(model, output): """Test the resume from pickled sampler method""" obj = MagicMock() obj.model = None + obj.output = "orig" obj._previous_likelihood_evaluations = 3 obj._previous_likelihood_evaluation_time = 4.0 model.likelihood_evaluations = 1 model.likelihood_evaluation_time = datetime.timedelta(seconds=2) - out = BaseNestedSampler.resume_from_pickled_sampler(obj, model) + out = BaseNestedSampler.resume_from_pickled_sampler( + obj, model, output=output + ) assert out.model == model assert out.model.likelihood_evaluations == 4 assert out.model.likelihood_evaluation_time.total_seconds() == 6 + if output == "new": + obj.update_output.assert_called_once_with("new") + else: + obj.update_output.assert_not_called() + def test_resume(model): """Test the resume method""" @@ -458,11 +467,11 @@ def test_resume(model): return_value=pickle_out, ) as mock_resume, ): - out = BaseNestedSampler.resume("test.pkl", model) + out = BaseNestedSampler.resume("test.pkl", model, output="test") assert out is pickle_out mock_pickle.assert_called_once() - mock_resume.assert_called_once_with(obj, model) + mock_resume.assert_called_once_with(obj, model, output="test") def test_get_result_dictionary(sampler): @@ -516,3 +525,12 @@ def test_update_history(sampler): BaseNestedSampler.update_history(sampler) assert sampler.history["likelihood_evaluations"] == [10, 20] assert sampler.history["sampling_time"] == [1, 2] + + +def test_update_output(sampler, tmp_path): + sampler.output = tmp_path / "orig" + sampler.resume_file = sampler.output / "resume.pkl" + new_output = tmp_path / "new" + BaseNestedSampler.update_output(sampler, new_output) + assert sampler.output == new_output + assert sampler.resume_file == str(new_output / "resume.pkl") diff --git a/tests/test_samplers/test_importance_nested_sampler/test_config.py b/tests/test_samplers/test_importance_nested_sampler/test_config.py index 3b4968df..be1b903d 100644 --- a/tests/test_samplers/test_importance_nested_sampler/test_config.py +++ b/tests/test_samplers/test_importance_nested_sampler/test_config.py @@ -1,6 +1,7 @@ """Test configuration of INS""" -from unittest.mock import MagicMock +import os +from unittest.mock import MagicMock, patch import numpy as np import pytest @@ -81,3 +82,21 @@ def check_configuration_okay(ins): ins.nlive = 100 ins.min_remove = 1 assert INS.check_configuration(ins) is True + + +@pytest.mark.parametrize("has_proposal", [False, True]) +def test_update_output(ins, tmp_path, has_proposal): + output = tmp_path / "new" + if has_proposal: + ins.proposal = MagicMock() + ins.proposal.output = tmp_path / "orig" / "levels" + else: + ins.proposal = None + with patch("nessai.samplers.base.BaseNestedSampler.update_output") as mock: + INS.update_output(ins, output) + + mock.assert_called_once_with(output) + if has_proposal: + ins.proposal.update_output.assert_called_once_with( + os.path.join(output, "levels", "") + ) diff --git a/tests/test_samplers/test_nested_sampler/test_general_config.py b/tests/test_samplers/test_nested_sampler/test_general_config.py index 490e809c..b67712b8 100644 --- a/tests/test_samplers/test_nested_sampler/test_general_config.py +++ b/tests/test_samplers/test_nested_sampler/test_general_config.py @@ -87,6 +87,29 @@ def test_setup_output(sampler, tmpdir, plot): assert os.path.exists(os.path.join(path, "diagnostics")) +@pytest.mark.parametrize("plot", [False, True]) +@pytest.mark.parametrize("has_proposal", [False, True]) +def test_update_output(sampler, tmp_path, plot, has_proposal): + output = tmp_path / "new" + sampler.plot = plot + if has_proposal: + sampler._flow_proposal = MagicMock() + sampler._flow_proposal.output = tmp_path / "orig" / "proposal" + else: + sampler._flow_proposal = None + with patch("nessai.samplers.base.BaseNestedSampler.update_output") as mock: + NestedSampler.update_output(sampler, output) + + mock.assert_called_once_with(output) + if has_proposal: + sampler._flow_proposal.update_output.assert_called_once_with( + os.path.join(output, "proposal", "") + ) + + if plot: + assert os.path.exists(os.path.join(output, "diagnostics")) + + def test_configure_max_iteration(sampler): """Test to make sure the maximum iteration is set correctly""" NestedSampler.configure_max_iteration(sampler, 10) From 941e6786db690481f7cb9d4a51fe557cc3700fd1 Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 16 Sep 2024 13:08:18 +0100 Subject: [PATCH 8/9] TST: add a test for `output=None` --- tests/test_flowsampler.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/test_flowsampler.py b/tests/test_flowsampler.py index 89d691fb..72513b07 100644 --- a/tests/test_flowsampler.py +++ b/tests/test_flowsampler.py @@ -113,7 +113,10 @@ def test_check_resume_files_do_not_exist(flow_sampler, tmp_path): @pytest.mark.parametrize("resume", [False, True]) @pytest.mark.parametrize("use_ins", [False, True]) -def test_init_no_resume_file(flow_sampler, tmp_path, resume, use_ins): +@pytest.mark.parametrize("specify_output", [False, True]) +def test_init_no_resume_file( + flow_sampler, tmp_path, resume, use_ins, specify_output +): """Test the init method when there is no run to resume from""" integration_model = MagicMock() @@ -139,11 +142,12 @@ def test_init_no_resume_file(flow_sampler, tmp_path, resume, use_ins): f"nessai.flowsampler.{sampler_class}", return_value="ns" ) as mock, patch("nessai.flowsampler.configure_threads") as mock_threads, + patch("os.getcwd", return_value=output) as mock_getcwd, ): FlowSampler.__init__( flow_sampler, integration_model, - output=output, + output=output if specify_output else None, resume=resume, exit_code=exit_code, pytorch_threads=pytorch_threads, @@ -169,6 +173,9 @@ def test_init_no_resume_file(flow_sampler, tmp_path, resume, use_ins): flow_sampler.save_kwargs.assert_called_once_with(kwargs) + if not specify_output: + mock_getcwd.assert_called_once() + def test_resume_from_resume_data(flow_sampler, model, tmp_path): """Test for resume from data""" From 49b05db99ffe18f6b833da0127f678e6cfb2e640 Mon Sep 17 00:00:00 2001 From: mj-will Date: Mon, 16 Sep 2024 13:38:26 +0100 Subject: [PATCH 9/9] TST: add integration test for moving resume files --- tests/test_sampling/test_standard_sampling.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/test_sampling/test_standard_sampling.py b/tests/test_sampling/test_standard_sampling.py index 67dc792c..ad3bb474 100644 --- a/tests/test_sampling/test_standard_sampling.py +++ b/tests/test_sampling/test_standard_sampling.py @@ -5,6 +5,7 @@ import logging import os +import shutil from unittest.mock import patch import numpy as np @@ -441,6 +442,50 @@ def test_resume_reparameterisation_values(tmpdir, model, flow_config): ) +@pytest.mark.slow_integration_test +def test_sampling_resume_move_files(model, flow_config, tmp_path): + """ + Test resuming the sampler after moving the resume files. + """ + output = tmp_path / "resume" + fp = FlowSampler( + model, + output=output, + resume=True, + nlive=100, + plot=False, + flow_config=flow_config, + training_frequency=10, + maximum_uninformed=9, + checkpoint_on_iteration=True, + checkpoint_interval=5, + seed=1234, + max_iteration=11, + poolsize=10, + ) + fp.run() + + assert os.path.exists(os.path.join(output, "nested_sampler_resume.pkl")) + new_output = tmp_path / "new_resume" + + shutil.move(output, new_output) + assert not os.path.exists(output) + + fp = FlowSampler( + model, + output=new_output, + resume=True, + flow_config=flow_config, + ) + assert fp.ns.iteration == 11 + fp.ns.max_iteration = 21 + fp.run() + assert fp.ns.iteration == 21 + assert os.path.exists( + os.path.join(new_output, "nested_sampler_resume.pkl.old") + ) + + @pytest.mark.slow_integration_test def test_sampling_with_infinite_prior_bounds(tmpdir): """