Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Commit

Permalink
Add utility func to get prompt from db
Browse files Browse the repository at this point in the history
  • Loading branch information
Akshaj000 committed Oct 6, 2023
1 parent 1caf502 commit 53a995f
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 11 deletions.
2 changes: 1 addition & 1 deletion genai_stack/genai_server/services/etl_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def submit_job(self, data: Any, stack_session_id: Optional[int] = None) -> ETLJo

data = ETLUtil(data).save_request(etl_job.id)

stack = get_current_stack(config=stack_config, session=stack_session)
stack = get_current_stack(config=stack_config, engine=session, session=stack_session)
get_etl_platform(stack=stack).handle_job(**data)

etl_job.data = data
Expand Down
9 changes: 5 additions & 4 deletions genai_stack/genai_server/services/prompt_engine_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_prompt(self, data: PromptEngineGetRequestModel) -> PromptEngineGetRespon
input_variables.remove("history")
prompt = PromptTemplate(template=template, input_variables=input_variables)
stack = get_current_stack(
engine=session,
config=stack_config,
session=stack_session,
overide_config={
Expand All @@ -50,7 +51,7 @@ def get_prompt(self, data: PromptEngineGetRequestModel) -> PromptEngineGetRespon
}
)
else:
stack = get_current_stack(config=stack_config, session=stack_session)
stack = get_current_stack(config=stack_config, engine=session, session=stack_session)
prompt = stack.prompt_engine.get_prompt_template(promptType=data.type, query=data.query)
return PromptEngineGetResponseModel(
template=prompt.template,
Expand All @@ -64,12 +65,12 @@ def set_prompt(self, data: PromptEngineSetRequestModel) -> PromptEngineSetRespon
if stack_session is None:
raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found")
input_variables = ["context", "history", "query"]
if data.type == PromptTypeEnum.SIMPLE_CHAT_PROMPT:
if data.type.value == PromptTypeEnum.SIMPLE_CHAT_PROMPT.value:
input_variables.remove("context")
elif data.type == PromptTypeEnum.CONTEXTUAL_QA_PROMPT:
elif data.type.value == PromptTypeEnum.CONTEXTUAL_QA_PROMPT.value:
input_variables.remove("history")
for variable in input_variables:
if f"{variable}" not in data.template:
if variable not in data.template:
raise HTTPException(status_code=400, detail=f"Input variable {variable} not found in template")
for variable in data.template.split("{"):
if "}" in variable and variable.split("}")[0] not in input_variables:
Expand Down
2 changes: 1 addition & 1 deletion genai_stack/genai_server/services/retriever_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def retrieve(self, data: RetrieverRequestModel) -> RetrieverResponseModel:
stack_session = session.get(StackSessionSchema, data.session_id)
if stack_session is None:
raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found")
stack = get_current_stack(config=stack_config, session=stack_session)
stack = get_current_stack(config=stack_config, engine=session, session=stack_session)
response = stack.retriever.retrieve(data.query)
return RetrieverResponseModel(
output=response['output'],
Expand Down
2 changes: 1 addition & 1 deletion genai_stack/genai_server/services/session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def create_session(self) -> StackSessionResponseModel:
created_at : datetime
modified_at : None
"""
stack = get_current_stack(stack_config, default_session=False)
stack = get_current_stack(config=stack_config, default_session=False)

with Session(self.engine) as session:
stack_session = StackSessionSchema(stack_id=1, meta_data={})
Expand Down
4 changes: 2 additions & 2 deletions genai_stack/genai_server/services/vectordb_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def add_documents(self, data: RetrieverAddDocumentsRequestModel) -> RetrieverAdd
stack_session = session.get(StackSessionSchema, data.session_id)
if stack_session is None:
raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found")
stack = get_current_stack(config=stack_config, session=stack_session)
stack = get_current_stack(config=stack_config, engine=session, session=stack_session)
stack.vectordb.add_documents(data.documents)
return RetrieverAddDocumentsResponseModel(
documents=[
Expand All @@ -34,7 +34,7 @@ def search(self, data: RetrieverSearchRequestModel) -> RetrieverSearchResponseMo

with Session(self.engine) as session:
stack_session = session.get(StackSessionSchema, data.session_id)
stack = get_current_stack(config=stack_config, session=stack_session)
stack = get_current_stack(config=stack_config, engine=session, session=stack_session)
if stack_session is None:
raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found")
documents = stack.vectordb.search(data.query)
Expand Down
55 changes: 54 additions & 1 deletion genai_stack/genai_server/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import string
import random

from langchain.prompts import PromptTemplate

from genai_stack.genai_server.schemas import PromptSchema
from genai_stack.prompt_engine.utils import PromptTypeEnum
from genai_stack.utils import import_class
from genai_stack.enums import StackComponentType
from genai_stack.genai_server.models.session_models import StackSessionResponseModel
Expand Down Expand Up @@ -68,7 +72,56 @@ def create_indexes(stack, stack_id: int, session_id: int) -> dict:
return meta_data


def get_current_stack(config: dict, session=None, default_session: bool = True, overide_config: dict = None):
def get_prompt_from_db(session, session_id, stack_config):
prompt_sessions = (
session.query(PromptSchema)
.filter_by(stack_session=session_id)
)
prompt_type_map = {
PromptTypeEnum.SIMPLE_CHAT_PROMPT.value: {
"field": "simple_chat_prompt_template",
"input_variables": ["history", "query"]
},
PromptTypeEnum.CONTEXTUAL_CHAT_PROMPT.value: {
"field": "contextual_chat_prompt_template",
"input_variables": ["context", "history", "query"]
},
PromptTypeEnum.CONTEXTUAL_QA_PROMPT.value: {
"field": "contextual_qa_prompt_template",
"input_variables": ["context", "query"]
}
}
for prompt_session in prompt_sessions:
if "prompt_engine" not in stack_config["components"]:
stack_config["components"]["prompt_engine"] = {
"name": "PromptEngine",
"config": {}
}
if "config" not in stack_config["components"]["prompt_engine"]:
stack_config["components"]["prompt_engine"]["config"] = {}
stack_config["components"]["prompt_engine"]["config"] = {
**stack_config["components"]["prompt_engine"]["config"],
prompt_type_map[prompt_session.type.value]["field"]: PromptTemplate(
template=prompt_session.template,
input_variables=prompt_type_map[prompt_session.type.value]["input_variables"]
)
}
return stack_config


def get_current_stack(
config: dict,
engine=None,
session=None,
default_session: bool = True,
overide_config: dict = None
):
if engine is not None:
config = get_prompt_from_db(
session=engine,
session_id=session.id,
stack_config=config
)
components = {}
if session is None and default_session:
from genai_stack.genai_server.settings.settings import settings
Expand Down
2 changes: 1 addition & 1 deletion tests/api/test_genai_server/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def setUp(self) -> None:
def test_retrieve(self):
response = requests.get(
url=self.base_url + "/retrieve",
params={"session_id": 1, "query": "Where is sunil from ?"},
params={"session_id": 2, "query": "Where is sunil from ?"},
)
assert response.status_code == 200
assert response.json()
Expand Down

0 comments on commit 53a995f

Please sign in to comment.