Skip to content

Commit

Permalink
fix: Updates faiss store to handle multiple indices and dynamically g…
Browse files Browse the repository at this point in the history
…et embedding dim

Signed-off-by: ishaansehgal99 <[email protected]>
  • Loading branch information
ishaansehgal99 committed Sep 24, 2024
1 parent d82897d commit cd9cbab
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 66 deletions.
22 changes: 11 additions & 11 deletions presets/rag_service/vector_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,45 @@

class BaseVectorStore(ABC):
@abstractmethod
def index_documents(self, documents: List[Document]) -> List[str]:
def index_documents(self, documents: List[Document], index_name: str) -> List[str]:
pass

@abstractmethod
def query(self, query: str, top_k: int):
def query(self, query: str, top_k: int, index_name: str):
pass

@abstractmethod
def add_document(self, document: Document):
def add_document(self, document: Document, index_name: str):
pass

@abstractmethod
def delete_document(self, doc_id: str):
def delete_document(self, doc_id: str, index_name: str):
pass

@abstractmethod
def update_document(self, document: Document) -> str:
def update_document(self, document: Document, index_name: str) -> str:
pass

@abstractmethod
def get_document(self, doc_id: str) -> Document:
def get_document(self, doc_id: str, index_name: str) -> Document:
pass

@abstractmethod
def list_documents(self) -> Dict[str, Document]:
def list_documents(self, index_name: str) -> Dict[str, Document]:
pass

@abstractmethod
def document_exists(self, doc_id: str) -> bool:
def document_exists(self, doc_id: str, index_name: str) -> bool:
pass

@abstractmethod
def refresh_documents(self, documents: List[Document]) -> List[bool]:
def refresh_documents(self, documents: List[Document], index_name: str) -> List[bool]:
pass

@abstractmethod
def list_documents(self) -> Dict[str, Document]:
def list_documents(self, index_name: str) -> Dict[str, Document]:
pass

@abstractmethod
def document_exists(self, doc_id: str) -> bool:
def document_exists(self, doc_id: str, index_name: str) -> bool:
pass
9 changes: 5 additions & 4 deletions presets/rag_service/vector_store/chromadb_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ def __init__(self, embed_model):
self.chroma_collection = self.chroma_client.create_collection(self.collection_name)
self.vector_store = ChromaVectorStore(chroma_collection=self.chroma_collection)
self.storage_context = StorageContext.from_defaults(vector_store=self.vector_store)
self.index = None # Use to store the in-memory index # TODO: Multiple indexes via name (e.g. namespace)
self.indices = {} # Use to store the in-memory index via namespace (e.g. namespace -> index)

if not os.path.exists(PERSIST_DIR):
os.makedirs(PERSIST_DIR)

def index_documents(self, documents: List[Document]):
"""Recreates the entire ChromaDB index and vector store with new documents."""
def index_documents(self, documents: List[Document], index_name: str):
"""Recreates the entire FAISS index and vector store with new documents."""
llama_docs = [LlamaDocument(text=doc.text, metadata=doc.metadata, id_=doc.doc_id) for doc in documents]
self.index = VectorStoreIndex.from_documents(llama_docs, storage_context=self.storage_context, embed_model=self.embed_model)
# Creates the actual vector-based index using indexing method, vector store, storage method and embedding model specified above
self.indices[index_name] = VectorStoreIndex.from_documents(llama_docs, storage_context=self.storage_context, embed_model=self.embed_model)
self._persist()
# Return the document IDs that were indexed
return [doc.doc_id for doc in documents]
Expand Down
119 changes: 68 additions & 51 deletions presets/rag_service/vector_store/faiss_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,95 +15,112 @@


