-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy pathbase.py
79 lines (68 loc) · 2.47 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from abc import ABC, abstractmethod
from typing import List
from decouple import config
from semantic_router.encoders import BaseEncoder
from tqdm import tqdm
from models.delete import DeleteResponse
from models.document import BaseDocumentChunk
from models.query import Filter
from utils.logger import logger
class BaseVectorDatabase(ABC):
def __init__(
self, index_name: str, dimension: int, credentials: dict, encoder: BaseEncoder
):
self.index_name = index_name
self.dimension = dimension
self.credentials = credentials
self.encoder = encoder
@abstractmethod
async def upsert(self, chunks: List[BaseDocumentChunk]):
pass
@abstractmethod
async def query(
self, input: str, filter: Filter, top_k: int = 25
) -> List[BaseDocumentChunk]:
pass
@abstractmethod
async def delete(self, file_url: str) -> DeleteResponse:
pass
async def _generate_vectors(self, input: str) -> List[List[float]]:
return self.encoder([input])
async def rerank(
self, query: str, documents: list[BaseDocumentChunk], top_n: int = 5
) -> list[BaseDocumentChunk]:
from cohere import Client
api_key = config("COHERE_API_KEY")
if not api_key:
raise ValueError("API key for Cohere is not present.")
cohere_client = Client(api_key=api_key)
# Avoid duplications, TODO: fix ingestion for duplications
# Deduplicate documents based on content while preserving order
seen = set()
deduplicated_documents = [
doc
for doc in documents
if doc.content not in seen and not seen.add(doc.content)
]
docs_text = list(
doc.content
for doc in tqdm(
deduplicated_documents,
desc=f"Reranking {len(deduplicated_documents)} documents",
)
)
try:
re_ranked = cohere_client.rerank(
model="rerank-multilingual-v2.0",
query=query,
documents=docs_text,
top_n=top_n,
).results
results = []
for r in tqdm(re_ranked, desc="Processing reranked results "):
doc = deduplicated_documents[r.index]
results.append(doc)
return results
except Exception as e:
logger.error(f"Error while reranking: {e}")
raise Exception(f"Error while reranking: {e}")