Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Part 4 (Final) - Introduce Main RAG Service API and its tests #603

Merged
merged 50 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
d369d85
feat: Add RAGEngine CRD
Fei-Guo Sep 17, 2024
7740def
Merge branch 'main' of https://github.com/Azure/kaito into Ishaan/RAG
ishaansehgal99 Sep 19, 2024
47c1ce6
feat: New RAG Service
ishaansehgal99 Sep 20, 2024
c8bfa18
feat: New RAG Service
ishaansehgal99 Sep 20, 2024
a28a8d5
fix: Use local index object
ishaansehgal99 Sep 20, 2024
63ef83d
fix: Load Index
ishaansehgal99 Sep 20, 2024
ff03456
fix: Add ChromaDB VectorStore
ishaansehgal99 Sep 24, 2024
d02391a
fix: Add TODOs and comments
ishaansehgal99 Sep 24, 2024
d82897d
fix: Add function for getting embedding dim
ishaansehgal99 Sep 24, 2024
cd9cbab
fix: Updates faiss store to handle multiple indices and dynamically g…
ishaansehgal99 Sep 24, 2024
33669fc
feat: Add requirements
ishaansehgal99 Sep 27, 2024
7165ccf
feat: fix typos, syntax errors and bugs
ishaansehgal99 Sep 27, 2024
7f39939
fix: Bugs fixed for managing embeddings
ishaansehgal99 Sep 28, 2024
1e07beb
feat: Use a global SimpleIndexStore and seperate StorageContexts
ishaansehgal99 Oct 1, 2024
746c156
feat: Add the load and list indexing functions
ishaansehgal99 Oct 1, 2024
3a83f26
feat: Remove chromadb from PR
ishaansehgal99 Oct 2, 2024
cb80f3e
feat: Add CustomLLM Inference
ishaansehgal99 Oct 2, 2024
a0d1186
fix: Introduce Custom LLM class and top_k query
ishaansehgal99 Oct 2, 2024
4c66387
fix: Update tests to handle faiss delete not implemented yet
ishaansehgal99 Oct 2, 2024
35b5113
fix: Update tests to handle refresh documents
ishaansehgal99 Oct 2, 2024
742485e
fix: Update tests for loading and persisting data
ishaansehgal99 Oct 2, 2024
ba03cdd
Merge branch 'main' into Ishaan/RAG
ishaansehgal99 Oct 2, 2024
51c7035
fix: Update tests for loading index
ishaansehgal99 Oct 2, 2024
6e7b827
feat: Move to ragengine folder and remove unneeded CRUD operations (r…
ishaansehgal99 Oct 3, 2024
aaaa21b
fix: Update to include rag unit tests
ishaansehgal99 Oct 3, 2024
be9d6ed
fix: Update persisting and loading logic
ishaansehgal99 Oct 4, 2024
cf24953
feat: Custom params for llm
ishaansehgal99 Oct 7, 2024
eeef54a
Merge branch 'main' into Ishaan/RAG
ishaansehgal99 Oct 8, 2024
eff5b37
feat: massive update, improvements all across service and enhanced un…
ishaansehgal99 Oct 10, 2024
83ab9a3
Merge branch 'main' of https://github.com/Azure/kaito into Ishaan/RAG
ishaansehgal99 Oct 10, 2024
9f52ee8
fix: Slight fix no need to parse inference result
ishaansehgal99 Oct 11, 2024
a232d67
nit
ishaansehgal99 Oct 11, 2024
cee740b
Merge branch 'main' into Ishaan/RAG
ishaansehgal99 Oct 11, 2024
afb8606
nit
ishaansehgal99 Oct 11, 2024
2455dfd
fix: remove unused files
ishaansehgal99 Oct 11, 2024
d32169f
fix: Example of live test
ishaansehgal99 Oct 11, 2024
5520950
Merge branch 'main' of https://github.com/Azure/kaito into Ishaan/RAG
ishaansehgal99 Oct 21, 2024
42f288b
Update endpoints and remove old class
ishaansehgal99 Oct 21, 2024
e652935
pytest fix target
ishaansehgal99 Oct 21, 2024
7748420
feat: Updated UTs, models and API
ishaansehgal99 Oct 23, 2024
1b0a7a0
feat: Updated UTs, models and API
ishaansehgal99 Oct 23, 2024
1c34fb0
feat: Updated UTs, models and API
ishaansehgal99 Oct 23, 2024
3cba796
Merge branch 'main' into Ishaan/RAG
ishaansehgal99 Oct 23, 2024
bc94669
feat: Updated UTs, models and API
ishaansehgal99 Oct 23, 2024
7390aa4
Merge branch 'Ishaan/RAG' of https://github.com/Azure/kaito into Isha…
ishaansehgal99 Oct 23, 2024
bc076bd
feat: Updated UTs, models and API
ishaansehgal99 Oct 23, 2024
f93669b
feat: Updated UTs, models and API
ishaansehgal99 Oct 23, 2024
a5dd527
approx
ishaansehgal99 Oct 23, 2024
1000732
fix: add ut dependency
ishaansehgal99 Oct 23, 2024
3d6a623
fix: renamed
ishaansehgal99 Oct 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ jobs:
- name: Run unit tests & Generate coverage
run: |
make unit-test
make rag-service-test
make tuning-metrics-server-test

- name: Run inference api unit tests
- name: Run inference api e2e tests
run: |
make inference-api-e2e

Expand Down
13 changes: 12 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,22 @@ unit-test: ## Run unit tests.
-race -coverprofile=coverage.txt -covermode=atomic
go tool cover -func=coverage.txt

.PHONY: rag-service-test
rag-service-test:
pip install -r ragengine/requirements.txt
pytest -o log_cli=true -o log_cli_level=INFO ragengine/tests

.PHONY: tuning-metrics-server-test
tuning-metrics-server-test:
pip install -r presets/inference/text-generation/requirements.txt
pytest -o log_cli=true -o log_cli_level=INFO presets/tuning/text-generation/metrics

## --------------------------------------
## E2E tests
## --------------------------------------

inference-api-e2e:
.PHONY: inference-api-e2e
inference-api-e2e:
pip install -r presets/inference/text-generation/requirements.txt
pytest -o log_cli=true -o log_cli_level=INFO presets/inference/text-generation/tests

Expand Down
Empty file added ragengine/README.md
Empty file.
59 changes: 59 additions & 0 deletions ragengine/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import List
from vector_store_manager.manager import VectorStoreManager
from embedding.huggingface_local import LocalHuggingFaceEmbedding
from embedding.huggingface_remote import RemoteHuggingFaceEmbedding
from fastapi import FastAPI, HTTPException
from models import (IndexRequest, ListDocumentsResponse,
QueryRequest, QueryResponse, DocumentResponse)
from vector_store.faiss_store import FaissVectorStoreHandler

from ragengine.config import ACCESS_SECRET, EMBEDDING_TYPE, MODEL_ID

app = FastAPI()

# Initialize embedding model
if EMBEDDING_TYPE.lower() == "local":
embedding_manager = LocalHuggingFaceEmbedding(MODEL_ID)
elif EMBEDDING_TYPE.lower() == "remote":
embedding_manager = RemoteHuggingFaceEmbedding(MODEL_ID, ACCESS_SECRET)
else:
raise ValueError("Invalid Embedding Type Specified (Must be Local or Remote)")

# Initialize vector store
# TODO: Dynamically set VectorStore from EnvVars (which ultimately comes from CRD StorageSpec)
vector_store_handler = FaissVectorStoreHandler(embedding_manager)

# Initialize RAG operations
rag_ops = VectorStoreManager(vector_store_handler)

@app.post("/index", response_model=List[DocumentResponse])
async def index_documents(request: IndexRequest): # TODO: Research async/sync what to use (inference is calling)
try:
doc_ids = rag_ops.index(request.index_name, request.documents)
documents = [
DocumentResponse(doc_id=doc_id, text=doc.text, metadata=doc.metadata)
for doc_id, doc in zip(doc_ids, request.documents)
]
return documents
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/query", response_model=QueryResponse)
async def query_index(request: QueryRequest):
try:
llm_params = request.llm_params or {} # Default to empty dict if no params provided
return rag_ops.query(request.index_name, request.query, request.top_k, llm_params)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.get("/indexed-documents", response_model=ListDocumentsResponse)
async def list_all_indexed_documents():
try:
documents = rag_ops.list_all_indexed_documents()
return ListDocumentsResponse(documents=documents)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
19 changes: 18 additions & 1 deletion ragengine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ class Document(BaseModel):
text: str
metadata: Optional[dict] = {}

class DocumentResponse(BaseModel):
doc_id: str
text: str
metadata: Optional[dict] = None

class IndexRequest(BaseModel):
index_name: str
documents: List[Document]
Expand All @@ -17,4 +22,16 @@ class QueryRequest(BaseModel):
llm_params: Optional[Dict] = None # Accept a dictionary for parameters

class ListDocumentsResponse(BaseModel):
documents:Dict[str, Dict[str, Dict[str, str]]]
documents: Dict[str, Dict[str, Dict[str, str]]]

# Define models for NodeWithScore, and QueryResponse
class NodeWithScore(BaseModel):
node_id: str
text: str
score: float
metadata: Optional[dict] = None

class QueryResponse(BaseModel):
response: str
source_nodes: List[NodeWithScore]
metadata: Optional[dict] = None
8 changes: 8 additions & 0 deletions ragengine/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
# RAG Library Requirements
llama-index
# HF Embeddings
llama-index-embeddings-huggingface
llama-index-embeddings-huggingface-api
# HF LLMs
llama-index-llms-huggingface
llama-index-llms-huggingface-api

fastapi
faiss-cpu
llama-index-vector-stores-faiss
uvicorn
# For UTs
pytest
Empty file.
6 changes: 6 additions & 0 deletions ragengine/tests/api/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU-only execution for testing
os.environ["OMP_NUM_THREADS"] = "1" # Force single-threaded for testing to prevent segfault while loading embedding model
os.environ["MKL_NUM_THREADS"] = "1" # Force MKL to use a single thread
143 changes: 143 additions & 0 deletions ragengine/tests/api/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from unittest.mock import patch

from llama_index.core.storage.index_store import SimpleIndexStore

from ragengine.main import app, vector_store_handler, rag_ops
from fastapi.testclient import TestClient
import pytest

AUTO_GEN_DOC_ID_LEN = 64

client = TestClient(app)

@pytest.fixture(autouse=True)
def clear_index():
vector_store_handler.index_map.clear()
vector_store_handler.index_store = SimpleIndexStore()

def test_index_documents_success():
request_data = {
"index_name": "test_index",
"documents": [
{"text": "This is a test document"},
{"text": "Another test document"}
]
}

response = client.post("/index", json=request_data)
assert response.status_code == 200
doc1, doc2 = response.json()
assert (doc1["text"] == "This is a test document")
assert len(doc1["doc_id"]) == AUTO_GEN_DOC_ID_LEN
assert not doc1["metadata"]

assert (doc2["text"] == "Another test document")
assert len(doc2["doc_id"]) == AUTO_GEN_DOC_ID_LEN
assert not doc2["metadata"]

@patch('requests.post')
def test_query_index_success(mock_post):
# Define Mock Response for Custom Inference API
mock_response = {
"result": "This is the completion from the API"
}
mock_post.return_value.json.return_value = mock_response
# Index
request_data = {
"index_name": "test_index",
"documents": [
{"text": "This is a test document"},
{"text": "Another test document"}
]
}

response = client.post("/index", json=request_data)
assert response.status_code == 200

# Query
request_data = {
"index_name": "test_index",
"query": "test query",
"top_k": 1,
"llm_params": {"temperature": 0.7}
}

response = client.post("/query", json=request_data)
assert response.status_code == 200
assert response.json()["response"] == "{'result': 'This is the completion from the API'}"
assert len(response.json()["source_nodes"]) == 1
assert response.json()["source_nodes"][0]["text"] == "This is a test document"
assert response.json()["source_nodes"][0]["score"] == pytest.approx(0.5354418754577637, rel=1e-6)
assert response.json()["source_nodes"][0]["metadata"] == {}
assert mock_post.call_count == 1

def test_query_index_failure():
# Prepare request data for querying.
request_data = {
"index_name": "non_existent_index", # Use an index name that doesn't exist
"query": "test query",
"top_k": 1,
"llm_params": {"temperature": 0.7}
}

response = client.post("/query", json=request_data)
assert response.status_code == 500
assert response.json()["detail"] == "No such index: 'non_existent_index' exists."


def test_list_all_indexed_documents_success():
response = client.get("/indexed-documents")
assert response.status_code == 200
assert response.json() == {'documents': {}}

request_data = {
"index_name": "test_index",
"documents": [
{"text": "This is a test document"},
{"text": "Another test document"}
]
}

response = client.post("/index", json=request_data)
assert response.status_code == 200

response = client.get("/indexed-documents")
assert response.status_code == 200
assert "test_index" in response.json()["documents"]
response_idx = response.json()["documents"]["test_index"]
assert len(response_idx) == 2 # Two Documents Indexed
assert ({item["text"] for item in response_idx.values()}
== {item["text"] for item in request_data["documents"]})


"""
Example of a live query test. This test is currently commented out as it requires a valid
INFERENCE_URL in config.py. To run the test, ensure that a valid INFERENCE_URL is provided.
Upon execution, RAG results should be observed.

def test_live_query_test():
# Index
request_data = {
"index_name": "test_index",
"documents": [
{"text": "Polar bear – can lift 450Kg (approximately 0.7 times their body weight) \
Adult male polar bears can grow to be anywhere between 300 and 700kg"},
{"text": "Giraffes are the tallest mammals and are well-adapted to living in trees. \
They have few predators as adults."}
]
}

response = client.post("/index", json=request_data)
assert response.status_code == 200

# Query
request_data = {
"index_name": "test_index",
"query": "What is the strongest bear?",
"top_k": 1,
"llm_params": {"temperature": 0.7}
}

response = client.post("/query", json=request_data)
assert response.status_code == 200
"""
Loading
Loading