diff --git a/docetl/operations/clustering_utils.py b/docetl/operations/clustering_utils.py index 05fed867..7663b892 100644 --- a/docetl/operations/clustering_utils.py +++ b/docetl/operations/clustering_utils.py @@ -60,8 +60,8 @@ def get_embeddings_for_clustering_with_st( return embeddings, 0 -def cluster_items( - items: List[Dict], +def cluster_documents( + documents: List[Dict], sampling_config: Dict, sample_size: int, api_wrapper: APIWrapper, @@ -70,7 +70,7 @@ def cluster_items( Cluster documents using KMeans clustering algorithm. Args: - items (List[Dict]): The list of documents to cluster. + documents (List[Dict]): The list of documents to cluster. sampling_config (Dict): The sampling configuration. Must contain embedding_model. If embedding_keys is not specified, it will use all keys in the document. If embedding_model is not specified, it will use text-embedding-3-small. If embedding_model is sentence-transformer, it will use all-MiniLM-L6-v2. sample_size (int): The number of clusters to create. api_wrapper (APIWrapper): The API wrapper to use for embedding. @@ -78,17 +78,17 @@ def cluster_items( Dict[int, List[Dict]]: A dictionary of clusters, where each cluster is a list of documents. """ embeddings, cost = get_embeddings_for_clustering( - items, sampling_config, api_wrapper + documents, sampling_config, api_wrapper ) from sklearn.cluster import KMeans - num_clusters = min(sample_size, len(items)) + num_clusters = min(sample_size, len(documents)) kmeans = KMeans(n_clusters=num_clusters, random_state=42) cluster_labels = kmeans.fit_predict(embeddings) clusters = {i: [] for i in range(num_clusters)} for idx, label in enumerate(cluster_labels): - clusters[label].append(items[idx]) + clusters[label].append(documents[idx]) return clusters, cost diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index 5d7f5012..cd3cee78 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -20,7 +20,7 @@ from docetl.operations.base import BaseOperation from docetl.operations.clustering_utils import ( - cluster_items, + cluster_documents, get_embeddings_for_clustering, ) from docetl.operations.utils import rich_as_completed @@ -428,7 +428,7 @@ def _cluster_based_sampling( if sample_size >= len(group_list): return group_list, 0 - clusters, cost = cluster_items( + clusters, cost = cluster_documents( group_list, value_sampling, sample_size, self.runner.api )