-
Notifications
You must be signed in to change notification settings - Fork 292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for weaviate vector database #353
base: main
Are you sure you want to change the base?
Changes from 1 commit
8e7ae29
5b4cdda
4e5cca9
de6c83b
51f97c8
0a558c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,18 @@ | ||
from typing import Any, Dict, List | ||
|
||
import weaviate | ||
import weaviate.classes as wvc | ||
from langchain.embeddings.base import Embeddings | ||
from langchain_community.vectorstores.weaviate import Weaviate | ||
from langchain_weaviate.vectorstores import WeaviateVectorStore | ||
from langchain_core.documents import Document | ||
|
||
from backend.constants import DATA_POINT_FQN_METADATA_KEY | ||
from backend.modules.vector_db.base import BaseVectorDB | ||
from backend.types import DataPointVector, VectorDBConfig | ||
from backend.logger import logger | ||
|
||
BATCH_SIZE = 1000 | ||
MAX_SCROLL_LIMIT = int(1e6) | ||
|
||
def decapitalize(s): | ||
if not s: | ||
|
@@ -18,29 +22,56 @@ def decapitalize(s): | |
|
||
class WeaviateVectorDB(BaseVectorDB): | ||
def __init__(self, config: VectorDBConfig): | ||
self.url = config.url | ||
self.api_key = config.api_key | ||
self.weaviate_client = weaviate.Client( | ||
url=self.url, | ||
**( | ||
{"auth_client_secret": weaviate.AuthApiKey(api_key=self.api_key)} | ||
if self.api_key | ||
else {} | ||
logger.debug(f"[Weaviate] Connecting using config: {config.model_dump()}") | ||
if config.local is True: | ||
self.weaviate_client = weaviate.connect_to_local() | ||
else: | ||
self.weaviate_client = weaviate.connect_to_weaviate_cloud( | ||
cluster_url=config.url, | ||
auth_credentials=wvc.init.Auth.api_key(config.api_key) | ||
) | ||
|
||
def create_collection(self, collection_name: str, embeddings: Embeddings): | ||
logger.debug(f"[Weaviate] Creating new collection {collection_name}") | ||
self.weaviate_client.collections.create( | ||
name=collection_name.capitalize(), | ||
replication_config=wvc.config.Configure.replication( | ||
factor=1 | ||
), | ||
vectorizer_config=wvc.config.Configure.Vectorizer.none(), | ||
properties=[ | ||
wvc.config.Property(name=DATA_POINT_FQN_METADATA_KEY, data_type=wvc.config.DataType.TEXT) | ||
] | ||
) | ||
logger.debug(f"[Weaviate] Created new collection {collection_name}") | ||
|
||
def create_collection(self, collection_name: str, embeddings: Embeddings): | ||
self.weaviate_client.schema.create_class( | ||
{ | ||
"class": collection_name.capitalize(), | ||
"properties": [ | ||
{ | ||
"name": f"{DATA_POINT_FQN_METADATA_KEY}", | ||
"dataType": ["text"], | ||
}, | ||
], | ||
} | ||
def _get_records_to_be_updated(self, collection_name: str, data_point_fqns: List[str]): | ||
logger.debug( | ||
f"[Weaviate] Incremental Ingestion: Fetching documents for {len(data_point_fqns)} data point fqns for collection {collection_name}" | ||
) | ||
stop = False | ||
offset = 0 | ||
record_ids_to_be_updated = [] | ||
while stop is not True: | ||
records = self.weaviate_client.collections \ | ||
.get(collection_name.capitalize()).query \ | ||
.fetch_objects( | ||
limit=BATCH_SIZE, | ||
filters=wvc.query.Filter.by_property(DATA_POINT_FQN_METADATA_KEY).contains_any(data_point_fqns), | ||
offset=offset, | ||
return_properties=[DATA_POINT_FQN_METADATA_KEY] | ||
) | ||
if not records or len(records.objects) < BATCH_SIZE or len(record_ids_to_be_updated) > MAX_SCROLL_LIMIT: | ||
stop = True | ||
for record in records.objects: | ||
record_ids_to_be_updated.append(record.uuid) | ||
offset += BATCH_SIZE | ||
logger.debug( | ||
f"[Weaviate] Incremental Ingestion: collection={collection_name} Addition={len(data_point_fqns)}, Updates={len(record_ids_to_be_updated)}" | ||
) | ||
return record_ids_to_be_updated | ||
|
||
|
||
|
||
def upsert_documents( | ||
self, | ||
|
@@ -59,25 +90,59 @@ def upsert_documents( | |
Returns: | ||
- None | ||
""" | ||
Weaviate.from_documents( | ||
if len(documents) == 0: | ||
logger.warning("No documents to index") | ||
return | ||
logger.debug( | ||
f"[Weaviate] Adding {len(documents)} documents to collection {collection_name}" | ||
) | ||
|
||
data_point_fqns = [] | ||
for document in documents: | ||
if document.metadata.get(DATA_POINT_FQN_METADATA_KEY): | ||
data_point_fqns.append( | ||
document.metadata.get(DATA_POINT_FQN_METADATA_KEY) | ||
) | ||
records_to_be_updated:List[str] | ||
if incremental: | ||
records_to_be_updated = self._get_records_to_be_updated(collection_name, data_point_fqns) | ||
|
||
WeaviateVectorStore.from_documents( | ||
documents=documents, | ||
embedding=embeddings, | ||
client=self.weaviate_client, | ||
index_name=collection_name.capitalize(), | ||
) | ||
logger.debug( | ||
f"[Weaviate] Added {len(documents)} documents to collection {collection_name}" | ||
) | ||
|
||
if len(records_to_be_updated) > 0: | ||
logger.debug( | ||
f"[Weaviate] Deleting {len(records_to_be_updated)} outdated documents from collection {collection_name}" | ||
) | ||
collection = self.weaviate_client.collections.get(collection_name.capitalize()) | ||
for i in range(0, len(records_to_be_updated), BATCH_SIZE): | ||
record_ids_to_be_processed = records_to_be_updated[i : i + BATCH_SIZE] | ||
collection.data.delete_many( | ||
where=wvc.query.Filter.by_id().contains_any(record_ids_to_be_processed) | ||
) | ||
logger.debug( | ||
f"[Weaviate] Deleted {len(records_to_be_updated)} outdated documents from collection {collection_name}" | ||
) | ||
|
||
def get_collections(self) -> List[str]: | ||
collections = self.weaviate_client.schema.get().get("classes", []) | ||
return [decapitalize(collection["class"]) for collection in collections] | ||
collections = self.weaviate_client.collections.list_all(simple=True) | ||
return list(collections.keys()) | ||
|
||
def delete_collection( | ||
self, | ||
collection_name: str, | ||
): | ||
return self.weaviate_client.schema.delete_class(collection_name.capitalize()) | ||
return self.weaviate_client.collections.delete(collection_name.capitalize()) | ||
|
||
def get_vector_store(self, collection_name: str, embeddings: Embeddings): | ||
return Weaviate( | ||
return WeaviateVectorStore( | ||
client=self.weaviate_client, | ||
embedding=embeddings, | ||
index_name=collection_name.capitalize(), # Weaviate stores the index name as capitalized | ||
|
@@ -92,40 +157,13 @@ def list_documents_in_collection( | |
""" | ||
List all documents in a collection | ||
""" | ||
# https://weaviate.io/developers/weaviate/search/aggregate#retrieve-groupedby-properties | ||
response = ( | ||
self.weaviate_client.query.aggregate(collection_name.capitalize()) | ||
.with_group_by_filter([f"{DATA_POINT_FQN_METADATA_KEY}"]) | ||
.with_fields("groupedBy { value }") | ||
.do() | ||
) | ||
groups: List[Dict[Any, Any]] = ( | ||
response.get("data", {}) | ||
.get("Aggregate", {}) | ||
.get(collection_name.capitalize(), []) | ||
) | ||
document_ids = set() | ||
for group in groups: | ||
# TODO (chiragjn): Revisit this, we should not be letting `value` be empty | ||
document_ids.add(group.get("groupedBy", {}).get("value", "") or "") | ||
return list(document_ids) | ||
pass | ||
|
||
def delete_documents(self, collection_name: str, document_ids: List[str]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function can be removed as it is not in the base class. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
""" | ||
Delete documents from the collection that match given `document_id_match` | ||
""" | ||
# https://weaviate.io/developers/weaviate/manage-data/delete#delete-multiple-objects | ||
res = self.weaviate_client.batch.delete_objects( | ||
class_name=collection_name.capitalize(), | ||
where={ | ||
"path": [f"{DATA_POINT_FQN_METADATA_KEY}"], | ||
"operator": "ContainsAny", | ||
"valueTextArray": document_ids, | ||
}, | ||
) | ||
deleted_vectors = res.get("results", {}).get("successful", None) | ||
if deleted_vectors: | ||
print(f"Deleted {len(document_ids)} documents from the collection") | ||
pass | ||
|
||
def get_vector_client(self): | ||
return self.weaviate_client | ||
|
@@ -136,7 +174,8 @@ def list_data_point_vectors( | |
data_source_fqn: str, | ||
batch_size: int = 1000, | ||
) -> List[DataPointVector]: | ||
pass | ||
document_vector_points: List[DataPointVector] = [] | ||
return document_vector_points | ||
|
||
def delete_data_point_vectors( | ||
self, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,4 @@ | |
singlestoredb==1.0.4 | ||
|
||
### Weaviate client (in progress) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can remove this in progress comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
weaviate-client==3.25.3 | ||
weaviate-client==4.7.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function can be removed as it is not in the base class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done