Skip to content

Commit

Permalink
Merge pull request #429 from mj-will/add-outdir-to-resume
Browse files Browse the repository at this point in the history
ENH: add `output` to resume functions
  • Loading branch information
mj-will authored Sep 16, 2024
2 parents 0ca7181 + 49b05db commit 3d19e39
Show file tree
Hide file tree
Showing 12 changed files with 230 additions and 12 deletions.
6 changes: 6 additions & 0 deletions docs/further-details.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ priority over the resume file.
This will be passed to the :python:`resume_from_pickled_sampler` of the
corresponding sampler class.

.. note::

If the output directory has been moved, make sure to change the
:code`output` argument when calling :code:`FlowSampler`. The sampler
will then automatically update the relevant paths.


Checkpoint callbacks
--------------------
Expand Down
6 changes: 5 additions & 1 deletion nessai/flowsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class FlowSampler:
def __init__(
self,
model,
output=os.getcwd(),
output=None,
importance_nested_sampler=False,
resume=True,
resume_file="nested_sampler_resume.pkl",
Expand Down Expand Up @@ -145,6 +145,8 @@ def __init__(
logger.debug("Overriding `parallelise_prior` in the model")
model.parallelise_prior = parallelise_prior

if output is None:
output = os.getcwd()
self.output = os.path.join(output, "")
os.makedirs(self.output, exist_ok=True)
self.save_kwargs(kwargs)
Expand All @@ -160,6 +162,7 @@ def __init__(
SamplerClass,
resume_data=resume_data,
model=model,
output=self.output,
weights_path=weights_path,
flow_config=kwargs.get("flow_config"),
checkpoint_callback=kwargs.get("checkpoint_callback"),
Expand All @@ -168,6 +171,7 @@ def __init__(
self.ns = self._resume_from_file(
SamplerClass,
model=model,
output=self.output,
resume_file=resume_file,
weights_path=weights_path,
flow_config=kwargs.get("flow_config"),
Expand Down
19 changes: 19 additions & 0 deletions nessai/proposal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import datetime
import logging
import os
from abc import ABC, abstractmethod

import numpy as np
Expand Down Expand Up @@ -53,6 +54,24 @@ def initialise(self):
"""
self.initialised = True

def update_output(self, output: str) -> None:
"""
Update the output directory.
Only updates the output if the proposal has an output attribute.
Parameters
----------
output: str
Path to the output directory
"""
if hasattr(self, "output"):
logger.debug(f"Updating output directory to {output}")
self.output = output
os.makedirs(self.output, exist_ok=True)
else:
logger.debug("No output directory to update")

def evaluate_likelihoods(self):
"""Evaluate the likelihoods for the pool of live points."""
self.samples["logL"] = self.model.batch_evaluate_log_likelihood(
Expand Down
37 changes: 35 additions & 2 deletions nessai/samplers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ def configure_output(
self.output = output
self.resume_file = resume_file

def update_output(self, output: str) -> None:
"""Update the output directory and resume file.
Parameters
----------
output: str
Path to the output directory
"""
self.output = output
resume_file = os.path.split(self.resume_file)[1]
self.resume_file = os.path.join(output, resume_file)

def configure_random_seed(self, seed: Optional[int]):
"""Initialise the random seed.
Expand Down Expand Up @@ -316,6 +328,7 @@ def resume_from_pickled_sampler(
cls,
sampler: Any,
model: Model,
output: Optional[str] = None,
checkpoint_callback: Optional[Callable] = None,
):
"""Resume from pickle data.
Expand All @@ -326,6 +339,9 @@ def resume_from_pickled_sampler(
Pickle data
model : :obj:`nessai.model.Model`
User-defined model
output : Optional[str]
New output directory. If not specified, the output directory will
be the same as the previous run.
checkpoint_callback : Optional[Callable]
Checkpoint callback function. If not specified, the default method
will be used.
Expand All @@ -341,13 +357,25 @@ def resume_from_pickled_sampler(
model.likelihood_evaluation_time += datetime.timedelta(
seconds=sampler._previous_likelihood_evaluation_time
)
if output is not None and output != sampler.output:
logger.info(
f"Overwriting output from {sampler.output} to {output}"
)
os.makedirs(output, exist_ok=True)
sampler.update_output(output)
sampler.model = model
sampler.resumed = True
sampler.checkpoint_callback = checkpoint_callback
return sampler

@classmethod
def resume(cls, filename: str, model: Model, **kwargs):
def resume(
cls,
filename: str,
model: Model,
output: Optional[str] = None,
**kwargs,
):
"""Resumes the interrupted state from a checkpoint pickle file.
Parameters
Expand All @@ -356,6 +384,9 @@ def resume(cls, filename: str, model: Model, **kwargs):
Pickle file to resume from
model : :obj:`nessai.model.Model`
User-defined model
output : Optional[str]
New output directory. If not specified, the output directory will
be the same as the previous run.
Returns
-------
Expand All @@ -365,7 +396,9 @@ def resume(cls, filename: str, model: Model, **kwargs):
logger.info(f"Resuming {cls.__name__} from {filename}")
with open(filename, "rb") as f:
sampler = pickle.load(f)
return cls.resume_from_pickled_sampler(sampler, model, **kwargs)
return cls.resume_from_pickled_sampler(
sampler, model, output=output, **kwargs
)

@abstractmethod
def nested_sampling_loop(self):
Expand Down
7 changes: 7 additions & 0 deletions nessai/samplers/importancesampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,13 @@ def get_proposal(self, subdir: str = "levels", **kwargs):
proposal = ImportanceFlowProposal(self.model, output, **kwargs)
return proposal

def update_output(self, output: str) -> None:
super().update_output(output)
if self.proposal is not None:
# ImportanceFlowProposal uses a subdirectory
subdir = os.path.basename(os.path.normpath(self.proposal.output))
self.proposal.update_output(os.path.join(output, subdir, ""))

def configure_iterations(
self,
min_iteration: Optional[int] = None,
Expand Down
20 changes: 20 additions & 0 deletions nessai/samplers/nestedsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,26 @@ def configure_output(
if self.plot:
os.makedirs(os.path.join(output, "diagnostics"), exist_ok=True)

def update_output(self, output: str) -> None:
"""Update the output directory.
Also creates a "diagnostics" directory for plotting.
Parameters
----------
output : str
Path to the output directory.
"""
super().update_output(output)
if self.plot:
os.makedirs(os.path.join(output, "diagnostics"), exist_ok=True)
if self._flow_proposal is not None:
# FlowProposal uses a subdirectory
subdir = os.path.basename(
os.path.normpath(self._flow_proposal.output)
)
self._flow_proposal.update_output(os.path.join(output, subdir, ""))

def configure_flow_reset(
self, reset_weights, reset_permutations, reset_flow
):
Expand Down
18 changes: 14 additions & 4 deletions tests/test_flowsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def test_check_resume_files_do_not_exist(flow_sampler, tmp_path):

@pytest.mark.parametrize("resume", [False, True])
@pytest.mark.parametrize("use_ins", [False, True])
def test_init_no_resume_file(flow_sampler, tmp_path, resume, use_ins):
@pytest.mark.parametrize("specify_output", [False, True])
def test_init_no_resume_file(
flow_sampler, tmp_path, resume, use_ins, specify_output
):
"""Test the init method when there is no run to resume from"""

integration_model = MagicMock()
Expand All @@ -139,11 +142,12 @@ def test_init_no_resume_file(flow_sampler, tmp_path, resume, use_ins):
f"nessai.flowsampler.{sampler_class}", return_value="ns"
) as mock,
patch("nessai.flowsampler.configure_threads") as mock_threads,
patch("os.getcwd", return_value=output) as mock_getcwd,
):
FlowSampler.__init__(
flow_sampler,
integration_model,
output=output,
output=output if specify_output else None,
resume=resume,
exit_code=exit_code,
pytorch_threads=pytorch_threads,
Expand All @@ -169,10 +173,13 @@ def test_init_no_resume_file(flow_sampler, tmp_path, resume, use_ins):

flow_sampler.save_kwargs.assert_called_once_with(kwargs)

if not specify_output:
mock_getcwd.assert_called_once()


def test_resume_from_resume_data(flow_sampler, model, tmp_path):
"""Test for resume from data"""
output = tmp_path / "test"
output = str(tmp_path / "test")
data = object()
flow_sampler.check_resume = MagicMock(return_value=True)
flow_sampler._resume_from_data = MagicMock()
Expand All @@ -183,6 +190,7 @@ def test_resume_from_resume_data(flow_sampler, model, tmp_path):
NestedSampler,
resume_data=data,
model=model,
output=os.path.join(output, ""),
weights_path=None,
flow_config=None,
checkpoint_callback=None,
Expand All @@ -191,7 +199,7 @@ def test_resume_from_resume_data(flow_sampler, model, tmp_path):

def test_resume_from_resume_file(flow_sampler, model, tmp_path):
"""Test for resume from data"""
output = tmp_path / "test"
output = str(tmp_path / "test")
resume_file = "resume.pkl"
flow_sampler.check_resume = MagicMock(return_value=True)
flow_sampler._resume_from_file = MagicMock()
Expand All @@ -207,6 +215,7 @@ def test_resume_from_resume_file(flow_sampler, model, tmp_path):
NestedSampler,
resume_file=resume_file,
model=model,
output=os.path.join(output, ""),
weights_path=None,
flow_config=None,
checkpoint_callback=None,
Expand Down Expand Up @@ -381,6 +390,7 @@ def test_init_resume(tmp_path, test_old, error):
mock_resume.assert_called_with(
expected_rf,
integration_model,
output=os.path.join(output, ""),
flow_config=flow_config,
weights_path=weights_file,
checkpoint_callback=None,
Expand Down
14 changes: 14 additions & 0 deletions tests/test_proposal/test_base_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import logging
import os
import pickle
from unittest.mock import MagicMock, Mock, create_autospec

Expand Down Expand Up @@ -62,6 +63,19 @@ def test_initialise(proposal):
assert proposal.initialised is True


def test_update_output(proposal, tmp_path):
tmp_path = tmp_path / "test"
proposal.output = tmp_path / "orig"
Proposal.update_output(proposal, tmp_path)
assert proposal.output == tmp_path
assert os.path.exists(tmp_path)


def test_update_output_no_output(proposal):
Proposal.update_output(proposal, "test")
assert not hasattr(proposal, "output")


def test_evaluate_likelihoods(proposal):
"""Assert the correct method is called"""
samples = numpy_array_to_live_points(np.array([[1], [2]]), ["x"])
Expand Down
26 changes: 22 additions & 4 deletions tests/test_samplers/test_base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,22 +428,31 @@ def test_close_pool(sampler):
sampler.model.close_pool.assert_called_once_with(code=2)


def test_resume_from_pickled_sampler(model):
@pytest.mark.parametrize("output", [None, "orig", "new"])
def test_resume_from_pickled_sampler(model, output):
"""Test the resume from pickled sampler method"""
obj = MagicMock()
obj.model = None
obj.output = "orig"
obj._previous_likelihood_evaluations = 3
obj._previous_likelihood_evaluation_time = 4.0

model.likelihood_evaluations = 1
model.likelihood_evaluation_time = datetime.timedelta(seconds=2)

out = BaseNestedSampler.resume_from_pickled_sampler(obj, model)
out = BaseNestedSampler.resume_from_pickled_sampler(
obj, model, output=output
)

assert out.model == model
assert out.model.likelihood_evaluations == 4
assert out.model.likelihood_evaluation_time.total_seconds() == 6

if output == "new":
obj.update_output.assert_called_once_with("new")
else:
obj.update_output.assert_not_called()


def test_resume(model):
"""Test the resume method"""
Expand All @@ -458,11 +467,11 @@ def test_resume(model):
return_value=pickle_out,
) as mock_resume,
):
out = BaseNestedSampler.resume("test.pkl", model)
out = BaseNestedSampler.resume("test.pkl", model, output="test")

assert out is pickle_out
mock_pickle.assert_called_once()
mock_resume.assert_called_once_with(obj, model)
mock_resume.assert_called_once_with(obj, model, output="test")


def test_get_result_dictionary(sampler):
Expand Down Expand Up @@ -516,3 +525,12 @@ def test_update_history(sampler):
BaseNestedSampler.update_history(sampler)
assert sampler.history["likelihood_evaluations"] == [10, 20]
assert sampler.history["sampling_time"] == [1, 2]


def test_update_output(sampler, tmp_path):
sampler.output = tmp_path / "orig"
sampler.resume_file = sampler.output / "resume.pkl"
new_output = tmp_path / "new"
BaseNestedSampler.update_output(sampler, new_output)
assert sampler.output == new_output
assert sampler.resume_file == str(new_output / "resume.pkl")
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test configuration of INS"""

from unittest.mock import MagicMock
import os
from unittest.mock import MagicMock, patch

import numpy as np
import pytest
Expand Down Expand Up @@ -81,3 +82,21 @@ def check_configuration_okay(ins):
ins.nlive = 100
ins.min_remove = 1
assert INS.check_configuration(ins) is True


@pytest.mark.parametrize("has_proposal", [False, True])
def test_update_output(ins, tmp_path, has_proposal):
output = tmp_path / "new"
if has_proposal:
ins.proposal = MagicMock()
ins.proposal.output = tmp_path / "orig" / "levels"
else:
ins.proposal = None
with patch("nessai.samplers.base.BaseNestedSampler.update_output") as mock:
INS.update_output(ins, output)

mock.assert_called_once_with(output)
if has_proposal:
ins.proposal.update_output.assert_called_once_with(
os.path.join(output, "levels", "")
)
Loading

0 comments on commit 3d19e39

Please sign in to comment.