From 6e0f2bd3ea50322e02f3758cd66573d1452a192b Mon Sep 17 00:00:00 2001 From: Anurag Date: Mon, 14 Oct 2024 13:28:52 +0530 Subject: [PATCH] Revert "refactor: code changes documents -> items" This reverts commit bcd4e8cb9f67aa7f5861be69a459b68385e347c9. --- docetl/operations/clustering_utils.py | 12 ++++++------ docetl/operations/reduce.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docetl/operations/clustering_utils.py b/docetl/operations/clustering_utils.py index 05fed867..7663b892 100644 --- a/docetl/operations/clustering_utils.py +++ b/docetl/operations/clustering_utils.py @@ -60,8 +60,8 @@ def get_embeddings_for_clustering_with_st( return embeddings, 0 -def cluster_items( - items: List[Dict], +def cluster_documents( + documents: List[Dict], sampling_config: Dict, sample_size: int, api_wrapper: APIWrapper, @@ -70,7 +70,7 @@ def cluster_items( Cluster documents using KMeans clustering algorithm. Args: - items (List[Dict]): The list of documents to cluster. + documents (List[Dict]): The list of documents to cluster. sampling_config (Dict): The sampling configuration. Must contain embedding_model. If embedding_keys is not specified, it will use all keys in the document. If embedding_model is not specified, it will use text-embedding-3-small. If embedding_model is sentence-transformer, it will use all-MiniLM-L6-v2. sample_size (int): The number of clusters to create. api_wrapper (APIWrapper): The API wrapper to use for embedding. @@ -78,17 +78,17 @@ def cluster_items( Dict[int, List[Dict]]: A dictionary of clusters, where each cluster is a list of documents. """ embeddings, cost = get_embeddings_for_clustering( - items, sampling_config, api_wrapper + documents, sampling_config, api_wrapper ) from sklearn.cluster import KMeans - num_clusters = min(sample_size, len(items)) + num_clusters = min(sample_size, len(documents)) kmeans = KMeans(n_clusters=num_clusters, random_state=42) cluster_labels = kmeans.fit_predict(embeddings) clusters = {i: [] for i in range(num_clusters)} for idx, label in enumerate(cluster_labels): - clusters[label].append(items[idx]) + clusters[label].append(documents[idx]) return clusters, cost diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index 5d7f5012..cd3cee78 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -20,7 +20,7 @@ from docetl.operations.base import BaseOperation from docetl.operations.clustering_utils import ( - cluster_items, + cluster_documents, get_embeddings_for_clustering, ) from docetl.operations.utils import rich_as_completed @@ -428,7 +428,7 @@ def _cluster_based_sampling( if sample_size >= len(group_list): return group_list, 0 - clusters, cost = cluster_items( + clusters, cost = cluster_documents( group_list, value_sampling, sample_size, self.runner.api )