Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Commit

Permalink
refactor get prompt service
Browse files Browse the repository at this point in the history
  • Loading branch information
Akshaj000 committed Oct 6, 2023
1 parent 53a995f commit 5aa7ff4
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 31 deletions.
170 changes: 170 additions & 0 deletions dummy.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2023-10-05T13:43:29.749167090Z",
"start_time": "2023-10-05T13:43:24.752225360Z"
}
},
"outputs": [],
"source": [
"from genai_stack.stack.stack import Stack\n",
"from genai_stack.model import HuggingFaceModel\n",
"from genai_stack.etl.langchain import LangchainETL\n",
"from genai_stack.embedding.langchain import LangchainEmbedding\n",
"from genai_stack.vectordb.chromadb import ChromaDB\n",
"from genai_stack.prompt_engine.engine import PromptEngine\n",
"from genai_stack.retriever.langchain import LangChainRetriever\n",
"from genai_stack.memory.langchain import ConversationBufferMemory"
]
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"etl = LangchainETL.from_kwargs(\n",
" name=\"WebBaseLoader\",\n",
" fields={\"web_path\": [\n",
" \"https://aiplanet.com\",\n",
" ]\n",
" }\n",
")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-05T13:43:29.752762828Z",
"start_time": "2023-10-05T13:43:29.749885841Z"
}
},
"id": "58197b4e9c357f27"
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [],
"source": [
"llm = HuggingFaceModel.from_kwargs(model=\"skt/ko-gpt-trinity-1.2B-v0.5\")\n",
"config = {\n",
" \"model_name\": \"sentence-transformers/all-mpnet-base-v2\",\n",
" \"model_kwargs\": {\"device\": \"cpu\"},\n",
" \"encode_kwargs\": {\"normalize_embeddings\": False},\n",
"}\n",
"embedding = LangchainEmbedding.from_kwargs(name=\"HuggingFaceEmbeddings\", fields=config)\n",
"chromadb = ChromaDB.from_kwargs()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-05T13:44:37.250590Z",
"start_time": "2023-10-05T13:44:37.223926726Z"
}
},
"id": "42fa2dcd7085ac80"
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [
{
"ename": "NameError",
"evalue": "name 'PromptEngine' is not defined",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)",
"Cell \u001B[0;32mIn[1], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m prompt_engine \u001B[38;5;241m=\u001B[39m \u001B[43mPromptEngine\u001B[49m\u001B[38;5;241m.\u001B[39mfrom_kwargs(should_validate\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m)\n\u001B[1;32m 2\u001B[0m retriever \u001B[38;5;241m=\u001B[39m LangChainRetriever\u001B[38;5;241m.\u001B[39mfrom_kwargs()\n\u001B[1;32m 3\u001B[0m memory \u001B[38;5;241m=\u001B[39m ConversationBufferMemory\u001B[38;5;241m.\u001B[39mfrom_kwargs()\n",
"\u001B[0;31mNameError\u001B[0m: name 'PromptEngine' is not defined"
]
}
],
"source": [
"prompt_engine = PromptEngine.from_kwargs(should_validate=False)\n",
"retriever = LangChainRetriever.from_kwargs()\n",
"memory = ConversationBufferMemory.from_kwargs()\n",
"Stack(\n",
" etl=etl,\n",
" embedding=embedding,\n",
" vectordb=chromadb,\n",
" model=llm,\n",
" prompt_engine=prompt_engine,\n",
" retriever=retriever,\n",
" memory=memory\n",
")\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-05T13:45:48.582398487Z",
"start_time": "2023-10-05T13:45:48.417937857Z"
}
},
"id": "81e991b63b233e6d"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"prompt1 = \"Why choose models from AI Marketplace?\""
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-10-05T13:43:49.783528419Z"
}
},
"id": "2e815c9b106810f0"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-10-05T13:43:49.786045019Z"
}
},
"id": "80d05b20a6921f73"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "8226f4204a117563"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
32 changes: 1 addition & 31 deletions genai_stack/genai_server/services/prompt_engine_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,7 @@ def get_prompt(self, data: PromptEngineGetRequestModel) -> PromptEngineGetRespon
stack_session = session.get(StackSessionSchema, data.session_id)
if stack_session is None:
raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found")
prompt_session = (
session.query(PromptSchema)
.filter_by(stack_session=data.session_id, type=data.type.value)
.first()
)
if prompt_session is not None:
template = prompt_session.template
prompt_type_map = {
PromptTypeEnum.SIMPLE_CHAT_PROMPT.value: "simple_chat_prompt_template",
PromptTypeEnum.CONTEXTUAL_CHAT_PROMPT.value: "contextual_chat_prompt_template",
PromptTypeEnum.CONTEXTUAL_QA_PROMPT.value: "contextual_qa_prompt_template",
}
input_variables = ["context", "history", "query"]
if data.type == PromptTypeEnum.SIMPLE_CHAT_PROMPT:
input_variables.remove("context")
elif data.type == PromptTypeEnum.CONTEXTUAL_QA_PROMPT:
input_variables.remove("history")
prompt = PromptTemplate(template=template, input_variables=input_variables)
stack = get_current_stack(
engine=session,
config=stack_config,
session=stack_session,
overide_config={
"prompt_engine": {
"should_validate": data.should_validate,
prompt_type_map[data.type.value]: prompt
}
}
)
else:
stack = get_current_stack(config=stack_config, engine=session, session=stack_session)
stack = get_current_stack(config=stack_config, engine=session, session=stack_session)
prompt = stack.prompt_engine.get_prompt_template(promptType=data.type, query=data.query)
return PromptEngineGetResponseModel(
template=prompt.template,
Expand Down

0 comments on commit 5aa7ff4

Please sign in to comment.