class FaissVectorStoreManager(BaseVectorStore):
def __init__(self, dimension: int, embed_model):
self.dimension = dimension # TODO: Automatically needs to configure dim based on embed_model
def __init__(self, embed_model):
self.embed_model = embed_model
self.faiss_index = faiss.IndexFlatL2(self.dimension)
self.vector_store = FaissVectorStore(faiss_index=self.faiss_index)
self.storage_context = StorageContext.from_defaults(vector_store=self.vector_store)
self.index = None # Use to store the in-memory index # TODO: Multiple indexes via name (e.g. namespace)
self.dimension = self.embed_model.get_embedding_dimension()
# TODO: Consider allowing user custom indexing method e.g.
"""
# Choose the FAISS index type based on the provided index_method
if index_method == 'FlatL2':
faiss_index = faiss.IndexFlatL2(self.dimension) # L2 (Euclidean distance) index
elif index_method == 'FlatIP':
faiss_index = faiss.IndexFlatIP(self.dimension) # Inner product (cosine similarity) index
elif index_method == 'IVFFlat':
quantizer = faiss.IndexFlatL2(self.dimension) # Quantizer for IVF
faiss_index = faiss.IndexIVFFlat(quantizer, self.dimension, 100) # IVF with flat quantization
elif index_method == 'HNSW':
faiss_index = faiss.IndexHNSWFlat(self.dimension, 32) # HNSW index with 32 neighbors
else:
raise ValueError(f"Unknown index method: {index_method}")
"""
# TODO: We need to test if sharing storage_context is viable/correct or if we should make a new one for each index
self.faiss_index = faiss.IndexFlatL2(self.dimension) # Specifies FAISS indexing method (https://github.com/facebookresearch/faiss/wiki/Faiss-indexes)
self.vector_store = FaissVectorStore(faiss_index=self.faiss_index) # Specifies in-memory data structure for storing and retrieving document embeddings
self.storage_context = StorageContext.from_defaults(vector_store=self.vector_store) # Used to persist the vector store and its underlying data across sessions
self.indices = {} # Use to store the in-memory index via namespace (e.g. namespace -> index)

if not os.path.exists(PERSIST_DIR):
os.makedirs(PERSIST_DIR)

def index_documents(self, documents: List[Document]):
def index_documents(self, documents: List[Document], index_name: str):
"""Recreates the entire FAISS index and vector store with new documents."""
if index_name in self.indices:
print(f"Index {index_name} already exists. Overwriting.")
llama_docs = [LlamaDocument(text=doc.text, metadata=doc.metadata, id_=doc.doc_id) for doc in documents]
self.index = VectorStoreIndex.from_documents(llama_docs, storage_context=self.storage_context, embed_model=self.embed_model)
self._persist()
# Creates the actual vector-based index using indexing method, vector store, storage method and embedding model specified above
self.indices[index_name] = VectorStoreIndex.from_documents(llama_docs, storage_context=self.storage_context, embed_model=self.embed_model)
self._persist(index_name)
# Return the document IDs that were indexed
return [doc.doc_id for doc in documents]

def add_document(self, document: Document):
def add_document(self, document: Document, index_name: str):
"""Inserts a single document into the existing FAISS index."""
if self.index is None:
self.index = self._load_index() # Load if not already in memory
assert index_name in self.indices, f"No such index: '{index_name}' exists."
llama_doc = LlamaDocument(text=document.text, metadata=document.metadata, id_=document.doc_id)
self.index.insert(llama_doc)
self.storage_context.persist(persist_dir=PERSIST_DIR)
self.indices[index_name].insert(llama_doc)
self.indices[index_name].storage_context.persist(persist_dir=PERSIST_DIR)

def query(self, query: str, top_k: int):
def query(self, query: str, top_k: int, index_name: str):
"""Queries the FAISS vector store."""
if self.index is None:
self.index = self._load_index() # Load if not already in memory
query_engine = self.index.as_query_engine(top_k=top_k)
assert index_name in self.indices, f"No such index: '{index_name}' exists."
query_engine = self.indices[index_name].as_query_engine(top_k=top_k)
return query_engine.query(query)

def delete_document(self, doc_id: str):
def delete_document(self, doc_id: str, index_name: str):
"""Deletes a document from the FAISS vector store."""
if self.index is None:
self.index = self._load_index() # Load if not already in memory
self.index.delete_ref_doc(doc_id, delete_from_docstore=True)
self.storage_context.persist(persist_dir=PERSIST_DIR)
assert index_name in self.indices, f"No such index: '{index_name}' exists."
self.indices[index_name].delete_ref_doc(doc_id, delete_from_docstore=True)
self.indices[index_name].storage_context.persist(persist_dir=PERSIST_DIR)

