From cd9cbab69800322ce5fe3f917b4e86027b6a04d7 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Tue, 24 Sep 2024 11:46:04 -0700 Subject: [PATCH] fix: Updates faiss store to handle multiple indices and dynamically get embedding dim Signed-off-by: ishaansehgal99 --- presets/rag_service/vector_store/base.py | 22 ++-- .../vector_store/chromadb_store.py | 9 +- .../rag_service/vector_store/faiss_store.py | 119 ++++++++++-------- .../{ => playground}/chromadb_playground.py | 0 4 files changed, 84 insertions(+), 66 deletions(-) rename presets/rag_service/vector_store/{ => playground}/chromadb_playground.py (100%) diff --git a/presets/rag_service/vector_store/base.py b/presets/rag_service/vector_store/base.py index b791bb7e6..d9b92315c 100644 --- a/presets/rag_service/vector_store/base.py +++ b/presets/rag_service/vector_store/base.py @@ -6,45 +6,45 @@ class BaseVectorStore(ABC): @abstractmethod - def index_documents(self, documents: List[Document]) -> List[str]: + def index_documents(self, documents: List[Document], index_name: str) -> List[str]: pass @abstractmethod - def query(self, query: str, top_k: int): + def query(self, query: str, top_k: int, index_name: str): pass @abstractmethod - def add_document(self, document: Document): + def add_document(self, document: Document, index_name: str): pass @abstractmethod - def delete_document(self, doc_id: str): + def delete_document(self, doc_id: str, index_name: str): pass @abstractmethod - def update_document(self, document: Document) -> str: + def update_document(self, document: Document, index_name: str) -> str: pass @abstractmethod - def get_document(self, doc_id: str) -> Document: + def get_document(self, doc_id: str, index_name: str) -> Document: pass @abstractmethod - def list_documents(self) -> Dict[str, Document]: + def list_documents(self, index_name: str) -> Dict[str, Document]: pass @abstractmethod - def document_exists(self, doc_id: str) -> bool: + def document_exists(self, doc_id: str, index_name: str) -> bool: pass @abstractmethod - def refresh_documents(self, documents: List[Document]) -> List[bool]: + def refresh_documents(self, documents: List[Document], index_name: str) -> List[bool]: pass @abstractmethod - def list_documents(self) -> Dict[str, Document]: + def list_documents(self, index_name: str) -> Dict[str, Document]: pass @abstractmethod - def document_exists(self, doc_id: str) -> bool: + def document_exists(self, doc_id: str, index_name: str) -> bool: pass \ No newline at end of file diff --git a/presets/rag_service/vector_store/chromadb_store.py b/presets/rag_service/vector_store/chromadb_store.py index acb940747..927318202 100644 --- a/presets/rag_service/vector_store/chromadb_store.py +++ b/presets/rag_service/vector_store/chromadb_store.py @@ -22,15 +22,16 @@ def __init__(self, embed_model): self.chroma_collection = self.chroma_client.create_collection(self.collection_name) self.vector_store = ChromaVectorStore(chroma_collection=self.chroma_collection) self.storage_context = StorageContext.from_defaults(vector_store=self.vector_store) - self.index = None # Use to store the in-memory index # TODO: Multiple indexes via name (e.g. namespace) + self.indices = {} # Use to store the in-memory index via namespace (e.g. namespace -> index) if not os.path.exists(PERSIST_DIR): os.makedirs(PERSIST_DIR) - def index_documents(self, documents: List[Document]): - """Recreates the entire ChromaDB index and vector store with new documents.""" + def index_documents(self, documents: List[Document], index_name: str): + """Recreates the entire FAISS index and vector store with new documents.""" llama_docs = [LlamaDocument(text=doc.text, metadata=doc.metadata, id_=doc.doc_id) for doc in documents] - self.index = VectorStoreIndex.from_documents(llama_docs, storage_context=self.storage_context, embed_model=self.embed_model) + # Creates the actual vector-based index using indexing method, vector store, storage method and embedding model specified above + self.indices[index_name] = VectorStoreIndex.from_documents(llama_docs, storage_context=self.storage_context, embed_model=self.embed_model) self._persist() # Return the document IDs that were indexed return [doc.doc_id for doc in documents] diff --git a/presets/rag_service/vector_store/faiss_store.py b/presets/rag_service/vector_store/faiss_store.py index df44e6c8f..e33b3904a 100644 --- a/presets/rag_service/vector_store/faiss_store.py +++ b/presets/rag_service/vector_store/faiss_store.py @@ -15,95 +15,112 @@ class FaissVectorStoreManager(BaseVectorStore): - def __init__(self, dimension: int, embed_model): - self.dimension = dimension # TODO: Automatically needs to configure dim based on embed_model + def __init__(self, embed_model): self.embed_model = embed_model - self.faiss_index = faiss.IndexFlatL2(self.dimension) - self.vector_store = FaissVectorStore(faiss_index=self.faiss_index) - self.storage_context = StorageContext.from_defaults(vector_store=self.vector_store) - self.index = None # Use to store the in-memory index # TODO: Multiple indexes via name (e.g. namespace) + self.dimension = self.embed_model.get_embedding_dimension() + # TODO: Consider allowing user custom indexing method e.g. + """ + # Choose the FAISS index type based on the provided index_method + if index_method == 'FlatL2': + faiss_index = faiss.IndexFlatL2(self.dimension) # L2 (Euclidean distance) index + elif index_method == 'FlatIP': + faiss_index = faiss.IndexFlatIP(self.dimension) # Inner product (cosine similarity) index + elif index_method == 'IVFFlat': + quantizer = faiss.IndexFlatL2(self.dimension) # Quantizer for IVF + faiss_index = faiss.IndexIVFFlat(quantizer, self.dimension, 100) # IVF with flat quantization + elif index_method == 'HNSW': + faiss_index = faiss.IndexHNSWFlat(self.dimension, 32) # HNSW index with 32 neighbors + else: + raise ValueError(f"Unknown index method: {index_method}") + """ + # TODO: We need to test if sharing storage_context is viable/correct or if we should make a new one for each index + self.faiss_index = faiss.IndexFlatL2(self.dimension) # Specifies FAISS indexing method (https://github.com/facebookresearch/faiss/wiki/Faiss-indexes) + self.vector_store = FaissVectorStore(faiss_index=self.faiss_index) # Specifies in-memory data structure for storing and retrieving document embeddings + self.storage_context = StorageContext.from_defaults(vector_store=self.vector_store) # Used to persist the vector store and its underlying data across sessions + self.indices = {} # Use to store the in-memory index via namespace (e.g. namespace -> index) if not os.path.exists(PERSIST_DIR): os.makedirs(PERSIST_DIR) - def index_documents(self, documents: List[Document]): + def index_documents(self, documents: List[Document], index_name: str): """Recreates the entire FAISS index and vector store with new documents.""" + if index_name in self.indices: + print(f"Index {index_name} already exists. Overwriting.") llama_docs = [LlamaDocument(text=doc.text, metadata=doc.metadata, id_=doc.doc_id) for doc in documents] - self.index = VectorStoreIndex.from_documents(llama_docs, storage_context=self.storage_context, embed_model=self.embed_model) - self._persist() + # Creates the actual vector-based index using indexing method, vector store, storage method and embedding model specified above + self.indices[index_name] = VectorStoreIndex.from_documents(llama_docs, storage_context=self.storage_context, embed_model=self.embed_model) + self._persist(index_name) # Return the document IDs that were indexed return [doc.doc_id for doc in documents] - def add_document(self, document: Document): + def add_document(self, document: Document, index_name: str): """Inserts a single document into the existing FAISS index.""" - if self.index is None: - self.index = self._load_index() # Load if not already in memory + assert index_name in self.indices, f"No such index: '{index_name}' exists." llama_doc = LlamaDocument(text=document.text, metadata=document.metadata, id_=document.doc_id) - self.index.insert(llama_doc) - self.storage_context.persist(persist_dir=PERSIST_DIR) + self.indices[index_name].insert(llama_doc) + self.indices[index_name].storage_context.persist(persist_dir=PERSIST_DIR) - def query(self, query: str, top_k: int): + def query(self, query: str, top_k: int, index_name: str): """Queries the FAISS vector store.""" - if self.index is None: - self.index = self._load_index() # Load if not already in memory - query_engine = self.index.as_query_engine(top_k=top_k) + assert index_name in self.indices, f"No such index: '{index_name}' exists." + query_engine = self.indices[index_name].as_query_engine(top_k=top_k) return query_engine.query(query) - def delete_document(self, doc_id: str): + def delete_document(self, doc_id: str, index_name: str): """Deletes a document from the FAISS vector store.""" - if self.index is None: - self.index = self._load_index() # Load if not already in memory - self.index.delete_ref_doc(doc_id, delete_from_docstore=True) - self.storage_context.persist(persist_dir=PERSIST_DIR) + assert index_name in self.indices, f"No such index: '{index_name}' exists." + self.indices[index_name].delete_ref_doc(doc_id, delete_from_docstore=True) + self.indices[index_name].storage_context.persist(persist_dir=PERSIST_DIR) - def update_document(self, document: Document): + def update_document(self, document: Document, index_name: str): """Updates an existing document in the FAISS vector store.""" - if self.index is None: - self.index = self._load_index() # Load if not already in memory + assert index_name in self.indices, f"No such index: '{index_name}' exists." llama_doc = LlamaDocument(text=document.text, metadata=document.metadata, id_=document.doc_id) - self.index.update_ref_doc(llama_doc) - self.storage_context.persist(persist_dir=PERSIST_DIR) + self.indices[index_name].update_ref_doc(llama_doc) + self.indices[index_name].storage_context.persist(persist_dir=PERSIST_DIR) - def get_document(self, doc_id: str): + def get_document(self, doc_id: str, index_name: str): """Retrieves a document by its ID.""" - if self.index is None: - self.index = self._load_index() # Load if not already in memory - doc = self.index.docstore.get_document(doc_id) + assert index_name in self.indices, f"No such index: '{index_name}' exists." + doc = self.indices[index_name].docstore.get_document(doc_id) if not doc: raise ValueError(f"Document with ID {doc_id} not found.") return doc - def refresh_documents(self, documents: List[Document]) -> List[bool]: + def refresh_documents(self, documents: List[Document], index_name: str) -> List[bool]: """Updates existing documents and inserts new documents in the vector store.""" - if self.index is None: - self.index = self._load_index() # Load if not already in memory + assert index_name in self.indices, f"No such index: '{index_name}' exists." llama_docs = [LlamaDocument(text=doc.text, metadata=doc.metadata, id_=doc.doc_id) for doc in documents] - refresh_results = self.index.refresh_ref_docs(llama_docs) - self._persist() + refresh_results = self.indices[index_name].refresh_ref_docs(llama_docs) + self._persist(index_name) # Returns a list of booleans indicating whether each document was successfully refreshed. return refresh_results - def list_documents(self) -> Dict[str, Document]: + def list_documents(self, index_name: str) -> Dict[str, Document]: """Lists all documents in the vector store.""" - if self.index is None: - self.index = self._load_index() # Load if not already in memory + assert index_name in self.indices, f"No such index: '{index_name}' exists." return {doc_id: Document(text=doc.text, metadata=doc.metadata, doc_id=doc_id) - for doc_id, doc in self.index.docstore.docs.items()} + for doc_id, doc in self.indices[index_name].docstore.docs.items()} - def document_exists(self, doc_id: str) -> bool: + def document_exists(self, doc_id: str, index_name: str) -> bool: """Checks if a document exists in the vector store.""" - if self.index is None: - self.index = self._load_index() # Load if not already in memory - return doc_id in self.index.docstore.docs + assert index_name in self.indices, f"No such index: '{index_name}' exists." + return doc_id in self.indices[index_name].docstore.docs - def _load_index(self): + def _load_index(self, index_name: str): """Loads the existing FAISS index from disk.""" - vector_store = FaissVectorStore.from_persist_dir(PERSIST_DIR) + persist_dir = os.path.join(PERSIST_DIR, index_name) + if not os.path.exists(persist_dir): + raise ValueError(f"No persisted index found for '{index_name}'") + vector_store = FaissVectorStore.from_persist_dir(persist_dir) storage_context = StorageContext.from_defaults( - vector_store=vector_store, persist_dir=PERSIST_DIR + vector_store=vector_store, persist_dir=persist_dir ) - return load_index_from_storage(storage_context=storage_context) + self.indices[index_name] = load_index_from_storage(storage_context=storage_context) + return self.indices[index_name] - def _persist(self): + def _persist(self, index_name: str): """Saves the existing FAISS index to disk.""" - self.storage_context.persist(persist_dir=PERSIST_DIR) + assert index_name in self.indices, f"No such index: '{index_name}' exists." + storage_context = self.indices[index_name].storage_context + storage_context.persist(persist_dir=os.path.join(PERSIST_DIR, index_name)) diff --git a/presets/rag_service/vector_store/chromadb_playground.py b/presets/rag_service/vector_store/playground/chromadb_playground.py similarity index 100% rename from presets/rag_service/vector_store/chromadb_playground.py rename to presets/rag_service/vector_store/playground/chromadb_playground.py