From 8e96993af66590f9482001046e9cc0009e3d8b0a Mon Sep 17 00:00:00 2001 From: David Straub Date: Thu, 25 Jun 2020 17:06:09 +0200 Subject: [PATCH 1/2] [backends] Make backends inherit from common base class --- src/emcee/backends/backend.py | 177 ++++---------------------- src/emcee/backends/base.py | 226 ++++++++++++++++++++++++++++++++++ src/emcee/backends/hdf.py | 22 ++-- 3 files changed, 261 insertions(+), 164 deletions(-) create mode 100644 src/emcee/backends/base.py diff --git a/src/emcee/backends/backend.py b/src/emcee/backends/backend.py index 9d5cfde7..db720e2e 100644 --- a/src/emcee/backends/backend.py +++ b/src/emcee/backends/backend.py @@ -4,15 +4,16 @@ from .. import autocorr from ..state import State +from .base import BackendBase __all__ = ["Backend"] -class Backend(object): +class Backend(BackendBase): """A simple default backend that stores the chain in memory""" def __init__(self, dtype=None): - self.initialized = False + self._initialized = False if dtype is None: dtype = np.float self.dtype = dtype @@ -27,26 +28,34 @@ def reset(self, nwalkers, ndim): """ self.nwalkers = int(nwalkers) self.ndim = int(ndim) - self.iteration = 0 + self._iteration = 0 self.accepted = np.zeros(self.nwalkers, dtype=self.dtype) self.chain = np.empty((0, self.nwalkers, self.ndim), dtype=self.dtype) self.log_prob = np.empty((0, self.nwalkers), dtype=self.dtype) self.blobs = None - self.random_state = None - self.initialized = True + self._random_state = None + self._initialized = True def has_blobs(self): """Returns ``True`` if the model includes blobs""" return self.blobs is not None - def get_value(self, name, flat=False, thin=1, discard=0): - if self.iteration <= 0: - raise AttributeError( - "you must run the sampler with " - "'store == True' before accessing the " - "results" - ) + @property + def iteration(self): + """Return the iteration number.""" + return self._iteration + + @property + def initialized(self): + """Return true if backend has been initialized.""" + return self._initialized + @property + def shape(self): + """The dimensions of the ensemble ``(nwalkers, ndim)``""" + return self.nwalkers, self.ndim + + def _get_value(self, name, flat, thin, discard): if name == "blobs" and not self.has_blobs(): return None @@ -57,110 +66,6 @@ def get_value(self, name, flat=False, thin=1, discard=0): return v.reshape(s) return v - def get_chain(self, **kwargs): - """Get the stored chain of MCMC samples - - Args: - flat (Optional[bool]): Flatten the chain across the ensemble. - (default: ``False``) - thin (Optional[int]): Take only every ``thin`` steps from the - chain. (default: ``1``) - discard (Optional[int]): Discard the first ``discard`` steps in - the chain as burn-in. (default: ``0``) - - Returns: - array[..., nwalkers, ndim]: The MCMC samples. - - """ - return self.get_value("chain", **kwargs) - - def get_blobs(self, **kwargs): - """Get the chain of blobs for each sample in the chain - - Args: - flat (Optional[bool]): Flatten the chain across the ensemble. - (default: ``False``) - thin (Optional[int]): Take only every ``thin`` steps from the - chain. (default: ``1``) - discard (Optional[int]): Discard the first ``discard`` steps in - the chain as burn-in. (default: ``0``) - - Returns: - array[..., nwalkers]: The chain of blobs. - - """ - return self.get_value("blobs", **kwargs) - - def get_log_prob(self, **kwargs): - """Get the chain of log probabilities evaluated at the MCMC samples - - Args: - flat (Optional[bool]): Flatten the chain across the ensemble. - (default: ``False``) - thin (Optional[int]): Take only every ``thin`` steps from the - chain. (default: ``1``) - discard (Optional[int]): Discard the first ``discard`` steps in - the chain as burn-in. (default: ``0``) - - Returns: - array[..., nwalkers]: The chain of log probabilities. - - """ - return self.get_value("log_prob", **kwargs) - - def get_last_sample(self): - """Access the most recent sample in the chain""" - if (not self.initialized) or self.iteration <= 0: - raise AttributeError( - "you must run the sampler with " - "'store == True' before accessing the " - "results" - ) - it = self.iteration - blobs = self.get_blobs(discard=it - 1) - if blobs is not None: - blobs = blobs[0] - return State( - self.get_chain(discard=it - 1)[0], - log_prob=self.get_log_prob(discard=it - 1)[0], - blobs=blobs, - random_state=self.random_state, - ) - - def get_autocorr_time(self, discard=0, thin=1, **kwargs): - """Compute an estimate of the autocorrelation time for each parameter - - Args: - thin (Optional[int]): Use only every ``thin`` steps from the - chain. The returned estimate is multiplied by ``thin`` so the - estimated time is in units of steps, not thinned steps. - (default: ``1``) - discard (Optional[int]): Discard the first ``discard`` steps in - the chain as burn-in. (default: ``0``) - - Other arguments are passed directly to - :func:`emcee.autocorr.integrated_time`. - - Returns: - array[ndim]: The integrated autocorrelation time estimate for the - chain for each parameter. - - """ - x = self.get_chain(discard=discard, thin=thin) - return thin * autocorr.integrated_time(x, **kwargs) - - @property - def shape(self): - """The dimensions of the ensemble ``(nwalkers, ndim)``""" - return self.nwalkers, self.ndim - - def _check_blobs(self, blobs): - has_blobs = self.has_blobs() - if has_blobs and blobs is None: - raise ValueError("inconsistent use of blobs") - if self.iteration > 0 and blobs is not None and not has_blobs: - raise ValueError("inconsistent use of blobs") - def grow(self, ngrow, blobs): """Expand the storage space by some number of samples @@ -184,33 +89,6 @@ def grow(self, ngrow, blobs): else: self.blobs = np.concatenate((self.blobs, a), axis=0) - def _check(self, state, accepted): - self._check_blobs(state.blobs) - nwalkers, ndim = self.shape - has_blobs = self.has_blobs() - if state.coords.shape != (nwalkers, ndim): - raise ValueError( - "invalid coordinate dimensions; expected {0}".format( - (nwalkers, ndim) - ) - ) - if state.log_prob.shape != (nwalkers,): - raise ValueError( - "invalid log probability size; expected {0}".format(nwalkers) - ) - if state.blobs is not None and not has_blobs: - raise ValueError("unexpected blobs") - if state.blobs is None and has_blobs: - raise ValueError("expected blobs, but none were given") - if state.blobs is not None and len(state.blobs) != nwalkers: - raise ValueError( - "invalid blobs size; expected {0}".format(nwalkers) - ) - if accepted.shape != (nwalkers,): - raise ValueError( - "invalid acceptance size; expected {0}".format(nwalkers) - ) - def save_step(self, state, accepted): """Save a step to the backend @@ -227,11 +105,10 @@ def save_step(self, state, accepted): if state.blobs is not None: self.blobs[self.iteration, :] = state.blobs self.accepted += accepted - self.random_state = state.random_state - self.iteration += 1 + self._random_state = state.random_state + self._iteration += 1 - def __enter__(self): - return self - - def __exit__(self, exception_type, exception_value, traceback): - pass + @property + def random_state(self): + """Return the random state.""" + return self._random_state diff --git a/src/emcee/backends/base.py b/src/emcee/backends/base.py new file mode 100644 index 00000000..98b85fdc --- /dev/null +++ b/src/emcee/backends/base.py @@ -0,0 +1,226 @@ +# -*- coding: utf-8 -*- + +"""Backend base class.""" + + +from .. import autocorr +from ..state import State + +__all__ = ["BackendBase"] + + +class BackendBase: + """Backend base class. Not meant to be used directly.""" + + # Methods to be implemented by children + + def __init__(self, dtype=None): + raise NotImplementedError("Method must be implemented by child class.") + + def reset(self, nwalkers, ndim): + """Clear the state of the chain and empty the backend + + Args: + nwakers (int): The size of the ensemble + ndim (int): The number of dimensions + + """ + raise NotImplementedError("Method must be implemented by child class.") + + def has_blobs(self): + """Returns ``True`` if the model includes blobs.""" + raise NotImplementedError("Method must be implemented by child class.") + + @property + def iteration(self): + """Return the iteration number.""" + raise NotImplementedError("Method must be implemented by child class.") + + @property + def initialized(self): + """Return true if backend has been initialized.""" + raise NotImplementedError("Method must be implemented by child class.") + + @property + def shape(self): + """The dimensions of the ensemble ``(nwalkers, ndim)``""" + raise NotImplementedError("Method must be implemented by child class.") + + def _get_value(self, name, flat, thin, discard): + """Get a value from the backend.""" + raise NotImplementedError("Method must be implemented by child class.") + + def grow(self, ngrow, blobs): + """Expand the storage space by some number of samples + + Args: + ngrow (int): The number of steps to grow the chain. + blobs: The current list of blobs. This is used to compute the + dtype for the blobs array. + + """ + raise NotImplementedError("Method must be implemented by child class.") + + def save_step(self, state, accepted): + """Save a step to the backend + + Args: + state (State): The :class:`State` of the ensemble. + accepted (ndarray): An array of boolean flags indicating whether + or not the proposal for each walker was accepted. + + """ + raise NotImplementedError("Method must be implemented by child class.") + + @property + def random_state(self): + """Return the random state.""" + raise NotImplementedError("Method must be implemented by child class.") + + # Methods that *can* be overwritten by children + + def __enter__(self): + """Enter method for context manager.""" + return self + + def __exit__(self, exception_type, exception_value, traceback): + """Exit method for context manager.""" + pass + + # Common methods + + def get_value(self, name, flat=False, thin=1, discard=0): + """Get a value from the backend.""" + if not self.initialized or self.iteration <= 0: + raise AttributeError( + "you must run the sampler with " + "'store == True' before accessing the " + "results" + ) + return self._get_value(name, flat=flat, thin=thin, discard=discard) + + def get_chain(self, **kwargs): + """Get the stored chain of MCMC samples + + Args: + flat (Optional[bool]): Flatten the chain across the ensemble. + (default: ``False``) + thin (Optional[int]): Take only every ``thin`` steps from the + chain. (default: ``1``) + discard (Optional[int]): Discard the first ``discard`` steps in + the chain as burn-in. (default: ``0``) + + Returns: + array[..., nwalkers, ndim]: The MCMC samples. + + """ + return self.get_value("chain", **kwargs) + + def get_blobs(self, **kwargs): + """Get the chain of blobs for each sample in the chain + + Args: + flat (Optional[bool]): Flatten the chain across the ensemble. + (default: ``False``) + thin (Optional[int]): Take only every ``thin`` steps from the + chain. (default: ``1``) + discard (Optional[int]): Discard the first ``discard`` steps in + the chain as burn-in. (default: ``0``) + + Returns: + array[..., nwalkers]: The chain of blobs. + + """ + return self.get_value("blobs", **kwargs) + + def get_log_prob(self, **kwargs): + """Get the chain of log probabilities evaluated at the MCMC samples + + Args: + flat (Optional[bool]): Flatten the chain across the ensemble. + (default: ``False``) + thin (Optional[int]): Take only every ``thin`` steps from the + chain. (default: ``1``) + discard (Optional[int]): Discard the first ``discard`` steps in + the chain as burn-in. (default: ``0``) + + Returns: + array[..., nwalkers]: The chain of log probabilities. + + """ + return self.get_value("log_prob", **kwargs) + + def get_last_sample(self): + """Access the most recent sample in the chain""" + if (not self.initialized) or self.iteration <= 0: + raise AttributeError( + "you must run the sampler with " + "'store == True' before accessing the " + "results" + ) + it = self.iteration + blobs = self.get_blobs(discard=it - 1) + if blobs is not None: + blobs = blobs[0] + return State( + self.get_chain(discard=it - 1)[0], + log_prob=self.get_log_prob(discard=it - 1)[0], + blobs=blobs, + random_state=self.random_state, + ) + + def get_autocorr_time(self, discard=0, thin=1, **kwargs): + """Compute an estimate of the autocorrelation time for each parameter + + Args: + thin (Optional[int]): Use only every ``thin`` steps from the + chain. The returned estimate is multiplied by ``thin`` so the + estimated time is in units of steps, not thinned steps. + (default: ``1``) + discard (Optional[int]): Discard the first ``discard`` steps in + the chain as burn-in. (default: ``0``) + + Other arguments are passed directly to + :func:`emcee.autocorr.integrated_time`. + + Returns: + array[ndim]: The integrated autocorrelation time estimate for the + chain for each parameter. + + """ + x = self.get_chain(discard=discard, thin=thin) + return thin * autocorr.integrated_time(x, **kwargs) + + def _check_blobs(self, blobs): + has_blobs = self.has_blobs() + if has_blobs and blobs is None: + raise ValueError("inconsistent use of blobs") + if self.iteration > 0 and blobs is not None and not has_blobs: + raise ValueError("inconsistent use of blobs") + + def _check(self, state, accepted): + self._check_blobs(state.blobs) + nwalkers, ndim = self.shape + has_blobs = self.has_blobs() + if state.coords.shape != (nwalkers, ndim): + raise ValueError( + "invalid coordinate dimensions; expected {0}".format( + (nwalkers, ndim) + ) + ) + if state.log_prob.shape != (nwalkers,): + raise ValueError( + "invalid log probability size; expected {0}".format(nwalkers) + ) + if state.blobs is not None and not has_blobs: + raise ValueError("unexpected blobs") + if state.blobs is None and has_blobs: + raise ValueError("expected blobs, but none were given") + if state.blobs is not None and len(state.blobs) != nwalkers: + raise ValueError( + "invalid blobs size; expected {0}".format(nwalkers) + ) + if accepted.shape != (nwalkers,): + raise ValueError( + "invalid acceptance size; expected {0}".format(nwalkers) + ) diff --git a/src/emcee/backends/hdf.py b/src/emcee/backends/hdf.py index d4e6d6c1..0fb42e94 100644 --- a/src/emcee/backends/hdf.py +++ b/src/emcee/backends/hdf.py @@ -10,7 +10,7 @@ import numpy as np from .. import __version__ -from .backend import Backend +from .base import BackendBase try: @@ -19,7 +19,7 @@ h5py = None -class HDFBackend(Backend): +class HDFBackend(BackendBase): """A backend that stores the chain in an HDF5 file using h5py .. note:: You must install `h5py `_ to use this @@ -34,6 +34,7 @@ class HDFBackend(Backend): ``RuntimeError`` if the file is opened with write access. """ + def __init__(self, filename, name="mcmc", read_only=False, dtype=None): if h5py is None: raise ImportError("you must install 'h5py' to use the HDFBackend") @@ -108,13 +109,7 @@ def has_blobs(self): with self.open() as f: return f[self.name].attrs["has_blobs"] - def get_value(self, name, flat=False, thin=1, discard=0): - if not self.initialized: - raise AttributeError( - "You must run the sampler with " - "'store == True' before accessing the " - "results" - ) + def _get_value(self, name, flat, thin, discard): with self.open() as f: g = f[self.name] iteration = g.attrs["iteration"] @@ -219,16 +214,15 @@ def save_step(self, state, accepted): g.attrs["iteration"] = iteration + 1 -class TempHDFBackend(object): - +class TempHDFBackend: def __init__(self, dtype=None): self.dtype = dtype self.filename = None def __enter__(self): - f = NamedTemporaryFile(prefix="emcee-temporary-hdf5", - suffix=".hdf5", - delete=False) + f = NamedTemporaryFile( + prefix="emcee-temporary-hdf5", suffix=".hdf5", delete=False + ) f.close() self.filename = f.name return HDFBackend(f.name, "test", dtype=self.dtype) From 3346472bccada470663e9c6169c3722760dd72bf Mon Sep 17 00:00:00 2001 From: David Straub Date: Thu, 25 Jun 2020 17:19:37 +0200 Subject: [PATCH 2/2] [backends] Make HDF inherit from file backend --- src/emcee/backends/file.py | 40 ++++++++++++++++++++++++++++++++++++++ src/emcee/backends/hdf.py | 7 ++++--- 2 files changed, 44 insertions(+), 3 deletions(-) create mode 100644 src/emcee/backends/file.py diff --git a/src/emcee/backends/file.py b/src/emcee/backends/file.py new file mode 100644 index 00000000..be214d79 --- /dev/null +++ b/src/emcee/backends/file.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- + +from __future__ import division, print_function + +__all__ = ["FileBackend"] + +import os + +from .. import __version__ +from .base import BackendBase + + +class FileBackend(BackendBase): + """A backend that stores the chain in a file. + + This is a base class for file-based backends, not meant to be used directly. + + Args: + filename (str): The name of the HDF5 file where the chain will be + saved. + read_only (bool; optional): If ``True``, the backend will throw a + ``RuntimeError`` if the file is opened with write access. + + """ + + def __init__(self, filename, read_only=False): + """Initialize self given a file name. + + If ``read_only`` is ``True``, will throw a ``RuntimeError`` + if the file is opened with write access. + """ + self.filename = filename + self.read_only = read_only + + @property + def initialized(self): + """Return True if the backend has been initialized.""" + if not os.path.exists(self.filename): + return False + return True diff --git a/src/emcee/backends/hdf.py b/src/emcee/backends/hdf.py index 0fb42e94..121b224c 100644 --- a/src/emcee/backends/hdf.py +++ b/src/emcee/backends/hdf.py @@ -10,7 +10,7 @@ import numpy as np from .. import __version__ -from .base import BackendBase +from .file import FileBackend try: @@ -19,7 +19,7 @@ h5py = None -class HDFBackend(BackendBase): +class HDFBackend(FileBackend): """A backend that stores the chain in an HDF5 file using h5py .. note:: You must install `h5py `_ to use this @@ -38,6 +38,7 @@ class HDFBackend(BackendBase): def __init__(self, filename, name="mcmc", read_only=False, dtype=None): if h5py is None: raise ImportError("you must install 'h5py' to use the HDFBackend") + super().__init__(filename=filename, read_only=read_only) self.filename = filename self.name = name self.read_only = read_only @@ -50,7 +51,7 @@ def __init__(self, filename, name="mcmc", read_only=False, dtype=None): @property def initialized(self): - if not os.path.exists(self.filename): + if not super().initialized: return False try: with self.open() as f: