Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Add llm cache APIs #99

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions genai_stack/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
VECTORDB = "/vectordb"
ETL = "/etl"
PROMPT_ENGINE = "/prompt-engine"
LLM_CACHE = "/llm-cache"
MODEL = "/model"
19 changes: 19 additions & 0 deletions genai_stack/genai_server/models/cache_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pydantic import BaseModel


class BaseCacheRequestModel(BaseModel):
session_id: int
query: str
metadata: dict = None


class GetCacheRequestModel(BaseCacheRequestModel):
pass


class SetCacheRequestModel(BaseCacheRequestModel):
response: str


class CacheResponseModel(BaseCacheRequestModel):
response: str
22 changes: 22 additions & 0 deletions genai_stack/genai_server/routers/cache_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from fastapi import APIRouter

from genai_stack.constant import API, LLM_CACHE
from genai_stack.genai_server.settings.settings import settings
from genai_stack.genai_server.models.cache_models import (
GetCacheRequestModel, SetCacheRequestModel, CacheResponseModel
)
from genai_stack.genai_server.services.cache_service import LLMCacheService

service = LLMCacheService(store=settings.STORE)

router = APIRouter(prefix=API + LLM_CACHE, tags=["llm_cache"])


@router.get("/get-cache")
def get_cache(data: GetCacheRequestModel) -> CacheResponseModel:
return service.get_cache(data=data)


@router.post("/set-cache")
def set_cache(data: SetCacheRequestModel) -> CacheResponseModel:
return service.set_cache(data=data)
2 changes: 2 additions & 0 deletions genai_stack/genai_server/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import FastAPI

from genai_stack.genai_server.routers import (
cache_routes,
session_routes,
retriever_routes,
vectordb_routes,
Expand All @@ -26,6 +27,7 @@ def get_genai_server_app():
app.include_router(retriever_routes.router)
app.include_router(vectordb_routes.router)
app.include_router(etl_routes.router)
app.include_router(cache_routes.router)
app.include_router(model_routes.router)

return app
46 changes: 46 additions & 0 deletions genai_stack/genai_server/services/cache_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from fastapi import HTTPException
from sqlalchemy.orm import Session

from genai_stack.genai_platform.services import BaseService
from genai_stack.genai_server.models.cache_models import GetCacheRequestModel, SetCacheRequestModel, CacheResponseModel
from genai_stack.genai_server.settings.config import stack_config
from genai_stack.genai_server.utils import get_current_stack
from genai_stack.genai_store.schemas import StackSessionSchema


class LLMCacheService(BaseService):

def get_cache(self, data: GetCacheRequestModel) -> CacheResponseModel:
with Session(self.engine) as session:
stack_session = session.get(StackSessionSchema, data.session_id)
if stack_session is None:
raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found")
stack = get_current_stack(config=stack_config, session=stack_session)
response = stack.llm_cache.get_cache(
query=data.query,
metadata=data.metadata
)
return CacheResponseModel(
session_id=data.session_id,
query=data.query,
metadata=data.metadata,
response=response
)

def set_cache(self, data: SetCacheRequestModel) -> CacheResponseModel:
with Session(self.engine) as session:
stack_session = session.get(StackSessionSchema, data.session_id)
if stack_session is None:
raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found")
stack = get_current_stack(config=stack_config, session=stack_session)
stack.llm_cache.set_cache(
query=data.query,
response=data.response,
metadata=data.metadata
)
return CacheResponseModel(
session_id=data.session_id,
query=data.query,
metadata=data.metadata,
response=data.response
)
99 changes: 99 additions & 0 deletions tests/api/test_genai_server/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#!/usr/bin/env python

"""Tests for `genai_server`."""
import unittest
import requests


class TestLLMCacheAPIs(unittest.TestCase):

def setUp(self) -> None:
self.base_url = "http://127.0.0.1:5000/api/llm-cache"
Akshaj000 marked this conversation as resolved.
Show resolved Hide resolved

def test_set_cache(self):
response = requests.post(
url=self.base_url + "/set-cache",
json={
"session_id": 1,
"query": "Where is sunil from ?",
"response": "Sunil is from Hyderabad.",
"metadata": {"source": "/path", "page": 1}
}
)
assert response.status_code == 200
assert response.json()
data = response.json()
assert "query" in data.keys()
assert "metadata" in data.keys()
assert "response" in data.keys()

def test_get_cache(self):
response = requests.get(
url=self.base_url + "/get-cache",
json={
"session_id": 1,
"query": "Where is sunil from ?"
}
)

assert response.status_code == 200
assert response.json()
data = response.json()
assert "query" in data.keys()
assert "metadata" in data.keys()
assert "response" in data.keys()

def test_get_and_set(self):
query = "Where is sunil from ?"
metadata = {"source": "/path", "page": 1}
output = "Sunil is from Hyderabad."
response = requests.post(
url=self.base_url + "/set-cache",
json={
"session_id": 1,
"query": query,
"response": output,
"metadata": metadata
}
)
assert response.status_code == 200
assert response.json()
data = response.json()
assert "query" in data.keys() and data.get("query") == query
assert "metadata" in data.keys() and data.get("metadata") == metadata
assert "response" in data.keys() and data.get("response") == output

response = requests.get(
url=self.base_url + "/get-cache",
json={
"session_id": 1,
"query": query
}
)

assert response.status_code == 200
assert response.json()
data = response.json()
assert "query" in data.keys() and data.get("query") == query
assert "response" in data.keys() and data.get("response") == output

response = requests.get(
url=self.base_url + "/get-cache",
json={
"session_id": 1,
"query": "Where is sunil from ?",
"metadata": {"source": "/pathdiff", "page": 1}
}
)
assert response.status_code != 200

response = requests.get(
url=self.base_url + "/get-cache",
json={
"session_id": 1,
"query": "Where is sunil from ?",
"metadata": metadata
}
)

assert response.status_code == 200