diff --git a/dummy.ipynb b/dummy.ipynb new file mode 100644 index 00000000..8731515b --- /dev/null +++ b/dummy.ipynb @@ -0,0 +1,170 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2023-10-05T13:43:29.749167090Z", + "start_time": "2023-10-05T13:43:24.752225360Z" + } + }, + "outputs": [], + "source": [ + "from genai_stack.stack.stack import Stack\n", + "from genai_stack.model import HuggingFaceModel\n", + "from genai_stack.etl.langchain import LangchainETL\n", + "from genai_stack.embedding.langchain import LangchainEmbedding\n", + "from genai_stack.vectordb.chromadb import ChromaDB\n", + "from genai_stack.prompt_engine.engine import PromptEngine\n", + "from genai_stack.retriever.langchain import LangChainRetriever\n", + "from genai_stack.memory.langchain import ConversationBufferMemory" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [], + "source": [ + "etl = LangchainETL.from_kwargs(\n", + " name=\"WebBaseLoader\",\n", + " fields={\"web_path\": [\n", + " \"https://aiplanet.com\",\n", + " ]\n", + " }\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-05T13:43:29.752762828Z", + "start_time": "2023-10-05T13:43:29.749885841Z" + } + }, + "id": "58197b4e9c357f27" + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [], + "source": [ + "llm = HuggingFaceModel.from_kwargs(model=\"skt/ko-gpt-trinity-1.2B-v0.5\")\n", + "config = {\n", + " \"model_name\": \"sentence-transformers/all-mpnet-base-v2\",\n", + " \"model_kwargs\": {\"device\": \"cpu\"},\n", + " \"encode_kwargs\": {\"normalize_embeddings\": False},\n", + "}\n", + "embedding = LangchainEmbedding.from_kwargs(name=\"HuggingFaceEmbeddings\", fields=config)\n", + "chromadb = ChromaDB.from_kwargs()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-05T13:44:37.250590Z", + "start_time": "2023-10-05T13:44:37.223926726Z" + } + }, + "id": "42fa2dcd7085ac80" + }, + { + "cell_type": "code", + "execution_count": 1, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'PromptEngine' is not defined", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[1], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m prompt_engine \u001B[38;5;241m=\u001B[39m \u001B[43mPromptEngine\u001B[49m\u001B[38;5;241m.\u001B[39mfrom_kwargs(should_validate\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m)\n\u001B[1;32m 2\u001B[0m retriever \u001B[38;5;241m=\u001B[39m LangChainRetriever\u001B[38;5;241m.\u001B[39mfrom_kwargs()\n\u001B[1;32m 3\u001B[0m memory \u001B[38;5;241m=\u001B[39m ConversationBufferMemory\u001B[38;5;241m.\u001B[39mfrom_kwargs()\n", + "\u001B[0;31mNameError\u001B[0m: name 'PromptEngine' is not defined" + ] + } + ], + "source": [ + "prompt_engine = PromptEngine.from_kwargs(should_validate=False)\n", + "retriever = LangChainRetriever.from_kwargs()\n", + "memory = ConversationBufferMemory.from_kwargs()\n", + "Stack(\n", + " etl=etl,\n", + " embedding=embedding,\n", + " vectordb=chromadb,\n", + " model=llm,\n", + " prompt_engine=prompt_engine,\n", + " retriever=retriever,\n", + " memory=memory\n", + ")\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-05T13:45:48.582398487Z", + "start_time": "2023-10-05T13:45:48.417937857Z" + } + }, + "id": "81e991b63b233e6d" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "prompt1 = \"Why choose models from AI Marketplace?\"" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2023-10-05T13:43:49.783528419Z" + } + }, + "id": "2e815c9b106810f0" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2023-10-05T13:43:49.786045019Z" + } + }, + "id": "80d05b20a6921f73" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "8226f4204a117563" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/genai_stack/genai_server/services/prompt_engine_service.py b/genai_stack/genai_server/services/prompt_engine_service.py index 47364323..9231bd30 100644 --- a/genai_stack/genai_server/services/prompt_engine_service.py +++ b/genai_stack/genai_server/services/prompt_engine_service.py @@ -21,37 +21,7 @@ def get_prompt(self, data: PromptEngineGetRequestModel) -> PromptEngineGetRespon stack_session = session.get(StackSessionSchema, data.session_id) if stack_session is None: raise HTTPException(status_code=404, detail=f"Session {data.session_id} not found") - prompt_session = ( - session.query(PromptSchema) - .filter_by(stack_session=data.session_id, type=data.type.value) - .first() - ) - if prompt_session is not None: - template = prompt_session.template - prompt_type_map = { - PromptTypeEnum.SIMPLE_CHAT_PROMPT.value: "simple_chat_prompt_template", - PromptTypeEnum.CONTEXTUAL_CHAT_PROMPT.value: "contextual_chat_prompt_template", - PromptTypeEnum.CONTEXTUAL_QA_PROMPT.value: "contextual_qa_prompt_template", - } - input_variables = ["context", "history", "query"] - if data.type == PromptTypeEnum.SIMPLE_CHAT_PROMPT: - input_variables.remove("context") - elif data.type == PromptTypeEnum.CONTEXTUAL_QA_PROMPT: - input_variables.remove("history") - prompt = PromptTemplate(template=template, input_variables=input_variables) - stack = get_current_stack( - engine=session, - config=stack_config, - session=stack_session, - overide_config={ - "prompt_engine": { - "should_validate": data.should_validate, - prompt_type_map[data.type.value]: prompt - } - } - ) - else: - stack = get_current_stack(config=stack_config, engine=session, session=stack_session) + stack = get_current_stack(config=stack_config, engine=session, session=stack_session) prompt = stack.prompt_engine.get_prompt_template(promptType=data.type, query=data.query) return PromptEngineGetResponseModel( template=prompt.template,