Skip to content

Commit

Permalink
Cluster progress (#195)
Browse files Browse the repository at this point in the history
* fix custom metric example

* make Index projection consistently compile-time enabled

* rework progress handling

* tram toc update

* doc fixes and tests

* notebooks update
  • Loading branch information
clonker authored Jan 24, 2022
1 parent 5a785b2 commit 2f16e0f
Show file tree
Hide file tree
Showing 19 changed files with 229 additions and 79 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ add_subdirectory(deeptime/markov/msm/tram/_bindings)
add_subdirectory(deeptime/markov/tools/estimation/dense/_bindings)
add_subdirectory(deeptime/markov/tools/estimation/sparse/_bindings)

add_subdirectory(examples/clustering_custom_metric)

if(DEEPTIME_BUILD_CPP_TESTS)
add_subdirectory(tests)
endif()
45 changes: 40 additions & 5 deletions deeptime/clustering/_kmeans.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import random
import warnings
from contextlib import nullcontext
from typing import Optional

import numpy as np

from ..base import EstimatorTransformer
from ._cluster_model import ClusterModel
from . import metrics
from ..util.callbacks import ProgressCallback

from ..util.parallel import handle_n_jobs

Expand Down Expand Up @@ -173,6 +175,10 @@ class KMeans(EstimatorTransformer):
initial_centers: None or np.ndarray[k, dim], default=None
This is used to resume the kmeans iteration. Note, that if this is set, the init_strategy is ignored and
the centers are directly passed to the kmeans iteration algorithm.
progress : object
Progress bar object that `KMeans` will call to indicate progress to the user. Tested for a tqdm progress bar.
The interface is checked
via :meth:`supports_progress_interface <deeptime.util.callbacks.supports_progress_interface>`.
References
----------
Expand All @@ -186,7 +192,7 @@ class KMeans(EstimatorTransformer):

def __init__(self, n_clusters: int, max_iter: int = 500, metric='euclidean',
tolerance=1e-5, init_strategy: str = 'kmeans++', fixed_seed=False,
n_jobs=None, initial_centers=None):
n_jobs=None, initial_centers=None, progress=None):
super(KMeans, self).__init__()

