Skip to content

Commit

Permalink
fix duplicated llm name betweeen different suppliers (infiniflow#2477)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

infiniflow#2465

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
  • Loading branch information
KevinHuSh authored Sep 18, 2024
1 parent 71c90d5 commit cdd8a5f
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 18 deletions.
16 changes: 6 additions & 10 deletions api/apps/chunk_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from rag.utils import rmSpace
from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db.services.document_service import DocumentService
Expand Down Expand Up @@ -141,8 +141,7 @@ def set():
return get_data_error_result(retmsg="Tenant not found!")

embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = TenantLLMService.model_instance(
tenant_id, LLMType.EMBEDDING.value, embd_id)
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)

e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
Expand Down Expand Up @@ -235,8 +234,7 @@ def create():
return get_data_error_result(retmsg="Tenant not found!")

embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = TenantLLMService.model_instance(
tenant_id, LLMType.EMBEDDING.value, embd_id)
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)

v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1]
Expand Down Expand Up @@ -281,16 +279,14 @@ def retrieval_test():
if not e:
return get_data_error_result(retmsg="Knowledgebase not found!")

embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)

rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])

if req.get("keyword", False):
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)

retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
Expand Down
11 changes: 9 additions & 2 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def count():


def llm_id2llm_type(llm_id):
llm_id = llm_id.split("@")[0]
fnm = os.path.join(get_project_base_directory(), "conf")
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
for llm_factory in llm_factories["factory_llm_infos"]:
Expand All @@ -89,9 +90,15 @@ def llm_id2llm_type(llm_id):
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
st = timer()
llm = LLMService.query(llm_name=dialog.llm_id)
tmp = dialog.llm_id.split("@")
fid = None
llm_id = tmp[0]
if len(tmp)>1: fid = tmp[1]

llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
if not llm:
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id)
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=fid)
if not llm:
raise LookupError("LLM(%s) not found" % dialog.llm_id)
max_tokens = 8192
Expand Down
17 changes: 12 additions & 5 deletions api/db/services/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from api.settings import database_logger
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
from api.db import LLMType
from api.db.db_models import DB, UserTenant
from api.db.db_models import DB
from api.db.db_models import LLMFactories, LLM, TenantLLM
from api.db.services.common_service import CommonService

Expand All @@ -36,7 +36,11 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_name):
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
arr = model_name.split("@")
if len(arr) < 2:
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
else:
objs = cls.query(tenant_id=tenant_id, llm_name=arr[0], llm_factory=arr[1])
if not objs:
return
return objs[0]
Expand Down Expand Up @@ -81,14 +85,17 @@ def model_instance(cls, tenant_id, llm_type,
assert False, "LLM type error"

model_config = cls.get_api_key(tenant_id, mdlnm)
tmp = mdlnm.split("@")
fid = None if len(tmp) < 2 else tmp[1]
mdlnm = tmp[0]
if model_config: model_config = model_config.to_dict()
if not model_config:
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
llm = LLMService.query(llm_name=llm_name if llm_name else mdlnm)
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name if llm_name else mdlnm, "api_base": ""}
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": mdlnm, "api_base": ""}
if not model_config:
if llm_name == "flag-embedding":
if mdlnm == "flag-embedding":
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
"llm_name": llm_name, "api_base": ""}
else:
Expand Down
2 changes: 1 addition & 1 deletion rag/app/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __call__(self, filename, binary=None, from_page=0, to_page=100000):
if last_image:
image_list.insert(0, last_image)
last_image = None
lines.append((self.__clean(p.text), image_list, p.style.name))
lines.append((self.__clean(p.text), image_list, p.style.name if p.style else ""))
else:
if current_image := self.get_picture(self.doc, p):
if lines:
Expand Down

0 comments on commit cdd8a5f

Please sign in to comment.