diff --git a/backend/modules/vector_db/weaviate.py b/backend/modules/vector_db/weaviate.py index 2aad43e5..94ceb8ac 100644 --- a/backend/modules/vector_db/weaviate.py +++ b/backend/modules/vector_db/weaviate.py @@ -1,14 +1,18 @@ from typing import Any, Dict, List import weaviate +import weaviate.classes as wvc from langchain.embeddings.base import Embeddings -from langchain_community.vectorstores.weaviate import Weaviate +from langchain_weaviate.vectorstores import WeaviateVectorStore from langchain_core.documents import Document from backend.constants import DATA_POINT_FQN_METADATA_KEY from backend.modules.vector_db.base import BaseVectorDB from backend.types import DataPointVector, VectorDBConfig +from backend.logger import logger +BATCH_SIZE = 1000 +MAX_SCROLL_LIMIT = int(1e6) def decapitalize(s): if not s: @@ -18,29 +22,54 @@ def decapitalize(s): class WeaviateVectorDB(BaseVectorDB): def __init__(self, config: VectorDBConfig): - self.url = config.url - self.api_key = config.api_key - self.weaviate_client = weaviate.Client( - url=self.url, - **( - {"auth_client_secret": weaviate.AuthApiKey(api_key=self.api_key)} - if self.api_key - else {} + logger.debug(f"[Weaviate] Connecting using config: {config.model_dump()}") + if config.local is True: + self.weaviate_client = weaviate.connect_to_local() + else: + self.weaviate_client = weaviate.connect_to_weaviate_cloud( + cluster_url=config.url, + auth_credentials=wvc.init.Auth.api_key(config.api_key) + ) + + def create_collection(self, collection_name: str, embeddings: Embeddings): + logger.debug(f"[Weaviate] Creating new collection {collection_name}") + self.weaviate_client.collections.create( + name=collection_name.capitalize(), + replication_config=wvc.config.Configure.replication( + factor=1 ), + vectorizer_config=wvc.config.Configure.Vectorizer.none(), + properties=[ + wvc.config.Property(name=DATA_POINT_FQN_METADATA_KEY, data_type=wvc.config.DataType.TEXT) + ] ) + logger.debug(f"[Weaviate] Created new collection {collection_name}") - def create_collection(self, collection_name: str, embeddings: Embeddings): - self.weaviate_client.schema.create_class( - { - "class": collection_name.capitalize(), - "properties": [ - { - "name": f"{DATA_POINT_FQN_METADATA_KEY}", - "dataType": ["text"], - }, - ], - } + def _get_records_to_be_updated(self, collection_name: str, data_point_fqns: List[str]): + logger.debug( + f"[Weaviate] Incremental Ingestion: Fetching documents for {len(data_point_fqns)} data point fqns for collection {collection_name}" ) + stop = False + offset = 0 + record_ids_to_be_updated = [] + while stop is not True: + records = self.weaviate_client.collections \ + .get(collection_name.capitalize()).query \ + .fetch_objects( + limit=BATCH_SIZE, + filters=wvc.query.Filter.by_property(DATA_POINT_FQN_METADATA_KEY).contains_any(data_point_fqns), + offset=offset, + return_properties=[DATA_POINT_FQN_METADATA_KEY] + ) + if not records or len(records.objects) < BATCH_SIZE or len(record_ids_to_be_updated) > MAX_SCROLL_LIMIT: + stop = True + for record in records.objects: + record_ids_to_be_updated.append(record.uuid) + offset += BATCH_SIZE + logger.debug( + f"[Weaviate] Incremental Ingestion: collection={collection_name} Addition={len(data_point_fqns)}, Updates={len(record_ids_to_be_updated)}" + ) + return record_ids_to_be_updated def upsert_documents( self, @@ -59,25 +88,59 @@ def upsert_documents( Returns: - None """ - Weaviate.from_documents( + if len(documents) == 0: + logger.warning("No documents to index") + return + logger.debug( + f"[Weaviate] Adding {len(documents)} documents to collection {collection_name}" + ) + + data_point_fqns = [] + for document in documents: + if document.metadata.get(DATA_POINT_FQN_METADATA_KEY): + data_point_fqns.append( + document.metadata.get(DATA_POINT_FQN_METADATA_KEY) + ) + records_to_be_updated:List[str] + if incremental: + records_to_be_updated = self._get_records_to_be_updated(collection_name, data_point_fqns) + + WeaviateVectorStore.from_documents( documents=documents, embedding=embeddings, client=self.weaviate_client, index_name=collection_name.capitalize(), ) + logger.debug( + f"[Weaviate] Added {len(documents)} documents to collection {collection_name}" + ) + + if len(records_to_be_updated) > 0: + logger.debug( + f"[Weaviate] Deleting {len(records_to_be_updated)} outdated documents from collection {collection_name}" + ) + collection = self.weaviate_client.collections.get(collection_name.capitalize()) + for i in range(0, len(records_to_be_updated), BATCH_SIZE): + record_ids_to_be_processed = records_to_be_updated[i : i + BATCH_SIZE] + collection.data.delete_many( + where=wvc.query.Filter.by_id().contains_any(record_ids_to_be_processed) + ) + logger.debug( + f"[Weaviate] Deleted {len(records_to_be_updated)} outdated documents from collection {collection_name}" + ) def get_collections(self) -> List[str]: - collections = self.weaviate_client.schema.get().get("classes", []) - return [decapitalize(collection["class"]) for collection in collections] + collections = self.weaviate_client.collections.list_all(simple=True) + return list(collections.keys()) def delete_collection( self, collection_name: str, ): - return self.weaviate_client.schema.delete_class(collection_name.capitalize()) + return self.weaviate_client.collections.delete(collection_name.capitalize()) def get_vector_store(self, collection_name: str, embeddings: Embeddings): - return Weaviate( + return WeaviateVectorStore( client=self.weaviate_client, embedding=embeddings, index_name=collection_name.capitalize(), # Weaviate stores the index name as capitalized @@ -86,47 +149,6 @@ def get_vector_store(self, collection_name: str, embeddings: Embeddings): attributes=[f"{DATA_POINT_FQN_METADATA_KEY}"], ) - def list_documents_in_collection( - self, collection_name: str, base_document_id: str = None - ) -> List[str]: - """ - List all documents in a collection - """ - # https://weaviate.io/developers/weaviate/search/aggregate#retrieve-groupedby-properties - response = ( - self.weaviate_client.query.aggregate(collection_name.capitalize()) - .with_group_by_filter([f"{DATA_POINT_FQN_METADATA_KEY}"]) - .with_fields("groupedBy { value }") - .do() - ) - groups: List[Dict[Any, Any]] = ( - response.get("data", {}) - .get("Aggregate", {}) - .get(collection_name.capitalize(), []) - ) - document_ids = set() - for group in groups: - # TODO (chiragjn): Revisit this, we should not be letting `value` be empty - document_ids.add(group.get("groupedBy", {}).get("value", "") or "") - return list(document_ids) - - def delete_documents(self, collection_name: str, document_ids: List[str]): - """ - Delete documents from the collection that match given `document_id_match` - """ - # https://weaviate.io/developers/weaviate/manage-data/delete#delete-multiple-objects - res = self.weaviate_client.batch.delete_objects( - class_name=collection_name.capitalize(), - where={ - "path": [f"{DATA_POINT_FQN_METADATA_KEY}"], - "operator": "ContainsAny", - "valueTextArray": document_ids, - }, - ) - deleted_vectors = res.get("results", {}).get("successful", None) - if deleted_vectors: - print(f"Deleted {len(document_ids)} documents from the collection") - def get_vector_client(self): return self.weaviate_client @@ -136,7 +158,8 @@ def list_data_point_vectors( data_source_fqn: str, batch_size: int = 1000, ) -> List[DataPointVector]: - pass + document_vector_points: List[DataPointVector] = [] + return document_vector_points def delete_data_point_vectors( self, diff --git a/backend/requirements.txt b/backend/requirements.txt index f443bc00..a8f00239 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -3,6 +3,7 @@ langchain==0.1.9 langchain-community==0.0.24 langchain-openai==0.1.7 langchain-core==0.1.46 +langchain-weaviate==0.0.3 openai==1.35.3 tiktoken==0.7.0 uvicorn[standard]==0.23.2 @@ -14,7 +15,7 @@ pydantic-settings==2.3.3 orjson==3.9.15 PyMuPDF==1.23.6 beautifulsoup4==4.12.2 -truefoundry[ml]==0.3.1 +truefoundry==0.4.1 markdownify==0.11.6 gunicorn==22.0.0 markdown-crawler==0.0.8 diff --git a/backend/vectordb.requirements.txt b/backend/vectordb.requirements.txt index aeaa4bef..c754cdff 100644 --- a/backend/vectordb.requirements.txt +++ b/backend/vectordb.requirements.txt @@ -1,5 +1,5 @@ #### singlestore db singlestoredb==1.0.4 -### Weaviate client (in progress) -weaviate-client==3.25.3 +### Weaviate client +weaviate-client==4.7.1