Skip to content

Commit

Permalink
Revert "refactor: code changes documents -> items"
Browse files Browse the repository at this point in the history
This reverts commit bcd4e8c.
  • Loading branch information
garuna-m6 committed Oct 14, 2024
1 parent 283f79b commit 6e0f2bd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions docetl/operations/clustering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def get_embeddings_for_clustering_with_st(
return embeddings, 0


def cluster_items(
items: List[Dict],
def cluster_documents(
documents: List[Dict],
sampling_config: Dict,
sample_size: int,
api_wrapper: APIWrapper,
Expand All @@ -70,25 +70,25 @@ def cluster_items(
Cluster documents using KMeans clustering algorithm.
Args:
items (List[Dict]): The list of documents to cluster.
documents (List[Dict]): The list of documents to cluster.
sampling_config (Dict): The sampling configuration. Must contain embedding_model. If embedding_keys is not specified, it will use all keys in the document. If embedding_model is not specified, it will use text-embedding-3-small. If embedding_model is sentence-transformer, it will use all-MiniLM-L6-v2.
sample_size (int): The number of clusters to create.
api_wrapper (APIWrapper): The API wrapper to use for embedding.
Returns:
Dict[int, List[Dict]]: A dictionary of clusters, where each cluster is a list of documents.
"""
embeddings, cost = get_embeddings_for_clustering(
items, sampling_config, api_wrapper
documents, sampling_config, api_wrapper
)

from sklearn.cluster import KMeans

num_clusters = min(sample_size, len(items))
num_clusters = min(sample_size, len(documents))
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
cluster_labels = kmeans.fit_predict(embeddings)

clusters = {i: [] for i in range(num_clusters)}
for idx, label in enumerate(cluster_labels):
clusters[label].append(items[idx])
clusters[label].append(documents[idx])

return clusters, cost
4 changes: 2 additions & 2 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from docetl.operations.base import BaseOperation
from docetl.operations.clustering_utils import (
cluster_items,
cluster_documents,
get_embeddings_for_clustering,
)
from docetl.operations.utils import rich_as_completed
Expand Down Expand Up @@ -428,7 +428,7 @@ def _cluster_based_sampling(
if sample_size >= len(group_list):
return group_list, 0

clusters, cost = cluster_items(
clusters, cost = cluster_documents(
group_list, value_sampling, sample_size, self.runner.api
)

Expand Down

0 comments on commit 6e0f2bd

Please sign in to comment.