diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 81716c68f24..aca9a2ddf71 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -27,7 +27,7 @@ from rag.utils import rmSpace from api.db import LLMType, ParserType from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import TenantLLMService +from api.db.services.llm_service import LLMBundle from api.db.services.user_service import UserTenantService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.db.services.document_service import DocumentService @@ -141,8 +141,7 @@ def set(): return get_data_error_result(retmsg="Tenant not found!") embd_id = DocumentService.get_embd_id(req["doc_id"]) - embd_mdl = TenantLLMService.model_instance( - tenant_id, LLMType.EMBEDDING.value, embd_id) + embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id) e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: @@ -235,8 +234,7 @@ def create(): return get_data_error_result(retmsg="Tenant not found!") embd_id = DocumentService.get_embd_id(req["doc_id"]) - embd_mdl = TenantLLMService.model_instance( - tenant_id, LLMType.EMBEDDING.value, embd_id) + embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id) v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) v = 0.1 * v[0] + 0.9 * v[1] @@ -281,16 +279,14 @@ def retrieval_test(): if not e: return get_data_error_result(retmsg="Knowledgebase not found!") - embd_mdl = TenantLLMService.model_instance( - kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) + embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) rerank_mdl = None if req.get("rerank_id"): - rerank_mdl = TenantLLMService.model_instance( - kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) + rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) if req.get("keyword", False): - chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) + chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 993c02ee928..18b8a7dba80 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -78,6 +78,7 @@ def count(): def llm_id2llm_type(llm_id): + llm_id = llm_id.split("@")[0] fnm = os.path.join(get_project_base_directory(), "conf") llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r")) for llm_factory in llm_factories["factory_llm_infos"]: @@ -89,9 +90,15 @@ def llm_id2llm_type(llm_id): def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." st = timer() - llm = LLMService.query(llm_name=dialog.llm_id) + tmp = dialog.llm_id.split("@") + fid = None + llm_id = tmp[0] + if len(tmp)>1: fid = tmp[1] + + llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid) if not llm: - llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id) + llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \ + TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=fid) if not llm: raise LookupError("LLM(%s) not found" % dialog.llm_id) max_tokens = 8192 diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 0309e38588b..87bd59be9b7 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -17,7 +17,7 @@ from api.settings import database_logger from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel from api.db import LLMType -from api.db.db_models import DB, UserTenant +from api.db.db_models import DB from api.db.db_models import LLMFactories, LLM, TenantLLM from api.db.services.common_service import CommonService @@ -36,7 +36,11 @@ class TenantLLMService(CommonService): @classmethod @DB.connection_context() def get_api_key(cls, tenant_id, model_name): - objs = cls.query(tenant_id=tenant_id, llm_name=model_name) + arr = model_name.split("@") + if len(arr) < 2: + objs = cls.query(tenant_id=tenant_id, llm_name=model_name) + else: + objs = cls.query(tenant_id=tenant_id, llm_name=arr[0], llm_factory=arr[1]) if not objs: return return objs[0] @@ -81,14 +85,17 @@ def model_instance(cls, tenant_id, llm_type, assert False, "LLM type error" model_config = cls.get_api_key(tenant_id, mdlnm) + tmp = mdlnm.split("@") + fid = None if len(tmp) < 2 else tmp[1] + mdlnm = tmp[0] if model_config: model_config = model_config.to_dict() if not model_config: if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: - llm = LLMService.query(llm_name=llm_name if llm_name else mdlnm) + llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]: - model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name if llm_name else mdlnm, "api_base": ""} + model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": mdlnm, "api_base": ""} if not model_config: - if llm_name == "flag-embedding": + if mdlnm == "flag-embedding": model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""} else: diff --git a/rag/app/naive.py b/rag/app/naive.py index 54bfd77c340..9e0724de5d4 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -76,7 +76,7 @@ def __call__(self, filename, binary=None, from_page=0, to_page=100000): if last_image: image_list.insert(0, last_image) last_image = None - lines.append((self.__clean(p.text), image_list, p.style.name)) + lines.append((self.__clean(p.text), image_list, p.style.name if p.style else "")) else: if current_image := self.get_picture(self.doc, p): if lines: