Skip to content

Commit

Permalink
Merge pull request #9 from rafaelbicudo/metrics
Browse files Browse the repository at this point in the history
Examples and metrics
  • Loading branch information
hmcezar authored Jan 12, 2024
2 parents c481cf0 + 832abb4 commit 64147b9
Show file tree
Hide file tree
Showing 12 changed files with 238 additions and 13 deletions.
6 changes: 0 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ jobs:
- name: Install package
run: |
pip install -e .
- name: Lint with ruff
run: |
# stop the build if there are Python syntax errors or undefined names
ruff --format=github --select=E9,F63,F7,F82 --target-version=py37 .
# default set of ruff rules with GitHub Annotations
ruff --format=github --target-version=py37 .
- name: Test with pytest
run: |
pytest --cov=./ --cov-report=xml
Expand Down
10 changes: 9 additions & 1 deletion clusttraj/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class ClustOptions:
overwrite: bool = None
final_kabsch: bool = None
silhouette_score: bool = None
metrics: bool = None
distmat_name: str = None
out_clust_name: str = None
evo_name: str = None
Expand Down Expand Up @@ -305,6 +306,12 @@ def configure_runtime(args_in: List[str]) -> ClustOptions:
help="log file (default: clusttraj.log)",
)

parser.add_argument(
"--metrics",
action="store_true",
help="compute metrics to evaluate the clustering procedure quality.",
)

rmsd_criterion = parser.add_mutually_exclusive_group(required=True)

rmsd_criterion.add_argument(
Expand Down Expand Up @@ -552,7 +559,7 @@ def parse_args(args: argparse.Namespace) -> ClustOptions:
"plot": bool(args.plot),
"evo_name": basenameout + "_evo.pdf" if args.plot else None,
"dendrogram_name": basenameout + "_dendrogram.pdf" if args.plot else None,
"mds_name": basenameout + ".pdf" if args.plot else None,
"mds_name": basenameout + "_mds.pdf" if args.plot else None,
"trajfile": args.trajectory_file,
"min_rmsd": args.min_rmsd,
"method": args.method,
Expand All @@ -562,6 +569,7 @@ def parse_args(args: argparse.Namespace) -> ClustOptions:
"overwrite": args.force,
"final_kabsch": args.final_kabsch,
"silhouette_score": args.silhouette_score,
"metrics": args.metrics,
}

if args.reorder:
Expand Down
14 changes: 13 additions & 1 deletion clusttraj/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from typing import List
from .io import Logger, configure_runtime, save_clusters_config
from .distmat import get_distmat
from .plot import plot_clust_evo, plot_dendrogram, plot_mds
from .plot import plot_clust_evo, plot_dendrogram, plot_mds, plot_tsne
from .classify import classify_structures, classify_structures_silhouette
from .metrics import compute_metrics


def main(args: List[str] = None) -> None:
Expand Down Expand Up @@ -65,6 +66,8 @@ def main(args: List[str] = None) -> None:

plot_mds(clust_opt, clusters, distmat)

plot_tsne(clust_opt, clusters, distmat)

# print the cluster sizes
outclust_str = f"A total {len(clusters)} snapshots were read and {max(clusters)} cluster(s) was(were) found.\n"
outclust_str += "The cluster sizes are:\nCluster\tSize\n"
Expand All @@ -74,6 +77,15 @@ def main(args: List[str] = None) -> None:
outclust_str += f"{label}\t{size}\n"
Logger.logger.info(outclust_str)

# Compute the evaluation metrics
if clust_opt.metrics:
ss, ch, db, cpcc = compute_metrics(clust_opt, distmat, Z, clusters)

outclust_str += f"\nSilhouette score: {ss:.3f}\n"
outclust_str += f"Calinski Harabsz score: {ch:.3f}\n"
outclust_str += f"Davies-Bouldin score: {db:.3f}\n"
outclust_str += f"Cophenetic correlation coefficient: {cpcc:.3f}\n\n"

