From 60910dd1f9749a8dff681bf8a88f0766ee3d0672 Mon Sep 17 00:00:00 2001 From: Valdanito Date: Thu, 25 Jul 2024 18:03:01 +0800 Subject: [PATCH 1/6] API: add retrieval api --- api/apps/api_app.py | 56 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 6676941df8d..bae0527d537 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -20,7 +20,7 @@ from flask import request, Response from flask_login import login_required, current_user -from api.db import FileType, ParserType, FileSource +from api.db import FileType, ParserType, FileSource, LLMType from api.db.db_models import APIToken, API4Conversation, Task, File from api.db.services import duplicate_name from api.db.services.api_service import APITokenService, API4ConversationService @@ -29,6 +29,7 @@ from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.llm_service import TenantLLMService from api.db.services.task_service import queue_tasks, TaskService from api.db.services.user_service import UserTenantService from api.settings import RetCode, retrievaler @@ -37,6 +38,7 @@ from itsdangerous import URLSafeTimedSerializer from api.utils.file_utils import filename_type, thumbnail +from rag.nlp import keyword_extraction from rag.utils.minio_conn import MINIO @@ -587,3 +589,55 @@ def fillin_conv(ans): except Exception as e: return server_error_response(e) + + +@manager.route('/retrieval', methods=['POST']) +@validate_request("kb_id", "question") +def retrieval(): + token = request.headers.get('Authorization').split()[1] + objs = APIToken.query(token=token) + if not objs: + return get_json_result( + data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) + + req = request.json + kb_id = req.get("kb_id") + doc_ids = req.get("doc_ids", []) + question = req.get("question") + page = int(req.get("page", 1)) + size = int(req.get("size", 30)) + similarity_threshold = float(req.get("similarity_threshold", 0.2)) + vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) + top = int(req.get("top_k", 1024)) + + try: + e, kb = KnowledgebaseService.get_by_id(kb_id) + if not e: + return get_data_error_result(retmsg="Knowledgebase not found!") + + embd_mdl = TenantLLMService.model_instance( + kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) + + rerank_mdl = None + if req.get("rerank_id"): + rerank_mdl = TenantLLMService.model_instance( + kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) + + if req.get("keyword", False): + chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) + question += keyword_extraction(chat_mdl, question) + + ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, + similarity_threshold, vector_similarity_weight, top, + doc_ids, rerank_mdl=rerank_mdl) + for c in ranks["chunks"]: + if "vector" in c: + del c["vector"] + + return get_json_result(data=ranks) + except Exception as e: + if str(e).find("not_found") > 0: + return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!', + retcode=RetCode.DATA_ERROR) + return server_error_response(e) + From 933adb7a1b02e64c4a68ed75e7a83bacc234f51f Mon Sep 17 00:00:00 2001 From: Valdanito Date: Fri, 13 Sep 2024 18:37:57 +0800 Subject: [PATCH 2/6] refactor(API): Refactor datasets API --- api/apps/__init__.py | 62 ++++++++--- api/apps/apis/__init__.py | 0 api/apps/apis/datasets.py | 89 +++++++++++++++ api/apps/services/__init__.py | 0 api/apps/services/dataset_service.py | 156 +++++++++++++++++++++++++++ requirements.txt | 1 + requirements_arm.txt | 1 + sdk/python/ragflow/ragflow.py | 12 +++ sdk/python/test/test_dataset.py | 36 ++++--- 9 files changed, 327 insertions(+), 30 deletions(-) create mode 100644 api/apps/apis/__init__.py create mode 100644 api/apps/apis/datasets.py create mode 100644 api/apps/services/__init__.py create mode 100644 api/apps/services/dataset_service.py diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 4fdeb9630b1..425285afe1b 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -18,40 +18,64 @@ import sys from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path -from flask import Blueprint, Flask -from werkzeug.wrappers.request import Request +from typing import Union + +from apiflask import APIFlask, APIBlueprint, HTTPTokenAuth from flask_cors import CORS +from flask_login import LoginManager +from flask_session import Session +from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer +from werkzeug.wrappers.request import Request from api.db import StatusEnum -from api.db.db_models import close_connection +from api.db.db_models import close_connection, APIToken from api.db.services import UserService -from api.utils import CustomJSONEncoder, commands - -from flask_session import Session -from flask_login import LoginManager -from api.settings import SECRET_KEY, stat_logger from api.settings import API_VERSION, access_logger +from api.settings import SECRET_KEY, stat_logger +from api.utils import CustomJSONEncoder, commands from api.utils.api_utils import server_error_response -from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer __all__ = ['app'] - logger = logging.getLogger('flask.app') for h in access_logger.handlers: logger.addHandler(h) Request.json = property(lambda self: self.get_json(force=True, silent=True)) -app = Flask(__name__) -CORS(app, supports_credentials=True,max_age=2592000) +app = APIFlask(__name__) +auth = HTTPTokenAuth() + + +class AuthUser: + def __init__(self, tenant_id, token): + self.id = tenant_id + self.token = token + + def get_token(self): + return self.token + + +@auth.verify_token +def verify_token(token: str) -> Union[AuthUser, None]: + try: + objs = APIToken.query(token=token) + if objs: + api_token = objs[0] + user = AuthUser(api_token.tenant_id, api_token.token) + return user + except Exception as e: + server_error_response(e) + return None + + +CORS(app, supports_credentials=True, max_age=2592000) app.url_map.strict_slashes = False app.json_encoder = CustomJSONEncoder app.errorhandler(Exception)(server_error_response) - ## convince for dev and debug -#app.config["LOGIN_DISABLED"] = True +# app.config["LOGIN_DISABLED"] = True app.config["SESSION_PERMANENT"] = False app.config["SESSION_TYPE"] = "filesystem" app.config['MAX_CONTENT_LENGTH'] = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)) @@ -66,7 +90,9 @@ def search_pages_path(pages_dir): app_path_list = [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')] api_path_list = [path for path in pages_dir.glob('*sdk/*.py') if not path.name.startswith('.')] + restful_api_path_list = [path for path in pages_dir.glob('*apis/*.py') if not path.name.startswith('.')] app_path_list.extend(api_path_list) + app_path_list.extend(restful_api_path_list) return app_path_list @@ -79,11 +105,12 @@ def register_page(page_path): spec = spec_from_file_location(module_name, page_path) page = module_from_spec(spec) page.app = app - page.manager = Blueprint(page_name, module_name) + page.manager = APIBlueprint(page_name, module_name) sys.modules[module_name] = page spec.loader.exec_module(page) page_name = getattr(page, 'page_name', page_name) - url_prefix = f'/api/{API_VERSION}/{page_name}' if "/sdk/" in path else f'/{API_VERSION}/{page_name}' + url_prefix = f'/api/{API_VERSION}/{page_name}' if "/sdk/" in path or "/apis/" in path \ + else f'/{API_VERSION}/{page_name}' app.register_blueprint(page.manager, url_prefix=url_prefix) return url_prefix @@ -93,6 +120,7 @@ def register_page(page_path): Path(__file__).parent, Path(__file__).parent.parent / 'api' / 'apps', Path(__file__).parent.parent / 'api' / 'apps' / 'sdk', + Path(__file__).parent.parent / 'api' / 'apps' / 'apis', ] client_urls_prefix = [ @@ -123,4 +151,4 @@ def load_user(web_request): @app.teardown_request def _db_close(exc): - close_connection() \ No newline at end of file + close_connection() diff --git a/api/apps/apis/__init__.py b/api/apps/apis/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/api/apps/apis/datasets.py b/api/apps/apis/datasets.py new file mode 100644 index 00000000000..507dcd592ef --- /dev/null +++ b/api/apps/apis/datasets.py @@ -0,0 +1,89 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from api.apps import auth +from api.apps.services import dataset_service +from api.utils.api_utils import server_error_response + + +@manager.post('') +@manager.input(dataset_service.CreateDatasetReq, location='json') +@manager.auth_required(auth) +def create_dataset(data): + try: + tenant_id = auth.current_user.id + return dataset_service.create_dataset(tenant_id, data) + except Exception as e: + return server_error_response(e) + + +@manager.put('') +@manager.input(dataset_service.UpdateDatasetReq, location='json') +@manager.auth_required(auth) +def update_dataset(data): + try: + tenant_id = auth.current_user.id + return dataset_service.update_dataset(tenant_id, data) + except Exception as e: + return server_error_response(e) + + +@manager.get('/') +@manager.auth_required(auth) +def get_dataset_by_id(kb_id): + try: + tenant_id = auth.current_user.id + return dataset_service.get_dataset_by_id(tenant_id, kb_id) + except Exception as e: + return server_error_response(e) + + +@manager.get('/search') +@manager.input(dataset_service.SearchDatasetReq, location='query') +@manager.auth_required(auth) +def get_dataset_by_name(query_data): + try: + tenant_id = auth.current_user.id + return dataset_service.get_dataset_by_name(tenant_id, query_data["name"]) + except Exception as e: + return server_error_response(e) + + +@manager.get('') +@manager.input(dataset_service.QueryDatasetReq, location='query') +@manager.auth_required(auth) +def get_all_datasets(query_data): + try: + tenant_id = auth.current_user.id + return dataset_service.get_all_datasets( + tenant_id, + query_data['page'], + query_data['page_size'], + query_data['orderby'], + query_data['desc'], + ) + except Exception as e: + return server_error_response(e) + + +@manager.delete('/') +@manager.auth_required(auth) +def delete_dataset(kb_id): + try: + tenant_id = auth.current_user.id + dataset_service.delete_dataset(tenant_id, kb_id) + except Exception as e: + return server_error_response(e) diff --git a/api/apps/services/__init__.py b/api/apps/services/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/api/apps/services/dataset_service.py b/api/apps/services/dataset_service.py new file mode 100644 index 00000000000..474845c6e61 --- /dev/null +++ b/api/apps/services/dataset_service.py @@ -0,0 +1,156 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from apiflask import Schema, fields + +from api.db import StatusEnum, FileSource +from api.db.db_models import File +from api.db.services import duplicate_name +from api.db.services.document_service import DocumentService +from api.db.services.file2document_service import File2DocumentService +from api.db.services.file_service import FileService +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.user_service import TenantService +from api.settings import RetCode +from api.utils import get_uuid +from api.utils.api_utils import get_json_result, get_data_error_result + + +class QueryDatasetReq(Schema): + page = fields.Integer(load_default=1) + page_size = fields.Integer(load_default=150) + orderby = fields.String(load_default='create_time') + desc = fields.Boolean(load_default=True) + +class SearchDatasetReq(Schema): + name = fields.String(required=True) + + +class CreateDatasetReq(Schema): + name = fields.String(required=True) + + +class UpdateDatasetReq(Schema): + kb_id = fields.String(required=True) + name = fields.String() + description = fields.String() + permission = fields.String() + parser_id = fields.String() + + +def get_all_datasets(user_id, offset, count, orderby, desc): + tenants = TenantService.get_joined_tenants_by_user_id(user_id) + datasets = KnowledgebaseService.get_by_tenant_ids_by_offset( + [m["tenant_id"] for m in tenants], user_id, int(offset), int(count), orderby, desc) + return get_json_result(data=datasets) + + +def get_tenant_dataset_by_id(tenant_id, kb_id): + kbs = KnowledgebaseService.query(tenant_id=tenant_id, id=kb_id) + if not kbs: + return get_data_error_result(retmsg="Can't find this knowledgebase!") + return get_json_result(data=kbs[0].to_dict()) + + +def get_dataset_by_id(tenant_id, kb_id): + kbs = KnowledgebaseService.query(created_by=tenant_id, id=kb_id) + if not kbs: + return get_data_error_result(retmsg="Can't find this knowledgebase!") + return get_json_result(data=kbs[0].to_dict()) + + +def get_dataset_by_name(tenant_id, kb_name): + e, kb = KnowledgebaseService.get_by_name(kb_name=kb_name, tenant_id=tenant_id) + if not e: + return get_json_result( + data=False, retmsg='You do not own the dataset.', + retcode=RetCode.OPERATING_ERROR) + return get_json_result(data=kb.to_dict()) + + +def create_dataset(tenant_id, data): + kb_name = data["name"].strip() + kb_name = duplicate_name( + KnowledgebaseService.query, + name=kb_name, + tenant_id=tenant_id, + status=StatusEnum.VALID.value + ) + e, t = TenantService.get_by_id(tenant_id) + if not e: + return get_data_error_result(retmsg="Tenant not found.") + kb = { + "id": get_uuid(), + "name": kb_name, + "tenant_id": tenant_id, + "created_by": tenant_id, + "embd_id": t.embd_id, + } + if not KnowledgebaseService.save(**kb): + return get_data_error_result() + return get_json_result(data={"kb_id": kb["id"]}) + + +def update_dataset(tenant_id, data): + kb_name = data["name"].strip() + kb_id = data["kb_id"].strip() + if not KnowledgebaseService.query( + created_by=tenant_id, id=kb_id): + return get_json_result( + data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', + retcode=RetCode.OPERATING_ERROR) + + e, kb = KnowledgebaseService.get_by_id(kb_id) + if not e: + return get_data_error_result( + retmsg="Can't find this knowledgebase!") + + if kb_name.lower() != kb.name.lower() and len( + KnowledgebaseService.query(name=kb_name, tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 1: + return get_data_error_result( + retmsg="Duplicated knowledgebase name.") + + del data["kb_id"] + if not KnowledgebaseService.update_by_id(kb.id, data): + return get_data_error_result() + + e, kb = KnowledgebaseService.get_by_id(kb.id) + if not e: + return get_data_error_result( + retmsg="Database error (Knowledgebase rename)!") + + return get_json_result(data=kb.to_json()) + + +def delete_dataset(tenant_id, kb_id): + kbs = KnowledgebaseService.query(created_by=tenant_id, id=kb_id) + if not kbs: + return get_json_result( + data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', + retcode=RetCode.OPERATING_ERROR) + + for doc in DocumentService.query(kb_id=kb_id): + if not DocumentService.remove_document(doc, kbs[0].tenant_id): + return get_data_error_result( + retmsg="Database error (Document removal)!") + f2d = File2DocumentService.get_by_document_id(doc.id) + FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) + File2DocumentService.delete_by_document_id(doc.id) + + if not KnowledgebaseService.delete_by_id(kb_id): + return get_data_error_result( + retmsg="Database error (Knowledgebase removal)!") + return get_json_result(data=True) diff --git a/requirements.txt b/requirements.txt index 720e44c977c..bf676feab71 100644 --- a/requirements.txt +++ b/requirements.txt @@ -102,3 +102,4 @@ xgboost==2.1.0 xpinyin==0.7.6 yfinance==0.2.43 zhipuai==2.0.1 +apiflask==2.2.1 diff --git a/requirements_arm.txt b/requirements_arm.txt index 90b0de1f039..a0197db3f73 100644 --- a/requirements_arm.txt +++ b/requirements_arm.txt @@ -172,3 +172,4 @@ yfinance==0.2.43 pywencai==0.12.2 akshare==1.14.72 ranx==0.3.20 +apiflask==2.2.1 diff --git a/sdk/python/ragflow/ragflow.py b/sdk/python/ragflow/ragflow.py index 80ac415fc6d..957391d41c6 100644 --- a/sdk/python/ragflow/ragflow.py +++ b/sdk/python/ragflow/ragflow.py @@ -142,3 +142,15 @@ def list_assistants(self) -> List[Assistant]: result_list.append(Assistant(self, data)) return result_list raise Exception(res["retmsg"]) + + def get_all_datasets( + self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True + ) -> List[DataSet]: + res = self.get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc}) + res = res.json() + result_list = [] + if res.get("retmsg") == "success": + for data in res['data']: + result_list.append(DataSet(self, data)) + return result_list + raise Exception(res["retmsg"]) diff --git a/sdk/python/test/test_dataset.py b/sdk/python/test/test_dataset.py index 8c2084a9053..c3ab2082c1d 100644 --- a/sdk/python/test/test_dataset.py +++ b/sdk/python/test/test_dataset.py @@ -22,12 +22,13 @@ def setup_method(self): Delete all the datasets. """ ragflow = RAGFlow(API_KEY, HOST_ADDRESS) - listed_data = ragflow.list_dataset() + listed_data = ragflow.list_datasets() listed_data = listed_data['data'] listed_names = {d['name'] for d in listed_data} for name in listed_names: - ragflow.delete_dataset(name) + print(f'--dataset-- {name}') + # ragflow.delete_dataset(name) # -----------------------create_dataset--------------------------------- def test_create_dataset_with_success(self): @@ -146,7 +147,7 @@ def test_list_dataset_success(self): """ ragflow = RAGFlow(API_KEY, HOST_ADDRESS) # Call the list_datasets method - response = ragflow.list_dataset() + response = ragflow.list_datasets() assert response['code'] == RetCode.SUCCESS def test_list_dataset_with_checking_size_and_name(self): @@ -163,7 +164,7 @@ def test_list_dataset_with_checking_size_and_name(self): dataset_name = response['data']['dataset_name'] real_name_to_create.add(dataset_name) - response = ragflow.list_dataset(0, 3) + response = ragflow.list_datasets(0, 3) listed_data = response['data'] listed_names = {d['name'] for d in listed_data} @@ -185,7 +186,7 @@ def test_list_dataset_with_getting_empty_result(self): dataset_name = response['data']['dataset_name'] real_name_to_create.add(dataset_name) - response = ragflow.list_dataset(0, 0) + response = ragflow.list_datasets(0, 0) listed_data = response['data'] listed_names = {d['name'] for d in listed_data} @@ -208,7 +209,7 @@ def test_list_dataset_with_creating_100_knowledge_bases(self): dataset_name = response['data']['dataset_name'] real_name_to_create.add(dataset_name) - res = ragflow.list_dataset(0, 100) + res = ragflow.list_datasets(0, 100) listed_data = res['data'] listed_names = {d['name'] for d in listed_data} @@ -221,7 +222,7 @@ def test_list_dataset_with_showing_one_dataset(self): Test listing one dataset and verify the size of the dataset. """ ragflow = RAGFlow(API_KEY, HOST_ADDRESS) - response = ragflow.list_dataset(0, 1) + response = ragflow.list_datasets(0, 1) datasets = response['data'] assert len(datasets) == 1 and response['code'] == RetCode.SUCCESS @@ -230,7 +231,7 @@ def test_list_dataset_failure(self): Test listing datasets with IndexError. """ ragflow = RAGFlow(API_KEY, HOST_ADDRESS) - response = ragflow.list_dataset(-1, -1) + response = ragflow.list_datasets(-1, -1) assert "IndexError" in response['message'] and response['code'] == RetCode.EXCEPTION_ERROR def test_list_dataset_for_empty_datasets(self): @@ -238,7 +239,7 @@ def test_list_dataset_for_empty_datasets(self): Test listing datasets when the datasets are empty. """ ragflow = RAGFlow(API_KEY, HOST_ADDRESS) - response = ragflow.list_dataset() + response = ragflow.list_datasets() datasets = response['data'] assert len(datasets) == 0 and response['code'] == RetCode.SUCCESS @@ -263,7 +264,8 @@ def test_delete_dataset_with_not_existing_dataset(self): """ ragflow = RAGFlow(API_KEY, HOST_ADDRESS) res = ragflow.delete_dataset("weird_dataset") - assert res['code'] == RetCode.OPERATING_ERROR and res['message'] == 'The dataset cannot be found for your current account.' + assert res['code'] == RetCode.OPERATING_ERROR and res[ + 'message'] == 'The dataset cannot be found for your current account.' def test_delete_dataset_with_creating_100_datasets_and_deleting_100_datasets(self): """ @@ -346,7 +348,7 @@ def test_delete_dataset_with_name_with_space_in_the_head_and_tail_and_length_exc assert (res['code'] == RetCode.OPERATING_ERROR and res['message'] == 'The dataset cannot be found for your current account.') -# ---------------------------------get_dataset----------------------------------------- + # ---------------------------------get_dataset----------------------------------------- def test_get_dataset_with_success(self): """ @@ -366,7 +368,7 @@ def test_get_dataset_with_failure(self): res = ragflow.get_dataset("weird_dataset") assert res['code'] == RetCode.DATA_ERROR and res['message'] == "Can't find this dataset!" -# ---------------------------------update a dataset----------------------------------- + # ---------------------------------update a dataset----------------------------------- def test_update_dataset_without_existing_dataset(self): """ @@ -435,7 +437,7 @@ def test_update_dataset_with_empty_parameter(self): assert (res['code'] == RetCode.DATA_ERROR and res['message'] == 'Please input at least one parameter that you want to update!') -# ---------------------------------mix the different methods-------------------------- + # ---------------------------------mix the different methods-------------------------- def test_create_and_delete_dataset_together(self): """ @@ -466,3 +468,11 @@ def test_create_and_delete_dataset_together(self): res = ragflow.delete_dataset(name) assert res["code"] == RetCode.SUCCESS + def test_list_dataset_success(self): + """ + Test listing datasets with a successful outcome. + """ + ragflow = RAGFlow(API_KEY, HOST_ADDRESS) + # Call the get_all_datasets method + response = ragflow.get_all_datasets() + assert isinstance(response, list) From 259df2e26b17d5ed73443f64c2cec311e08c6879 Mon Sep 17 00:00:00 2001 From: Valdanito Date: Sat, 14 Sep 2024 11:09:47 +0800 Subject: [PATCH 3/6] refactor(API): Add request data validators --- api/apps/apis/datasets.py | 10 +++++----- api/apps/services/dataset_service.py | 17 +++++++++++------ 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/api/apps/apis/datasets.py b/api/apps/apis/datasets.py index 507dcd592ef..2103f7dd01f 100644 --- a/api/apps/apis/datasets.py +++ b/api/apps/apis/datasets.py @@ -22,10 +22,10 @@ @manager.post('') @manager.input(dataset_service.CreateDatasetReq, location='json') @manager.auth_required(auth) -def create_dataset(data): +def create_dataset(json_data): try: tenant_id = auth.current_user.id - return dataset_service.create_dataset(tenant_id, data) + return dataset_service.create_dataset(tenant_id, json_data) except Exception as e: return server_error_response(e) @@ -33,10 +33,10 @@ def create_dataset(data): @manager.put('') @manager.input(dataset_service.UpdateDatasetReq, location='json') @manager.auth_required(auth) -def update_dataset(data): +def update_dataset(json_data): try: tenant_id = auth.current_user.id - return dataset_service.update_dataset(tenant_id, data) + return dataset_service.update_dataset(tenant_id, json_data) except Exception as e: return server_error_response(e) @@ -84,6 +84,6 @@ def get_all_datasets(query_data): def delete_dataset(kb_id): try: tenant_id = auth.current_user.id - dataset_service.delete_dataset(tenant_id, kb_id) + return dataset_service.delete_dataset(tenant_id, kb_id) except Exception as e: return server_error_response(e) diff --git a/api/apps/services/dataset_service.py b/api/apps/services/dataset_service.py index 474845c6e61..d6c53cef5b0 100644 --- a/api/apps/services/dataset_service.py +++ b/api/apps/services/dataset_service.py @@ -14,9 +14,9 @@ # limitations under the License. # -from apiflask import Schema, fields +from apiflask import Schema, fields, validators -from api.db import StatusEnum, FileSource +from api.db import StatusEnum, FileSource, ParserType from api.db.db_models import File from api.db.services import duplicate_name from api.db.services.document_service import DocumentService @@ -35,6 +35,7 @@ class QueryDatasetReq(Schema): orderby = fields.String(load_default='create_time') desc = fields.Boolean(load_default=True) + class SearchDatasetReq(Schema): name = fields.String(required=True) @@ -45,10 +46,14 @@ class CreateDatasetReq(Schema): class UpdateDatasetReq(Schema): kb_id = fields.String(required=True) - name = fields.String() - description = fields.String() - permission = fields.String() - parser_id = fields.String() + name = fields.String(validate=validators.Length(min=1, max=128)) + description = fields.String(allow_none=True) + permission = fields.String(validate=validators.OneOf(['me', 'team'])) + embd_id = fields.String(validate=validators.Length(min=1, max=128)) + language = fields.String(validate=validators.OneOf(['Chinese', 'English'])) + parser_id = fields.String(validate=validators.OneOf([parser_type.value for parser_type in ParserType])) + parser_config = fields.Dict() + avatar = fields.String() def get_all_datasets(user_id, offset, count, orderby, desc): From 4e0f463720de09333f7f60c11410b9d6d9c27150 Mon Sep 17 00:00:00 2001 From: Valdanito Date: Sat, 14 Sep 2024 16:27:55 +0800 Subject: [PATCH 4/6] refactor(API): Add http_basic_auth_required --- api/apps/__init__.py | 16 ++++++++++------ api/apps/apis/datasets.py | 29 +++++++++++++++-------------- api/utils/api_utils.py | 22 +++++++++++++++++++++- sdk/python/test/test_dataset.py | 10 +++++----- 4 files changed, 51 insertions(+), 26 deletions(-) diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 425285afe1b..fdda82830e0 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -30,7 +30,7 @@ from api.db import StatusEnum from api.db.db_models import close_connection, APIToken from api.db.services import UserService -from api.settings import API_VERSION, access_logger +from api.settings import API_VERSION, access_logger, RAG_FLOW_SERVICE_NAME from api.settings import SECRET_KEY, stat_logger from api.utils import CustomJSONEncoder, commands from api.utils.api_utils import server_error_response @@ -43,8 +43,8 @@ Request.json = property(lambda self: self.get_json(force=True, silent=True)) -app = APIFlask(__name__) -auth = HTTPTokenAuth() +app = APIFlask(__name__, title=RAG_FLOW_SERVICE_NAME, version=API_VERSION, docs_path=f'/{API_VERSION}/docs') +http_token_auth = HTTPTokenAuth() class AuthUser: @@ -56,7 +56,7 @@ def get_token(self): return self.token -@auth.verify_token +@http_token_auth.verify_token def verify_token(token: str) -> Union[AuthUser, None]: try: objs = APIToken.query(token=token) @@ -109,8 +109,12 @@ def register_page(page_path): sys.modules[module_name] = page spec.loader.exec_module(page) page_name = getattr(page, 'page_name', page_name) - url_prefix = f'/api/{API_VERSION}/{page_name}' if "/sdk/" in path or "/apis/" in path \ - else f'/{API_VERSION}/{page_name}' + if "/sdk/" in path or "/apis/" in path: + url_prefix = f'/api/{API_VERSION}/{page_name}' + # elif "/apis/" in path: + # url_prefix = f'/{API_VERSION}/api/{page_name}' + else: + url_prefix = f'/{API_VERSION}/{page_name}' app.register_blueprint(page.manager, url_prefix=url_prefix) return url_prefix diff --git a/api/apps/apis/datasets.py b/api/apps/apis/datasets.py index 2103f7dd01f..ff96a3977a1 100644 --- a/api/apps/apis/datasets.py +++ b/api/apps/apis/datasets.py @@ -14,17 +14,17 @@ # limitations under the License. # -from api.apps import auth +from api.apps import http_token_auth from api.apps.services import dataset_service -from api.utils.api_utils import server_error_response +from api.utils.api_utils import server_error_response, http_basic_auth_required @manager.post('') @manager.input(dataset_service.CreateDatasetReq, location='json') -@manager.auth_required(auth) +@manager.auth_required(http_token_auth) def create_dataset(json_data): try: - tenant_id = auth.current_user.id + tenant_id = http_token_auth.current_user.id return dataset_service.create_dataset(tenant_id, json_data) except Exception as e: return server_error_response(e) @@ -32,20 +32,20 @@ def create_dataset(json_data): @manager.put('') @manager.input(dataset_service.UpdateDatasetReq, location='json') -@manager.auth_required(auth) +@manager.auth_required(http_token_auth) def update_dataset(json_data): try: - tenant_id = auth.current_user.id + tenant_id = http_token_auth.current_user.id return dataset_service.update_dataset(tenant_id, json_data) except Exception as e: return server_error_response(e) @manager.get('/') -@manager.auth_required(auth) +@manager.auth_required(http_token_auth) def get_dataset_by_id(kb_id): try: - tenant_id = auth.current_user.id + tenant_id = http_token_auth.current_user.id return dataset_service.get_dataset_by_id(tenant_id, kb_id) except Exception as e: return server_error_response(e) @@ -53,10 +53,10 @@ def get_dataset_by_id(kb_id): @manager.get('/search') @manager.input(dataset_service.SearchDatasetReq, location='query') -@manager.auth_required(auth) +@manager.auth_required(http_token_auth) def get_dataset_by_name(query_data): try: - tenant_id = auth.current_user.id + tenant_id = http_token_auth.current_user.id return dataset_service.get_dataset_by_name(tenant_id, query_data["name"]) except Exception as e: return server_error_response(e) @@ -64,10 +64,11 @@ def get_dataset_by_name(query_data): @manager.get('') @manager.input(dataset_service.QueryDatasetReq, location='query') -@manager.auth_required(auth) +@http_basic_auth_required +@manager.auth_required(http_token_auth) def get_all_datasets(query_data): try: - tenant_id = auth.current_user.id + tenant_id = http_token_auth.current_user.id return dataset_service.get_all_datasets( tenant_id, query_data['page'], @@ -80,10 +81,10 @@ def get_all_datasets(query_data): @manager.delete('/') -@manager.auth_required(auth) +@manager.auth_required(http_token_auth) def delete_dataset(kb_id): try: - tenant_id = auth.current_user.id + tenant_id = http_token_auth.current_user.id return dataset_service.delete_dataset(tenant_id, kb_id) except Exception as e: return server_error_response(e) diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index c5b93d56f0a..105038cca26 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -27,8 +27,10 @@ import requests from flask import ( Response, jsonify, send_file, make_response, - request as flask_request, + request as flask_request, current_app, ) +from flask_login import current_user +from flask_login.config import EXEMPT_METHODS from werkzeug.http import HTTP_STATUS_CODES from api.db.db_models import APIToken @@ -288,3 +290,21 @@ def decorated_function(*args, **kwargs): return func(*args, **kwargs) return decorated_function + + +def http_basic_auth_required(func): + @wraps(func) + def decorated_view(*args, **kwargs): + if 'Authorization' in flask_request.headers: + # 如果请求中包含 token,则跳过用户名密码验证 + return func(*args, **kwargs) + if flask_request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"): + pass + elif not current_user.is_authenticated: + return current_app.login_manager.unauthorized() + + if callable(getattr(current_app, "ensure_sync", None)): + return current_app.ensure_sync(func)(*args, **kwargs) + return func(*args, **kwargs) + + return decorated_view diff --git a/sdk/python/test/test_dataset.py b/sdk/python/test/test_dataset.py index c3ab2082c1d..55ca0db863b 100644 --- a/sdk/python/test/test_dataset.py +++ b/sdk/python/test/test_dataset.py @@ -22,12 +22,12 @@ def setup_method(self): Delete all the datasets. """ ragflow = RAGFlow(API_KEY, HOST_ADDRESS) - listed_data = ragflow.list_datasets() - listed_data = listed_data['data'] + # listed_data = ragflow.list_datasets() + # listed_data = listed_data['data'] - listed_names = {d['name'] for d in listed_data} - for name in listed_names: - print(f'--dataset-- {name}') + # listed_names = {d['name'] for d in listed_data} + # for name in listed_names: + # print(f'--dataset-- {name}') # ragflow.delete_dataset(name) # -----------------------create_dataset--------------------------------- From 905319830da755f9f86a07095990238587967534 Mon Sep 17 00:00:00 2001 From: Valdanito Date: Sat, 14 Sep 2024 18:09:52 +0800 Subject: [PATCH 5/6] refactor(API): Add source code comments --- api/apps/__init__.py | 5 +++++ api/apps/apis/datasets.py | 6 ++++++ api/utils/api_utils.py | 2 +- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/api/apps/__init__.py b/api/apps/__init__.py index fdda82830e0..3d02b847bdc 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -43,10 +43,13 @@ Request.json = property(lambda self: self.get_json(force=True, silent=True)) +# Integrate APIFlask: Flask class -> APIFlask class. app = APIFlask(__name__, title=RAG_FLOW_SERVICE_NAME, version=API_VERSION, docs_path=f'/{API_VERSION}/docs') +# Integrate APIFlask: Use apiflask.HTTPTokenAuth for the HTTP Bearer or API Keys authentication. http_token_auth = HTTPTokenAuth() +# Current logged-in user class class AuthUser: def __init__(self, tenant_id, token): self.id = tenant_id @@ -56,6 +59,7 @@ def get_token(self): return self.token +# Verify if the token is valid @http_token_auth.verify_token def verify_token(token: str) -> Union[AuthUser, None]: try: @@ -105,6 +109,7 @@ def register_page(page_path): spec = spec_from_file_location(module_name, page_path) page = module_from_spec(spec) page.app = app + # Integrate APIFlask: Blueprint class -> APIBlueprint class page.manager = APIBlueprint(page_name, module_name) sys.modules[module_name] = page spec.loader.exec_module(page) diff --git a/api/apps/apis/datasets.py b/api/apps/apis/datasets.py index ff96a3977a1..73c71af981c 100644 --- a/api/apps/apis/datasets.py +++ b/api/apps/apis/datasets.py @@ -23,6 +23,7 @@ @manager.input(dataset_service.CreateDatasetReq, location='json') @manager.auth_required(http_token_auth) def create_dataset(json_data): + """Creates a new Dataset(Knowledgebase).""" try: tenant_id = http_token_auth.current_user.id return dataset_service.create_dataset(tenant_id, json_data) @@ -34,6 +35,7 @@ def create_dataset(json_data): @manager.input(dataset_service.UpdateDatasetReq, location='json') @manager.auth_required(http_token_auth) def update_dataset(json_data): + """Updates a Dataset(Knowledgebase).""" try: tenant_id = http_token_auth.current_user.id return dataset_service.update_dataset(tenant_id, json_data) @@ -44,6 +46,7 @@ def update_dataset(json_data): @manager.get('/') @manager.auth_required(http_token_auth) def get_dataset_by_id(kb_id): + """Query Dataset(Knowledgebase) by Dataset(Knowledgebase) ID.""" try: tenant_id = http_token_auth.current_user.id return dataset_service.get_dataset_by_id(tenant_id, kb_id) @@ -55,6 +58,7 @@ def get_dataset_by_id(kb_id): @manager.input(dataset_service.SearchDatasetReq, location='query') @manager.auth_required(http_token_auth) def get_dataset_by_name(query_data): + """Query Dataset(Knowledgebase) by Dataset(Knowledgebase) Name.""" try: tenant_id = http_token_auth.current_user.id return dataset_service.get_dataset_by_name(tenant_id, query_data["name"]) @@ -67,6 +71,7 @@ def get_dataset_by_name(query_data): @http_basic_auth_required @manager.auth_required(http_token_auth) def get_all_datasets(query_data): + """Query all Datasets(Knowledgebase)""" try: tenant_id = http_token_auth.current_user.id return dataset_service.get_all_datasets( @@ -83,6 +88,7 @@ def get_all_datasets(query_data): @manager.delete('/') @manager.auth_required(http_token_auth) def delete_dataset(kb_id): + """Deletes a Dataset(Knowledgebase).""" try: tenant_id = http_token_auth.current_user.id return dataset_service.delete_dataset(tenant_id, kb_id) diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 105038cca26..cbe0343b35a 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -296,7 +296,7 @@ def http_basic_auth_required(func): @wraps(func) def decorated_view(*args, **kwargs): if 'Authorization' in flask_request.headers: - # 如果请求中包含 token,则跳过用户名密码验证 + # If the request header contains a token, skip username and password verification return func(*args, **kwargs) if flask_request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"): pass From 5a173072cdcb61003a76437cafa3b56ebff14bae Mon Sep 17 00:00:00 2001 From: Valdanito Date: Wed, 18 Sep 2024 14:40:33 +0800 Subject: [PATCH 6/6] refactor(API): Refactor documentss API --- api/apps/apis/documents.py | 64 ++++++++++ api/apps/services/document_service.py | 161 ++++++++++++++++++++++++++ sdk/python/ragflow/ragflow.py | 68 +++++++++-- 3 files changed, 285 insertions(+), 8 deletions(-) create mode 100644 api/apps/apis/documents.py create mode 100644 api/apps/services/document_service.py diff --git a/api/apps/apis/documents.py b/api/apps/apis/documents.py new file mode 100644 index 00000000000..312a4cdff9d --- /dev/null +++ b/api/apps/apis/documents.py @@ -0,0 +1,64 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from api.apps import http_token_auth +from api.apps.services import document_service +from api.utils.api_utils import server_error_response + + +@manager.route('/change_parser', methods=['POST']) +@manager.input(document_service.ChangeDocumentParserReq, location='json') +@manager.auth_required(http_token_auth) +def change_document_parser(json_data): + """Change document file parser.""" + try: + return document_service.change_document_parser(json_data) + except Exception as e: + return server_error_response(e) + + +@manager.route('/run', methods=['POST']) +@manager.input(document_service.RunParsingReq, location='json') +@manager.auth_required(http_token_auth) +def run_parsing(json_data): + """Run parsing documents file.""" + try: + return document_service.run_parsing(json_data) + except Exception as e: + return server_error_response(e) + + +@manager.post('/upload') +@manager.input(document_service.UploadDocumentsReq, location='form_and_files') +@manager.auth_required(http_token_auth) +def upload_documents_2_dataset(form_and_files_data): + """Upload documents file a Dataset(Knowledgebase).""" + try: + tenant_id = http_token_auth.current_user.id + return document_service.upload_documents_2_dataset(form_and_files_data, tenant_id) + except Exception as e: + return server_error_response(e) + + +@manager.get('') +@manager.input(document_service.QueryDocumentsReq, location='query') +@manager.auth_required(http_token_auth) +def get_all_documents(query_data): + """Query documents file in Dataset(Knowledgebase).""" + try: + tenant_id = http_token_auth.current_user.id + return document_service.get_all_documents(query_data, tenant_id) + except Exception as e: + return server_error_response(e) diff --git a/api/apps/services/document_service.py b/api/apps/services/document_service.py new file mode 100644 index 00000000000..12be2599d03 --- /dev/null +++ b/api/apps/services/document_service.py @@ -0,0 +1,161 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import re + +from apiflask import Schema, fields, validators +from elasticsearch_dsl import Q + +from api.db import FileType, TaskStatus, ParserType +from api.db.db_models import Task +from api.db.services.document_service import DocumentService +from api.db.services.file2document_service import File2DocumentService +from api.db.services.file_service import FileService +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.task_service import TaskService, queue_tasks +from api.db.services.user_service import UserTenantService +from api.settings import RetCode +from api.utils.api_utils import get_data_error_result +from api.utils.api_utils import get_json_result +from rag.nlp import search +from rag.utils.es_conn import ELASTICSEARCH + + +class QueryDocumentsReq(Schema): + kb_id = fields.String(required=True, error='Invalid kb_id parameter!') + keywords = fields.String(load_default='') + page = fields.Integer(load_default=1) + page_size = fields.Integer(load_default=150) + orderby = fields.String(load_default='create_time') + desc = fields.Boolean(load_default=True) + + +class ChangeDocumentParserReq(Schema): + doc_id = fields.String(required=True) + parser_id = fields.String( + required=True, validate=validators.OneOf([parser_type.value for parser_type in ParserType]) + ) + parser_config = fields.Dict() + + +class RunParsingReq(Schema): + doc_ids = fields.List(required=True) + run = fields.Integer(default=1) + + +class UploadDocumentsReq(Schema): + kb_id = fields.String(required=True) + file = fields.File(required=True) + + +def get_all_documents(query_data, tenant_id): + kb_id = query_data["kb_id"] + tenants = UserTenantService.query(user_id=tenant_id) + for tenant in tenants: + if KnowledgebaseService.query( + tenant_id=tenant.tenant_id, id=kb_id): + break + else: + return get_json_result( + data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', + retcode=RetCode.OPERATING_ERROR) + keywords = query_data["keywords"] + + page_number = query_data["page"] + items_per_page = query_data["page_size"] + orderby = query_data["orderby"] + desc = query_data["desc"] + docs, tol = DocumentService.get_by_kb_id( + kb_id, page_number, items_per_page, orderby, desc, keywords) + return get_json_result(data={"total": tol, "docs": docs}) + + +def upload_documents_2_dataset(form_and_files_data, tenant_id): + file_objs = form_and_files_data['file'] + dataset_id = form_and_files_data['kb_id'] + for file_obj in file_objs: + if file_obj.filename == '': + return get_json_result( + data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) + e, kb = KnowledgebaseService.get_by_id(dataset_id) + if not e: + raise LookupError(f"Can't find the knowledgebase with ID {dataset_id}!") + err, _ = FileService.upload_document(kb, file_objs, tenant_id) + if err: + return get_json_result( + data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR) + return get_json_result(data=True) + + +def change_document_parser(json_data): + e, doc = DocumentService.get_by_id(json_data["doc_id"]) + if not e: + return get_data_error_result(retmsg="Document not found!") + if doc.parser_id.lower() == json_data["parser_id"].lower(): + if "parser_config" in json_data: + if json_data["parser_config"] == doc.parser_config: + return get_json_result(data=True) + else: + return get_json_result(data=True) + + if doc.type == FileType.VISUAL or re.search( + r"\.(ppt|pptx|pages)$", doc.name): + return get_data_error_result(retmsg="Not supported yet!") + + e = DocumentService.update_by_id(doc.id, + {"parser_id": json_data["parser_id"], "progress": 0, "progress_msg": "", + "run": TaskStatus.UNSTART.value}) + if not e: + return get_data_error_result(retmsg="Document not found!") + if "parser_config" in json_data: + DocumentService.update_parser_config(doc.id, json_data["parser_config"]) + if doc.token_num > 0: + e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, + doc.process_duation * -1) + if not e: + return get_data_error_result(retmsg="Document not found!") + tenant_id = DocumentService.get_tenant_id(json_data["doc_id"]) + if not tenant_id: + return get_data_error_result(retmsg="Tenant not found!") + ELASTICSEARCH.deleteByQuery( + Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) + + return get_json_result(data=True) + + +def run_parsing(json_data): + for id in json_data["doc_ids"]: + run = str(json_data["run"]) + info = {"run": run, "progress": 0} + if run == TaskStatus.RUNNING.value: + info["progress_msg"] = "" + info["chunk_num"] = 0 + info["token_num"] = 0 + DocumentService.update_by_id(id, info) + tenant_id = DocumentService.get_tenant_id(id) + if not tenant_id: + return get_data_error_result(retmsg="Tenant not found!") + ELASTICSEARCH.deleteByQuery( + Q("match", doc_id=id), idxnm=search.index_name(tenant_id)) + + if run == TaskStatus.RUNNING.value: + TaskService.filter_delete([Task.doc_id == id]) + e, doc = DocumentService.get_by_id(id) + doc = doc.to_dict() + doc["tenant_id"] = tenant_id + bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"]) + queue_tasks(doc, bucket, name) + + return get_json_result(data=True) diff --git a/sdk/python/ragflow/ragflow.py b/sdk/python/ragflow/ragflow.py index cca8b77338e..6cda66e9147 100644 --- a/sdk/python/ragflow/ragflow.py +++ b/sdk/python/ragflow/ragflow.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Union import requests @@ -80,11 +80,66 @@ def get_all_datasets( res = self.get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc}) res = res.json() - result_list = [] if res.get("retmsg") == "success": - for data in res['data']: - result_list.append(DataSet(self, data)) - return result_list + return res['data'] + raise Exception(res["retmsg"]) + + def get_dataset_by_name(self, name: str) -> List[DataSet]: + res = self.get("/datasets/search", + {"name": name}) + res = res.json() + if res.get("retmsg") == "success": + return res['data'] + raise Exception(res["retmsg"]) + + def change_document_parser(self, doc_id: str, parser_id: str, parser_config: dict): + res = self.post( + "/documents/change_parser", + { + "doc_id": doc_id, + "parser_id": parser_id, + "parser_config": parser_config, + } + ) + res = res.json() + if res.get("retmsg") == "success": + return res['data'] + raise Exception(res["retmsg"]) + + def upload_documents_2_dataset(self, kb_id: str, files: Union[dict, List[bytes]]): + files_data = {} + if isinstance(files, dict): + files_data = files + elif isinstance(files, list): + for idx, file in enumerate(files): + files_data[f'file_{idx}'] = file + else: + files_data['file'] = files + data = { + 'kb_id': kb_id, + } + res = requests.post(url=self.api_url + "/documents/upload", data=data, files=files_data) + res = res.json() + if res.get("retmsg") == "success": + return res['data'] + raise Exception(res["retmsg"]) + + def documents_run_parsing(self, doc_ids: list): + res = self.post("/documents/run", + {"doc_ids": doc_ids}) + res = res.json() + if res.get("retmsg") == "success": + return res['data'] + raise Exception(res["retmsg"]) + + def get_all_documents( + self, keywords: str = '', page: int = 1, page_size: int = 1024, + orderby: str = "create_time", desc: bool = True): + res = self.get("/documents", + {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc}) + res = res.json() + if res.get("retmsg") == "success": + return res['data'] raise Exception(res["retmsg"]) def get_dataset(self, id: str = None, name: str = None) -> DataSet: @@ -220,7 +275,6 @@ def async_cancel_parse_documents(self, doc_ids): raise ValueError("doc_ids must be a non-empty list of document IDs") data = {"doc_ids": doc_ids, "run": 2} - res = self.post(f'/doc/run', data) if res.status_code != 200: @@ -231,5 +285,3 @@ def async_cancel_parse_documents(self, doc_ids): except Exception as e: print(f"Error occurred during canceling parsing for documents: {str(e)}") raise - -