From 77bf4d698671c221c74840a1e84ced0d735f06fb Mon Sep 17 00:00:00 2001 From: ShreehariVaasishta Date: Wed, 4 Oct 2023 13:47:26 +0530 Subject: [PATCH] Add APIs for Memory component --- genai_stack/constant.py | 1 + genai_stack/genai_server/models/memory.py | 60 +++++++++ .../genai_server/routers/memory_routes.py | 30 +++++ genai_stack/genai_server/server.py | 9 +- genai_stack/genai_server/services/memory.py | 123 ++++++++++++++++++ genai_stack/memory/langchain.py | 26 ++-- genai_stack/memory/utils.py | 45 ++++++- 7 files changed, 279 insertions(+), 15 deletions(-) create mode 100644 genai_stack/genai_server/models/memory.py create mode 100644 genai_stack/genai_server/routers/memory_routes.py create mode 100644 genai_stack/genai_server/services/memory.py diff --git a/genai_stack/constant.py b/genai_stack/constant.py index 4a63e152..a0864ebf 100644 --- a/genai_stack/constant.py +++ b/genai_stack/constant.py @@ -7,3 +7,4 @@ VECTORDB = "/vectordb" ETL = "/etl" PROMPT_ENGINE = "/prompt-engine" +MEMORY = "/memory" diff --git a/genai_stack/genai_server/models/memory.py b/genai_stack/genai_server/models/memory.py new file mode 100644 index 00000000..28cf9732 --- /dev/null +++ b/genai_stack/genai_server/models/memory.py @@ -0,0 +1,60 @@ +from typing import List, Optional + +from pydantic import BaseModel + + +class Message(BaseModel): + """ + Represents a message with user text and model text. + + Attributes: + user_text (str): The text provided by the user. + model_text (str): The text generated by the model. + """ + + user_text: str + model_text: str + + +class MemoryBaseModel(BaseModel): + """ + Represents a base model for memory with a session ID. + + Attributes: + session_id (int): The ID of the session. + """ + + session_id: int + + +class MemoryAddTextRequestModel(MemoryBaseModel): + """ + Represents a request model for adding text to memory, extending the MemoryBaseModel. + + Attributes: + message (Message): The message containing user text and model text. + """ + + message: Message + + +class MemoryLatestMessageResponseModel(MemoryBaseModel): + """ + Represents a response model for the latest message in memory, extending the MemoryBaseModel. + + Attributes: + message (Optional[Message]): The latest message in memory, if available. + """ + + message: Optional[Message] + + +class MemoryHistoryResponseModel(MemoryBaseModel): + """ + Represents a response model for the history of messages in memory, extending the MemoryBaseModel. + + Attributes: + messages (Optional[List[Message]]): The list of messages in memory, if available. + """ + + messages: Optional[List[Message]] diff --git a/genai_stack/genai_server/routers/memory_routes.py b/genai_stack/genai_server/routers/memory_routes.py new file mode 100644 index 00000000..0c97e007 --- /dev/null +++ b/genai_stack/genai_server/routers/memory_routes.py @@ -0,0 +1,30 @@ +from fastapi.routing import APIRouter + +from genai_stack.constant import API, MEMORY +from genai_stack.genai_server.models.memory import ( + MemoryAddTextRequestModel, + MemoryBaseModel, + MemoryHistoryResponseModel, + MemoryLatestMessageResponseModel, +) +from genai_stack.genai_server.services.memory import MemoryService +from genai_stack.genai_server.settings.settings import settings + +service = MemoryService(store=settings.STORE) + +router = APIRouter(prefix=API + MEMORY, tags=["memory"]) + + +@router.post("/add-text", response_model=MemoryLatestMessageResponseModel) +def add_texts_to_memory(data: MemoryAddTextRequestModel): + return service.add_to_memory(data=data) + + +@router.post("/get-latest-text", response_model=MemoryLatestMessageResponseModel) +def get_latest_text_from_memory(data: MemoryBaseModel): + return service.get_latest_message_from_memory(data=data) + + +@router.post("/get-history", response_model=MemoryHistoryResponseModel) +def get_history_from_memory(data: MemoryBaseModel): + return service.get_latest_message_from_memory(data=data) diff --git a/genai_stack/genai_server/server.py b/genai_stack/genai_server/server.py index fff17a2a..a219bb50 100644 --- a/genai_stack/genai_server/server.py +++ b/genai_stack/genai_server/server.py @@ -1,6 +1,12 @@ from fastapi import FastAPI -from genai_stack.genai_server.routers import session_routes, retriever_routes, vectordb_routes, etl_routes +from genai_stack.genai_server.routers import ( + session_routes, + retriever_routes, + vectordb_routes, + etl_routes, + memory_routes, +) def get_genai_server_app(): @@ -20,5 +26,6 @@ def get_genai_server_app(): app.include_router(retriever_routes.router) app.include_router(vectordb_routes.router) app.include_router(etl_routes.router) + app.include_router(memory_routes.router) return app diff --git a/genai_stack/genai_server/services/memory.py b/genai_stack/genai_server/services/memory.py new file mode 100644 index 00000000..4c5cfd87 --- /dev/null +++ b/genai_stack/genai_server/services/memory.py @@ -0,0 +1,123 @@ +from fastapi import status +from fastapi.exceptions import HTTPException +from sqlalchemy.orm import Session + +from genai_stack.genai_server.models.memory import ( + MemoryAddTextRequestModel, + MemoryHistoryResponseModel, + MemoryLatestMessageResponseModel, +) +from genai_stack.genai_server.schemas import StackSessionSchema +from genai_stack.genai_server.settings.config import stack_config +from genai_stack.genai_server.utils.utils import get_current_stack + + +class MemoryService: + """ + Represents a service for managing memory for chat like conversation. + """ + + def _get_stack_session(self, session, data): + """ + Retrieves the stack session based on the provided session and data. + + Args: + session: The session object. + data: The data containing the session ID. + + Returns: + The result of retrieving the current stack. + + Raises: + HTTPException: Raised when the stack session is not found. + + """ + + stack_session = session.get(StackSessionSchema, data.session_id) + if stack_session is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Session {data.session_id} not found", + ) + return get_current_stack(config=stack_config, session=stack_session) + + def add_to_memory(self, data: MemoryAddTextRequestModel): + """ + Adds text to memory based on the provided data. + + Args: + data (MemoryAddTextRequestModel): The data containing the text to be added. + + Returns: + MemoryLatestMessageResponseModel: The response model containing the latest message in memory. + + Raises: + HTTPException: Raised when both human and model texts are not provided. + """ + + with Session(self.engine) as session: + stack = self._get_stack_session(session, data) + human_text, model_text = data.message.user_text, data.message.model_text + + if not human_text or not model_text: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Both Human and Model texts are required.", + ) + + stack.memory.add_text(user_text=human_text, model_text=model_text) + response = stack.memory.get_text() + return MemoryLatestMessageResponseModel( + message=response, + session_id=data.session_id, + ) + + def get_memory_history(self, data): + """ + Retrieves the history of messages in memory based on the provided data. + + Args: + data: The data containing the session ID. + + Returns: + MemoryHistoryResponseModel: The response model containing the history of messages in memory. + """ + + with Session(self.engine) as session: + stack = self._get_stack_session(session, data) + response = stack.memory.get_chat_history_list() + return MemoryHistoryResponseModel( + messages=response, + session_id=data.session_id, + ) + + def get_latest_message_from_memory(self, data): + """ + Retrieves the latest message from memory based on the provided data. + + Args: + data: The data containing the session ID. + + Returns: + MemoryLatestMessageResponseModel: The response model containing the latest message in memory. + + Example: + ```python + data = DataImplementation() + + # Creating an instance of MemoryService + service = MemoryService() + + # Retrieving the latest message from memory + response = service.get_latest_message_from_memory(data) + print(response) + ``` + """ + + with Session(self.engine) as session: + stack = self._get_stack_session(session, data) + response = stack.memory.get_text() + return MemoryLatestMessageResponseModel( + message=response, + session_id=data.session_id, + ) diff --git a/genai_stack/memory/langchain.py b/genai_stack/memory/langchain.py index 26fbc894..ee46e4c4 100644 --- a/genai_stack/memory/langchain.py +++ b/genai_stack/memory/langchain.py @@ -1,12 +1,19 @@ +from typing import List + from langchain.memory import ConversationBufferMemory as cbm -from genai_stack.memory.base import BaseMemoryConfigModel, BaseMemoryConfig, BaseMemory -from genai_stack.memory.utils import parse_chat_conversation_history + +from genai_stack.memory.base import BaseMemory, BaseMemoryConfig, BaseMemoryConfigModel +from genai_stack.memory.utils import ( + get_chat_conversation_history_dict, + parse_chat_conversation_history, +) class ConversationBufferMemoryConfigModel(BaseMemoryConfigModel): """ Data Model for the configs """ + pass @@ -28,17 +35,20 @@ def get_user_text(self): if len(self.memory.chat_memory.messages) == 0: return None return self.memory.chat_memory.messages[-2].content - + def get_model_text(self): if len(self.memory.chat_memory.messages) == 0: return None return self.memory.chat_memory.messages[-1].content - + def get_text(self): return { - "user_text":self.get_user_text(), - "model_text":self.get_model_text() + "user_text": self.get_user_text(), + "model_text": self.get_model_text(), } - def get_chat_history(self): - return parse_chat_conversation_history(self.memory.chat_memory.messages) \ No newline at end of file + def get_chat_history(self) -> str: + return parse_chat_conversation_history(self.memory.chat_memory.messages) + + def get_chat_history_list(self) -> List: + return get_chat_conversation_history_dict(self.memory.chat_memory.messages) diff --git a/genai_stack/memory/utils.py b/genai_stack/memory/utils.py index 4346aae2..667f6733 100644 --- a/genai_stack/memory/utils.py +++ b/genai_stack/memory/utils.py @@ -1,9 +1,42 @@ -def parse_chat_conversation_history(response:list) -> str: +from typing import List, Dict +from langchain.schema.messages import BaseMessage + + +def get_chat_conversation_history_dict(messages: List[BaseMessage]) -> List[Dict]: + """ + Converts a list of messages into a list of dictionaries representing a chat conversation history. + + Args: + messages (List[BaseMessage]): The list of messages. + + Returns: + List[Dict]: The formatted chat conversation history. + + Example: + ```python + messages = [BaseMessage(), BaseMessage()] + + # Converting the list of messages into a chat conversation history + history = get_chat_conversation_history_dict(messages) + print(history) + ``` + """ + + formatted_messages = [] + for i in range(0, len(messages), 2): + user_message = messages[i].content + model_message = messages[i + 1].content + + formatted_messages.append({"user_text": user_message, "model_text": model_message}) + return formatted_messages + + +def parse_chat_conversation_history(response: list) -> str: history = "" for i in range(len(response)): - if i%2 == 0: - history+=f"HUMAN : {response[i].content}\n" + if i % 2 == 0: + history += f"HUMAN : {response[i].content}\n" else: - history+=f"YOU : {response[i].content}\n" - - return history \ No newline at end of file + history += f"YOU : {response[i].content}\n" + + return history