Skip to content

Commit

Permalink
Added a t-SNE plot
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaelbicudo committed Oct 14, 2023
1 parent 74e55f1 commit 2bac548
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 13 deletions.
2 changes: 1 addition & 1 deletion clusttraj/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def parse_args(args: argparse.Namespace) -> ClustOptions:
"plot": bool(args.plot),
"evo_name": basenameout + "_evo.pdf" if args.plot else None,
"dendrogram_name": basenameout + "_dendrogram.pdf" if args.plot else None,
"mds_name": basenameout + ".pdf" if args.plot else None,
"mds_name": basenameout + "_mds.pdf" if args.plot else None,
"trajfile": args.trajectory_file,
"min_rmsd": args.min_rmsd,
"method": args.method,
Expand Down
4 changes: 3 additions & 1 deletion clusttraj/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import List
from .io import Logger, configure_runtime, save_clusters_config
from .distmat import get_distmat
from .plot import plot_clust_evo, plot_dendrogram, plot_mds
from .plot import plot_clust_evo, plot_dendrogram, plot_mds, plot_tsne
from .classify import classify_structures, classify_structures_silhouette
from .metrics import compute_metrics

Expand Down Expand Up @@ -66,6 +66,8 @@ def main(args: List[str] = None) -> None:

plot_mds(clust_opt, clusters, distmat)

plot_tsne(clust_opt, clusters, distmat)

# print the cluster sizes
outclust_str = f"A total {len(clusters)} snapshots were read and {max(clusters)} cluster(s) was(were) found.\n"
outclust_str += "The cluster sizes are:\nCluster\tSize\n"
Expand Down
17 changes: 7 additions & 10 deletions clusttraj/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,30 @@
def compute_metrics(
clust_opt: ClustOptions, distmat: np.ndarray, z_matrix: np.ndarray, clusters: np.ndarray
) -> Tuple[np.float64, np.float64, np.float64, np.float64]:
"""Compute the Cophenetic Correlation Coefficient and Silhouette,
Calinski Harabasz and Davies-Bouldin scores.
"""Compute metrics to assess the performance of the clustering procedure.
Args:
clust_opt (ClustOptions): The clustering options.
z_matrix (np.ndarray): The Z-matrix from hierarchical clustering procedure.
Returns:
None
ss (np.float64): The silhouette score.
ch (np.float64): The Calinski Harabasz score.
db (np.float64): The Davies-Bouldin score.
cpcc (np.float64): The cophenetic correlation coefficient.
"""

# Compute the silhouette score
ss = silhouette_score(squareform(distmat), clusters, metric="precomputed")

# Compute the Calinski Harabasz score

ch = calinski_harabasz_score(squareform(distmat), clusters)

# Compute the Davies-Bouldin score

db = davies_bouldin_score(squareform(distmat), clusters)

# Compute the cophenetic correlation coefficient
cpcc = cophenet(z_matrix)[0]

# cpcc_coef, a = cophenet(z_matrix)
cpcc_coef = cophenet(z_matrix)[0]
# cpcc_coef = result[0]

return ss, ch, db, cpcc_coef
return ss, ch, db, cpcc

59 changes: 58 additions & 1 deletion clusttraj/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,69 @@ def plot_mds(clust_opt: ClustOptions, clusters: np.ndarray, distmat: np.ndarray)
right=False,
labelbottom=False,
labelleft=False,
)
)

# Scatter plot the coordinates with cluster colors
plt.scatter(
coords[:, 0], coords[:, 1], marker="o", c=clusters, cmap=plt.cm.nipy_spectral
)

plt.title('MDS Visualization')

# Save the plot
plt.savefig(clust_opt.mds_name, bbox_inches="tight")


def plot_tsne(clust_opt: ClustOptions, clusters: np.ndarray, distmat: np.ndarray) -> None:
"""Plot the t-distributed Stochastic Neighbor Embedding 2D plot of the clustering.
Args:
clust_opt (ClustOptions): The clustering options.
clusters (np.ndarray): The cluster labels.
distmat (np.ndarray): The distance matrix.
Returns:
None
"""

# Initialize the tSNE model
tsne = manifold.TSNE(
n_components=2,
perplexity=30,
learning_rate=200,
random_state=666,
n_jobs=clust_opt.n_workers,
)

# Perform the t-SNE and get the 2D representation
coords = tsne.fit_transform(squareform(distmat))

# Define a list of unique colors for each cluster
unique_clusters = np.unique(clusters)
colors = plt.cm.tab20(np.linspace(0, 1, len(unique_clusters)))

# Create a new figure
plt.figure()

# Configure tick parameters
plt.tick_params(
axis="both",
which="both",
bottom=False,
top=False,
left=False,
right=False,
labelbottom=False,
labelleft=False,
)

# Create a scatter plot with different colors for each cluster
for i, cluster in enumerate(unique_clusters):
cluster_data = coords[clusters == cluster]
plt.scatter(cluster_data[:, 0], cluster_data[:, 1], color=colors[i])

plt.title('t-SNE Visualization')

# Save the plot
plt.savefig(clust_opt.mds_name[:-7] + "tsne.pdf", bbox_inches="tight")

0 comments on commit 2bac548

Please sign in to comment.