# save summary
with open(clust_opt.summary_name, "w") as f:
f.write(str(clust_opt))
Expand Down
46 changes: 46 additions & 0 deletions clusttraj/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Functions to compute evaluation metrics of the clustering procedure"""

from sklearn.metrics import (
silhouette_score,
calinski_harabasz_score,
davies_bouldin_score,
)
from scipy.spatial.distance import squareform
from scipy.cluster.hierarchy import cophenet
from typing import Tuple
import numpy as np
from .io import ClustOptions


def compute_metrics(
clust_opt: ClustOptions,
distmat: np.ndarray,
z_matrix: np.ndarray,
clusters: np.ndarray,
) -> Tuple[np.float64, np.float64, np.float64, np.float64]:
"""Compute metrics to assess the performance of the clustering procedure.
Args:
clust_opt (ClustOptions): The clustering options.
z_matrix (np.ndarray): The Z-matrix from hierarchical clustering procedure.
Returns:
ss (np.float64): The silhouette score.
ch (np.float64): The Calinski Harabasz score.
db (np.float64): The Davies-Bouldin score.
cpcc (np.float64): The cophenetic correlation coefficient.
"""

# Compute the silhouette score
ss = silhouette_score(squareform(distmat), clusters, metric="precomputed")

# Compute the Calinski Harabasz score
ch = calinski_harabasz_score(squareform(distmat), clusters)

# Compute the Davies-Bouldin score
db = davies_bouldin_score(squareform(distmat), clusters)

# Compute the cophenetic correlation coefficient
cpcc = cophenet(z_matrix)[0]

return ss, ch, db, cpcc
58 changes: 58 additions & 0 deletions clusttraj/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,63 @@ def plot_mds(clust_opt: ClustOptions, clusters: np.ndarray, distmat: np.ndarray)
coords[:, 0], coords[:, 1], marker="o", c=clusters, cmap=plt.cm.nipy_spectral
)

plt.title("MDS Visualization")

# Save the plot
plt.savefig(clust_opt.mds_name, bbox_inches="tight")


def plot_tsne(
clust_opt: ClustOptions, clusters: np.ndarray, distmat: np.ndarray
) -> None:
"""Plot the t-distributed Stochastic Neighbor Embedding 2D plot of the clustering.
Args:
clust_opt (ClustOptions): The clustering options.
clusters (np.ndarray): The cluster labels.
distmat (np.ndarray): The distance matrix.
Returns:
None
"""

# Initialize the tSNE model
tsne = manifold.TSNE(
n_components=2,
perplexity=30,
learning_rate=200,
random_state=666,
n_jobs=clust_opt.n_workers,
)

# Perform the t-SNE and get the 2D representation
coords = tsne.fit_transform(squareform(distmat))

# Define a list of unique colors for each cluster
unique_clusters = np.unique(clusters)
colors = plt.cm.tab20(np.linspace(0, 1, len(unique_clusters)))

# Create a new figure
plt.figure()

# Configure tick parameters
plt.tick_params(
axis="both",
which="both",
bottom=False,
top=False,
left=False,
right=False,
labelbottom=False,
labelleft=False,
)

# Create a scatter plot with different colors for each cluster
for i, cluster in enumerate(unique_clusters):
cluster_data = coords[clusters == cluster]
plt.scatter(cluster_data[:, 0], cluster_data[:, 1], color=colors[i])

plt.title("t-SNE Visualization")

# Save the plot
plt.savefig(clust_opt.mds_name[:-7] + "tsne.pdf", bbox_inches="tight")
112 changes: 108 additions & 4 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,17 @@ Since we already computed the distance matrix, we can provide it as input using

- ``clusters.pdf`` plots the multidimensional scaling (MDS) of the distance matrix.

.. image:: images/clusters.pdf
.. image:: images/average_full_mds.pdf
:width: 300pt

- ``clusters_dendrogram.pdf`` plots the hierarchical clustering dendrogram.

.. image:: images/clusters_dendrogram.pdf
.. image:: images/average_full_dend.pdf
:width: 300pt

- ``clusters_evo.pdf`` plots the evolution of cluster populations during the simulation.

.. image:: images/clusters_evo.pdf
.. image:: images/average_full_evo.pdf
:width: 300pt

The highest silhouette score is printed in the ``clusttraj.log`` file, along with the corresponding RMSD threshold:
Expand Down Expand Up @@ -218,19 +218,123 @@ To adopt the ``median`` method we can run:
In this case the highest silhouette score of 0.075 indicates that the points are located near the edge of the clusters. The distribution of population among the 2 clusters (1/99) also indicates the limitations of the method. Finally, visual inspection of the dendrogram shows anomalous behavior.

.. image:: images/anomalous_dendrogram.pdf
.. image:: images/anomalous_dend.pdf
:width: 300pt

.. .. raw:: html
.. <iframe src='/Users/Rafael/Coisas/Doutorado/clusttraj/clusttraj/docs/build/html/_images/anomalous_dendrogram.pdf' width="100%" height="500"></iframe>
The reader is encouraged to verify that the addition of ``-odl`` for `optimal visualization <https://academic.oup.com/bioinformatics/article/17/suppl_1/S22/261423?login=true>`_ flag cannot avoid the dendrogram crossings.


Accouting for molecule permutation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

As an attempt to avoid separating similar configurations due to permutation of identical molecules, we can reorder the atoms using the ``-e`` flag.

.. code-block:: console
python -m clusttraj h2o_traj.xyz -ss -p -m average -e -f
For this system the reordering compromised the statistical quality of the clustering. The number of clusters was increased from 2 to 35 while the optimal silhouette score was reduced from 0.217 to 0.119:

.. code-block:: console
╰─○ cat clusttraj.log
2023-10-02 19:53:20,618 INFO [distmat.py:34] <get_distmat> Calculating distance matrix using 4 threads
2023-10-02 19:54:00,821 INFO [distmat.py:38] <get_distmat> Saving condensed distance matrix to distmat.npy
2023-10-02 19:54:00,823 INFO [classify.py:27] <classify_structures_silhouette> Clustering using 'average' method to join the clusters
2023-10-02 19:54:00,855 INFO [classify.py:61] <classify_structures_silhouette> Highest silhouette score: 0.11873407875769024
2023-10-02 19:54:00,856 INFO [classify.py:71] <classify_structures_silhouette> Optimal RMSD threshold value: 1.237013337787396
2023-10-02 19:54:00,856 INFO [classify.py:76] <classify_structures_silhouette> Saving clustering classification to clusters.dat
2023-10-02 19:54:06,676 INFO [main.py:75] <main> A total 100 snapshots were read and 35 cluster(s) was(were) found.
The cluster sizes are:
Cluster Size
1 2
2 4
3 3
4 1
5 1
6 1
7 2
8 2
9 3
10 2
11 7
12 3
13 7
14 7
15 3
16 5
17 4
18 3
19 2
20 4
21 2
22 3
23 3
24 1
25 2
26 3
27 2
28 1
29 2
30 2
31 5
32 4
33 2
34 1
35 1
This functionality is especially useful in the case of solvated systems. In our case, we can treat one water molecule as the solute and the others as solvent. For example, considering the first water molecule as the solute:

.. code-block:: console
python -m clusttraj h2o_traj.xyz -ss -p -m average -e -f -ns 3
The number of solvent atoms must be specified using the ``-ns`` flag, and as a result we managed to increase the silhouette coefficient to 0.247 with a significant change in the cluster populations:

.. code-block:: console
╰─○ cat clusttraj.log
2023-10-02 20:13:52,041 INFO [distmat.py:38] <get_distmat> Saving condensed distance matrix to distmat.npy
2023-10-02 20:13:52,044 INFO [classify.py:27] <classify_structures_silhouette> Clustering using 'average' method to join the clusters
2023-10-02 20:13:52,101 INFO [classify.py:61] <classify_structures_silhouette> Highest silhouette score: 0.24735123044958368
2023-10-02 20:13:52,102 INFO [classify.py:65] <classify_structures_silhouette> The following RMSD threshold values yielded the same optimial silhouette score: 3.035586843407412, 3.135586843407412, 3.235586843407412, 3.335586843407412
2023-10-02 20:13:52,102 INFO [classify.py:68] <classify_structures_silhouette> The smallest RMSD of 3.035586843407412 has been adopted
2023-10-02 20:13:52,102 INFO [classify.py:76] <classify_structures_silhouette> Saving clustering classification to clusters.dat
2023-10-02 20:13:57,498 INFO [main.py:75] <main> A total 100 snapshots were read and 2 cluster(s) was(were) found.
The cluster sizes are:
Cluster Size
1 3
2 97
Final Kabsch rotation
^^^^^^^^^^^^^^^^^^^^^

We can also add a final Kabsch rotation to minimize the RMSD after reordering the solvent atoms:

.. code-block:: console
python -m clusttraj h2o_traj.xyz -ss -p -m average -e -f -ns 3 --final-kabsch
For this system no significant changes were observed, as the silhouette coefficient and cluster populations remain almost identical.

Removing hydrogen atoms
^^^^^^^^^^^^^^^^^^^^^^^



Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ Welcome to ClustTraj's documentation!
intro
install
clusttraj

usage
examples

Indices and tables
==================
Expand Down
2 changes: 2 additions & 0 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_parse_args():
force=True,
final_kabsch=True,
silhouette_score=False,
metrics=False,
)
clust_opt = parse_args(args)

Expand All @@ -107,6 +108,7 @@ def test_parse_args():
force=True,
final_kabsch=True,
silhouette_score=False,
metrics=False,
)
clust_opt = parse_args(args)

Expand Down

0 comments on commit 64147b9

Please sign in to comment.