Skip to content

Commit

Permalink
feat: Updated UTs, models and API
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaansehgal99 committed Oct 23, 2024
1 parent e652935 commit 7748420
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 141 deletions.
27 changes: 7 additions & 20 deletions ragengine/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from typing import Dict, List

from llama_index.core.schema import TextNode

from vector_store_manager.manager import VectorStoreManager
from embedding.huggingface_local import LocalHuggingFaceEmbedding
from embedding.huggingface_remote import RemoteHuggingFaceEmbedding
from llama_index.core.storage.docstore.types import RefDocInfo
from fastapi import FastAPI, HTTPException
from models import (IndexRequest, ListDocumentsResponse,
QueryRequest, Document)
QueryRequest, QueryResponse, DocumentResponse)
from vector_store.faiss_store import FaissVectorStoreHandler

from config import ACCESS_SECRET, EMBEDDING_TYPE, MODEL_ID
from ragengine.config import ACCESS_SECRET, EMBEDDING_TYPE, MODEL_ID

app = FastAPI()

Expand All @@ -30,40 +26,31 @@
# Initialize RAG operations
rag_ops = VectorStoreManager(vector_store_handler)

@app.post("/index", response_model=List[Document])
@app.post("/index", response_model=List[DocumentResponse])
async def index_documents(request: IndexRequest): # TODO: Research async/sync what to use (inference is calling)
try:
doc_ids = rag_ops.create(request.index_name, request.documents)
documents = [
Document(doc_id=doc_id, text=doc.text, metadata=doc.metadata)
DocumentResponse(doc_id=doc_id, text=doc.text, metadata=doc.metadata)
for doc_id, doc in zip(doc_ids, request.documents)
]
return documents
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/query", response_model=Dict[str, str])
@app.post("/query", response_model=QueryResponse)
async def query_index(request: QueryRequest):
try:
llm_params = request.llm_params or {} # Default to empty dict if no params provided
response = rag_ops.read(request.index_name, request.query, request.top_k, llm_params)
return {"response": str(response)}
return rag_ops.read(request.index_name, request.query, request.top_k, llm_params)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.get("/indexed-documents", response_model=ListDocumentsResponse)
async def list_all_indexed_documents():
try:
documents = rag_ops.list_all_indexed_documents()
serialized_documents = {
index_name: {
doc_name: {
"text": doc_info.text, "hash": doc_info.hash
} for doc_name, doc_info in vector_store_index.docstore.docs.items()
}
for index_name, vector_store_index in documents.items()
}
return ListDocumentsResponse(documents=serialized_documents)
return ListDocumentsResponse(documents=documents)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

Expand Down
19 changes: 18 additions & 1 deletion ragengine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ class Document(BaseModel):
text: str
metadata: Optional[dict] = {}

class DocumentResponse(BaseModel):
doc_id: str
text: str
metadata: Optional[dict] = None

class IndexRequest(BaseModel):
index_name: str
documents: List[Document]
Expand All @@ -17,4 +22,16 @@ class QueryRequest(BaseModel):
llm_params: Optional[Dict] = None # Accept a dictionary for parameters

class ListDocumentsResponse(BaseModel):
documents:Dict[str, Dict[str, Dict[str, str]]]
documents: Dict[str, Dict[str, Dict[str, str]]]

# Define models for TextNode, NodeWithScore, and the main Response
class NodeWithScore(BaseModel):
node_id: str
text: str
score: float
metadata: Optional[dict] = None

class QueryResponse(BaseModel):
response: str
source_nodes: List[NodeWithScore]
metadata: Optional[dict] = None
55 changes: 14 additions & 41 deletions ragengine/tests/api/test_main.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import os
from tempfile import TemporaryDirectory
from unittest.mock import patch

import pytest
from vector_store.faiss_store import FaissVectorStoreHandler
from models import Document
from embedding.huggingface_local import LocalHuggingFaceEmbedding
from config import MODEL_ID, INFERENCE_URL, INFERENCE_ACCESS_SECRET
from llama_index.core.storage.index_store import SimpleIndexStore

from main import app, rag_ops
from ragengine.main import app, vector_store_handler, rag_ops
from fastapi.testclient import TestClient
from unittest.mock import MagicMock
import pytest

AUTO_GEN_DOC_ID_LEN = 36
AUTO_GEN_DOC_ID_LEN = 64

client = TestClient(app)

@pytest.fixture(autouse=True)
def clear_index():
vector_store_handler.index_map.clear()
vector_store_handler.index_store = SimpleIndexStore()

