Skip to content

Commit

Permalink
feat: Update VectorStore Base class (#673)
Browse files Browse the repository at this point in the history
Reason for Change:
Cleans up vector store classes by putting common logic in the base class
and inheriting from their.

This makes the code cleaner, readable and more modular for future work.

<i>PR also adds license headers</i>
  • Loading branch information
ishaansehgal99 authored Nov 5, 2024
1 parent b97ab11 commit ad0dde9
Show file tree
Hide file tree
Showing 17 changed files with 189 additions and 187 deletions.
3 changes: 3 additions & 0 deletions pkg/ragengine/services/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

# config.py

# Variables are set via environment variables from the RAGEngine CR
Expand Down
3 changes: 3 additions & 0 deletions pkg/ragengine/services/embedding/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from abc import ABC, abstractmethod


Expand Down
3 changes: 3 additions & 0 deletions pkg/ragengine/services/embedding/huggingface_local.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from llama_index.embeddings.huggingface import HuggingFaceEmbedding

from .base import BaseEmbeddingModel
Expand Down
3 changes: 3 additions & 0 deletions pkg/ragengine/services/embedding/huggingface_remote.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from llama_index.embeddings.huggingface_api import \
HuggingFaceInferenceAPIEmbedding

Expand Down
3 changes: 3 additions & 0 deletions pkg/ragengine/services/inference/inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Any
from llama_index.core.llms import CustomLLM, CompletionResponse, LLMMetadata, CompletionResponseGen
from llama_index.llms.openai import OpenAI
Expand Down
3 changes: 3 additions & 0 deletions pkg/ragengine/services/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import List
from vector_store_manager.manager import VectorStoreManager
from embedding.huggingface_local import LocalHuggingFaceEmbedding
Expand Down
3 changes: 3 additions & 0 deletions pkg/ragengine/services/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Dict, List, Optional

from pydantic import BaseModel
Expand Down
1 change: 1 addition & 0 deletions pkg/ragengine/services/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ llama-index-llms-huggingface-api
fastapi
faiss-cpu
llama-index-vector-stores-faiss
llama-index-vector-stores-azurecosmosmongo
uvicorn
# For UTs
pytest
3 changes: 3 additions & 0 deletions pkg/ragengine/services/tests/api/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
Expand Down
3 changes: 3 additions & 0 deletions pkg/ragengine/services/tests/api/test_main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from unittest.mock import patch

from llama_index.core.storage.index_store import SimpleIndexStore
Expand Down
3 changes: 3 additions & 0 deletions pkg/ragengine/services/tests/vector_store/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
from tempfile import TemporaryDirectory
from unittest.mock import patch
Expand Down Expand Up @@ -99,7 +102,7 @@ def test_add_document(vector_store_manager):
vector_store_manager.index_documents("test_index", new_document)

# Assert that the document exists
assert vector_store_manager.document_exists("test_index",
assert vector_store_manager.document_exists("test_index", new_document[0],
BaseVectorStore.generate_doc_id("Fourth document"))

def test_persist_index_1(vector_store_manager):
Expand Down
149 changes: 138 additions & 11 deletions pkg/ragengine/services/vector_store/base.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,158 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from abc import ABC, abstractmethod
from typing import Dict, List
import hashlib
import os

from llama_index.core import Document as LlamaDocument
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.core import (StorageContext, VectorStoreIndex)

from services.models import Document
import hashlib
from services.embedding.base import BaseEmbeddingModel
from services.inference.inference import Inference
from services.config import PERSIST_DIR

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class BaseVectorStore(ABC):
def __init__(self, embedding_manager: BaseEmbeddingModel):
self.embedding_manager = embedding_manager
self.embed_model = self.embedding_manager.model
self.index_map = {}
self.index_store = SimpleIndexStore()
self.llm = Inference()

@staticmethod
def generate_doc_id(text: str) -> str:
"""Generates a unique document ID based on the hash of the document text."""
return hashlib.sha256(text.encode('utf-8')).hexdigest()

@abstractmethod
def index_documents(self, index_name: str, documents: List[Document]) -> List[str]:
pass
"""Common indexing logic for all vector stores."""
if index_name in self.index_map:
return self._append_documents_to_index(index_name, documents)
else:
return self._create_new_index(index_name, documents)

def _append_documents_to_index(self, index_name: str, documents: List[Document]) -> List[str]:
"""Common logic for appending documents to existing index."""
logger.info(f"Index {index_name} already exists. Appending documents to existing index.")
indexed_doc_ids = set()

for doc in documents:
doc_id = self.generate_doc_id(doc.text)
if not self.document_exists(index_name, doc, doc_id):
self.add_document_to_index(index_name, doc, doc_id)
indexed_doc_ids.add(doc_id)
else:
logger.info(f"Document {doc_id} already exists in index {index_name}. Skipping.")

if indexed_doc_ids:
self._persist(index_name)
return list(indexed_doc_ids)

@abstractmethod
def query(self, index_name: str, query: str, top_k: int, params: dict):
def _create_new_index(self, index_name: str, documents: List[Document]) -> List[str]:
"""Create a new index - implementation specific to each vector store."""
pass

def _create_index_common(self, index_name: str, documents: List[Document], vector_store) -> List[str]:
"""Common logic for creating a new index with documents."""
storage_context = StorageContext.from_defaults(vector_store=vector_store)
llama_docs = []
indexed_doc_ids = set()

for doc in documents:
doc_id = self.generate_doc_id(doc.text)
llama_doc = LlamaDocument(id_=doc_id, text=doc.text, metadata=doc.metadata)
llama_docs.append(llama_doc)
indexed_doc_ids.add(doc_id)

if llama_docs:
index = VectorStoreIndex.from_documents(
llama_docs,
storage_context=storage_context,
embed_model=self.embed_model,
)
index.set_index_id(index_name)
self.index_map[index_name] = index
self.index_store.add_index_struct(index.index_struct)
self._persist(index_name)
return list(indexed_doc_ids)

def query(self, index_name: str, query: str, top_k: int, llm_params: dict):
"""Common query logic for all vector stores."""
if index_name not in self.index_map:
raise ValueError(f"No such index: '{index_name}' exists.")
self.llm.set_params(llm_params)

query_engine = self.index_map[index_name].as_query_engine(
llm=self.llm,
similarity_top_k=top_k
)
query_result = query_engine.query(query)
return {
"response": query_result.response,
"source_nodes": [
{
"node_id": node.node_id,
"text": node.text,
"score": node.score,
"metadata": node.metadata
}
for node in query_result.source_nodes
],
"metadata": query_result.metadata,
}

@abstractmethod
def add_document_to_index(self, index_name: str, document: Document, doc_id: str):
pass
"""Common logic for adding a single document."""
if index_name not in self.index_map:
raise ValueError(f"No such index: '{index_name}' exists.")
llama_doc = LlamaDocument(text=document.text, metadata=document.metadata, id_=doc_id)
self.index_map[index_name].insert(llama_doc)

@abstractmethod
def list_all_indexed_documents(self) -> Dict[str, Dict[str, Dict[str, str]]]:
pass
"""Common logic for listing all documents."""
return {
index_name: {
doc_info.ref_doc_id: {
"text": doc_info.text,
"hash": doc_info.hash
} for _, doc_info in vector_store_index.docstore.docs.items()
}
for index_name, vector_store_index in self.index_map.items()
}

@abstractmethod
def document_exists(self, index_name: str, doc_id: str) -> bool:
pass
def document_exists(self, index_name: str, doc: Document, doc_id: str) -> bool:
"""Common logic for checking document existence."""
if index_name not in self.index_map:
logger.warning(f"No such index: '{index_name}' exists in vector store.")
return False
return doc_id in self.index_map[index_name].ref_doc_info

def _persist_all(self):
"""Common persistence logic."""
logger.info("Persisting all indexes.")
self.index_store.persist(os.path.join(PERSIST_DIR, "store.json"))
for idx in self.index_store.index_structs():
self._persist(idx.index_id)

def _persist(self, index_name: str):
"""Common persistence logic for individual index."""
try:
logger.info(f"Persisting index {index_name}.")
self.index_store.persist(os.path.join(PERSIST_DIR, "store.json"))
assert index_name in self.index_map, f"No such index: '{index_name}' exists."
storage_context = self.index_map[index_name].storage_context
# Persist the specific index
storage_context.persist(persist_dir=os.path.join(PERSIST_DIR, index_name))
logger.info(f"Successfully persisted index {index_name}.")
except Exception as e:
logger.error(f"Failed to persist index {index_name}. Error: {str(e)}")
Loading

0 comments on commit ad0dde9

Please sign in to comment.