From 336a639164685788428b1c4e7381a538e10f59e2 Mon Sep 17 00:00:00 2001 From: LiuHua <10215101452@stu.ecnu.edu.cn> Date: Mon, 9 Sep 2024 17:18:08 +0800 Subject: [PATCH] SDK for session (#2312) ### What problem does this PR solve? SDK for session #1102 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Feiue <10215101452@stu.ecun.edu.cn> Co-authored-by: Kevin Hu --- api/apps/sdk/assistant.py | 43 +++-- api/apps/sdk/dataset.py | 6 +- api/apps/sdk/session.py | 168 +++++++++++++++++++ sdk/python/ragflow/modules/chat_assistant.py | 21 ++- sdk/python/ragflow/modules/session.py | 64 +++++++ sdk/python/ragflow/ragflow.py | 5 +- sdk/python/test/t_assistant.py | 26 +-- sdk/python/test/t_session.py | 27 +++ 8 files changed, 325 insertions(+), 35 deletions(-) create mode 100644 api/apps/sdk/session.py create mode 100644 sdk/python/ragflow/modules/session.py create mode 100644 sdk/python/test/t_session.py diff --git a/api/apps/sdk/assistant.py b/api/apps/sdk/assistant.py index c71b1e6705c..fe059ccc03b 100644 --- a/api/apps/sdk/assistant.py +++ b/api/apps/sdk/assistant.py @@ -16,9 +16,10 @@ from flask import request from api.db import StatusEnum +from api.db.db_models import TenantLLM from api.db.services.dialog_service import DialogService -from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.llm_service import LLMService, TenantLLMService from api.db.services.user_service import TenantService from api.settings import RetCode from api.utils import get_uuid @@ -30,7 +31,6 @@ @token_required def save(tenant_id): req = request.json - id = req.get("id") # dataset if req.get("knowledgebases") == []: return get_data_error_result(retmsg="knowledgebases can not be empty list") @@ -41,8 +41,8 @@ def save(tenant_id): return get_data_error_result(retmsg="knowledgebase needs id") if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id): return get_data_error_result(retmsg="you do not own the knowledgebase") - if not DocumentService.query(kb_id=kb["id"]): - return get_data_error_result(retmsg="There is a invalid knowledgebase") + # if not DocumentService.query(kb_id=kb["id"]): + # return get_data_error_result(retmsg="There is a invalid knowledgebase") kb_list.append(kb["id"]) req["kb_ids"] = kb_list # llm @@ -72,10 +72,10 @@ def save(tenant_id): req[key] = prompt.pop(key) req["prompt_config"] = req.pop("prompt") # create - if not id: + if "id" not in req: # dataset if not kb_list: - return get_data_error_result(retmsg="knowledgebase is required!") + return get_data_error_result(retmsg="knowledgebases are required!") # init req["id"] = get_uuid() req["description"] = req.get("description", "A helpful Assistant") @@ -83,7 +83,11 @@ def save(tenant_id): req["top_n"] = req.get("top_n", 6) req["top_k"] = req.get("top_k", 1024) req["rerank_id"] = req.get("rerank_id", "") - req["llm_id"] = req.get("llm_id", tenant.llm_id) + if req.get("llm_id"): + if not TenantLLMService.query(llm_name=req["llm_id"]): + return get_data_error_result(retmsg="the model_name does not exist.") + else: + req["llm_id"] = tenant.llm_id if not req.get("name"): return get_data_error_result(retmsg="name is required.") if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): @@ -149,14 +153,20 @@ def save(tenant_id): if not DialogService.query(tenant_id=tenant_id, id=req["id"], status=StatusEnum.VALID.value): return get_json_result(data=False, retmsg='You do not own the assistant', retcode=RetCode.OPERATING_ERROR) # prompt + if not req["id"]: + return get_data_error_result(retmsg="id can not be empty") e, res = DialogService.get_by_id(req["id"]) res = res.to_json() + if "llm_id" in req: + if not TenantLLMService.query(llm_name=req["llm_id"]): + return get_data_error_result(retmsg="the model_name does not exist.") if "name" in req: if not req.get("name"): return get_data_error_result(retmsg="name is not empty.") if req["name"].lower() != res["name"].lower() \ - and len(DialogService.query(name=req["name"], tenant_id=tenant_id,status=StatusEnum.VALID.value)) > 0: - return get_data_error_result(retmsg="Duplicated knowledgebase name in updating dataset.") + and len( + DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0: + return get_data_error_result(retmsg="Duplicated assistant name in updating dataset.") if "prompt_config" in req: res["prompt_config"].update(req["prompt_config"]) for p in res["prompt_config"]["parameters"]: @@ -186,7 +196,7 @@ def delete(tenant_id): if "id" not in req: return get_data_error_result(retmsg="id is required") id = req['id'] - if not DialogService.query(tenant_id=tenant_id, id=id,status=StatusEnum.VALID.value): + if not DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value): return get_json_result(data=False, retmsg='you do not own the assistant.', retcode=RetCode.OPERATING_ERROR) temp_dict = {"status": StatusEnum.INVALID.value} @@ -200,21 +210,22 @@ def get(tenant_id): req = request.args if "id" in req: id = req["id"] - ass = DialogService.query(tenant_id=tenant_id, id=id,status=StatusEnum.VALID.value) + ass = DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value) if not ass: return get_json_result(data=False, retmsg='You do not own the assistant.', retcode=RetCode.OPERATING_ERROR) if "name" in req: name = req["name"] if ass[0].name != name: return get_json_result(data=False, retmsg='name does not match id.', retcode=RetCode.OPERATING_ERROR) - res=ass[0].to_json() + res = ass[0].to_json() else: if "name" in req: name = req["name"] - ass = DialogService.query(name=name, tenant_id=tenant_id,status=StatusEnum.VALID.value) + ass = DialogService.query(name=name, tenant_id=tenant_id, status=StatusEnum.VALID.value) if not ass: - return get_json_result(data=False, retmsg='You do not own the dataset.',retcode=RetCode.OPERATING_ERROR) - res=ass[0].to_json() + return get_json_result(data=False, retmsg='You do not own the assistant.', + retcode=RetCode.OPERATING_ERROR) + res = ass[0].to_json() else: return get_data_error_result(retmsg="At least one of `id` or `name` must be provided.") renamed_dict = {} @@ -258,7 +269,7 @@ def list_assistants(tenant_id): reverse=True, order_by=DialogService.model.create_time) assts = [d.to_dict() for d in assts] - list_assts=[] + list_assts = [] renamed_dict = {} key_mapping = {"parameters": "variables", "prologue": "opener", diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index bd14587c13c..94bdf63839f 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -60,7 +60,7 @@ def save(tenant_id): req.update(mapped_keys) if not KnowledgebaseService.save(**req): return get_data_error_result(retmsg="Create dataset error.(Database error)") - renamed_data={} + renamed_data = {} e, k = KnowledgebaseService.get_by_id(req["id"]) for key, value in k.to_dict().items(): new_key = key_mapping.get(key, key) @@ -88,6 +88,9 @@ def save(tenant_id): data=False, retmsg='You do not own the dataset.', retcode=RetCode.OPERATING_ERROR) + if not req["id"]: + return get_data_error_result( + retmsg="id can not be empty.") e, kb = KnowledgebaseService.get_by_id(req["id"]) if "chunk_count" in req: @@ -108,6 +111,7 @@ def save(tenant_id): retmsg="If chunk count is not 0, parse method is not changable.") req['parser_id'] = req.pop('parse_method') if "name" in req: + req["name"] = req["name"].strip() if req["name"].lower() != kb.name.lower() \ and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0: diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py new file mode 100644 index 00000000000..8df43ee3840 --- /dev/null +++ b/api/apps/sdk/session.py @@ -0,0 +1,168 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +from copy import deepcopy +from uuid import uuid4 + +from flask import request, Response + +from api.db import StatusEnum +from api.db.services.dialog_service import DialogService, ConversationService, chat +from api.utils import get_uuid +from api.utils.api_utils import get_data_error_result +from api.utils.api_utils import get_json_result, token_required + + +@manager.route('/save', methods=['POST']) +@token_required +def set_conversation(tenant_id): + req = request.json + conv_id = req.get("id") + if "messages" in req: + req["message"] = req.pop("messages") + if req["message"]: + for message in req["message"]: + if "reference" in message: + req["reference"] = message.pop("reference") + if "assistant_id" in req: + req["dialog_id"] = req.pop("assistant_id") + if "id" in req: + del req["id"] + conv = ConversationService.query(id=conv_id) + if not conv: + return get_data_error_result(retmsg="Session does not exist") + if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): + return get_data_error_result(retmsg="You do not own the session") + if req.get("dialog_id"): + dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) + if not dia: + return get_data_error_result(retmsg="You do not own the assistant") + if "dialog_id" in req and not req.get("dialog_id"): + return get_data_error_result(retmsg="assistant_id can not be empty.") + if "name" in req and not req.get("name"): + return get_data_error_result(retmsg="name can not be empty.") + if "message" in req and not req.get("message"): + return get_data_error_result(retmsg="messages can not be empty") + if not ConversationService.update_by_id(conv_id, req): + return get_data_error_result(retmsg="Session updates error") + return get_json_result(data=True) + + if not req.get("dialog_id"): + return get_data_error_result(retmsg="assistant_id is required.") + dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) + if not dia: + return get_data_error_result(retmsg="You do not own the assistant") + conv = { + "id": get_uuid(), + "dialog_id": req["dialog_id"], + "name": req.get("name", "New session"), + "message": req.get("message", [{"role": "assistant", "content": dia[0].prompt_config["prologue"]}]), + "reference": req.get("reference", []) + } + if not conv.get("name"): + return get_data_error_result(retmsg="name can not be empty.") + if not conv.get("message"): + return get_data_error_result(retmsg="messages can not be empty") + ConversationService.save(**conv) + e, conv = ConversationService.get_by_id(conv["id"]) + if not e: + return get_data_error_result(retmsg="Fail to new session!") + conv = conv.to_dict() + conv["messages"] = conv.pop("message") + conv["assistant_id"] = conv.pop("dialog_id") + for message in conv["messages"]: + message["reference"] = conv.get("reference") + del conv["reference"] + return get_json_result(data=conv) + + +@manager.route('/completion', methods=['POST']) +@token_required +def completion(tenant_id): + req = request.json + # req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [ + # {"role": "user", "content": "上海有吗?"} + # ]} + msg = [] + question = { + "content": req.get("question"), + "role": "user", + "id": str(uuid4()) + } + req["messages"].append(question) + for m in req["messages"]: + if m["role"] == "system": continue + if m["role"] == "assistant" and not msg: continue + m["id"] = m.get("id", str(uuid4())) + msg.append(m) + message_id = msg[-1].get("id") + conv = ConversationService.query(id=req["id"]) + conv = conv[0] + if not conv: + return get_data_error_result(retmsg="Session does not exist") + if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): + return get_data_error_result(retmsg="You do not own the session") + conv.message = deepcopy(req["messages"]) + e, dia = DialogService.get_by_id(conv.dialog_id) + if not e: + return get_data_error_result(retmsg="Dialog not found!") + del req["id"] + del req["messages"] + + if not conv.reference: + conv.reference = [] + conv.message.append({"role": "assistant", "content": "", "id": message_id}) + conv.reference.append({"chunks": [], "doc_aggs": []}) + + def fillin_conv(ans): + nonlocal conv, message_id + if not conv.reference: + conv.reference.append(ans["reference"]) + else: + conv.reference[-1] = ans["reference"] + conv.message[-1] = {"role": "assistant", "content": ans["answer"], + "id": message_id, "prompt": ans.get("prompt", "")} + ans["id"] = message_id + + def stream(): + nonlocal dia, msg, req, conv + try: + for ans in chat(dia, msg, **req): + fillin_conv(ans) + yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" + ConversationService.update_by_id(conv.id, conv.to_dict()) + except Exception as e: + yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), + "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, + ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" + + if req.get("stream", True): + resp = Response(stream(), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + + else: + answer = None + for ans in chat(dia, msg, **req): + answer = ans + fillin_conv(ans) + ConversationService.update_by_id(conv.id, conv.to_dict()) + break + return get_json_result(data=answer) diff --git a/sdk/python/ragflow/modules/chat_assistant.py b/sdk/python/ragflow/modules/chat_assistant.py index d5ec05bdfb1..515d8b59efd 100644 --- a/sdk/python/ragflow/modules/chat_assistant.py +++ b/sdk/python/ragflow/modules/chat_assistant.py @@ -1,9 +1,12 @@ +from typing import List + from .base import Base +from .session import Session, Message class Assistant(Base): def __init__(self, rag, res_dict): - self.id="" + self.id = "" self.name = "assistant" self.avatar = "path/to/avatar" self.knowledgebases = ["kb1"] @@ -41,8 +44,8 @@ def __init__(self, rag, res_dict): def save(self) -> bool: res = self.post('/assistant/save', - {"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases":self.knowledgebases, - "llm":self.llm.to_json(),"prompt":self.prompt.to_json() + {"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases": self.knowledgebases, + "llm": self.llm.to_json(), "prompt": self.prompt.to_json() }) res = res.json() if res.get("retmsg") == "success": return True @@ -54,3 +57,15 @@ def delete(self) -> bool: res = res.json() if res.get("retmsg") == "success": return True raise Exception(res["retmsg"]) + + def create_session(self, name: str = "New session", messages: List[Message] = [ + {"role": "assistant", "reference": [], + "content": "您好,我是您的助手小樱,长得可爱又善良,can I help you?"}]) -> Session: + res = self.post("/session/save", {"name": name, "messages": messages, "assistant_id": self.id, }) + res = res.json() + if res.get("retmsg") == "success": + return Session(self.rag, res['data']) + raise Exception(res["retmsg"]) + + def get_prologue(self): + return self.prompt.opener diff --git a/sdk/python/ragflow/modules/session.py b/sdk/python/ragflow/modules/session.py new file mode 100644 index 00000000000..3b29c7a245a --- /dev/null +++ b/sdk/python/ragflow/modules/session.py @@ -0,0 +1,64 @@ +import json + +from .base import Base + + +class Session(Base): + def __init__(self, rag, res_dict): + self.id = None + self.name = "New session" + self.messages = [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}] + + self.assistant_id = None + super().__init__(rag, res_dict) + + def chat(self, question: str, stream: bool = False): + res = self.post("/session/completion", + {"id": self.id, "question": question, "stream": stream, "messages": self.messages}) + res = res.text + response_lines = res.splitlines() + message_list = [] + for line in response_lines: + if line.startswith("data:"): + json_data = json.loads(line[5:]) + if json_data["data"] != True: + answer = json_data["data"]["answer"] + reference = json_data["data"]["reference"] + temp_dict = { + "content": answer, + "role": "assistant", + "reference": reference + } + message = Message(self.rag, temp_dict) + message_list.append(message) + return message_list + + def save(self): + res = self.post("/session/save", + {"id": self.id, "dialog_id": self.assistant_id, "name": self.name, "message": self.messages}) + res = res.json() + if res.get("retmsg") == "success": return True + raise Exception(res.get("retmsg")) + +class Message(Base): + def __init__(self, rag, res_dict): + self.content = "您好,我是您的助手小樱,长得可爱又善良,can I help you?" + self.reference = [] + self.role = "assistant" + self.prompt=None + super().__init__(rag, res_dict) + + +class Chunk(Base): + def __init__(self, rag, res_dict): + self.id = None + self.content = None + self.document_id = None + self.document_name = None + self.knowledgebase_id = None + self.image_id = None + self.similarity = None + self.vector_similarity = None + self.term_similarity = None + self.positions = None + super().__init__(rag, res_dict) diff --git a/sdk/python/ragflow/ragflow.py b/sdk/python/ragflow/ragflow.py index 68713f03ad1..57cd62f167f 100644 --- a/sdk/python/ragflow/ragflow.py +++ b/sdk/python/ragflow/ragflow.py @@ -17,7 +17,6 @@ import requests - from .modules.chat_assistant import Assistant from .modules.dataset import DataSet @@ -88,7 +87,7 @@ def create_assistant(self, name: str = "assistant", avatar: str = "path", knowle datasets.append(dataset.to_json()) if llm is None: - llm = Assistant.LLM(self, {"model_name": "deepseek-chat", + llm = Assistant.LLM(self, {"model_name": None, "temperature": 0.1, "top_p": 0.3, "presence_penalty": 0.4, @@ -142,4 +141,4 @@ def list_assistants(self) -> List[Assistant]: for data in res['data']: result_list.append(Assistant(self, data)) return result_list - raise Exception(res["retmsg"]) \ No newline at end of file + raise Exception(res["retmsg"]) diff --git a/sdk/python/test/t_assistant.py b/sdk/python/test/t_assistant.py index 7d70a337b90..ef91f5d7810 100644 --- a/sdk/python/test/t_assistant.py +++ b/sdk/python/test/t_assistant.py @@ -10,10 +10,10 @@ def test_create_assistant_with_success(self): Test creating an assistant with success """ rag = RAGFlow(API_KEY, HOST_ADDRESS) - kb = rag.get_dataset(name="God") - assistant = rag.create_assistant("God",knowledgebases=[kb]) + kb = rag.create_dataset(name="test_create_assistant") + assistant = rag.create_assistant("test_create", knowledgebases=[kb]) if isinstance(assistant, Assistant): - assert assistant.name == "God", "Name does not match." + assert assistant.name == "test_create", "Name does not match." else: assert False, f"Failed to create assistant, error: {assistant}" @@ -22,11 +22,11 @@ def test_update_assistant_with_success(self): Test updating an assistant with success. """ rag = RAGFlow(API_KEY, HOST_ADDRESS) - kb = rag.get_dataset(name="God") - assistant = rag.create_assistant("ABC",knowledgebases=[kb]) + kb = rag.create_dataset(name="test_update_assistant") + assistant = rag.create_assistant("test_update", knowledgebases=[kb]) if isinstance(assistant, Assistant): - assert assistant.name == "ABC", "Name does not match." - assistant.name = 'DEF' + assert assistant.name == "test_update", "Name does not match." + assistant.name = 'new_assistant' res = assistant.save() assert res is True, f"Failed to update assistant, error: {res}" else: @@ -37,10 +37,10 @@ def test_delete_assistant_with_success(self): Test deleting an assistant with success """ rag = RAGFlow(API_KEY, HOST_ADDRESS) - kb = rag.get_dataset(name="God") - assistant = rag.create_assistant("MA",knowledgebases=[kb]) + kb = rag.create_dataset(name="test_delete_assistant") + assistant = rag.create_assistant("test_delete", knowledgebases=[kb]) if isinstance(assistant, Assistant): - assert assistant.name == "MA", "Name does not match." + assert assistant.name == "test_delete", "Name does not match." res = assistant.delete() assert res is True, f"Failed to delete assistant, error: {res}" else: @@ -61,6 +61,8 @@ def test_get_detail_assistant_with_success(self): Test getting an assistant's detail with success """ rag = RAGFlow(API_KEY, HOST_ADDRESS) - assistant = rag.get_assistant(name="God") + kb = rag.create_dataset(name="test_get_assistant") + rag.create_assistant("test_get_assistant", knowledgebases=[kb]) + assistant = rag.get_assistant(name="test_get_assistant") assert isinstance(assistant, Assistant), f"Failed to get assistant, error: {assistant}." - assert assistant.name == "God", "Name does not match" + assert assistant.name == "test_get_assistant", "Name does not match" diff --git a/sdk/python/test/t_session.py b/sdk/python/test/t_session.py new file mode 100644 index 00000000000..6fd2e36c6f2 --- /dev/null +++ b/sdk/python/test/t_session.py @@ -0,0 +1,27 @@ +from ragflow import RAGFlow + +from common import API_KEY, HOST_ADDRESS + + +class TestChatSession: + def test_create_session(self): + rag = RAGFlow(API_KEY, HOST_ADDRESS) + kb = rag.create_dataset(name="test_create_session") + assistant = rag.create_assistant(name="test_create_session", knowledgebases=[kb]) + session = assistant.create_session() + assert assistant is not None, "Failed to get the assistant." + assert session is not None, "Failed to create a session." + + def test_create_chat_with_success(self): + rag = RAGFlow(API_KEY, HOST_ADDRESS) + kb = rag.create_dataset(name="test_create_chat") + assistant = rag.create_assistant(name="test_create_chat", knowledgebases=[kb]) + session = assistant.create_session() + assert session is not None, "Failed to create a session." + prologue = assistant.get_prologue() + assert isinstance(prologue, str), "Prologue is not a string." + assert len(prologue) > 0, "Prologue is empty." + question = "What is AI" + ans = session.chat(question, stream=True) + response = ans[-1].content + assert len(response) > 0, "Assistant did not return any response."