Skip to content

Commit

Permalink
TST: add tests for update_output and resume
Browse files Browse the repository at this point in the history
  • Loading branch information
mj-will committed Sep 16, 2024
1 parent 3cf460a commit 7d66fee
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 7 deletions.
7 changes: 5 additions & 2 deletions tests/test_flowsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_init_no_resume_file(flow_sampler, tmp_path, resume, use_ins):

def test_resume_from_resume_data(flow_sampler, model, tmp_path):
"""Test for resume from data"""
output = tmp_path / "test"
output = str(tmp_path / "test")
data = object()
flow_sampler.check_resume = MagicMock(return_value=True)
flow_sampler._resume_from_data = MagicMock()
Expand All @@ -183,6 +183,7 @@ def test_resume_from_resume_data(flow_sampler, model, tmp_path):
NestedSampler,
resume_data=data,
model=model,
output=os.path.join(output, ""),
weights_path=None,
flow_config=None,
checkpoint_callback=None,
Expand All @@ -191,7 +192,7 @@ def test_resume_from_resume_data(flow_sampler, model, tmp_path):

def test_resume_from_resume_file(flow_sampler, model, tmp_path):
"""Test for resume from data"""
output = tmp_path / "test"
output = str(tmp_path / "test")
resume_file = "resume.pkl"
flow_sampler.check_resume = MagicMock(return_value=True)
flow_sampler._resume_from_file = MagicMock()
Expand All @@ -207,6 +208,7 @@ def test_resume_from_resume_file(flow_sampler, model, tmp_path):
NestedSampler,
resume_file=resume_file,
model=model,
output=os.path.join(output, ""),
weights_path=None,
flow_config=None,
checkpoint_callback=None,
Expand Down Expand Up @@ -381,6 +383,7 @@ def test_init_resume(tmp_path, test_old, error):
mock_resume.assert_called_with(
expected_rf,
integration_model,
output=os.path.join(output, ""),
flow_config=flow_config,
weights_path=weights_file,
checkpoint_callback=None,
Expand Down
14 changes: 14 additions & 0 deletions tests/test_proposal/test_base_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import logging
import os
import pickle
from unittest.mock import MagicMock, Mock, create_autospec

Expand Down Expand Up @@ -62,6 +63,19 @@ def test_initialise(proposal):
assert proposal.initialised is True


def test_update_output(proposal, tmp_path):
tmp_path = tmp_path / "test"
proposal.output = tmp_path / "orig"
Proposal.update_output(proposal, tmp_path)
assert proposal.output == tmp_path
assert os.path.exists(tmp_path)


def test_update_output_no_output(proposal):
Proposal.update_output(proposal, "test")
assert not hasattr(proposal, "output")


def test_evaluate_likelihoods(proposal):
"""Assert the correct method is called"""
samples = numpy_array_to_live_points(np.array([[1], [2]]), ["x"])
Expand Down
26 changes: 22 additions & 4 deletions tests/test_samplers/test_base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,22 +428,31 @@ def test_close_pool(sampler):
sampler.model.close_pool.assert_called_once_with(code=2)


def test_resume_from_pickled_sampler(model):
@pytest.mark.parametrize("output", [None, "orig", "new"])
def test_resume_from_pickled_sampler(model, output):
"""Test the resume from pickled sampler method"""
obj = MagicMock()
obj.model = None
obj.output = "orig"
obj._previous_likelihood_evaluations = 3
obj._previous_likelihood_evaluation_time = 4.0

model.likelihood_evaluations = 1
model.likelihood_evaluation_time = datetime.timedelta(seconds=2)

out = BaseNestedSampler.resume_from_pickled_sampler(obj, model)
out = BaseNestedSampler.resume_from_pickled_sampler(
obj, model, output=output
)

assert out.model == model
assert out.model.likelihood_evaluations == 4
assert out.model.likelihood_evaluation_time.total_seconds() == 6

if output == "new":
obj.update_output.assert_called_once_with("new")
else:
obj.update_output.assert_not_called()


def test_resume(model):
"""Test the resume method"""
Expand All @@ -458,11 +467,11 @@ def test_resume(model):
return_value=pickle_out,
) as mock_resume,
):
out = BaseNestedSampler.resume("test.pkl", model)
out = BaseNestedSampler.resume("test.pkl", model, output="test")

assert out is pickle_out
mock_pickle.assert_called_once()
mock_resume.assert_called_once_with(obj, model)
mock_resume.assert_called_once_with(obj, model, output="test")


def test_get_result_dictionary(sampler):
Expand Down Expand Up @@ -516,3 +525,12 @@ def test_update_history(sampler):
BaseNestedSampler.update_history(sampler)
assert sampler.history["likelihood_evaluations"] == [10, 20]
assert sampler.history["sampling_time"] == [1, 2]


def test_update_output(sampler, tmp_path):
sampler.output = tmp_path / "orig"
sampler.resume_file = sampler.output / "resume.pkl"
new_output = tmp_path / "new"
BaseNestedSampler.update_output(sampler, new_output)
assert sampler.output == new_output
assert sampler.resume_file == str(new_output / "resume.pkl")
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test configuration of INS"""

from unittest.mock import MagicMock
import os
from unittest.mock import MagicMock, patch

import numpy as np
import pytest
Expand Down Expand Up @@ -81,3 +82,21 @@ def check_configuration_okay(ins):
ins.nlive = 100
ins.min_remove = 1
assert INS.check_configuration(ins) is True


@pytest.mark.parametrize("has_proposal", [False, True])
def test_update_output(ins, tmp_path, has_proposal):
output = tmp_path / "new"
if has_proposal:
ins.proposal = MagicMock()
ins.proposal.output = tmp_path / "orig" / "levels"
else:
ins.proposal = None
with patch("nessai.samplers.base.BaseNestedSampler.update_output") as mock:
INS.update_output(ins, output)

mock.assert_called_once_with(output)
if has_proposal:
ins.proposal.update_output.assert_called_once_with(
os.path.join(output, "levels", "")
)
23 changes: 23 additions & 0 deletions tests/test_samplers/test_nested_sampler/test_general_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,29 @@ def test_setup_output(sampler, tmpdir, plot):
assert os.path.exists(os.path.join(path, "diagnostics"))


@pytest.mark.parametrize("plot", [False, True])
@pytest.mark.parametrize("has_proposal", [False, True])
def test_update_output(sampler, tmp_path, plot, has_proposal):
output = tmp_path / "new"
sampler.plot = plot
if has_proposal:
sampler._flow_proposal = MagicMock()
sampler._flow_proposal.output = tmp_path / "orig" / "proposal"
else:
sampler._flow_proposal = None
with patch("nessai.samplers.base.BaseNestedSampler.update_output") as mock:
NestedSampler.update_output(sampler, output)

mock.assert_called_once_with(output)
if has_proposal:
sampler._flow_proposal.update_output.assert_called_once_with(
os.path.join(output, "proposal", "")
)

if plot:
assert os.path.exists(os.path.join(output, "diagnostics"))


def test_configure_max_iteration(sampler):
"""Test to make sure the maximum iteration is set correctly"""
NestedSampler.configure_max_iteration(sampler, 10)
Expand Down

0 comments on commit 7d66fee

Please sign in to comment.