-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy pathastra.py
78 lines (70 loc) · 2.63 KB
/
astra.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from typing import List
from astrapy.db import AstraDB
from semantic_router.encoders import BaseEncoder
from tqdm import tqdm
from models.document import BaseDocumentChunk
from models.query import Filter
from vectordbs.base import BaseVectorDatabase
class AstraService(BaseVectorDatabase):
def __init__(
self, index_name: str, dimension: int, credentials: dict, encoder: BaseEncoder
):
super().__init__(
index_name=index_name,
dimension=dimension,
credentials=credentials,
encoder=encoder,
)
self.client = AstraDB(
token=credentials["api_key"],
api_endpoint=credentials["host"],
)
collections = self.client.get_collections()
if self.index_name not in collections["status"]["collections"]:
self.collection = self.client.create_collection(
dimension=dimension, collection_name=index_name
)
self.collection = self.client.collection(collection_name=self.index_name)
# TODO: remove this
async def convert_to_rerank_format(self, chunks: List) -> List:
docs = [
{
"content": chunk.get("text"),
"page_label": chunk.get("page_label"),
"file_url": chunk.get("file_url"),
}
for chunk in chunks
]
return docs
async def upsert(self, chunks: List[BaseDocumentChunk]) -> None:
documents = [
{
"_id": chunk.id,
"text": chunk.content,
"$vector": chunk.dense_embedding,
**chunk.metadata,
}
for chunk in tqdm(chunks, desc="Upserting to Astra")
]
for i in range(0, len(documents), 5):
self.collection.insert_many(documents=documents[i : i + 5])
async def query(self, input: str, filter: Filter = None, top_k: int = 4) -> List:
vectors = await self._generate_vectors(input=input)
results = self.collection.vector_find(
vector=vectors[0],
limit=top_k,
fields={"text", "page_number", "source", "document_id"},
filter=filter,
)
return [
BaseDocumentChunk(
id=result.get("_id"),
document_id=result.get("document_id"),
content=result.get("text"),
doc_url=result.get("source"),
page_number=result.get("page_number"),
)
for result in results
]
async def delete(self, file_url: str) -> None:
self.collection.delete_many(filter={"file_url": file_url})