-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathHybridSearchDemo.py
144 lines (110 loc) · 5.1 KB
/
HybridSearchDemo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from llama_index.core import VectorStoreIndex, StorageContext, Settings
from llama_index.core.storage.index_store.simple_index_store import SimpleIndexStore
from llama_index.vector_stores.qdrant import QdrantVectorStore
from qdrant_client import QdrantClient, AsyncQdrantClient
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.llms.openai import OpenAI
from llama_index.core.postprocessor import LLMRerank
from llama_index.core.storage.docstore.simple_docstore import SimpleDocumentStore
from llama_index.embeddings.openai import OpenAIEmbedding
import os
from DocumentContextExtractor import DocumentContextExtractor
# TODO: add 'query context' to this
class HybridSearchWithContext:
CHUNK_SIZE = 512
CHUNK_OVERLAP = 50
SIMILARITY_TOP_K = 10
SPARSE_TOP_K = 20
REREANKER_TOP_N = 3
def __init__(self, name:str):
"""
:param name: The name of the index, required for the underlying vector store
"""
# Initialize clients
client = QdrantClient(":memory:")
aclient = AsyncQdrantClient(":memory:")
self.index_store_path = f"{name}"
if not os.path.exists(self.index_store_path):
os.makedirs(self.index_store_path)
# Load documents
self.context_llm = OpenAI(model="gpt-4o-mini")
self.answering_llm = OpenAI(model="gpt-4o-mini")
self.embed_model = OpenAIEmbedding(model="text-embedding-3-small")
sample_embedding = self.embed_model.get_query_embedding("sample text")
self.embed_size = len(sample_embedding)
self.reranker = LLMRerank(
choice_batch_size=5,
top_n=self.REREANKER_TOP_N,
llm=self.context_llm
)
# Create vector store
self.vector_store = QdrantVectorStore(
name,
client=client,
aclient=aclient,
enable_hybrid=True,
batch_size=20,
dim=self.embed_size
)
# Initialize storage context
if os.path.exists(os.path.join(self.index_store_path, "index_store.json")):
index_store=SimpleIndexStore.from_persist_dir(persist_dir=self.index_store_path)
else:
index_store=SimpleIndexStore()
self.storage_context = StorageContext.from_defaults(vector_store=self.vector_store,
index_store=index_store)
# Create text splitter
self.text_splitter = SentenceSplitter(
chunk_size=self.CHUNK_SIZE,
chunk_overlap=self.CHUNK_OVERLAP
)
# DocumentContextExtractor requires a document store
# 1st 2 arguments are required.
# max_contextual_tokens plus chunk_size should be a little less than the max input size of your embedding to give some headroom
self.document_context_extractor = DocumentContextExtractor(docstore=self.storage_context.docstore,
llm=self.context_llm, max_context_length=128000,
max_contextual_tokens=512,
oversized_document_strategy="truncate_first")
self.index = VectorStoreIndex.from_vector_store(
vector_store=self.vector_store,
embed_model=self.embed_model,
storage_context=self.storage_context,
transformations=[self.text_splitter, self.document_context_extractor]
)
self.storage_context.persist(persist_dir=self.index_store_path)
def add_directory(self, directory):
reader = SimpleDirectoryReader(directory)
documents = reader.load_data()
self.storage_context.docstore.add_documents(documents)
for doc in documents:
self.index.insert(doc)
self.query_engine = self.index.as_query_engine(
similarity_top_k=self.SIMILARITY_TOP_K,
sparse_top_k=self.SPARSE_TOP_K,
vector_store_query_mode="hybrid",
llm=self.answering_llm,
node_postprocessors=[self.reranker]
)
self.retriever = self.index.as_retriever(
similarity_top_k=self.SIMILARITY_TOP_K,
sparse_top_k=self.SPARSE_TOP_K,
vector_store_query_mode="hybrid"
)
self.storage_context.persist(persist_dir=self.index_store_path)
def get_raw_search_results(self, question):
retrieved_nodes = self.retriever.retrieve(question)
retrieved_texts = [node.text for node in retrieved_nodes]
return retrieved_nodes
def query_engine(self, question):
response = self.query_engine.query(
question
)
return response
if __name__=='__main__':
from dotenv import load_dotenv
load_dotenv()
hybrid_search = HybridSearchWithContext(name="hybriddemo")
hybrid_search.add_directory("./data")
question = "Why was this document written?"
print(hybrid_search.get_raw_search_results(question))