Skip to content

Commit

Permalink
API: retrieval api (infiniflow#1763)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Add retrieval api on a specific knowledge base


![ragflow](https://github.com/user-attachments/assets/dc30a4c3-03c5-4d34-bb7c-60b8830f1225)

infiniflow#1102

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
Valdanitooooo authored Aug 1, 2024
1 parent 7fbc182 commit 115fdaa
Showing 1 changed file with 55 additions and 1 deletion.
56 changes: 55 additions & 1 deletion api/apps/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from flask import request, Response
from flask_login import login_required, current_user

from api.db import FileType, ParserType, FileSource
from api.db import FileType, ParserType, FileSource, LLMType
from api.db.db_models import APIToken, API4Conversation, Task, File
from api.db.services import duplicate_name
from api.db.services.api_service import APITokenService, API4ConversationService
Expand All @@ -29,6 +29,7 @@
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService
from api.db.services.task_service import queue_tasks, TaskService
from api.db.services.user_service import UserTenantService
from api.settings import RetCode, retrievaler
Expand All @@ -37,6 +38,7 @@
from itsdangerous import URLSafeTimedSerializer

from api.utils.file_utils import filename_type, thumbnail
from rag.nlp import keyword_extraction
from rag.utils.minio_conn import MINIO


Expand Down Expand Up @@ -587,3 +589,55 @@ def fillin_conv(ans):

except Exception as e:
return server_error_response(e)


@manager.route('/retrieval', methods=['POST'])
@validate_request("kb_id", "question")
def retrieval():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)

req = request.json
kb_id = req.get("kb_id")
doc_ids = req.get("doc_ids", [])
question = req.get("question")
page = int(req.get("page", 1))
size = int(req.get("size", 30))
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))

try:
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
return get_data_error_result(retmsg="Knowledgebase not found!")

embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)

rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])

if req.get("keyword", False):
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)

ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl)
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]

return get_json_result(data=ranks)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
retcode=RetCode.DATA_ERROR)
return server_error_response(e)

0 comments on commit 115fdaa

Please sign in to comment.