diff --git a/genai_stack/genai_server/services/etl_service.py b/genai_stack/genai_server/services/etl_service.py index 69b5d4de..6aa419c5 100644 --- a/genai_stack/genai_server/services/etl_service.py +++ b/genai_stack/genai_server/services/etl_service.py @@ -20,7 +20,7 @@ def submit_job(self, data: Any, stack_session_id: Optional[int] = None) -> ETLJo data = ETLUtil(data).save_request(etl_job.id) - stack = get_current_stack(config=stack_config, session=stack_session) + stack = get_current_stack(config=stack_config, engine=session, session=stack_session) get_etl_platform(stack=stack).handle_job(**data) etl_job.data = data diff --git a/genai_stack/genai_server/services/prompt_engine_service.py b/genai_stack/genai_server/services/prompt_engine_service.py index 10472494..47364323 100644 --- a/genai_stack/genai_server/services/prompt_engine_service.py +++ b/genai_stack/genai_server/services/prompt_engine_service.py @@ -40,6 +40,7 @@ def get_prompt(self, data: PromptEngineGetRequestModel) -> PromptEngineGetRespon input_variables.remove("history") prompt = PromptTemplate(template=template, input_variables=input_variables) stack = get_current_stack( + engine=session, config=stack_config, session=stack_session, overide_config={ @@ -50,7 +51,7 @@ def get_prompt(self, data: PromptEngineGetRequestModel) -> PromptEngineGetRespon } ) else: - stack = get_current_stack(config=stack_config, session=stack_session) + stack = get_current_stack(config=stack_config, engine=session, session=stack_session) prompt = stack.prompt_engine.get_prompt_template(promptType=data.type, query=data.query) return PromptEngineGetResponseModel( template=prompt.template, @@ -64,12 +65,12 @@ def set_prompt(self, data: PromptEngineSetRequestModel) -> PromptEngineSetRespon if stack_session is None: raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found") input_variables = ["context", "history", "query"] - if data.type == PromptTypeEnum.SIMPLE_CHAT_PROMPT: + if data.type.value == PromptTypeEnum.SIMPLE_CHAT_PROMPT.value: input_variables.remove("context") - elif data.type == PromptTypeEnum.CONTEXTUAL_QA_PROMPT: + elif data.type.value == PromptTypeEnum.CONTEXTUAL_QA_PROMPT.value: input_variables.remove("history") for variable in input_variables: - if f"{variable}" not in data.template: + if variable not in data.template: raise HTTPException(status_code=400, detail=f"Input variable {variable} not found in template") for variable in data.template.split("{"): if "}" in variable and variable.split("}")[0] not in input_variables: diff --git a/genai_stack/genai_server/services/retriever_service.py b/genai_stack/genai_server/services/retriever_service.py index dab1c83c..3bd10e61 100644 --- a/genai_stack/genai_server/services/retriever_service.py +++ b/genai_stack/genai_server/services/retriever_service.py @@ -15,7 +15,7 @@ def retrieve(self, data: RetrieverRequestModel) -> RetrieverResponseModel: stack_session = session.get(StackSessionSchema, data.session_id) if stack_session is None: raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found") - stack = get_current_stack(config=stack_config, session=stack_session) + stack = get_current_stack(config=stack_config, engine=session, session=stack_session) response = stack.retriever.retrieve(data.query) return RetrieverResponseModel( output=response['output'], diff --git a/genai_stack/genai_server/services/session_service.py b/genai_stack/genai_server/services/session_service.py index 9e006fcf..399661be 100644 --- a/genai_stack/genai_server/services/session_service.py +++ b/genai_stack/genai_server/services/session_service.py @@ -24,7 +24,7 @@ def create_session(self) -> StackSessionResponseModel: created_at : datetime modified_at : None """ - stack = get_current_stack(stack_config, default_session=False) + stack = get_current_stack(config=stack_config, default_session=False) with Session(self.engine) as session: stack_session = StackSessionSchema(stack_id=1, meta_data={}) diff --git a/genai_stack/genai_server/services/vectordb_service.py b/genai_stack/genai_server/services/vectordb_service.py index 8ccf4295..9268e7dc 100644 --- a/genai_stack/genai_server/services/vectordb_service.py +++ b/genai_stack/genai_server/services/vectordb_service.py @@ -18,7 +18,7 @@ def add_documents(self, data: RetrieverAddDocumentsRequestModel) -> RetrieverAdd stack_session = session.get(StackSessionSchema, data.session_id) if stack_session is None: raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found") - stack = get_current_stack(config=stack_config, session=stack_session) + stack = get_current_stack(config=stack_config, engine=session, session=stack_session) stack.vectordb.add_documents(data.documents) return RetrieverAddDocumentsResponseModel( documents=[ @@ -34,7 +34,7 @@ def search(self, data: RetrieverSearchRequestModel) -> RetrieverSearchResponseMo with Session(self.engine) as session: stack_session = session.get(StackSessionSchema, data.session_id) - stack = get_current_stack(config=stack_config, session=stack_session) + stack = get_current_stack(config=stack_config, engine=session, session=stack_session) if stack_session is None: raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found") documents = stack.vectordb.search(data.query) diff --git a/genai_stack/genai_server/utils/utils.py b/genai_stack/genai_server/utils/utils.py index 6961c279..43f0b6f7 100644 --- a/genai_stack/genai_server/utils/utils.py +++ b/genai_stack/genai_server/utils/utils.py @@ -1,6 +1,10 @@ import string import random +from langchain.prompts import PromptTemplate + +from genai_stack.genai_server.schemas import PromptSchema +from genai_stack.prompt_engine.utils import PromptTypeEnum from genai_stack.utils import import_class from genai_stack.enums import StackComponentType from genai_stack.genai_server.models.session_models import StackSessionResponseModel @@ -68,7 +72,56 @@ def create_indexes(stack, stack_id: int, session_id: int) -> dict: return meta_data -def get_current_stack(config: dict, session=None, default_session: bool = True, overide_config: dict = None): +def get_prompt_from_db(session, session_id, stack_config): + prompt_sessions = ( + session.query(PromptSchema) + .filter_by(stack_session=session_id) + ) + prompt_type_map = { + PromptTypeEnum.SIMPLE_CHAT_PROMPT.value: { + "field": "simple_chat_prompt_template", + "input_variables": ["history", "query"] + }, + PromptTypeEnum.CONTEXTUAL_CHAT_PROMPT.value: { + "field": "contextual_chat_prompt_template", + "input_variables": ["context", "history", "query"] + }, + PromptTypeEnum.CONTEXTUAL_QA_PROMPT.value: { + "field": "contextual_qa_prompt_template", + "input_variables": ["context", "query"] + } + } + for prompt_session in prompt_sessions: + if "prompt_engine" not in stack_config["components"]: + stack_config["components"]["prompt_engine"] = { + "name": "PromptEngine", + "config": {} + } + if "config" not in stack_config["components"]["prompt_engine"]: + stack_config["components"]["prompt_engine"]["config"] = {} + stack_config["components"]["prompt_engine"]["config"] = { + **stack_config["components"]["prompt_engine"]["config"], + prompt_type_map[prompt_session.type.value]["field"]: PromptTemplate( + template=prompt_session.template, + input_variables=prompt_type_map[prompt_session.type.value]["input_variables"] + ) + } + return stack_config + + +def get_current_stack( + config: dict, + engine=None, + session=None, + default_session: bool = True, + overide_config: dict = None +): + if engine is not None: + config = get_prompt_from_db( + session=engine, + session_id=session.id, + stack_config=config + ) components = {} if session is None and default_session: from genai_stack.genai_server.settings.settings import settings diff --git a/tests/api/test_genai_server/test_retriever.py b/tests/api/test_genai_server/test_retriever.py index 2325a4ee..09378d24 100644 --- a/tests/api/test_genai_server/test_retriever.py +++ b/tests/api/test_genai_server/test_retriever.py @@ -14,7 +14,7 @@ def setUp(self) -> None: def test_retrieve(self): response = requests.get( url=self.base_url + "/retrieve", - params={"session_id": 1, "query": "Where is sunil from ?"}, + params={"session_id": 2, "query": "Where is sunil from ?"}, ) assert response.status_code == 200 assert response.json()