diff --git a/src/nessai/config.py b/src/nessai/config.py index 6999facb..607590b4 100644 --- a/src/nessai/config.py +++ b/src/nessai/config.py @@ -144,6 +144,12 @@ class PlottingConfig(_BaseConfig): Based on the default DPI in matplotlib of 100, so this will give a maximum size of 5000 pixels. """ + clip_min: float = -1e10 + """Minimum value to clip data to for plotting. + + This is used to avoid issues with plots where the dynamic range is too + large for matplotlib to handle. + """ @dataclass diff --git a/src/nessai/plot.py b/src/nessai/plot.py index ff51b874..e1c69023 100644 --- a/src/nessai/plot.py +++ b/src/nessai/plot.py @@ -68,6 +68,36 @@ def wrapper(*args, **kwargs): return decorator +def sanitise_array( + a: np.ndarray, + /, + a_min: Optional[float] = None, + a_max: Optional[float] = None, +): + """Sanitise an array for plotting. + + If :code:`x_min` is not specified, it is set to the value in + :code:`nessai.config.plotting.clip_min`. + + Parameters + ---------- + x : array_like + Array to sanitise. + x_min : float, optional + Minimum value to clip the data to. + xmax : float, optional + Maximum value to clip the data to. + + Returns + ------- + np.ndarray + Sanitised array. + """ + if a_min is None: + a_min = config.plotting.clip_min + return np.clip(a, a_min, a_max) + + @nessai_style() def plot_live_points( live_points, filename=None, bounds=None, c=None, **kwargs diff --git a/src/nessai/samplers/nestedsampler.py b/src/nessai/samplers/nestedsampler.py index 9b662359..0ea0545d 100644 --- a/src/nessai/samplers/nestedsampler.py +++ b/src/nessai/samplers/nestedsampler.py @@ -19,7 +19,7 @@ from .. import config from ..evidence import _NSIntegralState from ..livepoint import empty_structured_array -from ..plot import nessai_style, plot_indices, plot_trace +from ..plot import nessai_style, plot_indices, plot_trace, sanitise_array from ..proposal.utils import ( check_proposal_kwargs, get_flow_proposal_class, @@ -1004,11 +1004,39 @@ def plot_state(self, filename=None): for a in ax: a.axvline(self.iteration, c="#ff9900", ls="-.") - ax[0].plot(it, self.history["min_log_likelihood"], label="Min log L") - ax[0].plot(it, self.history["max_log_likelihood"], label="Max log L") + ax[0].plot( + it, + sanitise_array(self.history["min_log_likelihood"]), + label="Min log L", + ) + ax[0].plot( + it, + sanitise_array(self.history["max_log_likelihood"]), + label="Max log L", + ) ax[0].set_ylabel(r"$\log L$") ax[0].legend(frameon=False) + ax_logl_diff = plt.twinx(ax[0]) + logl_diff = sanitise_array( + np.array(self.history["max_log_likelihood"]) + - np.array(self.history["min_log_likelihood"]), + ) + ax_logl_diff.plot( + it, + logl_diff, + c="C2", + ls=config.plotting.line_styles[2], + label=r"$\Delta \log L$ ", + ) + ax_logl_diff.set_yscale("log") + ax_logl_diff.set_ylabel(r"$\Delta \log L$") + handles, labels = ax[0].get_legend_handles_labels() + handles_diff, labels_diff = ax_logl_diff.get_legend_handles_labels() + ax[0].legend( + handles + handles_diff, labels + labels_diff, frameon=False + ) + logX_its = np.arange(len(self.state.log_vols)) ax[1].plot(logX_its, self.state.log_vols, label="log X") ax[1].set_ylabel(r"$\log X$") @@ -1054,7 +1082,7 @@ def plot_state(self, filename=None): handles + handles_time, labels + labels_time, frameon=False ) - ax[3].plot(it, self.history["logZ"], label="logZ") + ax[3].plot(it, sanitise_array(self.history["logZ"]), label="logZ") ax[3].set_ylabel(r"$\log Z$") ax[3].legend(frameon=False)