Skip to content

Commit

Permalink
Merge pull request #442 from mj-will/state-plot-improvements
Browse files Browse the repository at this point in the history
ENH: add delta logL to state plot and sanitse inputs
  • Loading branch information
mj-will authored Dec 12, 2024
2 parents f44294b + 505aa94 commit 5ef5885
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 4 deletions.
6 changes: 6 additions & 0 deletions src/nessai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ class PlottingConfig(_BaseConfig):
Based on the default DPI in matplotlib of 100, so this will give a maximum
size of 5000 pixels.
"""
clip_min: float = -1e10
"""Minimum value to clip data to for plotting.
This is used to avoid issues with plots where the dynamic range is too
large for matplotlib to handle.
"""


@dataclass
Expand Down
30 changes: 30 additions & 0 deletions src/nessai/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,36 @@ def wrapper(*args, **kwargs):
return decorator


def sanitise_array(
a: np.ndarray,
/,
a_min: Optional[float] = None,
a_max: Optional[float] = None,
):
"""Sanitise an array for plotting.
If :code:`x_min` is not specified, it is set to the value in
:code:`nessai.config.plotting.clip_min`.
Parameters
----------
x : array_like
Array to sanitise.
x_min : float, optional
Minimum value to clip the data to.
xmax : float, optional
Maximum value to clip the data to.
Returns
-------
np.ndarray
Sanitised array.
"""
if a_min is None:
a_min = config.plotting.clip_min
return np.clip(a, a_min, a_max)


@nessai_style()
def plot_live_points(
live_points, filename=None, bounds=None, c=None, **kwargs
Expand Down
36 changes: 32 additions & 4 deletions src/nessai/samplers/nestedsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .. import config
from ..evidence import _NSIntegralState
from ..livepoint import empty_structured_array
from ..plot import nessai_style, plot_indices, plot_trace
from ..plot import nessai_style, plot_indices, plot_trace, sanitise_array
from ..proposal.utils import (
check_proposal_kwargs,
get_flow_proposal_class,
Expand Down Expand Up @@ -1004,11 +1004,39 @@ def plot_state(self, filename=None):
for a in ax:
a.axvline(self.iteration, c="#ff9900", ls="-.")

ax[0].plot(it, self.history["min_log_likelihood"], label="Min log L")
ax[0].plot(it, self.history["max_log_likelihood"], label="Max log L")
ax[0].plot(
it,
sanitise_array(self.history["min_log_likelihood"]),
label="Min log L",
)
ax[0].plot(
it,
sanitise_array(self.history["max_log_likelihood"]),
label="Max log L",
)
ax[0].set_ylabel(r"$\log L$")
ax[0].legend(frameon=False)

ax_logl_diff = plt.twinx(ax[0])
logl_diff = sanitise_array(
np.array(self.history["max_log_likelihood"])
- np.array(self.history["min_log_likelihood"]),
)
ax_logl_diff.plot(
it,
logl_diff,
c="C2",
ls=config.plotting.line_styles[2],
label=r"$\Delta \log L$ ",
)
ax_logl_diff.set_yscale("log")
ax_logl_diff.set_ylabel(r"$\Delta \log L$")
handles, labels = ax[0].get_legend_handles_labels()
handles_diff, labels_diff = ax_logl_diff.get_legend_handles_labels()
ax[0].legend(
handles + handles_diff, labels + labels_diff, frameon=False
)

logX_its = np.arange(len(self.state.log_vols))
ax[1].plot(logX_its, self.state.log_vols, label="log X")
ax[1].set_ylabel(r"$\log X$")
Expand Down Expand Up @@ -1054,7 +1082,7 @@ def plot_state(self, filename=None):
handles + handles_time, labels + labels_time, frameon=False
)

ax[3].plot(it, self.history["logZ"], label="logZ")
ax[3].plot(it, sanitise_array(self.history["logZ"]), label="logZ")
ax[3].set_ylabel(r"$\log Z$")
ax[3].legend(frameon=False)

Expand Down

0 comments on commit 5ef5885

Please sign in to comment.