def update_document(self, document: Document):
def update_document(self, document: Document, index_name: str):
"""Updates an existing document in the FAISS vector store."""
if self.index is None:
self.index = self._load_index() # Load if not already in memory
assert index_name in self.indices, f"No such index: '{index_name}' exists."
llama_doc = LlamaDocument(text=document.text, metadata=document.metadata, id_=document.doc_id)
self.index.update_ref_doc(llama_doc)
self.storage_context.persist(persist_dir=PERSIST_DIR)
self.indices[index_name].update_ref_doc(llama_doc)
self.indices[index_name].storage_context.persist(persist_dir=PERSIST_DIR)

def get_document(self, doc_id: str):
def get_document(self, doc_id: str, index_name: str):
"""Retrieves a document by its ID."""
if self.index is None:
self.index = self._load_index() # Load if not already in memory
doc = self.index.docstore.get_document(doc_id)
assert index_name in self.indices, f"No such index: '{index_name}' exists."
doc = self.indices[index_name].docstore.get_document(doc_id)
if not doc:
raise ValueError(f"Document with ID {doc_id} not found.")
return doc

def refresh_documents(self, documents: List[Document]) -> List[bool]:
def refresh_documents(self, documents: List[Document], index_name: str) -> List[bool]:
"""Updates existing documents and inserts new documents in the vector store."""
if self.index is None:
self.index = self._load_index() # Load if not already in memory
assert index_name in self.indices, f"No such index: '{index_name}' exists."
llama_docs = [LlamaDocument(text=doc.text, metadata=doc.metadata, id_=doc.doc_id) for doc in documents]
refresh_results = self.index.refresh_ref_docs(llama_docs)
self._persist()
refresh_results = self.indices[index_name].refresh_ref_docs(llama_docs)
self._persist(index_name)
# Returns a list of booleans indicating whether each document was successfully refreshed.
return refresh_results

def list_documents(self) -> Dict[str, Document]:
def list_documents(self, index_name: str) -> Dict[str, Document]:
"""Lists all documents in the vector store."""
if self.index is None:
self.index = self._load_index() # Load if not already in memory
assert index_name in self.indices, f"No such index: '{index_name}' exists."
return {doc_id: Document(text=doc.text, metadata=doc.metadata, doc_id=doc_id)
for doc_id, doc in self.index.docstore.docs.items()}
for doc_id, doc in self.indices[index_name].docstore.docs.items()}

def document_exists(self, doc_id: str) -> bool:
def document_exists(self, doc_id: str, index_name: str) -> bool:
"""Checks if a document exists in the vector store."""
if self.index is None:
self.index = self._load_index() # Load if not already in memory
return doc_id in self.index.docstore.docs
assert index_name in self.indices, f"No such index: '{index_name}' exists."
return doc_id in self.indices[index_name].docstore.docs

def _load_index(self):
def _load_index(self, index_name: str):
"""Loads the existing FAISS index from disk."""
vector_store = FaissVectorStore.from_persist_dir(PERSIST_DIR)
persist_dir = os.path.join(PERSIST_DIR, index_name)
if not os.path.exists(persist_dir):
raise ValueError(f"No persisted index found for '{index_name}'")
vector_store = FaissVectorStore.from_persist_dir(persist_dir)
storage_context = StorageContext.from_defaults(
vector_store=vector_store, persist_dir=PERSIST_DIR
vector_store=vector_store, persist_dir=persist_dir
)
return load_index_from_storage(storage_context=storage_context)
self.indices[index_name] = load_index_from_storage(storage_context=storage_context)
return self.indices[index_name]

def _persist(self):
def _persist(self, index_name: str):
"""Saves the existing FAISS index to disk."""
self.storage_context.persist(persist_dir=PERSIST_DIR)
assert index_name in self.indices, f"No such index: '{index_name}' exists."
storage_context = self.indices[index_name].storage_context
storage_context.persist(persist_dir=os.path.join(PERSIST_DIR, index_name))

0 comments on commit cd9cbab

Please sign in to comment.