def test_index_documents_success():
request_data = {
"index_name": "test_index",
Expand Down Expand Up @@ -65,7 +64,11 @@ def test_query_index_success(mock_post):

response = client.post("/query", json=request_data)
assert response.status_code == 200
assert response.json() == {"response": "This is the completion from the API"}
assert response.json()["response"] == "{'result': 'This is the completion from the API'}"
assert len(response.json()["source_nodes"]) == 1
assert response.json()["source_nodes"][0]["text"] == "This is a test document"
assert response.json()["source_nodes"][0]["score"] == 0.5354418754577637
assert response.json()["source_nodes"][0]["metadata"] == {}
assert mock_post.call_count == 1

def test_query_index_failure():
Expand All @@ -82,36 +85,6 @@ def test_query_index_failure():
assert response.json()["detail"] == "No such index: 'non_existent_index' exists."


def test_get_document_success():
request_data = {
"index_name": "test_index",
"documents": [
# {"doc_id": "doc1", "text": "This is a test document"},
{"doc_id": "doc1", "text": "This is a test document"},
{"text": "Another test document"}
]
}

index_response = client.post("/index", json=request_data)
assert index_response.status_code == 200

# Call the GET document endpoint.
get_response = client.get("/document/test_index/doc1")
assert get_response.status_code == 200

response_json = get_response.json()

assert response_json.keys() == {"node_ids", 'metadata'}
assert response_json['metadata'] == {}

assert isinstance(response_json["node_ids"], list) and len(response_json["node_ids"]) == 1


def test_get_document_failure():
# Call the GET document endpoint.
response = client.get("/document/test_index/doc1")
assert response.status_code == 404

def test_list_all_indexed_documents_success():
response = client.get("/indexed-documents")
assert response.status_code == 200
Expand Down
95 changes: 38 additions & 57 deletions ragengine/tests/vector_store/test_faiss_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from unittest.mock import patch

import pytest

from ragengine.vector_store.base import BaseVectorStore
from ragengine.vector_store.faiss_store import FaissVectorStoreHandler
from ragengine.models import Document
from ragengine.embedding.huggingface_local import LocalHuggingFaceEmbedding
from ragengine.config import MODEL_ID, INFERENCE_URL, INFERENCE_ACCESS_SECRET
from ragengine.config import PERSIST_DIR

@pytest.fixture(scope='session')
def init_embed_manager():
Expand All @@ -21,40 +24,39 @@ def vector_store_manager(init_embed_manager):
yield FaissVectorStoreHandler(init_embed_manager)

def test_index_documents(vector_store_manager):
first_doc_text, second_doc_text = "First document", "Second document"
documents = [
Document(doc_id="1", text="First document", metadata={"type": "text"}),
Document(doc_id="2", text="Second document", metadata={"type": "text"})
Document(text=first_doc_text, metadata={"type": "text"}),
Document(text=second_doc_text, metadata={"type": "text"})
]

doc_ids = vector_store_manager.index_documents("test_index", documents)

assert len(doc_ids) == 2
assert doc_ids == ["1", "2"]
assert set(doc_ids) == {BaseVectorStore.generate_doc_id(first_doc_text),
BaseVectorStore.generate_doc_id(second_doc_text)}

def test_index_documents_isolation(vector_store_manager):
doc_1_id, doc_2_id = "1", "2"
documents1 = [
Document(doc_id=doc_1_id, text="First document in index1", metadata={"type": "text"}),
Document(text="First document in index1", metadata={"type": "text"}),
]
documents2 = [
Document(doc_id=doc_2_id, text="First document in index2", metadata={"type": "text"}),
Document(text="First document in index2", metadata={"type": "text"}),
]

# Index documents in separate indices
index_name_1, index_name_2 = "index1", "index2"
vector_store_manager.index_documents(index_name_1, documents1)
vector_store_manager.index_documents(index_name_2, documents2)

# Ensure documents are correctly persisted and separated by index
doc_1 = vector_store_manager.get_document(index_name_1, doc_1_id)
assert doc_1 and doc_1.node_ids # Ensure documents were created

doc_2 = vector_store_manager.get_document(index_name_2, doc_2_id)
assert doc_2 and doc_2.node_ids # Ensure documents were created

# Ensure that the documents do not mix between indices
assert vector_store_manager.get_document(index_name_2, doc_1_id) is None, f"Document {doc_1_id} should not exist in {index_name_2}"
assert vector_store_manager.get_document(index_name_1, doc_2_id) is None, f"Document {doc_2_id} should not exist in {index_name_1}"
assert vector_store_manager.list_all_indexed_documents() == {
'index1': {"87117028123498eb7d757b1507aa3e840c63294f94c27cb5ec83c939dedb32fd":
{'hash': '1e64a170be48c45efeaa8667ab35919106da0489ec99a11d0029f2842db133aa',
'text': 'First document in index1'}},
'index2': {"49b198c0e126a99e1975f17b564756c25b4ad691a57eda583e232fd9bee6de91":
{'hash': 'a222f875b83ce8b6eb72b3cae278b620de9bcc7c6b73222424d3ce979d1a463b',
'text': 'First document in index2'}}
}

@patch('requests.post')
def test_query_documents(mock_post, vector_store_manager):
Expand All @@ -67,17 +69,19 @@ def test_query_documents(mock_post, vector_store_manager):

# Add documents to index
documents = [
Document(doc_id="1", text="First document", metadata={"type": "text"}),
Document(doc_id="2", text="Second document", metadata={"type": "text"})
Document(text="First document", metadata={"type": "text"}),
Document(text="Second document", metadata={"type": "text"})
]
vector_store_manager.index_documents("test_index", documents)

params = {"temperature": 0.7}
# Mock query and results
query_result = vector_store_manager.query("test_index", "First", top_k=1, params=params)
query_result = vector_store_manager.query("test_index", "First", top_k=1, llm_params=params)

assert query_result is not None
assert query_result.response == "This is the completion from the API"
assert query_result["response"] == "{'result': 'This is the completion from the API'}"
assert query_result["source_nodes"][0]["text"] == "First document"
assert query_result["source_nodes"][0]["score"] == 0.5795239210128784

mock_post.assert_called_once_with(
INFERENCE_URL,
Expand All @@ -86,57 +90,34 @@ def test_query_documents(mock_post, vector_store_manager):
headers={"Authorization": f"Bearer {INFERENCE_ACCESS_SECRET}"}
)

def test_add_document(vector_store_manager, capsys):
documents = [Document(doc_id="3", text="Third document", metadata={"type": "text"})]
def test_add_document(vector_store_manager):
documents = [Document(text="Third document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)

# Add a document to the existing index
new_document = Document(doc_id="4", text="Fourth document", metadata={"type": "text"})
new_document = [Document(text="Fourth document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", new_document)

# Assert that the document exists
assert vector_store_manager.document_exists("test_index", "4")
assert vector_store_manager.document_exists("test_index",
BaseVectorStore.generate_doc_id("Fourth document"))

def test_persist_and_load_index_store(vector_store_manager):
"""Test that the index store is persisted and loaded correctly."""
def test_persist_index_1(vector_store_manager):
"""Test that the index store is persisted."""
# Add a document and persist the index
documents = [Document(doc_id="1", text="Test document", metadata={"type": "text"})]
documents = [Document(text="Test document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)
vector_store_manager._persist("test_index")
assert os.path.exists(PERSIST_DIR)

# Simulate a fresh load of the index store (clearing in-memory state)
vector_store_manager.index_store = None # Clear current in-memory store
vector_store_manager._load_index_store()

# Verify that the store was reloaded and contains the expected index structure
assert vector_store_manager.index_store is not None
assert len(vector_store_manager.index_store.index_structs()) > 0

# TODO: Prevent default re-indexing from load_index_from_storage
def test_persist_and_load_index(vector_store_manager):
"""Test that an index is persisted and then loaded correctly."""
def test_persist_index_2(vector_store_manager):
"""Test that an index store is persisted."""
# Add a document and persist the index
documents = [Document(doc_id="1", text="Test document", metadata={"type": "text"})]
documents = [Document(text="Test document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)

documents = [Document(doc_id="1", text="Another Test document", metadata={"type": "text"})]
documents = [Document(text="Another Test document", metadata={"type": "text"})]
vector_store_manager.index_documents("another_test_index", documents)

vector_store_manager._persist_all()

# Simulate a fresh load of the index (clearing in-memory state)
vector_store_manager.index_map = {} # Clear current in-memory index map
loaded_indices = vector_store_manager._load_indices()

# Verify that the index was reloaded and contains the expected document
assert loaded_indices is not None
assert vector_store_manager.document_exists("test_index", "1")
assert vector_store_manager.document_exists("another_test_index", "1")

vector_store_manager.index_map = {} # Clear current in-memory index map
loaded_index = vector_store_manager._load_index("test_index")

assert loaded_index is not None
assert vector_store_manager.document_exists("test_index", "1")
assert not vector_store_manager.document_exists("another_test_index", "1") # Since we didn't load this index

assert os.path.exists(PERSIST_DIR)
4 changes: 2 additions & 2 deletions ragengine/vector_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ def query(self, index_name: str, query: str, top_k: int, params: dict):
pass

@abstractmethod
def add_document(self, index_name: str, document: Document):
def add_document_to_index(self, index_name: str, document: Document, doc_id: str):
pass

@abstractmethod
def list_all_indexed_documents(self) -> Dict[str, VectorStoreIndex]:
def list_all_indexed_documents(self) -> Dict[str, Dict[str, Dict[str, str]]]:
pass

@abstractmethod
Expand Down
Loading

0 comments on commit 7748420

Please sign in to comment.