self.n_clusters = n_clusters
Expand All @@ -198,6 +204,7 @@ def __init__(self, n_clusters: int, max_iter: int = 500, metric='euclidean',
self.random_state = np.random.RandomState(self.fixed_seed)
self.n_jobs = handle_n_jobs(n_jobs)
self.initial_centers = initial_centers
self.progress = progress

@property
def initial_centers(self) -> Optional[np.ndarray]:
Expand Down Expand Up @@ -421,14 +428,30 @@ def fit(self, data, initial_centers=None, callback_init_centers=None, callback_l
if initial_centers is not None:
self.initial_centers = initial_centers
if self.initial_centers is None:
self.initial_centers = self._pick_initial_centers(data, self.init_strategy, n_jobs, callback_init_centers)
if self.progress is not None:
callback = KMeansCallback(self.progress, "KMeans++ initialization", self.n_clusters,
callback_init_centers)
context = callback
else:
callback = callback_init_centers
context = nullcontext()
with context:
self.initial_centers = self._pick_initial_centers(data, self.init_strategy, n_jobs, callback)

# run k-means with all the data
converged = False
impl = metrics[self.metric]
cluster_centers, code, iterations, cost = impl.kmeans.cluster_loop(
data, self.initial_centers.copy(), n_jobs, self.max_iter,
self.tolerance, callback_loop)

if self.progress is not None:
callback = KMeansCallback(self.progress, "KMeans iterations", self.max_iter, callback_loop)
context = callback
else:
callback = callback_loop
context = nullcontext()
with context:
cluster_centers, code, iterations, cost = impl.kmeans.cluster_loop(
data, self.initial_centers.copy(), n_jobs, self.max_iter,
self.tolerance, callback)
if code == 0:
converged = True
else:
Expand Down Expand Up @@ -526,3 +549,15 @@ def partial_fit(self, data, n_jobs=None):
self._model._converged = True

return self


class KMeansCallback(ProgressCallback):

def __init__(self, progress, description, total, parent_callback=None):
super().__init__(progress, description=description, total=total)
self._parent_callback = parent_callback

def __call__(self, *args, **kw):
super().__call__(*args, **kw)
if self._parent_callback is not None:
self._parent_callback(*args, **kw)
21 changes: 11 additions & 10 deletions deeptime/markov/msm/tram/_tram.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ class TRAM(_MSMBaseEstimator):
`track_log_likelihoods=true`, the log-likelihood are also stored. If `callback_interval=0`, no call to the
callback function is done.
progress : object
Progress bar object that `TRAM` will call to indicate progress to the user.
Tested for a tqdm progress bar. Should implement `update()` and `close()` and have `total` and `desc`
properties.
Progress bar object that `TRAM` will call to indicate progress to the user. Tested for a tqdm progress bar.
The interface is checked
via :meth:`supports_progress_interface <deeptime.util.callbacks.supports_progress_interface>`.
See also
--------
Expand Down Expand Up @@ -192,8 +192,7 @@ def fit(self, data, model=None, *args, **kw):
return self

def _run_estimation(self, tram_input):
""" Estimate the free energies using self-consistent iteration as described in the TRAM paper.
"""
""" Estimate the free energies using self-consistent iteration as described in the TRAM paper. """
with TRAMCallback(self.progress, self.maxiter, self.log_likelihoods, self.increments,
self.callback_interval > 0) as callback:
self._tram_estimator.estimate(tram_input, self.maxiter, self.maxerr,
Expand All @@ -202,26 +201,28 @@ def _run_estimation(self, tram_input):

if callback.last_increment > self.maxerr:
warnings.warn(
f"TRAM did not converge after {self.maxiter} iteration. Last increment: {callback.last_increment}",
ConvergenceWarning)
f"TRAM did not converge after {self.maxiter} iteration(s). "
f"Last increment: {callback.last_increment}", ConvergenceWarning)


class TRAMCallback(callbacks.Callback):
class TRAMCallback(callbacks.ProgressCallback):
"""Callback for the TRAM estimate process. Increments a progress bar and optionally saves iteration increments and
log likelihoods to a list.
Parameters
----------
log_likelihoods_list : list, optional
A list to append the log-likelihoods to that are passed to the callback.__call__() method.
total : int
Maximum number of callbacks.
increments : list, optional
A list to append the increments to that are passed to the callback.__call__() method.
store_convergence_info : bool, default=False
If True, log_likelihoods and increments are appended to their respective lists each time callback.__call__() is
called. If false, no values are appended, only the last increment is stored.
"""
def __init__(self, progress, n_iter, log_likelihoods_list=None, increments=None, store_convergence_info=False):
super().__init__(progress, n_iter=n_iter, display_text="Running TRAM estimate")
def __init__(self, progress, total, log_likelihoods_list=None, increments=None, store_convergence_info=False):
super().__init__(progress, total=total, description="Running TRAM estimate")
self.log_likelihoods = log_likelihoods_list
self.increments = increments
self.store_convergence_info = store_convergence_info
Expand Down
11 changes: 6 additions & 5 deletions deeptime/markov/msm/tram/_tram_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,10 @@ def restrict_to_largest_connected_set(self, connectivity='post_hoc_RE', connecti
Only needed if connectivity="post_hoc_RE" or "BAR_variance". Values greater than 1.0 weaken the connectivity
conditions. For 'post_hoc_RE' this multiplies the number of hypothetically observed transitions. For
'BAR_variance' this scales the threshold for the minimal allowed variance of free energy differences.
progress : object, default=None
Progress bar that TRAMDataset will call to indicate progress to the user.
Tested for a tqdm progress bar. Should implement `update()` and `close()`.
progress : object
Progress bar object that `TRAMDataset` will call to indicate progress to the user.
Tested for a tqdm progress bar. The interface is checked
via :meth:`supports_progress_interface <deeptime.util.callbacks.supports_progress_interface>`.
Raises
------
Expand Down Expand Up @@ -416,8 +417,8 @@ def _find_largest_connected_set(self, connectivity, connectivity_factor, progres
else:
connectivity_fn = tram.find_state_transitions_BAR_variance

with callbacks.Callback(progress, self.n_therm_states * self.n_markov_states,
"Finding connected sets") as callback:
with callbacks.ProgressCallback(progress, "Finding connected sets",
self.n_therm_states * self.n_markov_states) as callback:
(i_s, j_s) = connectivity_fn(self.ttrajs, self.dtrajs, self.bias_matrices, all_state_counts,
self.n_therm_states, self.n_markov_states, connectivity_factor,
callback)
Expand Down
15 changes: 8 additions & 7 deletions deeptime/src/include/deeptime/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ struct ComputeIndex {
static constexpr auto compute(const Arr &strides, const std::tuple<Ix...> &tup, std::index_sequence<I...>) {
return (0 + ... + (strides[I] * std::get<I>(tup)));
}

template<typename Arr, typename Arr2, std::size_t... I>
static constexpr auto computeContainer(const Arr &strides, const Arr2 &tup, std::index_sequence<I...>) {
return (0 + ... + (strides[I] * std::get<I>(tup)));
}
};

}
Expand Down Expand Up @@ -152,13 +157,9 @@ class Index {
* @param indices the Dims-dimensional index
* @return the 1D index
*/
template<typename Arr>
value_type index(const Arr &indices) const {
std::size_t result{0};
for (std::size_t i = 0; i < Dims; ++i) {
result += _cum_size[i] * indices[i];
}
return result;
template<typename Arr, typename Indices = std::make_index_sequence<Dims>>
constexpr value_type index(const Arr &indices) const {
return detail::ComputeIndex<>::computeContainer(_cum_size, indices, Indices{});
}

/**
Expand Down
8 changes: 8 additions & 0 deletions deeptime/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,18 @@
parallel.handle_n_jobs
decorators.cached_property
decorators.plotting_function
callbacks.supports_progress_interface
callbacks.ProgressCallback
platform.module_available
platform.handle_progress_bar
"""

from .stats import QuantityStatistics, confidence_interval
from ._validation import LaggedModelValidator

from . import data
from . import types
from . import callbacks
from . import platform
68 changes: 49 additions & 19 deletions deeptime/util/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,58 @@
import copy
from .platform import handle_progress_bar


class Callback:
"""Base callback function for the c++ bindings to indicate progress by incrementing a progress bar.
def supports_progress_interface(bar):
r""" Method to check if a progress bar supports the deeptime interface, meaning that it
has `update`, `close`, and `set_description` methods as well as a `total` attribute.
Parameters
----------
progress_bar : object
Tested for a tqdm progress bar. Should implement update() and close() and have .total and .desc properties.
n_iter : int
Number of iterations to completion.
display_text : string
text to display in front of the progress bar.
"""
Parameters
----------
bar : object, optional
The progress bar implementation to check, can be None.
def __init__(self, progress, n_iter=None, display_text=None):
self.progress_bar = handle_progress_bar(progress)()
if display_text is not None:
self.progress_bar.desc = display_text
if n_iter is not None:
self.progress_bar.total = n_iter
Returns
-------
supports : bool
Whether the progress bar is supported.
def __call__(self):
See Also
--------
ProgressCallback
"""
has_methods = all(callable(getattr(bar, method, None)) for method in supports_progress_interface.required_methods)
return has_methods


supports_progress_interface.required_methods = ['update', 'close', 'set_description']


class ProgressCallback:
r"""Base callback function for the c++ bindings to indicate progress by incrementing a progress bar.
Parameters
----------
progress : object
Tested for a tqdm progress bar. Should implement `update()`, `set_description()`, and `close()`. Should
also possess a `total` constructor keyword argument.
total : int
Number of iterations to completion.
description : string
text to display in front of the progress bar.
See Also
--------
supports_progress_interface
"""

def __init__(self, progress, description=None, total=None):
self.progress_bar = handle_progress_bar(progress)(total=total)
assert supports_progress_interface(self.progress_bar), \
f"Progress bar did not satisfy interface! It should at least have " \
f"the method(s) {supports_progress_interface.required_methods}."
if description is not None:
self.progress_bar.set_description(description)

def __call__(self, *args, **kw):
self.progress_bar.update()

def __enter__(self):
Expand Down
10 changes: 4 additions & 6 deletions deeptime/util/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,17 @@ def handle_progress_bar(progress):
class progress:
def __init__(self, x=None, **_):
self._x = x
self.total = None

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
return False
def __enter__(self): return self
def __exit__(self, exc_type, exc_val, exc_tb): return False

def __iter__(self):
for x in self._x:
yield x

def update(self): pass

def close(self): pass
def set_description(self, *_): pass

return progress
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ memory_profiler
mdshare
nbconvert
jupyter
tqdm
12 changes: 11 additions & 1 deletion docs/source/index_msm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,24 @@ over the encountered state transitions. This is covered in `transition counting
notebooks/mlmsm
notebooks/pcca
notebooks/tpt
notebooks/tram

Furthermore, deeptime implements :class:`Augmented Markov models <deeptime.markov.msm.AugmentedMSMEstimator>`
:footcite:`olsson2017combining` which can be used when experimental data is available, as well as
:class:`Observable Operator Model MSMs <deeptime.markov.msm.OOMReweightedMSM>` :footcite:`nuske2017markov` which is
an unbiased estimator for the MSM transition matrix that corrects for the effect of starting out of equilibrium,
even when short lag times are used.

.. rubric:: Multiensemble MSMs

Deeptime offers the TRAM method :footcite:`wu2016multiensemble` for estimating multiensemble MSMs. These are collections
of MSMs based on simulations that are governed by biased dynamics (i.e., replica exchange simulations
and umbrella sampling).

.. toctree::
:maxdepth: 1

notebooks/tram

.. rubric:: References

.. footbibliography::
2 changes: 1 addition & 1 deletion docs/source/notebooks
Submodule notebooks updated 1 files
+62 −40 tram.ipynb
6 changes: 6 additions & 0 deletions examples/clustering_custom_metric/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set(SRC bindings.cpp)
pybind11_add_module(custom_metric ${SRC})
target_link_libraries(custom_metric PUBLIC deeptime::deeptime)
if(OpenMP_FOUND)
target_link_libraries(custom_metric PUBLIC OpenMP::OpenMP_CXX)
endif()
2 changes: 1 addition & 1 deletion examples/clustering_custom_metric/bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "register_clustering.h"
#include <deeptime/clustering/register_clustering.h>

struct MaximumMetric {

Expand Down
Loading

0 comments on commit 2f16e0f

Please sign in to comment.