From c91e8b1737f28c9997ce7bd9c1e8ac4e377ef1b7 Mon Sep 17 00:00:00 2001 From: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Date: Tue, 24 Dec 2024 16:26:47 +0800 Subject: [PATCH 01/65] fix: modal bg color (#12042) --- web/app/components/base/modal/index.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/base/modal/index.tsx b/web/app/components/base/modal/index.tsx index 3040cdb00b502a..26cde5fce3dd6a 100644 --- a/web/app/components/base/modal/index.tsx +++ b/web/app/components/base/modal/index.tsx @@ -39,7 +39,7 @@ export default function Modal({ leaveFrom="opacity-100" leaveTo="opacity-0" > -
+
Date: Tue, 24 Dec 2024 18:38:51 +0800 Subject: [PATCH 02/65] feat: mypy for all type check (#10921) --- .github/workflows/api-tests.yml | 6 + api/commands.py | 13 +- api/configs/feature/__init__.py | 14 +- api/configs/middleware/__init__.py | 4 - .../remote_settings_sources/apollo/client.py | 5 +- api/constants/model_template.py | 3 +- api/controllers/common/fields.py | 2 +- api/controllers/console/__init__.py | 95 +++++++- api/controllers/console/admin.py | 2 +- api/controllers/console/apikey.py | 23 +- .../console/app/advanced_prompt_template.py | 2 +- api/controllers/console/app/agent.py | 2 +- api/controllers/console/app/annotation.py | 6 +- api/controllers/console/app/app.py | 4 +- api/controllers/console/app/app_import.py | 4 +- api/controllers/console/app/audio.py | 2 +- api/controllers/console/app/completion.py | 4 +- api/controllers/console/app/conversation.py | 15 +- .../console/app/conversation_variables.py | 2 +- api/controllers/console/app/generator.py | 4 +- api/controllers/console/app/message.py | 6 +- api/controllers/console/app/model_config.py | 13 +- api/controllers/console/app/ops_trace.py | 2 +- api/controllers/console/app/site.py | 6 +- api/controllers/console/app/statistic.py | 4 +- api/controllers/console/app/workflow.py | 2 +- .../console/app/workflow_app_log.py | 4 +- api/controllers/console/app/workflow_run.py | 4 +- .../console/app/workflow_statistic.py | 4 +- api/controllers/console/app/wraps.py | 2 +- api/controllers/console/auth/activate.py | 6 +- .../console/auth/data_source_bearer_auth.py | 4 +- .../console/auth/data_source_oauth.py | 8 +- .../console/auth/forgot_password.py | 6 +- api/controllers/console/auth/login.py | 4 +- api/controllers/console/auth/oauth.py | 7 +- api/controllers/console/billing/billing.py | 4 +- .../console/datasets/data_source.py | 4 +- api/controllers/console/datasets/datasets.py | 6 +- .../console/datasets/datasets_document.py | 10 +- .../console/datasets/datasets_segments.py | 4 +- api/controllers/console/datasets/external.py | 4 +- .../console/datasets/hit_testing.py | 2 +- .../console/datasets/hit_testing_base.py | 4 +- api/controllers/console/datasets/website.py | 2 +- api/controllers/console/explore/audio.py | 9 +- api/controllers/console/explore/completion.py | 23 +- .../console/explore/conversation.py | 32 +-- .../console/explore/installed_app.py | 11 +- api/controllers/console/explore/message.py | 25 +- api/controllers/console/explore/parameter.py | 2 +- .../console/explore/recommended_app.py | 4 +- .../console/explore/saved_message.py | 6 +- api/controllers/console/explore/workflow.py | 9 +- api/controllers/console/explore/wraps.py | 4 +- api/controllers/console/extension.py | 4 +- api/controllers/console/feature.py | 4 +- api/controllers/console/files.py | 9 +- api/controllers/console/init_validate.py | 2 +- api/controllers/console/ping.py | 2 +- api/controllers/console/remote_files.py | 4 +- api/controllers/console/setup.py | 2 +- api/controllers/console/tag/tags.py | 6 +- api/controllers/console/version.py | 2 +- api/controllers/console/workspace/account.py | 4 +- .../workspace/load_balancing_config.py | 6 +- api/controllers/console/workspace/members.py | 31 +-- .../console/workspace/model_providers.py | 9 +- api/controllers/console/workspace/models.py | 6 +- .../console/workspace/tool_providers.py | 4 +- .../console/workspace/workspace.py | 14 +- api/controllers/console/wraps.py | 6 +- api/controllers/files/image_preview.py | 2 +- api/controllers/files/tool_files.py | 2 +- .../inner_api/workspace/workspace.py | 2 +- api/controllers/inner_api/wraps.py | 6 +- api/controllers/service_api/app/app.py | 2 +- api/controllers/service_api/app/audio.py | 4 +- api/controllers/service_api/app/completion.py | 2 +- .../service_api/app/conversation.py | 4 +- api/controllers/service_api/app/file.py | 2 +- api/controllers/service_api/app/message.py | 4 +- api/controllers/service_api/app/workflow.py | 4 +- .../service_api/dataset/dataset.py | 2 +- .../service_api/dataset/document.py | 2 +- .../service_api/dataset/segment.py | 4 +- api/controllers/service_api/index.py | 2 +- api/controllers/service_api/wraps.py | 10 +- api/controllers/web/app.py | 2 +- api/controllers/web/audio.py | 4 +- api/controllers/web/completion.py | 2 +- api/controllers/web/conversation.py | 4 +- api/controllers/web/feature.py | 2 +- api/controllers/web/files.py | 4 +- api/controllers/web/message.py | 4 +- api/controllers/web/passport.py | 2 +- api/controllers/web/remote_files.py | 2 +- api/controllers/web/saved_message.py | 4 +- api/controllers/web/site.py | 2 +- api/controllers/web/workflow.py | 2 +- api/controllers/web/wraps.py | 2 +- api/core/agent/base_agent_runner.py | 47 ++-- api/core/agent/cot_agent_runner.py | 82 ++++--- api/core/agent/cot_chat_agent_runner.py | 6 + api/core/agent/cot_completion_agent_runner.py | 21 +- api/core/agent/entities.py | 2 +- api/core/agent/fc_agent_runner.py | 46 ++-- .../agent/output_parser/cot_output_parser.py | 10 +- .../easy_ui_based_app/dataset/manager.py | 2 + .../easy_ui_based_app/model_config/manager.py | 2 +- .../features/opening_statement/manager.py | 4 +- .../app/apps/advanced_chat/app_generator.py | 5 +- .../app_generator_tts_publisher.py | 23 +- api/core/app/apps/advanced_chat/app_runner.py | 10 +- .../advanced_chat/generate_task_pipeline.py | 28 ++- .../app/apps/agent_chat/app_config_manager.py | 2 +- api/core/app/apps/agent_chat/app_generator.py | 13 +- api/core/app/apps/agent_chat/app_runner.py | 19 +- .../agent_chat/generate_response_converter.py | 14 +- api/core/app/apps/base_app_queue_manager.py | 2 +- api/core/app/apps/base_app_runner.py | 45 ++-- api/core/app/apps/chat/app_generator.py | 11 +- .../apps/chat/generate_response_converter.py | 10 +- .../app/apps/completion/app_config_manager.py | 2 +- api/core/app/apps/completion/app_generator.py | 18 +- api/core/app/apps/completion/app_runner.py | 4 +- .../completion/generate_response_converter.py | 10 +- .../app/apps/message_based_app_generator.py | 12 +- api/core/app/apps/workflow/app_generator.py | 2 +- .../workflow/generate_response_converter.py | 12 +- api/core/app/apps/workflow_app_runner.py | 14 +- api/core/app/entities/app_invoke_entities.py | 4 +- api/core/app/entities/queue_entities.py | 2 +- api/core/app/entities/task_entities.py | 14 +- .../annotation_reply/annotation_reply.py | 2 +- .../app/features/rate_limiting/rate_limit.py | 2 +- .../based_generate_task_pipeline.py | 2 + .../easy_ui_based_generate_task_pipeline.py | 51 ++-- .../app/task_pipeline/message_cycle_manage.py | 2 +- .../task_pipeline/workflow_cycle_manage.py | 22 +- .../agent_tool_callback_handler.py | 2 +- .../index_tool_callback_handler.py | 17 +- api/core/entities/model_entities.py | 3 +- api/core/entities/provider_configuration.py | 110 +++++---- .../api_based_extension_requestor.py | 6 +- api/core/extension/extensible.py | 9 +- api/core/extension/extension.py | 10 +- api/core/external_data_tool/api/api.py | 3 + .../external_data_tool/external_data_fetch.py | 24 +- api/core/external_data_tool/factory.py | 10 +- api/core/file/file_manager.py | 5 +- api/core/file/tool_file_parser.py | 4 +- .../helper/code_executor/code_executor.py | 16 +- .../code_executor/jinja2/jinja2_formatter.py | 7 +- .../code_executor/template_transformer.py | 3 +- api/core/helper/lru_cache.py | 2 +- api/core/helper/model_provider_cache.py | 2 +- api/core/helper/moderation.py | 4 +- api/core/helper/module_import_helper.py | 13 +- api/core/helper/tool_parameter_cache.py | 2 +- api/core/helper/tool_provider_cache.py | 2 +- api/core/hosting_configuration.py | 14 +- api/core/indexing_runner.py | 66 ++--- api/core/llm_generator/llm_generator.py | 89 ++++--- api/core/memory/token_buffer_memory.py | 2 +- api/core/model_manager.py | 149 +++++++----- .../callbacks/logging_callback.py | 13 +- .../entities/message_entities.py | 3 +- .../model_providers/__base/ai_model.py | 13 +- .../__base/large_language_model.py | 14 +- .../model_providers/__base/model_provider.py | 3 +- .../__base/text_embedding_model.py | 6 +- .../__base/tokenizers/gpt2_tokenzier.py | 4 +- .../model_providers/__base/tts_model.py | 9 +- .../azure_openai/speech2text/speech2text.py | 4 +- .../model_providers/azure_openai/tts/tts.py | 2 + .../model_providers/bedrock/llm/llm.py | 6 +- .../model_providers/cohere/rerank/rerank.py | 4 +- .../model_providers/fireworks/_common.py | 4 +- .../text_embedding/text_embedding.py | 3 +- .../model_providers/gitee_ai/_common.py | 2 +- .../model_providers/gitee_ai/rerank/rerank.py | 4 +- .../gitee_ai/text_embedding/text_embedding.py | 2 +- .../model_providers/gitee_ai/tts/tts.py | 8 +- .../model_providers/google/llm/llm.py | 2 +- .../huggingface_hub/_common.py | 2 +- .../huggingface_hub/llm/llm.py | 6 +- .../text_embedding/text_embedding.py | 2 +- .../model_providers/hunyuan/llm/llm.py | 12 +- .../hunyuan/text_embedding/text_embedding.py | 10 +- .../jina/text_embedding/jina_tokenizer.py | 4 +- .../minimax/llm/chat_completion.py | 30 +-- .../minimax/llm/chat_completion_pro.py | 26 +- .../model_providers/minimax/llm/types.py | 4 +- .../nomic/text_embedding/text_embedding.py | 4 +- .../model_providers/oci/llm/llm.py | 4 +- .../oci/text_embedding/text_embedding.py | 2 +- .../ollama/text_embedding/text_embedding.py | 1 + .../model_providers/openai/_common.py | 4 +- .../openai/moderation/moderation.py | 6 +- .../model_providers/openai/openai.py | 3 +- .../speech2text/speech2text.py | 1 + .../text_embedding/text_embedding.py | 1 + .../openai_api_compatible/tts/tts.py | 1 + .../openllm/llm/openllm_generate.py | 16 +- .../text_embedding/text_embedding.py | 7 +- .../model_providers/replicate/_common.py | 2 +- .../model_providers/replicate/llm/llm.py | 6 +- .../text_embedding/text_embedding.py | 6 +- .../model_providers/sagemaker/llm/llm.py | 8 +- .../sagemaker/rerank/rerank.py | 3 +- .../sagemaker/speech2text/speech2text.py | 3 +- .../text_embedding/text_embedding.py | 3 +- .../model_providers/sagemaker/tts/tts.py | 2 +- .../model_providers/siliconflow/llm/llm.py | 2 +- .../model_providers/spark/llm/llm.py | 4 +- .../model_providers/togetherai/llm/llm.py | 3 +- .../model_providers/tongyi/_common.py | 2 +- .../model_providers/tongyi/llm/llm.py | 6 +- .../model_providers/tongyi/rerank/rerank.py | 8 +- .../tongyi/text_embedding/text_embedding.py | 2 +- .../model_providers/tongyi/tts/tts.py | 8 +- .../model_providers/upstage/_common.py | 4 +- .../model_providers/upstage/llm/llm.py | 2 +- .../upstage/text_embedding/text_embedding.py | 5 +- .../model_providers/vertex_ai/_common.py | 2 +- .../model_providers/vertex_ai/llm/llm.py | 2 +- .../model_providers/vessl_ai/llm/llm.py | 4 +- .../model_providers/volcengine_maas/client.py | 12 +- .../volcengine_maas/legacy/errors.py | 3 +- .../volcengine_maas/llm/llm.py | 2 +- .../volcengine_maas/llm/models.py | 6 +- .../model_providers/wenxin/llm/ernie_bot.py | 5 +- .../wenxin/text_embedding/text_embedding.py | 9 +- .../model_providers/xinference/llm/llm.py | 2 +- .../xinference/rerank/rerank.py | 2 +- .../xinference/speech2text/speech2text.py | 2 +- .../text_embedding/text_embedding.py | 4 +- .../model_providers/xinference/tts/tts.py | 7 +- .../xinference/xinference_helper.py | 12 +- .../model_providers/yi/llm/llm.py | 2 +- .../model_providers/zhipuai/llm/llm.py | 6 +- .../zhipuai/text_embedding/text_embedding.py | 2 +- .../schema_validators/common_validator.py | 7 +- api/core/model_runtime/utils/encoders.py | 3 +- api/core/model_runtime/utils/helper.py | 3 +- api/core/moderation/api/api.py | 14 +- api/core/moderation/base.py | 4 +- api/core/moderation/factory.py | 3 +- api/core/moderation/input_moderation.py | 8 +- api/core/moderation/keywords/keywords.py | 6 +- .../openai_moderation/openai_moderation.py | 4 + api/core/moderation/output_moderation.py | 2 +- api/core/ops/entities/trace_entity.py | 5 +- api/core/ops/langfuse_trace/langfuse_trace.py | 12 +- .../entities/langsmith_trace_entity.py | 1 - .../ops/langsmith_trace/langsmith_trace.py | 155 ++++++++++-- api/core/ops/ops_trace_manager.py | 52 ++-- api/core/prompt/advanced_prompt_transform.py | 37 +-- .../prompt/agent_history_prompt_transform.py | 2 +- api/core/prompt/prompt_transform.py | 22 +- api/core/prompt/simple_prompt_transform.py | 25 +- api/core/prompt/utils/prompt_message_util.py | 8 +- .../prompt/utils/prompt_template_parser.py | 3 +- api/core/provider_manager.py | 37 ++- .../rag/datasource/keyword/jieba/jieba.py | 37 ++- .../jieba/jieba_keyword_table_handler.py | 4 +- .../rag/datasource/keyword/keyword_base.py | 4 +- api/core/rag/datasource/retrieval_service.py | 23 +- .../vdb/analyticdb/analyticdb_vector.py | 35 +-- .../analyticdb/analyticdb_vector_openapi.py | 37 +-- .../vdb/analyticdb/analyticdb_vector_sql.py | 24 +- .../rag/datasource/vdb/baidu/baidu_vector.py | 28 ++- .../datasource/vdb/chroma/chroma_vector.py | 24 +- .../vdb/couchbase/couchbase_vector.py | 24 +- .../vdb/elasticsearch/elasticsearch_vector.py | 22 +- .../datasource/vdb/lindorm/lindorm_vector.py | 33 ++- .../datasource/vdb/milvus/milvus_vector.py | 22 +- .../datasource/vdb/myscale/myscale_vector.py | 19 +- .../vdb/oceanbase/oceanbase_vector.py | 4 +- .../vdb/opensearch/opensearch_vector.py | 4 +- .../rag/datasource/vdb/oracle/oraclevector.py | 42 ++-- .../datasource/vdb/pgvecto_rs/pgvecto_rs.py | 14 +- .../rag/datasource/vdb/pgvector/pgvector.py | 31 +-- .../datasource/vdb/qdrant/qdrant_vector.py | 21 +- .../rag/datasource/vdb/relyt/relyt_vector.py | 26 +- .../datasource/vdb/tencent/tencent_vector.py | 30 +-- .../tidb_on_qdrant/tidb_on_qdrant_vector.py | 31 ++- .../vdb/tidb_on_qdrant/tidb_service.py | 11 +- .../datasource/vdb/tidb_vector/tidb_vector.py | 12 +- api/core/rag/datasource/vdb/vector_base.py | 11 +- api/core/rag/datasource/vdb/vector_factory.py | 9 +- .../vdb/vikingdb/vikingdb_vector.py | 11 +- .../vdb/weaviate/weaviate_vector.py | 14 +- api/core/rag/docstore/dataset_docstore.py | 9 +- api/core/rag/embedding/cached_embedding.py | 16 +- .../rag/extractor/entity/extract_setting.py | 2 +- api/core/rag/extractor/excel_extractor.py | 8 +- api/core/rag/extractor/extract_processor.py | 20 +- .../rag/extractor/firecrawl/firecrawl_app.py | 21 +- api/core/rag/extractor/html_extractor.py | 3 +- api/core/rag/extractor/notion_extractor.py | 14 +- api/core/rag/extractor/pdf_extractor.py | 6 +- .../unstructured_eml_extractor.py | 2 +- .../unstructured_epub_extractor.py | 3 + .../unstructured_ppt_extractor.py | 4 +- .../unstructured_pptx_extractor.py | 11 +- api/core/rag/extractor/word_extractor.py | 6 + .../index_processor/index_processor_base.py | 1 + .../index_processor_factory.py | 2 +- .../processor/paragraph_index_processor.py | 10 +- .../processor/qa_index_processor.py | 27 +- api/core/rag/rerank/rerank_model.py | 11 +- api/core/rag/rerank/weight_rerank.py | 16 +- api/core/rag/retrieval/dataset_retrieval.py | 111 +++++---- .../multi_dataset_function_call_router.py | 16 +- .../router/multi_dataset_react_route.py | 22 +- api/core/rag/splitter/fixed_text_splitter.py | 4 +- api/core/rag/splitter/text_splitter.py | 4 +- api/core/tools/entities/api_entities.py | 2 +- api/core/tools/entities/tool_bundle.py | 2 +- api/core/tools/entities/tool_entities.py | 12 +- api/core/tools/provider/api_tool_provider.py | 70 +++--- api/core/tools/provider/app_tool_provider.py | 15 +- api/core/tools/provider/builtin/_positions.py | 2 +- .../provider/builtin/aippt/tools/aippt.py | 36 +-- .../builtin/arxiv/tools/arxiv_search.py | 2 +- .../tools/provider/builtin/audio/tools/tts.py | 20 +- .../builtin/aws/tools/apply_guardrail.py | 4 +- .../aws/tools/lambda_translate_utils.py | 2 +- .../builtin/aws/tools/lambda_yaml_to_json.py | 2 +- .../aws/tools/sagemaker_text_rerank.py | 6 +- .../builtin/aws/tools/sagemaker_tts.py | 4 +- .../builtin/cogview/tools/cogvideo.py | 2 +- .../builtin/cogview/tools/cogvideo_job.py | 2 +- .../builtin/cogview/tools/cogview3.py | 2 +- .../feishu_base/tools/search_records.py | 20 +- .../feishu_base/tools/update_records.py | 12 +- .../tools/add_event_attendees.py | 8 +- .../feishu_calendar/tools/delete_event.py | 6 +- .../tools/get_primary_calendar.py | 4 + .../feishu_calendar/tools/list_events.py | 12 +- .../feishu_calendar/tools/update_event.py | 14 +- .../feishu_document/tools/create_document.py | 10 +- .../tools/list_document_blocks.py | 6 +- .../builtin/json_process/tools/delete.py | 2 +- .../builtin/json_process/tools/insert.py | 2 +- .../builtin/json_process/tools/parse.py | 2 +- .../builtin/json_process/tools/replace.py | 2 +- .../builtin/maths/tools/eval_expression.py | 2 +- .../builtin/novitaai/_novita_tool_base.py | 2 +- .../novitaai/tools/novitaai_createtile.py | 2 +- .../novitaai/tools/novitaai_txt2img.py | 2 +- .../tools/podcast_audio_generator.py | 2 +- .../builtin/qrcode/tools/qrcode_generator.py | 8 +- .../builtin/transcript/tools/transcript.py | 2 +- .../builtin/twilio/tools/send_message.py | 2 +- .../tools/provider/builtin/twilio/twilio.py | 4 +- .../provider/builtin/vanna/tools/vanna.py | 5 +- .../wikipedia/tools/wikipedia_search.py | 2 +- .../provider/builtin/yahoo/tools/analytics.py | 2 +- .../provider/builtin/yahoo/tools/news.py | 2 +- .../provider/builtin/yahoo/tools/ticker.py | 2 +- .../provider/builtin/youtube/tools/videos.py | 2 +- .../tools/provider/builtin_tool_provider.py | 70 +++--- api/core/tools/provider/tool_provider.py | 63 ++--- .../tools/provider/workflow_tool_provider.py | 13 +- api/core/tools/tool/api_tool.py | 14 +- api/core/tools/tool/builtin_tool.py | 35 ++- .../dataset_multi_retriever_tool.py | 13 +- .../dataset_retriever_base_tool.py | 2 +- .../dataset_retriever_tool.py | 57 +++-- api/core/tools/tool/dataset_retriever_tool.py | 11 +- api/core/tools/tool/tool.py | 17 +- api/core/tools/tool/workflow_tool.py | 14 +- api/core/tools/tool_engine.py | 24 +- api/core/tools/tool_label_manager.py | 8 +- api/core/tools/tool_manager.py | 126 ++++++---- api/core/tools/utils/configuration.py | 18 +- api/core/tools/utils/feishu_api_utils.py | 179 ++++++++------ api/core/tools/utils/lark_api_utils.py | 193 +++++++++------ api/core/tools/utils/message_transformer.py | 12 +- .../tools/utils/model_invocation_utils.py | 23 +- api/core/tools/utils/parser.py | 17 +- api/core/tools/utils/web_reader_tool.py | 15 +- .../utils/workflow_configuration_sync.py | 4 +- api/core/tools/utils/yaml_utils.py | 2 +- api/core/variables/variables.py | 3 +- .../callbacks/workflow_logging_callback.py | 2 +- api/core/workflow/entities/node_entities.py | 4 +- .../condition_handlers/condition_handler.py | 2 +- .../workflow/graph_engine/entities/graph.py | 16 +- .../workflow/graph_engine/graph_engine.py | 55 +++-- .../nodes/answer/answer_stream_processor.py | 4 +- .../nodes/answer/base_stream_processor.py | 8 +- api/core/workflow/nodes/base/entities.py | 5 +- api/core/workflow/nodes/code/code_node.py | 2 +- api/core/workflow/nodes/code/entities.py | 2 +- .../workflow/nodes/document_extractor/node.py | 7 +- .../nodes/end/end_stream_generate_router.py | 5 +- .../nodes/end/end_stream_processor.py | 2 +- api/core/workflow/nodes/event/event.py | 2 +- .../workflow/nodes/http_request/executor.py | 44 ++-- api/core/workflow/nodes/http_request/node.py | 8 +- .../nodes/iteration/iteration_node.py | 15 +- .../knowledge_retrieval_node.py | 16 +- api/core/workflow/nodes/list_operator/node.py | 34 +-- api/core/workflow/nodes/llm/node.py | 17 +- api/core/workflow/nodes/loop/loop_node.py | 6 +- .../nodes/parameter_extractor/entities.py | 4 +- .../parameter_extractor_node.py | 16 +- .../nodes/parameter_extractor/prompts.py | 4 +- .../question_classifier_node.py | 12 +- api/core/workflow/nodes/tool/tool_node.py | 9 +- .../nodes/variable_assigner/v1/node.py | 2 + .../nodes/variable_assigner/v2/node.py | 6 +- api/core/workflow/workflow_entry.py | 15 +- .../event_handlers/create_document_index.py | 2 +- .../create_site_record_when_app_created.py | 29 +-- .../deduct_quota_when_message_created.py | 2 +- ...rameters_cache_when_sync_draft_workflow.py | 5 +- ...aset_join_when_app_model_config_updated.py | 10 +- ...oin_when_app_published_workflow_updated.py | 10 +- api/extensions/__init__.py | 0 api/extensions/ext_app_metrics.py | 14 +- api/extensions/ext_celery.py | 6 +- api/extensions/ext_compress.py | 2 +- api/extensions/ext_logging.py | 5 +- api/extensions/ext_login.py | 2 +- api/extensions/ext_mail.py | 8 +- api/extensions/ext_migrate.py | 2 +- api/extensions/ext_proxy_fix.py | 2 +- api/extensions/ext_sentry.py | 2 +- api/extensions/ext_storage.py | 8 +- api/extensions/storage/aliyun_oss_storage.py | 12 +- api/extensions/storage/aws_s3_storage.py | 8 +- api/extensions/storage/azure_blob_storage.py | 8 +- api/extensions/storage/baidu_obs_storage.py | 9 +- .../storage/google_cloud_storage.py | 4 +- api/extensions/storage/huawei_obs_storage.py | 4 +- api/extensions/storage/opendal_storage.py | 8 +- api/extensions/storage/oracle_oci_storage.py | 6 +- api/extensions/storage/supabase_storage.py | 2 +- api/extensions/storage/tencent_cos_storage.py | 4 +- .../storage/volcengine_tos_storage.py | 4 +- api/factories/__init__.py | 0 api/factories/file_factory.py | 5 +- api/factories/variable_factory.py | 21 +- api/fields/annotation_fields.py | 2 +- api/fields/api_based_extension_fields.py | 2 +- api/fields/app_fields.py | 2 +- api/fields/conversation_fields.py | 2 +- api/fields/conversation_variable_fields.py | 2 +- api/fields/data_source_fields.py | 2 +- api/fields/dataset_fields.py | 2 +- api/fields/document_fields.py | 2 +- api/fields/end_user_fields.py | 2 +- api/fields/external_dataset_fields.py | 2 +- api/fields/file_fields.py | 2 +- api/fields/hit_testing_fields.py | 2 +- api/fields/installed_app_fields.py | 2 +- api/fields/member_fields.py | 2 +- api/fields/message_fields.py | 2 +- api/fields/raws.py | 2 +- api/fields/segment_fields.py | 2 +- api/fields/tag_fields.py | 2 +- api/fields/workflow_app_log_fields.py | 2 +- api/fields/workflow_fields.py | 2 +- api/fields/workflow_run_fields.py | 2 +- api/libs/external_api.py | 7 +- api/libs/gmpy2_pkcs10aep_cipher.py | 8 +- api/libs/helper.py | 6 +- api/libs/json_in_md_parser.py | 1 + api/libs/login.py | 15 +- api/libs/oauth.py | 6 +- api/libs/oauth_data_source.py | 7 +- api/libs/threadings_utils.py | 4 +- api/models/account.py | 32 +-- api/models/api_based_extension.py | 2 +- api/models/dataset.py | 33 +-- api/models/model.py | 73 +++--- api/models/provider.py | 14 +- api/models/source.py | 4 +- api/models/task.py | 6 +- api/models/tools.py | 22 +- api/models/web.py | 4 +- api/models/workflow.py | 19 +- api/mypy.ini | 10 + api/poetry.lock | 230 ++++++++++++------ api/pyproject.toml | 3 + api/schedule/clean_messages.py | 3 +- api/schedule/clean_unused_datasets_task.py | 6 +- api/schedule/create_tidb_serverless_task.py | 15 +- .../update_tidb_serverless_status_task.py | 13 +- api/services/account_service.py | 31 ++- .../advanced_prompt_template_service.py | 4 + api/services/agent_service.py | 13 +- api/services/annotation_service.py | 14 +- api/services/app_dsl_service.py | 9 +- api/services/app_generate_service.py | 4 +- api/services/app_service.py | 24 +- api/services/audio_service.py | 6 + api/services/auth/firecrawl/firecrawl.py | 4 +- api/services/auth/jina.py | 2 +- api/services/auth/jina/jina.py | 2 +- api/services/billing_service.py | 6 +- api/services/conversation_service.py | 3 +- api/services/dataset_service.py | 42 +++- api/services/enterprise/base.py | 4 +- .../entities/model_provider_entities.py | 8 +- api/services/external_knowledge_service.py | 39 +-- api/services/file_service.py | 6 +- api/services/hit_testing_service.py | 17 +- api/services/knowledge_service.py | 2 +- api/services/message_service.py | 6 +- api/services/model_load_balancing_service.py | 49 ++-- api/services/model_provider_service.py | 25 +- api/services/moderation_service.py | 4 +- api/services/ops_service.py | 26 +- .../buildin/buildin_retrieval.py | 8 +- .../recommend_app/remote/remote_retrieval.py | 6 +- api/services/recommended_app_service.py | 2 +- api/services/saved_message_service.py | 6 + api/services/tag_service.py | 4 +- .../tools/api_tools_manage_service.py | 33 +-- .../tools/builtin_tools_manage_service.py | 8 +- api/services/tools/tools_transform_service.py | 34 ++- .../tools/workflow_tools_manage_service.py | 63 +++-- api/services/web_conversation_service.py | 6 + api/services/website_service.py | 33 ++- api/services/workflow/workflow_converter.py | 18 +- api/services/workflow_run_service.py | 4 +- api/services/workflow_service.py | 6 +- api/services/workspace_service.py | 3 +- api/tasks/__init__.py | 0 api/tasks/add_document_to_index_task.py | 2 +- .../add_annotation_to_index_task.py | 2 +- .../batch_import_annotations_task.py | 2 +- .../delete_annotation_index_task.py | 2 +- .../disable_annotation_reply_task.py | 2 +- .../enable_annotation_reply_task.py | 2 +- .../update_annotation_to_index_task.py | 2 +- .../batch_create_segment_to_index_task.py | 10 +- api/tasks/clean_dataset_task.py | 4 +- api/tasks/clean_document_task.py | 4 +- api/tasks/clean_notion_document_task.py | 2 +- api/tasks/create_segment_to_index_task.py | 2 +- api/tasks/deal_dataset_vector_index_task.py | 2 +- api/tasks/delete_segment_from_index_task.py | 2 +- api/tasks/disable_segment_from_index_task.py | 2 +- api/tasks/document_indexing_sync_task.py | 2 +- api/tasks/document_indexing_task.py | 2 +- api/tasks/document_indexing_update_task.py | 2 +- api/tasks/duplicate_document_indexing_task.py | 4 +- api/tasks/enable_segment_to_index_task.py | 2 +- api/tasks/mail_email_code_login.py | 2 +- api/tasks/mail_invite_member_task.py | 2 +- api/tasks/mail_reset_password_task.py | 2 +- api/tasks/ops_trace_task.py | 2 +- api/tasks/recover_document_indexing_task.py | 2 +- api/tasks/remove_app_and_related_data_task.py | 2 +- api/tasks/remove_document_from_index_task.py | 2 +- api/tasks/retry_document_indexing_task.py | 45 ++-- .../sync_website_document_indexing_task.py | 42 ++-- .../dependencies/test_dependencies_sorted.py | 4 +- .../controllers/test_controllers.py | 2 +- .../model_runtime/__mock/google.py | 4 +- .../model_runtime/__mock/huggingface.py | 2 +- .../model_runtime/__mock/huggingface_chat.py | 6 +- .../model_runtime/__mock/nomic_embeddings.py | 2 +- .../model_runtime/__mock/xinference.py | 4 +- .../model_runtime/tongyi/test_rerank.py | 2 +- .../tools/__mock_server/openapi_todo.py | 2 +- .../vdb/__mock/baiduvectordb.py | 10 +- .../vdb/__mock/tcvectordb.py | 12 +- .../integration_tests/vdb/__mock/vikingdb.py | 2 +- api/tests/unit_tests/oss/__mock/aliyun_oss.py | 4 +- .../unit_tests/oss/__mock/tencent_cos.py | 4 +- .../unit_tests/oss/__mock/volcengine_tos.py | 4 +- .../aliyun_oss/aliyun_oss/test_aliyun_oss.py | 2 +- .../oss/tencent_cos/test_tencent_cos.py | 2 +- .../oss/volcengine_tos/test_volcengine_tos.py | 2 +- .../unit_tests/utils/yaml/test_yaml_utils.py | 2 +- sdks/python-client/dify_client/client.py | 27 +- 584 files changed, 3980 insertions(+), 2831 deletions(-) create mode 100644 api/extensions/__init__.py create mode 100644 api/factories/__init__.py create mode 100644 api/mypy.ini create mode 100644 api/tasks/__init__.py diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 2cd0b2a7d430de..fd98db24b961b4 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -56,6 +56,12 @@ jobs: - name: Run Tool run: poetry run -C api bash dev/pytest/pytest_tools.sh + - name: Run mypy + run: | + pushd api + poetry run python -m mypy --install-types --non-interactive . + popd + - name: Set up dotenvs run: | cp docker/.env.example docker/.env diff --git a/api/commands.py b/api/commands.py index bf013cc77e0627..ad7ad972f3fd01 100644 --- a/api/commands.py +++ b/api/commands.py @@ -159,8 +159,7 @@ def migrate_annotation_vector_database(): try: # get apps info apps = ( - db.session.query(App) - .filter(App.status == "normal") + App.query.filter(App.status == "normal") .order_by(App.created_at.desc()) .paginate(page=page, per_page=50) ) @@ -285,8 +284,7 @@ def migrate_knowledge_vector_database(): while True: try: datasets = ( - db.session.query(Dataset) - .filter(Dataset.indexing_technique == "high_quality") + Dataset.query.filter(Dataset.indexing_technique == "high_quality") .order_by(Dataset.created_at.desc()) .paginate(page=page, per_page=50) ) @@ -450,7 +448,8 @@ def convert_to_agent_apps(): if app_id not in proceeded_app_ids: proceeded_app_ids.append(app_id) app = db.session.query(App).filter(App.id == app_id).first() - apps.append(app) + if app is not None: + apps.append(app) if len(apps) == 0: break @@ -621,6 +620,10 @@ def fix_app_site_missing(): try: app = db.session.query(App).filter(App.id == app_id).first() + if not app: + print(f"App {app_id} not found") + continue + tenant = app.tenant if tenant: accounts = tenant.get_accounts() diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 73f8a95989baaf..74cdf944865796 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -239,7 +239,6 @@ class HttpConfig(BaseSettings): ) @computed_field - @property def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]: return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",") @@ -250,7 +249,6 @@ def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]: ) @computed_field - @property def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") @@ -715,27 +713,27 @@ class PositionConfig(BaseSettings): default="", ) - @computed_field + @property def POSITION_PROVIDER_PINS_LIST(self) -> list[str]: return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""] - @computed_field + @property def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""} - @computed_field + @property def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""} - @computed_field + @property def POSITION_TOOL_PINS_LIST(self) -> list[str]: return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""] - @computed_field + @property def POSITION_TOOL_INCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""} - @computed_field + @property def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 9265a48d9bc53c..f6a44eaa471e62 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -130,7 +130,6 @@ class DatabaseConfig(BaseSettings): ) @computed_field - @property def SQLALCHEMY_DATABASE_URI(self) -> str: db_extras = ( f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS @@ -168,7 +167,6 @@ def SQLALCHEMY_DATABASE_URI(self) -> str: ) @computed_field - @property def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: return { "pool_size": self.SQLALCHEMY_POOL_SIZE, @@ -206,7 +204,6 @@ class CeleryConfig(DatabaseConfig): ) @computed_field - @property def CELERY_RESULT_BACKEND(self) -> str | None: return ( "db+{}".format(self.SQLALCHEMY_DATABASE_URI) @@ -214,7 +211,6 @@ def CELERY_RESULT_BACKEND(self) -> str | None: else self.CELERY_BROKER_URL ) - @computed_field @property def BROKER_USE_SSL(self) -> bool: return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False diff --git a/api/configs/remote_settings_sources/apollo/client.py b/api/configs/remote_settings_sources/apollo/client.py index d1f6781ed370dd..03c64ea00f0185 100644 --- a/api/configs/remote_settings_sources/apollo/client.py +++ b/api/configs/remote_settings_sources/apollo/client.py @@ -4,6 +4,7 @@ import os import threading import time +from collections.abc import Mapping from pathlib import Path from .python_3x import http_request, makedirs_wrapper @@ -255,8 +256,8 @@ def _listener(self): logger.info("stopped, long_poll") # add the need for endorsement to the header - def _sign_headers(self, url): - headers = {} + def _sign_headers(self, url: str) -> Mapping[str, str]: + headers: dict[str, str] = {} if self.secret == "": return headers uri = url[len(self.config_url) : len(url)] diff --git a/api/constants/model_template.py b/api/constants/model_template.py index 7e1a196356c4e2..c26d8c018610d0 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -1,8 +1,9 @@ import json +from collections.abc import Mapping from models.model import AppMode -default_app_templates = { +default_app_templates: Mapping[AppMode, Mapping] = { # workflow default mode AppMode.WORKFLOW: { "app": { diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 79869916eda062..b1ebc444a51868 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore parameters__system_parameters = { "image_file_size_limit": fields.Integer, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index f46d5b6b138d59..cb6b0d097b1fc9 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -3,6 +3,25 @@ from libs.external_api import ExternalApi from .app.app_import import AppImportApi, AppImportConfirmApi +from .explore.audio import ChatAudioApi, ChatTextApi +from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi +from .explore.conversation import ( + ConversationApi, + ConversationListApi, + ConversationPinApi, + ConversationRenameApi, + ConversationUnPinApi, +) +from .explore.message import ( + MessageFeedbackApi, + MessageListApi, + MessageMoreLikeThisApi, + MessageSuggestedQuestionApi, +) +from .explore.workflow import ( + InstalledAppWorkflowRunApi, + InstalledAppWorkflowTaskStopApi, +) from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi @@ -66,15 +85,81 @@ # Import explore controllers from .explore import ( - audio, - completion, - conversation, installed_app, - message, parameter, recommended_app, saved_message, - workflow, +) + +# Explore Audio +api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") +api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") + +# Explore Completion +api.add_resource( + CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" +) +api.add_resource( + CompletionStopApi, + "/installed-apps//completion-messages//stop", + endpoint="installed_app_stop_completion", +) +api.add_resource( + ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" +) +api.add_resource( + ChatStopApi, + "/installed-apps//chat-messages//stop", + endpoint="installed_app_stop_chat_completion", +) + +# Explore Conversation +api.add_resource( + ConversationRenameApi, + "/installed-apps//conversations//name", + endpoint="installed_app_conversation_rename", +) +api.add_resource( + ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" +) +api.add_resource( + ConversationApi, + "/installed-apps//conversations/", + endpoint="installed_app_conversation", +) +api.add_resource( + ConversationPinApi, + "/installed-apps//conversations//pin", + endpoint="installed_app_conversation_pin", +) +api.add_resource( + ConversationUnPinApi, + "/installed-apps//conversations//unpin", + endpoint="installed_app_conversation_unpin", +) + + +# Explore Message +api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") +api.add_resource( + MessageFeedbackApi, + "/installed-apps//messages//feedbacks", + endpoint="installed_app_message_feedback", +) +api.add_resource( + MessageMoreLikeThisApi, + "/installed-apps//messages//more-like-this", + endpoint="installed_app_more_like_this", +) +api.add_resource( + MessageSuggestedQuestionApi, + "/installed-apps//messages//suggested-questions", + endpoint="installed_app_suggested_question", +) +# Explore Workflow +api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") +api.add_resource( + InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" ) # Import tag controllers diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 8c0bf8710d3964..52e0bb6c56bdc2 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,7 +1,7 @@ from functools import wraps from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 953770868904d3..ca8ddc32094ac5 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,5 +1,7 @@ -import flask_restful -from flask_login import current_user +from typing import Any + +import flask_restful # type: ignore +from flask_login import current_user # type: ignore from flask_restful import Resource, fields, marshal_with from werkzeug.exceptions import Forbidden @@ -35,14 +37,15 @@ def _get_resource(resource_id, tenant_id, resource_model): class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type = None - resource_model = None - resource_id_field = None - token_prefix = None + resource_type: str | None = None + resource_model: Any = None + resource_id_field: str | None = None + token_prefix: str | None = None max_keys = 10 @marshal_with(api_key_list) def get(self, resource_id): + assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) keys = ( @@ -54,6 +57,7 @@ def get(self, resource_id): @marshal_with(api_key_fields) def post(self, resource_id): + assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) if not current_user.is_editor: @@ -86,11 +90,12 @@ def post(self, resource_id): class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type = None - resource_model = None - resource_id_field = None + resource_type: str | None = None + resource_model: Any = None + resource_id_field: str | None = None def delete(self, resource_id, api_key_id): + assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) api_key_id = str(api_key_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index c228743fa53591..8d0c5b84af5e37 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index d4334158945e16..920cae0d859354 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index fd05cbc19bf04f..24f1020c18ec37 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,6 +1,6 @@ from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -110,7 +110,7 @@ def get(self, app_id): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - keyword = request.args.get("keyword", default=None, type=str) + keyword = request.args.get("keyword", default="", type=str) app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index da72b704c71bd7..9cd56cef0b7039 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,8 +1,8 @@ import uuid from typing import cast -from flask_login import current_user -from flask_restful import Resource, inputs, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, inputs, marshal, marshal_with, reqparse # type: ignore from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, abort diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 244dcd75de29bc..7e2888d71c79c8 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,7 +1,7 @@ from typing import cast -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 695b8890e30f5c..9d26af276d2fc3 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError import services diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 9896fcaab8ad36..dba41e5c47d24f 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,7 +1,7 @@ import logging -import flask_login -from flask_restful import Resource, reqparse +import flask_login # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index a25004be4d16ae..8827f129d99317 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,9 +1,9 @@ from datetime import UTC, datetime -import pytz -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +import pytz # pip install pytz +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy import func, or_ from sqlalchemy.orm import joinedload from werkzeug.exceptions import Forbidden, NotFound @@ -77,8 +77,9 @@ def get(self, app_model): query = query.where(Conversation.created_at < end_datetime_utc) + # FIXME, the type ignore in this file if args["annotation_status"] == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( + query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) elif args["annotation_status"] == "not_annotated": @@ -222,7 +223,7 @@ def get(self, app_model): query = query.where(Conversation.created_at <= end_datetime_utc) if args["annotation_status"] == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( + query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) elif args["annotation_status"] == "not_annotated": @@ -234,7 +235,7 @@ def get(self, app_model): if args["message_count_gte"] and args["message_count_gte"] >= 1: query = ( - query.options(joinedload(Conversation.messages)) + query.options(joinedload(Conversation.messages)) # type: ignore .join(Message, Message.conversation_id == Conversation.id) .group_by(Conversation.id) .having(func.count(Message.id) >= args["message_count_gte"]) diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index d49f433ba1f575..c0a20b7160e719 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, marshal_with, reqparse +from flask_restful import Resource, marshal_with, reqparse # type: ignore from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 9c3cbe4e3e049e..8518d34a8e5af2 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,7 +1,7 @@ import os -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.error import ( diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index b7a4c31a156b80..b5828b6b4b08c4 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,8 +1,8 @@ import logging -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from controllers.console import api diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index a46bc6a8a97606..8ecc8a9db5738d 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -1,8 +1,9 @@ import json +from typing import cast from flask import request -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model @@ -26,7 +27,9 @@ def post(self, app_model): """Modify app model config""" # validate config model_configuration = AppModelConfigService.validate_configuration( - tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode) + tenant_id=current_user.current_tenant_id, + config=cast(dict, request.json), + app_mode=AppMode.value_of(app_model.mode), ) new_app_model_config = AppModelConfig( @@ -38,9 +41,11 @@ def post(self, app_model): if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: # get original app model config - original_app_model_config: AppModelConfig = ( + original_app_model_config = ( db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() ) + if original_app_model_config is None: + raise ValueError("Original app model config not found") agent_mode = original_app_model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input parameter_map = {} diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 3f10215e702ac1..dd25af8ebf9312 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import BadRequest from controllers.console import api diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 407f6898199bae..db29b95c4140ff 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,7 +1,7 @@ from datetime import UTC, datetime -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound from constants.languages import supported_language @@ -50,7 +50,7 @@ def post(self, app_model): if not current_user.is_editor: raise Forbidden() - site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404() + site = Site.query.filter(Site.app_id == app_model.id).one_or_404() for attr_name in [ "title", diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index db5e2824095ca0..3b21108ceaf76b 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -3,8 +3,8 @@ import pytz from flask import jsonify -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index f228c3ec4a0e07..26a3a022d401a4 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -2,7 +2,7 @@ import logging from flask import abort, request -from flask_restful import Resource, marshal_with, reqparse +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 2940556f84ef4e..882c53e4fb9972 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -1,5 +1,5 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 08ab61bbb9c97e..25a99c1e1594ae 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,5 +1,5 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 6c7c73707bb204..097bf7d1888cf5 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -3,8 +3,8 @@ import pytz from flask import jsonify -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 63edb83079041e..9ad8c158473df9 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -8,7 +8,7 @@ from models import App, AppMode -def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): +def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index d2aa7c903b046c..c56f551d49be8b 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,14 +1,14 @@ import datetime from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db from libs.helper import StrLen, email, extract_remote_ip, timezone -from models.account import AccountStatus, Tenant +from models.account import AccountStatus from services.account_service import AccountService, RegisterService @@ -27,7 +27,7 @@ def get(self): invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) if invitation: data = invitation.get("data", {}) - tenant: Tenant = invitation.get("tenant", None) + tenant = invitation.get("tenant", None) workspace_name = tenant.name if tenant else None workspace_id = tenant.id if tenant else None invitee_email = data.get("email") if data else None diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 465c44e9b6dc2f..ea00c2b8c2272c 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index faca67bb177f10..e911c9a5e5b5ea 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -2,8 +2,8 @@ import requests from flask import current_app, redirect, request -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from werkzeug.exceptions import Forbidden from configs import dify_config @@ -17,8 +17,8 @@ def get_oauth_providers(): with current_app.app_context(): notion_oauth = NotionOAuth( - client_id=dify_config.NOTION_CLIENT_ID, - client_secret=dify_config.NOTION_CLIENT_SECRET, + client_id=dify_config.NOTION_CLIENT_ID or "", + client_secret=dify_config.NOTION_CLIENT_SECRET or "", redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion", ) diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index fb32bb2b60286d..140b9e145fa9cd 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,7 +2,7 @@ import secrets from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from constants.languages import languages from controllers.console import api @@ -122,8 +122,8 @@ def post(self): else: try: account = AccountService.create_account_and_tenant( - email=reset_data.get("email"), - name=reset_data.get("email"), + email=reset_data.get("email", ""), + name=reset_data.get("email", ""), password=password_confirm, interface_language=languages[0], ) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index f4463ce9cb3f30..78a80fc8d7e075 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,8 +1,8 @@ from typing import cast -import flask_login +import flask_login # type: ignore from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore import services from constants.languages import languages diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index b9188aa0798ea2..333b24142727f0 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -4,7 +4,7 @@ import requests from flask import current_app, redirect, request -from flask_restful import Resource +from flask_restful import Resource # type: ignore from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -77,7 +77,8 @@ def get(self, provider: str): token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) except requests.exceptions.RequestException as e: - logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") + error_text = e.response.text if e.response else str(e) + logging.exception(f"An error occurred during the OAuth process with {provider}: {error_text}") return {"error": "OAuth process failed"}, 400 if invite_token and RegisterService.is_valid_invite_token(invite_token): @@ -129,7 +130,7 @@ def get(self, provider: str): def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: - account = Account.get_by_openid(provider, user_info.id) + account: Optional[Account] = Account.get_by_openid(provider, user_info.id) if not account: account = Account.query.filter_by(email=user_info.email).first() diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 4b0c82ae6c90c2..fd7b7bd8cb3ddd 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 278295ca39a696..d7c431b95080da 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -2,8 +2,8 @@ import json from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import NotFound from controllers.console import api diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 95d4013e3a8f27..f3c3736b25acc5 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,7 +1,7 @@ -import flask_restful +import flask_restful # type: ignore from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore # type: ignore +from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound import services diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index ad4768f51959ac..ca41e504be7eda 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,12 +1,13 @@ import logging from argparse import ArgumentTypeError from datetime import UTC, datetime +from typing import cast from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore from sqlalchemy import asc, desc -from transformers.hf_argparser import string_to_bool +from transformers.hf_argparser import string_to_bool # type: ignore from werkzeug.exceptions import Forbidden, NotFound import services @@ -733,8 +734,7 @@ def put(self, dataset_id, document_id): if not isinstance(doc_metadata, dict): raise ValueError("doc_metadata must be a dictionary.") - - metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] + metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]) document.doc_metadata = {} if doc_type == "others": diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 6f7ef86d2c3fd3..2d5933ca23609a 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -3,8 +3,8 @@ import pandas as pd from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound import services diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index bc6e3687c1c99d..48f360dcd179bc 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,6 +1,6 @@ from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 495f511275b4b9..18b746f547287c 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from controllers.console import api from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 3b4c07686361d0..bd944602c147cb 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging -from flask_login import current_user -from flask_restful import marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import marshal, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services.dataset_service diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index 9127c8af455f6c..da995537e74753 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.datasets.error import WebsiteCrawlError diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 9690677f61b1c2..c7f9fec326945f 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -4,7 +4,6 @@ from werkzeug.exceptions import InternalServerError import services -from controllers.console import api from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -67,7 +66,7 @@ def post(self, installed_app): class ChatTextApi(InstalledAppResource): def post(self, installed_app): - from flask_restful import reqparse + from flask_restful import reqparse # type: ignore app_model = installed_app.app try: @@ -118,9 +117,3 @@ def post(self, installed_app): except Exception as e: logging.exception("internal server error.") raise InternalServerError() - - -api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") -api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") -# api.add_resource(ChatTextApiWithMessageId, '/installed-apps//text-to-audio/message-id', -# endpoint='installed_app_text_with_message_id') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 85c43f8101028e..3331ded70f6620 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,12 +1,11 @@ import logging from datetime import UTC, datetime -from flask_login import current_user -from flask_restful import reqparse +from flask_login import current_user # type: ignore +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services -from controllers.console import api from controllers.console.app.error import ( AppUnavailableError, CompletionRequestError, @@ -147,21 +146,3 @@ def post(self, installed_app, task_id): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"}, 200 - - -api.add_resource( - CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" -) -api.add_resource( - CompletionStopApi, - "/installed-apps//completion-messages//stop", - endpoint="installed_app_stop_completion", -) -api.add_resource( - ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" -) -api.add_resource( - ChatStopApi, - "/installed-apps//chat-messages//stop", - endpoint="installed_app_stop_chat_completion", -) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 5e7a3da017edd7..91916cbc1ed85f 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,10 +1,9 @@ -from flask_login import current_user -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from controllers.console import api from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom @@ -118,28 +117,3 @@ def patch(self, installed_app, c_id): WebConversationService.unpin(app_model, conversation_id, current_user) return {"result": "success"} - - -api.add_resource( - ConversationRenameApi, - "/installed-apps//conversations//name", - endpoint="installed_app_conversation_rename", -) -api.add_resource( - ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" -) -api.add_resource( - ConversationApi, - "/installed-apps//conversations/", - endpoint="installed_app_conversation", -) -api.add_resource( - ConversationPinApi, - "/installed-apps//conversations//pin", - endpoint="installed_app_conversation_pin", -) -api.add_resource( - ConversationUnPinApi, - "/installed-apps//conversations//unpin", - endpoint="installed_app_conversation_unpin", -) diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3de179164de91d..86550b2bdf44b9 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,8 +1,9 @@ from datetime import UTC, datetime +from typing import Any from flask import request -from flask_login import current_user -from flask_restful import Resource, inputs, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore from sqlalchemy import and_ from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -34,7 +35,7 @@ def get(self): installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) - installed_apps = [ + installed_app_list: list[dict[str, Any]] = [ { "id": installed_app.id, "app": installed_app.app, @@ -47,7 +48,7 @@ def get(self): for installed_app in installed_apps if installed_app.app is not None ] - installed_apps.sort( + installed_app_list.sort( key=lambda app: ( -app["is_pinned"], app["last_used_at"] is None, @@ -55,7 +56,7 @@ def get(self): ) ) - return {"installed_apps": installed_apps} + return {"installed_apps": installed_app_list} @login_required @account_initialization_required diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 4e11d8005f138b..c3488de29929c9 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,12 +1,11 @@ import logging -from flask_login import current_user -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services -from controllers.console import api from controllers.console.app.error import ( AppMoreLikeThisDisabledError, CompletionRequestError, @@ -153,21 +152,3 @@ def get(self, installed_app, message_id): raise InternalServerError() return {"data": questions} - - -api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") -api.add_resource( - MessageFeedbackApi, - "/installed-apps//messages//feedbacks", - endpoint="installed_app_message_feedback", -) -api.add_resource( - MessageMoreLikeThisApi, - "/installed-apps//messages//more-like-this", - endpoint="installed_app_more_like_this", -) -api.add_resource( - MessageSuggestedQuestionApi, - "/installed-apps//messages//suggested-questions", - endpoint="installed_app_suggested_question", -) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index fee52248a698e0..5bc74d16e784af 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,4 +1,4 @@ -from flask_restful import marshal_with +from flask_restful import marshal_with # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index ce85f495aacd50..be6b1f5d215fb4 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from constants.languages import languages from controllers.console import api diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 0fc963747981e1..9f0c4966457186 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,6 +1,6 @@ -from flask_login import current_user -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import NotFound from controllers.console import api diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 45f99b1db9fa9e..76d30299cd84a7 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,9 +1,8 @@ import logging -from flask_restful import reqparse +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError -from controllers.console import api from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -73,9 +72,3 @@ def post(self, installed_app: InstalledApp, task_id: str): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"} - - -api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") -api.add_resource( - InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" -) diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 49ea81a8a0f86d..b7ba81fba20f79 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -1,7 +1,7 @@ from functools import wraps -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from werkzeug.exceptions import NotFound from controllers.console.wraps import account_initialization_required diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 4ac0aa497e0866..ed6cedb220cf4b 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from constants import HIDDEN_VALUE from controllers.console import api diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 70ab4ff865cb48..da1171412fdb2d 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from libs.login import login_required from services.feature_service import FeatureService diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index ca32d29efaa474..8cf754bbd686fd 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -1,6 +1,8 @@ +from typing import Literal + from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with # type: ignore from werkzeug.exceptions import Forbidden import services @@ -48,7 +50,8 @@ def get(self): @cloud_edition_billing_resource_check("documents") def post(self): file = request.files["file"] - source = request.form.get("source") + source_str = request.form.get("source") + source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None if "file" not in request.files: raise NoFileUploadedError() diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index ae759bb752a30e..d9ae5cf29fc626 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,7 +1,7 @@ import os from flask import session -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from configs import dify_config from libs.helper import StrLen diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index cd28cc946ee288..2a116112a3227c 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from controllers.console import api diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index b8cf019e4f068d..30afc930a8e980 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -2,8 +2,8 @@ from typing import cast import httpx -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore import services from controllers.common import helpers diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index e0b728d97739d3..aba6f0aad9ee54 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from configs import dify_config from libs.helper import StrLen, email, extract_remote_ip diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index ccd3293a6266fc..da83f64019161b 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,6 +1,6 @@ from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -23,7 +23,7 @@ class TagListApi(Resource): @account_initialization_required @marshal_with(tag_fields) def get(self): - tag_type = request.args.get("type", type=str) + tag_type = request.args.get("type", type=str, default="") keyword = request.args.get("keyword", default=None, type=str) tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 7dea8e554edd7a..7773c99944e42c 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -2,7 +2,7 @@ import logging import requests -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from packaging import version from configs import dify_config diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index f704783cfff56b..96ed4b7a570256 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -2,8 +2,8 @@ import pytz from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from configs import dify_config from constants.languages import supported_language diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index d2b2092b75a9ff..7009343d9923da 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -37,7 +37,7 @@ def post(self, provider: str): model_load_balancing_service = ModelLoadBalancingService() result = True - error = None + error = "" try: model_load_balancing_service.validate_load_balancing_credentials( @@ -86,7 +86,7 @@ def post(self, provider: str, config_id: str): model_load_balancing_service = ModelLoadBalancingService() result = True - error = None + error = "" try: model_load_balancing_service.validate_load_balancing_credentials( diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 38ed2316a58935..1afb41ea87660c 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,7 +1,7 @@ from urllib import parse -from flask_login import current_user -from flask_restful import Resource, abort, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, abort, marshal_with, reqparse # type: ignore import services from configs import dify_config @@ -89,19 +89,19 @@ class MemberCancelInviteApi(Resource): @account_initialization_required def delete(self, member_id): member = db.session.query(Account).filter(Account.id == str(member_id)).first() - if not member: + if member is None: abort(404) - - try: - TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) - except services.errors.account.CannotOperateSelfError as e: - return {"code": "cannot-operate-self", "message": str(e)}, 400 - except services.errors.account.NoPermissionError as e: - return {"code": "forbidden", "message": str(e)}, 403 - except services.errors.account.MemberNotInTenantError as e: - return {"code": "member-not-found", "message": str(e)}, 404 - except Exception as e: - raise ValueError(str(e)) + else: + try: + TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) + except services.errors.account.CannotOperateSelfError as e: + return {"code": "cannot-operate-self", "message": str(e)}, 400 + except services.errors.account.NoPermissionError as e: + return {"code": "forbidden", "message": str(e)}, 403 + except services.errors.account.MemberNotInTenantError as e: + return {"code": "member-not-found", "message": str(e)}, 404 + except Exception as e: + raise ValueError(str(e)) return {"result": "success"}, 204 @@ -122,10 +122,11 @@ def put(self, member_id): return {"code": "invalid-role", "message": "Invalid role"}, 400 member = db.session.get(Account, str(member_id)) - if not member: + if member: abort(404) try: + assert member is not None, "Member not found" TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user) except Exception as e: raise ValueError(str(e)) diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 0e54126063be75..2d11295b0fdf61 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -1,8 +1,8 @@ import io from flask import send_file -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -66,7 +66,7 @@ def post(self, provider: str): model_provider_service = ModelProviderService() result = True - error = None + error = "" try: model_provider_service.provider_credentials_validate( @@ -132,7 +132,8 @@ def get(self, provider: str, icon_type: str, lang: str): icon_type=icon_type, lang=lang, ) - + if icon is None: + raise ValueError(f"icon not found for provider {provider}, icon_type {icon_type}, lang {lang}") return send_file(io.BytesIO(icon), mimetype=mimetype) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index f804285f008120..618262e502ab33 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -1,7 +1,7 @@ import logging -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -308,7 +308,7 @@ def post(self, provider: str): model_provider_service = ModelProviderService() result = True - error = None + error = "" try: model_provider_service.model_credentials_validate( diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 9e62a546997b12..964f3862291a2e 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,8 +1,8 @@ import io from flask import send_file -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 76d76f6b58fc3c..0f99bf62e3c251 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,8 +1,8 @@ import logging from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Unauthorized import services @@ -82,11 +82,7 @@ def get(self): parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - tenants = ( - db.session.query(Tenant) - .order_by(Tenant.created_at.desc()) - .paginate(page=args["page"], per_page=args["limit"]) - ) + tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate(page=args["page"], per_page=args["limit"]) has_more = False if len(tenants.items) == args["limit"]: @@ -151,6 +147,8 @@ def post(self): raise AccountNotLinkTenantError("Account not link tenant") new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant + if new_tenant is None: + raise ValueError("Tenant not found") return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} @@ -166,7 +164,7 @@ def post(self): parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() - tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404() + tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404() custom_config_dict = { "remove_webapp_brand": args["remove_webapp_brand"], diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index d0df296c240686..111db7ccf2da04 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -3,7 +3,7 @@ from functools import wraps from flask import abort, request -from flask_login import current_user +from flask_login import current_user # type: ignore from configs import dify_config from controllers.console.workspace.error import AccountNotInitializedError @@ -121,8 +121,8 @@ def decorated(*args, **kwargs): utm_info = request.cookies.get("utm_info") if utm_info: - utm_info = json.loads(utm_info) - OperationService.record_utm(current_user.current_tenant_id, utm_info) + utm_info_dict: dict = json.loads(utm_info) + OperationService.record_utm(current_user.current_tenant_id, utm_info_dict) except Exception as e: pass return view(*args, **kwargs) diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 6b3ac93cdf3d8f..2357288a50ae36 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -1,5 +1,5 @@ from flask import Response, request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import NotFound import services diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index a298701a2f8b11..cfcce8124761f5 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -1,5 +1,5 @@ from flask import Response -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound from controllers.files import api diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 99d32af593991f..d7346b13b10a90 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console.wraps import setup_required from controllers.inner_api import api diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 51ffe683ff40ad..d4587235f6aef8 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -45,14 +45,14 @@ def decorated(*args, **kwargs): if " " in user_id: user_id = user_id.split(" ")[1] - inner_api_key = request.headers.get("X-Inner-Api-Key") + inner_api_key = request.headers.get("X-Inner-Api-Key", "") data_to_sign = f"DIFY {user_id}" signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) - signature = b64encode(signature.digest()).decode("utf-8") + signature_base64 = b64encode(signature.digest()).decode("utf-8") - if signature != token: + if signature_base64 != token: return view(*args, **kwargs) kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index ecff7d07e974d9..8388e2045dd34f 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, marshal_with +from flask_restful import Resource, marshal_with # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 5db41636471220..e6bcc0bfd25355 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError import services @@ -83,7 +83,7 @@ def post(self, app_model: App, end_user: EndUser): and app_model.workflow and app_model.workflow.features_dict ): - text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {}) voice = args.get("voice") or text_to_speech.get("voice") else: try: diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 8d8e356c4cb940..1be54b386bfe8c 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,6 +1,6 @@ import logging -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 32940cbc29f355..334f2c56206794 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,5 +1,5 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index b0fd8e65ef97df..27b21b9f505633 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import Resource, marshal_with +from flask_restful import Resource, marshal_with # type: ignore import services from controllers.common.errors import FilenameNotExistsError diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 599401bc6f1821..522c7509b9849d 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,7 +1,7 @@ import logging -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 96d1337632826a..c7dd4de3452970 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,7 +1,7 @@ import logging -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import InternalServerError from controllers.service_api import api diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 799fccc228e21d..d6a3beb6b80b9d 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import marshal, reqparse +from flask_restful import marshal, reqparse # type: ignore from werkzeug.exceptions import NotFound import services.dataset_service diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 5c3fc7b241175a..34afe2837f4ca5 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,7 +1,7 @@ import json from flask import request -from flask_restful import marshal, reqparse +from flask_restful import marshal, reqparse # type: ignore from sqlalchemy import desc from werkzeug.exceptions import NotFound diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index e68f6b4dc40a36..34904574a8b88d 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import marshal, reqparse # type: ignore from werkzeug.exceptions import NotFound from controllers.service_api import api diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index d24c4597e210fb..75d9141a6d0a3a 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from configs import dify_config from controllers.service_api import api diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 2128c4c53f9909..740b92ef8e4faf 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -5,8 +5,8 @@ from typing import Optional from flask import current_app, request -from flask_login import user_logged_in -from flask_restful import Resource +from flask_login import user_logged_in # type: ignore +from flask_restful import Resource # type: ignore from pydantic import BaseModel from werkzeug.exceptions import Forbidden, Unauthorized @@ -49,6 +49,8 @@ def decorated_view(*args, **kwargs): raise Forbidden("The app's API service has been disabled.") tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() + if tenant is None: + raise ValueError("Tenant does not exist.") if tenant.status == TenantStatus.ARCHIVE: raise Forbidden("The workspace's status is archived.") @@ -154,8 +156,8 @@ def decorated(*args, **kwargs): # Login admin if account: account.current_tenant = tenant - current_app.login_manager._update_request_context_with_user(account) - user_logged_in.send(current_app._get_current_object(), user=_get_user()) + current_app.login_manager._update_request_context_with_user(account) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore else: raise Unauthorized("Tenant owner account does not exist.") else: diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index cc8255ccf4e748..20e071c834ad5b 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,4 +1,4 @@ -from flask_restful import marshal_with +from flask_restful import marshal_with # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index e8521307ad357a..97d980d07c13a7 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -65,7 +65,7 @@ def post(self, app_model: App, end_user): class TextApi(WebApiResource): def post(self, app_model: App, end_user): - from flask_restful import reqparse + from flask_restful import reqparse # type: ignore try: parser = reqparse.RequestParser() @@ -82,7 +82,7 @@ def post(self, app_model: App, end_user): and app_model.workflow and app_model.workflow.features_dict ): - text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {}) voice = args.get("voice") or text_to_speech.get("voice") else: try: diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 45b890dfc4899d..761771a81a4bb3 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,6 +1,6 @@ import logging -from flask_restful import reqparse +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index fe0d7c74f32cff..28feb1ca47effd 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,5 +1,5 @@ -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py index 0563ed22382e6b..ce841a8814972d 100644 --- a/api/controllers/web/feature.py +++ b/api/controllers/web/feature.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from controllers.web import api from services.feature_service import FeatureService diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index a282fc63a8b056..1d4474015ab648 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import marshal_with +from flask_restful import marshal_with # type: ignore import services from controllers.common.errors import FilenameNotExistsError @@ -33,7 +33,7 @@ def post(self, app_model, end_user): content=file.read(), mimetype=file.mimetype, user=end_user, - source=source, + source="datasets" if source == "datasets" else None, ) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index febaab5328e8b3..0f47e643708570 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -1,7 +1,7 @@ import logging -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index a01ffd861230a5..4625c1f43dfbd1 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,7 +1,7 @@ import uuid from flask import request -from flask_restful import Resource +from flask_restful import Resource # type: ignore from werkzeug.exceptions import NotFound, Unauthorized from controllers.web import api diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index ae68df6bdc4e48..d559ab8e07e736 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,7 +1,7 @@ import urllib.parse import httpx -from flask_restful import marshal_with, reqparse +from flask_restful import marshal_with, reqparse # type: ignore import services from controllers.common import helpers diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index b0492e6b6f0d31..6a9b8189076c3c 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,5 +1,5 @@ -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import NotFound from controllers.web import api diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 0564b15ea39855..e68dc7aa4afba5 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,4 @@ -from flask_restful import fields, marshal_with +from flask_restful import fields, marshal_with # type: ignore from werkzeug.exceptions import Forbidden from configs import dify_config diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 55b0c3e2ab34c5..48d25e720c10c3 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,6 +1,6 @@ import logging -from flask_restful import reqparse +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError from controllers.web import api diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index c327c3df18526c..1b4d263bee4401 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,7 +1,7 @@ from functools import wraps from flask import request -from flask_restful import Resource +from flask_restful import Resource # type: ignore from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from controllers.web.error import WebSSOAuthRequiredError diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index ead293200ea3aa..8d69bdcec2c2ac 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,7 +1,6 @@ import json import logging import uuid -from collections.abc import Mapping, Sequence from datetime import UTC, datetime from typing import Optional, Union, cast @@ -53,6 +52,7 @@ class BaseAgentRunner(AppRunner): def __init__( self, + *, tenant_id: str, application_generate_entity: AgentChatAppGenerateEntity, conversation: Conversation, @@ -66,7 +66,7 @@ def __init__( prompt_messages: Optional[list[PromptMessage]] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None, db_variables: Optional[ToolConversationVariables] = None, - model_instance: ModelInstance | None = None, + model_instance: ModelInstance, ) -> None: self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity @@ -117,7 +117,7 @@ def __init__( features = model_schema.features if model_schema and model_schema.features else [] self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features self.files = application_generate_entity.files if ModelFeature.VISION in features else [] - self.query = None + self.query: Optional[str] = "" self._current_thoughts: list[PromptMessage] = [] def _repack_app_generate_entity( @@ -145,7 +145,7 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P message_tool = PromptMessageTool( name=tool.tool_name, - description=tool_entity.description.llm, + description=tool_entity.description.llm if tool_entity.description else "", parameters={ "type": "object", "properties": {}, @@ -167,7 +167,7 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: - enum = [option.value for option in parameter.options] + enum = [option.value for option in parameter.options] if parameter.options else [] message_tool.parameters["properties"][parameter.name] = { "type": parameter_type, @@ -187,8 +187,8 @@ def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRe convert dataset retriever tool to prompt message tool """ prompt_tool = PromptMessageTool( - name=tool.identity.name, - description=tool.description.llm, + name=tool.identity.name if tool.identity else "unknown", + description=tool.description.llm if tool.description else "", parameters={ "type": "object", "properties": {}, @@ -210,14 +210,14 @@ def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRe return prompt_tool - def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]: + def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]: """ Init tools """ tool_instances = {} prompt_messages_tools = [] - for tool in self.app_config.agent.tools if self.app_config.agent else []: + for tool in self.app_config.agent.tools or [] if self.app_config.agent else []: try: prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) except Exception: @@ -234,7 +234,8 @@ def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessage # save prompt tool prompt_messages_tools.append(prompt_tool) # save tool entity - tool_instances[dataset_tool.identity.name] = dataset_tool + if dataset_tool.identity is not None: + tool_instances[dataset_tool.identity.name] = dataset_tool return tool_instances, prompt_messages_tools @@ -258,7 +259,7 @@ def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: - enum = [option.value for option in parameter.options] + enum = [option.value for option in parameter.options] if parameter.options else [] prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, @@ -322,16 +323,21 @@ def save_agent_thought( tool_name: str, tool_input: Union[str, dict], thought: str, - observation: Union[str, dict], - tool_invoke_meta: Union[str, dict], + observation: Union[str, dict, None], + tool_invoke_meta: Union[str, dict, None], answer: str, messages_ids: list[str], - llm_usage: LLMUsage = None, - ) -> MessageAgentThought: + llm_usage: LLMUsage | None = None, + ): """ Save agent thought """ - agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() + queried_thought = ( + db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() + ) + if not queried_thought: + raise ValueError(f"Agent thought {agent_thought.id} not found") + agent_thought = queried_thought if thought is not None: agent_thought.thought = thought @@ -404,7 +410,7 @@ def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variab """ convert tool variables to db variables """ - db_variables = ( + queried_variables = ( db.session.query(ToolConversationVariables) .filter( ToolConversationVariables.conversation_id == self.message.conversation_id, @@ -412,6 +418,11 @@ def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variab .first() ) + if not queried_variables: + return + + db_variables = queried_variables + db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db.session.commit() @@ -421,7 +432,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P """ Organize agent history """ - result = [] + result: list[PromptMessage] = [] # check if there is a system message in the beginning of the conversation for prompt_message in prompt_messages: if isinstance(prompt_message, SystemPromptMessage): diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index d98ba5a3fad846..e936acb6055cb8 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -1,7 +1,7 @@ import json from abc import ABC, abstractmethod -from collections.abc import Generator -from typing import Optional, Union +from collections.abc import Generator, Mapping +from typing import Any, Optional from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit @@ -12,6 +12,7 @@ from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageTool, ToolPromptMessage, UserPromptMessage, ) @@ -26,18 +27,18 @@ class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True _ignore_observation_providers = ["wenxin"] - _historic_prompt_messages: list[PromptMessage] = None - _agent_scratchpad: list[AgentScratchpadUnit] = None - _instruction: str = None - _query: str = None - _prompt_messages_tools: list[PromptMessage] = None + _historic_prompt_messages: list[PromptMessage] | None = None + _agent_scratchpad: list[AgentScratchpadUnit] | None = None + _instruction: str = "" # FIXME this must be str for now + _query: str | None = None + _prompt_messages_tools: list[PromptMessageTool] = [] def run( self, message: Message, query: str, - inputs: dict[str, str], - ) -> Union[Generator, LLMResult]: + inputs: Mapping[str, str], + ) -> Generator: """ Run Cot agent application """ @@ -57,19 +58,19 @@ def run( # init instruction inputs = inputs or {} instruction = app_config.prompt_template.simple_prompt_template - self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) + self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs) iteration_step = 1 - max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 + max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1 # convert tools into ModelRuntime Tool format tool_instances, self._prompt_messages_tools = self._init_prompt_tools() function_call_state = True - llm_usage = {"usage": None} + llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} final_answer = "" - def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): if not final_llm_usage_dict["usage"]: final_llm_usage_dict["usage"] = usage else: @@ -90,7 +91,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # the last iteration, remove all tools self._prompt_messages_tools = [] - message_file_ids = [] + message_file_ids: list[str] = [] agent_thought = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids @@ -105,7 +106,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): prompt_messages = self._organize_prompt_messages() self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model - chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( + chunks = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=app_generate_entity.model_conf.parameters, tools=[], @@ -115,11 +116,14 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): callbacks=[], ) + if not isinstance(chunks, Generator): + raise ValueError("Expected streaming response from LLM") + # check llm result if not chunks: raise ValueError("failed to invoke llm") - usage_dict = {} + usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None} react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) scratchpad = AgentScratchpadUnit( agent_response="", @@ -139,25 +143,30 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if isinstance(chunk, AgentScratchpadUnit.Action): action = chunk # detect action - scratchpad.agent_response += json.dumps(chunk.model_dump()) + if scratchpad.agent_response is not None: + scratchpad.agent_response += json.dumps(chunk.model_dump()) scratchpad.action_str = json.dumps(chunk.model_dump()) scratchpad.action = action else: - scratchpad.agent_response += chunk - scratchpad.thought += chunk + if scratchpad.agent_response is not None: + scratchpad.agent_response += chunk + if scratchpad.thought is not None: + scratchpad.thought += chunk yield LLMResultChunk( model=self.model_config.model, prompt_messages=prompt_messages, system_fingerprint="", delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), ) - - scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" - self._agent_scratchpad.append(scratchpad) + if scratchpad.thought is not None: + scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" + if self._agent_scratchpad is not None: + self._agent_scratchpad.append(scratchpad) # get llm usage if "usage" in usage_dict: - increase_usage(llm_usage, usage_dict["usage"]) + if usage_dict["usage"] is not None: + increase_usage(llm_usage, usage_dict["usage"]) else: usage_dict["usage"] = LLMUsage.empty_usage() @@ -166,9 +175,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_name=scratchpad.action.action_name if scratchpad.action else "", tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, tool_invoke_meta={}, - thought=scratchpad.thought, + thought=scratchpad.thought or "", observation="", - answer=scratchpad.agent_response, + answer=scratchpad.agent_response or "", messages_ids=[], llm_usage=usage_dict["usage"], ) @@ -209,7 +218,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): agent_thought=agent_thought, tool_name=scratchpad.action.action_name, tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, - thought=scratchpad.thought, + thought=scratchpad.thought or "", observation={scratchpad.action.action_name: tool_invoke_response}, tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, answer=scratchpad.agent_response, @@ -247,8 +256,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): answer=final_answer, messages_ids=[], ) - - self.update_db_variables(self.variables_pool, self.db_variables_pool) + if self.variables_pool is not None and self.db_variables_pool is not None: + self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event self.queue_manager.publish( QueueMessageEndEvent( @@ -307,8 +316,9 @@ def _handle_invoke_action( # publish files for message_file_id, save_as in message_files: - if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) + if save_as is not None and self.variables_pool: + # FIXME the save_as type is confusing, it should be a string or not + self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as)) # publish message file self.queue_manager.publish( @@ -325,7 +335,7 @@ def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: """ return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"]) - def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: + def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str: """ fill in inputs from external data tools """ @@ -376,11 +386,13 @@ def _organize_historic_prompt_messages( """ result: list[PromptMessage] = [] scratchpads: list[AgentScratchpadUnit] = [] - current_scratchpad: AgentScratchpadUnit = None + current_scratchpad: AgentScratchpadUnit | None = None for message in self.history_prompt_messages: if isinstance(message, AssistantPromptMessage): if not current_scratchpad: + if not isinstance(message.content, str | None): + raise NotImplementedError("expected str type") current_scratchpad = AgentScratchpadUnit( agent_response=message.content, thought=message.content or "I am thinking about how to help you", @@ -399,8 +411,12 @@ def _organize_historic_prompt_messages( except: pass elif isinstance(message, ToolPromptMessage): - if current_scratchpad: + if not current_scratchpad: + continue + if isinstance(message.content, str): current_scratchpad.observation = message.content + else: + raise NotImplementedError("expected str type") elif isinstance(message, UserPromptMessage): if scratchpads: result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index d8d047fe91cdbd..6a96c349b2611c 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -19,7 +19,12 @@ def _organize_system_prompt(self) -> SystemPromptMessage: """ Organize system prompt """ + if not self.app_config.agent: + raise ValueError("Agent configuration is not set") + prompt_entity = self.app_config.agent.prompt + if not prompt_entity: + raise ValueError("Agent prompt configuration is not set") first_prompt = prompt_entity.first_prompt system_prompt = ( @@ -75,6 +80,7 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: assistant_messages = [] else: assistant_message = AssistantPromptMessage(content="") + assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str for unit in agent_scratchpad: if unit.is_final(): assistant_message.content += f"Final Answer: {unit.agent_response}" diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 0563090537e62c..3a4d31e047f5ae 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -2,7 +2,12 @@ from typing import Optional from core.agent.cot_agent_runner import CotAgentRunner -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.utils.encoders import jsonable_encoder @@ -11,7 +16,11 @@ def _organize_instruction_prompt(self) -> str: """ Organize instruction prompt """ + if self.app_config.agent is None: + raise ValueError("Agent configuration is not set") prompt_entity = self.app_config.agent.prompt + if prompt_entity is None: + raise ValueError("prompt entity is not set") first_prompt = prompt_entity.first_prompt system_prompt = ( @@ -33,7 +42,13 @@ def _organize_historic_prompt(self, current_session_messages: Optional[list[Prom if isinstance(message, UserPromptMessage): historic_prompt += f"Question: {message.content}\n\n" elif isinstance(message, AssistantPromptMessage): - historic_prompt += message.content + "\n\n" + if isinstance(message.content, str): + historic_prompt += message.content + "\n\n" + elif isinstance(message.content, list): + for content in message.content: + if not isinstance(content, TextPromptMessageContent): + continue + historic_prompt += content.data return historic_prompt @@ -50,7 +65,7 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: # organize current assistant messages agent_scratchpad = self._agent_scratchpad assistant_prompt = "" - for unit in agent_scratchpad: + for unit in agent_scratchpad or []: if unit.is_final(): assistant_prompt += f"Final Answer: {unit.agent_response}" else: diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 119a88fc7becbf..2ae87dca3f8cbd 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -78,5 +78,5 @@ class Strategy(Enum): model: str strategy: Strategy prompt: Optional[AgentPromptEntity] = None - tools: list[AgentToolEntity] = None + tools: list[AgentToolEntity] | None = None max_iteration: int = 5 diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index cd546dee124147..b862c96072aaa0 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -40,6 +40,8 @@ def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResul app_generate_entity = self.application_generate_entity app_config = self.app_config + assert app_config is not None, "app_config is required" + assert app_config.agent is not None, "app_config.agent is required" # convert tools into ModelRuntime Tool format tool_instances, prompt_messages_tools = self._init_prompt_tools() @@ -49,7 +51,7 @@ def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResul # continue to run until there is not any tool call function_call_state = True - llm_usage = {"usage": None} + llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()} final_answer = "" # get tracing instance @@ -75,7 +77,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # the last iteration, remove all tools prompt_messages_tools = [] - message_file_ids = [] + message_file_ids: list[str] = [] agent_thought = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) @@ -105,7 +107,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): current_llm_usage = None - if self.stream_tool_call: + if self.stream_tool_call and isinstance(chunks, Generator): is_first_chunk = True for chunk in chunks: if is_first_chunk: @@ -116,7 +118,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # check if there is any tool call if self.check_tool_calls(chunk): function_call_state = True - tool_calls.extend(self.extract_tool_calls(chunk)) + tool_calls.extend(self.extract_tool_calls(chunk) or []) tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: tool_call_inputs = json.dumps( @@ -131,19 +133,19 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): for content in chunk.delta.message.content: response += content.data else: - response += chunk.delta.message.content + response += str(chunk.delta.message.content) if chunk.delta.usage: increase_usage(llm_usage, chunk.delta.usage) current_llm_usage = chunk.delta.usage yield chunk - else: - result: LLMResult = chunks + elif not self.stream_tool_call and isinstance(chunks, LLMResult): + result = chunks # check if there is any tool call if self.check_blocking_tool_calls(result): function_call_state = True - tool_calls.extend(self.extract_blocking_tool_calls(result)) + tool_calls.extend(self.extract_blocking_tool_calls(result) or []) tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: tool_call_inputs = json.dumps( @@ -162,7 +164,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): for content in result.message.content: response += content.data else: - response += result.message.content + response += str(result.message.content) if not result.message.content: result.message.content = "" @@ -181,6 +183,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): usage=result.usage, ), ) + else: + raise RuntimeError(f"invalid chunks type: {type(chunks)}") assistant_message = AssistantPromptMessage(content="", tool_calls=[]) if tool_calls: @@ -243,7 +247,10 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # publish files for message_file_id, save_as in message_files: if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) + if self.variables_pool: + self.variables_pool.set_file( + tool_name=tool_call_name, value=message_file_id, name=save_as + ) # publish message file self.queue_manager.publish( @@ -263,7 +270,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if tool_response["tool_response"] is not None: self._current_thoughts.append( ToolPromptMessage( - content=tool_response["tool_response"], + content=str(tool_response["tool_response"]), tool_call_id=tool_call_id, name=tool_call_name, ) @@ -273,9 +280,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # save agent thought self.save_agent_thought( agent_thought=agent_thought, - tool_name=None, - tool_input=None, - thought=None, + tool_name="", + tool_input="", + thought="", tool_invoke_meta={ tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses }, @@ -283,7 +290,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_response["tool_call_name"]: tool_response["tool_response"] for tool_response in tool_responses }, - answer=None, + answer="", messages_ids=message_file_ids, ) self.queue_manager.publish( @@ -296,7 +303,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): iteration_step += 1 - self.update_db_variables(self.variables_pool, self.db_variables_pool) + if self.variables_pool and self.db_variables_pool: + self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event self.queue_manager.publish( QueueMessageEndEvent( @@ -389,9 +397,9 @@ def _init_system_message( if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) - return prompt_messages + return prompt_messages or [] - def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ Organize user query """ @@ -449,7 +457,7 @@ def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage] def _organize_prompt_messages(self): prompt_template = self.app_config.prompt_template.simple_prompt_template or "" self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) - query_prompt_messages = self._organize_user_query(self.query, []) + query_prompt_messages = self._organize_user_query(self.query or "", []) self.history_prompt_messages = AgentHistoryPromptTransform( model_config=self.model_config, diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 085bac8601b2da..61fa774ea5f390 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -38,7 +38,7 @@ def parse_action(json_str): except: return json_str or "" - def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: + def extra_json_from_code_block(code_block) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL) if not code_blocks: return @@ -67,15 +67,15 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, for response in llm_response: if response.delta.usage: usage_dict["usage"] = response.delta.usage - response = response.delta.message.content - if not isinstance(response, str): + response_content = response.delta.message.content + if not isinstance(response_content, str): continue # stream index = 0 - while index < len(response): + while index < len(response_content): steps = 1 - delta = response[index : index + steps] + delta = response_content[index : index + steps] yield_delta = False if delta == "`": diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index b9aae7904f5e7c..646c4badb9f73a 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -66,6 +66,8 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]: dataset_configs = config.get("dataset_configs") else: dataset_configs = {"retrieval_model": "multiple"} + if dataset_configs is None: + return None query_variable = config.get("dataset_query_variable") if dataset_configs["retrieval_model"] == "single": diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 5adcf26f1486e8..6426865115126f 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -94,7 +94,7 @@ def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> config["model"]["completion_params"] ) - return config, ["model"] + return dict(config), ["model"] @classmethod def validate_model_completion_params(cls, cp: dict) -> dict: diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py index b4dacbc409044a..92b4185abf0183 100644 --- a/api/core/app/app_config/features/opening_statement/manager.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -7,10 +7,10 @@ def convert(cls, config: dict) -> tuple[str, list]: :param config: model config args """ # opening statement - opening_statement = config.get("opening_statement") + opening_statement = config.get("opening_statement", "") # suggested questions - suggested_questions_list = config.get("suggested_questions") + suggested_questions_list = config.get("suggested_questions", []) return opening_statement, suggested_questions_list diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 6200299d21c869..a18b40712b7ce6 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -29,6 +29,7 @@ from models.account import Account from models.model import App, Conversation, EndUser, Message from models.workflow import Workflow +from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -145,7 +146,7 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, stream=streaming, @@ -313,6 +314,8 @@ def _generate_worker( # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") # chatbot app runner = AdvancedChatAppRunner( diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index 29709914b7cfb8..a506447671abfb 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -5,6 +5,7 @@ import re import threading from collections.abc import Iterable +from typing import Optional from core.app.entities.queue_entities import ( MessageQueueMessage, @@ -15,6 +16,7 @@ WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.message_entities import TextPromptMessageContent from core.model_runtime.entities.model_entities import ModelType @@ -71,8 +73,9 @@ def __init__(self, tenant_id: str, voice: str): if not voice or voice not in values: self.voice = self.voices[0].get("value") self.MAX_SENTENCE = 2 - self._last_audio_event = None - self._runtime_thread = threading.Thread(target=self._runtime).start() + self._last_audio_event: Optional[AudioTrunk] = None + # FIXME better way to handle this threading.start + threading.Thread(target=self._runtime).start() self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /): @@ -92,10 +95,21 @@ def _runtime(self): future_queue.put(futures_result) break elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): - self.msg_text += message.event.chunk.delta.message.content + message_content = message.event.chunk.delta.message.content + if not message_content: + continue + if isinstance(message_content, str): + self.msg_text += message_content + elif isinstance(message_content, list): + for content in message_content: + if not isinstance(content, TextPromptMessageContent): + continue + self.msg_text += content.data elif isinstance(message.event, QueueTextChunkEvent): self.msg_text += message.event.text elif isinstance(message.event, QueueNodeSucceededEvent): + if message.event.outputs is None: + continue self.msg_text += message.event.outputs.get("output", "") self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) @@ -121,11 +135,10 @@ def check_and_get_audio(self): if self._last_audio_event and self._last_audio_event.status == "finish": if self.executor: self.executor.shutdown(wait=False) - return self.last_message + return self._last_audio_event audio = self._audio_queue.get_nowait() if audio and audio.status == "finish": self.executor.shutdown(wait=False) - self._runtime_thread = None if audio: self._last_audio_event = audio return audio diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index cf0c9d7593429a..6339d798984800 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -109,18 +109,18 @@ def run(self) -> None: ConversationVariable.conversation_id == self.conversation.id, ) with Session(db.engine) as session: - conversation_variables = session.scalars(stmt).all() - if not conversation_variables: + db_conversation_variables = session.scalars(stmt).all() + if not db_conversation_variables: # Create conversation variables if they don't exist. - conversation_variables = [ + db_conversation_variables = [ ConversationVariable.from_variable( app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable ) for variable in workflow.conversation_variables ] - session.add_all(conversation_variables) + session.add_all(db_conversation_variables) # Convert database entities to variables. - conversation_variables = [item.to_variable() for item in conversation_variables] + conversation_variables = [item.to_variable() for item in db_conversation_variables] session.commit() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 635e482ad980ed..1073a0f2e4f706 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -2,6 +2,7 @@ import logging import time from collections.abc import Generator, Mapping +from threading import Thread from typing import Any, Optional, Union from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -64,6 +65,7 @@ from models.workflow import ( Workflow, WorkflowNodeExecution, + WorkflowRun, WorkflowRunStatus, ) @@ -81,6 +83,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _user: Union[Account, EndUser] _workflow_system_variables: dict[SystemVariableKey, Any] _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] + _conversation_name_generate_thread: Optional[Thread] = None def __init__( self, @@ -131,7 +134,7 @@ def __init__( self._conversation_name_generate_thread = None self._recorded_files: list[Mapping[str, Any]] = [] - def process(self): + def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Process generate task pipeline. :return: @@ -262,8 +265,8 @@ def _process_stream_response( :return: """ # init fake graph runtime state - graph_runtime_state = None - workflow_run = None + graph_runtime_state: Optional[GraphRuntimeState] = None + workflow_run: Optional[WorkflowRun] = None for queue_message in self._queue_manager.listen(): event = queue_message.event @@ -315,14 +318,14 @@ def _process_stream_response( workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) - response = self._workflow_node_start_to_stream_response( + response_start = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if response: - yield response + if response_start: + yield response_start elif isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._handle_workflow_node_execution_success(event) @@ -330,18 +333,18 @@ def _process_stream_response( if event.node_type in [NodeType.ANSWER, NodeType.END]: self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) - response = self._workflow_node_finish_to_stream_response( + response_finish = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if response: - yield response + if response_finish: + yield response_finish elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) - response = self._workflow_node_finish_to_stream_response( + response_finish = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -609,7 +612,10 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: del extras["metadata"]["annotation_reply"] return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras + task_id=self._application_generate_entity.task_id, + id=self._message.id, + files=self._recorded_files, + metadata=extras.get("metadata", {}), ) def _handle_output_moderation_chunk(self, text: str) -> bool: diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 417d23eccfb553..55b6ee510f228c 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -61,7 +61,7 @@ def get_app_config( app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict + config_dict = override_config_dict or {} app_mode = AppMode.value_of(app_model.mode) app_config = AgentChatAppConfig( diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index b391169e3dbe5c..63e11bdaa27f74 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -23,6 +23,7 @@ from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser +from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -97,7 +98,7 @@ def generate( # get conversation conversation = None if args.get("conversation_id"): - conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user) # get app model config app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) @@ -153,7 +154,7 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, stream=streaming, @@ -180,7 +181,7 @@ def generate( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "conversation_id": conversation.id, @@ -199,8 +200,8 @@ def generate( user=user, stream=streaming, ) - - return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # FIXME: Type hinting issue here, ignore it for now, will fix it later + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore def _generate_worker( self, @@ -224,6 +225,8 @@ def _generate_worker( # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") # chatbot app runner = AgentChatAppRunner() diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 45b1bf00934d35..ac71f02b6de03d 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -173,6 +173,8 @@ def run( return agent_entity = app_config.agent + if not agent_entity: + raise ValueError("Agent entity not found") # load tool variables tool_conversation_variables = self._load_tool_variables( @@ -200,14 +202,21 @@ def run( # change function call strategy based on LLM model llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + if not model_schema or not model_schema.features: + raise ValueError("Model schema not found") if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING - conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() - message = db.session.query(Message).filter(Message.id == message.id).first() + conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() + if conversation_result is None: + raise ValueError("Conversation not found") + message_result = db.session.query(Message).filter(Message.id == message.id).first() + if message_result is None: + raise ValueError("Message not found") db.session.close() + runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner] # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: # check LLM mode @@ -225,12 +234,12 @@ def run( runner = runner_cls( tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, - conversation=conversation, + conversation=conversation_result, app_config=app_config, model_config=application_generate_entity.model_conf, config=agent_entity, queue_manager=queue_manager, - message=message, + message=message_result, user_id=application_generate_entity.user_id, memory=memory, prompt_messages=prompt_message, @@ -257,7 +266,7 @@ def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: st """ load tool variables from database """ - tool_variables: ToolConversationVariables = ( + tool_variables: ToolConversationVariables | None = ( db.session.query(ToolConversationVariables) .filter( ToolConversationVariables.conversation_id == conversation_id, diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 629c309c065458..ce331d904cc826 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = ChatbotAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingRes return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -51,8 +51,9 @@ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingR return response @classmethod - def convert_stream_full_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + def convert_stream_full_response( # type: ignore[override] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], ) -> Generator[str, None, None]: """ Convert stream full response. @@ -82,8 +83,9 @@ def convert_stream_full_response( yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + def convert_stream_simple_response( # type: ignore[override] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 3725c6e6ddc4fd..1842fc43033ab8 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -50,7 +50,7 @@ def listen(self): # wait for APP_MAX_EXECUTION_TIME seconds to stop listen listen_timeout = dify_config.APP_MAX_EXECUTION_TIME start_time = time.time() - last_ping_time = 0 + last_ping_time: int | float = 0 while True: try: message = self._q.get(timeout=1) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 609fd03f229da8..07a248d77aee86 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,5 +1,5 @@ import time -from collections.abc import Generator, Mapping +from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity @@ -36,8 +36,8 @@ def get_pre_calculate_rest_tokens( app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["File"], + inputs: Mapping[str, str], + files: Sequence["File"], query: Optional[str] = None, ) -> int: """ @@ -64,7 +64,7 @@ def get_pre_calculate_rest_tokens( ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 if model_context_tokens is None: @@ -85,7 +85,7 @@ def get_pre_calculate_rest_tokens( prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) - rest_tokens = model_context_tokens - max_tokens - prompt_tokens + rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens if rest_tokens < 0: raise InvokeBadRequestError( "Query or prefix prompt is too long, you can reduce the prefix prompt, " @@ -111,7 +111,7 @@ def recalc_llm_max_tokens( ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 if model_context_tokens is None: @@ -136,8 +136,8 @@ def organize_prompt_messages( app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["File"], + inputs: Mapping[str, str], + files: Sequence["File"], query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None, @@ -156,6 +156,7 @@ def organize_prompt_messages( """ # get prompt without memory and context if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + prompt_transform: Union[SimplePromptTransform, AdvancedPromptTransform] prompt_transform = SimplePromptTransform() prompt_messages, stop = prompt_transform.get_prompt( app_mode=AppMode.value_of(app_record.mode), @@ -171,8 +172,11 @@ def organize_prompt_messages( memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) model_mode = ModelMode.value_of(model_config.mode) + prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] if model_mode == ModelMode.COMPLETION: advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template + if not advanced_completion_prompt_template: + raise InvokeBadRequestError("Advanced completion prompt template is required.") prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt) if advanced_completion_prompt_template.role_prefix: @@ -181,6 +185,8 @@ def organize_prompt_messages( assistant=advanced_completion_prompt_template.role_prefix.assistant, ) else: + if not prompt_template_entity.advanced_chat_prompt_template: + raise InvokeBadRequestError("Advanced chat prompt template is required.") prompt_template = [] for message in prompt_template_entity.advanced_chat_prompt_template.messages: prompt_template.append(ChatModelMessage(text=message.text, role=message.role)) @@ -246,7 +252,7 @@ def direct_output( def _handle_invoke_result( self, - invoke_result: Union[LLMResult, Generator], + invoke_result: Union[LLMResult, Generator[Any, None, None]], queue_manager: AppQueueManager, stream: bool, agent: bool = False, @@ -259,10 +265,12 @@ def _handle_invoke_result( :param agent: agent :return: """ - if not stream: + if not stream and isinstance(invoke_result, LLMResult): self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) - else: + elif stream and isinstance(invoke_result, Generator): self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) + else: + raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") def _handle_invoke_result_direct( self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool @@ -291,8 +299,8 @@ def _handle_invoke_result_stream( :param agent: agent :return: """ - model = None - prompt_messages = [] + model: str = "" + prompt_messages: list[PromptMessage] = [] text = "" usage = None for result in invoke_result: @@ -328,13 +336,14 @@ def _handle_invoke_result_stream( def moderation_for_inputs( self, + *, app_id: str, tenant_id: str, app_generate_entity: AppGenerateEntity, inputs: Mapping[str, Any], - query: str, + query: str | None = None, message_id: str, - ) -> tuple[bool, dict, str]: + ) -> tuple[bool, Mapping[str, Any], str]: """ Process sensitive_word_avoidance. :param app_id: app id @@ -350,7 +359,7 @@ def moderation_for_inputs( app_id=app_id, tenant_id=tenant_id, app_config=app_generate_entity.app_config, - inputs=inputs, + inputs=dict(inputs), query=query or "", message_id=message_id, trace_manager=app_generate_entity.trace_manager, @@ -390,9 +399,9 @@ def fill_in_inputs_from_external_data_tools( tenant_id: str, app_id: str, external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, + inputs: Mapping[str, Any], query: str, - ) -> dict: + ) -> Mapping[str, Any]: """ Fill in variable inputs from external data tools if exists. diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 5b8debaaae6a56..6ed71fcd843083 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -24,6 +24,7 @@ from factories import file_factory from models.account import Account from models.model import App, EndUser +from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -91,7 +92,7 @@ def generate( # get conversation conversation = None if args.get("conversation_id"): - conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user) # get app model config app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) @@ -104,7 +105,7 @@ def generate( # validate config override_model_config_dict = ChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=args.get("model_config") + tenant_id=app_model.tenant_id, config=args.get("model_config", {}) ) # always enable retriever resource in debugger mode @@ -146,7 +147,7 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, invoke_from=invoke_from, @@ -172,7 +173,7 @@ def generate( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "conversation_id": conversation.id, @@ -216,6 +217,8 @@ def _generate_worker( # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") # chatbot app runner = ChatAppRunner() diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 0fa7af0a7fa36d..9024c3a98273d1 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = ChatbotAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingRes return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -52,7 +52,8 @@ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingR @classmethod def convert_stream_full_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream full response. @@ -83,7 +84,8 @@ def convert_stream_full_response( @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index 1193c4b7a43632..02e5d475684cdc 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -42,7 +42,7 @@ def get_app_config( app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict + config_dict = override_config_dict or {} app_mode = AppMode.value_of(app_model.mode) app_config = CompletionAppConfig( diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 14fd33dd398927..17d0d52497ceee 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -83,8 +83,6 @@ def generate( query = query.replace("\x00", "") inputs = args["inputs"] - extras = {} - # get conversation conversation = None @@ -99,7 +97,7 @@ def generate( # validate config override_model_config_dict = CompletionAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=args.get("model_config") + tenant_id=app_model.tenant_id, config=args.get("model_config", {}) ) # parse files @@ -132,11 +130,11 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), user_id=user.id, stream=streaming, invoke_from=invoke_from, - extras=extras, + extras={}, trace_manager=trace_manager, ) @@ -157,7 +155,7 @@ def generate( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "message_id": message.id, @@ -197,6 +195,8 @@ def _generate_worker( try: # get message message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError() # chatbot app runner = CompletionAppRunner() @@ -231,7 +231,7 @@ def generate_more_like_this( user: Union[Account, EndUser], invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[str, None, None]]: + ) -> Union[Mapping[str, Any], Generator[str, None, None]]: """ Generate App response. @@ -293,7 +293,7 @@ def generate_more_like_this( model_conf=ModelConfigConverter.convert(app_config), inputs=message.inputs, query=message.query, - files=file_objs, + files=list(file_objs), user_id=user.id, stream=stream, invoke_from=invoke_from, @@ -317,7 +317,7 @@ def generate_more_like_this( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "message_id": message.id, diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 908d74ff539a5a..41278b75b42bf4 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -76,7 +76,7 @@ def run( tenant_id=app_config.tenant_id, app_generate_entity=application_generate_entity, inputs=inputs, - query=query, + query=query or "", message_id=message.id, ) except ModerationError as e: @@ -122,7 +122,7 @@ def run( tenant_id=app_record.tenant_id, model_config=application_generate_entity.model_conf, config=dataset_config, - query=query, + query=query or "", invoke_from=application_generate_entity.invoke_from, show_retrieve_source=app_config.additional_features.show_retrieve_source, hit_callback=hit_callback, diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 697f0273a5673e..73f38c3d0bcb96 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = CompletionAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -36,7 +36,7 @@ def convert_blocking_full_response(cls, blocking_response: CompletionAppBlocking return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -51,7 +51,8 @@ def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlocki @classmethod def convert_stream_full_response( - cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + cls, + stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream full response. @@ -81,7 +82,8 @@ def convert_stream_full_response( @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + cls, + stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 95ae798ec1ac74..c2e35faf89ba15 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -2,11 +2,11 @@ import logging from collections.abc import Generator from datetime import UTC, datetime -from typing import Optional, Union +from typing import Optional, Union, cast from sqlalchemy import and_ -from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError from core.app.entities.app_invoke_entities import ( @@ -42,7 +42,7 @@ def _handle_response( ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity, + AgentChatAppGenerateEntity, ], queue_manager: AppQueueManager, conversation: Conversation, @@ -144,7 +144,7 @@ def _init_generate_records( :conversation conversation :return: """ - app_config = application_generate_entity.app_config + app_config: EasyUIBasedAppConfig = cast(EasyUIBasedAppConfig, application_generate_entity.app_config) # get from source end_user_id = None @@ -267,7 +267,7 @@ def _get_conversation_introduction(self, application_generate_entity: AppGenerat except KeyError: pass - return introduction + return introduction or "" def _get_conversation(self, conversation_id: str): """ @@ -282,7 +282,7 @@ def _get_conversation(self, conversation_id: str): return conversation - def _get_message(self, message_id: str) -> Message: + def _get_message(self, message_id: str) -> Optional[Message]: """ Get message by message id :param message_id: message id diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index dc4ee9e566a2f3..1d5f21b9e0cc07 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -116,7 +116,7 @@ def generate( inputs=self._prepare_user_inputs( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), - files=system_files, + files=list(system_files), user_id=user.id, stream=streaming, invoke_from=invoke_from, diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 08d00ee1805aa2..5cdac6ad28fdaa 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -17,16 +17,16 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = WorkflowAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response :return: """ - return blocking_response.to_dict() + return dict(blocking_response.to_dict()) @classmethod - def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -36,7 +36,8 @@ def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlocking @classmethod def convert_stream_full_response( - cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + cls, + stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream full response. @@ -65,7 +66,8 @@ def convert_stream_full_response( @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + cls, + stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 885283504b4175..63f516bcc60682 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -24,6 +24,7 @@ QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, @@ -190,16 +191,15 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) elif isinstance(event, NodeRunRetryEvent): node_run_result = event.route_node_state.node_run_result + inputs: Mapping[str, Any] | None = {} + process_data: Mapping[str, Any] | None = {} + outputs: Mapping[str, Any] | None = {} + execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {} if node_run_result: inputs = node_run_result.inputs process_data = node_run_result.process_data outputs = node_run_result.outputs execution_metadata = node_run_result.metadata - else: - inputs = {} - process_data = {} - outputs = {} - execution_metadata = {} self._publish_event( QueueNodeRetryEvent( node_execution_id=event.id, @@ -289,7 +289,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) process_data=event.route_node_state.node_run_result.process_data if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs + outputs=event.route_node_state.node_run_result.outputs or {} if event.route_node_state.node_run_result else {}, error=event.route_node_state.node_run_result.error @@ -349,7 +349,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) process_data=event.route_node_state.node_run_result.process_data if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs + outputs=event.route_node_state.node_run_result.outputs or {} if event.route_node_state.node_run_result else {}, execution_metadata=event.route_node_state.node_run_result.metadata diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 31c3a996e19286..16dc91bb777a9b 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from constants import UUID_NIL -from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig +from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle from core.file import File, FileUploadConfig from core.model_runtime.entities.model_entities import AIModelEntity @@ -79,7 +79,7 @@ class AppGenerateEntity(BaseModel): task_id: str # app config - app_config: AppConfig + app_config: Any file_upload_config: Optional[FileUploadConfig] = None inputs: Mapping[str, Any] diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index d73c2eb53bfcd7..a93e533ff45d26 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -308,7 +308,7 @@ class QueueNodeSucceededEvent(AppQueueEvent): inputs: Optional[Mapping[str, Any]] = None process_data: Optional[Mapping[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None error: Optional[str] = None """single iteration duration map""" diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index dd088a897816ef..5e845eba2da1d3 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -70,7 +70,7 @@ class StreamResponse(BaseModel): event: StreamEvent task_id: str - def to_dict(self) -> dict: + def to_dict(self): return jsonable_encoder(self) @@ -474,8 +474,8 @@ class Data(BaseModel): title: str created_at: int extras: dict = {} - metadata: dict = {} - inputs: dict = {} + metadata: Mapping = {} + inputs: Mapping = {} parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None @@ -526,15 +526,15 @@ class Data(BaseModel): node_id: str node_type: str title: str - outputs: Optional[dict] = None + outputs: Optional[Mapping] = None created_at: int extras: Optional[dict] = None - inputs: Optional[dict] = None + inputs: Optional[Mapping] = None status: WorkflowNodeExecutionStatus error: Optional[str] = None elapsed_time: float total_tokens: int - execution_metadata: Optional[dict] = None + execution_metadata: Optional[Mapping] = None finished_at: int steps: int parallel_id: Optional[str] = None @@ -628,7 +628,7 @@ class AppBlockingResponse(BaseModel): task_id: str - def to_dict(self) -> dict: + def to_dict(self): return jsonable_encoder(self) diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 77b6bb554c65ec..83fd3debad4cf1 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -58,7 +58,7 @@ def query( query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]} ) - if documents: + if documents and documents[0].metadata: annotation_id = documents[0].metadata["annotation_id"] score = documents[0].metadata["score"] annotation = AppAnnotationService.get_annotation_by_id(annotation_id) diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 8fe1d96b37be0c..dcc2b4e55f6ae1 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -17,7 +17,7 @@ class RateLimit: _UNLIMITED_REQUEST_ID = "unlimited_request_id" _REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes - _instance_dict = {} + _instance_dict: dict[str, "RateLimit"] = {} def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 51d610e2cbedc6..03a81353d02625 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -62,6 +62,7 @@ def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = Non """ logger.debug("error: %s", event.error) e = event.error + err: Exception if isinstance(e, InvokeAuthorizationError): err = InvokeAuthorizationError("Incorrect API key provided") @@ -130,6 +131,7 @@ def _init_output_moderation(self) -> Optional[OutputModeration]: rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config), queue_manager=self._queue_manager, ) + return None def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: """ diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index e26b60c4d3043e..b9f8e7ca560ce7 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -2,6 +2,7 @@ import logging import time from collections.abc import Generator +from threading import Thread from typing import Optional, Union, cast from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -103,7 +104,7 @@ def __init__( ) ) - self._conversation_name_generate_thread = None + self._conversation_name_generate_thread: Optional[Thread] = None def process( self, @@ -123,7 +124,7 @@ def process( if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, self._application_generate_entity.query + self._conversation, self._application_generate_entity.query or "" ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) @@ -146,7 +147,7 @@ def _to_blocking_response( extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} if self._task_state.metadata: extras["metadata"] = self._task_state.metadata - + response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] if self._conversation.mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( task_id=self._application_generate_entity.task_id, @@ -154,7 +155,7 @@ def _to_blocking_response( id=self._message.id, mode=self._conversation.mode, message_id=self._message.id, - answer=self._task_state.llm_result.message.content, + answer=cast(str, self._task_state.llm_result.message.content), created_at=int(self._message.created_at.timestamp()), **extras, ), @@ -167,7 +168,7 @@ def _to_blocking_response( mode=self._conversation.mode, conversation_id=self._conversation.id, message_id=self._message.id, - answer=self._task_state.llm_result.message.content, + answer=cast(str, self._task_state.llm_result.message.content), created_at=int(self._message.created_at.timestamp()), **extras, ), @@ -252,7 +253,7 @@ def _wrapper_process_stream_response( yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None + self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -269,13 +270,14 @@ def _process_stream_response( break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): - self._task_state.llm_result = event.llm_result + if event.llm_result: + self._task_state.llm_result = event.llm_result else: self._handle_stop(event) # handle output moderation output_moderation_answer = self._handle_output_moderation_when_task_finished( - self._task_state.llm_result.message.content + cast(str, self._task_state.llm_result.message.content) ) if output_moderation_answer: self._task_state.llm_result.message.content = output_moderation_answer @@ -292,7 +294,9 @@ def _process_stream_response( if annotation: self._task_state.llm_result.message.content = annotation.content elif isinstance(event, QueueAgentThoughtEvent): - yield self._agent_thought_to_stream_response(event) + agent_thought_response = self._agent_thought_to_stream_response(event) + if agent_thought_response is not None: + yield agent_thought_response elif isinstance(event, QueueMessageFileEvent): response = self._message_file_to_stream_response(event) if response: @@ -307,16 +311,18 @@ def _process_stream_response( self._task_state.llm_result.prompt_messages = chunk.prompt_messages # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk(delta_text) + should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text)) if should_direct_answer: continue - self._task_state.llm_result.message.content += delta_text + current_content = cast(str, self._task_state.llm_result.message.content) + current_content += cast(str, delta_text) + self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - yield self._message_to_stream_response(delta_text, self._message.id) + yield self._message_to_stream_response(cast(str, delta_text), self._message.id) else: - yield self._agent_message_to_stream_response(delta_text, self._message.id) + yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id) elif isinstance(event, QueueMessageReplaceEvent): yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): @@ -336,8 +342,14 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No llm_result = self._task_state.llm_result usage = llm_result.usage - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() - self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() + message = db.session.query(Message).filter(Message.id == self._message.id).first() + if not message: + raise Exception(f"Message {self._message.id} not found") + self._message = message + conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() + if not conversation: + raise Exception(f"Conversation {self._conversation.id} not found") + self._conversation = conversation self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( self._model_config.mode, self._task_state.llm_result.prompt_messages @@ -346,7 +358,7 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit self._message.answer = ( - PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) + PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip()) if llm_result.message.content else "" ) @@ -374,6 +386,7 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No application_generate_entity=self._application_generate_entity, conversation=self._conversation, is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} + and hasattr(self._application_generate_entity, "conversation_id") and self._application_generate_entity.conversation_id is None, extras=self._application_generate_entity.extras, ) @@ -420,7 +433,9 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: extras["metadata"] = self._task_state.metadata return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, id=self._message.id, **extras + task_id=self._application_generate_entity.task_id, + id=self._message.id, + metadata=extras.get("metadata", {}), ) def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: @@ -440,7 +455,7 @@ def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Op :param event: agent thought event :return: """ - agent_thought: MessageAgentThought = ( + agent_thought: Optional[MessageAgentThought] = ( db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() ) db.session.refresh(agent_thought) diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index e818a090ed7d0f..007543f6d0d1f2 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -128,7 +128,7 @@ def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Opti """ message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first() - if message_file: + if message_file and message_file.url is not None: # get tool file id tool_file_id = message_file.url.split("/")[-1] # trim extension diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index df7dbace0ef818..f581e564f224ce 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -93,7 +93,7 @@ def _handle_workflow_run_start(self) -> WorkflowRun: ) # handle special values - inputs = WorkflowEntry.handle_special_values(inputs) + inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) # init workflow run with Session(db.engine, expire_on_commit=False) as session: @@ -192,7 +192,7 @@ def _handle_workflow_run_partial_success( """ workflow_run = self._refetch_workflow_run(workflow_run.id) - outputs = WorkflowEntry.handle_special_values(outputs) + outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value workflow_run.outputs = json.dumps(outputs or {}) @@ -500,7 +500,7 @@ def _workflow_start_to_stream_response( id=workflow_run.id, workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, - inputs=workflow_run.inputs_dict, + inputs=dict(workflow_run.inputs_dict or {}), created_at=int(workflow_run.created_at.timestamp()), ), ) @@ -545,7 +545,7 @@ def _workflow_finish_to_stream_response( workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, status=workflow_run.status, - outputs=workflow_run.outputs_dict, + outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None, error=workflow_run.error, elapsed_time=workflow_run.elapsed_time, total_tokens=workflow_run.total_tokens, @@ -553,7 +553,7 @@ def _workflow_finish_to_stream_response( created_by=created_by, created_at=int(workflow_run.created_at.timestamp()), finished_at=int(workflow_run.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict), + files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)), exceptions_count=workflow_run.exceptions_count, ), ) @@ -655,7 +655,7 @@ def _workflow_node_retry_to_stream_response( event: QueueNodeRetryEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, - ) -> Optional[NodeFinishStreamResponse]: + ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: """ Workflow node finish to stream response. :param event: queue node succeeded or failed event @@ -838,7 +838,7 @@ def _workflow_iteration_completed_to_stream_response( ), ) - def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]: + def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]: """ Fetch files from node outputs :param outputs_dict: node outputs dict @@ -851,9 +851,11 @@ def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping # Remove None files = [file for file in files if file] # Flatten list - files = [file for sublist in files for file in sublist] + # Flatten the list of sequences into a single list of mappings + flattened_files = [file for sublist in files if sublist for file in sublist] - return files + # Convert to tuple to match Sequence type + return tuple(flattened_files) def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]: """ @@ -891,6 +893,8 @@ def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any elif isinstance(value, File): return value.to_dict() + return None + def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: """ Refetch workflow run diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index d826edf6a0fc19..effc7eff9179ae 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -57,7 +57,7 @@ def on_tool_end( self, tool_name: str, tool_inputs: Mapping[str, Any], - tool_outputs: Sequence[ToolInvokeMessage], + tool_outputs: Sequence[ToolInvokeMessage] | str, message_id: Optional[str] = None, timer: Optional[Any] = None, trace_manager: Optional[TraceQueueManager] = None, diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 1481578630f63b..8f8aaa93d6f986 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -40,17 +40,18 @@ def on_query(self, query: str, dataset_id: str) -> None: def on_tool_end(self, documents: list[Document]) -> None: """Handle tool end.""" for document in documents: - query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + if document.metadata is not None: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata["doc_id"] + ) - if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + # add hit count to document segment + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) - db.session.commit() + db.session.commit() def return_retriever_resource_info(self, resource: list): """Handle return_retriever_resource_info.""" diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 9ed5528e43b9b8..5017835565789c 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from enum import Enum from typing import Optional @@ -72,7 +73,7 @@ class DefaultModelProviderEntity(BaseModel): label: I18nObject icon_small: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None - supported_model_types: list[ModelType] + supported_model_types: Sequence[ModelType] = [] class DefaultModelEntity(BaseModel): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index d1b34db2fe7172..2e27b362d3092c 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -40,7 +40,7 @@ logger = logging.getLogger(__name__) -original_provider_configurate_methods = {} +original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {} class ProviderConfiguration(BaseModel): @@ -99,7 +99,8 @@ def get_current_credentials(self, model_type: ModelType, model: str) -> Optional continue restrict_models = quota_configuration.restrict_models - + if self.system_configuration.credentials is None: + return None copy_credentials = self.system_configuration.credentials.copy() if restrict_models: for restrict_model in restrict_models: @@ -124,7 +125,7 @@ def get_current_credentials(self, model_type: ModelType, model: str) -> Optional return credentials - def get_system_configuration_status(self) -> SystemConfigurationStatus: + def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]: """ Get system configuration status. :return: @@ -136,6 +137,8 @@ def get_system_configuration_status(self) -> SystemConfigurationStatus: current_quota_configuration = next( (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None ) + if current_quota_configuration is None: + return None return ( SystemConfigurationStatus.ACTIVE @@ -150,7 +153,7 @@ def is_custom_configuration_available(self) -> bool: """ return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 - def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: + def get_custom_credentials(self, obfuscated: bool = False): """ Get custom credentials. @@ -172,7 +175,7 @@ def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: else [], ) - def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: + def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]: """ Validate custom credentials. :param credentials: provider credentials @@ -324,7 +327,7 @@ def get_custom_model_credentials( def custom_model_credentials_validate( self, model_type: ModelType, model: str, credentials: dict - ) -> tuple[ProviderModel, dict]: + ) -> tuple[Optional[ProviderModel], dict]: """ Validate custom model credentials. @@ -740,10 +743,10 @@ def get_provider_models( if model_type: model_types.append(model_type) else: - model_types = provider_instance.get_provider_schema().supported_model_types + model_types = list(provider_instance.get_provider_schema().supported_model_types) # Group model settings by model type and model - model_setting_map = defaultdict(dict) + model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) for model_setting in self.model_settings: model_setting_map[model_setting.model_type][model_setting.model] = model_setting @@ -822,54 +825,57 @@ def _get_system_provider_models( ]: # only customizable model for restrict_model in restrict_models: - copy_credentials = self.system_configuration.credentials.copy() - if restrict_model.base_model_name: - copy_credentials["base_model_name"] = restrict_model.base_model_name - - try: - custom_model_schema = provider_instance.get_model_instance( - restrict_model.model_type - ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) - except Exception as ex: - logger.warning(f"get custom model schema failed, {ex}") - continue - - if not custom_model_schema: - continue - - if custom_model_schema.model_type not in model_types: - continue - - status = ModelStatus.ACTIVE - if ( - custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] - ): - model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] - if model_setting.enabled is False: - status = ModelStatus.DISABLED - - provider_models.append( - ModelWithProviderEntity( - model=custom_model_schema.model, - label=custom_model_schema.label, - model_type=custom_model_schema.model_type, - features=custom_model_schema.features, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties=custom_model_schema.model_properties, - deprecated=custom_model_schema.deprecated, - provider=SimpleModelProviderEntity(self.provider), - status=status, + if self.system_configuration.credentials is not None: + copy_credentials = self.system_configuration.credentials.copy() + if restrict_model.base_model_name: + copy_credentials["base_model_name"] = restrict_model.base_model_name + + try: + custom_model_schema = provider_instance.get_model_instance( + restrict_model.model_type + ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) + except Exception as ex: + logger.warning(f"get custom model schema failed, {ex}") + continue + + if not custom_model_schema: + continue + + if custom_model_schema.model_type not in model_types: + continue + + status = ModelStatus.ACTIVE + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): + model_setting = model_setting_map[custom_model_schema.model_type][ + custom_model_schema.model + ] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + provider_models.append( + ModelWithProviderEntity( + model=custom_model_schema.model, + label=custom_model_schema.label, + model_type=custom_model_schema.model_type, + features=custom_model_schema.features, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties=custom_model_schema.model_properties, + deprecated=custom_model_schema.deprecated, + provider=SimpleModelProviderEntity(self.provider), + status=status, + ) ) - ) # if llm name not in restricted llm list, remove it restrict_model_names = [rm.model for rm in restrict_models] - for m in provider_models: - if m.model_type == ModelType.LLM and m.model not in restrict_model_names: - m.status = ModelStatus.NO_PERMISSION + for model in provider_models: + if model.model_type == ModelType.LLM and m.model not in restrict_model_names: + model.status = ModelStatus.NO_PERMISSION elif not quota_configuration.is_valid: - m.status = ModelStatus.QUOTA_EXCEEDED + model.status = ModelStatus.QUOTA_EXCEEDED return provider_models @@ -1043,7 +1049,7 @@ def __iter__(self): return iter(self.configurations) def values(self) -> Iterator[ProviderConfiguration]: - return self.configurations.values() + return iter(self.configurations.values()) def get(self, key, default=None): return self.configurations.get(key, default) diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 38cebb6b6b1c36..3f4e20ec245302 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,3 +1,5 @@ +from typing import cast + import requests from configs import dify_config @@ -5,7 +7,7 @@ class APIBasedExtensionRequestor: - timeout: (int, int) = (5, 60) + timeout: tuple[int, int] = (5, 60) """timeout for request connect and read""" def __init__(self, api_endpoint: str, api_key: str) -> None: @@ -51,4 +53,4 @@ def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: "request error, status_code: {}, content: {}".format(response.status_code, response.text[:100]) ) - return response.json() + return cast(dict, response.json()) diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 97dbaf2026e790..231743bf2a948c 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -38,8 +38,8 @@ def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: @classmethod def scan_extensions(cls): - extensions: list[ModuleExtension] = [] - position_map = {} + extensions = [] + position_map: dict[str, int] = {} # get the path of the current class current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") @@ -58,7 +58,8 @@ def scan_extensions(cls): # is builtin extension, builtin extension # in the front-end page and business logic, there are special treatments. builtin = False - position = None + # default position is 0 can not be None for sort_to_dict_by_position_map + position = 0 if "__builtin__" in file_names: builtin = True @@ -89,7 +90,7 @@ def scan_extensions(cls): logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") continue - json_data = {} + json_data: dict[str, Any] = {} if not builtin: if "schema.json" not in file_names: logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py index 3da170455e3398..9eb9e0306b577f 100644 --- a/api/core/extension/extension.py +++ b/api/core/extension/extension.py @@ -1,4 +1,6 @@ -from core.extension.extensible import ExtensionModule, ModuleExtension +from typing import cast + +from core.extension.extensible import Extensible, ExtensionModule, ModuleExtension from core.external_data_tool.base import ExternalDataTool from core.moderation.base import Moderation @@ -10,7 +12,8 @@ class Extension: def init(self): for module, module_class in self.module_classes.items(): - self.__module_extensions[module.value] = module_class.scan_extensions() + m = cast(Extensible, module_class) + self.__module_extensions[module.value] = m.scan_extensions() def module_extensions(self, module: str) -> list[ModuleExtension]: module_extensions = self.__module_extensions.get(module) @@ -35,7 +38,8 @@ def module_extension(self, module: ExtensionModule, extension_name: str) -> Modu def extension_class(self, module: ExtensionModule, extension_name: str) -> type: module_extension = self.module_extension(module, extension_name) - return module_extension.extension_class + t: type = module_extension.extension_class + return t def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None: module_extension = self.module_extension(module, extension_name) diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 54ec97a4933a94..9989c8a09013bd 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -48,7 +48,10 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str: :return: the tool query result """ # get params from config + if not self.config: + raise ValueError("config is required, config: {}".format(self.config)) api_based_extension_id = self.config.get("api_based_extension_id") + assert api_based_extension_id is not None, "api_based_extension_id is required" # get api_based_extension api_based_extension = ( diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py index 84b94e117ff5f9..6a9703a569b308 100644 --- a/api/core/external_data_tool/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -1,7 +1,7 @@ -import concurrent import logging -from concurrent.futures import ThreadPoolExecutor -from typing import Optional +from collections.abc import Mapping +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from typing import Any, Optional from flask import Flask, current_app @@ -17,9 +17,9 @@ def fetch( tenant_id: str, app_id: str, external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, + inputs: Mapping[str, Any], query: str, - ) -> dict: + ) -> Mapping[str, Any]: """ Fill in variable inputs from external data tools if exists. @@ -30,13 +30,14 @@ def fetch( :param query: the query :return: the filled inputs """ - results = {} + results: dict[str, Any] = {} + inputs = dict(inputs) with ThreadPoolExecutor() as executor: futures = {} for tool in external_data_tools: - future = executor.submit( + future: Future[tuple[str | None, str | None]] = executor.submit( self._query_external_data_tool, - current_app._get_current_object(), + current_app._get_current_object(), # type: ignore tenant_id, app_id, tool, @@ -46,9 +47,10 @@ def fetch( futures[future] = tool - for future in concurrent.futures.as_completed(futures): + for future in as_completed(futures): tool_variable, result = future.result() - results[tool_variable] = result + if tool_variable is not None: + results[tool_variable] = result inputs.update(results) return inputs @@ -59,7 +61,7 @@ def _query_external_data_tool( tenant_id: str, app_id: str, external_data_tool: ExternalDataVariableEntity, - inputs: dict, + inputs: Mapping[str, Any], query: str, ) -> tuple[Optional[str], Optional[str]]: """ diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 28721098594962..245507e17c7032 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -1,4 +1,5 @@ -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional, cast from core.extension.extensible import ExtensionModule from extensions.ext_code_based_extension import code_based_extension @@ -23,9 +24,10 @@ def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: """ code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config) extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) - extension_class.validate_config(tenant_id, config) + # FIXME mypy issue here, figure out how to fix it + extension_class.validate_config(tenant_id, config) # type: ignore - def query(self, inputs: dict, query: Optional[str] = None) -> str: + def query(self, inputs: Mapping[str, Any], query: Optional[str] = None) -> str: """ Query the external data tool. @@ -33,4 +35,4 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str: :param query: the query of chat app :return: the tool query result """ - return self.__extension_instance.query(inputs, query) + return cast(str, self.__extension_instance.query(inputs, query)) diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index 15eb351a7ef309..4a50fb85c9cca3 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -1,4 +1,5 @@ import base64 +from collections.abc import Mapping from configs import dify_config from core.helper import ssrf_proxy @@ -55,7 +56,7 @@ def to_prompt_message_content( if f.type == FileType.IMAGE: params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - prompt_class_map = { + prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = { FileType.IMAGE: ImagePromptMessageContent, FileType.AUDIO: AudioPromptMessageContent, FileType.VIDEO: VideoPromptMessageContent, @@ -63,7 +64,7 @@ def to_prompt_message_content( } try: - return prompt_class_map[f.type](**params) + return prompt_class_map[f.type].model_validate(params) except KeyError: raise ValueError(f"file type {f.type} is not supported") diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index a17b7be3675ab1..6fa101cf36192b 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: from core.tools.tool_file_manager import ToolFileManager @@ -9,4 +9,4 @@ class ToolFileParser: @staticmethod def get_tool_file_manager() -> "ToolFileManager": - return tool_file_manager["manager"] + return cast("ToolFileManager", tool_file_manager["manager"]) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 584e3e9698a88d..15b501780e766c 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -38,7 +38,7 @@ class CodeLanguage(StrEnum): class CodeExecutor: - dependencies_cache = {} + dependencies_cache: dict[str, str] = {} dependencies_cache_lock = Lock() code_template_transformers: dict[CodeLanguage, type[TemplateTransformer]] = { @@ -103,19 +103,19 @@ def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str: ) try: - response = response.json() + response_data = response.json() except: raise CodeExecutionError("Failed to parse response") - if (code := response.get("code")) != 0: - raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}") + if (code := response_data.get("code")) != 0: + raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}") - response = CodeExecutionResponse(**response) + response_code = CodeExecutionResponse(**response_data) - if response.data.error: - raise CodeExecutionError(response.data.error) + if response_code.data.error: + raise CodeExecutionError(response_code.data.error) - return response.data.stdout or "" + return response_code.data.stdout or "" @classmethod def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]): diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py index db2eb5ebb6b19a..264947b5686d0e 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_formatter.py +++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py @@ -1,9 +1,11 @@ +from collections.abc import Mapping + from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage class Jinja2Formatter: @classmethod - def format(cls, template: str, inputs: dict) -> str: + def format(cls, template: str, inputs: Mapping[str, str]) -> str: """ Format template :param template: template @@ -11,5 +13,4 @@ def format(cls, template: str, inputs: dict) -> str: :return: """ result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs) - - return result["result"] + return str(result.get("result", "")) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index 605719747a7b56..baa792b5bc6c41 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -29,8 +29,7 @@ def extract_result_str_from_response(cls, response: str): result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) if not result: raise ValueError("Failed to parse result") - result = result.group(1) - return result + return result.group(1) @classmethod def transform_response(cls, response: str) -> Mapping[str, Any]: diff --git a/api/core/helper/lru_cache.py b/api/core/helper/lru_cache.py index 518962c1652df7..81501d2e4e23b2 100644 --- a/api/core/helper/lru_cache.py +++ b/api/core/helper/lru_cache.py @@ -4,7 +4,7 @@ class LRUCache: def __init__(self, capacity: int): - self.cache = OrderedDict() + self.cache: OrderedDict[Any, Any] = OrderedDict() self.capacity = capacity def get(self, key: Any) -> Any: diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 5e274f8916869d..35349210bd53ab 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -30,7 +30,7 @@ def get(self) -> Optional[dict]: except JSONDecodeError: return None - return cached_provider_credentials + return dict(cached_provider_credentials) else: return None diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index da0fd0031cc6dc..543444463b9f1a 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -22,6 +22,7 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) provider_name = model_config.provider if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: hosting_openai_config = hosting_configuration.provider_map["openai"] + assert hosting_openai_config is not None # 2000 text per chunk length = 2000 @@ -34,8 +35,9 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) try: model_type_instance = OpenAIModerationModel() + # FIXME, for type hint using assert or raise ValueError is better here? moderation_result = model_type_instance.invoke( - model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk + model="text-moderation-stable", credentials=hosting_openai_config.credentials or {}, text=text_chunk ) if moderation_result is True: diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py index 1e2fefce88b632..9a041667e46df5 100644 --- a/api/core/helper/module_import_helper.py +++ b/api/core/helper/module_import_helper.py @@ -14,12 +14,13 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz if existed_spec: spec = existed_spec if not spec.loader: - raise Exception(f"Failed to load module {module_name} from {py_file_path}") + raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") else: # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly - spec = importlib.util.spec_from_file_location(module_name, py_file_path) + # FIXME: mypy does not support the type of spec.loader + spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore if not spec or not spec.loader: - raise Exception(f"Failed to load module {module_name} from {py_file_path}") + raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") if use_lazy_loader: # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports spec.loader = importlib.util.LazyLoader(spec.loader) @@ -29,7 +30,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz spec.loader.exec_module(module) return module except Exception as e: - logging.exception(f"Failed to load module {module_name} from script file '{py_file_path}'") + logging.exception(f"Failed to load module {module_name} from script file '{py_file_path!r}'") raise e @@ -57,6 +58,6 @@ def load_single_subclass_from_source( case 1: return subclasses[0] case 0: - raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path}") + raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path!r}") case _: - raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path}") + raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path!r}") diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index e848b46c5633ab..3b67b3f84838d3 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -33,7 +33,7 @@ def get(self) -> Optional[dict]: except JSONDecodeError: return None - return cached_tool_parameter + return dict(cached_tool_parameter) else: return None diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py index 94b02cf98578b1..6de5e704abf4f5 100644 --- a/api/core/helper/tool_provider_cache.py +++ b/api/core/helper/tool_provider_cache.py @@ -28,7 +28,7 @@ def get(self) -> Optional[dict]: except JSONDecodeError: return None - return cached_provider_credentials + return dict(cached_provider_credentials) else: return None diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index b47ba67f2fa64f..f9fb7275f3624f 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -42,7 +42,7 @@ class HostedModerationConfig(BaseModel): class HostingConfiguration: provider_map: dict[str, HostingProvider] = {} - moderation_config: HostedModerationConfig = None + moderation_config: Optional[HostedModerationConfig] = None def init_app(self, app: Flask) -> None: if dify_config.EDITION != "CLOUD": @@ -67,7 +67,7 @@ def init_azure_openai() -> HostingProvider: "base_model_name": "gpt-35-turbo", } - quotas = [] + quotas: list[HostingQuota] = [] hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT trial_quota = TrialHostingQuota( quota_limit=hosted_quota_limit, @@ -123,7 +123,7 @@ def init_azure_openai() -> HostingProvider: def init_openai(self) -> HostingProvider: quota_unit = QuotaUnit.CREDITS - quotas = [] + quotas: list[HostingQuota] = [] if dify_config.HOSTED_OPENAI_TRIAL_ENABLED: hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT @@ -157,7 +157,7 @@ def init_openai(self) -> HostingProvider: @staticmethod def init_anthropic() -> HostingProvider: quota_unit = QuotaUnit.TOKENS - quotas = [] + quotas: list[HostingQuota] = [] if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED: hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT @@ -187,7 +187,7 @@ def init_anthropic() -> HostingProvider: def init_minimax() -> HostingProvider: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_MINIMAX_ENABLED: - quotas = [FreeHostingQuota()] + quotas: list[HostingQuota] = [FreeHostingQuota()] return HostingProvider( enabled=True, @@ -205,7 +205,7 @@ def init_minimax() -> HostingProvider: def init_spark() -> HostingProvider: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_SPARK_ENABLED: - quotas = [FreeHostingQuota()] + quotas: list[HostingQuota] = [FreeHostingQuota()] return HostingProvider( enabled=True, @@ -223,7 +223,7 @@ def init_spark() -> HostingProvider: def init_zhipuai() -> HostingProvider: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_ZHIPUAI_ENABLED: - quotas = [FreeHostingQuota()] + quotas: list[HostingQuota] = [FreeHostingQuota()] return HostingProvider( enabled=True, diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 29e161cb747284..1f0a0d0ef1dda4 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -6,10 +6,10 @@ import threading import time import uuid -from typing import Optional, cast +from typing import Any, Optional, cast from flask import Flask, current_app -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy.orm.exc import ObjectDeletedError from configs import dify_config @@ -62,6 +62,8 @@ def run(self, dataset_documents: list[DatasetDocument]): .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) + if not processing_rule: + raise ValueError("no process rule found") index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract @@ -120,6 +122,8 @@ def run_in_splitting_status(self, dataset_document: DatasetDocument): .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) + if not processing_rule: + raise ValueError("no process rule found") index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -254,7 +258,7 @@ def indexing_estimate( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) - preview_texts = [] + preview_texts: list[str] = [] total_segments = 0 index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -285,7 +289,8 @@ def indexing_estimate( for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() try: - storage.delete(image_file.key) + if image_file: + storage.delete(image_file.key) except Exception: logging.exception( "Delete image_files failed while indexing_estimate, \ @@ -379,8 +384,9 @@ def _extract( # replace doc id to document model id text_docs = cast(list[Document], text_docs) for text_doc in text_docs: - text_doc.metadata["document_id"] = dataset_document.id - text_doc.metadata["dataset_id"] = dataset_document.dataset_id + if text_doc.metadata is not None: + text_doc.metadata["document_id"] = dataset_document.id + text_doc.metadata["dataset_id"] = dataset_document.dataset_id return text_docs @@ -400,6 +406,7 @@ def _get_splitter( """ Get the NodeParser object according to the processing rule. """ + character_splitter: TextSplitter if processing_rule.mode == "custom": # The user-defined segmentation rule rules = json.loads(processing_rule.rules) @@ -426,9 +433,10 @@ def _get_splitter( ) else: # Automatic segmentation + automatic_rules: dict[str, Any] = dict(DatasetProcessRule.AUTOMATIC_RULES["segmentation"]) character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( - chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], - chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], + chunk_size=automatic_rules["max_tokens"], + chunk_overlap=automatic_rules["chunk_overlap"], separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance, ) @@ -497,8 +505,8 @@ def _split_to_documents( """ Split the text documents into nodes. """ - all_documents = [] - all_qa_documents = [] + all_documents: list[Document] = [] + all_qa_documents: list[Document] = [] for text_doc in text_docs: # document clean document_text = self._document_clean(text_doc.page_content, processing_rule) @@ -509,10 +517,11 @@ def _split_to_documents( split_documents = [] for document_node in documents: if document_node.page_content.strip(): - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash + if document_node.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document_node.page_content) + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content document_node.page_content = remove_leading_symbols(page_content) @@ -529,7 +538,7 @@ def _split_to_documents( document_format_thread = threading.Thread( target=self.format_qa_document, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "tenant_id": tenant_id, "document_node": doc, "all_qa_documents": all_qa_documents, @@ -557,11 +566,12 @@ def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, al qa_document = Document( page_content=result["question"], metadata=document_node.metadata.model_copy() ) - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result["question"]) - qa_document.metadata["answer"] = result["answer"] - qa_document.metadata["doc_id"] = doc_id - qa_document.metadata["doc_hash"] = hash + if qa_document.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: @@ -575,7 +585,7 @@ def _split_to_documents_for_estimate( """ Split the text documents into nodes. """ - all_documents = [] + all_documents: list[Document] = [] for text_doc in text_docs: # document clean document_text = self._document_clean(text_doc.page_content, processing_rule) @@ -588,11 +598,11 @@ def _split_to_documents_for_estimate( for document in documents: if document.page_content is None or not document.page_content.strip(): continue - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(document.page_content) - - document.metadata["doc_id"] = doc_id - document.metadata["doc_hash"] = hash + if document.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document.page_content) + document.metadata["doc_id"] = doc_id + document.metadata["doc_hash"] = hash split_documents.append(document) @@ -648,7 +658,7 @@ def _load( # create keyword index create_keyword_thread = threading.Thread( target=self._process_keyword_index, - args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), + args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore ) create_keyword_thread.start() if dataset.indexing_technique == "high_quality": @@ -659,7 +669,7 @@ def _load( futures.append( executor.submit( self._process_chunk, - current_app._get_current_object(), + current_app._get_current_object(), # type: ignore index_processor, chunk_documents, dataset, diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 3a92c8d9d22562..9fe3f68f2a8af5 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -1,7 +1,7 @@ import json import logging import re -from typing import Optional +from typing import Optional, cast from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser @@ -13,6 +13,7 @@ WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) from core.model_manager import ModelManager +from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -44,10 +45,13 @@ def generate_conversation_name( prompts = [UserPromptMessage(content=prompt)] with measure_time() as timer: - response = model_instance.invoke_llm( - prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False + ), ) - answer = response.message.content + answer = cast(str, response.message.content) cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) if cleaned_answer is None: return "" @@ -94,11 +98,16 @@ def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: st prompt_messages = [UserPromptMessage(content=prompt)] try: - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters={"max_tokens": 256, "temperature": 0}, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters={"max_tokens": 256, "temperature": 0}, + stream=False, + ), ) - questions = output_parser.parse(response.message.content) + questions = output_parser.parse(cast(str, response.message.content)) except InvokeError: questions = [] except Exception as e: @@ -138,11 +147,14 @@ def generate_rule_config( ) try: - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), ) - rule_config["prompt"] = response.message.content + rule_config["prompt"] = cast(str, response.message.content) except InvokeError as e: error = str(e) @@ -178,15 +190,18 @@ def generate_rule_config( model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider") if model_config else None, - model=model_config.get("name") if model_config else None, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) try: try: # the first step to generate the task prompt - prompt_content = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + prompt_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), ) except InvokeError as e: error = str(e) @@ -195,8 +210,10 @@ def generate_rule_config( return rule_config - rule_config["prompt"] = prompt_content.message.content + rule_config["prompt"] = cast(str, prompt_content.message.content) + if not isinstance(prompt_content.message.content, str): + raise NotImplementedError("prompt content is not a string") parameter_generate_prompt = parameter_template.format( inputs={ "INPUT_TEXT": prompt_content.message.content, @@ -216,19 +233,25 @@ def generate_rule_config( statement_messages = [UserPromptMessage(content=statement_generate_prompt)] try: - parameter_content = model_instance.invoke_llm( - prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False + parameter_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False + ), ) - rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content) + rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) except InvokeError as e: error = str(e) error_step = "generate variables" try: - statement_content = model_instance.invoke_llm( - prompt_messages=statement_messages, model_parameters=model_parameters, stream=False + statement_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=statement_messages, model_parameters=model_parameters, stream=False + ), ) - rule_config["opening_statement"] = statement_content.message.content + rule_config["opening_statement"] = cast(str, statement_content.message.content) except InvokeError as e: error = str(e) error_step = "generate conversation opener" @@ -267,19 +290,22 @@ def generate_code( model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider") if model_config else None, - model=model_config.get("name") if model_config else None, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) prompt_messages = [UserPromptMessage(content=prompt)] model_parameters = {"max_tokens": max_tokens, "temperature": 0.01} try: - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), ) - generated_code = response.message.content + generated_code = cast(str, response.message.content) return {"code": generated_code, "language": code_language, "error": ""} except InvokeError as e: @@ -303,9 +329,14 @@ def generate_qa_document(cls, tenant_id: str, query, document_language: str): prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters={"temperature": 0.01, "max_tokens": 2000}, + stream=False, + ), ) - answer = response.message.content + answer = cast(str, response.message.content) return answer.strip() diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 81d08dc8854f80..003a0c85b1f12e 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -68,7 +68,7 @@ def get_history_prompt_messages( messages = list(reversed(thread_messages)) - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] for message in messages: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 1986688551b601..d1e71148cd6023 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -124,17 +124,20 @@ def invoke_llm( raise Exception("Model type instance is not LargeLanguageModel") self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, + return cast( + Union[LLMResult, Generator], + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + callbacks=callbacks, + ), ) def get_llm_num_tokens( @@ -151,12 +154,15 @@ def get_llm_num_tokens( raise Exception("Model type instance is not LargeLanguageModel") self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model, - credentials=self.credentials, - prompt_messages=prompt_messages, - tools=tools, + return cast( + int, + self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model, + credentials=self.credentials, + prompt_messages=prompt_messages, + tools=tools, + ), ) def invoke_text_embedding( @@ -174,13 +180,16 @@ def invoke_text_embedding( raise Exception("Model type instance is not TextEmbeddingModel") self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - texts=texts, - user=user, - input_type=input_type, + return cast( + TextEmbeddingResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + texts=texts, + user=user, + input_type=input_type, + ), ) def get_text_embedding_num_tokens(self, texts: list[str]) -> int: @@ -194,11 +203,14 @@ def get_text_embedding_num_tokens(self, texts: list[str]) -> int: raise Exception("Model type instance is not TextEmbeddingModel") self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model, - credentials=self.credentials, - texts=texts, + return cast( + int, + self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model, + credentials=self.credentials, + texts=texts, + ), ) def invoke_rerank( @@ -223,15 +235,18 @@ def invoke_rerank( raise Exception("Model type instance is not RerankModel") self.model_type_instance = cast(RerankModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - user=user, + return cast( + RerankResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + user=user, + ), ) def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: @@ -246,12 +261,15 @@ def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: raise Exception("Model type instance is not ModerationModel") self.model_type_instance = cast(ModerationModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - text=text, - user=user, + return cast( + bool, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + text=text, + user=user, + ), ) def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str: @@ -266,12 +284,15 @@ def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str raise Exception("Model type instance is not Speech2TextModel") self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - file=file, - user=user, + return cast( + str, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + file=file, + user=user, + ), ) def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]: @@ -288,17 +309,20 @@ def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Option raise Exception("Model type instance is not TTSModel") self.model_type_instance = cast(TTSModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - content_text=content_text, - user=user, - tenant_id=tenant_id, - voice=voice, + return cast( + Iterable[bytes], + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + content_text=content_text, + user=user, + tenant_id=tenant_id, + voice=voice, + ), ) - def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): + def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any: """ Round-robin invoke :param function: function to invoke @@ -309,7 +333,7 @@ def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): if not self.load_balancing_manager: return function(*args, **kwargs) - last_exception = None + last_exception: Union[InvokeRateLimitError, InvokeAuthorizationError, InvokeConnectionError, None] = None while True: lb_config = self.load_balancing_manager.fetch_next() if not lb_config: @@ -463,7 +487,7 @@ def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: if real_index > max_index: real_index = 0 - config = self._load_balancing_configs[real_index] + config: ModelLoadBalancingConfiguration = self._load_balancing_configs[real_index] if self.in_cooldown(config): cooldown_load_balancing_configs.append(config) @@ -507,8 +531,7 @@ def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: self._tenant_id, self._provider, self._model_type.value, self._model, config.id ) - res = redis_client.exists(cooldown_cache_key) - res = cast(bool, res) + res: bool = redis_client.exists(cooldown_cache_key) return res @staticmethod diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 3b6b825244dfdc..1f21a2d3763c4a 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -1,7 +1,8 @@ import json import logging import sys -from typing import Optional +from collections.abc import Sequence +from typing import Optional, cast from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -20,7 +21,7 @@ def on_before_invoke( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -76,7 +77,7 @@ def on_new_chunk( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ): @@ -94,7 +95,7 @@ def on_new_chunk( :param stream: is stream response :param user: unique user id """ - sys.stdout.write(chunk.delta.message.content) + sys.stdout.write(cast(str, chunk.delta.message.content)) sys.stdout.flush() def on_after_invoke( @@ -106,7 +107,7 @@ def on_after_invoke( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -147,7 +148,7 @@ def on_invoke_error( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 0efe46f87d6de9..2f682ceef578dc 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -3,7 +3,7 @@ from enum import Enum, StrEnum from typing import Optional -from pydantic import BaseModel, Field, computed_field, field_validator +from pydantic import BaseModel, Field, field_validator class PromptMessageRole(Enum): @@ -89,7 +89,6 @@ class MultiModalPromptMessageContent(PromptMessageContent): url: str = Field(default="", description="the url of multi-modal file") mime_type: str = Field(default=..., description="the mime type of multi-modal file") - @computed_field(return_type=str) @property def data(self): return self.url or f"data:{self.mime_type};base64,{self.base64_data}" diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 79a1d28ebe637e..e2b95603379348 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -1,7 +1,6 @@ import decimal import os from abc import ABC, abstractmethod -from collections.abc import Mapping from typing import Optional from pydantic import ConfigDict @@ -36,7 +35,7 @@ class AIModel(ABC): model_config = ConfigDict(protected_namespaces=()) @abstractmethod - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials @@ -214,7 +213,7 @@ def predefined_models(self) -> list[AIModelEntity]: return model_schemas - def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]: + def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]: """ Get model schema by model name and credentials @@ -236,9 +235,7 @@ def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> return None - def get_customizable_model_schema_from_credentials( - self, model: str, credentials: Mapping - ) -> Optional[AIModelEntity]: + def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get customizable model schema from credentials @@ -248,7 +245,7 @@ def get_customizable_model_schema_from_credentials( """ return self._get_customizable_model_schema(model, credentials) - def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: + def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get customizable model schema and fill in the template """ @@ -301,7 +298,7 @@ def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Op return schema - def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get customizable model schema diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 8faeffa872b40f..402a30376b7546 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,7 +2,7 @@ import re import time from abc import abstractmethod -from collections.abc import Generator, Mapping, Sequence +from collections.abc import Generator, Sequence from typing import Optional, Union from pydantic import ConfigDict @@ -48,7 +48,7 @@ def invoke( prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -291,12 +291,12 @@ def _code_block_mode_stream_processor( content = piece.delta.message.content piece.delta.message.content = "" yield piece - piece = content + content_piece = content else: yield piece continue new_piece: str = "" - for char in piece: + for char in content_piece: char = str(char) if state == "normal": if char == "`": @@ -350,7 +350,7 @@ def _code_block_mode_stream_processor_with_backtick( piece.delta.message.content = "" # Yield a piece with cleared content before processing it to maintain the generator structure yield piece - piece = content + content_piece = content else: # Yield pieces without content directly yield piece @@ -360,7 +360,7 @@ def _code_block_mode_stream_processor_with_backtick( continue new_piece: str = "" - for char in piece: + for char in content_piece: if state == "search_start": if char == "`": backtick_count += 1 @@ -535,7 +535,7 @@ def get_parameter_rules(self, model: str, credentials: dict) -> list[ParameterRu return [] - def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode: + def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode: """ Get model mode diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 4374093de4ab38..36e3e7bd557163 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -104,9 +104,10 @@ def get_model_instance(self, model_type: ModelType) -> AIModel: mod = import_module_from_source( module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path ) + # FIXME "type" has no attribute "__abstractmethods__" ignore it for now fix it later model_class = next( filter( - lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, + lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, # type: ignore get_subclasses_from_module(mod, AIModel), ), None, diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index 2d38fba955fb86..33135129082b1d 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -89,7 +89,8 @@ def _get_context_size(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] + content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] + return content_size return 1000 @@ -104,6 +105,7 @@ def _get_max_chunks(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + return max_chunks return 1 diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 5fe6dda6ad5d79..6dab0aaf2d41e7 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -2,9 +2,9 @@ from threading import Lock from typing import Any -from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer +from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore -_tokenizer = None +_tokenizer: Any = None _lock = Lock() diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index b394ea4e9d22fe..6ce316b137abb4 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -127,7 +127,8 @@ def _get_model_audio_type(self, model: str, credentials: dict) -> str: if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties: raise ValueError("this model does not support audio type") - return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] + audio_type: str = model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] + return audio_type def _get_model_word_limit(self, model: str, credentials: dict) -> int: """ @@ -138,8 +139,9 @@ def _get_model_word_limit(self, model: str, credentials: dict) -> int: if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties: raise ValueError("this model does not support word limit") + world_limit: int = model_schema.model_properties[ModelPropertyKey.WORD_LIMIT] - return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT] + return world_limit def _get_model_workers_limit(self, model: str, credentials: dict) -> int: """ @@ -150,8 +152,9 @@ def _get_model_workers_limit(self, model: str, credentials: dict) -> int: if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties: raise ValueError("this model does not support max workers") + workers_limit: int = model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] - return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] + return workers_limit @staticmethod def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"): diff --git a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py index a2b14cf3dbe6d4..4aa09e61fd3599 100644 --- a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py @@ -64,10 +64,12 @@ def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) + if not ai_model_entity: + return None return ai_model_entity.entity @staticmethod - def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: + def _get_ai_model_entity(base_model_name: str, model: str) -> Optional[AzureBaseModel]: for ai_model_entity in SPEECH2TEXT_BASE_MODELS: if ai_model_entity.base_model_name == base_model_name: ai_model_entity_copy = copy.deepcopy(ai_model_entity) diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index 173b9d250c1743..6d50ba9163984f 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -114,6 +114,8 @@ def _process_sentence(self, sentence: str, model: str, voice, credentials: dict) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) + if not ai_model_entity: + return None return ai_model_entity.entity @staticmethod diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 75ed7ad62404cb..29bd673d576fc9 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -6,9 +6,9 @@ from typing import Optional, Union, cast # 3rd import -import boto3 -from botocore.config import Config -from botocore.exceptions import ( +import boto3 # type: ignore +from botocore.config import Config # type: ignore +from botocore.exceptions import ( # type: ignore ClientError, EndpointConnectionError, NoRegionError, diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py index aba8fedbc097e5..3a0a241f7ea0c0 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -44,7 +44,7 @@ def _invoke( :return: rerank result """ if len(docs) == 0: - return RerankResult(model=model, docs=docs) + return RerankResult(model=model, docs=[]) # initialize client client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) @@ -62,7 +62,7 @@ def _invoke( # format document rerank_document = RerankDocument( index=result.index, - text=result.document.text, + text=result.document.text if result.document else "", score=result.relevance_score, ) diff --git a/api/core/model_runtime/model_providers/fireworks/_common.py b/api/core/model_runtime/model_providers/fireworks/_common.py index 378ced3a4019ba..38d0a9dfbcadee 100644 --- a/api/core/model_runtime/model_providers/fireworks/_common.py +++ b/api/core/model_runtime/model_providers/fireworks/_common.py @@ -1,5 +1,3 @@ -from collections.abc import Mapping - import openai from core.model_runtime.errors.invoke import ( @@ -13,7 +11,7 @@ class _CommonFireworks: - def _to_credential_kwargs(self, credentials: Mapping) -> dict: + def _to_credential_kwargs(self, credentials: dict) -> dict: """ Transform credentials to kwargs for model instance diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py index c745a7e978f4be..4c036283893fcc 100644 --- a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py @@ -1,5 +1,4 @@ import time -from collections.abc import Mapping from typing import Optional, Union import numpy as np @@ -93,7 +92,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int """ return sum(self._get_num_tokens_by_gpt2(text) for text in texts) - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials diff --git a/api/core/model_runtime/model_providers/gitee_ai/_common.py b/api/core/model_runtime/model_providers/gitee_ai/_common.py index 0750f3b75d0542..ad6600faf7bc15 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/_common.py +++ b/api/core/model_runtime/model_providers/gitee_ai/_common.py @@ -1,4 +1,4 @@ -from dashscope.common.error import ( +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py index 832ba927406c4c..737d3d5c931221 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional import httpx @@ -51,7 +51,7 @@ def _invoke( base_url = base_url.removesuffix("/") try: - body = {"model": model, "query": query, "documents": docs} + body: dict[str, Any] = {"model": model, "query": query, "documents": docs} if top_n is not None: body["top_n"] = top_n response = httpx.post( diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py index b833c5652c650a..a1fa89c5b34af6 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py @@ -24,7 +24,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: super().validate_credentials(model, credentials) @staticmethod - def _add_custom_parameters(credentials: dict, model: str) -> None: + def _add_custom_parameters(credentials: dict, model: Optional[str]) -> None: if model is None: model = "bge-m3" diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py index 36dcea405d0974..dc91257daf9d4e 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional import requests @@ -13,9 +13,10 @@ class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel): Model class for OpenAI text2speech model. """ + # FIXME this Any return will be better type def _invoke( self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None - ) -> any: + ) -> Any: """ _invoke text2speech model @@ -47,7 +48,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: + # FIXME this Any return will be better type + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any: """ _tts_invoke_streaming text2speech model :param model: model name diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 7d19ccbb74a011..98273f60a41190 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -7,7 +7,7 @@ from typing import Optional, Union import google.ai.generativelanguage as glm -import google.generativeai as genai +import google.generativeai as genai # type: ignore import requests from google.api_core import exceptions from google.generativeai.types import ContentType, File, GenerateContentResponse diff --git a/api/core/model_runtime/model_providers/huggingface_hub/_common.py b/api/core/model_runtime/model_providers/huggingface_hub/_common.py index 3c4020b6eedf24..d8a09265e21059 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/_common.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/_common.py @@ -1,4 +1,4 @@ -from huggingface_hub.utils import BadRequestError, HfHubHTTPError +from huggingface_hub.utils import BadRequestError, HfHubHTTPError # type: ignore from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index 9d29237fdde573..cdb4103cd83712 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -1,9 +1,9 @@ from collections.abc import Generator from typing import Optional, Union -from huggingface_hub import InferenceClient -from huggingface_hub.hf_api import HfApi -from huggingface_hub.utils import BadRequestError +from huggingface_hub import InferenceClient # type: ignore +from huggingface_hub.hf_api import HfApi # type: ignore +from huggingface_hub.utils import BadRequestError # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index 8278d1e64def89..4ca5379405f4e6 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ import numpy as np import requests -from huggingface_hub import HfApi, InferenceClient +from huggingface_hub import HfApi, InferenceClient # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py index 2014de8516bc11..2dd45f065d5e26 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py @@ -3,11 +3,11 @@ from collections.abc import Generator from typing import cast -from tencentcloud.common import credential -from tencentcloud.common.exception import TencentCloudSDKException -from tencentcloud.common.profile.client_profile import ClientProfile -from tencentcloud.common.profile.http_profile import HttpProfile -from tencentcloud.hunyuan.v20230901 import hunyuan_client, models +from tencentcloud.common import credential # type: ignore +from tencentcloud.common.exception import TencentCloudSDKException # type: ignore +from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore +from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore +from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -305,7 +305,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: elif isinstance(message, ToolPromptMessage): message_text = f"{tool_prompt} {content}" elif isinstance(message, SystemPromptMessage): - message_text = content + message_text = content if isinstance(content, str) else "" else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py index b6d857cb37cba0..856cda90d35a22 100644 --- a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py @@ -3,11 +3,11 @@ import time from typing import Optional -from tencentcloud.common import credential -from tencentcloud.common.exception import TencentCloudSDKException -from tencentcloud.common.profile.client_profile import ClientProfile -from tencentcloud.common.profile.http_profile import HttpProfile -from tencentcloud.hunyuan.v20230901 import hunyuan_client, models +from tencentcloud.common import credential # type: ignore +from tencentcloud.common.exception import TencentCloudSDKException # type: ignore +from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore +from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore +from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py index d80cbfa83d6425..1fc0f8c028ba92 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py @@ -1,11 +1,11 @@ from os.path import abspath, dirname, join from threading import Lock -from transformers import AutoTokenizer +from transformers import AutoTokenizer # type: ignore class JinaTokenizer: - _tokenizer = None + _tokenizer: AutoTokenizer | None = None _lock = Lock() @classmethod diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 88cc0e8e0f32d0..357631b2dba0b9 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -40,7 +40,7 @@ def generate( url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}" - extra_kwargs = {} + extra_kwargs: dict[str, Any] = {} if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] @@ -117,19 +117,19 @@ def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ handle chat generate response """ - response = response.json() - if "base_resp" in response and response["base_resp"]["status_code"] != 0: - code = response["base_resp"]["status_code"] - msg = response["base_resp"]["status_msg"] + response_data = response.json() + if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0: + code = response_data["base_resp"]["status_code"] + msg = response_data["base_resp"]["status_msg"] self._handle_error(code, msg) - message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) + message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { "prompt_tokens": 0, - "completion_tokens": response["usage"]["total_tokens"], - "total_tokens": response["usage"]["total_tokens"], + "completion_tokens": response_data["usage"]["total_tokens"], + "total_tokens": response_data["usage"]["total_tokens"], } - message.stop_reason = response["choices"][0]["finish_reason"] + message.stop_reason = response_data["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: @@ -139,10 +139,10 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator for line in response.iter_lines(): if not line: continue - line: str = line.decode("utf-8") - if line.startswith("data: "): - line = line[6:].strip() - data = loads(line) + line_str: str = line.decode("utf-8") + if line_str.startswith("data: "): + line_str = line_str[6:].strip() + data = loads(line_str) if "base_resp" in data and data["base_resp"]["status_code"] != 0: code = data["base_resp"]["status_code"] @@ -162,5 +162,5 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator continue for choice in choices: - message = choice["delta"] - yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value) + message_choice = choice["delta"] + yield MinimaxMessage(content=message_choice, role=MinimaxMessage.Role.ASSISTANT.value) diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 8b8fdbb6bdf558..284b61829f9729 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -41,7 +41,7 @@ def generate( url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}" - extra_kwargs = {} + extra_kwargs: dict[str, Any] = {} if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] @@ -122,19 +122,19 @@ def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ handle chat generate response """ - response = response.json() - if "base_resp" in response and response["base_resp"]["status_code"] != 0: - code = response["base_resp"]["status_code"] - msg = response["base_resp"]["status_msg"] + response_data = response.json() + if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0: + code = response_data["base_resp"]["status_code"] + msg = response_data["base_resp"]["status_msg"] self._handle_error(code, msg) - message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) + message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { "prompt_tokens": 0, - "completion_tokens": response["usage"]["total_tokens"], - "total_tokens": response["usage"]["total_tokens"], + "completion_tokens": response_data["usage"]["total_tokens"], + "total_tokens": response_data["usage"]["total_tokens"], } - message.stop_reason = response["choices"][0]["finish_reason"] + message.stop_reason = response_data["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: @@ -144,10 +144,10 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator for line in response.iter_lines(): if not line: continue - line: str = line.decode("utf-8") - if line.startswith("data: "): - line = line[6:].strip() - data = loads(line) + line_str: str = line.decode("utf-8") + if line_str.startswith("data: "): + line_str = line_str[6:].strip() + data = loads(line_str) if "base_resp" in data and data["base_resp"]["status_code"] != 0: code = data["base_resp"]["status_code"] diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py index 88ebe5e2e00e7a..c248db374a2504 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/types.py +++ b/api/core/model_runtime/model_providers/minimax/llm/types.py @@ -11,9 +11,9 @@ class Role(Enum): role: str = Role.USER.value content: str - usage: dict[str, int] = None + usage: dict[str, int] | None = None stop_reason: str = "" - function_call: dict[str, Any] = None + function_call: dict[str, Any] | None = None def to_dict(self) -> dict[str, Any]: if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: diff --git a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py index 56a707333c40e9..8a4c19d4d8f71b 100644 --- a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py @@ -2,8 +2,8 @@ from functools import wraps from typing import Optional -from nomic import embed -from nomic import login as nomic_login +from nomic import embed # type: ignore +from nomic import login as nomic_login # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/model_providers/oci/llm/llm.py b/api/core/model_runtime/model_providers/oci/llm/llm.py index 1e1fc5b3ea89aa..9f676573fc2ece 100644 --- a/api/core/model_runtime/model_providers/oci/llm/llm.py +++ b/api/core/model_runtime/model_providers/oci/llm/llm.py @@ -5,8 +5,8 @@ from collections.abc import Generator from typing import Optional, Union -import oci -from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse +import oci # type: ignore +from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse # type: ignore from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py index 50fa63768c241b..5a428c9fed0466 100644 --- a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import numpy as np -import oci +import oci # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 83c4facc8db76c..3543fe58bb68d2 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -61,6 +61,7 @@ def _invoke( headers = {"Content-Type": "application/json"} endpoint_url = credentials.get("base_url") + assert endpoint_url is not None, "Base URL is required for Ollama API" if not endpoint_url.endswith("/"): endpoint_url += "/" diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py index 2181bb4f08fd8f..ac2b3e6881c740 100644 --- a/api/core/model_runtime/model_providers/openai/_common.py +++ b/api/core/model_runtime/model_providers/openai/_common.py @@ -1,5 +1,3 @@ -from collections.abc import Mapping - import openai from httpx import Timeout @@ -14,7 +12,7 @@ class _CommonOpenAI: - def _to_credential_kwargs(self, credentials: Mapping) -> dict: + def _to_credential_kwargs(self, credentials: dict) -> dict: """ Transform credentials to kwargs for model instance diff --git a/api/core/model_runtime/model_providers/openai/moderation/moderation.py b/api/core/model_runtime/model_providers/openai/moderation/moderation.py index 619044d808cdf6..227e4b0c152a05 100644 --- a/api/core/model_runtime/model_providers/openai/moderation/moderation.py +++ b/api/core/model_runtime/model_providers/openai/moderation/moderation.py @@ -93,7 +93,8 @@ def _get_max_characters_per_chunk(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK] + max_characters_per_chunk: int = model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK] + return max_characters_per_chunk return 2000 @@ -108,6 +109,7 @@ def _get_max_chunks(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + return max_chunks return 1 diff --git a/api/core/model_runtime/model_providers/openai/openai.py b/api/core/model_runtime/model_providers/openai/openai.py index aa6f38ce9fae5a..c546441af61d9b 100644 --- a/api/core/model_runtime/model_providers/openai/openai.py +++ b/api/core/model_runtime/model_providers/openai/openai.py @@ -1,5 +1,4 @@ import logging -from collections.abc import Mapping from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -9,7 +8,7 @@ class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: Mapping) -> None: + def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials if validate failed, raise exception diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py index a490537e51a6ad..74229a089aa45e 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py @@ -33,6 +33,7 @@ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" endpoint_url = urljoin(endpoint_url, "audio/transcriptions") diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 9da8f55d0a7ed9..b4d6c6c6ca9942 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -55,6 +55,7 @@ def _invoke( headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py b/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py index 8239c625f7ada8..53e895b0ecb376 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py @@ -44,6 +44,7 @@ def _invoke( # Construct endpoint URL endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" endpoint_url = urljoin(endpoint_url, "audio/speech") diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 2789a9250a1d35..e9509b544d9f4e 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -1,7 +1,7 @@ from collections.abc import Generator from enum import Enum from json import dumps, loads -from typing import Any, Union +from typing import Any, Optional, Union from requests import Response, post from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema @@ -20,7 +20,7 @@ class Role(Enum): role: str = Role.USER.value content: str - usage: dict[str, int] = None + usage: Optional[dict[str, int]] = None stop_reason: str = "" def to_dict(self) -> dict[str, Any]: @@ -165,17 +165,17 @@ def _handle_chat_stream_generate_response( if not line: continue - line: str = line.decode("utf-8") - if line.startswith("data: "): - line = line[6:].strip() + line_str: str = line.decode("utf-8") + if line_str.startswith("data: "): + line_str = line_str[6:].strip() - if line == "[DONE]": + if line_str == "[DONE]": return try: - data = loads(line) + data = loads(line_str) except Exception as e: - raise InternalServerError(f"Failed to convert response to json: {e} with text: {line}") + raise InternalServerError(f"Failed to convert response to json: {e} with text: {line_str}") output = data["outputs"] diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index 7bbd31e87c595d..40ea4dc0118026 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -53,14 +53,16 @@ def _invoke( api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - + endpoint_url: Optional[str] if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": endpoint_url = "https://cloud.perfxlab.cn/v1/" else: endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" + assert isinstance(endpoint_url, str) endpoint_url = urljoin(endpoint_url, "embeddings") extra_model_kwargs = {} @@ -142,13 +144,16 @@ def validate_credentials(self, model: str, credentials: dict) -> None: if api_key: headers["Authorization"] = f"Bearer {api_key}" + endpoint_url: Optional[str] if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": endpoint_url = "https://cloud.perfxlab.cn/v1/" else: endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" + assert isinstance(endpoint_url, str) endpoint_url = urljoin(endpoint_url, "embeddings") payload = {"input": "ping", "model": model} diff --git a/api/core/model_runtime/model_providers/replicate/_common.py b/api/core/model_runtime/model_providers/replicate/_common.py index 915f6e0eefcd08..3e2cf2adb306db 100644 --- a/api/core/model_runtime/model_providers/replicate/_common.py +++ b/api/core/model_runtime/model_providers/replicate/_common.py @@ -1,4 +1,4 @@ -from replicate.exceptions import ModelError, ReplicateError +from replicate.exceptions import ModelError, ReplicateError # type: ignore from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 3641b35dc02a39..1e7858100b0429 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -1,9 +1,9 @@ from collections.abc import Generator from typing import Optional, Union -from replicate import Client as ReplicateClient -from replicate.exceptions import ReplicateError -from replicate.prediction import Prediction +from replicate import Client as ReplicateClient # type: ignore +from replicate.exceptions import ReplicateError # type: ignore +from replicate.prediction import Prediction # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 41759fe07d0cac..aaf825388a9043 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -2,11 +2,11 @@ import time from typing import Optional -from replicate import Client as ReplicateClient +from replicate import Client as ReplicateClient # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel @@ -86,7 +86,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> Option label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={"context_size": 4096, "max_chunks": 1}, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096, ModelPropertyKey.MAX_CHUNKS: 1}, ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index 5ff00f008eb621..b8c979b1f53ce9 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -4,7 +4,7 @@ from collections.abc import Generator, Iterator from typing import Any, Optional, Union, cast -import boto3 +import boto3 # type: ignore from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -83,7 +83,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): sagemaker_session: Any = None predictor: Any = None - sagemaker_endpoint: str = None + sagemaker_endpoint: str | None = None def _handle_chat_generate_response( self, @@ -209,8 +209,8 @@ def _invoke( :param user: unique user id :return: full response or stream response chunk generator result """ - from sagemaker import Predictor, serializers - from sagemaker.session import Session + from sagemaker import Predictor, serializers # type: ignore + from sagemaker.session import Session # type: ignore if not self.sagemaker_session: access_key = credentials.get("aws_access_key_id") diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py index df797bae265825..7daab6d8653d33 100644 --- a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -3,7 +3,7 @@ import operator from typing import Any, Optional -import boto3 +import boto3 # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -114,6 +114,7 @@ def _invoke( except Exception as e: logger.exception(f"Failed to invoke rerank model, model: {model}") + raise InvokeError(f"Failed to invoke rerank model, model: {model}, error: {str(e)}") def validate_credentials(self, model: str, credentials: dict) -> None: """ diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py index 2d50e9c7b4c28a..a6aca130456063 100644 --- a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py @@ -2,7 +2,7 @@ import logging from typing import IO, Any, Optional -import boto3 +import boto3 # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -67,6 +67,7 @@ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional s3_prefix = "dify/speech2text/" sagemaker_endpoint = credentials.get("sagemaker_endpoint") bucket = credentials.get("audio_s3_cache_bucket") + assert bucket is not None, "audio_s3_cache_bucket is required in credentials" s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix) payload = {"audio_s3_presign_uri": s3_presign_url} diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py index ef4ddcd6a72847..e7eccd997d11c1 100644 --- a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ import time from typing import Any, Optional -import boto3 +import boto3 # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject @@ -118,6 +118,7 @@ def _invoke( except Exception as e: logger.exception(f"Failed to invoke text embedding model, model: {model}, line: {line}") + raise InvokeError(str(e)) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ diff --git a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py index 6a5946453be07f..62231c518deef1 100644 --- a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py +++ b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Any, Optional -import boto3 +import boto3 # type: ignore import requests from core.model_runtime.entities.common_entities import I18nObject diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py index e3a323a4965bc7..f61e8b82e4db99 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py +++ b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py @@ -43,7 +43,7 @@ def _add_custom_parameters(cls, credentials: dict) -> None: credentials["mode"] = "chat" credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: return AIModelEntity( model=model, label=I18nObject(en_US=model, zh_Hans=model), diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 1181ba699af886..cb6f28b6c27fa9 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -1,6 +1,6 @@ import threading from collections.abc import Generator -from typing import Optional, Union +from typing import Optional, Union, cast from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -270,7 +270,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" elif isinstance(message, SystemPromptMessage): - message_text = content + message_text = cast(str, content) else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index b96d43979ef54a..03eac194235e83 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -12,6 +12,7 @@ AIModelEntity, DefaultParameterName, FetchFrom, + ModelFeature, ModelPropertyKey, ModelType, ParameterRule, @@ -67,7 +68,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode cred_with_endpoint = self._update_endpoint_url(credentials=credentials) REPETITION_PENALTY = "repetition_penalty" TOP_K = "top_k" - features = [] + features: list[ModelFeature] = [] entity = AIModelEntity( model=model, diff --git a/api/core/model_runtime/model_providers/tongyi/_common.py b/api/core/model_runtime/model_providers/tongyi/_common.py index 8a50c7aa05f38c..bb68319555007f 100644 --- a/api/core/model_runtime/model_providers/tongyi/_common.py +++ b/api/core/model_runtime/model_providers/tongyi/_common.py @@ -1,4 +1,4 @@ -from dashscope.common.error import ( +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 0c1f6518811aa8..61ebd45ed64a6d 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -7,9 +7,9 @@ from pathlib import Path from typing import Optional, Union, cast -from dashscope import Generation, MultiModalConversation, get_tokenizer -from dashscope.api_entities.dashscope_response import GenerationResponse -from dashscope.common.error import ( +from dashscope import Generation, MultiModalConversation, get_tokenizer # type: ignore +from dashscope.api_entities.dashscope_response import GenerationResponse # type: ignore +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py index a5ce9ead6ee3be..ed682cb0f3c1e4 100644 --- a/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py @@ -1,7 +1,7 @@ from typing import Optional -import dashscope -from dashscope.common.error import ( +import dashscope # type: ignore +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, @@ -51,7 +51,7 @@ def _invoke( :return: rerank result """ if len(docs) == 0: - return RerankResult(model=model, docs=docs) + return RerankResult(model=model, docs=[]) # initialize client dashscope.api_key = credentials["dashscope_api_key"] @@ -64,7 +64,7 @@ def _invoke( return_documents=True, ) - rerank_documents = [] + rerank_documents: list[RerankDocument] = [] if not response.output: return RerankResult(model=model, docs=rerank_documents) for _, result in enumerate(response.output.results): diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index 2ef7f3f5774481..8c53be413002a9 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -import dashscope +import dashscope # type: ignore import numpy as np from core.entities.embedding_type import EmbeddingInputType diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py index ca3b9fbc1c3c00..a654e2d760d7c4 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -2,10 +2,10 @@ from queue import Queue from typing import Any, Optional -import dashscope -from dashscope import SpeechSynthesizer -from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse -from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult +import dashscope # type: ignore +from dashscope import SpeechSynthesizer # type: ignore +from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse # type: ignore +from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult # type: ignore from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/upstage/_common.py b/api/core/model_runtime/model_providers/upstage/_common.py index 47ebaccd84ab8a..f6609bba77129b 100644 --- a/api/core/model_runtime/model_providers/upstage/_common.py +++ b/api/core/model_runtime/model_providers/upstage/_common.py @@ -1,5 +1,3 @@ -from collections.abc import Mapping - import openai from httpx import Timeout @@ -14,7 +12,7 @@ class _CommonUpstage: - def _to_credential_kwargs(self, credentials: Mapping) -> dict: + def _to_credential_kwargs(self, credentials: dict) -> dict: """ Transform credentials to kwargs for model instance diff --git a/api/core/model_runtime/model_providers/upstage/llm/llm.py b/api/core/model_runtime/model_providers/upstage/llm/llm.py index a18ee906248a49..2bf6796ca5cf45 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/llm.py +++ b/api/core/model_runtime/model_providers/upstage/llm/llm.py @@ -6,7 +6,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall -from tokenizers import Tokenizer +from tokenizers import Tokenizer # type: ignore from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index 5b340e53bbc543..87693eca768dfd 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -1,11 +1,10 @@ import base64 import time -from collections.abc import Mapping from typing import Union import numpy as np from openai import OpenAI -from tokenizers import Tokenizer +from tokenizers import Tokenizer # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType @@ -132,7 +131,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int return total_num_tokens - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials diff --git a/api/core/model_runtime/model_providers/vertex_ai/_common.py b/api/core/model_runtime/model_providers/vertex_ai/_common.py index 8f7c859e3803c0..4e3df7574e9ce8 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/_common.py +++ b/api/core/model_runtime/model_providers/vertex_ai/_common.py @@ -12,4 +12,4 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ - pass + raise NotImplementedError diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index c50e0f794616b3..85be34f3f0fe7f 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -6,7 +6,7 @@ from collections.abc import Generator from typing import TYPE_CHECKING, Optional, Union, cast -import google.auth.transport.requests +import google.auth.transport.requests # type: ignore import requests from anthropic import AnthropicVertex, Stream from anthropic.types import ( diff --git a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py index 034c066ab5f071..782e4fd6232a3b 100644 --- a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py @@ -17,14 +17,12 @@ class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: - features = [] - entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - features=features, + features=[], model_properties={ ModelPropertyKey.MODE: credentials.get("mode"), }, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py index 1cffd902c7a25d..a8a015167e3227 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -1,8 +1,8 @@ from collections.abc import Generator from typing import Optional, cast -from volcenginesdkarkruntime import Ark -from volcenginesdkarkruntime.types.chat import ( +from volcenginesdkarkruntime import Ark # type: ignore +from volcenginesdkarkruntime.types.chat import ( # type: ignore ChatCompletion, ChatCompletionAssistantMessageParam, ChatCompletionChunk, @@ -15,10 +15,10 @@ ChatCompletionToolParam, ChatCompletionUserMessageParam, ) -from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL -from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function -from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse -from volcenginesdkarkruntime.types.shared_params import FunctionDefinition +from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL # type: ignore +from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function # type: ignore +from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse # type: ignore +from volcenginesdkarkruntime.types.shared_params import FunctionDefinition # type: ignore from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py index 91dbe21a616195..aa837b8318873d 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py @@ -152,5 +152,6 @@ class ServiceNotOpenError(MaasError): def wrap_error(e: MaasError) -> Exception: if ErrorCodeMap.get(e.code): - return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) + # FIXME: mypy type error, try to fix it instead of using type: ignore + return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) # type: ignore return e diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index 9e19b7dedaa5a7..f0b2b101b7be9d 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -2,7 +2,7 @@ from collections.abc import Generator from typing import Optional -from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk +from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index cf3cf23cfb9cef..7c37368086e0e6 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMMode @@ -102,7 +104,7 @@ def get_model_config(credentials: dict) -> ModelConfig: def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): - req_params = {} + req_params: dict[str, Any] = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: @@ -130,7 +132,7 @@ def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] def get_v3_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): - req_params = {} + req_params: dict[str, Any] = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index 07b970f8104c8f..d2899795696aa4 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -1,7 +1,7 @@ from collections.abc import Generator from enum import Enum from json import dumps, loads -from typing import Any, Union +from typing import Any, Optional, Union from requests import Response, post @@ -22,7 +22,7 @@ class Role(Enum): role: str = Role.USER.value content: str - usage: dict[str, int] = None + usage: Optional[dict[str, int]] = None stop_reason: str = "" def to_dict(self) -> dict[str, Any]: @@ -135,6 +135,7 @@ def _build_function_calling_request_body( """ TODO: implement function calling """ + raise NotImplementedError("Function calling is not supported yet.") def _build_chat_request_body( self, diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py index 19135deb27380d..816b3b98c4b8c5 100644 --- a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -1,6 +1,5 @@ import time from abc import abstractmethod -from collections.abc import Mapping from json import dumps from typing import Any, Optional @@ -23,12 +22,12 @@ class TextEmbedding: @abstractmethod - def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + def embed_documents(self, model: str, texts: list[str], user: str) -> tuple[list[list[float]], int, int]: raise NotImplementedError class WenxinTextEmbedding(_CommonWenxin, TextEmbedding): - def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + def embed_documents(self, model: str, texts: list[str], user: str) -> tuple[list[list[float]], int, int]: access_token = self._get_access_token() url = f"{self.api_bases[model]}?access_token={access_token}" body = self._build_embed_request_body(model, texts, user) @@ -50,7 +49,7 @@ def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> } return body - def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int): + def _handle_embed_response(self, model: str, response: Response) -> tuple[list[list[float]], int, int]: data = response.json() if "error_code" in data: code = data["error_code"] @@ -147,7 +146,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int return total_num_tokens - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: api_key = credentials["api_key"] secret_key = credentials["secret_key"] try: diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 8d86d6937d8ac9..7db1203641cad2 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -17,7 +17,7 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall from openai.types.completion import Completion -from xinference_client.client.restful.restful_client import ( +from xinference_client.client.restful.restful_client import ( # type: ignore Client, RESTfulChatModelHandle, RESTfulGenerateModelHandle, diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index efaf114854b5c1..078ec0537a37f4 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -1,6 +1,6 @@ from typing import Optional -from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle +from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 3d7aefeb6dd89a..5f330ece1a5750 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -1,6 +1,6 @@ from typing import IO, Optional -from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle +from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index e51e6a941c5413..9054aabab2dd05 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle +from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject @@ -134,7 +134,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: handle = client.get_model(model_uid=model_uid) except RuntimeError as e: - raise InvokeAuthorizationError(e) + raise InvokeAuthorizationError(str(e)) if not isinstance(handle, RESTfulEmbeddingModelHandle): raise InvokeBadRequestError( diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index ad7b64efb5d2e7..8aa39d4de0d2cb 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -1,7 +1,7 @@ import concurrent.futures from typing import Any, Optional -from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle +from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -74,11 +74,14 @@ def validate_credentials(self, model: str, credentials: dict) -> None: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") credentials["server_url"] = credentials["server_url"].removesuffix("/") + api_key = credentials.get("api_key") + if api_key is None: + raise CredentialsValidateFailedError("api_key is required") extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials["server_url"], model_uid=credentials["model_uid"], - api_key=credentials.get("api_key"), + api_key=api_key, ) if "text-to-audio" not in extra_param.model_ability: diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index baa3ccbe8adbc0..b51423f4eda2e6 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -1,6 +1,6 @@ from threading import Lock from time import time -from typing import Optional +from typing import Any, Optional from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, MissingSchema, Timeout @@ -39,13 +39,15 @@ def __init__( self.model_family = model_family -cache = {} +cache: dict[str, dict[str, Any]] = {} cache_lock = Lock() class XinferenceHelper: @staticmethod - def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: + def get_xinference_extra_parameter( + server_url: str, model_uid: str, api_key: str | None + ) -> XinferenceModelExtraParameter: XinferenceHelper._clean_cache() with cache_lock: if model_uid not in cache: @@ -66,7 +68,9 @@ def _clean_cache() -> None: pass @staticmethod - def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: + def _get_xinference_extra_parameter( + server_url: str, model_uid: str, api_key: str | None + ) -> XinferenceModelExtraParameter: """ get xinference model extra parameter like model_format and model_handle_type """ diff --git a/api/core/model_runtime/model_providers/yi/llm/llm.py b/api/core/model_runtime/model_providers/yi/llm/llm.py index 0642e72ed500e1..f5b61e207635bc 100644 --- a/api/core/model_runtime/model_providers/yi/llm/llm.py +++ b/api/core/model_runtime/model_providers/yi/llm/llm.py @@ -136,7 +136,7 @@ def _add_custom_parameters(credentials: dict) -> None: parsed_url = urlparse(credentials["endpoint_url"]) credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: return AIModelEntity( model=model, label=I18nObject(en_US=model, zh_Hans=model), diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index 59861507e45cd6..eef86cc52c36e8 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -1,9 +1,9 @@ from collections.abc import Generator from typing import Optional, Union -from zhipuai import ZhipuAI -from zhipuai.types.chat.chat_completion import Completion -from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk +from zhipuai import ZhipuAI # type: ignore +from zhipuai.types.chat.chat_completion import Completion # type: ignore +from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk # type: ignore from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 2428284ba9a8ff..a700304db7b6f3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index 029ec1a581b2e9..8cc8adfc3656ea 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Union, cast from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType @@ -38,7 +38,7 @@ def _validate_and_filter_credential_form_schemas( def _validate_credential_form_schema( self, credential_form_schema: CredentialFormSchema, credentials: dict - ) -> Optional[str]: + ) -> Union[str, bool, None]: """ Validate credential form schema @@ -47,6 +47,7 @@ def _validate_credential_form_schema( :return: validated credential form schema value """ # If the variable does not exist in credentials + value: Union[str, bool, None] = None if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: # If required is True, an exception is thrown if credential_form_schema.required: @@ -61,7 +62,7 @@ def _validate_credential_form_schema( return None # Get the value corresponding to the variable from credentials - value = credentials[credential_form_schema.variable] + value = cast(str, credentials[credential_form_schema.variable]) # If max_length=0, no validation is performed if credential_form_schema.max_length: diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index ec1bad5698f2eb..03e350627140cf 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -129,7 +129,8 @@ def jsonable_encoder( sqlalchemy_safe=sqlalchemy_safe, ) if dataclasses.is_dataclass(obj): - obj_dict = dataclasses.asdict(obj) + # FIXME: mypy error, try to fix it instead of using type: ignore + obj_dict = dataclasses.asdict(obj) # type: ignore return jsonable_encoder( obj_dict, by_alias=by_alias, diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py index 2067092d80f582..5e8a723ec7c510 100644 --- a/api/core/model_runtime/utils/helper.py +++ b/api/core/model_runtime/utils/helper.py @@ -4,6 +4,7 @@ def dump_model(model: BaseModel) -> dict: if hasattr(pydantic, "model_dump"): - return pydantic.model_dump(model) + # FIXME mypy error, try to fix it instead of using type: ignore + return pydantic.model_dump(model) # type: ignore else: return model.model_dump() diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 094ad7863603dc..c65a3885fd1eb9 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import BaseModel from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor @@ -43,6 +45,8 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["inputs_config"]["enabled"]: params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) @@ -57,6 +61,8 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["outputs_config"]["enabled"]: params = ModerationOutputParams(app_id=self.app_id, text=text) @@ -69,14 +75,18 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: ) def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: - extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id")) + if self.config is None: + raise ValueError("The config is not set.") + extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", "")) + if not extension: + raise ValueError("API-based Extension not found. Please check it again.") requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key)) result = requestor.request(extension_point, params) return result @staticmethod - def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: extension = ( db.session.query(APIBasedExtension) .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 60898d5547ae3b..d8c392d0970e19 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -100,14 +100,14 @@ def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_re if not inputs_config.get("preset_response"): raise ValueError("inputs_config.preset_response is required") - if len(inputs_config.get("preset_response")) > 100: + if len(inputs_config.get("preset_response", 0)) > 100: raise ValueError("inputs_config.preset_response must be less than 100 characters") if outputs_config_enabled: if not outputs_config.get("preset_response"): raise ValueError("outputs_config.preset_response is required") - if len(outputs_config.get("preset_response")) > 100: + if len(outputs_config.get("preset_response", 0)) > 100: raise ValueError("outputs_config.preset_response must be less than 100 characters") diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py index 96bf2ab54b41eb..0ad4438c143870 100644 --- a/api/core/moderation/factory.py +++ b/api/core/moderation/factory.py @@ -22,7 +22,8 @@ def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: """ code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config) extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) - extension_class.validate_config(tenant_id, config) + # FIXME: mypy error, try to fix it instead of using type: ignore + extension_class.validate_config(tenant_id, config) # type: ignore def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: """ diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 46d3963bd07f5a..3ac33966cb14bf 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -1,5 +1,6 @@ import logging -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional from core.app.app_config.entities import AppConfig from core.moderation.base import ModerationAction, ModerationError @@ -17,11 +18,11 @@ def check( app_id: str, tenant_id: str, app_config: AppConfig, - inputs: dict, + inputs: Mapping[str, Any], query: str, message_id: str, trace_manager: Optional[TraceQueueManager] = None, - ) -> tuple[bool, dict, str]: + ) -> tuple[bool, Mapping[str, Any], str]: """ Process sensitive_word_avoidance. :param app_id: app id @@ -33,6 +34,7 @@ def check( :param trace_manager: trace manager :return: """ + inputs = dict(inputs) if not app_config.sensitive_word_avoidance: return False, inputs, query diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 00b3c56c03602d..9dd2665c3bf3d3 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -21,7 +21,7 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: if not config.get("keywords"): raise ValueError("keywords is required") - if len(config.get("keywords")) > 10000: + if len(config.get("keywords", [])) > 10000: raise ValueError("keywords length must be less than 10000") keywords_row_len = config["keywords"].split("\n") @@ -31,6 +31,8 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["inputs_config"]["enabled"]: preset_response = self.config["inputs_config"]["preset_response"] @@ -50,6 +52,8 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["outputs_config"]["enabled"]: # Filter out empty values diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 6465de23b9a2de..d64f17b383e0b5 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -20,6 +20,8 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["inputs_config"]["enabled"]: preset_response = self.config["inputs_config"]["preset_response"] @@ -35,6 +37,8 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["outputs_config"]["enabled"]: flagged = self._is_violated({"text": text}) diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 4635bd9c251851..e595be126c7824 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -70,7 +70,7 @@ def start_thread(self) -> threading.Thread: thread = threading.Thread( target=self.worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE, }, ) diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 71ff03b6ef5160..f0e34c0cd71241 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from datetime import datetime from enum import StrEnum from typing import Any, Optional, Union @@ -38,8 +39,8 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_id: str workflow_run_elapsed_time: Union[int, float] workflow_run_status: str - workflow_run_inputs: dict[str, Any] - workflow_run_outputs: dict[str, Any] + workflow_run_inputs: Mapping[str, Any] + workflow_run_outputs: Mapping[str, Any] workflow_run_version: str error: Optional[str] = None total_tokens: int diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 29fdebd8feaeb8..b9ba068b19936d 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -77,8 +77,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): id=trace_id, user_id=user_id, name=name, - input=trace_info.workflow_run_inputs, - output=trace_info.workflow_run_outputs, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), metadata=metadata, session_id=trace_info.conversation_id, tags=["message", "workflow"], @@ -87,8 +87,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): workflow_span_data = LangfuseSpan( id=trace_info.workflow_run_id, name=TraceTaskName.WORKFLOW_TRACE.value, - input=trace_info.workflow_run_inputs, - output=trace_info.workflow_run_outputs, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), trace_id=trace_id, start_time=trace_info.start_time, end_time=trace_info.end_time, @@ -102,8 +102,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): id=trace_id, user_id=user_id, name=TraceTaskName.WORKFLOW_TRACE.value, - input=trace_info.workflow_run_inputs, - output=trace_info.workflow_run_outputs, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), metadata=metadata, session_id=trace_info.conversation_id, tags=["workflow"], diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 99221d669b3193..348b7ba5012b6b 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -49,7 +49,6 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run") input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") - dotted_order: Optional[str] = Field(None, description="Dotted order of the run") @field_validator("inputs", "outputs") @classmethod diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 672843e5a8f986..4ffd888bddf8a3 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -3,6 +3,7 @@ import os import uuid from datetime import datetime, timedelta +from typing import Optional, cast from langsmith import Client from langsmith.schemas import RunBase @@ -63,6 +64,8 @@ def trace(self, trace_info: BaseTraceInfo): def workflow_trace(self, trace_info: WorkflowTraceInfo): trace_id = trace_info.message_id or trace_info.workflow_run_id + if trace_info.start_time is None: + trace_info.start_time = datetime.now() message_dotted_order = ( generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None ) @@ -78,8 +81,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): message_run = LangSmithRunModel( id=trace_info.message_id, name=TraceTaskName.MESSAGE_TRACE.value, - inputs=trace_info.workflow_run_inputs, - outputs=trace_info.workflow_run_outputs, + inputs=dict(trace_info.workflow_run_inputs), + outputs=dict(trace_info.workflow_run_outputs), run_type=LangSmithRunType.chain, start_time=trace_info.start_time, end_time=trace_info.end_time, @@ -90,6 +93,15 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): error=trace_info.error, trace_id=trace_id, dotted_order=message_dotted_order, + file_list=[], + serialized=None, + parent_run_id=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, ) self.add_run(message_run) @@ -98,11 +110,11 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): total_tokens=trace_info.total_tokens, id=trace_info.workflow_run_id, name=TraceTaskName.WORKFLOW_TRACE.value, - inputs=trace_info.workflow_run_inputs, + inputs=dict(trace_info.workflow_run_inputs), run_type=LangSmithRunType.tool, start_time=trace_info.workflow_data.created_at, end_time=trace_info.workflow_data.finished_at, - outputs=trace_info.workflow_run_outputs, + outputs=dict(trace_info.workflow_run_outputs), extra={ "metadata": metadata, }, @@ -111,6 +123,13 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): parent_run_id=trace_info.message_id or None, trace_id=trace_id, dotted_order=workflow_dotted_order, + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, ) self.add_run(langsmith_run) @@ -211,25 +230,35 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): id=node_execution_id, trace_id=trace_id, dotted_order=node_dotted_order, + error="", + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, ) self.add_run(langsmith_run) def message_trace(self, trace_info: MessageTraceInfo): # get message file data - file_list = trace_info.file_list - message_file_data: MessageFile = trace_info.message_file_data + file_list = cast(list[str], trace_info.file_list) or [] + message_file_data: Optional[MessageFile] = trace_info.message_file_data file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_list.append(file_url) metadata = trace_info.metadata message_data = trace_info.message_data + if message_data is None: + return message_id = message_data.id user_id = message_data.from_account_id metadata["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser = ( + end_user_data: Optional[EndUser] = ( db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: @@ -247,12 +276,20 @@ def message_trace(self, trace_info: MessageTraceInfo): start_time=trace_info.start_time, end_time=trace_info.end_time, outputs=message_data.answer, - extra={ - "metadata": metadata, - }, + extra={"metadata": metadata}, tags=["message", str(trace_info.conversation_mode)], error=trace_info.error, file_list=file_list, + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + parent_run_id=None, ) self.add_run(message_run) @@ -267,17 +304,27 @@ def message_trace(self, trace_info: MessageTraceInfo): start_time=trace_info.start_time, end_time=trace_info.end_time, outputs=message_data.answer, - extra={ - "metadata": metadata, - }, + extra={"metadata": metadata}, parent_run_id=message_id, tags=["llm", str(trace_info.conversation_mode)], error=trace_info.error, file_list=file_list, + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + id=str(uuid.uuid4()), ) self.add_run(llm_run) def moderation_trace(self, trace_info: ModerationTraceInfo): + if trace_info.message_data is None: + return langsmith_run = LangSmithRunModel( name=TraceTaskName.MODERATION_TRACE.value, inputs=trace_info.inputs, @@ -288,48 +335,82 @@ def moderation_trace(self, trace_info: ModerationTraceInfo): "inputs": trace_info.inputs, }, run_type=LangSmithRunType.tool, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["moderation"], parent_run_id=trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.updated_at, + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], ) self.add_run(langsmith_run) def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): message_data = trace_info.message_data + if message_data is None: + return suggested_question_run = LangSmithRunModel( name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, inputs=trace_info.inputs, outputs=trace_info.suggested_question, run_type=LangSmithRunType.tool, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["suggested_question"], parent_run_id=trace_info.message_id, start_time=trace_info.start_time or message_data.created_at, end_time=trace_info.end_time or message_data.updated_at, + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], ) self.add_run(suggested_question_run) def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): + if trace_info.message_data is None: + return dataset_retrieval_run = LangSmithRunModel( name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, inputs=trace_info.inputs, outputs={"documents": trace_info.documents}, run_type=LangSmithRunType.retriever, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["dataset_retrieval"], parent_run_id=trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.updated_at, + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], ) self.add_run(dataset_retrieval_run) @@ -347,7 +428,18 @@ def tool_trace(self, trace_info: ToolTraceInfo): parent_run_id=trace_info.message_id, start_time=trace_info.start_time, end_time=trace_info.end_time, - file_list=[trace_info.file_url], + file_list=[cast(str, trace_info.file_url)], + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error=trace_info.error or "", ) self.add_run(tool_run) @@ -358,12 +450,23 @@ def generate_name_trace(self, trace_info: GenerateNameTraceInfo): inputs=trace_info.inputs, outputs=trace_info.outputs, run_type=LangSmithRunType.tool, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["generate_name"], start_time=trace_info.start_time or datetime.now(), end_time=trace_info.end_time or datetime.now(), + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], + parent_run_id=None, ) self.add_run(name_run) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 4f41b6ed97047f..f538eaef5bd570 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -33,11 +33,11 @@ from core.ops.utils import get_message_data from extensions.ext_database import db from extensions.ext_storage import storage -from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig +from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig from models.workflow import WorkflowAppLog, WorkflowRun from tasks.ops_trace_task import process_trace_tasks -provider_config_map = { +provider_config_map: dict[str, dict[str, Any]] = { TracingProviderEnum.LANGFUSE.value: { "config_class": LangfuseConfig, "secret_keys": ["public_key", "secret_key"], @@ -145,7 +145,7 @@ def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = ( + trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -155,7 +155,11 @@ def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str): return None # decrypt_token - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + app = db.session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError("App not found") + + tenant_id = app.tenant_id decrypt_tracing_config = cls.decrypt_tracing_config( tenant_id, tracing_provider, trace_config_data.tracing_config ) @@ -178,7 +182,7 @@ def get_ops_trace_instance( if app_id is None: return None - app: App = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() if app is None: return None @@ -209,8 +213,12 @@ def get_ops_trace_instance( def get_app_config_through_message_id(cls, message_id: str): app_model_config = None message_data = db.session.query(Message).filter(Message.id == message_id).first() + if not message_data: + return None conversation_id = message_data.conversation_id conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + if not conversation_data: + return None if conversation_data.app_model_config_id: app_model_config = ( @@ -236,7 +244,9 @@ def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: if tracing_provider not in provider_config_map and tracing_provider is not None: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: App = db.session.query(App).filter(App.id == app_id).first() + app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + if not app_config: + raise ValueError("App not found") app_config.tracing = json.dumps( { "enabled": enabled, @@ -252,7 +262,9 @@ def get_app_tracing_config(cls, app_id: str): :param app_id: app id :return: """ - app: App = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError("App not found") if not app.tracing: return {"enabled": False, "tracing_provider": None} app_trace_config = json.loads(app.tracing) @@ -483,6 +495,8 @@ def message_trace(self, message_id): def moderation_trace(self, message_id, timer, **kwargs): moderation_result = kwargs.get("moderation_result") + if not moderation_result: + return {} inputs = kwargs.get("inputs") message_data = get_message_data(message_id) if not message_data: @@ -518,7 +532,7 @@ def moderation_trace(self, message_id, timer, **kwargs): return moderation_trace_info def suggested_question_trace(self, message_id, timer, **kwargs): - suggested_question = kwargs.get("suggested_question") + suggested_question = kwargs.get("suggested_question", []) message_data = get_message_data(message_id) if not message_data: return {} @@ -586,7 +600,7 @@ def dataset_retrieval_trace(self, message_id, timer, **kwargs): dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( message_id=message_id, inputs=message_data.query or message_data.inputs, - documents=[doc.model_dump() for doc in documents], + documents=[doc.model_dump() for doc in documents] if documents else [], start_time=timer.get("start"), end_time=timer.get("end"), metadata=metadata, @@ -596,9 +610,9 @@ def dataset_retrieval_trace(self, message_id, timer, **kwargs): return dataset_retrieval_trace_info def tool_trace(self, message_id, timer, **kwargs): - tool_name = kwargs.get("tool_name") - tool_inputs = kwargs.get("tool_inputs") - tool_outputs = kwargs.get("tool_outputs") + tool_name = kwargs.get("tool_name", "") + tool_inputs = kwargs.get("tool_inputs", {}) + tool_outputs = kwargs.get("tool_outputs", {}) message_data = get_message_data(message_id) if not message_data: return {} @@ -608,7 +622,7 @@ def tool_trace(self, message_id, timer, **kwargs): tool_parameters = {} created_time = message_data.created_at end_time = message_data.updated_at - agent_thoughts: list[MessageAgentThought] = message_data.agent_thoughts + agent_thoughts = message_data.agent_thoughts for agent_thought in agent_thoughts: if tool_name in agent_thought.tools: created_time = agent_thought.created_at @@ -672,6 +686,8 @@ def generate_name_trace(self, conversation_id, timer, **kwargs): generate_conversation_name = kwargs.get("generate_conversation_name") inputs = kwargs.get("inputs") tenant_id = kwargs.get("tenant_id") + if not tenant_id: + return {} start_time = timer.get("start") end_time = timer.get("end") @@ -693,8 +709,8 @@ def generate_name_trace(self, conversation_id, timer, **kwargs): return generate_name_trace_info -trace_manager_timer = None -trace_manager_queue = queue.Queue() +trace_manager_timer: Optional[threading.Timer] = None +trace_manager_queue: queue.Queue = queue.Queue() trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5)) trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100)) @@ -706,7 +722,7 @@ def __init__(self, app_id=None, user_id=None): self.app_id = app_id self.user_id = user_id self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) - self.flask_app = current_app._get_current_object() + self.flask_app = current_app._get_current_object() # type: ignore if trace_manager_timer is None: self.start_timer() @@ -723,7 +739,7 @@ def add_trace_task(self, trace_task: TraceTask): def collect_tasks(self): global trace_manager_queue - tasks = [] + tasks: list[TraceTask] = [] while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty(): task = trace_manager_queue.get_nowait() tasks.append(task) @@ -749,6 +765,8 @@ def start_timer(self): def send_to_celery(self, tasks: list[TraceTask]): with self.flask_app.app_context(): for task in tasks: + if task.app_id is None: + continue file_id = uuid4().hex trace_info = task.execute() task_data = TaskData( diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 0f3f8249661bf0..87c7a79fb01201 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,5 +1,5 @@ -from collections.abc import Sequence -from typing import Optional +from collections.abc import Mapping, Sequence +from typing import Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import file_manager @@ -39,7 +39,7 @@ def get_prompt( self, *, prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate, - inputs: dict[str, str], + inputs: Mapping[str, str], query: str, files: Sequence[File], context: Optional[str], @@ -77,7 +77,7 @@ def get_prompt( def _get_completion_model_prompt_messages( self, prompt_template: CompletionModelPromptTemplate, - inputs: dict, + inputs: Mapping[str, str], query: Optional[str], files: Sequence[File], context: Optional[str], @@ -90,15 +90,15 @@ def _get_completion_model_prompt_messages( """ raw_prompt = prompt_template.text - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] if prompt_template.edition_type == "basic" or not prompt_template.edition_type: parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs} prompt_inputs = self._set_context_variable(context, parser, prompt_inputs) - if memory and memory_config: + if memory and memory_config and memory_config.role_prefix: role_prefix = memory_config.role_prefix prompt_inputs = self._set_histories_variable( memory=memory, @@ -135,7 +135,7 @@ def _get_completion_model_prompt_messages( def _get_chat_model_prompt_messages( self, prompt_template: list[ChatModelMessage], - inputs: dict, + inputs: Mapping[str, str], query: Optional[str], files: Sequence[File], context: Optional[str], @@ -146,7 +146,7 @@ def _get_chat_model_prompt_messages( """ Get chat model prompt messages. """ - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] for prompt_item in prompt_template: raw_prompt = prompt_item.text @@ -160,7 +160,7 @@ def _get_chat_model_prompt_messages( prompt = vp.convert_template(raw_prompt).text else: parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs} prompt_inputs = self._set_context_variable( context=context, parser=parser, prompt_inputs=prompt_inputs ) @@ -207,7 +207,7 @@ def _get_chat_model_prompt_messages( last_message = prompt_messages[-1] if prompt_messages else None if last_message and last_message.role == PromptMessageRole.USER: # get last user message content and add files - prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] + prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))] for file in files: prompt_message_contents.append(file_manager.to_prompt_message_content(file)) @@ -229,7 +229,10 @@ def _get_chat_model_prompt_messages( return prompt_messages - def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + def _set_context_variable( + self, context: str | None, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str] + ) -> Mapping[str, str]: + prompt_inputs = dict(prompt_inputs) if "#context#" in parser.variable_keys: if context: prompt_inputs["#context#"] = context @@ -238,7 +241,10 @@ def _set_context_variable(self, context: str | None, parser: PromptTemplateParse return prompt_inputs - def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + def _set_query_variable( + self, query: str, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str] + ) -> Mapping[str, str]: + prompt_inputs = dict(prompt_inputs) if "#query#" in parser.variable_keys: if query: prompt_inputs["#query#"] = query @@ -254,9 +260,10 @@ def _set_histories_variable( raw_prompt: str, role_prefix: MemoryConfig.RolePrefix, parser: PromptTemplateParser, - prompt_inputs: dict, + prompt_inputs: Mapping[str, str], model_config: ModelConfigWithCredentialsEntity, - ) -> dict: + ) -> Mapping[str, str]: + prompt_inputs = dict(prompt_inputs) if "#histories#" in parser.variable_keys: if memory: inputs = {"#histories#": "", **prompt_inputs} diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index caa1793ea8c039..09f017a7db0d3a 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -31,7 +31,7 @@ def __init__( self.memory = memory def get_prompt(self) -> list[PromptMessage]: - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] num_system = 0 for prompt_message in self.history_messages: if isinstance(prompt_message, SystemPromptMessage): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 87acdb3c49cc01..1f040599be6dac 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -42,7 +42,7 @@ def _calculate_rest_token( ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens @@ -59,7 +59,7 @@ def _get_history_messages_from_memory( ai_prefix: Optional[str] = None, ) -> str: """Get memory messages.""" - kwargs = {"max_token_limit": max_token_limit} + kwargs: dict[str, Any] = {"max_token_limit": max_token_limit} if human_prefix: kwargs["human_prefix"] = human_prefix @@ -76,11 +76,15 @@ def _get_history_messages_list_from_memory( self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int ) -> list[PromptMessage]: """Get memory messages.""" - return memory.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=memory_config.window.size - if ( - memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0 + return list( + memory.get_history_prompt_messages( + max_token_limit=max_token_limit, + message_limit=memory_config.window.size + if ( + memory_config.window.enabled + and memory_config.window.size is not None + and memory_config.window.size > 0 + ) + else None, ) - else None, ) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 93dd92f188a9c6..e75877de9b695c 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -1,7 +1,8 @@ import enum import json import os -from typing import TYPE_CHECKING, Optional +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, Optional, cast from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -41,7 +42,7 @@ def value_of(cls, value: str) -> "ModelMode": raise ValueError(f"invalid mode value {value}") -prompt_file_contents = {} +prompt_file_contents: dict[str, Any] = {} class SimplePromptTransform(PromptTransform): @@ -53,9 +54,9 @@ def get_prompt( self, app_mode: AppMode, prompt_template_entity: PromptTemplateEntity, - inputs: dict, + inputs: Mapping[str, str], query: str, - files: list["File"], + files: Sequence["File"], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, @@ -66,7 +67,7 @@ def get_prompt( if model_mode == ModelMode.CHAT: prompt_messages, stops = self._get_chat_model_prompt_messages( app_mode=app_mode, - pre_prompt=prompt_template_entity.simple_prompt_template, + pre_prompt=prompt_template_entity.simple_prompt_template or "", inputs=inputs, query=query, files=files, @@ -77,7 +78,7 @@ def get_prompt( else: prompt_messages, stops = self._get_completion_model_prompt_messages( app_mode=app_mode, - pre_prompt=prompt_template_entity.simple_prompt_template, + pre_prompt=prompt_template_entity.simple_prompt_template or "", inputs=inputs, query=query, files=files, @@ -171,11 +172,11 @@ def _get_chat_model_prompt_messages( inputs: dict, query: str, context: Optional[str], - files: list["File"], + files: Sequence["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, ) -> tuple[list[PromptMessage], Optional[list[str]]]: - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] # get prompt prompt, _ = self.get_prompt_str_and_rules( @@ -216,7 +217,7 @@ def _get_completion_model_prompt_messages( inputs: dict, query: str, context: Optional[str], - files: list["File"], + files: Sequence["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, ) -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -263,7 +264,7 @@ def _get_completion_model_prompt_messages( return [self.get_last_user_message(prompt, files)], stops - def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage: + def get_last_user_message(self, prompt: str, files: Sequence["File"]) -> UserPromptMessage: if files: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=prompt)) @@ -288,7 +289,7 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: - return prompt_file_contents[prompt_file_name] + return cast(dict, prompt_file_contents[prompt_file_name]) # Get the absolute path of the subdirectory prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates") @@ -301,7 +302,7 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict # Store the content of the prompt file prompt_file_contents[prompt_file_name] = content - return content + return cast(dict, content) def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index aa175153bc633f..2f4e65146131be 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import cast +from typing import Any, cast from core.model_runtime.entities import ( AssistantPromptMessage, @@ -72,7 +72,7 @@ def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Seque } ) else: - text = prompt_message.content + text = cast(str, prompt_message.content) prompt = {"role": role, "text": text, "files": files} @@ -99,9 +99,9 @@ def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Seque } ) else: - text = prompt_message.content + text = cast(str, prompt_message.content) - params = { + params: dict[str, Any] = { "role": "user", "text": text, } diff --git a/api/core/prompt/utils/prompt_template_parser.py b/api/core/prompt/utils/prompt_template_parser.py index 0fd08c5d3c1a3e..8e40674bc193e0 100644 --- a/api/core/prompt/utils/prompt_template_parser.py +++ b/api/core/prompt/utils/prompt_template_parser.py @@ -1,4 +1,5 @@ import re +from collections.abc import Mapping REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#histories#|#query#|#context#)\}\}") WITH_VARIABLE_TMPL_REGEX = re.compile( @@ -28,7 +29,7 @@ def extract(self) -> list: # Regular expression to match the template rules return re.findall(self.regex, self.template) - def format(self, inputs: dict, remove_template_variables: bool = True) -> str: + def format(self, inputs: Mapping[str, str], remove_template_variables: bool = True) -> str: def replacer(match): key = match.group(1) value = inputs.get(key, match.group(0)) # return original matched string if key not found diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 3a1fe300dfd311..010abd12d275cd 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,7 +1,7 @@ import json from collections import defaultdict from json import JSONDecodeError -from typing import Optional +from typing import Optional, cast from sqlalchemy.exc import IntegrityError @@ -15,6 +15,7 @@ ModelLoadBalancingConfiguration, ModelSettings, QuotaConfiguration, + QuotaUnit, SystemConfiguration, ) from core.helper import encrypter @@ -116,8 +117,8 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations: for provider_entity in provider_entities: # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, - exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET), + exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET), data=provider_entity, name_func=lambda x: x.provider, ): @@ -490,12 +491,13 @@ def _init_trial_provider_records( # Init trial provider records if not exists if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: try: + # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic provider_record = Provider( tenant_id=tenant_id, provider_name=provider_name, provider_type=ProviderType.SYSTEM.value, quota_type=ProviderQuotaType.TRIAL.value, - quota_limit=quota.quota_limit, + quota_limit=quota.quota_limit, # type: ignore quota_used=0, is_valid=True, ) @@ -589,7 +591,9 @@ def _to_custom_configuration( if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa + provider_credentials.get(variable) or "", # type: ignore + self.decoding_rsa_key, + self.decoding_cipher_rsa, ) except ValueError: pass @@ -671,13 +675,9 @@ def _to_system_configuration( # Get hosting configuration hosting_configuration = ext_hosting_provider.hosting_configuration - if ( - provider_entity.provider not in hosting_configuration.provider_map - or not hosting_configuration.provider_map.get(provider_entity.provider).enabled - ): - return SystemConfiguration(enabled=False) - provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) + if provider_hosting_configuration is None or not provider_hosting_configuration.enabled: + return SystemConfiguration(enabled=False) # Convert provider_records to dict quota_type_to_provider_records_dict = {} @@ -688,14 +688,13 @@ def _to_system_configuration( quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( provider_record ) - quota_configurations = [] for provider_quota in provider_hosting_configuration.quotas: if provider_quota.quota_type not in quota_type_to_provider_records_dict: if provider_quota.quota_type == ProviderQuotaType.FREE: quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, quota_used=0, quota_limit=0, is_valid=False, @@ -708,7 +707,7 @@ def _to_system_configuration( quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, quota_used=provider_record.quota_used, quota_limit=provider_record.quota_limit, is_valid=provider_record.quota_limit > provider_record.quota_used @@ -725,12 +724,12 @@ def _to_system_configuration( current_using_credentials = provider_hosting_configuration.credentials if current_quota_type == ProviderQuotaType.FREE: - provider_record = quota_type_to_provider_records_dict.get(current_quota_type) + provider_record_quota_free = quota_type_to_provider_records_dict.get(current_quota_type) - if provider_record: + if provider_record_quota_free: provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, - identity_id=provider_record.id, + identity_id=provider_record_quota_free.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) @@ -763,7 +762,7 @@ def _to_system_configuration( except ValueError: pass - current_using_credentials = provider_credentials + current_using_credentials = provider_credentials or {} # cache provider credentials provider_credentials_cache.set(credentials=current_using_credentials) @@ -842,7 +841,7 @@ def _to_model_settings( else [] ) - model_settings = [] + model_settings: list[ModelSettings] = [] if not provider_model_settings: return model_settings diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index a0153c1e58a1a8..95a2316f1da4dd 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -32,8 +32,11 @@ def create(self, texts: list[Document], **kwargs) -> BaseKeyword: keywords = keyword_table_handler.extract_keywords( text.page_content, self._config.max_keywords_per_chunk ) - self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) + if text.metadata is not None: + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table( + keyword_table or {}, text.metadata["doc_id"], list(keywords) + ) self._save_dataset_keyword_table(keyword_table) @@ -58,20 +61,26 @@ def add_texts(self, texts: list[Document], **kwargs): keywords = keyword_table_handler.extract_keywords( text.page_content, self._config.max_keywords_per_chunk ) - self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) + if text.metadata is not None: + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table( + keyword_table or {}, text.metadata["doc_id"], list(keywords) + ) self._save_dataset_keyword_table(keyword_table) def text_exists(self, id: str) -> bool: keyword_table = self._get_dataset_keyword_table() + if keyword_table is None: + return False return id in set.union(*keyword_table.values()) def delete_by_ids(self, ids: list[str]) -> None: lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table = self._get_dataset_keyword_table() - keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) + if keyword_table is not None: + keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) self._save_dataset_keyword_table(keyword_table) @@ -80,7 +89,7 @@ def search(self, query: str, **kwargs: Any) -> list[Document]: k = kwargs.get("top_k", 4) - sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) + sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k) documents = [] for chunk_index in sorted_chunk_indices: @@ -137,7 +146,7 @@ def _get_dataset_keyword_table(self) -> Optional[dict]: if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict if keyword_table_dict: - return keyword_table_dict["__data__"]["table"] + return dict(keyword_table_dict["__data__"]["table"]) else: keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE dataset_keyword_table = DatasetKeywordTable( @@ -188,8 +197,8 @@ def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4): # go through text chunks in order of most matching keywords chunk_indices_count: dict[str, int] = defaultdict(int) - keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] - for keyword in keywords: + keywords_list = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] + for keyword in keywords_list: for node_id in keyword_table[keyword]: chunk_indices_count[node_id] += 1 @@ -215,7 +224,7 @@ def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list def create_segment_keywords(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() self._update_segment_keywords(self.dataset.id, node_id, keywords) - keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) + keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords) self._save_dataset_keyword_table(keyword_table) def multi_create_segment_keywords(self, pre_segment_data_list: list): @@ -226,17 +235,19 @@ def multi_create_segment_keywords(self, pre_segment_data_list: list): if pre_segment_data["keywords"]: segment.keywords = pre_segment_data["keywords"] keyword_table = self._add_text_to_keyword_table( - keyword_table, segment.index_node_id, pre_segment_data["keywords"] + keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"] ) else: keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk) segment.keywords = list(keywords) - keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) + keyword_table = self._add_text_to_keyword_table( + keyword_table or {}, segment.index_node_id, list(keywords) + ) self._save_dataset_keyword_table(keyword_table) def update_segment_keywords_index(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() - keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) + keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords) self._save_dataset_keyword_table(keyword_table) diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index ec809cf325306e..8b17e8dc0a3762 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -4,7 +4,7 @@ class JiebaKeywordTableHandler: def __init__(self): - import jieba.analyse + import jieba.analyse # type: ignore from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS @@ -12,7 +12,7 @@ def __init__(self): def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" - import jieba + import jieba # type: ignore keywords = jieba.analyse.extract_tags( sentence=text, diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py index be00687abd5025..b261b40b728692 100644 --- a/api/core/rag/datasource/keyword/keyword_base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -37,6 +37,8 @@ def search(self, query: str, **kwargs: Any) -> list[Document]: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): + if text.metadata is None: + continue doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: @@ -45,4 +47,4 @@ def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata["doc_id"] for text in texts] + return [text.metadata["doc_id"] for text in texts if text.metadata] diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 18f8d4e8392302..34343ad60ea4c1 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,6 +6,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db @@ -31,7 +32,7 @@ def retrieve( top_k: int, score_threshold: Optional[float] = 0.0, reranking_model: Optional[dict] = None, - reranking_mode: Optional[str] = "reranking_model", + reranking_mode: str = "reranking_model", weights: Optional[dict] = None, ): if not query: @@ -42,15 +43,15 @@ def retrieve( if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: return [] - all_documents = [] - threads = [] - exceptions = [] + all_documents: list[Document] = [] + threads: list[threading.Thread] = [] + exceptions: list[str] = [] # retrieval_model source with keyword if retrieval_method == "keyword_search": keyword_thread = threading.Thread( target=RetrievalService.keyword_search, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "top_k": top_k, @@ -65,7 +66,7 @@ def retrieve( embedding_thread = threading.Thread( target=RetrievalService.embedding_search, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "top_k": top_k, @@ -84,7 +85,7 @@ def retrieve( full_text_index_thread = threading.Thread( target=RetrievalService.full_text_index_search, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "retrieval_method": retrieval_method, @@ -124,7 +125,7 @@ def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model if not dataset: return [] all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( - dataset.tenant_id, dataset_id, query, external_retrieval_model + dataset.tenant_id, dataset_id, query, external_retrieval_model or {} ) return all_documents @@ -135,6 +136,8 @@ def keyword_search( with flask_app.app_context(): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("dataset not found") keyword = Keyword(dataset=dataset) @@ -159,6 +162,8 @@ def embedding_search( with flask_app.app_context(): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("dataset not found") vector = Vector(dataset=dataset) @@ -209,6 +214,8 @@ def full_text_index_search( with flask_app.app_context(): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("dataset not found") vector_processor = Vector( dataset=dataset, diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index 09104ae4223443..603d3fdbcdf1ab 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -17,12 +17,19 @@ class AnalyticdbVector(BaseVector): def __init__( - self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig + self, + collection_name: str, + api_config: AnalyticdbVectorOpenAPIConfig | None, + sql_config: AnalyticdbVectorBySqlConfig | None, ): super().__init__(collection_name) if api_config is not None: - self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config) + self.analyticdb_vector: AnalyticdbVectorOpenAPI | AnalyticdbVectorBySql = AnalyticdbVectorOpenAPI( + collection_name, api_config + ) else: + if sql_config is None: + raise ValueError("Either api_config or sql_config must be provided") self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config) def get_type(self) -> str: @@ -33,8 +40,8 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self.analyticdb_vector._create_collection_if_not_exists(dimension) self.analyticdb_vector.add_texts(texts, embeddings) - def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - self.analyticdb_vector.add_texts(texts, embeddings) + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + self.analyticdb_vector.add_texts(documents, embeddings) def text_exists(self, id: str) -> bool: return self.analyticdb_vector.text_exists(id) @@ -68,13 +75,13 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings if dify_config.ANALYTICDB_HOST is None: # implemented through OpenAPI apiConfig = AnalyticdbVectorOpenAPIConfig( - access_key_id=dify_config.ANALYTICDB_KEY_ID, - access_key_secret=dify_config.ANALYTICDB_KEY_SECRET, - region_id=dify_config.ANALYTICDB_REGION_ID, - instance_id=dify_config.ANALYTICDB_INSTANCE_ID, - account=dify_config.ANALYTICDB_ACCOUNT, - account_password=dify_config.ANALYTICDB_PASSWORD, - namespace=dify_config.ANALYTICDB_NAMESPACE, + access_key_id=dify_config.ANALYTICDB_KEY_ID or "", + access_key_secret=dify_config.ANALYTICDB_KEY_SECRET or "", + region_id=dify_config.ANALYTICDB_REGION_ID or "", + instance_id=dify_config.ANALYTICDB_INSTANCE_ID or "", + account=dify_config.ANALYTICDB_ACCOUNT or "", + account_password=dify_config.ANALYTICDB_PASSWORD or "", + namespace=dify_config.ANALYTICDB_NAMESPACE or "", namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD, ) sqlConfig = None @@ -83,11 +90,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings sqlConfig = AnalyticdbVectorBySqlConfig( host=dify_config.ANALYTICDB_HOST, port=dify_config.ANALYTICDB_PORT, - account=dify_config.ANALYTICDB_ACCOUNT, - account_password=dify_config.ANALYTICDB_PASSWORD, + account=dify_config.ANALYTICDB_ACCOUNT or "", + account_password=dify_config.ANALYTICDB_PASSWORD or "", min_connection=dify_config.ANALYTICDB_MIN_CONNECTION, max_connection=dify_config.ANALYTICDB_MAX_CONNECTION, - namespace=dify_config.ANALYTICDB_NAMESPACE, + namespace=dify_config.ANALYTICDB_NAMESPACE or "", ) apiConfig = None return AnalyticdbVector( diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index 05e0ebc54f7c4c..095752ea8eaa42 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, Optional from pydantic import BaseModel, model_validator @@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel): account: str account_password: str namespace: str = "dify" - namespace_password: str = (None,) + namespace_password: Optional[str] = None metrics: str = "cosine" read_timeout: int = 60000 @@ -55,8 +55,8 @@ def to_analyticdb_client_params(self): class AnalyticdbVectorOpenAPI: def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig): try: - from alibabacloud_gpdb20160503.client import Client - from alibabacloud_tea_openapi import models as open_api_models + from alibabacloud_gpdb20160503.client import Client # type: ignore + from alibabacloud_tea_openapi import models as open_api_models # type: ignore except: raise ImportError(_import_err_msg) self._collection_name = collection_name.lower() @@ -77,7 +77,7 @@ def _initialize(self) -> None: redis_client.set(database_exist_cache_key, 1, ex=3600) def _initialize_vector_database(self) -> None: - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore request = gpdb_20160503_models.InitVectorDatabaseRequest( dbinstance_id=self.config.instance_id, @@ -89,7 +89,7 @@ def _initialize_vector_database(self) -> None: def _create_namespace_if_not_exists(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - from Tea.exceptions import TeaException + from Tea.exceptions import TeaException # type: ignore try: request = gpdb_20160503_models.DescribeNamespaceRequest( @@ -159,17 +159,18 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] for doc, embedding in zip(documents, embeddings, strict=True): - metadata = { - "ref_doc_id": doc.metadata["doc_id"], - "page_content": doc.page_content, - "metadata_": json.dumps(doc.metadata), - } - rows.append( - gpdb_20160503_models.UpsertCollectionDataRequestRows( - vector=embedding, - metadata=metadata, + if doc.metadata is not None: + metadata = { + "ref_doc_id": doc.metadata["doc_id"], + "page_content": doc.page_content, + "metadata_": json.dumps(doc.metadata), + } + rows.append( + gpdb_20160503_models.UpsertCollectionDataRequestRows( + vector=embedding, + metadata=metadata, + ) ) - ) request = gpdb_20160503_models.UpsertCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -258,7 +259,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc metadata=metadata, ) documents.append(doc) - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -290,7 +291,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: metadata=metadata, ) documents.append(doc) - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents def delete(self) -> None: diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index e474db5cb21971..4d8f7929413cf2 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -3,8 +3,8 @@ from contextlib import contextmanager from typing import Any -import psycopg2.extras -import psycopg2.pool +import psycopg2.extras # type: ignore +import psycopg2.pool # type: ignore from pydantic import BaseModel, model_validator from core.rag.models.document import Document @@ -75,6 +75,7 @@ def _create_connection_pool(self): @contextmanager def _get_cursor(self): + assert self.pool is not None, "Connection pool is not initialized" conn = self.pool.getconn() cur = conn.cursor() try: @@ -156,16 +157,17 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s)); """ for i, doc in enumerate(documents): - values.append( - ( - id_prefix + str(i), - doc.metadata.get("doc_id", str(uuid.uuid4())), - embeddings[i], - doc.page_content, - json.dumps(doc.metadata), - doc.page_content, + if doc.metadata is not None: + values.append( + ( + id_prefix + str(i), + doc.metadata.get("doc_id", str(uuid.uuid4())), + embeddings[i], + doc.page_content, + json.dumps(doc.metadata), + doc.page_content, + ) ) - ) with self._get_cursor() as cur: psycopg2.extras.execute_batch(cur, sql, values) diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index eb78e8aa698b9b..85596ad20e099a 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -5,13 +5,13 @@ import numpy as np from pydantic import BaseModel, model_validator -from pymochow import MochowClient -from pymochow.auth.bce_credentials import BceCredentials -from pymochow.configuration import Configuration -from pymochow.exception import ServerError -from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState -from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex -from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row +from pymochow import MochowClient # type: ignore +from pymochow.auth.bce_credentials import BceCredentials # type: ignore +from pymochow.configuration import Configuration # type: ignore +from pymochow.exception import ServerError # type: ignore +from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore +from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore +from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore from configs import dify_config from core.rag.datasource.vdb.vector_base import BaseVector @@ -75,7 +75,7 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] + metadatas = [doc.metadata for doc in documents if doc.metadata is not None] total_count = len(documents) batch_size = 1000 @@ -84,6 +84,8 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** for start in range(0, total_count, batch_size): end = min(start + batch_size, total_count) rows = [] + assert len(metadatas) == total_count, "metadatas length should be equal to total_count" + # FIXME do you need this assert? for i in range(start, end, 1): row = Row( id=metadatas[i].get("doc_id", str(uuid.uuid4())), @@ -136,7 +138,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: # baidu vector database doesn't support bm25 search on current version return [] - def _get_search_res(self, res, score_threshold): + def _get_search_res(self, res, score_threshold) -> list[Document]: docs = [] for row in res.rows: row_data = row.get("row", {}) @@ -276,11 +278,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return BaiduVector( collection_name=collection_name, config=BaiduConfig( - endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT, + endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT or "", connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS, - account=dify_config.BAIDU_VECTOR_DB_ACCOUNT, - api_key=dify_config.BAIDU_VECTOR_DB_API_KEY, - database=dify_config.BAIDU_VECTOR_DB_DATABASE, + account=dify_config.BAIDU_VECTOR_DB_ACCOUNT or "", + api_key=dify_config.BAIDU_VECTOR_DB_API_KEY or "", + database=dify_config.BAIDU_VECTOR_DB_DATABASE or "", shard=dify_config.BAIDU_VECTOR_DB_SHARD, replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, ), diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index a9e1486edd25f1..0eab01b507dc94 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -71,11 +71,13 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** metadatas = [d.metadata for d in documents] collection = self._client.get_or_create_collection(self._collection_name) - collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) + # FIXME: chromadb using numpy array, fix the type error later + collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore def delete_by_metadata_field(self, key: str, value: str): collection = self._client.get_or_create_collection(self._collection_name) - collection.delete(where={key: {"$eq": value}}) + # FIXME: fix the type error later + collection.delete(where={key: {"$eq": value}}) # type: ignore def delete(self): self._client.delete_collection(self._collection_name) @@ -94,15 +96,19 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) score_threshold = float(kwargs.get("score_threshold") or 0.0) - ids: list[str] = results["ids"][0] - documents: list[str] = results["documents"][0] - metadatas: dict[str, Any] = results["metadatas"][0] - distances: list[float] = results["distances"][0] + # Check if results contain data + if not results["ids"] or not results["documents"] or not results["metadatas"] or not results["distances"]: + return [] + + ids = results["ids"][0] + documents = results["documents"][0] + metadatas = results["metadatas"][0] + distances = results["distances"][0] docs = [] for index in range(len(ids)): distance = distances[index] - metadata = metadatas[index] + metadata = dict(metadatas[index]) if distance >= score_threshold: metadata["score"] = distance doc = Document( @@ -111,7 +117,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -133,7 +139,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return ChromaVector( collection_name=collection_name, config=ChromaConfig( - host=dify_config.CHROMA_HOST, + host=dify_config.CHROMA_HOST or "", port=dify_config.CHROMA_PORT, tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py index d26726e86438bd..68a9952789e5b6 100644 --- a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py +++ b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py @@ -5,14 +5,14 @@ from datetime import timedelta from typing import Any -from couchbase import search -from couchbase.auth import PasswordAuthenticator -from couchbase.cluster import Cluster -from couchbase.management.search import SearchIndex +from couchbase import search # type: ignore +from couchbase.auth import PasswordAuthenticator # type: ignore +from couchbase.cluster import Cluster # type: ignore +from couchbase.management.search import SearchIndex # type: ignore # needed for options -- cluster, timeout, SQL++ (N1QL) query, etc. -from couchbase.options import ClusterOptions, SearchOptions -from couchbase.vector_search import VectorQuery, VectorSearch +from couchbase.options import ClusterOptions, SearchOptions # type: ignore +from couchbase.vector_search import VectorQuery, VectorSearch # type: ignore from flask import current_app from pydantic import BaseModel, model_validator @@ -231,7 +231,7 @@ def text_exists(self, id: str) -> bool: # Pass the id as a parameter to the query result = self._cluster.query(query, named_parameters={"doc_id": id}).execute() for row in result: - return row["count"] > 0 + return bool(row["count"] > 0) return False # Return False if no rows are returned def delete_by_ids(self, ids: list[str]) -> None: @@ -369,10 +369,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return CouchbaseVector( collection_name=collection_name, config=CouchbaseConfig( - connection_string=config.get("COUCHBASE_CONNECTION_STRING"), - user=config.get("COUCHBASE_USER"), - password=config.get("COUCHBASE_PASSWORD"), - bucket_name=config.get("COUCHBASE_BUCKET_NAME"), - scope_name=config.get("COUCHBASE_SCOPE_NAME"), + connection_string=config.get("COUCHBASE_CONNECTION_STRING", ""), + user=config.get("COUCHBASE_USER", ""), + password=config.get("COUCHBASE_PASSWORD", ""), + bucket_name=config.get("COUCHBASE_BUCKET_NAME", ""), + scope_name=config.get("COUCHBASE_SCOPE_NAME", ""), ), ) diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index b08811a02181d2..8661828dc2aa52 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -1,7 +1,7 @@ import json import logging import math -from typing import Any, Optional +from typing import Any, Optional, cast from urllib.parse import urlparse import requests @@ -70,7 +70,7 @@ def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: def _get_version(self) -> str: info = self._client.info() - return info["version"]["number"] + return cast(str, info["version"]["number"]) def _check_version(self): if self._version < "8.0.0": @@ -135,7 +135,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc for doc, score in docs_and_scores: score_threshold = float(kwargs.get("score_threshold") or 0.0) if score > score_threshold: - doc.metadata["score"] = score + if doc.metadata is not None: + doc.metadata["score"] = score docs.append(doc) return docs @@ -156,12 +157,15 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return docs def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - metadatas = [d.metadata for d in texts] + metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) self.add_texts(texts, embeddings, **kwargs) def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, + embeddings: list[list[float]], + metadatas: Optional[list[dict[Any, Any]]] = None, + index_params: Optional[dict] = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): @@ -208,10 +212,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return ElasticSearchVector( index_name=collection_name, config=ElasticSearchConfig( - host=config.get("ELASTICSEARCH_HOST"), - port=config.get("ELASTICSEARCH_PORT"), - username=config.get("ELASTICSEARCH_USERNAME"), - password=config.get("ELASTICSEARCH_PASSWORD"), + host=config.get("ELASTICSEARCH_HOST", "localhost"), + port=config.get("ELASTICSEARCH_PORT", 9200), + username=config.get("ELASTICSEARCH_USERNAME", ""), + password=config.get("ELASTICSEARCH_PASSWORD", ""), ), attributes=[], ) diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index 8646e52cf493ca..d7a14207e9375a 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -42,7 +42,7 @@ def validate_config(cls, values: dict) -> dict: return values def to_opensearch_params(self) -> dict[str, Any]: - params = {"hosts": self.hosts} + params: dict[str, Any] = {"hosts": self.hosts} if self.username and self.password: params["http_auth"] = (self.username, self.password) return params @@ -53,7 +53,7 @@ def __init__(self, collection_name: str, config: LindormVectorStoreConfig, using self._routing = None self._routing_field = None if using_ugc: - routing_value: str = kwargs.get("routing_value") + routing_value: str | None = kwargs.get("routing_value") if routing_value is None: raise ValueError("UGC index should init vector with valid 'routing_value' parameter value") self._routing = routing_value.lower() @@ -87,14 +87,15 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** "_id": uuids[i], } } - action_values = { + action_values: dict[str, Any] = { Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], # Make sure you pass an array here Field.METADATA_KEY.value: documents[i].metadata, } if self._using_ugc: action_header["index"]["routing"] = self._routing - action_values[self._routing_field] = self._routing + if self._routing_field is not None: + action_values[self._routing_field] = self._routing actions.append(action_header) actions.append(action_values) response = self._client.bulk(actions) @@ -105,7 +106,9 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** self.refresh() def get_ids_by_metadata_field(self, key: str, value: str): - query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}} + query: dict[str, Any] = { + "query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}} + } if self._using_ugc: query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}}) response = self._client.search(index=self._collection_name, body=query) @@ -191,7 +194,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc for doc, score in docs_and_scores: score_threshold = kwargs.get("score_threshold", 0.0) or 0.0 if score > score_threshold: - doc.metadata["score"] = score + if doc.metadata is not None: + doc.metadata["score"] = score docs.append(doc) return docs @@ -366,6 +370,7 @@ def default_text_search_query( routing_field: Optional[str] = None, **kwargs, ) -> dict: + query_clause: dict[str, Any] = {} if routing is not None: query_clause = { "bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]} @@ -386,7 +391,7 @@ def default_text_search_query( else: must = [query_clause] - boolean_query = {"must": must} + boolean_query: dict[str, Any] = {"must": must} if must_not: if not isinstance(must_not, list): @@ -426,7 +431,7 @@ def default_vector_search_query( filter_type = "post_filter" if filter_type is None else filter_type if not isinstance(filters, list): raise RuntimeError(f"unexpected filter with {type(filters)}") - final_ext = {"lvector": {}} + final_ext: dict[str, Any] = {"lvector": {}} if min_score != "0.0": final_ext["lvector"]["min_score"] = min_score if ef_search: @@ -438,7 +443,7 @@ def default_vector_search_query( if client_refactor: final_ext["lvector"]["client_refactor"] = client_refactor - search_query = { + search_query: dict[str, Any] = { "size": k, "_source": True, # force return '_source' "query": {"knn": {vector_field: {"vector": query_vector, "k": k}}}, @@ -446,8 +451,8 @@ def default_vector_search_query( if filters is not None: # when using filter, transform filter from List[Dict] to Dict as valid format - filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0] - search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict + filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0] + search_query["query"]["knn"][vector_field]["filter"] = filter_dict # filter should be Dict if filter_type: final_ext["lvector"]["filter_type"] = filter_type @@ -459,17 +464,19 @@ def default_vector_search_query( class LindormVectorStoreFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore: lindorm_config = LindormVectorStoreConfig( - hosts=dify_config.LINDORM_URL, + hosts=dify_config.LINDORM_URL or "", username=dify_config.LINDORM_USERNAME, password=dify_config.LINDORM_PASSWORD, using_ugc=dify_config.USING_UGC_INDEX, ) using_ugc = dify_config.USING_UGC_INDEX + if using_ugc is None: + raise ValueError("USING_UGC_INDEX is not set") routing_value = None if dataset.index_struct: # if an existed record's index_struct_dict doesn't contain using_ugc field, # it actually stores in the normal index format - stored_in_ugc = dataset.index_struct_dict.get("using_ugc", False) + stored_in_ugc: bool = dataset.index_struct_dict.get("using_ugc", False) using_ugc = stored_in_ugc if stored_in_ugc: dimension = dataset.index_struct_dict["dimension"] diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 5a263d6e78c3bd..9b029ffc193cc0 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -3,8 +3,8 @@ from typing import Any, Optional from pydantic import BaseModel, model_validator -from pymilvus import MilvusClient, MilvusException -from pymilvus.milvus_client import IndexParams +from pymilvus import MilvusClient, MilvusException # type: ignore +from pymilvus.milvus_client import IndexParams # type: ignore from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -54,14 +54,14 @@ def __init__(self, collection_name: str, config: MilvusConfig): self._client_config = config self._client = self._init_client(config) self._consistency_level = "Session" - self._fields = [] + self._fields: list[str] = [] def get_type(self) -> str: return VectorType.MILVUS def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} - metadatas = [d.metadata for d in texts] + metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas, index_params) self.add_texts(texts, embeddings) @@ -161,8 +161,8 @@ def create_collection( return # Grab the existing collection if it exists if not self._client.has_collection(self._collection_name): - from pymilvus import CollectionSchema, DataType, FieldSchema - from pymilvus.orm.types import infer_dtype_bydata + from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore + from pymilvus.orm.types import infer_dtype_bydata # type: ignore # Determine embedding dim dim = len(embeddings[0]) @@ -217,10 +217,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return MilvusVector( collection_name=collection_name, config=MilvusConfig( - uri=dify_config.MILVUS_URI, - token=dify_config.MILVUS_TOKEN, - user=dify_config.MILVUS_USER, - password=dify_config.MILVUS_PASSWORD, - database=dify_config.MILVUS_DATABASE, + uri=dify_config.MILVUS_URI or "", + token=dify_config.MILVUS_TOKEN or "", + user=dify_config.MILVUS_USER or "", + password=dify_config.MILVUS_PASSWORD or "", + database=dify_config.MILVUS_DATABASE or "", ), ) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index b7b6b803ad20af..e63e1f522b3812 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -74,15 +74,16 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** columns = ["id", "text", "vector", "metadata"] values = [] for i, doc in enumerate(documents): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - row = ( - doc_id, - self.escape_str(doc.page_content), - embeddings[i], - json.dumps(doc.metadata) if doc.metadata else {}, - ) - values.append(str(row)) - ids.append(doc_id) + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + row = ( + doc_id, + self.escape_str(doc.page_content), + embeddings[i], + json.dumps(doc.metadata) if doc.metadata else {}, + ) + values.append(str(row)) + ids.append(doc_id) sql = f""" INSERT INTO {self._config.database}.{self._collection_name} ({",".join(columns)}) VALUES {",".join(values)} diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index c44338d42a591a..957c799a60cbfe 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -4,7 +4,7 @@ from typing import Any from pydantic import BaseModel, model_validator -from pyobvector import VECTOR, ObVecClient +from pyobvector import VECTOR, ObVecClient # type: ignore from sqlalchemy import JSON, Column, String, func from sqlalchemy.dialects.mysql import LONGTEXT @@ -131,7 +131,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** def text_exists(self, id: str) -> bool: cur = self._client.get(table_name=self._collection_name, id=id) - return cur.rowcount != 0 + return bool(cur.rowcount != 0) def delete_by_ids(self, ids: list[str]) -> None: self._client.delete(table_name=self._collection_name, ids=ids) diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 7a976d7c3c8955..72a15022052f0a 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -66,7 +66,7 @@ def get_type(self) -> str: return VectorType.OPENSEARCH def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - metadatas = [d.metadata for d in texts] + metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) self.add_texts(texts, embeddings) @@ -244,7 +244,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) open_search_config = OpenSearchConfig( - host=dify_config.OPENSEARCH_HOST, + host=dify_config.OPENSEARCH_HOST or "localhost", port=dify_config.OPENSEARCH_PORT, user=dify_config.OPENSEARCH_USER, password=dify_config.OPENSEARCH_PASSWORD, diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 74608f1e1a3b05..dfff3563c3bb28 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -5,7 +5,7 @@ from contextlib import contextmanager from typing import Any -import jieba.posseg as pseg +import jieba.posseg as pseg # type: ignore import numpy import oracledb from pydantic import BaseModel, model_validator @@ -88,12 +88,11 @@ def input_type_handler(self, cursor, value, arraysize): def numpy_converter_out(self, value): if value.typecode == "b": - dtype = numpy.int8 + return numpy.array(value, copy=False, dtype=numpy.int8) elif value.typecode == "f": - dtype = numpy.float32 + return numpy.array(value, copy=False, dtype=numpy.float32) else: - dtype = numpy.float64 - return numpy.array(value, copy=False, dtype=dtype) + return numpy.array(value, copy=False, dtype=numpy.float64) def output_type_handler(self, cursor, metadata): if metadata.type_code is oracledb.DB_TYPE_VECTOR: @@ -135,17 +134,18 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** values = [] pks = [] for i, doc in enumerate(documents): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - pks.append(doc_id) - values.append( - ( - doc_id, - doc.page_content, - json.dumps(doc.metadata), - # array.array("f", embeddings[i]), - numpy.array(embeddings[i]), + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + pks.append(doc_id) + values.append( + ( + doc_id, + doc.page_content, + json.dumps(doc.metadata), + # array.array("f", embeddings[i]), + numpy.array(embeddings[i]), + ) ) - ) # print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") with self._get_cursor() as cur: cur.executemany( @@ -201,8 +201,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: # lazy import - import nltk - from nltk.corpus import stopwords + import nltk # type: ignore + from nltk.corpus import stopwords # type: ignore top_k = kwargs.get("top_k", 5) # just not implement fetch by score_threshold now, may be later @@ -285,10 +285,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return OracleVector( collection_name=collection_name, config=OracleVectorConfig( - host=dify_config.ORACLE_HOST, + host=dify_config.ORACLE_HOST or "localhost", port=dify_config.ORACLE_PORT, - user=dify_config.ORACLE_USER, - password=dify_config.ORACLE_PASSWORD, - database=dify_config.ORACLE_DATABASE, + user=dify_config.ORACLE_USER or "system", + password=dify_config.ORACLE_PASSWORD or "oracle", + database=dify_config.ORACLE_DATABASE or "orcl", ), ) diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 7cbbdcc81f6039..221bc68d68a6f7 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -4,7 +4,7 @@ from uuid import UUID, uuid4 from numpy import ndarray -from pgvecto_rs.sqlalchemy import VECTOR +from pgvecto_rs.sqlalchemy import VECTOR # type: ignore from pydantic import BaseModel, model_validator from sqlalchemy import Float, String, create_engine, insert, select, text from sqlalchemy import text as sql_text @@ -58,7 +58,7 @@ def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int): with Session(self._client) as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) session.commit() - self._fields = [] + self._fields: list[str] = [] class _Table(CollectionORM): __tablename__ = collection_name @@ -222,11 +222,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return PGVectoRS( collection_name=collection_name, config=PgvectoRSConfig( - host=dify_config.PGVECTO_RS_HOST, - port=dify_config.PGVECTO_RS_PORT, - user=dify_config.PGVECTO_RS_USER, - password=dify_config.PGVECTO_RS_PASSWORD, - database=dify_config.PGVECTO_RS_DATABASE, + host=dify_config.PGVECTO_RS_HOST or "localhost", + port=dify_config.PGVECTO_RS_PORT or 5432, + user=dify_config.PGVECTO_RS_USER or "postgres", + password=dify_config.PGVECTO_RS_PASSWORD or "", + database=dify_config.PGVECTO_RS_DATABASE or "postgres", ), dim=dim, ) diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 40a9cdd136b404..271281ca7e939f 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -3,8 +3,8 @@ from contextlib import contextmanager from typing import Any -import psycopg2.extras -import psycopg2.pool +import psycopg2.extras # type: ignore +import psycopg2.pool # type: ignore from pydantic import BaseModel, model_validator from configs import dify_config @@ -98,16 +98,17 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** values = [] pks = [] for i, doc in enumerate(documents): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - pks.append(doc_id) - values.append( - ( - doc_id, - doc.page_content, - json.dumps(doc.metadata), - embeddings[i], + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + pks.append(doc_id) + values.append( + ( + doc_id, + doc.page_content, + json.dumps(doc.metadata), + embeddings[i], + ) ) - ) with self._get_cursor() as cur: psycopg2.extras.execute_values( cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values @@ -216,11 +217,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return PGVector( collection_name=collection_name, config=PGVectorConfig( - host=dify_config.PGVECTOR_HOST, + host=dify_config.PGVECTOR_HOST or "localhost", port=dify_config.PGVECTOR_PORT, - user=dify_config.PGVECTOR_USER, - password=dify_config.PGVECTOR_PASSWORD, - database=dify_config.PGVECTOR_DATABASE, + user=dify_config.PGVECTOR_USER or "postgres", + password=dify_config.PGVECTOR_PASSWORD or "", + database=dify_config.PGVECTOR_DATABASE or "postgres", min_connection=dify_config.PGVECTOR_MIN_CONNECTION, max_connection=dify_config.PGVECTOR_MAX_CONNECTION, ), diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 3811458e02957c..6e94cb69db309d 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -51,6 +51,8 @@ def to_qdrant_params(self): if self.endpoint and self.endpoint.startswith("path:"): path = self.endpoint.replace("path:", "") if not os.path.isabs(path): + if not self.root_path: + raise ValueError("Root path is not set") path = os.path.join(self.root_path, path) return {"path": path} @@ -149,9 +151,12 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] - added_ids = [] - for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): + # Filter out None values from metadatas list to match expected type + filtered_metadatas = [m for m in metadatas if m is not None] + for batch_ids, points in self._generate_rest_batches( + texts, embeddings, filtered_metadatas, uuids, 64, self._group_id + ): self._client.upsert(collection_name=self._collection_name, points=points) added_ids.extend(batch_ids) @@ -194,7 +199,7 @@ def _generate_rest_batches( batch_metadatas, Field.CONTENT_KEY.value, Field.METADATA_KEY.value, - group_id, + group_id or "", # Ensure group_id is never None Field.GROUP_KEY.value, ), ) @@ -337,18 +342,20 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc ) docs = [] for result in results: + if result.payload is None: + continue metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold score_threshold = float(kwargs.get("score_threshold") or 0.0) if result.score > score_threshold: metadata["score"] = result.score doc = Document( - page_content=result.payload.get(Field.CONTENT_KEY.value), + page_content=result.payload.get(Field.CONTENT_KEY.value, ""), metadata=metadata, ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -432,9 +439,9 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings collection_name=collection_name, group_id=dataset.id, config=QdrantConfig( - endpoint=dify_config.QDRANT_URL, + endpoint=dify_config.QDRANT_URL or "", api_key=dify_config.QDRANT_API_KEY, - root_path=current_app.config.root_path, + root_path=str(current_app.config.root_path), timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index f373dcfeabef92..a3a20448ff7a0a 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -3,7 +3,7 @@ from typing import Any, Optional from pydantic import BaseModel, model_validator -from sqlalchemy import Column, Sequence, String, Table, create_engine, insert +from sqlalchemy import Column, String, Table, create_engine, insert from sqlalchemy import text as sql_text from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.orm import Session @@ -58,14 +58,14 @@ def __init__(self, collection_name: str, config: RelytConfig, group_id: str): f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" ) self.client = create_engine(self._url) - self._fields = [] + self._fields: list[str] = [] self._group_id = group_id def get_type(self) -> str: return VectorType.RELYT - def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - index_params = {} + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None: + index_params: dict[str, Any] = {} metadatas = [d.metadata for d in texts] self.create_collection(len(embeddings[0])) self.embedding_dimension = len(embeddings[0]) @@ -107,10 +107,10 @@ def create_collection(self, dimension: int): redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): - from pgvecto_rs.sqlalchemy import VECTOR + from pgvecto_rs.sqlalchemy import VECTOR # type: ignore ids = [str(uuid.uuid1()) for _ in documents] - metadatas = [d.metadata for d in documents] + metadatas = [d.metadata for d in documents if d.metadata is not None] for metadata in metadatas: metadata["group_id"] = self._group_id texts = [d.page_content for d in documents] @@ -242,10 +242,6 @@ def similarity_search_with_score_by_vector( filter: Optional[dict] = None, ) -> list[tuple[Document, float]]: # Add the filter if provided - try: - from sqlalchemy.engine import Row - except ImportError: - raise ImportError("Could not import Row from sqlalchemy.engine. Please 'pip install sqlalchemy>=1.4'.") filter_condition = "" if filter is not None: @@ -275,7 +271,7 @@ def similarity_search_with_score_by_vector( # Execute the query and fetch the results with self.client.connect() as conn: - results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall() + results = conn.execute(sql_text(sql_query), params).fetchall() documents_with_scores = [ ( @@ -307,11 +303,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return RelytVector( collection_name=collection_name, config=RelytConfig( - host=dify_config.RELYT_HOST, + host=dify_config.RELYT_HOST or "localhost", port=dify_config.RELYT_PORT, - user=dify_config.RELYT_USER, - password=dify_config.RELYT_PASSWORD, - database=dify_config.RELYT_DATABASE, + user=dify_config.RELYT_USER or "", + password=dify_config.RELYT_PASSWORD or "", + database=dify_config.RELYT_DATABASE or "default", ), group_id=dataset.id, ) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index f971a9c5eb1696..c15f4b229f81c3 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -2,10 +2,10 @@ from typing import Any, Optional from pydantic import BaseModel -from tcvectordb import VectorDBClient -from tcvectordb.model import document, enum -from tcvectordb.model import index as vdb_index -from tcvectordb.model.document import Filter +from tcvectordb import VectorDBClient # type: ignore +from tcvectordb.model import document, enum # type: ignore +from tcvectordb.model import index as vdb_index # type: ignore +from tcvectordb.model.document import Filter # type: ignore from configs import dify_config from core.rag.datasource.vdb.vector_base import BaseVector @@ -25,8 +25,8 @@ class TencentConfig(BaseModel): database: Optional[str] index_type: str = "HNSW" metric_type: str = "L2" - shard: int = (1,) - replicas: int = (2,) + shard: int = 1 + replicas: int = 2 def to_tencent_params(self): return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} @@ -120,15 +120,15 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** metadatas = [doc.metadata for doc in documents] total_count = len(embeddings) docs = [] - for id in range(0, total_count): + for i in range(0, total_count): if metadatas is None: continue - metadata = json.dumps(metadatas[id]) + metadata = metadatas[i] or {} doc = document.Document( - id=metadatas[id]["doc_id"], - vector=embeddings[id], - text=texts[id], - metadata=metadata, + id=metadata.get("doc_id"), + vector=embeddings[i], + text=texts[i], + metadata=json.dumps(metadata), ) docs.append(doc) self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout) @@ -159,8 +159,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return [] - def _get_search_res(self, res, score_threshold): - docs = [] + def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]: + docs: list[Document] = [] if res is None or len(res) == 0: return docs @@ -193,7 +193,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return TencentVector( collection_name=collection_name, config=TencentConfig( - url=dify_config.TENCENT_VECTOR_DB_URL, + url=dify_config.TENCENT_VECTOR_DB_URL or "", api_key=dify_config.TENCENT_VECTOR_DB_API_KEY, timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT, username=dify_config.TENCENT_VECTOR_DB_USERNAME, diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index cfd47aac5ba05b..19c5579a688f5a 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -54,7 +54,10 @@ def to_qdrant_params(self): if self.endpoint and self.endpoint.startswith("path:"): path = self.endpoint.replace("path:", "") if not os.path.isabs(path): - path = os.path.join(self.root_path, path) + if self.root_path: + path = os.path.join(self.root_path, path) + else: + raise ValueError("root_path is required") return {"path": path} else: @@ -157,7 +160,7 @@ def create_collection(self, collection_name: str, vector_size: int): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] - metadatas = [d.metadata for d in documents] + metadatas = [d.metadata for d in documents if d.metadata is not None] added_ids = [] for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): @@ -203,7 +206,7 @@ def _generate_rest_batches( batch_metadatas, Field.CONTENT_KEY.value, Field.METADATA_KEY.value, - group_id, + group_id or "", Field.GROUP_KEY.value, ), ) @@ -334,18 +337,20 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc ) docs = [] for result in results: + if result.payload is None: + continue metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold score_threshold = kwargs.get("score_threshold") or 0.0 if result.score > score_threshold: metadata["score"] = result.score doc = Document( - page_content=result.payload.get(Field.CONTENT_KEY.value), + page_content=result.payload.get(Field.CONTENT_KEY.value, ""), metadata=metadata, ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -427,12 +432,12 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings else: new_cluster = TidbService.create_tidb_serverless_cluster( - dify_config.TIDB_PROJECT_ID, - dify_config.TIDB_API_URL, - dify_config.TIDB_IAM_API_URL, - dify_config.TIDB_PUBLIC_KEY, - dify_config.TIDB_PRIVATE_KEY, - dify_config.TIDB_REGION, + dify_config.TIDB_PROJECT_ID or "", + dify_config.TIDB_API_URL or "", + dify_config.TIDB_IAM_API_URL or "", + dify_config.TIDB_PUBLIC_KEY or "", + dify_config.TIDB_PRIVATE_KEY or "", + dify_config.TIDB_REGION or "", ) new_tidb_auth_binding = TidbAuthBinding( cluster_id=new_cluster["cluster_id"], @@ -464,9 +469,9 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings collection_name=collection_name, group_id=dataset.id, config=TidbOnQdrantConfig( - endpoint=dify_config.TIDB_ON_QDRANT_URL, + endpoint=dify_config.TIDB_ON_QDRANT_URL or "", api_key=TIDB_ON_QDRANT_API_KEY, - root_path=config.root_path, + root_path=str(config.root_path), timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT, prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED, diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 8dd5922ad0171d..0a48c79511bf26 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -146,7 +146,7 @@ def batch_update_tidb_serverless_cluster_status( iam_url: str, public_key: str, private_key: str, - ) -> list[dict]: + ): """ Update the status of a new TiDB Serverless cluster. :param project_id: The project ID of the TiDB Cloud project (required). @@ -159,7 +159,6 @@ def batch_update_tidb_serverless_cluster_status( :return: The response from the API. """ - clusters = [] tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} cluster_ids = [item.cluster_id for item in tidb_serverless_list] params = {"clusterIds": cluster_ids, "view": "BASIC"} @@ -169,7 +168,6 @@ def batch_update_tidb_serverless_cluster_status( if response.status_code == 200: response_data = response.json() - cluster_infos = [] for item in response_data["clusters"]: state = item["state"] userPrefix = item["userPrefix"] @@ -236,16 +234,17 @@ def batch_create_tidb_serverless_cluster( cluster_infos = [] for item in response_data["clusters"]: cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" - password = redis_client.get(cache_key) - if not password: + cached_password = redis_client.get(cache_key) + if not cached_password: continue cluster_info = { "cluster_id": item["clusterId"], "cluster_name": item["displayName"], "account": "root", - "password": password.decode("utf-8"), + "password": cached_password.decode("utf-8"), } cluster_infos.append(cluster_info) return cluster_infos else: response.raise_for_status() + return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 39ab6ea71e9485..be3a417390e802 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -49,7 +49,7 @@ def get_type(self) -> str: return VectorType.TIDB_VECTOR def _table(self, dim: int) -> Table: - from tidb_vector.sqlalchemy import VectorType + from tidb_vector.sqlalchemy import VectorType # type: ignore return Table( self._collection_name, @@ -241,11 +241,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return TiDBVector( collection_name=collection_name, config=TiDBVectorConfig( - host=dify_config.TIDB_VECTOR_HOST, - port=dify_config.TIDB_VECTOR_PORT, - user=dify_config.TIDB_VECTOR_USER, - password=dify_config.TIDB_VECTOR_PASSWORD, - database=dify_config.TIDB_VECTOR_DATABASE, + host=dify_config.TIDB_VECTOR_HOST or "", + port=dify_config.TIDB_VECTOR_PORT or 0, + user=dify_config.TIDB_VECTOR_USER or "", + password=dify_config.TIDB_VECTOR_PASSWORD or "", + database=dify_config.TIDB_VECTOR_DATABASE or "", program_name=dify_config.APPLICATION_NAME, ), ) diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index 22e191340d3a47..edfce2edd896ee 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -51,15 +51,16 @@ def delete(self) -> None: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): - doc_id = text.metadata["doc_id"] - exists_duplicate_node = self.text_exists(doc_id) - if exists_duplicate_node: - texts.remove(text) + if text.metadata and "doc_id" in text.metadata: + doc_id = text.metadata["doc_id"] + exists_duplicate_node = self.text_exists(doc_id) + if exists_duplicate_node: + texts.remove(text) return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata["doc_id"] for text in texts] + return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata] @property def collection_name(self): diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 6d2e04fc020ab5..523fa80f124b0c 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -193,10 +193,13 @@ def _get_embeddings(self) -> Embeddings: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): + if text.metadata is None: + continue doc_id = text.metadata["doc_id"] - exists_duplicate_node = self.text_exists(doc_id) - if exists_duplicate_node: - texts.remove(text) + if doc_id: + exists_duplicate_node = self.text_exists(doc_id) + if exists_duplicate_node: + texts.remove(text) return texts diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index 4f927f28995613..9de8761a91ca68 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -2,7 +2,7 @@ from typing import Any from pydantic import BaseModel -from volcengine.viking_db import ( +from volcengine.viking_db import ( # type: ignore Data, DistanceType, Field, @@ -121,11 +121,12 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** for i, page_content in enumerate(page_contents): metadata = {} if metadatas is not None: - for key, val in metadatas[i].items(): + for key, val in (metadatas[i] or {}).items(): metadata[key] = val + # FIXME: fix the type of metadata later doc = Data( { - vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], + vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore vdb_Field.VECTOR.value: embeddings[i] if embeddings else None, vdb_Field.CONTENT_KEY.value: page_content, vdb_Field.METADATA_KEY.value: json.dumps(metadata), @@ -178,7 +179,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(results, score_threshold) - def _get_search_res(self, results, score_threshold): + def _get_search_res(self, results, score_threshold) -> list[Document]: if len(results) == 0: return [] @@ -191,7 +192,7 @@ def _get_search_res(self, results, score_threshold): metadata["score"] = result.score doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 649cfbfea8253c..68d043a19f171f 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -3,7 +3,7 @@ from typing import Any, Optional import requests -import weaviate +import weaviate # type: ignore from pydantic import BaseModel, model_validator from configs import dify_config @@ -107,7 +107,8 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** for i, text in enumerate(texts): data_properties = {Field.TEXT_KEY.value: text} if metadatas is not None: - for key, val in metadatas[i].items(): + # metadata maybe None + for key, val in (metadatas[i] or {}).items(): data_properties[key] = self._json_serializable(val) batch.add_data_object( @@ -208,10 +209,11 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc score_threshold = float(kwargs.get("score_threshold") or 0.0) # check score threshold if score > score_threshold: - doc.metadata["score"] = score - docs.append(doc) + if doc.metadata is not None: + doc.metadata["score"] = score + docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -275,7 +277,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return WeaviateVector( collection_name=collection_name, config=WeaviateConfig( - endpoint=dify_config.WEAVIATE_ENDPOINT, + endpoint=dify_config.WEAVIATE_ENDPOINT or "", api_key=dify_config.WEAVIATE_API_KEY, batch_size=dify_config.WEAVIATE_BATCH_SIZE, ), diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 319a2612c7ecb8..35becaa0c7bea7 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -83,6 +83,9 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> if not isinstance(doc, Document): raise ValueError("doc must be a Document") + if doc.metadata is None: + raise ValueError("doc.metadata must be a dict") + segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"]) # NOTE: doc could already exist in the store, but we overwrite it @@ -179,10 +182,10 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: if document_segment is None: return None + data: Optional[str] = document_segment.index_node_hash + return data - return document_segment.index_node_hash - - def get_document_segment(self, doc_id: str) -> DocumentSegment: + def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: document_segment = ( db.session.query(DocumentSegment) .filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 8ddda7e9832d97..a2c8737da79198 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -1,6 +1,6 @@ import base64 import logging -from typing import Optional, cast +from typing import Any, Optional, cast import numpy as np from sqlalchemy.exc import IntegrityError @@ -27,7 +27,7 @@ def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" # use doc embedding cache or store if not exists - text_embeddings = [None for _ in range(len(texts))] + text_embeddings: list[Any] = [None for _ in range(len(texts))] embedding_queue_indices = [] for i, text in enumerate(texts): hash = helper.generate_text_hash(text) @@ -64,7 +64,8 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: for vector in embedding_result.embeddings: try: - normalized_embedding = (vector / np.linalg.norm(vector)).tolist() + # FIXME: type ignore for numpy here + normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore # stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan if np.isnan(normalized_embedding).any(): # for issue #11827 float values are not json compliant @@ -77,8 +78,8 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: logging.exception("Failed transform embedding") cache_embeddings = [] try: - for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): - text_embeddings[i] = embedding + for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings): + text_embeddings[i] = n_embedding hash = helper.generate_text_hash(texts[i]) if hash not in cache_embeddings: embedding_cache = Embedding( @@ -86,7 +87,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: hash=hash, provider_name=self._model_instance.provider, ) - embedding_cache.set_embedding(embedding) + embedding_cache.set_embedding(n_embedding) db.session.add(embedding_cache) cache_embeddings.append(hash) db.session.commit() @@ -115,7 +116,8 @@ def embed_query(self, text: str) -> list[float]: ) embedding_results = embedding_result.embeddings[0] - embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() + # FIXME: type ignore for numpy here + embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore if np.isnan(embedding_results).any(): raise ValueError("Normalized embedding is nan please try again") except Exception as ex: diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 3692b5d19dfb65..7c00c668dd49a3 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -14,7 +14,7 @@ class NotionInfo(BaseModel): notion_workspace_id: str notion_obj_id: str notion_page_type: str - document: Document = None + document: Optional[Document] = None tenant_id: str model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index fc331657195454..c444105bb59443 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" import os -from typing import Optional +from typing import Optional, cast import pandas as pd from openpyxl import load_workbook @@ -47,7 +47,7 @@ def extract(self) -> list[Document]: for col_index, (k, v) in enumerate(row.items()): if pd.notna(v): cell = sheet.cell( - row=index + 2, column=col_index + 1 + row=cast(int, index) + 2, column=col_index + 1 ) # +2 to account for header and 1-based index if cell.hyperlink: value = f"[{v}]({cell.hyperlink.target})" @@ -60,8 +60,8 @@ def extract(self) -> list[Document]: elif file_extension == ".xls": excel_file = pd.ExcelFile(self._file_path, engine="xlrd") - for sheet_name in excel_file.sheet_names: - df = excel_file.parse(sheet_name=sheet_name) + for excel_sheet_name in excel_file.sheet_names: + df = excel_file.parse(sheet_name=excel_sheet_name) df.dropna(how="all", inplace=True) for _, row in df.iterrows(): diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 69659e31080da6..a473b3dfa78a90 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -10,6 +10,7 @@ from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.excel_extractor import ExcelExtractor +from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor from core.rag.extractor.html_extractor import HtmlExtractor from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor @@ -66,9 +67,13 @@ def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Docume filename_match = re.search(r'filename="([^"]+)"', content_disposition) if filename_match: filename = unquote(filename_match.group(1)) - suffix = "." + re.search(r"\.(\w+)$", filename).group(1) - - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + match = re.search(r"\.(\w+)$", filename) + if match: + suffix = "." + match.group(1) + else: + suffix = "" + # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore Path(file_path).write_bytes(response.content) extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") if return_text: @@ -89,15 +94,20 @@ def extract( if extract_setting.datasource_type == DatasourceType.FILE.value: with tempfile.TemporaryDirectory() as temp_dir: if not file_path: + assert extract_setting.upload_file is not None, "upload_file is required" upload_file: UploadFile = extract_setting.upload_file suffix = Path(upload_file.key).suffix - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore storage.download(upload_file.key, file_path) input_file = Path(file_path) file_extension = input_file.suffix.lower() etl_type = dify_config.ETL_TYPE unstructured_api_url = dify_config.UNSTRUCTURED_API_URL unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY + assert unstructured_api_url is not None, "unstructured_api_url is required" + assert unstructured_api_key is not None, "unstructured_api_key is required" + extractor: Optional[BaseExtractor] = None if etl_type == "Unstructured": if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) @@ -156,6 +166,7 @@ def extract( extractor = TextExtractor(file_path, autodetect_encoding=True) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.NOTION.value: + assert extract_setting.notion_info is not None, "notion_info is required" extractor = NotionExtractor( notion_workspace_id=extract_setting.notion_info.notion_workspace_id, notion_obj_id=extract_setting.notion_info.notion_obj_id, @@ -165,6 +176,7 @@ def extract( ) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.WEBSITE.value: + assert extract_setting.website_info is not None, "website_info is required" if extract_setting.website_info.provider == "firecrawl": extractor = FirecrawlWebExtractor( url=extract_setting.website_info.url, diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 17c2087a0ab575..8ae4579c7cf93f 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -1,5 +1,6 @@ import json import time +from typing import cast import requests @@ -20,9 +21,9 @@ def scrape_url(self, url, params=None) -> dict: json_data.update(params) response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data) if response.status_code == 200: - response = response.json() - if response["success"] == True: - data = response["data"] + response_data = response.json() + if response_data["success"] == True: + data = response_data["data"] return { "title": data.get("metadata").get("title"), "description": data.get("metadata").get("description"), @@ -30,7 +31,7 @@ def scrape_url(self, url, params=None) -> dict: "markdown": data.get("markdown"), } else: - raise Exception(f'Failed to scrape URL. Error: {response["error"]}') + raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}') elif response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") @@ -46,9 +47,11 @@ def crawl_url(self, url, params=None) -> str: response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers) if response.status_code == 200: job_id = response.json().get("jobId") - return job_id + return cast(str, job_id) else: self._handle_error(response, "start crawl job") + # FIXME: unreachable code for mypy + return "" # unreachable def check_crawl_status(self, job_id) -> dict: headers = self._prepare_headers() @@ -64,9 +67,9 @@ def check_crawl_status(self, job_id) -> dict: for item in data: if isinstance(item, dict) and "metadata" in item and "markdown" in item: url_data = { - "title": item.get("metadata").get("title"), - "description": item.get("metadata").get("description"), - "source_url": item.get("metadata").get("sourceURL"), + "title": item.get("metadata", {}).get("title"), + "description": item.get("metadata", {}).get("description"), + "source_url": item.get("metadata", {}).get("sourceURL"), "markdown": item.get("markdown"), } url_data_list.append(url_data) @@ -92,6 +95,8 @@ def check_crawl_status(self, job_id) -> dict: else: self._handle_error(response, "check crawl status") + # FIXME: unreachable code for mypy + return {} # unreachable def _prepare_headers(self): return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} diff --git a/api/core/rag/extractor/html_extractor.py b/api/core/rag/extractor/html_extractor.py index 560c2d1d84b04e..350b522347b09d 100644 --- a/api/core/rag/extractor/html_extractor.py +++ b/api/core/rag/extractor/html_extractor.py @@ -1,6 +1,6 @@ """Abstract interface for document loader implementations.""" -from bs4 import BeautifulSoup +from bs4 import BeautifulSoup # type: ignore from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -23,6 +23,7 @@ def extract(self) -> list[Document]: return [Document(page_content=self._load_as_text())] def _load_as_text(self) -> str: + text: str = "" with open(self._file_path, "rb") as fp: soup = BeautifulSoup(fp, "html.parser") text = soup.get_text() diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 87a4ce08bf3f89..fdc2e46d141d07 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Optional +from typing import Any, Optional, cast import requests @@ -78,6 +78,7 @@ def _load_data_as_documents(self, notion_obj_id: str, notion_page_type: str) -> def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]: """Get all the pages from a Notion database.""" + assert self._notion_access_token is not None, "Notion access token is required" res = requests.post( DATABASE_URL_TMPL.format(database_id=database_id), headers={ @@ -96,6 +97,7 @@ def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] for result in data["results"]: properties = result["properties"] data = {} + value: Any for property_name, property_value in properties.items(): type = property_value["type"] if type == "multi_select": @@ -130,6 +132,7 @@ def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] return [Document(page_content="\n".join(database_content))] def _get_notion_block_data(self, page_id: str) -> list[str]: + assert self._notion_access_token is not None, "Notion access token is required" result_lines_arr = [] start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id) @@ -184,6 +187,7 @@ def _get_notion_block_data(self, page_id: str) -> list[str]: def _read_block(self, block_id: str, num_tabs: int = 0) -> str: """Read a block.""" + assert self._notion_access_token is not None, "Notion access token is required" result_lines_arr = [] start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id) @@ -242,6 +246,7 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str: def _read_table_rows(self, block_id: str) -> str: """Read table rows.""" + assert self._notion_access_token is not None, "Notion access token is required" done = False result_lines_arr = [] start_cursor = None @@ -296,7 +301,7 @@ def _read_table_rows(self, block_id: str) -> str: result_lines = "\n".join(result_lines_arr) return result_lines - def update_last_edited_time(self, document_model: DocumentModel): + def update_last_edited_time(self, document_model: Optional[DocumentModel]): if not document_model: return @@ -309,6 +314,7 @@ def update_last_edited_time(self, document_model: DocumentModel): db.session.commit() def get_notion_last_edited_time(self) -> str: + assert self._notion_access_token is not None, "Notion access token is required" obj_id = self._notion_obj_id page_type = self._notion_page_type if page_type == "database": @@ -330,7 +336,7 @@ def get_notion_last_edited_time(self) -> str: ) data = res.json() - return data["last_edited_time"] + return cast(str, data["last_edited_time"]) @classmethod def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: @@ -349,4 +355,4 @@ def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: f"and notion workspace {notion_workspace_id}" ) - return data_source_binding.access_token + return cast(str, data_source_binding.access_token) diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 57cb9610ba267e..89a7061c26accc 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" from collections.abc import Iterator -from typing import Optional +from typing import Optional, cast from core.rag.extractor.blob.blob import Blob from core.rag.extractor.extractor_base import BaseExtractor @@ -27,7 +27,7 @@ def extract(self) -> list[Document]: plaintext_file_exists = False if self._file_cache_key: try: - text = storage.load(self._file_cache_key).decode("utf-8") + text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8") plaintext_file_exists = True return [Document(page_content=text)] except FileNotFoundError: @@ -53,7 +53,7 @@ def load( def parse(self, blob: Blob) -> Iterator[Document]: """Lazily parse the blob.""" - import pypdfium2 + import pypdfium2 # type: ignore with blob.as_bytes_io() as file_path: pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index bd669bbad36873..9647dedfff8516 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -1,7 +1,7 @@ import base64 import logging -from bs4 import BeautifulSoup +from bs4 import BeautifulSoup # type: ignore from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py index 35220b558afab9..80c29157aaf529 100644 --- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -30,6 +30,9 @@ def extract(self) -> list[Document]: if self._api_url: from unstructured.partition.api import partition_via_api + if self._api_key is None: + raise ValueError("api_key is required") + elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) else: from unstructured.partition.epub import partition_epub diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index 0fdcd58b2e569b..e504d4bc23014c 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -27,9 +27,11 @@ def extract(self) -> list[Document]: elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) else: raise NotImplementedError("Unstructured API Url is not configured") - text_by_page = {} + text_by_page: dict[int, str] = {} for element in elements: page = element.metadata.page_number + if page is None: + continue text = element.text if page in text_by_page: text_by_page[page] += "\n" + text diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index ab41290fbc4537..cefe72b29052a1 100644 --- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -29,14 +29,15 @@ def extract(self) -> list[Document]: from unstructured.partition.pptx import partition_pptx elements = partition_pptx(filename=self._file_path) - text_by_page = {} + text_by_page: dict[int, str] = {} for element in elements: page = element.metadata.page_number text = element.text - if page in text_by_page: - text_by_page[page] += "\n" + text - else: - text_by_page[page] = text + if page is not None: + if page in text_by_page: + text_by_page[page] += "\n" + text + else: + text_by_page[page] = text combined_texts = list(text_by_page.values()) documents = [] diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 0c38a9c0762130..c3161bc812cb73 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -89,6 +89,8 @@ def _extract_images_from_docx(self, doc, image_folder): response = ssrf_proxy.get(url) if response.status_code == 200: image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) + if image_ext is None: + continue file_uuid = str(uuid.uuid4()) file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext mime_type, _ = mimetypes.guess_type(file_key) @@ -97,6 +99,8 @@ def _extract_images_from_docx(self, doc, image_folder): continue else: image_ext = rel.target_ref.split(".")[-1] + if image_ext is None: + continue # user uuid as file name file_uuid = str(uuid.uuid4()) file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext @@ -226,6 +230,8 @@ def parse_docx(self, docx_path, image_folder): if x_child is None: continue if x.tag.endswith("instrText"): + if x.text is None: + continue for i in url_pattern.findall(x.text): hyperlinks_url = str(i) except Exception as e: diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index be857bd12215fd..7e5efdc66ed533 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -49,6 +49,7 @@ def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optiona """ Get the NodeParser object according to the processing rule. """ + character_splitter: TextSplitter if processing_rule["mode"] == "custom": # The user-defined segmentation rule rules = processing_rule["rules"] diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index 9b855ece2c3512..c5ba6295f32f84 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -9,7 +9,7 @@ class IndexProcessorFactory: """IndexProcessorInit.""" - def __init__(self, index_type: str): + def __init__(self, index_type: str | None): self._index_type = index_type def init_index_processor(self) -> BaseIndexProcessor: diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index a631f953ce2191..c66fa54d503e9f 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -27,12 +27,13 @@ def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]: # Split the text documents into nodes. splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + processing_rule=kwargs.get("process_rule", {}), + embedding_model_instance=kwargs.get("embedding_model_instance"), ) all_documents = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule", {})) document.page_content = document_text # parse document to nodes document_nodes = splitter.split_documents([document]) @@ -41,8 +42,9 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash + if document_node.metadata is not None: + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = remove_leading_symbols(document_node.page_content).strip() if len(page_content) > 0: diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 320f0157a10049..20fd16e8f39b65 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -32,15 +32,16 @@ def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]: splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + processing_rule=kwargs.get("process_rule") or {}, + embedding_model_instance=kwargs.get("embedding_model_instance"), ) # Split the text documents into nodes. - all_documents = [] - all_qa_documents = [] + all_documents: list[Document] = [] + all_qa_documents: list[Document] = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule") or {}) document.page_content = document_text # parse document to nodes @@ -50,8 +51,9 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash + if document_node.metadata is not None: + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content document_node.page_content = remove_leading_symbols(page_content) @@ -64,7 +66,7 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: document_format_thread = threading.Thread( target=self._format_qa_document, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "tenant_id": kwargs.get("tenant_id"), "document_node": doc, "all_qa_documents": all_qa_documents, @@ -148,11 +150,12 @@ def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, a qa_documents = [] for result in document_qa_list: qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy()) - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result["question"]) - qa_document.metadata["answer"] = result["answer"] - qa_document.metadata["doc_id"] = doc_id - qa_document.metadata["doc_hash"] = hash + if qa_document.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 6ae432a526b169..ac7a3f8bb857e4 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -30,7 +30,11 @@ def run( doc_ids = set() unique_documents = [] for document in documents: - if document.provider == "dify" and document.metadata["doc_id"] not in doc_ids: + if ( + document.provider == "dify" + and document.metadata is not None + and document.metadata["doc_id"] not in doc_ids + ): doc_ids.add(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) @@ -54,7 +58,8 @@ def run( metadata=documents[result.index].metadata, provider=documents[result.index].provider, ) - rerank_document.metadata["score"] = result.score - rerank_documents.append(rerank_document) + if rerank_document.metadata is not None: + rerank_document.metadata["score"] = result.score + rerank_documents.append(rerank_document) return rerank_documents diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 4719be012f99cc..cbc96037bf2cc0 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -39,7 +39,7 @@ def run( unique_documents = [] doc_ids = set() for document in documents: - if document.metadata["doc_id"] not in doc_ids: + if document.metadata is not None and document.metadata["doc_id"] not in doc_ids: doc_ids.add(document.metadata["doc_id"]) unique_documents.append(document) @@ -56,10 +56,11 @@ def run( ) if score_threshold and score < score_threshold: continue - document.metadata["score"] = score - rerank_documents.append(document) + if document.metadata is not None: + document.metadata["score"] = score + rerank_documents.append(document) - rerank_documents.sort(key=lambda x: x.metadata["score"], reverse=True) + rerank_documents.sort(key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return rerank_documents[:top_n] if top_n else rerank_documents def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]: @@ -76,8 +77,9 @@ def _calculate_keyword_score(self, query: str, documents: list[Document]) -> lis for document in documents: # get the document keywords document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata["keywords"] = document_keywords - documents_keywords.append(document_keywords) + if document.metadata is not None: + document.metadata["keywords"] = document_keywords + documents_keywords.append(document_keywords) # Counter query keywords(TF) query_keyword_counts = Counter(query_keywords) @@ -162,7 +164,7 @@ def _calculate_cosine( query_vector = cache_embedding.embed_query(query) for document in documents: # calculate cosine similarity - if "score" in document.metadata: + if document.metadata and "score" in document.metadata: query_vector_scores.append(document.metadata["score"]) else: # transform to NumPy diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 7a5bf39fa63f48..a265f36671b04b 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1,7 +1,7 @@ import math import threading from collections import Counter -from typing import Optional, cast +from typing import Any, Optional, cast from flask import Flask, current_app @@ -34,7 +34,7 @@ from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model = { +default_retrieval_model: dict[str, Any] = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -140,12 +140,12 @@ def retrieve( user_from, available_datasets, query, - retrieve_config.top_k, - retrieve_config.score_threshold, - retrieve_config.rerank_mode, + retrieve_config.top_k or 0, + retrieve_config.score_threshold or 0, + retrieve_config.rerank_mode or "reranking_model", retrieve_config.reranking_model, retrieve_config.weights, - retrieve_config.reranking_enabled, + retrieve_config.reranking_enabled or True, message_id, ) @@ -300,10 +300,11 @@ def single_retrieve( metadata=external_document.get("metadata"), provider="external", ) - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset_id - document.metadata["dataset_name"] = dataset.name + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset_id + document.metadata["dataset_name"] = dataset.name results.append(document) else: retrieval_model_config = dataset.retrieval_model or default_retrieval_model @@ -325,7 +326,7 @@ def single_retrieve( score_threshold = 0.0 score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") if score_threshold_enabled: - score_threshold = retrieval_model_config.get("score_threshold") + score_threshold = retrieval_model_config.get("score_threshold", 0.0) with measure_time() as timer: results = RetrievalService.retrieve( @@ -358,14 +359,14 @@ def multiple_retrieve( score_threshold: float, reranking_mode: str, reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, + weights: Optional[dict[str, Any]] = None, reranking_enable: bool = True, message_id: Optional[str] = None, ): if not available_datasets: return [] threads = [] - all_documents = [] + all_documents: list[Document] = [] dataset_ids = [dataset.id for dataset in available_datasets] index_type_check = all( item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets @@ -392,15 +393,18 @@ def multiple_retrieve( "The configured knowledge base list have different embedding model, please set reranking model." ) if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE: - weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider - weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model + if weights is not None: + weights["vector_setting"]["embedding_provider_name"] = available_datasets[ + 0 + ].embedding_model_provider + weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model for dataset in available_datasets: index_type = dataset.indexing_technique retrieval_thread = threading.Thread( target=self._retriever, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset.id, "query": query, "top_k": top_k, @@ -439,21 +443,22 @@ def _on_retrieval_end( """Handle retrieval end.""" dify_documents = [document for document in documents if document.provider == "dify"] for document in dify_documents: - query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + if document.metadata is not None: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata["doc_id"] + ) - # if 'dataset_id' in document.metadata: - if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + # if 'dataset_id' in document.metadata: + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + # add hit count to document segment + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) - db.session.commit() + db.session.commit() # get tracing instance - trace_manager: TraceQueueManager = ( + trace_manager: Optional[TraceQueueManager] = ( self.application_generate_entity.trace_manager if self.application_generate_entity else None ) if trace_manager: @@ -504,10 +509,11 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, metadata=external_document.get("metadata"), provider="external", ) - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset_id - document.metadata["dataset_name"] = dataset.name + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset_id + document.metadata["dataset_name"] = dataset.name all_documents.append(document) else: # get retrieval model , if the model is not setting , using default @@ -607,19 +613,20 @@ def to_dataset_retriever_tool( tools.append(tool) elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: - tool = DatasetMultiRetrieverTool.from_dataset( - dataset_ids=[dataset.id for dataset in available_datasets], - tenant_id=tenant_id, - top_k=retrieve_config.top_k or 2, - score_threshold=retrieve_config.score_threshold, - hit_callbacks=[hit_callback], - return_resource=return_resource, - retriever_from=invoke_from.to_source(), - reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), - reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), - ) + if retrieve_config.reranking_model is not None: + tool = DatasetMultiRetrieverTool.from_dataset( + dataset_ids=[dataset.id for dataset in available_datasets], + tenant_id=tenant_id, + top_k=retrieve_config.top_k or 2, + score_threshold=retrieve_config.score_threshold, + hit_callbacks=[hit_callback], + return_resource=return_resource, + retriever_from=invoke_from.to_source(), + reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), + reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), + ) - tools.append(tool) + tools.append(tool) return tools @@ -635,10 +642,11 @@ def calculate_keyword_score(self, query: str, documents: list[Document], top_k: query_keywords = keyword_table_handler.extract_keywords(query, None) documents_keywords = [] for document in documents: - # get the document keywords - document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata["keywords"] = document_keywords - documents_keywords.append(document_keywords) + if document.metadata is not None: + # get the document keywords + document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) + document.metadata["keywords"] = document_keywords + documents_keywords.append(document_keywords) # Counter query keywords(TF) query_keyword_counts = Counter(query_keywords) @@ -696,8 +704,9 @@ def cosine_similarity(vec1, vec2): for document, score in zip(documents, similarities): # format document - document.metadata["score"] = score - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + if document.metadata is not None: + document.metadata["score"] = score + documents = sorted(documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return documents[:top_k] if top_k else documents def calculate_vector_score( @@ -705,10 +714,12 @@ def calculate_vector_score( ) -> list[Document]: filter_documents = [] for document in all_documents: - if score_threshold is None or document.metadata["score"] >= score_threshold: + if score_threshold is None or (document.metadata and document.metadata.get("score", 0) >= score_threshold): filter_documents.append(document) if not filter_documents: return [] - filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True) + filter_documents = sorted( + filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True + ) return filter_documents[:top_k] if top_k else filter_documents diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 06147fe7b56544..b008d0df9c2f0e 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -1,7 +1,8 @@ -from typing import Union +from typing import Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage @@ -27,11 +28,14 @@ def invoke( SystemPromptMessage(content="You are a helpful AI assistant."), UserPromptMessage(content=query), ] - result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - tools=dataset_tools, - stream=False, - model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, + result = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + tools=dataset_tools, + stream=False, + model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, + ), ) if result.message.tool_calls: # get retrieval model config diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 68fab0c127a253..05e8d043dfe741 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,9 +1,9 @@ from collections.abc import Generator, Sequence -from typing import Union +from typing import Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate @@ -92,6 +92,7 @@ def _react_invoke( suffix: str = SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, ) -> Union[str, None]: + prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate] if model_config.mode == "chat": prompt = self.create_chat_prompt( query=query, @@ -149,12 +150,15 @@ def _invoke_llm( :param stop: stop :return: """ - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=completion_param, - stop=stop, - stream=True, - user=user_id, + invoke_result = cast( + Generator[LLMResult, None, None], + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=completion_param, + stop=stop, + stream=True, + user=user_id, + ), ) # handle invoke result @@ -172,7 +176,7 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage :return: """ model = None - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] full_text = "" usage = None for result in invoke_result: diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 53032b34d570c7..3376bd7f75dd96 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -26,8 +26,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): def from_encoder( cls: type[TS], embedding_model_instance: Optional[ModelInstance], - allowed_special: Union[Literal[all], Set[str]] = set(), - disallowed_special: Union[Literal[all], Collection[str]] = "all", + allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037 + disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037 **kwargs: Any, ): def _token_encoder(text: str) -> int: diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 7dd62f8de18a15..4bfa541fd454ad 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -92,7 +92,7 @@ def split_documents(self, documents: Iterable[Document]) -> list[Document]: texts, metadatas = [], [] for doc in documents: texts.append(doc.page_content) - metadatas.append(doc.metadata) + metadatas.append(doc.metadata or {}) return self.create_documents(texts, metadatas=metadatas) def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: @@ -143,7 +143,7 @@ def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: """Text splitter that uses HuggingFace tokenizer to count length.""" try: - from transformers import PreTrainedTokenizerBase + from transformers import PreTrainedTokenizerBase # type: ignore if not isinstance(tokenizer, PreTrainedTokenizerBase): raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase") diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index ddb1481276df67..975c374cae8356 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -14,7 +14,7 @@ class UserTool(BaseModel): label: I18nObject # label description: I18nObject parameters: Optional[list[ToolParameter]] = None - labels: list[str] = None + labels: list[str] | None = None UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]] diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index 0c15b2a3711f11..7c365dc69d3b39 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -18,7 +18,7 @@ class ApiToolBundle(BaseModel): # summary summary: Optional[str] = None # operation_id - operation_id: str = None + operation_id: str | None = None # parameters parameters: Optional[list[ToolParameter]] = None # author diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 4fc383f91baeba..260e4e457f083e 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -244,18 +244,19 @@ def get_simple_instance( """ # convert options to ToolParameterOption if options: - options = [ + options_tool_parametor = [ ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options ] return cls( name=name, label=I18nObject(en_US="", zh_Hans=""), human_description=I18nObject(en_US="", zh_Hans=""), + placeholder=None, type=type, form=cls.ToolParameterForm.LLM, llm_description=llm_description, required=required, - options=options, + options=options_tool_parametor, ) @@ -331,7 +332,7 @@ def to_dict(self) -> dict: "default": self.default, "options": self.options, "help": self.help.to_dict() if self.help else None, - "label": self.label.to_dict(), + "label": self.label.to_dict() if self.label else None, "url": self.url, "placeholder": self.placeholder.to_dict() if self.placeholder else None, } @@ -374,7 +375,10 @@ def __init__(self, **data: Any): pool[index] = ToolRuntimeImageVariable(**variable) super().__init__(**data) - def dict(self) -> dict: + def dict(self) -> dict: # type: ignore + """ + FIXME: just ignore the type check for now + """ return { "conversation_id": self.conversation_id, "user_id": self.user_id, diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index d99314e33a3204..f451edbf2ee969 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -1,9 +1,14 @@ +from typing import Optional + from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolCredentialsOption, + ToolDescription, + ToolIdentity, ToolProviderCredentials, + ToolProviderIdentity, ToolProviderType, ) from core.tools.provider.tool_provider import ToolProviderController @@ -64,21 +69,18 @@ def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "Ap pass else: raise ValueError(f"invalid auth type {auth_type}") - - user_name = db_provider.user.name if db_provider.user_id else "" - + user_name = db_provider.user.name if db_provider.user_id and db_provider.user is not None else "" return ApiToolProviderController( - **{ - "identity": { - "author": user_name, - "name": db_provider.name, - "label": {"en_US": db_provider.name, "zh_Hans": db_provider.name}, - "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, - "icon": db_provider.icon, - }, - "credentials_schema": credentials_schema, - "provider_id": db_provider.id or "", - } + identity=ToolProviderIdentity( + author=user_name, + name=db_provider.name, + label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), + description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), + icon=db_provider.icon, + ), + credentials_schema=credentials_schema, + provider_id=db_provider.id or "", + tools=None, ) @property @@ -93,24 +95,22 @@ def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool: :return: the tool """ return ApiTool( - **{ - "api_bundle": tool_bundle, - "identity": { - "author": tool_bundle.author, - "name": tool_bundle.operation_id, - "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id}, - "icon": self.identity.icon, - "provider": self.provider_id, - }, - "description": { - "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""}, - "llm": tool_bundle.summary or "", - }, - "parameters": tool_bundle.parameters or [], - } + api_bundle=tool_bundle, + identity=ToolIdentity( + author=tool_bundle.author, + name=tool_bundle.operation_id or "", + label=I18nObject(en_US=tool_bundle.operation_id, zh_Hans=tool_bundle.operation_id), + icon=self.identity.icon if self.identity else None, + provider=self.provider_id, + ), + description=ToolDescription( + human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""), + llm=tool_bundle.summary or "", + ), + parameters=tool_bundle.parameters or [], ) - def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: + def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[Tool]: """ load bundled tools @@ -121,7 +121,7 @@ def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: return self.tools - def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ fetch tools from database @@ -131,6 +131,8 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: """ if self.tools is not None: return self.tools + if self.identity is None: + return None tools: list[Tool] = [] @@ -151,7 +153,7 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: self.tools = tools return tools - def get_tool(self, tool_name: str) -> ApiTool: + def get_tool(self, tool_name: str) -> Tool: """ get tool by name @@ -161,7 +163,9 @@ def get_tool(self, tool_name: str) -> ApiTool: if self.tools is None: self.get_tools() - for tool in self.tools: + for tool in self.tools or []: + if tool.identity is None: + continue if tool.identity.name == tool_name: return tool diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 582ad636b1953a..fc29920acd40dc 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -1,9 +1,10 @@ import logging -from typing import Any +from typing import Any, Optional from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.api_tool import ApiTool from core.tools.tool.tool import Tool from extensions.ext_database import db from models.model import App, AppModelConfig @@ -20,10 +21,10 @@ def provider_type(self) -> ToolProviderType: def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None: pass - def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None: + def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: pass - def get_tools(self, user_id: str) -> list[Tool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> list[Tool]: db_tools: list[PublishedAppTool] = ( db.session.query(PublishedAppTool) .filter( @@ -38,7 +39,7 @@ def get_tools(self, user_id: str) -> list[Tool]: tools: list[Tool] = [] for db_tool in db_tools: - tool = { + tool: dict[str, Any] = { "identity": { "author": db_tool.author, "name": db_tool.tool_name, @@ -52,7 +53,7 @@ def get_tools(self, user_id: str) -> list[Tool]: "parameters": [], } # get app from db - app: App = db_tool.app + app: Optional[App] = db_tool.app if not app: logger.error(f"app {db_tool.app_id} not found") @@ -79,6 +80,7 @@ def get_tools(self, user_id: str) -> list[Tool]: type=ToolParameter.ToolParameterType.STRING, required=required, default=default, + placeholder=I18nObject(en_US="", zh_Hans=""), ) ) elif form_type == "select": @@ -92,6 +94,7 @@ def get_tools(self, user_id: str) -> list[Tool]: type=ToolParameter.ToolParameterType.SELECT, required=required, default=default, + placeholder=I18nObject(en_US="", zh_Hans=""), options=[ ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options @@ -99,5 +102,5 @@ def get_tools(self, user_id: str) -> list[Tool]: ) ) - tools.append(Tool(**tool)) + tools.append(ApiTool(**tool)) return tools diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index 5c10f72fdaed01..99a062f8c366aa 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -5,7 +5,7 @@ class BuiltinToolProviderSort: - _position = {} + _position: dict[str, int] = {} @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index 38123f125ae974..cf10f5d2556edd 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -4,7 +4,7 @@ from json import loads as json_loads from threading import Lock from time import sleep, time -from typing import Any +from typing import Any, Union from httpx import get, post from requests import get as requests_get @@ -21,23 +21,25 @@ class AIPPTGenerateToolAdapter: """ _api_base_url = URL("https://co.aippt.cn/api") - _api_token_cache = {} - _style_cache = {} + _api_token_cache: dict[str, dict[str, Union[str, float]]] = {} + _style_cache: dict[str, dict[str, Union[list[dict[str, Any]], float]]] = {} - _api_token_cache_lock = Lock() - _style_cache_lock = Lock() + _api_token_cache_lock: Lock = Lock() + _style_cache_lock: Lock = Lock() - _task = {} + _task: dict[str, Any] = {} _task_type_map = { "auto": 1, "markdown": 7, } - _tool: BuiltinTool + _tool: BuiltinTool | None - def __init__(self, tool: BuiltinTool = None): + def __init__(self, tool: BuiltinTool | None = None): self._tool = tool - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invokes the AIPPT generate tool with the given user ID and tool parameters. @@ -68,8 +70,8 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe ) # get suit - color: str = tool_parameters.get("color") - style: str = tool_parameters.get("style") + color: str = tool_parameters.get("color", "") + style: str = tool_parameters.get("style", "") if color == "__default__": color_id = "" @@ -226,7 +228,7 @@ def _generate_content(self, task_id: str, model: str, user_id: str) -> str: return "" - def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]: + def _generate_ppt(self, task_id: str, suit_id: int, user_id: str) -> tuple[str, str]: """ Generate a ppt @@ -362,7 +364,9 @@ def _calculate_sign(access_key: str, secret_key: str, timestamp: int) -> str: ).decode("utf-8") @classmethod - def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]: + def _get_styles( + cls, credentials: dict[str, str], user_id: str + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """ Get styles """ @@ -415,7 +419,7 @@ def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[di return colors, styles - def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]: + def get_styles(self, user_id: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """ Get styles @@ -507,7 +511,9 @@ class AIPPTGenerateTool(BuiltinTool): def __init__(self, **kwargs: Any): super().__init__(**kwargs) - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters) def get_runtime_parameters(self) -> list[ToolParameter]: diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py index 2d65ba2d6f4389..8bd16050ecf0a6 100644 --- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py +++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py @@ -1,7 +1,7 @@ import logging from typing import Any, Optional -import arxiv +import arxiv # type: ignore from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/audio/tools/tts.py b/api/core/tools/provider/builtin/audio/tools/tts.py index f83a64d041faab..8a33ac405bd4c3 100644 --- a/api/core/tools/provider/builtin/audio/tools/tts.py +++ b/api/core/tools/provider/builtin/audio/tools/tts.py @@ -11,19 +11,21 @@ class TTSTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: - provider, model = tool_parameters.get("model").split("#") - voice = tool_parameters.get(f"voice#{provider}#{model}") + provider, model = tool_parameters.get("model", "").split("#") + voice = tool_parameters.get(f"voice#{provider}#{model}", "") model_manager = ModelManager() + if not self.runtime: + raise ValueError("Runtime is required") model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", provider=provider, model_type=ModelType.TTS, model=model, ) tts = model_instance.invoke_tts( - content_text=tool_parameters.get("text"), + content_text=tool_parameters.get("text", ""), user=user_id, - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", voice=voice, ) buffer = io.BytesIO() @@ -41,8 +43,11 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInv ] def get_available_models(self) -> list[tuple[str, str, list[Any]]]: + if not self.runtime: + raise ValueError("Runtime is required") model_provider_service = ModelProviderService() - models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts") + tid: str = self.runtime.tenant_id or "" + models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts") items = [] for provider_model in models: provider = provider_model.provider @@ -62,6 +67,8 @@ def get_runtime_parameters(self) -> list[ToolParameter]: ToolParameter( name=f"voice#{provider}#{model}", label=I18nObject(en_US=f"Voice of {model}({provider})"), + human_description=I18nObject(en_US=f"Select a voice for {model} model"), + placeholder=I18nObject(en_US="Select a voice"), type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, options=[ @@ -83,6 +90,7 @@ def get_runtime_parameters(self) -> list[ToolParameter]: type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, required=True, + placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"), options=options, ), ) diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py index a04f5c0fe9f1af..b224ff5258c879 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py @@ -2,8 +2,8 @@ import logging from typing import Any, Union -import boto3 -from botocore.exceptions import BotoCoreError +import boto3 # type: ignore +from botocore.exceptions import BotoCoreError # type: ignore from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py index 989608122185c8..b6d16d2759c30e 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py index f43f3b6fe05694..01bc596346c231 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py @@ -2,7 +2,7 @@ import logging from typing import Any, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index bffcd058b509bf..715b1ddeddcae5 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -2,7 +2,7 @@ import operator from typing import Any, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -10,8 +10,8 @@ class SageMakerReRankTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint: str = None - topk: int = None + sagemaker_endpoint: str | None = None + topk: int | None = None def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): inputs = [query_input] * len(docs) diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py index 1fafe09b4d96bf..55cff89798a4eb 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Any, Optional, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -17,7 +17,7 @@ class TTSModelType(Enum): class SageMakerTTSTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint: str = None + sagemaker_endpoint: str | None = None s3_client: Any = None comprehend_client: Any = None diff --git a/api/core/tools/provider/builtin/cogview/tools/cogvideo.py b/api/core/tools/provider/builtin/cogview/tools/cogvideo.py index 7f69e833cb9046..a60062ca66abbf 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogvideo.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogvideo.py @@ -1,6 +1,6 @@ from typing import Any, Union -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py b/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py index a521f1c28a41b6..3e24b74d2598a7 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py @@ -1,7 +1,7 @@ from typing import Any, Union import httpx -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py index 12b4173fa40270..9aa781709a726c 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -1,7 +1,7 @@ import random from typing import Any, Union -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/feishu_base/tools/search_records.py b/api/core/tools/provider/builtin/feishu_base/tools/search_records.py index c959496735e747..d58b42b82029ce 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/search_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/search_records.py @@ -7,18 +7,22 @@ class SearchRecordsTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - app_token = tool_parameters.get("app_token") - table_id = tool_parameters.get("table_id") - table_name = tool_parameters.get("table_name") - view_id = tool_parameters.get("view_id") - field_names = tool_parameters.get("field_names") - sort = tool_parameters.get("sort") - filters = tool_parameters.get("filter") - page_token = tool_parameters.get("page_token") + app_token = tool_parameters.get("app_token", "") + table_id = tool_parameters.get("table_id", "") + table_name = tool_parameters.get("table_name", "") + view_id = tool_parameters.get("view_id", "") + field_names = tool_parameters.get("field_names", "") + sort = tool_parameters.get("sort", "") + filters = tool_parameters.get("filter", "") + page_token = tool_parameters.get("page_token", "") automatic_fields = tool_parameters.get("automatic_fields", False) user_id_type = tool_parameters.get("user_id_type", "open_id") page_size = tool_parameters.get("page_size", 20) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_records.py b/api/core/tools/provider/builtin/feishu_base/tools/update_records.py index a7b036387500b0..31cf8e18d85b8d 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/update_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/update_records.py @@ -7,14 +7,18 @@ class UpdateRecordsTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - app_token = tool_parameters.get("app_token") - table_id = tool_parameters.get("table_id") - table_name = tool_parameters.get("table_name") - records = tool_parameters.get("records") + app_token = tool_parameters.get("app_token", "") + table_id = tool_parameters.get("table_id", "") + table_name = tool_parameters.get("table_name", "") + records = tool_parameters.get("records", "") user_id_type = tool_parameters.get("user_id_type", "open_id") res = client.update_records(app_token, table_id, table_name, records, user_id_type) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py b/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py index 8f83aea5abbe3d..80287feca176e1 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py @@ -7,12 +7,16 @@ class AddEventAttendeesTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - event_id = tool_parameters.get("event_id") - attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email") + event_id = tool_parameters.get("event_id", "") + attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email", "") need_notification = tool_parameters.get("need_notification", True) res = client.add_event_attendees(event_id, attendee_phone_or_email, need_notification) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py b/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py index 144889692f9055..02e9b445219ac8 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py @@ -7,11 +7,15 @@ class DeleteEventTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - event_id = tool_parameters.get("event_id") + event_id = tool_parameters.get("event_id", "") need_notification = tool_parameters.get("need_notification", True) res = client.delete_event(event_id, need_notification) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py b/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py index a2cd5a8b17d0af..4dafe4b3baf0cd 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py @@ -7,8 +7,12 @@ class GetPrimaryCalendarTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) user_id_type = tool_parameters.get("user_id_type", "open_id") diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py b/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py index 8815b4c9c871cd..2e8ca968b3cc42 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py @@ -7,14 +7,18 @@ class ListEventsTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - start_time = tool_parameters.get("start_time") - end_time = tool_parameters.get("end_time") - page_token = tool_parameters.get("page_token") - page_size = tool_parameters.get("page_size") + start_time = tool_parameters.get("start_time", "") + end_time = tool_parameters.get("end_time", "") + page_token = tool_parameters.get("page_token", "") + page_size = tool_parameters.get("page_size", 50) res = client.list_events(start_time, end_time, page_token, page_size) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py b/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py index 85bcb1d3f63847..b20eb6c31828e4 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py @@ -7,16 +7,20 @@ class UpdateEventTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - event_id = tool_parameters.get("event_id") - summary = tool_parameters.get("summary") - description = tool_parameters.get("description") + event_id = tool_parameters.get("event_id", "") + summary = tool_parameters.get("summary", "") + description = tool_parameters.get("description", "") need_notification = tool_parameters.get("need_notification", True) - start_time = tool_parameters.get("start_time") - end_time = tool_parameters.get("end_time") + start_time = tool_parameters.get("start_time", "") + end_time = tool_parameters.get("end_time", "") auto_record = tool_parameters.get("auto_record", False) res = client.update_event(event_id, summary, description, need_notification, start_time, end_time, auto_record) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py index 090a0828e89bbf..1533f594172878 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py @@ -7,13 +7,17 @@ class CreateDocumentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - title = tool_parameters.get("title") - content = tool_parameters.get("content") - folder_token = tool_parameters.get("folder_token") + title = tool_parameters.get("title", "") + content = tool_parameters.get("content", "") + folder_token = tool_parameters.get("folder_token", "") res = client.create_document(title, content, folder_token) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py index dd57c6870d0ba9..8ea68a2ed87855 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py @@ -7,11 +7,15 @@ class ListDocumentBlockTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get("document_id") + document_id = tool_parameters.get("document_id", "") page_token = tool_parameters.get("page_token", "") user_id_type = tool_parameters.get("user_id_type", "open_id") page_size = tool_parameters.get("page_size", 500) diff --git a/api/core/tools/provider/builtin/json_process/tools/delete.py b/api/core/tools/provider/builtin/json_process/tools/delete.py index fcab3d71a93cf9..06f6cacd5d6126 100644 --- a/api/core/tools/provider/builtin/json_process/tools/delete.py +++ b/api/core/tools/provider/builtin/json_process/tools/delete.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/json_process/tools/insert.py b/api/core/tools/provider/builtin/json_process/tools/insert.py index 793c74e5f9df51..e825329a6d8f61 100644 --- a/api/core/tools/provider/builtin/json_process/tools/insert.py +++ b/api/core/tools/provider/builtin/json_process/tools/insert.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/json_process/tools/parse.py b/api/core/tools/provider/builtin/json_process/tools/parse.py index f91432ee77f488..193017ba9a7c53 100644 --- a/api/core/tools/provider/builtin/json_process/tools/parse.py +++ b/api/core/tools/provider/builtin/json_process/tools/parse.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/json_process/tools/replace.py b/api/core/tools/provider/builtin/json_process/tools/replace.py index 383825c2d0b259..feca0d8a7c2783 100644 --- a/api/core/tools/provider/builtin/json_process/tools/replace.py +++ b/api/core/tools/provider/builtin/json_process/tools/replace.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/maths/tools/eval_expression.py b/api/core/tools/provider/builtin/maths/tools/eval_expression.py index 0c5b5e41cbe1e1..d3a497d1cd5c54 100644 --- a/api/core/tools/provider/builtin/maths/tools/eval_expression.py +++ b/api/core/tools/provider/builtin/maths/tools/eval_expression.py @@ -1,7 +1,7 @@ import logging from typing import Any, Union -import numexpr as ne +import numexpr as ne # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py index db4adfd4ad4629..6473c509e1f4c2 100644 --- a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py +++ b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py @@ -1,4 +1,4 @@ -from novita_client import ( +from novita_client import ( # type: ignore Txt2ImgV3Embedding, Txt2ImgV3HiresFix, Txt2ImgV3LoRA, diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py index 0b4f2edff3607f..097b234bd50640 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py @@ -2,7 +2,7 @@ from copy import deepcopy from typing import Any, Union -from novita_client import ( +from novita_client import ( # type: ignore NovitaClient, ) diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py index 9c61eab9f95784..297a27abba667a 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py @@ -2,7 +2,7 @@ from copy import deepcopy from typing import Any, Union -from novita_client import ( +from novita_client import ( # type: ignore NovitaClient, ) diff --git a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py index 165e93956eff38..704e0015d961a3 100644 --- a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py +++ b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py @@ -13,7 +13,7 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") - from pydub import AudioSegment + from pydub import AudioSegment # type: ignore class PodcastAudioGeneratorTool(BuiltinTool): diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py index d8ca20bde6ffc9..4a47c4211f4fd4 100644 --- a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py @@ -2,10 +2,10 @@ import logging from typing import Any, Union -from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q -from qrcode.image.base import BaseImage -from qrcode.image.pure import PyPNGImage -from qrcode.main import QRCode +from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q # type: ignore +from qrcode.image.base import BaseImage # type: ignore +from qrcode.image.pure import PyPNGImage # type: ignore +from qrcode.main import QRCode # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/transcript/tools/transcript.py b/api/core/tools/provider/builtin/transcript/tools/transcript.py index 27f700efbd6936..ac7565d9eef5b8 100644 --- a/api/core/tools/provider/builtin/transcript/tools/transcript.py +++ b/api/core/tools/provider/builtin/transcript/tools/transcript.py @@ -1,7 +1,7 @@ from typing import Any, Union from urllib.parse import parse_qs, urlparse -from youtube_transcript_api import YouTubeTranscriptApi +from youtube_transcript_api import YouTubeTranscriptApi # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py index 5ee839baa56f02..98a108f4ec7e93 100644 --- a/api/core/tools/provider/builtin/twilio/tools/send_message.py +++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py @@ -37,7 +37,7 @@ class TwilioAPIWrapper(BaseModel): def set_validator(cls, values: dict) -> dict: """Validate that api key and python package exists in environment.""" try: - from twilio.rest import Client + from twilio.rest import Client # type: ignore except ImportError: raise ImportError("Could not import twilio python package. Please install it with `pip install twilio`.") account_sid = values.get("account_sid") diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py index b1d100aad93dba..649e03d185121c 100644 --- a/api/core/tools/provider/builtin/twilio/twilio.py +++ b/api/core/tools/provider/builtin/twilio/twilio.py @@ -1,7 +1,7 @@ from typing import Any -from twilio.base.exceptions import TwilioRestException -from twilio.rest import Client +from twilio.base.exceptions import TwilioRestException # type: ignore +from twilio.rest import Client # type: ignore from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py index 1c7cb39c92b40b..a6afd2dddfc63a 100644 --- a/api/core/tools/provider/builtin/vanna/tools/vanna.py +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py @@ -1,6 +1,6 @@ from typing import Any, Union -from vanna.remote import VannaDefault +from vanna.remote import VannaDefault # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.errors import ToolProviderCredentialValidationError @@ -14,6 +14,9 @@ def _invoke( """ invoke tools """ + # Ensure runtime and credentials + if not self.runtime or not self.runtime.credentials: + raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing") api_key = self.runtime.credentials.get("api_key", None) if not api_key: raise ToolProviderCredentialValidationError("Please input api key") diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py index cb88e9519a4346..edb96e722f7f33 100644 --- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py @@ -1,6 +1,6 @@ from typing import Any, Optional, Union -import wikipedia +import wikipedia # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/yahoo/tools/analytics.py b/api/core/tools/provider/builtin/yahoo/tools/analytics.py index f044fbe5404b0a..95a65ba22fc8af 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/analytics.py +++ b/api/core/tools/provider/builtin/yahoo/tools/analytics.py @@ -3,7 +3,7 @@ import pandas as pd from requests.exceptions import HTTPError, ReadTimeout -from yfinance import download +from yfinance import download # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/yahoo/tools/news.py b/api/core/tools/provider/builtin/yahoo/tools/news.py index ff820430f9f366..c9ae0c4ca7fcc6 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/news.py +++ b/api/core/tools/provider/builtin/yahoo/tools/news.py @@ -1,6 +1,6 @@ from typing import Any, Union -import yfinance +import yfinance # type: ignore from requests.exceptions import HTTPError, ReadTimeout from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/yahoo/tools/ticker.py b/api/core/tools/provider/builtin/yahoo/tools/ticker.py index dfc7e460473c33..74d0d25addf04b 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/ticker.py +++ b/api/core/tools/provider/builtin/yahoo/tools/ticker.py @@ -1,7 +1,7 @@ from typing import Any, Union from requests.exceptions import HTTPError, ReadTimeout -from yfinance import Ticker +from yfinance import Ticker # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.py b/api/core/tools/provider/builtin/youtube/tools/videos.py index 95dec2eac9a752..a24fe89679b29b 100644 --- a/api/core/tools/provider/builtin/youtube/tools/videos.py +++ b/api/core/tools/provider/builtin/youtube/tools/videos.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Any, Union -from googleapiclient.discovery import build +from googleapiclient.discovery import build # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 955a0add3b4513..61de75ac5e2ccd 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -1,6 +1,6 @@ from abc import abstractmethod from os import listdir, path -from typing import Any +from typing import Any, Optional from core.helper.module_import_helper import load_single_subclass_from_source from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType @@ -50,6 +50,8 @@ def _get_builtin_tools(self) -> list[Tool]: """ if self.tools: return self.tools + if not self.identity: + return [] provider = self.identity.name tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools") @@ -86,7 +88,7 @@ def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: return self.credentials_schema.copy() - def get_tools(self) -> list[Tool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ returns a list of tools that the provider can provide @@ -94,11 +96,14 @@ def get_tools(self) -> list[Tool]: """ return self._get_builtin_tools() - def get_tool(self, tool_name: str) -> Tool: + def get_tool(self, tool_name: str) -> Optional[Tool]: """ returns the tool that the provider can provide """ - return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + tools = self.get_tools() + if tools is None: + raise ValueError("tools not found") + return next((t for t in tools if t.identity and t.identity.name == tool_name), None) def get_parameters(self, tool_name: str) -> list[ToolParameter]: """ @@ -107,10 +112,13 @@ def get_parameters(self, tool_name: str) -> list[ToolParameter]: :param tool_name: the name of the tool, defined in `get_tools` :return: list of parameters """ - tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + tools = self.get_tools() + if tools is None: + raise ToolNotFoundError(f"tool {tool_name} not found") + tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None) if tool is None: raise ToolNotFoundError(f"tool {tool_name} not found") - return tool.parameters + return tool.parameters or [] @property def need_credentials(self) -> bool: @@ -144,6 +152,8 @@ def _get_tool_labels(self) -> list[ToolLabelEnum]: """ returns the labels of the provider """ + if self.identity is None: + return [] return self.identity.tags or [] def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: @@ -159,56 +169,56 @@ def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dic for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter - for parameter in tool_parameters: - if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + for parameter_name in tool_parameters: + if parameter_name not in tool_parameters_need_to_validate: + raise ToolParameterValidationError(f"parameter {parameter_name} not found in tool {tool_name}") # check type - parameter_schema = tool_parameters_need_to_validate[parameter] + parameter_schema = tool_parameters_need_to_validate[parameter_name] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[parameter_name], str): + raise ToolParameterValidationError(f"parameter {parameter_name} should be string") elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f"parameter {parameter} should be number") + if not isinstance(tool_parameters[parameter_name], int | float): + raise ToolParameterValidationError(f"parameter {parameter_name} should be number") - if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: + if parameter_schema.min is not None and tool_parameters[parameter_name] < parameter_schema.min: raise ToolParameterValidationError( - f"parameter {parameter} should be greater than {parameter_schema.min}" + f"parameter {parameter_name} should be greater than {parameter_schema.min}" ) - if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: + if parameter_schema.max is not None and tool_parameters[parameter_name] > parameter_schema.max: raise ToolParameterValidationError( - f"parameter {parameter} should be less than {parameter_schema.max}" + f"parameter {parameter_name} should be less than {parameter_schema.max}" ) elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + if not isinstance(tool_parameters[parameter_name], bool): + raise ToolParameterValidationError(f"parameter {parameter_name} should be boolean") elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[parameter_name], str): + raise ToolParameterValidationError(f"parameter {parameter_name} should be string") options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f"parameter {parameter} options should be list") + raise ToolParameterValidationError(f"parameter {parameter_name} options should be list") - if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + if tool_parameters[parameter_name] not in [x.value for x in options]: + raise ToolParameterValidationError(f"parameter {parameter_name} should be one of {options}") - tool_parameters_need_to_validate.pop(parameter) + tool_parameters_need_to_validate.pop(parameter_name) - for parameter in tool_parameters_need_to_validate: - parameter_schema = tool_parameters_need_to_validate[parameter] + for parameter_name in tool_parameters_need_to_validate: + parameter_schema = tool_parameters_need_to_validate[parameter_name] if parameter_schema.required: - raise ToolParameterValidationError(f"parameter {parameter} is required") + raise ToolParameterValidationError(f"parameter {parameter_name} is required") # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: default_value = parameter_schema.type.cast_value(parameter_schema.default) - tool_parameters[parameter] = default_value + tool_parameters[parameter_name] = default_value def validate_credentials(self, credentials: dict[str, Any]) -> None: """ diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index bc05a11562b717..e35207e4f06404 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -24,10 +24,12 @@ def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: :return: the credentials schema """ + if self.credentials_schema is None: + return {} return self.credentials_schema.copy() @abstractmethod - def get_tools(self) -> list[Tool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ returns a list of tools that the provider can provide @@ -36,7 +38,7 @@ def get_tools(self) -> list[Tool]: pass @abstractmethod - def get_tool(self, tool_name: str) -> Tool: + def get_tool(self, tool_name: str) -> Optional[Tool]: """ returns a tool that the provider can provide @@ -51,10 +53,13 @@ def get_parameters(self, tool_name: str) -> list[ToolParameter]: :param tool_name: the name of the tool, defined in `get_tools` :return: list of parameters """ - tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + tools = self.get_tools() + if tools is None: + raise ToolNotFoundError(f"tool {tool_name} not found") + tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None) if tool is None: raise ToolNotFoundError(f"tool {tool_name} not found") - return tool.parameters + return tool.parameters or [] @property def provider_type(self) -> ToolProviderType: @@ -78,55 +83,55 @@ def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dic for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter - for parameter in tool_parameters: - if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + for tool_parameter in tool_parameters: + if tool_parameter not in tool_parameters_need_to_validate: + raise ToolParameterValidationError(f"parameter {tool_parameter} not found in tool {tool_name}") # check type - parameter_schema = tool_parameters_need_to_validate[parameter] + parameter_schema = tool_parameters_need_to_validate[tool_parameter] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[tool_parameter], str): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be string") elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f"parameter {parameter} should be number") + if not isinstance(tool_parameters[tool_parameter], int | float): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be number") - if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: + if parameter_schema.min is not None and tool_parameters[tool_parameter] < parameter_schema.min: raise ToolParameterValidationError( - f"parameter {parameter} should be greater than {parameter_schema.min}" + f"parameter {tool_parameter} should be greater than {parameter_schema.min}" ) - if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: + if parameter_schema.max is not None and tool_parameters[tool_parameter] > parameter_schema.max: raise ToolParameterValidationError( - f"parameter {parameter} should be less than {parameter_schema.max}" + f"parameter {tool_parameter} should be less than {parameter_schema.max}" ) elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + if not isinstance(tool_parameters[tool_parameter], bool): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be boolean") elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[tool_parameter], str): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be string") options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f"parameter {parameter} options should be list") + raise ToolParameterValidationError(f"parameter {tool_parameter} options should be list") - if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + if tool_parameters[tool_parameter] not in [x.value for x in options]: + raise ToolParameterValidationError(f"parameter {tool_parameter} should be one of {options}") - tool_parameters_need_to_validate.pop(parameter) + tool_parameters_need_to_validate.pop(tool_parameter) - for parameter in tool_parameters_need_to_validate: - parameter_schema = tool_parameters_need_to_validate[parameter] + for tool_parameter_validate in tool_parameters_need_to_validate: + parameter_schema = tool_parameters_need_to_validate[tool_parameter_validate] if parameter_schema.required: - raise ToolParameterValidationError(f"parameter {parameter} is required") + raise ToolParameterValidationError(f"parameter {tool_parameter_validate} is required") # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default) + tool_parameters[tool_parameter_validate] = parameter_schema.type.cast_value(parameter_schema.default) def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ @@ -144,6 +149,8 @@ def validate_credentials_format(self, credentials: dict[str, Any]) -> None: for credential_name in credentials: if credential_name not in credentials_need_to_validate: + if self.identity is None: + raise ValueError("identity is not set") raise ToolProviderCredentialValidationError( f"credential {credential_name} not found in provider {self.identity.name}" ) diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py index 5656dd09ab8c94..17fe2e20cf282e 100644 --- a/api/core/tools/provider/workflow_tool_provider.py +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -11,6 +11,7 @@ ToolProviderType, ) from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.tool import Tool from core.tools.tool.workflow_tool import WorkflowTool from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from extensions.ext_database import db @@ -116,6 +117,7 @@ def fetch_workflow_variable(variable_name: str): llm_description=parameter.description, required=variable.required, options=options, + placeholder=I18nObject(en_US="", zh_Hans=""), ) ) elif features.file_upload: @@ -128,6 +130,7 @@ def fetch_workflow_variable(variable_name: str): llm_description=parameter.description, required=False, form=parameter.form, + placeholder=I18nObject(en_US="", zh_Hans=""), ) ) else: @@ -157,7 +160,7 @@ def fetch_workflow_variable(variable_name: str): label=db_provider.label, ) - def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ fetch tools from database @@ -168,7 +171,7 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: if self.tools is not None: return self.tools - db_providers: WorkflowToolProvider = ( + db_providers: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter( WorkflowToolProvider.tenant_id == tenant_id, @@ -179,12 +182,14 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: if not db_providers: return [] + if not db_providers.app: + raise ValueError("app not found") self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] return self.tools - def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: + def get_tool(self, tool_name: str) -> Optional[Tool]: """ get tool by name @@ -195,6 +200,8 @@ def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: return None for tool in self.tools: + if tool.identity is None: + continue if tool.identity.name == tool_name: return tool diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 48aac75dbb4115..9a00450290a660 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -32,11 +32,13 @@ def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": :param meta: the meta data of a tool call processing, tenant_id is required :return: the new tool """ + if self.api_bundle is None: + raise ValueError("api_bundle is required") return self.__class__( identity=self.identity.model_copy() if self.identity else None, parameters=self.parameters.copy() if self.parameters else None, description=self.description.model_copy() if self.description else None, - api_bundle=self.api_bundle.model_copy() if self.api_bundle else None, + api_bundle=self.api_bundle.model_copy(), runtime=Tool.Runtime(**runtime), ) @@ -61,6 +63,8 @@ def tool_provider_type(self) -> ToolProviderType: def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: headers = {} + if self.runtime is None: + raise ValueError("runtime is required") credentials = self.runtime.credentials or {} if "auth_type" not in credentials: @@ -88,7 +92,7 @@ def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: headers[api_key_header] = credentials["api_key_value"] - needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required] + needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required] for parameter in needed_parameters: if parameter.required and parameter.name not in parameters: raise ToolParameterValidationError(f"Missing required parameter {parameter.name}") @@ -137,7 +141,8 @@ def do_http_request( params = {} path_params = {} - body = {} + # FIXME: body should be a dict[str, Any] but it changed a lot in this function + body: Any = {} cookies = {} files = [] @@ -198,7 +203,7 @@ def do_http_request( body = body if method in {"get", "head", "post", "put", "delete", "patch"}: - response = getattr(ssrf_proxy, method)( + response: httpx.Response = getattr(ssrf_proxy, method)( url, params=params, headers=headers, @@ -288,6 +293,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe """ invoke http request """ + response: httpx.Response | str = "" # assemble request headers = self.assembling_request(tool_parameters) diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py index e2a81ed0a36edd..adda4297f38e8a 100644 --- a/api/core/tools/tool/builtin_tool.py +++ b/api/core/tools/tool/builtin_tool.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, cast from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage @@ -32,9 +32,12 @@ def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: :return: the model result """ # invoke model + if self.runtime is None or self.identity is None: + raise ValueError("runtime and identity are required") + return ModelInvocationUtils.invoke( user_id=user_id, - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", tool_type="builtin", tool_name=self.identity.name, prompt_messages=prompt_messages, @@ -50,8 +53,11 @@ def get_max_tokens(self) -> int: :param model_config: the model config :return: the max tokens """ + if self.runtime is None: + raise ValueError("runtime is required") + return ModelInvocationUtils.get_max_llm_context_tokens( - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", ) def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: @@ -61,7 +67,12 @@ def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: :param prompt_messages: the prompt messages :return: the tokens """ - return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages) + if self.runtime is None: + raise ValueError("runtime is required") + + return ModelInvocationUtils.calculate_tokens( + tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages + ) def summary(self, user_id: str, content: str) -> str: max_tokens = self.get_max_tokens() @@ -81,7 +92,7 @@ def summarize(content: str) -> str: stop=[], ) - return summary.message.content + return cast(str, summary.message.content) lines = content.split("\n") new_lines = [] @@ -102,16 +113,16 @@ def summarize(content: str) -> str: # merge lines into messages with max tokens messages: list[str] = [] - for i in new_lines: + for j in new_lines: if len(messages) == 0: - messages.append(i) + messages.append(j) else: - if len(messages[-1]) + len(i) < max_tokens * 0.5: - messages[-1] += i - if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: - messages.append(i) + if len(messages[-1]) + len(j) < max_tokens * 0.5: + messages[-1] += j + if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7: + messages.append(j) else: - messages[-1] += i + messages[-1] += j summaries = [] for i in range(len(messages)): diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index ab7b40a2536db8..a4afea4b9df429 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,4 +1,5 @@ import threading +from typing import Any from flask import Flask, current_app from pydantic import BaseModel, Field @@ -7,13 +8,14 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -default_retrieval_model = { +default_retrieval_model: dict[str, Any] = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -44,12 +46,12 @@ def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): def _run(self, query: str) -> str: threads = [] - all_documents = [] + all_documents: list[RagDocument] = [] for dataset_id in self.dataset_ids: retrieval_thread = threading.Thread( target=self._retriever, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "all_documents": all_documents, @@ -77,11 +79,11 @@ def _run(self, query: str) -> str: document_score_list = {} for item in all_documents: - if item.metadata.get("score"): + if item.metadata and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata["doc_id"] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.completed_at.isnot(None), @@ -139,6 +141,7 @@ def _run(self, query: str) -> str: hit_callback.return_retriever_resource_info(context_list) return str("\n".join(document_context_list)) + return "" def _retriever( self, diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py index dad8c773579099..a4d2de3b1c8ef3 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import Any, Optional -from msal_extensions.persistence import ABC +from msal_extensions.persistence import ABC # type: ignore from pydantic import BaseModel, ConfigDict from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 987f94a35046e9..b382016473055d 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel, Field from core.rag.datasource.retrieval_service import RetrievalService @@ -69,25 +71,27 @@ def _run(self, query: str) -> str: metadata=external_document.get("metadata"), provider="external", ) - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset.id - document.metadata["dataset_name"] = dataset.name - results.append(document) + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset.id + document.metadata["dataset_name"] = dataset.name + results.append(document) # deal with external documents context_list = [] for position, item in enumerate(results, start=1): - source = { - "position": position, - "dataset_id": item.metadata.get("dataset_id"), - "dataset_name": item.metadata.get("dataset_name"), - "document_name": item.metadata.get("title"), - "data_source_type": "external", - "retriever_from": self.retriever_from, - "score": item.metadata.get("score"), - "title": item.metadata.get("title"), - "content": item.page_content, - } + if item.metadata is not None: + source = { + "position": position, + "dataset_id": item.metadata.get("dataset_id"), + "dataset_name": item.metadata.get("dataset_name"), + "document_name": item.metadata.get("title"), + "data_source_type": "external", + "retriever_from": self.retriever_from, + "score": item.metadata.get("score"), + "title": item.metadata.get("title"), + "content": item.page_content, + } context_list.append(source) for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) @@ -95,7 +99,7 @@ def _run(self, query: str) -> str: return str("\n".join([item.page_content for item in results])) else: # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model or default_retrieval_model + retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query documents = RetrievalService.retrieve( @@ -113,11 +117,11 @@ def _run(self, query: str) -> str: score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] else 0.0, - reranking_model=retrieval_model.get("reranking_model", None) + reranking_model=retrieval_model.get("reranking_model") if retrieval_model["reranking_enable"] else None, reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights", None), + weights=retrieval_model.get("weights"), ) else: documents = [] @@ -127,7 +131,7 @@ def _run(self, query: str) -> str: document_score_list = {} if dataset.indexing_technique != "economy": for item in documents: - if item.metadata.get("score"): + if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] index_node_ids = [document.metadata["doc_id"] for document in documents] @@ -155,20 +159,21 @@ def _run(self, query: str) -> str: context_list = [] resource_number = 1 for segment in sorted_segments: - context = {} - document = Document.query.filter( + document_segment = Document.query.filter( Document.id == segment.document_id, Document.enabled == True, Document.archived == False, ).first() - if dataset and document: + if not document_segment: + continue + if dataset and document_segment: source = { "position": resource_number, "dataset_id": dataset.id, "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, + "document_id": document_segment.id, + "document_name": document_segment.name, + "data_source_type": document_segment.data_source_type, "segment_id": segment.id, "retriever_from": self.retriever_from, "score": document_score_list.get(segment.index_node_id, None), diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 3c9295c493c470..2d7e193e152645 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -23,7 +23,7 @@ class DatasetRetrieverTool(Tool): def get_dataset_tools( tenant_id: str, dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, + retrieve_config: Optional[DatasetRetrieveConfigEntity], return_resource: bool, invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler, @@ -51,6 +51,8 @@ def get_dataset_tools( invoke_from=invoke_from, hit_callback=hit_callback, ) + if retrieval_tools is None: + return [] # restore retrieve strategy retrieve_config.retrieve_strategy = original_retriever_mode @@ -83,6 +85,7 @@ def get_runtime_parameters(self) -> list[ToolParameter]: llm_description="Query for the dataset to be used to retrieve the dataset.", required=True, default="", + placeholder=I18nObject(en_US="", zh_Hans=""), ), ] @@ -102,7 +105,9 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe return self.create_text_message(text=result) - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str | None: """ validate the credentials for dataset retriever tool """ diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 8d4045038171a6..55f94d7619635b 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -91,7 +91,7 @@ def tool_provider_type(self) -> ToolProviderType: :return: the tool provider type """ - def load_variables(self, variables: ToolRuntimeVariablePool): + def load_variables(self, variables: ToolRuntimeVariablePool | None) -> None: """ load variables from database @@ -105,6 +105,8 @@ def set_image_variable(self, variable_name: str, image_key: str) -> None: """ if not self.variables: return + if self.identity is None: + return self.variables.set_file(self.identity.name, variable_name, image_key) @@ -114,6 +116,8 @@ def set_text_variable(self, variable_name: str, text: str) -> None: """ if not self.variables: return + if self.identity is None: + return self.variables.set_text(self.identity.name, variable_name, text) @@ -200,7 +204,11 @@ def list_default_image_variables(self) -> list[ToolRuntimeVariable]: def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]: # update tool_parameters # TODO: Fix type error. + if self.runtime is None: + return [] if self.runtime.runtime_parameters: + # Convert Mapping to dict before updating + tool_parameters = dict(tool_parameters) tool_parameters.update(self.runtime.runtime_parameters) # try parse tool parameters into the correct type @@ -221,7 +229,7 @@ def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> Transform tool parameters type """ # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials - result = deepcopy(tool_parameters) + result: dict[str, Any] = deepcopy(dict(tool_parameters)) for parameter in self.parameters or []: if parameter.name in tool_parameters: result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name]) @@ -234,12 +242,15 @@ def _invoke( ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: pass - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str | None: """ validate the credentials :param credentials: the credentials :param parameters: the parameters + :param format_only: only return the formatted """ pass diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index 33b4ad021a5e7f..edff4a2d07cca2 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -68,20 +68,20 @@ def _invoke( if data.get("error"): raise Exception(data.get("error")) - result = [] + r = [] outputs = data.get("outputs") if outputs == None: outputs = {} else: - outputs, files = self._extract_files(outputs) - for file in files: - result.append(self.create_file_message(file)) + outputs, extracted_files = self._extract_files(outputs) + for f in extracted_files: + r.append(self.create_file_message(f)) - result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) - result.append(self.create_json_message(outputs)) + r.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) + r.append(self.create_json_message(outputs)) - return result + return r def _get_user(self, user_id: str) -> Union[EndUser, Account]: """ diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index f92b43608ed935..425a892527daa4 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -3,7 +3,7 @@ from copy import deepcopy from datetime import UTC, datetime from mimetypes import guess_type -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from yarl import URL @@ -46,7 +46,7 @@ def agent_invoke( invoke_from: InvokeFrom, agent_tool_callback: DifyAgentCallbackHandler, trace_manager: Optional[TraceQueueManager] = None, - ) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]: + ) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. """ @@ -69,6 +69,8 @@ def agent_invoke( raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") # invoke the tool + if tool.identity is None: + raise ValueError("tool identity is not set") try: # hit the callback handler agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) @@ -163,6 +165,8 @@ def _invoke(tool: Tool, tool_parameters: dict, user_id: str) -> tuple[ToolInvoke """ Invoke the tool with the given arguments. """ + if tool.identity is None: + raise ValueError("tool identity is not set") started_at = datetime.now(UTC) meta = ToolInvokeMeta( time_cost=0.0, @@ -171,7 +175,7 @@ def _invoke(tool: Tool, tool_parameters: dict, user_id: str) -> tuple[ToolInvoke "tool_name": tool.identity.name, "tool_provider": tool.identity.provider, "tool_provider_type": tool.tool_provider_type().value, - "tool_parameters": deepcopy(tool.runtime.runtime_parameters), + "tool_parameters": deepcopy(tool.runtime.runtime_parameters) if tool.runtime else {}, "tool_icon": tool.identity.icon, }, ) @@ -194,9 +198,9 @@ def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str result = "" for response in tool_response: if response.type == ToolInvokeMessage.MessageType.TEXT: - result += response.message + result += str(response.message) if response.message is not None else "" elif response.type == ToolInvokeMessage.MessageType.LINK: - result += f"result link: {response.message}. please tell user to check it." + result += f"result link: {response.message!r}. please tell user to check it." elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: result += ( "image has been created and sent to user already, you do not need to create it," @@ -205,7 +209,7 @@ def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str elif response.type == ToolInvokeMessage.MessageType.JSON: result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}." else: - result += f"tool response: {response.message}." + result += f"tool response: {response.message!r}." return result @@ -223,7 +227,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis mimetype = response.meta.get("mime_type") else: try: - url = URL(response.message) + url = URL(cast(str, response.message)) extension = url.suffix guess_type_result, _ = guess_type(f"a{extension}") if guess_type_result: @@ -237,7 +241,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis result.append( ToolInvokeMessageBinary( mimetype=response.meta.get("mime_type", "image/jpeg"), - url=response.message, + url=cast(str, response.message), save_as=response.save_as, ) ) @@ -245,7 +249,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis result.append( ToolInvokeMessageBinary( mimetype=response.meta.get("mime_type", "octet/stream"), - url=response.message, + url=cast(str, response.message), save_as=response.save_as, ) ) @@ -257,7 +261,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis mimetype=response.meta.get("mime_type", "octet/stream") if response.meta else "octet/stream", - url=response.message, + url=cast(str, response.message), save_as=response.save_as, ) ) diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 2a5a2944ef8471..e53985951b0627 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -84,13 +84,17 @@ def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[ if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): raise ValueError("Unsupported tool type") - provider_ids = [controller.provider_id for controller in tool_providers] + provider_ids = [ + controller.provider_id + for controller in tool_providers + if isinstance(controller, (ApiToolProviderController, WorkflowToolProviderController)) + ] labels: list[ToolLabelBinding] = ( db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() ) - tool_labels = {label.tool_id: [] for label in labels} + tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} for label in labels: tool_labels[label.tool_id].append(label.label_name) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index ac333162b6bb1c..5b2173a4d0ad69 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -4,7 +4,7 @@ from collections.abc import Generator from os import listdir, path from threading import Lock, Thread -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from configs import dify_config from core.agent.entities import AgentToolEntity @@ -15,15 +15,18 @@ from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter -from core.tools.errors import ToolProviderNotFoundError +from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController from core.tools.tool.api_tool import ApiTool from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager +from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -33,9 +36,9 @@ class ToolManager: _builtin_provider_lock = Lock() - _builtin_providers = {} + _builtin_providers: dict[str, BuiltinToolProviderController] = {} _builtin_providers_loaded = False - _builtin_tools_labels = {} + _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} @classmethod def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: @@ -55,7 +58,7 @@ def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: return cls._builtin_providers[provider] @classmethod - def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: + def get_builtin_tool(cls, provider: str, tool_name: str) -> Union[BuiltinTool, Tool]: """ get the builtin tool @@ -66,13 +69,15 @@ def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: """ provider_controller = cls.get_builtin_provider(provider) tool = provider_controller.get_tool(tool_name) + if tool is None: + raise ToolNotFoundError(f"tool {tool_name} not found") return tool @classmethod def get_tool( cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: Optional[str] = None - ) -> Union[BuiltinTool, ApiTool]: + ) -> Union[BuiltinTool, ApiTool, Tool]: """ get the tool @@ -103,7 +108,7 @@ def get_tool_runtime( tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, - ) -> Union[BuiltinTool, ApiTool]: + ) -> Union[BuiltinTool, ApiTool, Tool]: """ get the tool runtime @@ -113,6 +118,7 @@ def get_tool_runtime( :return: the tool """ + controller: Union[BuiltinToolProviderController, ApiToolProviderController, WorkflowToolProviderController] if provider_type == "builtin": builtin_tool = cls.get_builtin_tool(provider_id, tool_name) @@ -129,7 +135,7 @@ def get_tool_runtime( ) # get credentials - builtin_provider: BuiltinToolProvider = ( + builtin_provider: Optional[BuiltinToolProvider] = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, @@ -177,7 +183,7 @@ def get_tool_runtime( } ) elif provider_type == "workflow": - workflow_provider = ( + workflow_provider: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .first() @@ -187,8 +193,13 @@ def get_tool_runtime( raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) + controller_tools: Optional[list[Tool]] = controller.get_tools( + user_id="", tenant_id=workflow_provider.tenant_id + ) + if controller_tools is None or len(controller_tools) == 0: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( + return controller_tools[0].fork_tool_runtime( runtime={ "tenant_id": tenant_id, "credentials": {}, @@ -215,7 +226,7 @@ def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict if parameter_rule.type == ToolParameter.ToolParameterType.SELECT: # check if tool_parameter_config in options - options = [x.value for x in parameter_rule.options] + options = [x.value for x in parameter_rule.options or []] if parameter_value is not None and parameter_value not in options: raise ValueError( f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" @@ -267,6 +278,8 @@ def get_agent_tool_runtime( identity_id=f"AGENT.{app_id}", ) runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None: + raise ValueError("runtime not found or runtime parameters not found") tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @@ -312,6 +325,9 @@ def get_workflow_tool_runtime( if runtime_parameters: runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None: + raise ValueError("runtime not found or runtime parameters not found") + tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @@ -326,6 +342,8 @@ def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]: """ # get provider provider_controller = cls.get_builtin_provider(provider) + if provider_controller.identity is None: + raise ToolProviderNotFoundError(f"builtin provider {provider} not found") absolute_path = path.join( path.dirname(path.realpath(__file__)), @@ -381,11 +399,15 @@ def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, Non ), parent_type=BuiltinToolProviderController, ) - provider: BuiltinToolProviderController = provider_class() - cls._builtin_providers[provider.identity.name] = provider - for tool in provider.get_tools(): + provider_controller: BuiltinToolProviderController = provider_class() + if provider_controller.identity is None: + continue + cls._builtin_providers[provider_controller.identity.name] = provider_controller + for tool in provider_controller.get_tools() or []: + if tool.identity is None: + continue cls._builtin_tools_labels[tool.identity.name] = tool.identity.label - yield provider + yield provider_controller except Exception as e: logger.exception(f"load builtin provider {provider}") @@ -449,9 +471,11 @@ def user_list_providers( # append builtin providers for provider in builtin_providers: # handle include, exclude + if provider.identity is None: + continue if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET), + exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET), data=provider, name_func=lambda x: x.identity.name, ): @@ -472,7 +496,7 @@ def user_list_providers( db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() ) - api_provider_controllers = [ + api_provider_controllers: list[dict[str, Any]] = [ {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} for provider in db_api_providers ] @@ -495,7 +519,7 @@ def user_list_providers( db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() ) - workflow_provider_controllers = [] + workflow_provider_controllers: list[WorkflowToolProviderController] = [] for provider in workflow_providers: try: workflow_provider_controllers.append( @@ -505,7 +529,9 @@ def user_list_providers( # app has been deleted pass - labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers) + labels = ToolLabelManager.get_tools_labels( + [cast(ToolProviderController, controller) for controller in workflow_provider_controllers] + ) for provider_controller in workflow_provider_controllers: user_provider = ToolTransformService.workflow_provider_to_user_provider( @@ -527,7 +553,7 @@ def get_api_provider_controller( :return: the provider controller, the credentials """ - provider: ApiToolProvider = ( + provider: Optional[ApiToolProvider] = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.id == provider_id, @@ -556,7 +582,7 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: get tool provider """ provider_name = provider - provider: ApiToolProvider = ( + provider_tool: Optional[ApiToolProvider] = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, @@ -565,17 +591,18 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: .first() ) - if provider is None: + if provider_tool is None: raise ValueError(f"you have not added provider {provider_name}") try: - credentials = json.loads(provider.credentials_str) or {} + credentials = json.loads(provider_tool.credentials_str) or {} except: credentials = {} # package tool provider controller controller = ApiToolProviderController.from_db( - provider, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE + provider_tool, + ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) # init tool configuration tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) @@ -584,25 +611,28 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) try: - icon = json.loads(provider.icon) + icon = json.loads(provider_tool.icon) except: icon = {"background": "#252525", "content": "\ud83d\ude01"} # add tool labels labels = ToolLabelManager.get_tool_labels(controller) - return jsonable_encoder( - { - "schema_type": provider.schema_type, - "schema": provider.schema, - "tools": provider.tools, - "icon": icon, - "description": provider.description, - "credentials": masked_credentials, - "privacy_policy": provider.privacy_policy, - "custom_disclaimer": provider.custom_disclaimer, - "labels": labels, - } + return cast( + dict, + jsonable_encoder( + { + "schema_type": provider_tool.schema_type, + "schema": provider_tool.schema, + "tools": provider_tool.tools, + "icon": icon, + "description": provider_tool.description, + "credentials": masked_credentials, + "privacy_policy": provider_tool.privacy_policy, + "custom_disclaimer": provider_tool.custom_disclaimer, + "labels": labels, + } + ), ) @classmethod @@ -617,6 +647,7 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> """ provider_type = provider_type provider_id = provider_id + provider: Optional[Union[BuiltinToolProvider, ApiToolProvider, WorkflowToolProvider]] = None if provider_type == "builtin": return ( dify_config.CONSOLE_API_URL @@ -626,16 +657,21 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> ) elif provider_type == "api": try: - provider: ApiToolProvider = ( + provider = ( db.session.query(ApiToolProvider) .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) .first() ) - return json.loads(provider.icon) + if provider is None: + raise ToolProviderNotFoundError(f"api provider {provider_id} not found") + icon = json.loads(provider.icon) + if isinstance(icon, (str, dict)): + return icon + return {"background": "#252525", "content": "\ud83d\ude01"} except: return {"background": "#252525", "content": "\ud83d\ude01"} elif provider_type == "workflow": - provider: WorkflowToolProvider = ( + provider = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .first() @@ -643,7 +679,13 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> if provider is None: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - return json.loads(provider.icon) + try: + icon = json.loads(provider.icon) + if isinstance(icon, (str, dict)): + return icon + return {"background": "#252525", "content": "\ud83d\ude01"} + except: + return {"background": "#252525", "content": "\ud83d\ude01"} else: raise ValueError(f"provider type {provider_type} not found") diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 8b5e27f5382ee7..d7720928644701 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -72,9 +72,13 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str return a deep copy of credentials with decrypted values """ + identity_id = "" + if self.provider_controller.identity: + identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}" + cache = ToolProviderCredentialsCache( tenant_id=self.tenant_id, - identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + identity_id=identity_id, cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cached_credentials = cache.get() @@ -95,9 +99,13 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str return credentials def delete_tool_credentials_cache(self): + identity_id = "" + if self.provider_controller.identity: + identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}" + cache = ToolProviderCredentialsCache( tenant_id=self.tenant_id, - identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + identity_id=identity_id, cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cache.delete() @@ -199,6 +207,9 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: return a deep copy of parameters with decrypted values """ + if self.tool_runtime is None or self.tool_runtime.identity is None: + raise ValueError("tool_runtime is required") + cache = ToolParameterCache( tenant_id=self.tenant_id, provider=f"{self.provider_type}.{self.provider_name}", @@ -232,6 +243,9 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: return parameters def delete_tool_parameters_cache(self): + if self.tool_runtime is None or self.tool_runtime.identity is None: + raise ValueError("tool_runtime is required") + cache = ToolParameterCache( tenant_id=self.tenant_id, provider=f"{self.provider_type}.{self.provider_name}", diff --git a/api/core/tools/utils/feishu_api_utils.py b/api/core/tools/utils/feishu_api_utils.py index ea28037df03720..ecf60045aa8dc5 100644 --- a/api/core/tools/utils/feishu_api_utils.py +++ b/api/core/tools/utils/feishu_api_utils.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional, cast import httpx @@ -101,7 +101,7 @@ def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict: """ url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token" payload = {"app_id": app_id, "app_secret": app_secret} - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def create_document(self, title: str, content: str, folder_token: str) -> dict: @@ -126,15 +126,16 @@ def create_document(self, title: str, content: str, folder_token: str) -> dict: "content": content, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def write_document(self, document_id: str, content: str, position: str = "end") -> dict: url = f"{self.API_BASE_URL}/document/write_document" payload = {"document_id": document_id, "content": content, "position": position} - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) return res def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str: @@ -155,9 +156,9 @@ def get_document_content(self, document_id: str, mode: str = "markdown", lang: s "lang": lang, } url = f"{self.API_BASE_URL}/document/get_document_content" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data").get("content") + return cast(str, res.get("data", {}).get("content")) return "" def list_document_blocks( @@ -173,9 +174,10 @@ def list_document_blocks( "page_token": page_token, } url = f"{self.API_BASE_URL}/document/list_document_blocks" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict: @@ -191,9 +193,10 @@ def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict: @@ -203,7 +206,7 @@ def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dic "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def get_chat_messages( @@ -227,9 +230,10 @@ def get_chat_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_thread_messages( @@ -245,9 +249,10 @@ def get_thread_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict: @@ -260,9 +265,10 @@ def create_task(self, summary: str, start_time: str, end_time: str, completed_ti "completed_at": completed_time, "description": description, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_task( @@ -278,9 +284,10 @@ def update_task( "completed_time": completed_time, "description": description, } - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_task(self, task_guid: str) -> dict: @@ -289,7 +296,7 @@ def delete_task(self, task_guid: str) -> dict: payload = { "task_guid": task_guid, } - res = self._send_request(url, method="DELETE", payload=payload) + res: dict = self._send_request(url, method="DELETE", payload=payload) return res def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict: @@ -300,7 +307,7 @@ def add_members(self, task_guid: str, member_phone_or_email: str, member_role: s "member_phone_or_email": member_phone_or_email, "member_role": member_role, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) return res def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict: @@ -312,9 +319,10 @@ def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: @@ -322,9 +330,10 @@ def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: params = { "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_event( @@ -347,9 +356,10 @@ def create_event( "auto_record": auto_record, "attendee_ability": attendee_ability, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_event( @@ -363,7 +373,7 @@ def update_event( auto_record: bool, ) -> dict: url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}" - payload = {} + payload: dict[str, Any] = {} if summary: payload["summary"] = summary if description: @@ -376,7 +386,7 @@ def update_event( payload["need_notification"] = need_notification if auto_record: payload["auto_record"] = auto_record - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) return res def delete_event(self, event_id: str, need_notification: bool = True) -> dict: @@ -384,7 +394,7 @@ def delete_event(self, event_id: str, need_notification: bool = True) -> dict: params = { "need_notification": need_notification, } - res = self._send_request(url, method="DELETE", params=params) + res: dict = self._send_request(url, method="DELETE", params=params) return res def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict: @@ -395,9 +405,10 @@ def list_events(self, start_time: str, end_time: str, page_token: str, page_size "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_events( @@ -418,9 +429,10 @@ def search_events( "user_id_type": user_id_type, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict: @@ -431,9 +443,10 @@ def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_ "attendee_phone_or_email": attendee_phone_or_email, "need_notification": need_notification, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_spreadsheet( @@ -447,9 +460,10 @@ def create_spreadsheet( "title": title, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_spreadsheet( @@ -463,9 +477,10 @@ def get_spreadsheet( "spreadsheet_token": spreadsheet_token, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_spreadsheet_sheets( @@ -477,9 +492,10 @@ def list_spreadsheet_sheets( params = { "spreadsheet_token": spreadsheet_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_rows( @@ -499,9 +515,10 @@ def add_rows( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_cols( @@ -521,9 +538,10 @@ def add_cols( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_rows( @@ -545,9 +563,10 @@ def read_rows( "num_rows": num_rows, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_cols( @@ -569,9 +588,10 @@ def read_cols( "num_cols": num_cols, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_table( @@ -593,9 +613,10 @@ def read_table( "query": query, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_base( @@ -609,9 +630,10 @@ def create_base( "name": name, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_records( @@ -633,9 +655,10 @@ def add_records( payload = { "records": convert_add_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_records( @@ -657,9 +680,10 @@ def update_records( payload = { "records": convert_update_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_records( @@ -686,9 +710,10 @@ def delete_records( payload = { "records": record_id_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_record( @@ -740,7 +765,7 @@ def search_record( except json.JSONDecodeError: raise ValueError("The input string is not valid JSON") - payload = {} + payload: dict[str, Any] = {} if view_id: payload["view_id"] = view_id @@ -752,10 +777,11 @@ def search_record( payload["filter"] = filter_dict if automatic_fields: payload["automatic_fields"] = automatic_fields - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_base_info( @@ -767,9 +793,10 @@ def get_base_info( params = { "app_token": app_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_table( @@ -797,9 +824,10 @@ def create_table( } if default_view_name: payload["default_view_name"] = default_view_name - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_tables( @@ -834,9 +862,10 @@ def delete_tables( "table_names": table_name_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_tables( @@ -852,9 +881,10 @@ def list_tables( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_records( @@ -882,7 +912,8 @@ def read_records( "record_ids": record_id_list, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params, payload=payload) + res: dict = self._send_request(url, method="GET", params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res diff --git a/api/core/tools/utils/lark_api_utils.py b/api/core/tools/utils/lark_api_utils.py index 30cb0cb141d9a6..de394a39bf5a00 100644 --- a/api/core/tools/utils/lark_api_utils.py +++ b/api/core/tools/utils/lark_api_utils.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional, cast import httpx @@ -62,12 +62,10 @@ def convert_update_records(self, json_str): def tenant_access_token(self) -> str: feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token" if redis_client.exists(feishu_tenant_access_token): - return redis_client.get(feishu_tenant_access_token).decode() - res = self.get_tenant_access_token(self.app_id, self.app_secret) + return str(redis_client.get(feishu_tenant_access_token).decode()) + res: dict[str, str] = self.get_tenant_access_token(self.app_id, self.app_secret) redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token")) - if "tenant_access_token" in res: - return res.get("tenant_access_token") - return "" + return res.get("tenant_access_token", "") def _send_request( self, @@ -91,7 +89,7 @@ def _send_request( def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict: url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token" payload = {"app_id": app_id, "app_secret": app_secret} - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def create_document(self, title: str, content: str, folder_token: str) -> dict: @@ -101,15 +99,16 @@ def create_document(self, title: str, content: str, folder_token: str) -> dict: "content": content, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def write_document(self, document_id: str, content: str, position: str = "end") -> dict: url = f"{self.API_BASE_URL}/document/write_document" payload = {"document_id": document_id, "content": content, "position": position} - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) return res def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str | dict: @@ -119,9 +118,9 @@ def get_document_content(self, document_id: str, mode: str = "markdown", lang: s "lang": lang, } url = f"{self.API_BASE_URL}/document/get_document_content" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data").get("content") + return cast(dict, res.get("data", {}).get("content")) return "" def list_document_blocks( @@ -134,9 +133,10 @@ def list_document_blocks( "page_token": page_token, } url = f"{self.API_BASE_URL}/document/list_document_blocks" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict: @@ -149,9 +149,10 @@ def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict: @@ -161,7 +162,7 @@ def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dic "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def get_chat_messages( @@ -182,9 +183,10 @@ def get_chat_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_thread_messages( @@ -197,9 +199,10 @@ def get_thread_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict: @@ -211,9 +214,10 @@ def create_task(self, summary: str, start_time: str, end_time: str, completed_ti "completed_at": completed_time, "description": description, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_task( @@ -228,9 +232,10 @@ def update_task( "completed_time": completed_time, "description": description, } - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_task(self, task_guid: str) -> dict: @@ -238,9 +243,10 @@ def delete_task(self, task_guid: str) -> dict: payload = { "task_guid": task_guid, } - res = self._send_request(url, method="DELETE", payload=payload) + res: dict = self._send_request(url, method="DELETE", payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict: @@ -250,9 +256,10 @@ def add_members(self, task_guid: str, member_phone_or_email: str, member_role: s "member_phone_or_email": member_phone_or_email, "member_role": member_role, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict: @@ -263,9 +270,10 @@ def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: @@ -273,9 +281,10 @@ def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: params = { "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_event( @@ -298,9 +307,10 @@ def create_event( "auto_record": auto_record, "attendee_ability": attendee_ability, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_event( @@ -314,7 +324,7 @@ def update_event( auto_record: bool, ) -> dict: url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}" - payload = {} + payload: dict[str, Any] = {} if summary: payload["summary"] = summary if description: @@ -327,7 +337,7 @@ def update_event( payload["need_notification"] = need_notification if auto_record: payload["auto_record"] = auto_record - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) return res def delete_event(self, event_id: str, need_notification: bool = True) -> dict: @@ -335,7 +345,7 @@ def delete_event(self, event_id: str, need_notification: bool = True) -> dict: params = { "need_notification": need_notification, } - res = self._send_request(url, method="DELETE", params=params) + res: dict = self._send_request(url, method="DELETE", params=params) return res def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict: @@ -346,9 +356,10 @@ def list_events(self, start_time: str, end_time: str, page_token: str, page_size "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_events( @@ -369,9 +380,10 @@ def search_events( "user_id_type": user_id_type, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict: @@ -381,9 +393,10 @@ def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_ "attendee_phone_or_email": attendee_phone_or_email, "need_notification": need_notification, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_spreadsheet( @@ -396,9 +409,10 @@ def create_spreadsheet( "title": title, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_spreadsheet( @@ -411,9 +425,10 @@ def get_spreadsheet( "spreadsheet_token": spreadsheet_token, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_spreadsheet_sheets( @@ -424,9 +439,10 @@ def list_spreadsheet_sheets( params = { "spreadsheet_token": spreadsheet_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_rows( @@ -445,9 +461,10 @@ def add_rows( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_cols( @@ -466,9 +483,10 @@ def add_cols( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_rows( @@ -489,9 +507,10 @@ def read_rows( "num_rows": num_rows, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_cols( @@ -512,9 +531,10 @@ def read_cols( "num_cols": num_cols, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_table( @@ -535,9 +555,10 @@ def read_table( "query": query, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_base( @@ -550,9 +571,10 @@ def create_base( "name": name, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_records( @@ -573,9 +595,10 @@ def add_records( payload = { "records": self.convert_add_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_records( @@ -596,9 +619,10 @@ def update_records( payload = { "records": self.convert_update_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_records( @@ -624,9 +648,10 @@ def delete_records( payload = { "records": record_id_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_record( @@ -678,7 +703,7 @@ def search_record( except json.JSONDecodeError: raise ValueError("The input string is not valid JSON") - payload = {} + payload: dict[str, Any] = {} if view_id: payload["view_id"] = view_id @@ -690,9 +715,10 @@ def search_record( payload["filter"] = filter_dict if automatic_fields: payload["automatic_fields"] = automatic_fields - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_base_info( @@ -703,9 +729,10 @@ def get_base_info( params = { "app_token": app_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_table( @@ -732,9 +759,10 @@ def create_table( } if default_view_name: payload["default_view_name"] = default_view_name - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_tables( @@ -767,9 +795,10 @@ def delete_tables( "table_ids": table_id_list, "table_names": table_name_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_tables( @@ -784,9 +813,10 @@ def list_tables( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_records( @@ -814,7 +844,8 @@ def read_records( "record_ids": record_id_list, "user_id_type": user_id_type, } - res = self._send_request(url, method="POST", params=params, payload=payload) + res: dict = self._send_request(url, method="POST", params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index e30c903a4b1146..3509f1e6e59f77 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -90,12 +90,12 @@ def transform_tool_invoke_messages( ) elif message.type == ToolInvokeMessage.MessageType.FILE: assert message.meta is not None - file = message.meta.get("file") - if isinstance(file, File): - if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) - if file.type == FileType.IMAGE: + file_mata = message.meta.get("file") + if isinstance(file_mata, File): + if file_mata.transfer_method == FileTransferMethod.TOOL_FILE: + assert file_mata.related_id is not None + url = cls.get_tool_file_url(tool_file_id=file_mata.related_id, extension=file_mata.extension) + if file_mata.type == FileType.IMAGE: result.append( ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 4e226810d6ac90..3689dcc9e5ebfd 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -5,7 +5,7 @@ """ import json -from typing import cast +from typing import Optional, cast from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult @@ -51,7 +51,7 @@ def get_max_llm_context_tokens( if not schema: raise InvokeModelError("No model schema found") - max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) + max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) if max_tokens is None: return 2048 @@ -133,14 +133,17 @@ def invoke( db.session.commit() try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=[], - stop=[], - stream=False, - user=user_id, - callbacks=[], + response: LLMResult = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=[], + stop=[], + stream=False, + user=user_id, + callbacks=[], + ), ) except InvokeRateLimitError as e: raise InvokeModelError(f"Invoke rate limit error: {e}") diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index ae44b1b99d447a..f1dc1123b9935f 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -6,7 +6,7 @@ from typing import Optional from requests import get -from yaml import YAMLError, safe_load +from yaml import YAMLError, safe_load # type: ignore from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle @@ -64,6 +64,9 @@ def parse_openapi_to_tool_bundle( default=parameter["schema"]["default"] if "schema" in parameter and "default" in parameter["schema"] else None, + placeholder=I18nObject( + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") + ), ) # check if there is a type @@ -108,6 +111,9 @@ def parse_openapi_to_tool_bundle( form=ToolParameter.ToolParameterForm.LLM, llm_description=property.get("description", ""), default=property.get("default", None), + placeholder=I18nObject( + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") + ), ) # check if there is a type @@ -158,9 +164,9 @@ def parse_openapi_to_tool_bundle( return bundles @staticmethod - def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType: + def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]: parameter = parameter or {} - typ = None + typ: Optional[str] = None if parameter.get("format") == "binary": return ToolParameter.ToolParameterType.FILE @@ -175,6 +181,8 @@ def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType return ToolParameter.ToolParameterType.BOOLEAN elif typ == "string": return ToolParameter.ToolParameterType.STRING + else: + return None @staticmethod def parse_openapi_yaml_to_tool_bundle( @@ -236,7 +244,8 @@ def parse_swagger_to_openapi(swagger: dict, extra_info: Optional[dict], warning: if ("summary" not in operation or len(operation["summary"]) == 0) and ( "description" not in operation or len(operation["description"]) == 0 ): - warning["missing_summary"] = f"No summary or description found in operation {method} {path}." + if warning is not None: + warning["missing_summary"] = f"No summary or description found in operation {method} {path}." openapi["paths"][path][method] = { "operationId": operation["operationId"], diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 3aae31e93a1304..d42fd99fce5e80 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -9,13 +9,13 @@ import unicodedata from contextlib import contextmanager from pathlib import Path -from typing import Optional +from typing import Any, Literal, Optional, cast from urllib.parse import unquote import chardet -import cloudscraper -from bs4 import BeautifulSoup, CData, Comment, NavigableString -from regex import regex +import cloudscraper # type: ignore +from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore +from regex import regex # type: ignore from core.helper import ssrf_proxy from core.rag.extractor import extract_processor @@ -68,7 +68,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str: return "Unsupported content-type [{}] of URL.".format(main_content_type) if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: - return ExtractProcessor.load_from_url(url, return_text=True) + return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) elif response.status_code == 403: @@ -125,7 +125,7 @@ def extract_using_readabilipy(html): os.unlink(article_json_path) os.unlink(html_path) - article_json = { + article_json: dict[str, Any] = { "title": None, "byline": None, "date": None, @@ -300,7 +300,7 @@ def strip_control_characters(text): def normalize_unicode(text): """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible.""" - normal_form = "NFKC" + normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC" text = unicodedata.normalize(normal_form, text) return text @@ -332,6 +332,7 @@ def add_content_digest(element): def content_digest(element): + digest: Any if is_text(element): # Hash trimmed_string = element.string.strip() diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index d92bfb9b90a9aa..08a112cfdb2b91 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -7,7 +7,7 @@ class WorkflowToolConfigurationUtils: @classmethod - def check_parameter_configurations(cls, configurations: Mapping[str, Any]): + def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]): for configuration in configurations: WorkflowToolParameterConfiguration.model_validate(configuration) @@ -27,7 +27,7 @@ def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[Vari @classmethod def check_is_synced( cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] - ) -> None: + ) -> bool: """ check is synced diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index 42c7f85bc6daeb..ee7ca11e056625 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Any -import yaml +import yaml # type: ignore from yaml import YAMLError logger = logging.getLogger(__name__) diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index 973e420961bb46..c32815b24d02ed 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from typing import cast from uuid import uuid4 from pydantic import Field @@ -78,7 +79,7 @@ class SecretVariable(StringVariable): @property def log(self) -> str: - return encrypter.obfuscated_token(self.value) + return cast(str, encrypter.obfuscated_token(self.value)) class NoneVariable(NoneSegment, Variable): diff --git a/api/core/workflow/callbacks/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py index ed737e7316973c..b9c6b35ad3476a 100644 --- a/api/core/workflow/callbacks/workflow_logging_callback.py +++ b/api/core/workflow/callbacks/workflow_logging_callback.py @@ -33,7 +33,7 @@ class WorkflowLoggingCallback(WorkflowCallback): def __init__(self) -> None: - self.current_node_id = None + self.current_node_id: Optional[str] = None def on_event(self, event: GraphEngineEvent) -> None: if isinstance(event, GraphRunStartedEvent): diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index ca01dcd7d8d4a8..ae5f117bf9b121 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -36,9 +36,9 @@ class NodeRunResult(BaseModel): status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING inputs: Optional[Mapping[str, Any]] = None # node inputs - process_data: Optional[dict[str, Any]] = None # process data + process_data: Optional[Mapping[str, Any]] = None # process data outputs: Optional[Mapping[str, Any]] = None # node outputs - metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata + metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata llm_usage: Optional[LLMUsage] = None # llm usage edge_source_handle: Optional[str] = None # source handle id of node with multiple branches diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py index bc3a15bd004ace..b8470aecbd83a2 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -5,7 +5,7 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler): - def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState): """ Check if the condition can be executed diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 800dd136afb57f..b3bcc3b2ccc309 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -1,4 +1,5 @@ import uuid +from collections import defaultdict from collections.abc import Mapping from typing import Any, Optional, cast @@ -310,26 +311,17 @@ def _recursively_add_parallels( parallel = None if len(target_node_edges) > 1: # fetch all node ids in current parallels - parallel_branch_node_ids = {} - condition_edge_mappings = {} + parallel_branch_node_ids = defaultdict(list) + condition_edge_mappings = defaultdict(list) for graph_edge in target_node_edges: if graph_edge.run_condition is None: - if "default" not in parallel_branch_node_ids: - parallel_branch_node_ids["default"] = [] - parallel_branch_node_ids["default"].append(graph_edge.target_node_id) else: condition_hash = graph_edge.run_condition.hash - if condition_hash not in condition_edge_mappings: - condition_edge_mappings[condition_hash] = [] - condition_edge_mappings[condition_hash].append(graph_edge) for condition_hash, graph_edges in condition_edge_mappings.items(): if len(graph_edges) > 1: - if condition_hash not in parallel_branch_node_ids: - parallel_branch_node_ids[condition_hash] = [] - for graph_edge in graph_edges: parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id) @@ -418,7 +410,7 @@ def _recursively_add_parallels( if condition_edge_mappings: for condition_hash, graph_edges in condition_edge_mappings.items(): for graph_edge in graph_edges: - current_parallel: GraphParallel | None = cls._get_current_parallel( + current_parallel = cls._get_current_parallel( parallel_mapping=parallel_mapping, graph_edge=graph_edge, parallel=condition_parallels.get(condition_hash), diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 854036b2c13212..db1e01f14fda59 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -40,6 +40,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.nodes import NodeType from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor +from core.workflow.nodes.answer.base_stream_processor import StreamProcessor from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor @@ -66,7 +67,7 @@ def __init__( self.max_submit_count = max_submit_count self.submit_count = 0 - def submit(self, fn, *args, **kwargs): + def submit(self, fn, /, *args, **kwargs): self.submit_count += 1 self.check_is_full() @@ -140,7 +141,8 @@ def __init__( def run(self) -> Generator[GraphEngineEvent, None, None]: # trigger graph run start event yield GraphRunStartedEvent() - handle_exceptions = [] + handle_exceptions: list[str] = [] + stream_processor: StreamProcessor try: if self.init_params.workflow_type == WorkflowType.CHAT: @@ -168,7 +170,7 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: elif isinstance(item, NodeRunSucceededEvent): if item.node_type == NodeType.END: self.graph_runtime_state.outputs = ( - item.route_node_state.node_run_result.outputs + dict(item.route_node_state.node_run_result.outputs) if item.route_node_state.node_run_result and item.route_node_state.node_run_result.outputs else {} @@ -350,7 +352,7 @@ def _run( if any(edge.run_condition for edge in edge_mappings): # if nodes has run conditions, get node id which branch to take based on the run condition results - condition_edge_mappings = {} + condition_edge_mappings: dict[str, list[GraphEdge]] = {} for edge in edge_mappings: if edge.run_condition: run_condition_hash = edge.run_condition.hash @@ -364,6 +366,9 @@ def _run( continue edge = cast(GraphEdge, sub_edge_mappings[0]) + if edge.run_condition is None: + logger.warning(f"Edge {edge.target_node_id} run condition is None") + continue result = ConditionManager.get_condition_handler( init_params=self.init_params, @@ -387,11 +392,11 @@ def _run( handle_exceptions=handle_exceptions, ) - for item in parallel_generator: - if isinstance(item, str): - final_node_id = item + for parallel_result in parallel_generator: + if isinstance(parallel_result, str): + final_node_id = parallel_result else: - yield item + yield parallel_result break @@ -413,11 +418,11 @@ def _run( handle_exceptions=handle_exceptions, ) - for item in parallel_generator: - if isinstance(item, str): - final_node_id = item + for generated_item in parallel_generator: + if isinstance(generated_item, str): + final_node_id = generated_item else: - yield item + yield generated_item if not final_node_id: break @@ -653,7 +658,7 @@ def _run_node( parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - error=run_result.error, + error=run_result.error or "Unknown error", retry_index=retries, start_at=retry_start_at, ) @@ -732,20 +737,20 @@ def _run_node( variable_value=variable_value, ) - # add parallel info to run result metadata - if parallel_id and parallel_start_node_id: - if not run_result.metadata: - run_result.metadata = {} + # When setting metadata, convert to dict first + if not run_result.metadata: + run_result.metadata = {} - run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id - run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = ( - parallel_start_node_id - ) + if parallel_id and parallel_start_node_id: + metadata_dict = dict(run_result.metadata) + metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id + metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id if parent_parallel_id and parent_parallel_start_node_id: - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( + metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id + metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( parent_parallel_start_node_id ) + run_result.metadata = metadata_dict yield NodeRunSucceededEvent( id=node_instance.id, @@ -869,8 +874,8 @@ def _handle_continue_on_error( variable_pool.add([node_instance.node_id, "error_message"], error_result.error) variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type) # add error message to handle_exceptions - handle_exceptions.append(error_result.error) - node_error_args = { + handle_exceptions.append(error_result.error or "") + node_error_args: dict[str, Any] = { "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": error_result.error, "inputs": error_result.inputs, diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index ed033e7f283961..40213bd151f7af 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -63,7 +63,7 @@ def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generat self._remove_unreachable_nodes(event) # generate stream outputs - yield from self._generate_stream_outputs_when_node_finished(event) + yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event)) else: yield event @@ -130,7 +130,7 @@ def _generate_stream_outputs_when_node_finished( node_type=event.node_type, node_data=event.node_data, chunk_content=text, - from_variable_selector=value_selector, + from_variable_selector=list(value_selector), route_node_state=event.route_node_state, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index d785397e130565..8ffb487ec108f8 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -3,7 +3,7 @@ from collections.abc import Generator from core.workflow.entities.variable_pool import VariablePool -from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent +from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent from core.workflow.graph_engine.entities.graph import Graph logger = logging.getLogger(__name__) @@ -19,7 +19,7 @@ def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: raise NotImplementedError - def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: + def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None: finished_node_id = event.route_node_state.node_id if finished_node_id not in self.rest_node_ids: return @@ -32,8 +32,8 @@ def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: return if run_result.edge_source_handle: - reachable_node_ids = [] - unreachable_first_node_ids = [] + reachable_node_ids: list[str] = [] + unreachable_first_node_ids: list[str] = [] if finished_node_id not in self.graph.edge_mapping: logger.warning(f"node {finished_node_id} has no edge mapping") return diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 529fd7be74e9a1..6bf8899f5d698b 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -38,7 +38,8 @@ def _parse_json(value: str) -> Any: @staticmethod def _validate_array(value: Any, element_type: DefaultValueType) -> bool: """Unified array type validation""" - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) + # FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it + return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore @staticmethod def _convert_number(value: str) -> float: @@ -84,7 +85,7 @@ def validate_value_type(self) -> "DefaultValue": }, } - validator = type_validators.get(self.type) + validator: dict[str, Any] = type_validators.get(self.type, {}) if not validator: if self.type == DefaultValueType.ARRAY_FILES: # Handle files type diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 4e371ca43645a5..2f82bf8c382b55 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -125,7 +125,7 @@ def _transform_result( if depth > dify_config.CODE_MAX_DEPTH: raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") - transformed_result = {} + transformed_result: dict[str, Any] = {} if output_schema is None: # validate output thought instance type for output_name, output_value in result.items(): diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index e78183baf12389..a4540358883210 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -14,7 +14,7 @@ class CodeNodeData(BaseNodeData): class Output(BaseModel): type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] - children: Optional[dict[str, "Output"]] = None + children: Optional[dict[str, "CodeNodeData.Output"]] = None class Dependency(BaseModel): name: str diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 6d82dbe6d70da3..0b1dc611c59da2 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -4,6 +4,7 @@ import logging import os import tempfile +from typing import cast import docx import pandas as pd @@ -159,7 +160,7 @@ def _extract_text_from_yaml(file_content: bytes) -> str: """Extract the content from yaml file""" try: yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore")) - return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) + return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)) except (UnicodeDecodeError, yaml.YAMLError) as e: raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e @@ -229,9 +230,9 @@ def _download_file_content(file: File) -> bytes: raise FileDownloadError("Missing URL for remote file") response = ssrf_proxy.get(file.remote_url) response.raise_for_status() - return response.content + return cast(bytes, response.content) else: - return file_manager.download(file) + return cast(bytes, file_manager.download(file)) except Exception as e: raise FileDownloadError(f"Error downloading file: {str(e)}") from e diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index 0db1ba9f09d36e..b3678a82b73959 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -67,7 +67,7 @@ def extract_stream_variable_selector_from_node_data( and node_type == NodeType.LLM.value and variable_selector.value_selector[1] == "text" ): - value_selectors.append(variable_selector.value_selector) + value_selectors.append(list(variable_selector.value_selector)) return value_selectors @@ -119,8 +119,7 @@ def _recursive_fetch_end_dependencies( current_node_id: str, end_node_id: str, node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], - # type: ignore[name-defined] + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] end_dependencies: dict[str, list[str]], ) -> None: """ diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index 1aecf863ac5fb9..a770eb951f6c8c 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -23,7 +23,7 @@ def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: self.route_position[end_node_id] = 0 self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} self.has_output = False - self.output_node_ids = set() + self.output_node_ids: set[str] = set() def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: for event in generator: diff --git a/api/core/workflow/nodes/event/event.py b/api/core/workflow/nodes/event/event.py index 137b47655102af..9fea3fbda3141f 100644 --- a/api/core/workflow/nodes/event/event.py +++ b/api/core/workflow/nodes/event/event.py @@ -42,6 +42,6 @@ class RunRetryEvent(BaseModel): class SingleStepRetryEvent(NodeRunResult): """Single step retry event""" - status: str = WorkflowNodeExecutionStatus.RETRY.value + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RETRY elapsed_time: float = Field(..., description="elapsed time") diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 575db15d365efb..cdfdc6e6d51b77 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -107,9 +107,9 @@ def _init_params(self): if not (key := key.strip()): continue - value = value[0].strip() if value else "" + value_str = value[0].strip() if value else "" result.append( - (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value).text) + (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text) ) self.params = result @@ -182,9 +182,10 @@ def _init_body(self): self.variable_pool.convert_template(item.key).text: item.file for item in filter(lambda item: item.type == "file", data) } + files: dict[str, Any] = {} files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()} files = {k: v for k, v in files.items() if v is not None} - files = {k: variable.value for k, variable in files.items()} + files = {k: variable.value for k, variable in files.items() if variable is not None} files = { k: (v.filename, file_manager.download(v), v.mime_type or "application/octet-stream") for k, v in files.items() @@ -258,7 +259,8 @@ def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: response = getattr(ssrf_proxy, self.method)(**request_args) except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e: raise HttpRequestNodeError(str(e)) - return response + # FIXME: fix type ignore, this maybe httpx type issue + return response # type: ignore def invoke(self) -> Response: # assemble headers @@ -300,37 +302,37 @@ def to_log(self): continue raw += f"{k}: {v}\r\n" - body = "" + body_string = "" if self.files: for k, v in self.files.items(): - body += f"--{boundary}\r\n" - body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' - body += f"{v[1]}\r\n" - body += f"--{boundary}--\r\n" + body_string += f"--{boundary}\r\n" + body_string += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' + body_string += f"{v[1]}\r\n" + body_string += f"--{boundary}--\r\n" elif self.node_data.body: if self.content: if isinstance(self.content, str): - body = self.content + body_string = self.content elif isinstance(self.content, bytes): - body = self.content.decode("utf-8", errors="replace") + body_string = self.content.decode("utf-8", errors="replace") elif self.data and self.node_data.body.type == "x-www-form-urlencoded": - body = urlencode(self.data) + body_string = urlencode(self.data) elif self.data and self.node_data.body.type == "form-data": for key, value in self.data.items(): - body += f"--{boundary}\r\n" - body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - body += f"{value}\r\n" - body += f"--{boundary}--\r\n" + body_string += f"--{boundary}\r\n" + body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' + body_string += f"{value}\r\n" + body_string += f"--{boundary}--\r\n" elif self.json: - body = json.dumps(self.json) + body_string = json.dumps(self.json) elif self.node_data.body.type == "raw-text": if len(self.node_data.body.data) != 1: raise RequestBodyError("raw-text body type should have exactly one item") - body = self.node_data.body.data[0].value - if body: - raw += f"Content-Length: {len(body)}\r\n" + body_string = self.node_data.body.data[0].value + if body_string: + raw += f"Content-Length: {len(body_string)}\r\n" raw += "\r\n" # Empty line between headers and body - raw += body + raw += body_string return raw diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index ebed690f6f3ffb..861119f26cb088 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -1,7 +1,7 @@ import logging import mimetypes from collections.abc import Mapping, Sequence -from typing import Any +from typing import Any, Optional from configs import dify_config from core.file import File, FileTransferMethod @@ -36,7 +36,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): _node_type = NodeType.HTTP_REQUEST @classmethod - def get_default_config(cls, filters: dict | None = None) -> dict: + def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: return { "type": "http-request", "config": { @@ -160,8 +160,8 @@ def _extract_variable_selector_to_variable_mapping( ) mapping = {} - for selector in selectors: - mapping[node_id + "." + selector.variable] = selector.value_selector + for selector_iter in selectors: + mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector return mapping diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 6a89cbfad61684..f1289558fffa82 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -361,13 +361,16 @@ def _handle_event_metadata( metadata = event.route_node_state.node_run_result.metadata if not metadata: metadata = {} - if NodeRunMetadataKey.ITERATION_ID not in metadata: - metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - if self.node_data.is_parallel: - metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id - else: - metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index + metadata = { + **metadata, + NodeRunMetadataKey.ITERATION_ID: self.node_id, + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID + if self.node_data.is_parallel + else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id + if self.node_data.is_parallel + else iter_run_index, + } event.route_node_state.node_run_result.metadata = metadata return event diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 4f9e415f4b83a3..bfd93c074dd6d5 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -147,6 +147,8 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: planning_strategy=planning_strategy, ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: + if node_data.multiple_retrieval_config is None: + raise ValueError("multiple_retrieval_config is required") if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": if node_data.multiple_retrieval_config.reranking_model: reranking_model = { @@ -157,6 +159,8 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: reranking_model = None weights = None elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": + if node_data.multiple_retrieval_config.weights is None: + raise ValueError("weights is required") reranking_model = None vector_setting = node_data.multiple_retrieval_config.weights.vector_setting weights = { @@ -180,7 +184,9 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: available_datasets=available_datasets, query=query, top_k=node_data.multiple_retrieval_config.top_k, - score_threshold=node_data.multiple_retrieval_config.score_threshold, + score_threshold=node_data.multiple_retrieval_config.score_threshold + if node_data.multiple_retrieval_config.score_threshold is not None + else 0.0, reranking_mode=node_data.multiple_retrieval_config.reranking_mode, reranking_model=reranking_model, weights=weights, @@ -205,7 +211,7 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: "content": item.page_content, } retrieval_resource_list.append(source) - document_score_list = {} + document_score_list: dict[str, float] = {} # deal with dify documents if dify_documents: document_score_list = {} @@ -260,7 +266,9 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: retrieval_resource_list.append(source) if retrieval_resource_list: retrieval_resource_list = sorted( - retrieval_resource_list, key=lambda x: x.get("metadata").get("score") or 0.0, reverse=True + retrieval_resource_list, + key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, + reverse=True, ) position = 1 for item in retrieval_resource_list: @@ -295,6 +303,8 @@ def _fetch_model_config( :param node_data: node data :return: """ + if node_data.single_retrieval_config is None: + raise ValueError("single_retrieval_config is required") model_name = node_data.single_retrieval_config.model.name provider_name = node_data.single_retrieval_config.model.provider diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 79066cece4f93c..432c57294ecbe9 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Sequence -from typing import Literal, Union +from typing import Any, Literal, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -17,9 +17,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): _node_type = NodeType.LIST_OPERATOR def _run(self): - inputs = {} - process_data = {} - outputs = {} + inputs: dict[str, list] = {} + process_data: dict[str, list] = {} + outputs: dict[str, Any] = {} variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) if variable is None: @@ -93,6 +93,8 @@ def _run(self): def _apply_filter( self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + filter_func: Callable[[Any], bool] + result: list[Any] = [] for condition in self.node_data.filter_by.conditions: if isinstance(variable, ArrayStringSegment): if not isinstance(condition.value, str): @@ -236,6 +238,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: + extract_func: Callable[[File], Any] if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str): extract_func = _get_file_extract_string_func(key=key) return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) @@ -249,47 +252,47 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str raise InvalidKeyError(f"Invalid key: {key}") -def _contains(value: str): +def _contains(value: str) -> Callable[[str], bool]: return lambda x: value in x -def _startswith(value: str): +def _startswith(value: str) -> Callable[[str], bool]: return lambda x: x.startswith(value) -def _endswith(value: str): +def _endswith(value: str) -> Callable[[str], bool]: return lambda x: x.endswith(value) -def _is(value: str): +def _is(value: str) -> Callable[[str], bool]: return lambda x: x is value -def _in(value: str | Sequence[str]): +def _in(value: str | Sequence[str]) -> Callable[[str], bool]: return lambda x: x in value -def _eq(value: int | float): +def _eq(value: int | float) -> Callable[[int | float], bool]: return lambda x: x == value -def _ne(value: int | float): +def _ne(value: int | float) -> Callable[[int | float], bool]: return lambda x: x != value -def _lt(value: int | float): +def _lt(value: int | float) -> Callable[[int | float], bool]: return lambda x: x < value -def _le(value: int | float): +def _le(value: int | float) -> Callable[[int | float], bool]: return lambda x: x <= value -def _gt(value: int | float): +def _gt(value: int | float) -> Callable[[int | float], bool]: return lambda x: x > value -def _ge(value: int | float): +def _ge(value: int | float) -> Callable[[int | float], bool]: return lambda x: x >= value @@ -302,6 +305,7 @@ def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]): def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]): + extract_func: Callable[[File], Any] if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}: extract_func = _get_file_extract_string_func(key=order_by) return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 55fac45576c821..6909b30c9e82ca 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -88,8 +88,8 @@ class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData _node_type = NodeType.LLM - def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]: - node_inputs = None + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: + node_inputs: Optional[dict[str, Any]] = None process_data = None try: @@ -196,7 +196,6 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] error_type=type(e).__name__, ) ) - return except Exception as e: yield RunCompletedEvent( run_result=NodeRunResult( @@ -206,7 +205,6 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] process_data=process_data, ) ) - return outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} @@ -302,7 +300,7 @@ def _transform_chat_messages( return messages def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: - variables = {} + variables: dict[str, Any] = {} if not node_data.prompt_config: return variables @@ -319,7 +317,7 @@ def parse_dict(input_dict: Mapping[str, Any]) -> str: """ # check if it's a context structure if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict: - return input_dict["content"] + return str(input_dict["content"]) # else, parse the dict try: @@ -557,7 +555,8 @@ def _fetch_prompt_messages( variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: - prompt_messages = [] + # FIXME: fix the type error cause prompt_messages is type quick a few times + prompt_messages: list[Any] = [] if isinstance(prompt_template, list): # For chat model @@ -783,7 +782,7 @@ def _extract_variable_selector_to_variable_mapping( else: raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") - variable_mapping = {} + variable_mapping: dict[str, Any] = {} for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector @@ -981,7 +980,7 @@ def _handle_memory_chat_mode( memory_config: MemoryConfig | None, model_config: ModelConfigWithCredentialsEntity, ) -> Sequence[PromptMessage]: - memory_messages = [] + memory_messages: Sequence[PromptMessage] = [] # Get messages from memory for chat model if memory and memory_config: rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 6fdff966026b63..a366c287c2ac56 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -14,8 +14,8 @@ class LoopNode(BaseNode[LoopNodeData]): _node_data_cls = LoopNodeData _node_type = NodeType.LOOP - def _run(self) -> LoopState: - return super()._run() + def _run(self) -> LoopState: # type: ignore + return super()._run() # type: ignore @classmethod def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: @@ -28,7 +28,7 @@ def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: # TODO waiting for implementation return [ - Condition( + Condition( # type: ignore variable_selector=[node_id, "index"], comparison_operator="≤", value_type="value_selector", diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index a001b44dc7dfee..369eb13b04e8c4 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -25,7 +25,7 @@ def validate_name(cls, value) -> str: raise ValueError("Parameter name is required") if value in {"__reason", "__is_success"}: raise ValueError("Invalid parameter name, __reason and __is_success are reserved") - return value + return str(value) class ParameterExtractorNodeData(BaseNodeData): @@ -52,7 +52,7 @@ def get_parameter_json_schema(self) -> dict: :return: parameter json schema """ - parameters = {"type": "object", "properties": {}, "required": []} + parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []} for parameter in self.parameters: parameter_schema: dict[str, Any] = {"description": parameter.description} diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index c8c854a43b3269..9c88047f2c8e57 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -63,7 +63,8 @@ class ParameterExtractorNode(LLMNode): Parameter Extractor Node. """ - _node_data_cls = ParameterExtractorNodeData + # FIXME: figure out why here is different from super class + _node_data_cls = ParameterExtractorNodeData # type: ignore _node_type = NodeType.PARAMETER_EXTRACTOR _model_instance: Optional[ModelInstance] = None @@ -253,6 +254,9 @@ def _invoke( # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) + if text is None: + text = "" + return text, usage, tool_call def _generate_function_call_prompt( @@ -605,9 +609,10 @@ def extract_json(text): json_str = extract_json(result[idx:]) if json_str: try: - return json.loads(json_str) + return cast(dict, json.loads(json_str)) except Exception: pass + return None def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: """ @@ -616,13 +621,13 @@ def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCal if not tool_call or not tool_call.function.arguments: return None - return json.loads(tool_call.function.arguments) + return cast(dict, json.loads(tool_call.function.arguments)) def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: """ Generate default result. """ - result = {} + result: dict[str, Any] = {} for parameter in data.parameters: if parameter.type == "number": result[parameter.name] = 0 @@ -772,7 +777,7 @@ def _extract_variable_selector_to_variable_mapping( *, graph_config: Mapping[str, Any], node_id: str, - node_data: ParameterExtractorNodeData, + node_data: ParameterExtractorNodeData, # type: ignore ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -781,6 +786,7 @@ def _extract_variable_selector_to_variable_mapping( :param node_data: node data :return: """ + # FIXME: fix the type error later variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} if node_data.instruction: diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py index e603add1704544..6c3155ac9a54e3 100644 --- a/api/core/workflow/nodes/parameter_extractor/prompts.py +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -1,3 +1,5 @@ +from typing import Any + FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters" FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. @@ -35,7 +37,7 @@ """ # noqa: E501 -FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [ +FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [ { "user": { "query": "What is the weather today in SF?", diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 31f8368d590ea9..0ec44eefacf52f 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -34,12 +34,9 @@ QUESTION_CLASSIFIER_USER_PROMPT_3, ) -if TYPE_CHECKING: - from core.file import File - class QuestionClassifierNode(LLMNode): - _node_data_cls = QuestionClassifierNodeData + _node_data_cls = QuestionClassifierNodeData # type: ignore _node_type = NodeType.QUESTION_CLASSIFIER def _run(self): @@ -61,7 +58,7 @@ def _run(self): node_data.instruction = node_data.instruction or "" node_data.instruction = variable_pool.convert_template(node_data.instruction).text - files: Sequence[File] = ( + files = ( self._fetch_files( selector=node_data.vision.configs.variable_selector, ) @@ -168,7 +165,7 @@ def _extract_variable_selector_to_variable_mapping( *, graph_config: Mapping[str, Any], node_id: str, - node_data: QuestionClassifierNodeData, + node_data: Any, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -177,6 +174,7 @@ def _extract_variable_selector_to_variable_mapping( :param node_data: node data :return: """ + node_data = cast(QuestionClassifierNodeData, node_data) variable_mapping = {"query": node_data.query_variable_selector} variable_selectors = [] if node_data.instruction: diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 983fa7e623177a..01d07e494944b4 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -9,7 +9,6 @@ from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine -from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool @@ -46,6 +45,8 @@ def _run(self) -> NodeRunResult: # get tool runtime try: + from core.tools.tool_manager import ToolManager + tool_runtime = ToolManager.get_workflow_tool_runtime( self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from ) @@ -142,7 +143,7 @@ def _generate_parameters( """ tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} - result = {} + result: dict[str, Any] = {} for parameter_name in node_data.tool_parameters: parameter = tool_parameters_dictionary.get(parameter_name) if not parameter: @@ -264,9 +265,9 @@ def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> """ return "\n".join( [ - f"{message.message}" + str(message.message) if message.type == ToolInvokeMessage.MessageType.TEXT - else f"Link: {message.message}" + else f"Link: {str(message.message)}" for message in tool_response if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK} ] diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 8eb4bd5c2da573..9acc76f326eec9 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -36,6 +36,8 @@ def _run(self) -> NodeRunResult: case WriteMode.CLEAR: income_value = get_zero_value(original_variable.value_type) + if income_value is None: + raise VariableOperatorNodeError("income value not found") updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index d73c7442029225..0c4aae827c0a0f 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, cast from core.variables import SegmentType, Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID @@ -29,7 +29,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): def _run(self) -> NodeRunResult: inputs = self.node_data.model_dump() - process_data = {} + process_data: dict[str, Any] = {} # NOTE: This node has no outputs updated_variables: list[Variable] = [] @@ -119,7 +119,7 @@ def _run(self) -> NodeRunResult: else: conversation_id = conversation_id.value common_helpers.update_conversation_variable( - conversation_id=conversation_id, + conversation_id=cast(str, conversation_id), variable=variable, ) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 811e40c11e5407..b14c6fafbd9fdc 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -129,11 +129,11 @@ def single_step_run( :return: """ # fetch node info from workflow graph - graph = workflow.graph_dict - if not graph: + workflow_graph = workflow.graph_dict + if not workflow_graph: raise ValueError("workflow graph not found") - nodes = graph.get("nodes") + nodes = workflow_graph.get("nodes") if not nodes: raise ValueError("nodes not found in workflow graph") @@ -196,7 +196,8 @@ def single_step_run( @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: - return WorkflowEntry._handle_special_values(value) + result = WorkflowEntry._handle_special_values(value) + return result if isinstance(result, Mapping) or result is None else dict(result) @staticmethod def _handle_special_values(value: Any) -> Any: @@ -208,10 +209,10 @@ def _handle_special_values(value: Any) -> Any: res[k] = WorkflowEntry._handle_special_values(v) return res if isinstance(value, list): - res = [] + res_list = [] for item in value: - res.append(WorkflowEntry._handle_special_values(item)) - return res + res_list.append(WorkflowEntry._handle_special_values(item)) + return res_list if isinstance(value, File): return value.to_dict() return value diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 24fa013697994c..8a677f6b6fc017 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -14,7 +14,7 @@ @document_index_created.connect def handle(sender, **kwargs): dataset_id = sender - document_ids = kwargs.get("document_ids") + document_ids = kwargs.get("document_ids", []) documents = [] start_at = time.perf_counter() for document_id in document_ids: diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index 1515661b2d45b8..5e7caf8cbed71e 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -8,18 +8,19 @@ def handle(sender, **kwargs): """Create site record when an app is created.""" app = sender account = kwargs.get("account") - site = Site( - app_id=app.id, - title=app.name, - icon_type=app.icon_type, - icon=app.icon, - icon_background=app.icon_background, - default_language=account.interface_language, - customize_token_strategy="not_allow", - code=Site.generate_code(16), - created_by=app.created_by, - updated_by=app.updated_by, - ) + if account is not None: + site = Site( + app_id=app.id, + title=app.name, + icon_type=app.icon_type, + icon=app.icon, + icon_background=app.icon_background, + default_language=account.interface_language, + customize_token_strategy="not_allow", + code=Site.generate_code(16), + created_by=app.created_by, + updated_by=app.updated_by, + ) - db.session.add(site) - db.session.commit() + db.session.add(site) + db.session.commit() diff --git a/api/events/event_handlers/deduct_quota_when_message_created.py b/api/events/event_handlers/deduct_quota_when_message_created.py index 843a2320968ced..1ed37efba0b3be 100644 --- a/api/events/event_handlers/deduct_quota_when_message_created.py +++ b/api/events/event_handlers/deduct_quota_when_message_created.py @@ -44,7 +44,7 @@ def handle(sender, **kwargs): else: used_quota = 1 - if used_quota is not None: + if used_quota is not None and system_configuration.current_quota_type is not None: db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.provider_name == model_config.provider, diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 9c5955c8c5a1a5..f89fae24a56378 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -8,7 +8,10 @@ @app_draft_workflow_was_synced.connect def handle(sender, **kwargs): app = sender - for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []): + synced_draft_workflow = kwargs.get("synced_draft_workflow") + if synced_draft_workflow is None: + return + for node_data in synced_draft_workflow.graph_dict.get("nodes", []): if node_data.get("data", {}).get("type") == NodeType.TOOL.value: try: tool_entity = ToolEntity(**node_data["data"]) diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index de7c0f4dfeb74f..408ed31096d2a0 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -8,16 +8,18 @@ def handle(sender, **kwargs): app = sender app_model_config = kwargs.get("app_model_config") + if app_model_config is None: + return dataset_ids = get_dataset_ids_from_model_config(app_model_config) app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() - removed_dataset_ids = [] + removed_dataset_ids: set[int] = set() if not app_dataset_joins: added_dataset_ids = dataset_ids else: - old_dataset_ids = set() + old_dataset_ids: set[int] = set() old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids @@ -37,8 +39,8 @@ def handle(sender, **kwargs): db.session.commit() -def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set: - dataset_ids = set() +def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[int]: + dataset_ids: set[int] = set() if not app_model_config: return dataset_ids diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 453395e8d7dc1c..7a31c82f6adbc2 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -17,11 +17,11 @@ def handle(sender, **kwargs): dataset_ids = get_dataset_ids_from_workflow(published_workflow) app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() - removed_dataset_ids = [] + removed_dataset_ids: set[int] = set() if not app_dataset_joins: added_dataset_ids = dataset_ids else: - old_dataset_ids = set() + old_dataset_ids: set[int] = set() old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids @@ -41,8 +41,8 @@ def handle(sender, **kwargs): db.session.commit() -def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set: - dataset_ids = set() +def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[int]: + dataset_ids: set[int] = set() graph = published_workflow.graph_dict if not graph: return dataset_ids @@ -60,7 +60,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set: for node in knowledge_retrieval_nodes: try: node_data = KnowledgeRetrievalNodeData(**node.get("data", {})) - dataset_ids.update(node_data.dataset_ids) + dataset_ids.update(int(dataset_id) for dataset_id in node_data.dataset_ids) except Exception as e: continue diff --git a/api/extensions/__init__.py b/api/extensions/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/extensions/ext_app_metrics.py b/api/extensions/ext_app_metrics.py index de1cdfeb984e86..b7d412d68deda1 100644 --- a/api/extensions/ext_app_metrics.py +++ b/api/extensions/ext_app_metrics.py @@ -54,12 +54,14 @@ def pool_stat(): from extensions.ext_database import db engine = db.engine + # TODO: Fix the type error + # FIXME maybe its sqlalchemy issue return { "pid": os.getpid(), - "pool_size": engine.pool.size(), - "checked_in_connections": engine.pool.checkedin(), - "checked_out_connections": engine.pool.checkedout(), - "overflow_connections": engine.pool.overflow(), - "connection_timeout": engine.pool.timeout(), - "recycle_time": db.engine.pool._recycle, + "pool_size": engine.pool.size(), # type: ignore + "checked_in_connections": engine.pool.checkedin(), # type: ignore + "checked_out_connections": engine.pool.checkedout(), # type: ignore + "overflow_connections": engine.pool.overflow(), # type: ignore + "connection_timeout": engine.pool.timeout(), # type: ignore + "recycle_time": db.engine.pool._recycle, # type: ignore } diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 9dbc4b93d46266..30f216ff95612b 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,8 +1,8 @@ from datetime import timedelta import pytz -from celery import Celery, Task -from celery.schedules import crontab +from celery import Celery, Task # type: ignore +from celery.schedules import crontab # type: ignore from configs import dify_config from dify_app import DifyApp @@ -47,7 +47,7 @@ def __call__(self, *args: object, **kwargs: object) -> object: worker_log_format=dify_config.LOG_FORMAT, worker_task_log_format=dify_config.LOG_FORMAT, worker_hijack_root_logger=False, - timezone=pytz.timezone(dify_config.LOG_TZ), + timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"), ) if dify_config.BROKER_USE_SSL: diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index 9c3a663af417ae..26ff6427bef1cc 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -7,7 +7,7 @@ def is_enabled() -> bool: def init_app(app: DifyApp): - from flask_compress import Compress + from flask_compress import Compress # type: ignore compress = Compress() compress.init_app(app) diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 9fc29b4eb17212..e1c459e8c17fd0 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -11,7 +11,7 @@ def init_app(app: DifyApp): - log_handlers = [] + log_handlers: list[logging.Handler] = [] log_file = dify_config.LOG_FILE if log_file: log_dir = os.path.dirname(log_file) @@ -49,7 +49,8 @@ def time_converter(seconds): return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() for handler in logging.root.handlers: - handler.formatter.converter = time_converter + if handler.formatter: + handler.formatter.converter = time_converter def get_request_id(): diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index b2955307144d67..10fb89eb7370ee 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -1,6 +1,6 @@ import json -import flask_login +import flask_login # type: ignore from flask import Response, request from flask_login import user_loaded_from_request, user_logged_in from werkzeug.exceptions import Unauthorized diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index 468aedd47ea90b..9240ebe7fcba73 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -26,7 +26,7 @@ def init_app(self, app: Flask): match mail_type: case "resend": - import resend + import resend # type: ignore api_key = dify_config.RESEND_API_KEY if not api_key: @@ -48,9 +48,9 @@ def init_app(self, app: Flask): self._client = SMTPClient( server=dify_config.SMTP_SERVER, port=dify_config.SMTP_PORT, - username=dify_config.SMTP_USERNAME, - password=dify_config.SMTP_PASSWORD, - _from=dify_config.MAIL_DEFAULT_SEND_FROM, + username=dify_config.SMTP_USERNAME or "", + password=dify_config.SMTP_PASSWORD or "", + _from=dify_config.MAIL_DEFAULT_SEND_FROM or "", use_tls=dify_config.SMTP_USE_TLS, opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS, ) diff --git a/api/extensions/ext_migrate.py b/api/extensions/ext_migrate.py index 6d8f35c30d9c65..5f862181fa8540 100644 --- a/api/extensions/ext_migrate.py +++ b/api/extensions/ext_migrate.py @@ -2,7 +2,7 @@ def init_app(app: DifyApp): - import flask_migrate + import flask_migrate # type: ignore from extensions.ext_database import db diff --git a/api/extensions/ext_proxy_fix.py b/api/extensions/ext_proxy_fix.py index 3b895ac95b5029..514e0658257293 100644 --- a/api/extensions/ext_proxy_fix.py +++ b/api/extensions/ext_proxy_fix.py @@ -6,4 +6,4 @@ def init_app(app: DifyApp): if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED: from werkzeug.middleware.proxy_fix import ProxyFix - app.wsgi_app = ProxyFix(app.wsgi_app) + app.wsgi_app = ProxyFix(app.wsgi_app) # type: ignore diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 3ec8ae6e1dc14e..3a74aace6a34cf 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -6,7 +6,7 @@ def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import openai import sentry_sdk - from langfuse import parse_error + from langfuse import parse_error # type: ignore from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 42422263c4dd03..588bdb2d2717e0 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -1,6 +1,6 @@ import logging from collections.abc import Callable, Generator -from typing import Union +from typing import Literal, Union, overload from flask import Flask @@ -79,6 +79,12 @@ def save(self, filename, data): logger.exception(f"Failed to save file {filename}") raise e + @overload + def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ... + + @overload + def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ... + def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]: try: if stream: diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index 58c917dbd386bc..00bf5d4f93ae3b 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -1,7 +1,7 @@ import posixpath from collections.abc import Generator -import oss2 as aliyun_s3 +import oss2 as aliyun_s3 # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -33,7 +33,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: obj = self.client.get_object(self.__wrapper_folder_filename(filename)) - data = obj.read() + data: bytes = obj.read() return data def load_stream(self, filename: str) -> Generator: @@ -41,14 +41,14 @@ def load_stream(self, filename: str) -> Generator: while chunk := obj.read(4096): yield chunk - def download(self, filename, target_filepath): + def download(self, filename: str, target_filepath): self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath) - def exists(self, filename): + def exists(self, filename: str): return self.client.object_exists(self.__wrapper_folder_filename(filename)) - def delete(self, filename): + def delete(self, filename: str): self.client.delete_object(self.__wrapper_folder_filename(filename)) - def __wrapper_folder_filename(self, filename) -> str: + def __wrapper_folder_filename(self, filename: str) -> str: return posixpath.join(self.folder, filename) if self.folder else filename diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py index ce36c2e7deeeda..7b6b2eedd62bf2 100644 --- a/api/extensions/storage/aws_s3_storage.py +++ b/api/extensions/storage/aws_s3_storage.py @@ -1,9 +1,9 @@ import logging from collections.abc import Generator -import boto3 -from botocore.client import Config -from botocore.exceptions import ClientError +import boto3 # type: ignore +from botocore.client import Config # type: ignore +from botocore.exceptions import ClientError # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -53,7 +53,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: try: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index b26caa8671b6df..2f8532f4f8f653 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -27,7 +27,7 @@ def load_once(self, filename: str) -> bytes: client = self._sync_client() blob = client.get_container_client(container=self.bucket_name) blob = blob.get_blob_client(blob=filename) - data = blob.download_blob().readall() + data: bytes = blob.download_blob().readall() return data def load_stream(self, filename: str) -> Generator: @@ -63,11 +63,11 @@ def _sync_client(self): sas_token = cache_result.decode("utf-8") else: sas_token = generate_account_sas( - account_name=self.account_name, - account_key=self.account_key, + account_name=self.account_name or "", + account_key=self.account_key or "", resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1), ) redis_client.set(cache_key, sas_token, ex=3000) - return BlobServiceClient(account_url=self.account_url, credential=sas_token) + return BlobServiceClient(account_url=self.account_url or "", credential=sas_token) diff --git a/api/extensions/storage/baidu_obs_storage.py b/api/extensions/storage/baidu_obs_storage.py index e0d2140e91272c..b94efa08be7613 100644 --- a/api/extensions/storage/baidu_obs_storage.py +++ b/api/extensions/storage/baidu_obs_storage.py @@ -2,9 +2,9 @@ import hashlib from collections.abc import Generator -from baidubce.auth.bce_credentials import BceCredentials -from baidubce.bce_client_configuration import BceClientConfiguration -from baidubce.services.bos.bos_client import BosClient +from baidubce.auth.bce_credentials import BceCredentials # type: ignore +from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore +from baidubce.services.bos.bos_client import BosClient # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -36,7 +36,8 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: response = self.client.get_object(bucket_name=self.bucket_name, key=filename) - return response.data.read() + data: bytes = response.data.read() + return data def load_stream(self, filename: str) -> Generator: response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index 26b662d2f04daf..705639f42e716f 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -3,7 +3,7 @@ import json from collections.abc import Generator -from google.cloud import storage as google_cloud_storage +from google.cloud import storage as google_cloud_storage # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -35,7 +35,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) - data = blob.download_as_bytes() + data: bytes = blob.download_as_bytes() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py index 20be70ef83dd7a..07f1d199701be4 100644 --- a/api/extensions/storage/huawei_obs_storage.py +++ b/api/extensions/storage/huawei_obs_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from obs import ObsClient +from obs import ObsClient # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -23,7 +23,7 @@ def save(self, filename, data): self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data) def load_once(self, filename: str) -> bytes: - data = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read() + data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index e671eff059ba21..b78fc94dae7843 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -3,7 +3,7 @@ from collections.abc import Generator from pathlib import Path -import opendal +import opendal # type: ignore[import] from dotenv import dotenv_values from extensions.storage.base_storage import BaseStorage @@ -18,7 +18,7 @@ def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str if key.startswith(config_prefix): kwargs[key[len(config_prefix) :].lower()] = value - file_env_vars = dotenv_values(env_file_path) + file_env_vars: dict = dotenv_values(env_file_path) or {} for key, value in file_env_vars.items(): if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value: kwargs[key[len(config_prefix) :].lower()] = value @@ -48,7 +48,7 @@ def load_once(self, filename: str) -> bytes: if not self.exists(filename): raise FileNotFoundError("File not found") - content = self.op.read(path=filename) + content: bytes = self.op.read(path=filename) logger.debug(f"file {filename} loaded") return content @@ -75,7 +75,7 @@ def exists(self, filename: str) -> bool: # error handler here when opendal python-binding has a exists method, we should use it # more https://github.com/apache/opendal/blob/main/bindings/python/src/operator.rs try: - res = self.op.stat(path=filename).mode.is_file() + res: bool = self.op.stat(path=filename).mode.is_file() logger.debug(f"file {filename} checked") return res except Exception: diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py index b59f83b8de90bf..82829f7fd50d65 100644 --- a/api/extensions/storage/oracle_oci_storage.py +++ b/api/extensions/storage/oracle_oci_storage.py @@ -1,7 +1,7 @@ from collections.abc import Generator -import boto3 -from botocore.exceptions import ClientError +import boto3 # type: ignore +from botocore.exceptions import ClientError # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -27,7 +27,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: try: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") diff --git a/api/extensions/storage/supabase_storage.py b/api/extensions/storage/supabase_storage.py index 9f7c69a9ae6312..711c3f72117c86 100644 --- a/api/extensions/storage/supabase_storage.py +++ b/api/extensions/storage/supabase_storage.py @@ -32,7 +32,7 @@ def save(self, filename, data): self.client.storage.from_(self.bucket_name).upload(filename, data) def load_once(self, filename: str) -> bytes: - content = self.client.storage.from_(self.bucket_name).download(filename) + content: bytes = self.client.storage.from_(self.bucket_name).download(filename) return content def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py index 13a6c9239c2d1e..9cdd3e67f75aab 100644 --- a/api/extensions/storage/tencent_cos_storage.py +++ b/api/extensions/storage/tencent_cos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from qcloud_cos import CosConfig, CosS3Client +from qcloud_cos import CosConfig, CosS3Client # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -25,7 +25,7 @@ def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename) def load_once(self, filename: str) -> bytes: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py index de82be04ea87b7..55fe6545ec3d2d 100644 --- a/api/extensions/storage/volcengine_tos_storage.py +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -import tos +import tos # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -24,6 +24,8 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: data = self.client.get_object(bucket=self.bucket_name, key=filename).read() + if not isinstance(data, bytes): + raise TypeError("Expected bytes, got {}".format(type(data).__name__)) return data def load_stream(self, filename: str) -> Generator: diff --git a/api/factories/__init__.py b/api/factories/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 13034f5cf5688b..856cf62e3ed243 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -64,7 +64,7 @@ def build_from_mapping( if not build_func: raise ValueError(f"Invalid file transfer method: {transfer_method}") - file = build_func( + file: File = build_func( mapping=mapping, tenant_id=tenant_id, transfer_method=transfer_method, @@ -72,7 +72,7 @@ def build_from_mapping( if config and not _is_file_valid_with_config( input_file_type=mapping.get("type", FileType.CUSTOM), - file_extension=file.extension, + file_extension=file.extension or "", file_transfer_method=file.transfer_method, config=config, ): @@ -281,6 +281,7 @@ def _get_file_type_by_extension(extension: str) -> FileType | None: return FileType.AUDIO elif extension in DOCUMENT_EXTENSIONS: return FileType.DOCUMENT + return None def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 16a578728aa16e..bbca8448ec0662 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any +from typing import Any, cast from uuid import uuid4 from configs import dify_config @@ -84,6 +84,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError("missing value type") if (value := mapping.get("value")) is None: raise VariableError("missing value") + # FIXME: using Any here, fix it later + result: Any match value_type: case SegmentType.STRING: result = StringVariable.model_validate(mapping) @@ -109,7 +111,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") if not result.selector: result = result.model_copy(update={"selector": selector}) - return result + return cast(Variable, result) def build_segment(value: Any, /) -> Segment: @@ -164,10 +166,13 @@ def segment_to_variable( raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return variable_class( - id=id, - name=name, - description=description, - value=segment.value, - selector=selector, + return cast( + Variable, + variable_class( + id=id, + name=name, + description=description, + value=segment.value, + selector=selector, + ), ) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index 379dcc6d16fe56..1c58b3a2579087 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index a85d4a34dbe7b1..d40407bfcc6193 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index abb27fdad17d63..73800eab853cd3 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.workflow_fields import workflow_partial_fields from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 6a9e347b1e04b4..c54554a6de8405 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.member_fields import simple_account_fields from libs.helper import TimestampField diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 983e50e73ceb9f..c6385efb5a3cf1 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py index 071071376fe6c8..608672121e2b50 100644 --- a/api/fields/data_source_fields.py +++ b/api/fields/data_source_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 533e3a0837b815..a74e6f54fb3858 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index a83ec7bc97adee..2b2ac6243f4da5 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.dataset_fields import dataset_fields from libs.helper import TimestampField diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index 99e529f9d1c076..aefa0b27580ca7 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore simple_end_user_fields = { "id": fields.String, diff --git a/api/fields/external_dataset_fields.py b/api/fields/external_dataset_fields.py index 2281460fe22146..9cc4e14a0575d7 100644 --- a/api/fields/external_dataset_fields.py +++ b/api/fields/external_dataset_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index afaacc0568ea0c..f896c15f0fec70 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index f36e80f8d493d5..aaafcab8ab6ba0 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index e0b3e340f67b8c..16f265b9bb6d07 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 1cf8e408d13d32..0c854c640c3f98 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 5f6e7884a69c5e..0571faab08c134 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.conversation_fields import message_file_fields from libs.helper import TimestampField diff --git a/api/fields/raws.py b/api/fields/raws.py index 15ec16ab13e4a8..493d4b6cce7d31 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from core.file import File diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 2dd4cb45be409b..4413af31607897 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index 9af4fc57dd061c..986cd725f70910 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,3 +1,3 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String} diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index a53b54624915c2..c45b33597b3978 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 0d860d6f406502..bd093d4063bc2e 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from core.helper import encrypter from core.variables import SecretVariable, SegmentType, Variable diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 74fdf8bd97b23a..ef59c57ec37957 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields diff --git a/api/libs/external_api.py b/api/libs/external_api.py index 179617ac0a6588..922d2d9cd33324 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -1,8 +1,9 @@ import re import sys +from typing import Any from flask import current_app, got_request_exception -from flask_restful import Api, http_status_message +from flask_restful import Api, http_status_message # type: ignore from werkzeug.datastructures import Headers from werkzeug.exceptions import HTTPException @@ -84,7 +85,7 @@ def handle_error(self, e): # record the exception in the logs when we have a server error of status code: 500 if status_code and status_code >= 500: - exc_info = sys.exc_info() + exc_info: Any = sys.exc_info() if exc_info[1] is None: exc_info = None current_app.log_exception(exc_info) @@ -100,7 +101,7 @@ def handle_error(self, e): resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype) elif status_code == 400: if isinstance(data.get("message"), dict): - param_key, param_value = list(data.get("message").items())[0] + param_key, param_value = list(data.get("message", {}).items())[0] data = {"code": "invalid_param", "message": param_value, "params": param_key} else: if "code" not in data: diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index 83f9c74e339e17..2dae87e1710bf6 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -23,7 +23,7 @@ import Crypto.Hash.SHA1 import Crypto.Util.number -import gmpy2 +import gmpy2 # type: ignore from Crypto import Random from Crypto.Signature.pss import MGF1 from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes @@ -191,12 +191,12 @@ def decrypt(self, ciphertext): # Step 3g one_pos = hLen + db[hLen:].find(b"\x01") lHash1 = db[:hLen] - invalid = bord(y) | int(one_pos < hLen) + invalid = bord(y) | int(one_pos < hLen) # type: ignore hash_compare = strxor(lHash1, lHash) for x in hash_compare: - invalid |= bord(x) + invalid |= bord(x) # type: ignore for x in db[hLen:one_pos]: - invalid |= bord(x) + invalid |= bord(x) # type: ignore if invalid != 0: raise ValueError("Incorrect decryption.") # Step 4 diff --git a/api/libs/helper.py b/api/libs/helper.py index 91b1d1fe173d6f..eaa4efdb714355 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -13,7 +13,7 @@ from zoneinfo import available_timezones from flask import Response, stream_with_context -from flask_restful import fields +from flask_restful import fields # type: ignore from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator @@ -248,13 +248,13 @@ def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]] if token_data_json is None: logging.warning(f"{token_type} token {token} not found with key {key}") return None - token_data = json.loads(token_data_json) + token_data: Optional[dict[str, Any]] = json.loads(token_data_json) return token_data @classmethod def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]: key = cls._get_account_token_key(account_id, token_type) - current_token = redis_client.get(key) + current_token: Optional[str] = redis_client.get(key) return current_token @classmethod diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 267af611f5e8cb..9ab53b6294db93 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -10,6 +10,7 @@ def parse_json_markdown(json_string: str) -> dict: ends = ["```", "``", "`", "}"] end_index = -1 start_index = 0 + parsed: dict = {} for s in starts: start_index = json_string.find(s) if start_index != -1: diff --git a/api/libs/login.py b/api/libs/login.py index 0ea191a185785d..5395534a6df93a 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,8 +1,9 @@ from functools import wraps +from typing import Any from flask import current_app, g, has_request_context, request -from flask_login import user_logged_in -from flask_login.config import EXEMPT_METHODS +from flask_login import user_logged_in # type: ignore +from flask_login.config import EXEMPT_METHODS # type: ignore from werkzeug.exceptions import Unauthorized from werkzeug.local import LocalProxy @@ -12,7 +13,7 @@ #: A proxy for the current user. If no user is logged in, this will be an #: anonymous user -current_user = LocalProxy(lambda: _get_user()) +current_user: Any = LocalProxy(lambda: _get_user()) def login_required(func): @@ -79,12 +80,12 @@ def decorated_view(*args, **kwargs): # Login admin if account: account.current_tenant = tenant - current_app.login_manager._update_request_context_with_user(account) - user_logged_in.send(current_app._get_current_object(), user=_get_user()) + current_app.login_manager._update_request_context_with_user(account) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: pass elif not current_user.is_authenticated: - return current_app.login_manager.unauthorized() + return current_app.login_manager.unauthorized() # type: ignore # flask 1.x compatibility # current_app.ensure_sync is only available in Flask >= 2.0 @@ -98,7 +99,7 @@ def decorated_view(*args, **kwargs): def _get_user(): if has_request_context(): if "_login_user" not in g: - current_app.login_manager._load_user() + current_app.login_manager._load_user() # type: ignore return g._login_user diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 6b6919de24f90f..df75b550195298 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -77,9 +77,9 @@ def get_raw_user_info(self, token: str): email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) email_info = email_response.json() - primary_email = next((email for email in email_info if email["primary"] == True), None) + primary_email: dict = next((email for email in email_info if email["primary"] == True), {}) - return {**user_info, "email": primary_email["email"]} + return {**user_info, "email": primary_email.get("email", "")} def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: email = raw_info.get("email") @@ -130,4 +130,4 @@ def get_raw_user_info(self, token: str): return response.json() def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - return OAuthUserInfo(id=str(raw_info["sub"]), name=None, email=raw_info["email"]) + return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"]) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 1d39abd8fa7886..0c872a0066d127 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -1,8 +1,9 @@ import datetime import urllib.parse +from typing import Any import requests -from flask_login import current_user +from flask_login import current_user # type: ignore from extensions.ext_database import db from models.source import DataSourceOauthBinding @@ -226,7 +227,7 @@ def notion_page_search(self, access_token: str): has_more = True while has_more: - data = { + data: dict[str, Any] = { "filter": {"value": "page", "property": "object"}, **({"start_cursor": next_cursor} if next_cursor else {}), } @@ -281,7 +282,7 @@ def notion_database_search(self, access_token: str): has_more = True while has_more: - data = { + data: dict[str, Any] = { "filter": {"value": "database", "property": "object"}, **({"start_cursor": next_cursor} if next_cursor else {}), } diff --git a/api/libs/threadings_utils.py b/api/libs/threadings_utils.py index d356def418ab1d..e4d63fd3142ce2 100644 --- a/api/libs/threadings_utils.py +++ b/api/libs/threadings_utils.py @@ -9,8 +9,8 @@ def apply_gevent_threading_patch(): :return: """ if not dify_config.DEBUG: - from gevent import monkey - from grpc.experimental import gevent as grpc_gevent + from gevent import monkey # type: ignore + from grpc.experimental import gevent as grpc_gevent # type: ignore # gevent monkey.patch_all() diff --git a/api/models/account.py b/api/models/account.py index a8602d10a97308..88c96da1a149d5 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -1,7 +1,7 @@ import enum import json -from flask_login import UserMixin +from flask_login import UserMixin # type: ignore from sqlalchemy import func from .engine import db @@ -16,7 +16,7 @@ class AccountStatus(enum.StrEnum): CLOSED = "closed" -class Account(UserMixin, db.Model): +class Account(UserMixin, db.Model): # type: ignore[name-defined] __tablename__ = "accounts" __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) @@ -43,7 +43,8 @@ def is_password_set(self): @property def current_tenant(self): - return self._current_tenant + # FIXME: fix the type error later, because the type is important maybe cause some bugs + return self._current_tenant # type: ignore @current_tenant.setter def current_tenant(self, value: "Tenant"): @@ -52,7 +53,8 @@ def current_tenant(self, value: "Tenant"): if ta: tenant.current_role = ta.role else: - tenant = None + # FIXME: fix the type error later, because the type is important maybe cause some bugs + tenant = None # type: ignore self._current_tenant = tenant @property @@ -89,7 +91,7 @@ def get_status(self) -> AccountStatus: return AccountStatus(status_str) @classmethod - def get_by_openid(cls, provider: str, open_id: str) -> db.Model: + def get_by_openid(cls, provider: str, open_id: str): account_integrate = ( db.session.query(AccountIntegrate) .filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) @@ -134,7 +136,7 @@ class TenantAccountRole(enum.StrEnum): @staticmethod def is_valid_role(role: str) -> bool: - return role and role in { + return role in { TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, @@ -144,15 +146,15 @@ def is_valid_role(role: str) -> bool: @staticmethod def is_privileged_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} + return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} @staticmethod def is_admin_role(role: str) -> bool: - return role and role == TenantAccountRole.ADMIN + return role == TenantAccountRole.ADMIN @staticmethod def is_non_owner_role(role: str) -> bool: - return role and role in { + return role in { TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, TenantAccountRole.NORMAL, @@ -161,11 +163,11 @@ def is_non_owner_role(role: str) -> bool: @staticmethod def is_editing_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} + return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} @staticmethod def is_dataset_edit_role(role: str) -> bool: - return role and role in { + return role in { TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, @@ -173,7 +175,7 @@ def is_dataset_edit_role(role: str) -> bool: } -class Tenant(db.Model): +class Tenant(db.Model): # type: ignore[name-defined] __tablename__ = "tenants" __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) @@ -209,7 +211,7 @@ class TenantAccountJoinRole(enum.Enum): DATASET_OPERATOR = "dataset_operator" -class TenantAccountJoin(db.Model): +class TenantAccountJoin(db.Model): # type: ignore[name-defined] __tablename__ = "tenant_account_joins" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), @@ -228,7 +230,7 @@ class TenantAccountJoin(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class AccountIntegrate(db.Model): +class AccountIntegrate(db.Model): # type: ignore[name-defined] __tablename__ = "account_integrates" __table_args__ = ( db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), @@ -245,7 +247,7 @@ class AccountIntegrate(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class InvitationCode(db.Model): +class InvitationCode(db.Model): # type: ignore[name-defined] __tablename__ = "invitation_codes" __table_args__ = ( db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index fbffe7a3b2ee9d..6b6d808710afc0 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -13,7 +13,7 @@ class APIBasedExtensionPoint(enum.Enum): APP_MODERATION_OUTPUT = "app.moderation.output" -class APIBasedExtension(db.Model): +class APIBasedExtension(db.Model): # type: ignore[name-defined] __tablename__ = "api_based_extensions" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), diff --git a/api/models/dataset.py b/api/models/dataset.py index 7279e8d5b3394a..b9b41dcf475bb1 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -9,6 +9,7 @@ import re import time from json import JSONDecodeError +from typing import Any, cast from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB @@ -29,7 +30,7 @@ class DatasetPermissionEnum(enum.StrEnum): PARTIAL_TEAM = "partial_members" -class Dataset(db.Model): +class Dataset(db.Model): # type: ignore[name-defined] __tablename__ = "datasets" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_pkey"), @@ -200,7 +201,7 @@ def gen_collection_name_by_id(dataset_id: str) -> str: return f"Vector_index_{normalized_dataset_id}_Node" -class DatasetProcessRule(db.Model): +class DatasetProcessRule(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_process_rules" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), @@ -216,7 +217,7 @@ class DatasetProcessRule(db.Model): MODES = ["automatic", "custom"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] - AUTOMATIC_RULES = { + AUTOMATIC_RULES: dict[str, Any] = { "pre_processing_rules": [ {"id": "remove_extra_spaces", "enabled": True}, {"id": "remove_urls_emails", "enabled": False}, @@ -242,7 +243,7 @@ def rules_dict(self): return None -class Document(db.Model): +class Document(db.Model): # type: ignore[name-defined] __tablename__ = "documents" __table_args__ = ( db.PrimaryKeyConstraint("id", name="document_pkey"), @@ -492,7 +493,7 @@ def from_dict(cls, data: dict): ) -class DocumentSegment(db.Model): +class DocumentSegment(db.Model): # type: ignore[name-defined] __tablename__ = "document_segments" __table_args__ = ( db.PrimaryKeyConstraint("id", name="document_segment_pkey"), @@ -604,7 +605,7 @@ def get_sign_content(self): return text -class AppDatasetJoin(db.Model): +class AppDatasetJoin(db.Model): # type: ignore[name-defined] __tablename__ = "app_dataset_joins" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), @@ -621,7 +622,7 @@ def app(self): return db.session.get(App, self.app_id) -class DatasetQuery(db.Model): +class DatasetQuery(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_queries" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), @@ -638,7 +639,7 @@ class DatasetQuery(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class DatasetKeywordTable(db.Model): +class DatasetKeywordTable(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_keyword_tables" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), @@ -683,7 +684,7 @@ def object_hook(self, dct): return None -class Embedding(db.Model): +class Embedding(db.Model): # type: ignore[name-defined] __tablename__ = "embeddings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="embedding_pkey"), @@ -704,10 +705,10 @@ def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) def get_embedding(self) -> list[float]: - return pickle.loads(self.embedding) + return cast(list[float], pickle.loads(self.embedding)) -class DatasetCollectionBinding(db.Model): +class DatasetCollectionBinding(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_collection_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), @@ -722,7 +723,7 @@ class DatasetCollectionBinding(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TidbAuthBinding(db.Model): +class TidbAuthBinding(db.Model): # type: ignore[name-defined] __tablename__ = "tidb_auth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), @@ -742,7 +743,7 @@ class TidbAuthBinding(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class Whitelist(db.Model): +class Whitelist(db.Model): # type: ignore[name-defined] __tablename__ = "whitelists" __table_args__ = ( db.PrimaryKeyConstraint("id", name="whitelists_pkey"), @@ -754,7 +755,7 @@ class Whitelist(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class DatasetPermission(db.Model): +class DatasetPermission(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_permissions" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), @@ -771,7 +772,7 @@ class DatasetPermission(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ExternalKnowledgeApis(db.Model): +class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined] __tablename__ = "external_knowledge_apis" __table_args__ = ( db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), @@ -824,7 +825,7 @@ def dataset_bindings(self): return dataset_bindings -class ExternalKnowledgeBindings(db.Model): +class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] __tablename__ = "external_knowledge_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), diff --git a/api/models/model.py b/api/models/model.py index 1417298c79c0a2..2a593f08298199 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -4,11 +4,11 @@ from collections.abc import Mapping from datetime import datetime from enum import Enum, StrEnum -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional, cast import sqlalchemy as sa from flask import request -from flask_login import UserMixin +from flask_login import UserMixin # type: ignore from sqlalchemy import Float, func, text from sqlalchemy.orm import Mapped, mapped_column @@ -28,7 +28,7 @@ from .workflow import Workflow -class DifySetup(db.Model): +class DifySetup(db.Model): # type: ignore[name-defined] __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) @@ -63,7 +63,7 @@ class IconType(Enum): EMOJI = "emoji" -class App(db.Model): +class App(db.Model): # type: ignore[name-defined] __tablename__ = "apps" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) @@ -86,7 +86,7 @@ class App(db.Model): is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) tracing = db.Column(db.Text, nullable=True) - max_active_requests = db.Column(db.Integer, nullable=True) + max_active_requests: Mapped[Optional[int]] = mapped_column(nullable=True) created_by = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) @@ -154,7 +154,7 @@ def mode_compatible_with_agent(self) -> str: if self.mode == AppMode.CHAT.value and self.is_agent: return AppMode.AGENT_CHAT.value - return self.mode + return str(self.mode) @property def deleted_tools(self) -> list: @@ -219,7 +219,7 @@ def tags(self): return tags or [] -class AppModelConfig(db.Model): +class AppModelConfig(db.Model): # type: ignore[name-defined] __tablename__ = "app_model_configs" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) @@ -322,7 +322,7 @@ def external_data_tools_list(self) -> list[dict]: return json.loads(self.external_data_tools) if self.external_data_tools else [] @property - def user_input_form_list(self) -> dict: + def user_input_form_list(self) -> list[dict]: return json.loads(self.user_input_form) if self.user_input_form else [] @property @@ -344,7 +344,7 @@ def completion_prompt_config_dict(self) -> dict: @property def dataset_configs_dict(self) -> dict: if self.dataset_configs: - dataset_configs = json.loads(self.dataset_configs) + dataset_configs: dict = json.loads(self.dataset_configs) if "retrieval_model" not in dataset_configs: return {"retrieval_model": "single"} else: @@ -466,7 +466,7 @@ def copy(self): return new_app_model_config -class RecommendedApp(db.Model): +class RecommendedApp(db.Model): # type: ignore[name-defined] __tablename__ = "recommended_apps" __table_args__ = ( db.PrimaryKeyConstraint("id", name="recommended_app_pkey"), @@ -494,7 +494,7 @@ def app(self): return app -class InstalledApp(db.Model): +class InstalledApp(db.Model): # type: ignore[name-defined] __tablename__ = "installed_apps" __table_args__ = ( db.PrimaryKeyConstraint("id", name="installed_app_pkey"), @@ -523,7 +523,7 @@ def tenant(self): return tenant -class Conversation(db.Model): +class Conversation(db.Model): # type: ignore[name-defined] __tablename__ = "conversations" __table_args__ = ( db.PrimaryKeyConstraint("id", name="conversation_pkey"), @@ -602,6 +602,8 @@ def inputs(self, value: Mapping[str, Any]): @property def model_config(self): model_config = {} + app_model_config: Optional[AppModelConfig] = None + if self.mode == AppMode.ADVANCED_CHAT.value: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) @@ -613,6 +615,7 @@ def model_config(self): if "model" in override_model_configs: app_model_config = AppModelConfig() app_model_config = app_model_config.from_model_config_dict(override_model_configs) + assert app_model_config is not None, "app model config not found" model_config = app_model_config.to_dict() else: model_config["configs"] = override_model_configs @@ -755,7 +758,7 @@ def in_debug_mode(self): return self.override_model_configs is not None -class Message(db.Model): +class Message(db.Model): # type: ignore[name-defined] __tablename__ = "messages" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_pkey"), @@ -995,7 +998,7 @@ def message_files(self): if not current_app: raise ValueError(f"App {self.app_id} not found") - files: list[File] = [] + files = [] for message_file in message_files: if message_file.transfer_method == "local_file": if message_file.upload_file_id is None: @@ -1102,7 +1105,7 @@ def from_dict(cls, data: dict): ) -class MessageFeedback(db.Model): +class MessageFeedback(db.Model): # type: ignore[name-defined] __tablename__ = "message_feedbacks" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), @@ -1129,7 +1132,7 @@ def from_account(self): return account -class MessageFile(db.Model): +class MessageFile(db.Model): # type: ignore[name-defined] __tablename__ = "message_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_file_pkey"), @@ -1170,7 +1173,7 @@ def __init__( created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class MessageAnnotation(db.Model): +class MessageAnnotation(db.Model): # type: ignore[name-defined] __tablename__ = "message_annotations" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), @@ -1201,7 +1204,7 @@ def annotation_create_account(self): return account -class AppAnnotationHitHistory(db.Model): +class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined] __tablename__ = "app_annotation_hit_histories" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), @@ -1239,7 +1242,7 @@ def annotation_create_account(self): return account -class AppAnnotationSetting(db.Model): +class AppAnnotationSetting(db.Model): # type: ignore[name-defined] __tablename__ = "app_annotation_settings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), @@ -1287,7 +1290,7 @@ def collection_binding_detail(self): return collection_binding_detail -class OperationLog(db.Model): +class OperationLog(db.Model): # type: ignore[name-defined] __tablename__ = "operation_logs" __table_args__ = ( db.PrimaryKeyConstraint("id", name="operation_log_pkey"), @@ -1304,7 +1307,7 @@ class OperationLog(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class EndUser(UserMixin, db.Model): +class EndUser(UserMixin, db.Model): # type: ignore[name-defined] __tablename__ = "end_users" __table_args__ = ( db.PrimaryKeyConstraint("id", name="end_user_pkey"), @@ -1324,7 +1327,7 @@ class EndUser(UserMixin, db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class Site(db.Model): +class Site(db.Model): # type: ignore[name-defined] __tablename__ = "sites" __table_args__ = ( db.PrimaryKeyConstraint("id", name="site_pkey"), @@ -1381,7 +1384,7 @@ def app_base_url(self): return dify_config.APP_WEB_URL or request.url_root.rstrip("/") -class ApiToken(db.Model): +class ApiToken(db.Model): # type: ignore[name-defined] __tablename__ = "api_tokens" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_token_pkey"), @@ -1408,7 +1411,7 @@ def generate_api_key(prefix, n): return result -class UploadFile(db.Model): +class UploadFile(db.Model): # type: ignore[name-defined] __tablename__ = "upload_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="upload_file_pkey"), @@ -1470,7 +1473,7 @@ def __init__( self.source_url = source_url -class ApiRequest(db.Model): +class ApiRequest(db.Model): # type: ignore[name-defined] __tablename__ = "api_requests" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_request_pkey"), @@ -1487,7 +1490,7 @@ class ApiRequest(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class MessageChain(db.Model): +class MessageChain(db.Model): # type: ignore[name-defined] __tablename__ = "message_chains" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_chain_pkey"), @@ -1502,7 +1505,7 @@ class MessageChain(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class MessageAgentThought(db.Model): +class MessageAgentThought(db.Model): # type: ignore[name-defined] __tablename__ = "message_agent_thoughts" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), @@ -1542,7 +1545,7 @@ class MessageAgentThought(db.Model): @property def files(self) -> list: if self.message_files: - return json.loads(self.message_files) + return cast(list[Any], json.loads(self.message_files)) else: return [] @@ -1554,7 +1557,7 @@ def tools(self) -> list[str]: def tool_labels(self) -> dict: try: if self.tool_labels_str: - return json.loads(self.tool_labels_str) + return cast(dict, json.loads(self.tool_labels_str)) else: return {} except Exception as e: @@ -1564,7 +1567,7 @@ def tool_labels(self) -> dict: def tool_meta(self) -> dict: try: if self.tool_meta_str: - return json.loads(self.tool_meta_str) + return cast(dict, json.loads(self.tool_meta_str)) else: return {} except Exception as e: @@ -1612,9 +1615,11 @@ def tool_outputs_dict(self) -> dict: except Exception as e: if self.observation: return dict.fromkeys(tools, self.observation) + else: + return {} -class DatasetRetrieverResource(db.Model): +class DatasetRetrieverResource(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_retriever_resources" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), @@ -1641,7 +1646,7 @@ class DatasetRetrieverResource(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class Tag(db.Model): +class Tag(db.Model): # type: ignore[name-defined] __tablename__ = "tags" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tag_pkey"), @@ -1659,7 +1664,7 @@ class Tag(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TagBinding(db.Model): +class TagBinding(db.Model): # type: ignore[name-defined] __tablename__ = "tag_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tag_binding_pkey"), @@ -1675,7 +1680,7 @@ class TagBinding(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TraceAppConfig(db.Model): +class TraceAppConfig(db.Model): # type: ignore[name-defined] __tablename__ = "trace_app_config" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), diff --git a/api/models/provider.py b/api/models/provider.py index fdd3e802d79211..abe673975c1ccc 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -36,7 +36,7 @@ def value_of(value): raise ValueError(f"No matching enum found for value '{value}'") -class Provider(db.Model): +class Provider(db.Model): # type: ignore[name-defined] """ Provider model representing the API providers and their configurations. """ @@ -89,7 +89,7 @@ def is_enabled(self): return self.is_valid and self.token_is_set -class ProviderModel(db.Model): +class ProviderModel(db.Model): # type: ignore[name-defined] """ Provider model representing the API provider_models and their configurations. """ @@ -114,7 +114,7 @@ class ProviderModel(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TenantDefaultModel(db.Model): +class TenantDefaultModel(db.Model): # type: ignore[name-defined] __tablename__ = "tenant_default_models" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), @@ -130,7 +130,7 @@ class TenantDefaultModel(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TenantPreferredModelProvider(db.Model): +class TenantPreferredModelProvider(db.Model): # type: ignore[name-defined] __tablename__ = "tenant_preferred_model_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), @@ -145,7 +145,7 @@ class TenantPreferredModelProvider(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ProviderOrder(db.Model): +class ProviderOrder(db.Model): # type: ignore[name-defined] __tablename__ = "provider_orders" __table_args__ = ( db.PrimaryKeyConstraint("id", name="provider_order_pkey"), @@ -170,7 +170,7 @@ class ProviderOrder(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ProviderModelSetting(db.Model): +class ProviderModelSetting(db.Model): # type: ignore[name-defined] """ Provider model settings for record the model enabled status and load balancing status. """ @@ -192,7 +192,7 @@ class ProviderModelSetting(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class LoadBalancingModelConfig(db.Model): +class LoadBalancingModelConfig(db.Model): # type: ignore[name-defined] """ Configurations for load balancing models. """ diff --git a/api/models/source.py b/api/models/source.py index 114db8e1100e5d..881cfaac7d3998 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -7,7 +7,7 @@ from .types import StringUUID -class DataSourceOauthBinding(db.Model): +class DataSourceOauthBinding(db.Model): # type: ignore[name-defined] __tablename__ = "data_source_oauth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="source_binding_pkey"), @@ -25,7 +25,7 @@ class DataSourceOauthBinding(db.Model): disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) -class DataSourceApiKeyAuthBinding(db.Model): +class DataSourceApiKeyAuthBinding(db.Model): # type: ignore[name-defined] __tablename__ = "data_source_api_key_auth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), diff --git a/api/models/task.py b/api/models/task.py index 27571e24746fe7..0db1c632299fcb 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,11 +1,11 @@ from datetime import UTC, datetime -from celery import states +from celery import states # type: ignore from .engine import db -class CeleryTask(db.Model): +class CeleryTask(db.Model): # type: ignore[name-defined] """Task result/status.""" __tablename__ = "celery_taskmeta" @@ -29,7 +29,7 @@ class CeleryTask(db.Model): queue = db.Column(db.String(155), nullable=True) -class CeleryTaskSet(db.Model): +class CeleryTaskSet(db.Model): # type: ignore[name-defined] """TaskSet result.""" __tablename__ = "celery_tasksetmeta" diff --git a/api/models/tools.py b/api/models/tools.py index e90ab669c66f1e..4151a2e9f636a0 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -14,7 +14,7 @@ from .types import StringUUID -class BuiltinToolProvider(db.Model): +class BuiltinToolProvider(db.Model): # type: ignore[name-defined] """ This table stores the tool provider information for built-in tools for each tenant. """ @@ -41,10 +41,10 @@ class BuiltinToolProvider(db.Model): @property def credentials(self) -> dict: - return json.loads(self.encrypted_credentials) + return dict(json.loads(self.encrypted_credentials)) -class PublishedAppTool(db.Model): +class PublishedAppTool(db.Model): # type: ignore[name-defined] """ The table stores the apps published as a tool for each person. """ @@ -86,7 +86,7 @@ def app(self): return db.session.query(App).filter(App.id == self.app_id).first() -class ApiToolProvider(db.Model): +class ApiToolProvider(db.Model): # type: ignore[name-defined] """ The table stores the api providers. """ @@ -133,7 +133,7 @@ def tools(self) -> list[ApiToolBundle]: @property def credentials(self) -> dict: - return json.loads(self.credentials_str) + return dict(json.loads(self.credentials_str)) @property def user(self) -> Account | None: @@ -144,7 +144,7 @@ def tenant(self) -> Tenant | None: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() -class ToolLabelBinding(db.Model): +class ToolLabelBinding(db.Model): # type: ignore[name-defined] """ The table stores the labels for tools. """ @@ -164,7 +164,7 @@ class ToolLabelBinding(db.Model): label_name = db.Column(db.String(40), nullable=False) -class WorkflowToolProvider(db.Model): +class WorkflowToolProvider(db.Model): # type: ignore[name-defined] """ The table stores the workflow providers. """ @@ -218,7 +218,7 @@ def app(self) -> App | None: return db.session.query(App).filter(App.id == self.app_id).first() -class ToolModelInvoke(db.Model): +class ToolModelInvoke(db.Model): # type: ignore[name-defined] """ store the invoke logs from tool invoke """ @@ -255,7 +255,7 @@ class ToolModelInvoke(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ToolConversationVariables(db.Model): +class ToolConversationVariables(db.Model): # type: ignore[name-defined] """ store the conversation variables from tool invoke """ @@ -283,10 +283,10 @@ class ToolConversationVariables(db.Model): @property def variables(self) -> dict: - return json.loads(self.variables_str) + return dict(json.loads(self.variables_str)) -class ToolFile(db.Model): +class ToolFile(db.Model): # type: ignore[name-defined] __tablename__ = "tool_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tool_file_pkey"), diff --git a/api/models/web.py b/api/models/web.py index 028a768519d99a..864428fe0931b6 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -6,7 +6,7 @@ from .types import StringUUID -class SavedMessage(db.Model): +class SavedMessage(db.Model): # type: ignore[name-defined] __tablename__ = "saved_messages" __table_args__ = ( db.PrimaryKeyConstraint("id", name="saved_message_pkey"), @@ -25,7 +25,7 @@ def message(self): return db.session.query(Message).filter(Message.id == self.message_id).first() -class PinnedConversation(db.Model): +class PinnedConversation(db.Model): # type: ignore[name-defined] __tablename__ = "pinned_conversations" __table_args__ = ( db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), diff --git a/api/models/workflow.py b/api/models/workflow.py index d5be949bf44f2a..880e044d073a67 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence from datetime import UTC, datetime from enum import Enum, StrEnum -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import sqlalchemy as sa from sqlalchemy import func @@ -20,6 +20,9 @@ from .engine import db from .types import StringUUID +if TYPE_CHECKING: + from models.model import AppMode, Message + class WorkflowType(Enum): """ @@ -56,7 +59,7 @@ def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT -class Workflow(db.Model): +class Workflow(db.Model): # type: ignore[name-defined] """ Workflow, for `Workflow App` and `Chat App workflow mode`. @@ -182,7 +185,7 @@ def features(self, value: str) -> None: self._features = value @property - def features_dict(self) -> Mapping[str, Any]: + def features_dict(self) -> dict[str, Any]: return json.loads(self.features) if self.features else {} def user_input_form(self, to_old_structure: bool = False) -> list: @@ -199,7 +202,7 @@ def user_input_form(self, to_old_structure: bool = False) -> list: return [] # get user_input_form from start node - variables = start_node.get("data", {}).get("variables", []) + variables: list[Any] = start_node.get("data", {}).get("variables", []) if to_old_structure: old_structure_variables = [] @@ -344,7 +347,7 @@ def value_of(cls, value: str) -> "WorkflowRunStatus": raise ValueError(f"invalid workflow run status value {value}") -class WorkflowRun(db.Model): +class WorkflowRun(db.Model): # type: ignore[name-defined] """ Workflow Run @@ -546,7 +549,7 @@ def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": raise ValueError(f"invalid workflow node execution status value {value}") -class WorkflowNodeExecution(db.Model): +class WorkflowNodeExecution(db.Model): # type: ignore[name-defined] """ Workflow Node Execution @@ -712,7 +715,7 @@ def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": raise ValueError(f"invalid workflow app log created from value {value}") -class WorkflowAppLog(db.Model): +class WorkflowAppLog(db.Model): # type: ignore[name-defined] """ Workflow App execution log, excluding workflow debugging records. @@ -774,7 +777,7 @@ def created_by_end_user(self): return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None -class ConversationVariable(db.Model): +class ConversationVariable(db.Model): # type: ignore[name-defined] __tablename__ = "workflow_conversation_variables" id: Mapped[str] = db.Column(StringUUID, primary_key=True) diff --git a/api/mypy.ini b/api/mypy.ini new file mode 100644 index 00000000000000..2c754f9fcd7c63 --- /dev/null +++ b/api/mypy.ini @@ -0,0 +1,10 @@ +[mypy] +warn_return_any = True +warn_unused_configs = True +check_untyped_defs = True +exclude = (?x)( + core/tools/provider/builtin/ + | core/model_runtime/model_providers/ + | tests/ + | migrations/ + ) \ No newline at end of file diff --git a/api/poetry.lock b/api/poetry.lock index 35fda9b36fa42a..b42eb22dd40b8a 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -5643,6 +5643,58 @@ files = [ {file = "multitasking-0.0.11.tar.gz", hash = "sha256:4d6bc3cc65f9b2dca72fb5a787850a88dae8f620c2b36ae9b55248e51bcd6026"}, ] +[[package]] +name = "mypy" +version = "1.13.0" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"}, + {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"}, + {file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"}, + {file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"}, + {file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d"}, + {file = "mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b"}, + {file = "mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73"}, + {file = "mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e"}, + {file = "mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2"}, + {file = "mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0"}, + {file = "mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62"}, + {file = "mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8"}, + {file = "mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7"}, + {file = "mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:100fac22ce82925f676a734af0db922ecfea991e1d7ec0ceb1e115ebe501301a"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bcb0bb7f42a978bb323a7c88f1081d1b5dee77ca86f4100735a6f541299d8fb"}, + {file = "mypy-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bde31fc887c213e223bbfc34328070996061b0833b0a4cfec53745ed61f3519b"}, + {file = "mypy-1.13.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07de989f89786f62b937851295ed62e51774722e5444a27cecca993fc3f9cd74"}, + {file = "mypy-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:4bde84334fbe19bad704b3f5b78c4abd35ff1026f8ba72b29de70dda0916beb6"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0246bcb1b5de7f08f2826451abd947bf656945209b140d16ed317f65a17dc7dc"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f5b7deae912cf8b77e990b9280f170381fdfbddf61b4ef80927edd813163732"}, + {file = "mypy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7029881ec6ffb8bc233a4fa364736789582c738217b133f1b55967115288a2bc"}, + {file = "mypy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3e38b980e5681f28f033f3be86b099a247b13c491f14bb8b1e1e134d23bb599d"}, + {file = "mypy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:a6789be98a2017c912ae6ccb77ea553bbaf13d27605d2ca20a76dfbced631b24"}, + {file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"}, + {file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +typing-extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -6537,6 +6589,21 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "pandas-stubs" +version = "2.2.3.241126" +description = "Type annotations for pandas" +optional = false +python-versions = ">=3.10" +files = [ + {file = "pandas_stubs-2.2.3.241126-py3-none-any.whl", hash = "sha256:74aa79c167af374fe97068acc90776c0ebec5266a6e5c69fe11e9c2cf51f2267"}, + {file = "pandas_stubs-2.2.3.241126.tar.gz", hash = "sha256:cf819383c6d9ae7d4dabf34cd47e1e45525bb2f312e6ad2939c2c204cb708acd"}, +] + +[package.dependencies] +numpy = ">=1.23.5" +types-pytz = ">=2022.1.1" + [[package]] name = "pathos" version = "0.3.3" @@ -9255,13 +9322,13 @@ sqlcipher = ["sqlcipher3_binary"] [[package]] name = "sqlparse" -version = "0.5.2" +version = "0.5.3" description = "A non-validating SQL parser." optional = false python-versions = ">=3.8" files = [ - {file = "sqlparse-0.5.2-py3-none-any.whl", hash = "sha256:e99bc85c78160918c3e1d9230834ab8d80fc06c59d03f8db2618f65f65dda55e"}, - {file = "sqlparse-0.5.2.tar.gz", hash = "sha256:9e37b35e16d1cc652a2545f0997c1deb23ea28fa1f3eefe609eee3063c3b105f"}, + {file = "sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca"}, + {file = "sqlparse-0.5.3.tar.gz", hash = "sha256:09f67787f56a0b16ecdbde1bfc7f5d9c3371ca683cfeaa8e6ff60b4807ec9272"}, ] [package.extras] @@ -9847,6 +9914,17 @@ rich = ">=10.11.0" shellingham = ">=1.3.0" typing-extensions = ">=3.7.4.3" +[[package]] +name = "types-pytz" +version = "2024.2.0.20241003" +description = "Typing stubs for pytz" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-pytz-2024.2.0.20241003.tar.gz", hash = "sha256:575dc38f385a922a212bac00a7d6d2e16e141132a3c955078f4a4fd13ed6cb44"}, + {file = "types_pytz-2024.2.0.20241003-py3-none-any.whl", hash = "sha256:3e22df1336c0c6ad1d29163c8fda82736909eb977281cb823c57f8bae07118b7"}, +] + [[package]] name = "types-requests" version = "2.32.0.20241016" @@ -10313,82 +10391,82 @@ ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic [[package]] name = "watchfiles" -version = "1.0.0" +version = "1.0.3" description = "Simple, modern and high performance file watching and code reload in python." optional = false python-versions = ">=3.9" files = [ - {file = "watchfiles-1.0.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:1d19df28f99d6a81730658fbeb3ade8565ff687f95acb59665f11502b441be5f"}, - {file = "watchfiles-1.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:28babb38cf2da8e170b706c4b84aa7e4528a6fa4f3ee55d7a0866456a1662041"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12ab123135b2f42517f04e720526d41448667ae8249e651385afb5cda31fedc0"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:13a4f9ee0cd25682679eea5c14fc629e2eaa79aab74d963bc4e21f43b8ea1877"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e1d9284cc84de7855fcf83472e51d32daf6f6cecd094160192628bc3fee1b78"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ee5edc939f53466b329bbf2e58333a5461e6c7b50c980fa6117439e2c18b42d"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5dccfc70480087567720e4e36ec381bba1ed68d7e5f368fe40c93b3b1eba0105"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c83a6d33a9eda0af6a7470240d1af487807adc269704fe76a4972dd982d16236"}, - {file = "watchfiles-1.0.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:905f69aad276639eff3893759a07d44ea99560e67a1cf46ff389cd62f88872a2"}, - {file = "watchfiles-1.0.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:09551237645d6bff3972592f2aa5424df9290e7a2e15d63c5f47c48cde585935"}, - {file = "watchfiles-1.0.0-cp310-none-win32.whl", hash = "sha256:d2b39aa8edd9e5f56f99a2a2740a251dc58515398e9ed5a4b3e5ff2827060755"}, - {file = "watchfiles-1.0.0-cp310-none-win_amd64.whl", hash = "sha256:2de52b499e1ab037f1a87cb8ebcb04a819bf087b1015a4cf6dcf8af3c2a2613e"}, - {file = "watchfiles-1.0.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:fbd0ab7a9943bbddb87cbc2bf2f09317e74c77dc55b1f5657f81d04666c25269"}, - {file = "watchfiles-1.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:774ef36b16b7198669ce655d4f75b4c3d370e7f1cbdfb997fb10ee98717e2058"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b4fb98100267e6a5ebaff6aaa5d20aea20240584647470be39fe4823012ac96"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0fc3bf0effa2d8075b70badfdd7fb839d7aa9cea650d17886982840d71fdeabf"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:648e2b6db53eca6ef31245805cd528a16f56fa4cc15aeec97795eaf713c11435"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa13d604fcb9417ae5f2e3de676e66aa97427d888e83662ad205bed35a313176"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:936f362e7ff28311b16f0b97ec51e8f2cc451763a3264640c6ed40fb252d1ee4"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:245fab124b9faf58430da547512d91734858df13f2ddd48ecfa5e493455ffccb"}, - {file = "watchfiles-1.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4ff9c7e84e8b644a8f985c42bcc81457240316f900fc72769aaedec9d088055a"}, - {file = "watchfiles-1.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9c9a8d8fd97defe935ef8dd53d562e68942ad65067cd1c54d6ed8a088b1d931d"}, - {file = "watchfiles-1.0.0-cp311-none-win32.whl", hash = "sha256:a0abf173975eb9dd17bb14c191ee79999e650997cc644562f91df06060610e62"}, - {file = "watchfiles-1.0.0-cp311-none-win_amd64.whl", hash = "sha256:2a825ba4b32c214e3855b536eb1a1f7b006511d8e64b8215aac06eb680642d84"}, - {file = "watchfiles-1.0.0-cp311-none-win_arm64.whl", hash = "sha256:a5a7a06cfc65e34fd0a765a7623c5ba14707a0870703888e51d3d67107589817"}, - {file = "watchfiles-1.0.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:28fb64b5843d94e2c2483f7b024a1280662a44409bedee8f2f51439767e2d107"}, - {file = "watchfiles-1.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e3750434c83b61abb3163b49c64b04180b85b4dabb29a294513faec57f2ffdb7"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bedf84835069f51c7b026b3ca04e2e747ea8ed0a77c72006172c72d28c9f69fc"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:90004553be36427c3d06ec75b804233f8f816374165d5225b93abd94ba6e7234"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b46e15c34d4e401e976d6949ad3a74d244600d5c4b88c827a3fdf18691a46359"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:487d15927f1b0bd24e7df921913399bb1ab94424c386bea8b267754d698f8f0e"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1ff236d7a3f4b0a42f699a22fc374ba526bc55048a70cbb299661158e1bb5e1f"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c01446626574561756067f00b37e6b09c8622b0fc1e9fdbc7cbcea328d4e514"}, - {file = "watchfiles-1.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b551c465a59596f3d08170bd7e1c532c7260dd90ed8135778038e13c5d48aa81"}, - {file = "watchfiles-1.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e1ed613ee107269f66c2df631ec0fc8efddacface85314d392a4131abe299f00"}, - {file = "watchfiles-1.0.0-cp312-none-win32.whl", hash = "sha256:5f75cd42e7e2254117cf37ff0e68c5b3f36c14543756b2da621408349bd9ca7c"}, - {file = "watchfiles-1.0.0-cp312-none-win_amd64.whl", hash = "sha256:cf517701a4a872417f4e02a136e929537743461f9ec6cdb8184d9a04f4843545"}, - {file = "watchfiles-1.0.0-cp312-none-win_arm64.whl", hash = "sha256:8a2127cd68950787ee36753e6d401c8ea368f73beaeb8e54df5516a06d1ecd82"}, - {file = "watchfiles-1.0.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:95de85c254f7fe8cbdf104731f7f87f7f73ae229493bebca3722583160e6b152"}, - {file = "watchfiles-1.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:533a7cbfe700e09780bb31c06189e39c65f06c7f447326fee707fd02f9a6e945"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2218e78e2c6c07b1634a550095ac2a429026b2d5cbcd49a594f893f2bb8c936"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9122b8fdadc5b341315d255ab51d04893f417df4e6c1743b0aac8bf34e96e025"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9272fdbc0e9870dac3b505bce1466d386b4d8d6d2bacf405e603108d50446940"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4a3b33c3aefe9067ebd87846806cd5fc0b017ab70d628aaff077ab9abf4d06b3"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bc338ce9f8846543d428260fa0f9a716626963148edc937d71055d01d81e1525"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ac778a460ea22d63c7e6fb0bc0f5b16780ff0b128f7f06e57aaec63bd339285"}, - {file = "watchfiles-1.0.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:53ae447f06f8f29f5ab40140f19abdab822387a7c426a369eb42184b021e97eb"}, - {file = "watchfiles-1.0.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1f73c2147a453315d672c1ad907abe6d40324e34a185b51e15624bc793f93cc6"}, - {file = "watchfiles-1.0.0-cp313-none-win32.whl", hash = "sha256:eba98901a2eab909dbd79681190b9049acc650f6111fde1845484a4450761e98"}, - {file = "watchfiles-1.0.0-cp313-none-win_amd64.whl", hash = "sha256:d562a6114ddafb09c33246c6ace7effa71ca4b6a2324a47f4b09b6445ea78941"}, - {file = "watchfiles-1.0.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3d94fd83ed54266d789f287472269c0def9120a2022674990bd24ad989ebd7a0"}, - {file = "watchfiles-1.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:48051d1c504448b2fcda71c5e6e3610ae45de6a0b8f5a43b961f250be4bdf5a8"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29cf884ad4285d23453c702ed03d689f9c0e865e3c85d20846d800d4787de00f"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d3572d4c34c4e9c33d25b3da47d9570d5122f8433b9ac6519dca49c2740d23cd"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c2696611182c85eb0e755b62b456f48debff484b7306b56f05478b843ca8ece"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:550109001920a993a4383b57229c717fa73627d2a4e8fcb7ed33c7f1cddb0c85"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b555a93c15bd2c71081922be746291d776d47521a00703163e5fbe6d2a402399"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:947ccba18a38b85c366dafeac8df2f6176342d5992ca240a9d62588b214d731f"}, - {file = "watchfiles-1.0.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ffd98a299b0a74d1b704ef0ed959efb753e656a4e0425c14e46ae4c3cbdd2919"}, - {file = "watchfiles-1.0.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f8c4f3a1210ed099a99e6a710df4ff2f8069411059ffe30fa5f9467ebed1256b"}, - {file = "watchfiles-1.0.0-cp39-none-win32.whl", hash = "sha256:1e176b6b4119b3f369b2b4e003d53a226295ee862c0962e3afd5a1c15680b4e3"}, - {file = "watchfiles-1.0.0-cp39-none-win_amd64.whl", hash = "sha256:2d9c0518fabf4a3f373b0a94bb9e4ea7a1df18dec45e26a4d182aa8918dee855"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f159ac795785cde4899e0afa539f4c723fb5dd336ce5605bc909d34edd00b79b"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:c3d258d78341d5d54c0c804a5b7faa66cd30ba50b2756a7161db07ce15363b8d"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bbd0311588c2de7f9ea5cf3922ccacfd0ec0c1922870a2be503cc7df1ca8be7"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9a13ac46b545a7d0d50f7641eefe47d1597e7d1783a5d89e09d080e6dff44b0"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2bca898c1dc073912d3db7fa6926cc08be9575add9e84872de2c99c688bac4e"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:06d828fe2adc4ac8a64b875ca908b892a3603d596d43e18f7948f3fef5fc671c"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:074c7618cd6c807dc4eaa0982b4a9d3f8051cd0b72793511848fd64630174b17"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95dc785bc284552d044e561b8f4fe26d01ab5ca40d35852a6572d542adfeb4bc"}, - {file = "watchfiles-1.0.0.tar.gz", hash = "sha256:37566c844c9ce3b5deb964fe1a23378e575e74b114618d211fbda8f59d7b5dab"}, + {file = "watchfiles-1.0.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:1da46bb1eefb5a37a8fb6fd52ad5d14822d67c498d99bda8754222396164ae42"}, + {file = "watchfiles-1.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2b961b86cd3973f5822826017cad7f5a75795168cb645c3a6b30c349094e02e3"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34e87c7b3464d02af87f1059fedda5484e43b153ef519e4085fe1a03dd94801e"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d9dd2b89a16cf7ab9c1170b5863e68de6bf83db51544875b25a5f05a7269e678"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b4691234d31686dca133c920f94e478b548a8e7c750f28dbbc2e4333e0d3da9"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:90b0fe1fcea9bd6e3084b44875e179b4adcc4057a3b81402658d0eb58c98edf8"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0b90651b4cf9e158d01faa0833b073e2e37719264bcee3eac49fc3c74e7d304b"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2e9fe695ff151b42ab06501820f40d01310fbd58ba24da8923ace79cf6d702d"}, + {file = "watchfiles-1.0.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62691f1c0894b001c7cde1195c03b7801aaa794a837bd6eef24da87d1542838d"}, + {file = "watchfiles-1.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:275c1b0e942d335fccb6014d79267d1b9fa45b5ac0639c297f1e856f2f532552"}, + {file = "watchfiles-1.0.3-cp310-cp310-win32.whl", hash = "sha256:06ce08549e49ba69ccc36fc5659a3d0ff4e3a07d542b895b8a9013fcab46c2dc"}, + {file = "watchfiles-1.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:f280b02827adc9d87f764972fbeb701cf5611f80b619c20568e1982a277d6146"}, + {file = "watchfiles-1.0.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ffe709b1d0bc2e9921257569675674cafb3a5f8af689ab9f3f2b3f88775b960f"}, + {file = "watchfiles-1.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:418c5ce332f74939ff60691e5293e27c206c8164ce2b8ce0d9abf013003fb7fe"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f492d2907263d6d0d52f897a68647195bc093dafed14508a8d6817973586b6b"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:48c9f3bc90c556a854f4cab6a79c16974099ccfa3e3e150673d82d47a4bc92c9"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75d3bcfa90454dba8df12adc86b13b6d85fda97d90e708efc036c2760cc6ba44"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5691340f259b8f76b45fb31b98e594d46c36d1dc8285efa7975f7f50230c9093"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1e263cc718545b7f897baeac1f00299ab6fabe3e18caaacacb0edf6d5f35513c"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c6cf7709ed3e55704cc06f6e835bf43c03bc8e3cb8ff946bf69a2e0a78d9d77"}, + {file = "watchfiles-1.0.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:703aa5e50e465be901e0e0f9d5739add15e696d8c26c53bc6fc00eb65d7b9469"}, + {file = "watchfiles-1.0.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bfcae6aecd9e0cb425f5145afee871465b98b75862e038d42fe91fd753ddd780"}, + {file = "watchfiles-1.0.3-cp311-cp311-win32.whl", hash = "sha256:6a76494d2c5311584f22416c5a87c1e2cb954ff9b5f0988027bc4ef2a8a67181"}, + {file = "watchfiles-1.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:cf745cbfad6389c0e331786e5fe9ae3f06e9d9c2ce2432378e1267954793975c"}, + {file = "watchfiles-1.0.3-cp311-cp311-win_arm64.whl", hash = "sha256:2dcc3f60c445f8ce14156854a072ceb36b83807ed803d37fdea2a50e898635d6"}, + {file = "watchfiles-1.0.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:93436ed550e429da007fbafb723e0769f25bae178fbb287a94cb4ccdf42d3af3"}, + {file = "watchfiles-1.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c18f3502ad0737813c7dad70e3e1cc966cc147fbaeef47a09463bbffe70b0a00"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a5bc3ca468bb58a2ef50441f953e1f77b9a61bd1b8c347c8223403dc9b4ac9a"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0d1ec043f02ca04bf21b1b32cab155ce90c651aaf5540db8eb8ad7f7e645cba8"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f58d3bfafecf3d81c15d99fc0ecf4319e80ac712c77cf0ce2661c8cf8bf84066"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1df924ba82ae9e77340101c28d56cbaff2c991bd6fe8444a545d24075abb0a87"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:632a52dcaee44792d0965c17bdfe5dc0edad5b86d6a29e53d6ad4bf92dc0ff49"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bf4b459d94a0387617a1b499f314aa04d8a64b7a0747d15d425b8c8b151da0"}, + {file = "watchfiles-1.0.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ca94c85911601b097d53caeeec30201736ad69a93f30d15672b967558df02885"}, + {file = "watchfiles-1.0.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:65ab1fb635476f6170b07e8e21db0424de94877e4b76b7feabfe11f9a5fc12b5"}, + {file = "watchfiles-1.0.3-cp312-cp312-win32.whl", hash = "sha256:49bc1bc26abf4f32e132652f4b3bfeec77d8f8f62f57652703ef127e85a3e38d"}, + {file = "watchfiles-1.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:48681c86f2cb08348631fed788a116c89c787fdf1e6381c5febafd782f6c3b44"}, + {file = "watchfiles-1.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:9e080cf917b35b20c889225a13f290f2716748362f6071b859b60b8847a6aa43"}, + {file = "watchfiles-1.0.3-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:e153a690b7255c5ced17895394b4f109d5dcc2a4f35cb809374da50f0e5c456a"}, + {file = "watchfiles-1.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ac1be85fe43b4bf9a251978ce5c3bb30e1ada9784290441f5423a28633a958a7"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2ec98e31e1844eac860e70d9247db9d75440fc8f5f679c37d01914568d18721"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0179252846be03fa97d4d5f8233d1c620ef004855f0717712ae1c558f1974a16"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:995c374e86fa82126c03c5b4630c4e312327ecfe27761accb25b5e1d7ab50ec8"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29b9cb35b7f290db1c31fb2fdf8fc6d3730cfa4bca4b49761083307f441cac5a"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f8dc09ae69af50bead60783180f656ad96bd33ffbf6e7a6fce900f6d53b08f1"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:489b80812f52a8d8c7b0d10f0d956db0efed25df2821c7a934f6143f76938bd6"}, + {file = "watchfiles-1.0.3-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:228e2247de583475d4cebf6b9af5dc9918abb99d1ef5ee737155bb39fb33f3c0"}, + {file = "watchfiles-1.0.3-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1550be1a5cb3be08a3fb84636eaafa9b7119b70c71b0bed48726fd1d5aa9b868"}, + {file = "watchfiles-1.0.3-cp313-cp313-win32.whl", hash = "sha256:16db2d7e12f94818cbf16d4c8938e4d8aaecee23826344addfaaa671a1527b07"}, + {file = "watchfiles-1.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:160eff7d1267d7b025e983ca8460e8cc67b328284967cbe29c05f3c3163711a3"}, + {file = "watchfiles-1.0.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c05b021f7b5aa333124f2a64d56e4cb9963b6efdf44e8d819152237bbd93ba15"}, + {file = "watchfiles-1.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:310505ad305e30cb6c5f55945858cdbe0eb297fc57378f29bacceb534ac34199"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ddff3f8b9fa24a60527c137c852d0d9a7da2a02cf2151650029fdc97c852c974"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:46e86ed457c3486080a72bc837300dd200e18d08183f12b6ca63475ab64ed651"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f79fe7993e230a12172ce7d7c7db061f046f672f2b946431c81aff8f60b2758b"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea2b51c5f38bad812da2ec0cd7eec09d25f521a8b6b6843cbccedd9a1d8a5c15"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fe4e740ea94978b2b2ab308cbf9270a246bcbb44401f77cc8740348cbaeac3d"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9af037d3df7188ae21dc1c7624501f2f90d81be6550904e07869d8d0e6766655"}, + {file = "watchfiles-1.0.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:52bb50a4c4ca2a689fdba84ba8ecc6a4e6210f03b6af93181bb61c4ec3abaf86"}, + {file = "watchfiles-1.0.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c14a07bdb475eb696f85c715dbd0f037918ccbb5248290448488a0b4ef201aad"}, + {file = "watchfiles-1.0.3-cp39-cp39-win32.whl", hash = "sha256:be37f9b1f8934cd9e7eccfcb5612af9fb728fecbe16248b082b709a9d1b348bf"}, + {file = "watchfiles-1.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:ef9ec8068cf23458dbf36a08e0c16f0a2df04b42a8827619646637be1769300a"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:84fac88278f42d61c519a6c75fb5296fd56710b05bbdcc74bdf85db409a03780"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:c68be72b1666d93b266714f2d4092d78dc53bd11cf91ed5a3c16527587a52e29"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:889a37e2acf43c377b5124166bece139b4c731b61492ab22e64d371cce0e6e80"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ca05cacf2e5c4a97d02a2878a24020daca21dbb8823b023b978210a75c79098"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:8af4b582d5fc1b8465d1d2483e5e7b880cc1a4e99f6ff65c23d64d070867ac58"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:127de3883bdb29dbd3b21f63126bb8fa6e773b74eaef46521025a9ce390e1073"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:713f67132346bdcb4c12df185c30cf04bdf4bf6ea3acbc3ace0912cab6b7cb8c"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abd85de513eb83f5ec153a802348e7a5baa4588b818043848247e3e8986094e8"}, + {file = "watchfiles-1.0.3.tar.gz", hash = "sha256:f3ff7da165c99a5412fe5dd2304dd2dbaaaa5da718aad942dcb3a178eaa70c56"}, ] [package.dependencies] @@ -11095,4 +11173,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "14476bf95504a4df4b8d5a5c6608c6aa3dae7499d27d1e41ef39d761cc7c693d" +content-hash = "f4accd01805cbf080c4c5295f97a06c8e4faec7365d2c43d0435e56b46461732" diff --git a/api/pyproject.toml b/api/pyproject.toml index da9eabecf55ccf..28e0305406a18b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -60,6 +60,7 @@ oci = "~2.135.1" openai = "~1.52.0" openpyxl = "~3.1.5" pandas = { version = "~2.2.2", extras = ["performance", "excel"] } +pandas-stubs = "~2.2.3.241009" psycopg2-binary = "~2.9.6" pycryptodome = "3.19.1" pydantic = "~2.9.2" @@ -84,6 +85,7 @@ tencentcloud-sdk-python-hunyuan = "~3.0.1158" tiktoken = "~0.8.0" tokenizers = "~0.15.0" transformers = "~4.35.0" +types-pytz = "~2024.2.0.20241003" unstructured = { version = "~0.16.1", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] } validators = "0.21.0" volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"} @@ -173,6 +175,7 @@ optional = true [tool.poetry.group.dev.dependencies] coverage = "~7.2.4" faker = "~32.1.0" +mypy = "~1.13.0" pytest = "~8.3.2" pytest-benchmark = "~4.0.0" pytest-env = "~1.1.3" diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 97e5c77e95361a..48bdc872f41e5c 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -32,8 +32,9 @@ def clean_messages(): while True: try: # Main query with join and filter + # FIXME:for mypy no paginate method error messages = ( - db.session.query(Message) + db.session.query(Message) # type: ignore .filter(Message.created_at < plan_sandbox_clean_message_day) .order_by(Message.created_at.desc()) .limit(100) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index e12be649e4d02d..f66b3c47979435 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -52,8 +52,7 @@ def clean_unused_datasets_task(): # Main query with join and filter datasets = ( - db.session.query(Dataset) - .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( Dataset.created_at < plan_sandbox_clean_day, @@ -120,8 +119,7 @@ def clean_unused_datasets_task(): # Main query with join and filter datasets = ( - db.session.query(Dataset) - .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( Dataset.created_at < plan_pro_clean_day, diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index a20b500308a4d6..1c985461c6aa2e 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -36,14 +36,15 @@ def create_tidb_serverless_task(): def create_clusters(batch_size): try: + # TODO: maybe we can set the default value for the following parameters in the config file new_clusters = TidbService.batch_create_tidb_serverless_cluster( - batch_size, - dify_config.TIDB_PROJECT_ID, - dify_config.TIDB_API_URL, - dify_config.TIDB_IAM_API_URL, - dify_config.TIDB_PUBLIC_KEY, - dify_config.TIDB_PRIVATE_KEY, - dify_config.TIDB_REGION, + batch_size=batch_size, + project_id=dify_config.TIDB_PROJECT_ID or "", + api_url=dify_config.TIDB_API_URL or "", + iam_url=dify_config.TIDB_IAM_API_URL or "", + public_key=dify_config.TIDB_PUBLIC_KEY or "", + private_key=dify_config.TIDB_PRIVATE_KEY or "", + region=dify_config.TIDB_REGION or "", ) for new_cluster in new_clusters: tidb_auth_binding = TidbAuthBinding( diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index b2d8746f9ca8f4..11a39e60ee4ce5 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -36,13 +36,14 @@ def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): # batch 20 for i in range(0, len(tidb_serverless_list), 20): items = tidb_serverless_list[i : i + 20] + # TODO: maybe we can set the default value for the following parameters in the config file TidbService.batch_update_tidb_serverless_cluster_status( - items, - dify_config.TIDB_PROJECT_ID, - dify_config.TIDB_API_URL, - dify_config.TIDB_IAM_API_URL, - dify_config.TIDB_PUBLIC_KEY, - dify_config.TIDB_PRIVATE_KEY, + tidb_serverless_list=items, + project_id=dify_config.TIDB_PROJECT_ID or "", + api_url=dify_config.TIDB_API_URL or "", + iam_url=dify_config.TIDB_IAM_API_URL or "", + public_key=dify_config.TIDB_PUBLIC_KEY or "", + private_key=dify_config.TIDB_PRIVATE_KEY or "", ) except Exception as e: click.echo(click.style(f"Error: {e}", fg="red")) diff --git a/api/services/account_service.py b/api/services/account_service.py index 22b54a3ab87473..91075ec46b16bf 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -6,7 +6,7 @@ import uuid from datetime import UTC, datetime, timedelta from hashlib import sha256 -from typing import Any, Optional +from typing import Any, Optional, cast from pydantic import BaseModel from sqlalchemy import func @@ -119,7 +119,7 @@ def load_user(user_id: str) -> None | Account: account.last_active_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return account + return cast(Account, account) @staticmethod def get_account_jwt_token(account: Account) -> str: @@ -132,7 +132,7 @@ def get_account_jwt_token(account: Account) -> str: "sub": "Console API Passport", } - token = PassportService().issue(payload) + token: str = PassportService().issue(payload) return token @staticmethod @@ -164,7 +164,7 @@ def authenticate(email: str, password: str, invite_token: Optional[str] = None) db.session.commit() - return account + return cast(Account, account) @staticmethod def update_account_password(account, password, new_password): @@ -347,6 +347,8 @@ def send_reset_password_email( language: Optional[str] = "en-US", ): account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") if cls.reset_password_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import PasswordResetRateLimitExceededError @@ -377,6 +379,8 @@ def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: def send_email_code_login_email( cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" ): + if email is None: + raise ValueError("Email must be provided.") if cls.email_code_login_rate_limiter.is_rate_limited(email): from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError @@ -669,7 +673,7 @@ def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoi @staticmethod def get_tenant_count() -> int: """Get tenant count""" - return db.session.query(func.count(Tenant.id)).scalar() + return cast(int, db.session.query(func.count(Tenant.id)).scalar()) @staticmethod def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None: @@ -733,10 +737,10 @@ def dissolve_tenant(tenant: Tenant, operator: Account) -> None: db.session.commit() @staticmethod - def get_custom_config(tenant_id: str) -> None: - tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).one_or_404() + def get_custom_config(tenant_id: str) -> dict: + tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404() - return tenant.custom_config_dict + return cast(dict, tenant.custom_config_dict) class RegisterService: @@ -807,7 +811,7 @@ def register( account.status = AccountStatus.ACTIVE.value if not status else status.value account.initialized_at = datetime.now(UTC).replace(tzinfo=None) - if open_id is not None or provider is not None: + if open_id is not None and provider is not None: AccountService.link_account_integrate(provider, open_id, account) if FeatureService.get_system_features().is_allow_create_workspace: @@ -828,10 +832,11 @@ def register( @classmethod def invite_new_member( - cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None + cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Optional[Account] = None ) -> str: """Invite new member""" account = Account.query.filter_by(email=email).first() + assert inviter is not None, "Inviter must be provided." if not account: TenantService.check_member_permission(tenant, inviter, None, "add") @@ -894,7 +899,9 @@ def revoke_token(cls, workspace_id: str, email: str, token: str): redis_client.delete(cls._get_invitation_token_key(token)) @classmethod - def get_invitation_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[dict[str, Any]]: + def get_invitation_if_token_valid( + cls, workspace_id: Optional[str], email: str, token: str + ) -> Optional[dict[str, Any]]: invitation_data = cls._get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None @@ -953,7 +960,7 @@ def _get_invitation_by_token( if not data: return None - invitation = json.loads(data) + invitation: dict = json.loads(data) return invitation diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index d2cd7bea67c5b6..6dc1affa11d036 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -48,6 +48,8 @@ def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> return cls.get_chat_prompt( copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt ) + # default return empty dict + return {} @classmethod def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: @@ -91,3 +93,5 @@ def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) - return cls.get_chat_prompt( copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt ) + # default return empty dict + return {} diff --git a/api/services/agent_service.py b/api/services/agent_service.py index c8819535f11a39..b02f762ad267b8 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -1,5 +1,7 @@ +from typing import Optional + import pytz -from flask_login import current_user +from flask_login import current_user # type: ignore from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager from core.tools.tool_manager import ToolManager @@ -14,7 +16,7 @@ def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) - """ Service to get agent logs """ - conversation: Conversation = ( + conversation: Optional[Conversation] = ( db.session.query(Conversation) .filter( Conversation.id == conversation_id, @@ -26,7 +28,7 @@ def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) - if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Message = ( + message: Optional[Message] = ( db.session.query(Message) .filter( Message.id == message_id, @@ -72,7 +74,10 @@ def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) - } agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict()) - agent_tools = agent_config.tools + if not agent_config: + return result + + agent_tools = agent_config.tools or [] def find_agent_tool(tool_name: str): for agent_tool in agent_tools: diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index f45c21cb18f5e3..a946405c955cec 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,8 +1,9 @@ import datetime import uuid +from typing import cast import pandas as pd -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy import or_ from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound @@ -71,7 +72,7 @@ def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> Messa app_id, annotation_setting.collection_binding_id, ) - return annotation + return cast(MessageAnnotation, annotation) @classmethod def enable_app_annotation(cls, args: dict, app_id: str) -> dict: @@ -124,8 +125,7 @@ def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keywo raise NotFound("App not found") if keyword: annotations = ( - db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) + MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) .filter( or_( MessageAnnotation.question.ilike("%{}%".format(keyword)), @@ -137,8 +137,7 @@ def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keywo ) else: annotations = ( - db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) + MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) ) @@ -327,8 +326,7 @@ def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, lim raise NotFound("Annotation not found") annotation_hit_histories = ( - db.session.query(AppAnnotationHitHistory) - .filter( + AppAnnotationHitHistory.query.filter( AppAnnotationHitHistory.app_id == app_id, AppAnnotationHitHistory.annotation_id == annotation_id, ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 7c1a175988071e..b191fa2397fa9e 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -1,7 +1,7 @@ import logging import uuid from enum import StrEnum -from typing import Optional +from typing import Optional, cast from uuid import uuid4 import yaml @@ -103,7 +103,7 @@ def import_app( raise ValueError(f"Invalid import_mode: {import_mode}") # Get YAML content - content = "" + content: bytes | str = b"" if mode == ImportMode.YAML_URL: if not yaml_url: return Import( @@ -136,7 +136,7 @@ def import_app( ) try: - content = content.decode("utf-8") + content = cast(bytes, content).decode("utf-8") except UnicodeDecodeError as e: return Import( id=import_id, @@ -362,6 +362,9 @@ def _create_or_update_app( app.icon_background = icon_background or app_data.get("icon_background", app.icon_background) app.updated_by = account.id else: + if account.current_tenant_id is None: + raise ValueError("Current tenant is not set") + # Create new app app = App() app.id = str(uuid4()) diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 9def7d15e928d4..51aef7ccab9a0c 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -118,7 +118,7 @@ def generate( @staticmethod def _get_max_active_requests(app_model: App) -> int: max_active_requests = app_model.max_active_requests - if app_model.max_active_requests is None: + if max_active_requests is None: max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS) return max_active_requests @@ -150,7 +150,7 @@ def generate_more_like_this( message_id: str, invoke_from: InvokeFrom, streaming: bool = True, - ) -> Union[dict, Generator]: + ) -> Union[Mapping, Generator]: """ Generate more like this :param app_model: app model diff --git a/api/services/app_service.py b/api/services/app_service.py index 8d8ba735ecfa71..41c15bbf0a330b 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,9 +1,9 @@ import json import logging from datetime import UTC, datetime -from typing import cast +from typing import Optional, cast -from flask_login import current_user +from flask_login import current_user # type: ignore from flask_sqlalchemy.pagination import Pagination from configs import dify_config @@ -83,7 +83,7 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: # get default model instance try: model_instance = model_manager.get_default_model_instance( - tenant_id=account.current_tenant_id, model_type=ModelType.LLM + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM ) except (ProviderTokenNotInitError, LLMBadRequestError): model_instance = None @@ -100,6 +100,8 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: else: llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + if model_schema is None: + raise ValueError(f"model schema not found for model {model_instance.model}") default_model_dict = { "provider": model_instance.provider, @@ -109,7 +111,7 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: } else: provider, model = model_manager.get_default_provider_model_name( - tenant_id=account.current_tenant_id, model_type=ModelType.LLM + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM ) default_model_config["model"]["provider"] = provider default_model_config["model"]["name"] = model @@ -314,7 +316,7 @@ def get_app_meta(self, app_model: App) -> dict: """ app_mode = AppMode.value_of(app_model.mode) - meta = {"tool_icons": {}} + meta: dict = {"tool_icons": {}} if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow @@ -336,7 +338,7 @@ def get_app_meta(self, app_model: App) -> dict: } ) else: - app_model_config: AppModelConfig = app_model.app_model_config + app_model_config: Optional[AppModelConfig] = app_model.app_model_config if not app_model_config: return meta @@ -352,16 +354,18 @@ def get_app_meta(self, app_model: App) -> dict: keys = list(tool.keys()) if len(keys) >= 4: # current tool standard - provider_type = tool.get("provider_type") - provider_id = tool.get("provider_id") - tool_name = tool.get("tool_name") + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + tool_name = tool.get("tool_name", "") if provider_type == "builtin": meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" elif provider_type == "api": try: - provider: ApiToolProvider = ( + provider: Optional[ApiToolProvider] = ( db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first() ) + if provider is None: + raise ValueError(f"provider not found for tool {tool_name}") meta["tool_icons"][tool_name] = json.loads(provider.icon) except: meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 7a0cd5725b2a96..973110f5156523 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -110,6 +110,8 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): voices = model_instance.get_tts_voices() if voices: voice = voices[0].get("value") + if not voice: + raise ValueError("Sorry, no voice available.") else: raise ValueError("Sorry, no voice available.") @@ -121,6 +123,8 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): if message_id: message = db.session.query(Message).filter(Message.id == message_id).first() + if message is None: + return None if message.answer == "" and message.status == "normal": return None @@ -130,6 +134,8 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): return Response(stream_with_context(response), content_type="audio/mpeg") return response else: + if not text: + raise ValueError("Text is required") response = invoke_tts(text, app_model, voice) if isinstance(response, Generator): return Response(stream_with_context(response), content_type="audio/mpeg") diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py index afc491398f25f3..50e4edff140346 100644 --- a/api/services/auth/firecrawl/firecrawl.py +++ b/api/services/auth/firecrawl/firecrawl.py @@ -11,8 +11,8 @@ def __init__(self, credentials: dict): auth_type = credentials.get("auth_type") if auth_type != "bearer": raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer") - self.api_key = credentials.get("config").get("api_key", None) - self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev") + self.api_key = credentials.get("config", {}).get("api_key", None) + self.base_url = credentials.get("config", {}).get("base_url", "https://api.firecrawl.dev") if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index de898a1f94b763..6100e9afc8f278 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -11,7 +11,7 @@ def __init__(self, credentials: dict): auth_type = credentials.get("auth_type") if auth_type != "bearer": raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") - self.api_key = credentials.get("config").get("api_key", None) + self.api_key = credentials.get("config", {}).get("api_key", None) if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index de898a1f94b763..6100e9afc8f278 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -11,7 +11,7 @@ def __init__(self, credentials: dict): auth_type = credentials.get("auth_type") if auth_type != "bearer": raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") - self.api_key = credentials.get("config").get("api_key", None) + self.api_key = credentials.get("config", {}).get("api_key", None) if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/billing_service.py b/api/services/billing_service.py index edc51682179cc5..d98018648839a9 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,4 +1,5 @@ import os +from typing import Optional import httpx from tenacity import retry, retry_if_not_exception_type, stop_before_delay, wait_fixed @@ -58,11 +59,14 @@ def _send_request(cls, method, endpoint, json=None, params=None): def is_tenant_owner_or_admin(current_user): tenant_id = current_user.current_tenant_id - join = ( + join: Optional[TenantAccountJoin] = ( db.session.query(TenantAccountJoin) .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) .first() ) + if not join: + raise ValueError("Tenant account join not found") + if not TenantAccountRole.is_privileged_role(join.role): raise ValueError("Only team owner or team admin can perform this action") diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 456dc3ebebaa28..6485cbf37d5b7f 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -72,8 +72,7 @@ def pagination_by_last_id( sort_direction=sort_direction, reference_conversation=current_page_last_conversation, ) - count_stmt = stmt.where(rest_filter_condition) - count_stmt = select(func.count()).select_from(count_stmt.subquery()) + count_stmt = select(func.count()).select_from(stmt.where(rest_filter_condition).subquery()) rest_count = session.scalar(count_stmt) or 0 if rest_count > 0: has_more = True diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 4e99c73ad4787a..d2d8a718d55c8a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,7 +6,7 @@ import uuid from typing import Any, Optional -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy import func from werkzeug.exceptions import NotFound @@ -186,8 +186,9 @@ def create_empty_dataset( return dataset @staticmethod - def get_dataset(dataset_id) -> Dataset: - return Dataset.query.filter_by(id=dataset_id).first() + def get_dataset(dataset_id) -> Optional[Dataset]: + dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first() + return dataset @staticmethod def check_dataset_model_setting(dataset): @@ -228,6 +229,8 @@ def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, @staticmethod def update_dataset(dataset_id, data, user): dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise ValueError("Dataset not found") DatasetService.check_dataset_permission(dataset, user) if dataset.provider == "external": @@ -371,7 +374,13 @@ def check_dataset_permission(dataset, user): raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod - def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None): + def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None): + if not dataset: + raise ValueError("Dataset not found") + + if not user: + raise ValueError("User not found") + if dataset.permission == DatasetPermissionEnum.ONLY_ME: if dataset.created_by != user.id: raise NoPermissionError("You do not have permission to access this dataset.") @@ -765,6 +774,11 @@ def save_document_with_dataset_id( rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) + else: + logging.warn( + f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule" + ) + return db.session.add(dataset_process_rule) db.session.commit() lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) @@ -1009,9 +1023,10 @@ def update_document_with_dataset_id( rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) - db.session.add(dataset_process_rule) - db.session.commit() - document.dataset_process_rule_id = dataset_process_rule.id + if dataset_process_rule is not None: + db.session.add(dataset_process_rule) + db.session.commit() + document.dataset_process_rule_id = dataset_process_rule.id # update document data source if document_data.get("data_source"): file_name = "" @@ -1554,7 +1569,7 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document segment.word_count = len(content) if document.doc_form == "qa_model": segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer) + segment.word_count += len(segment_update_entity.answer or "") word_count_change = segment.word_count - word_count_change if segment_update_entity.keywords: segment.keywords = segment_update_entity.keywords @@ -1569,7 +1584,8 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document db.session.add(document) # update segment index task if segment_update_entity.enabled: - VectorService.create_segments_vector([segment_update_entity.keywords], [segment], dataset) + keywords = segment_update_entity.keywords or [] + VectorService.create_segments_vector([keywords], [segment], dataset) else: segment_hash = helper.generate_text_hash(content) tokens = 0 @@ -1601,7 +1617,7 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document segment.disabled_by = None if document.doc_form == "qa_model": segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer) + segment.word_count += len(segment_update_entity.answer or "") word_count_change = segment.word_count - word_count_change # update document word count if word_count_change != 0: @@ -1619,8 +1635,8 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document segment.status = "error" segment.error = str(e) db.session.commit() - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() - return segment + new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() + return new_segment @classmethod def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): @@ -1680,6 +1696,8 @@ def get_dataset_collection_binding_by_id_and_type( .order_by(DatasetCollectionBinding.created_at) .first() ) + if not dataset_collection_binding: + raise ValueError("Dataset collection binding not found") return dataset_collection_binding diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index 92098f06cca538..3c3f9704440342 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -8,8 +8,8 @@ class EnterpriseRequest: secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") proxies = { - "http": None, - "https": None, + "http": "", + "https": "", } @classmethod diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index c519f0b0e51b68..334d009ee5f79f 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -4,7 +4,11 @@ from pydantic import BaseModel, ConfigDict from configs import dify_config -from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity +from core.entities.model_entities import ( + ModelWithProviderEntity, + ProviderModelWithStatusEntity, + SimpleModelProviderEntity, +) from core.entities.provider_entities import QuotaConfiguration from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType @@ -148,7 +152,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity): Model with provider entity. """ - provider: SimpleProviderEntityResponse + provider: SimpleModelProviderEntity def __init__(self, model: ModelWithProviderEntity) -> None: super().__init__(**model.model_dump()) diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 7be20301a74b78..898624066bef7e 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,7 +1,7 @@ import json from copy import deepcopy from datetime import UTC, datetime -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast import httpx import validators @@ -45,7 +45,10 @@ def validate_api_list(cls, api_settings: dict): @staticmethod def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis: - ExternalDatasetService.check_endpoint_and_api_key(args.get("settings")) + settings = args.get("settings") + if settings is None: + raise ValueError("settings is required") + ExternalDatasetService.check_endpoint_and_api_key(settings) external_knowledge_api = ExternalKnowledgeApis( tenant_id=tenant_id, created_by=user_id, @@ -86,11 +89,16 @@ def check_endpoint_and_api_key(settings: dict): @staticmethod def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: - return ExternalKnowledgeApis.query.filter_by(id=external_knowledge_api_id).first() + external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( + id=external_knowledge_api_id + ).first() + if external_knowledge_api is None: + raise ValueError("api template not found") + return external_knowledge_api @staticmethod def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: - external_knowledge_api = ExternalKnowledgeApis.query.filter_by( + external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( id=external_knowledge_api_id, tenant_id=tenant_id ).first() if external_knowledge_api is None: @@ -127,7 +135,7 @@ def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bo @staticmethod def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: - external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( + external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by( dataset_id=dataset_id, tenant_id=tenant_id ).first() if not external_knowledge_binding: @@ -163,8 +171,9 @@ def process_external_api( "follow_redirects": True, } - response = getattr(ssrf_proxy, settings.request_method)(data=json.dumps(settings.params), files=files, **kwargs) - + response: httpx.Response = getattr(ssrf_proxy, settings.request_method)( + data=json.dumps(settings.params), files=files, **kwargs + ) return response @staticmethod @@ -265,15 +274,15 @@ def fetch_external_knowledge_retrieval( "knowledge_id": external_knowledge_binding.external_knowledge_id, } - external_knowledge_api_setting = { - "url": f"{settings.get('endpoint')}/retrieval", - "request_method": "post", - "headers": headers, - "params": request_params, - } response = ExternalDatasetService.process_external_api( - ExternalKnowledgeApiSetting(**external_knowledge_api_setting), None + ExternalKnowledgeApiSetting( + url=f"{settings.get('endpoint')}/retrieval", + request_method="post", + headers=headers, + params=request_params, + ), + None, ) if response.status_code == 200: - return response.json().get("records", []) + return cast(list[Any], response.json().get("records", [])) return [] diff --git a/api/services/file_service.py b/api/services/file_service.py index b12b95ca13558c..d417e81734c8af 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -3,7 +3,7 @@ import uuid from typing import Any, Literal, Union -from flask_login import current_user +from flask_login import current_user # type: ignore from werkzeug.exceptions import NotFound from configs import dify_config @@ -61,14 +61,14 @@ def upload_file( # end_user current_tenant_id = user.tenant_id - file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension + file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension # save file to storage storage.save(file_key, content) # save file to db upload_file = UploadFile( - tenant_id=current_tenant_id, + tenant_id=current_tenant_id or "", storage_type=dify_config.STORAGE_TYPE, key=file_key, name=filename, diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 7957b4dc82dfd4..41b4e1ec46374a 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,5 +1,6 @@ import logging import time +from typing import Any from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document @@ -24,7 +25,7 @@ def retrieve( dataset: Dataset, query: str, account: Account, - retrieval_model: dict, + retrieval_model: Any, # FIXME drop this any external_retrieval_model: dict, limit: int = 10, ) -> dict: @@ -68,7 +69,7 @@ def retrieve( db.session.add(dataset_query) db.session.commit() - return cls.compact_retrieve_response(dataset, query, all_documents) + return dict(cls.compact_retrieve_response(dataset, query, all_documents)) @classmethod def external_retrieve( @@ -102,13 +103,16 @@ def external_retrieve( db.session.add(dataset_query) db.session.commit() - return cls.compact_external_retrieve_response(dataset, query, all_documents) + return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) @classmethod def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): records = [] for document in documents: + if document.metadata is None: + continue + index_node_id = document.metadata["doc_id"] segment = ( @@ -140,7 +144,7 @@ def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list } @classmethod - def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list): + def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]: records = [] if dataset.provider == "external": for document in documents: @@ -152,11 +156,10 @@ def compact_external_retrieve_response(cls, dataset: Dataset, query: str, docume } records.append(record) return { - "query": { - "content": query, - }, + "query": {"content": query}, "records": records, } + return {"query": {"content": query}, "records": []} @classmethod def hit_testing_args_check(cls, args): diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py index 02fe1d19bc42be..8df1a6ba144d4e 100644 --- a/api/services/knowledge_service.py +++ b/api/services/knowledge_service.py @@ -1,4 +1,4 @@ -import boto3 +import boto3 # type: ignore from configs import dify_config diff --git a/api/services/message_service.py b/api/services/message_service.py index be2922f4c58e76..c4447a84da5e09 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -157,7 +157,7 @@ def create_feedback( user: Optional[Union[Account, EndUser]], rating: Optional[str], content: Optional[str], - ) -> MessageFeedback: + ): if not user: raise ValueError("user cannot be None") @@ -264,6 +264,8 @@ def get_suggested_questions_after_answer( ) app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) + if not app_model_config: + raise ValueError("did not find app model config") suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict if suggested_questions_after_answer.get("enabled", False) is False: @@ -285,7 +287,7 @@ def get_suggested_questions_after_answer( ) with measure_time() as timer: - questions = LLMGenerator.generate_suggested_questions_after_answer( + questions: list[Message] = LLMGenerator.generate_suggested_questions_after_answer( tenant_id=app_model.tenant_id, histories=histories ) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index b20bda87551ca9..bacd3a8ec3d04f 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -2,7 +2,7 @@ import json import logging from json import JSONDecodeError -from typing import Optional +from typing import Optional, Union from constants import HIDDEN_VALUE from core.entities.provider_configuration import ProviderConfiguration @@ -88,11 +88,11 @@ def get_load_balancing_configs( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) # Get provider model setting provider_model_setting = provider_configuration.get_provider_model_setting( - model_type=model_type, + model_type=model_type_enum, model=model, ) @@ -106,7 +106,7 @@ def get_load_balancing_configs( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) .order_by(LoadBalancingModelConfig.created_at) @@ -124,7 +124,7 @@ def get_load_balancing_configs( if not inherit_config_exists: # Initialize the inherit configuration - inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type) + inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type_enum) # prepend the inherit configuration load_balancing_configs.insert(0, inherit_config) @@ -148,7 +148,7 @@ def get_load_balancing_configs( tenant_id=tenant_id, provider=provider, model=model, - model_type=model_type, + model_type=model_type_enum, config_id=load_balancing_config.id, ) @@ -214,7 +214,7 @@ def get_load_balancing_config( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) # Get load balancing configurations load_balancing_model_config = ( @@ -222,7 +222,7 @@ def get_load_balancing_config( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.id == config_id, ) @@ -300,7 +300,7 @@ def update_load_balancing_configs( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) if not isinstance(configs, list): raise ValueError("Invalid load balancing configs") @@ -310,7 +310,7 @@ def update_load_balancing_configs( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) .all() @@ -359,7 +359,7 @@ def update_load_balancing_configs( credentials = self._custom_credentials_validate( tenant_id=tenant_id, provider_configuration=provider_configuration, - model_type=model_type, + model_type=model_type_enum, model=model, credentials=credentials, load_balancing_model_config=load_balancing_config, @@ -395,7 +395,7 @@ def update_load_balancing_configs( credentials = self._custom_credentials_validate( tenant_id=tenant_id, provider_configuration=provider_configuration, - model_type=model_type, + model_type=model_type_enum, model=model, credentials=credentials, validate=False, @@ -405,7 +405,7 @@ def update_load_balancing_configs( load_balancing_model_config = LoadBalancingModelConfig( tenant_id=tenant_id, provider_name=provider_configuration.provider.provider, - model_type=model_type.to_origin_model_type(), + model_type=model_type_enum.to_origin_model_type(), model_name=model, name=name, encrypted_config=json.dumps(credentials), @@ -450,7 +450,7 @@ def validate_load_balancing_credentials( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) load_balancing_model_config = None if config_id: @@ -460,7 +460,7 @@ def validate_load_balancing_credentials( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.id == config_id, ) @@ -474,7 +474,7 @@ def validate_load_balancing_credentials( self._custom_credentials_validate( tenant_id=tenant_id, provider_configuration=provider_configuration, - model_type=model_type, + model_type=model_type_enum, model=model, credentials=credentials, load_balancing_model_config=load_balancing_model_config, @@ -547,19 +547,14 @@ def _custom_credentials_validate( def _get_credential_schema( self, provider_configuration: ProviderConfiguration - ) -> ModelCredentialSchema | ProviderCredentialSchema: - """ - Get form schemas. - :param provider_configuration: provider configuration - :return: - """ - # Get credential form schemas from model credential schema or provider credential schema + ) -> Union[ModelCredentialSchema, ProviderCredentialSchema]: + """Get form schemas.""" if provider_configuration.provider.model_credential_schema: - credential_schema = provider_configuration.provider.model_credential_schema + return provider_configuration.provider.model_credential_schema + elif provider_configuration.provider.provider_credential_schema: + return provider_configuration.provider.provider_credential_schema else: - credential_schema = provider_configuration.provider.provider_credential_schema - - return credential_schema + raise ValueError("No credential schema found") def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None: """ diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 384a072b371fdd..b10c5ad2d616e9 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -7,7 +7,7 @@ import requests from flask import current_app -from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity +from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity from core.model_runtime.entities.model_entities import ModelType, ParameterRule from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -100,23 +100,15 @@ def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWit ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider) ] - def get_provider_credentials(self, tenant_id: str, provider: str) -> dict: + def get_provider_credentials(self, tenant_id: str, provider: str): """ get provider credentials. - - :param tenant_id: - :param provider: - :return: """ - # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Get provider custom credentials from workspace return provider_configuration.get_custom_credentials(obfuscated=True) def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None: @@ -176,7 +168,7 @@ def remove_provider_credentials(self, tenant_id: str, provider: str) -> None: # Remove custom provider credentials. provider_configuration.delete_custom_credentials() - def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> dict: + def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str): """ get model credentials. @@ -287,7 +279,7 @@ def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[Prov models = provider_configurations.get_models(model_type=ModelType.value_of(model_type)) # Group models by provider - provider_models = {} + provider_models: dict[str, list[ModelWithProviderEntity]] = {} for model in models: if model.provider.provider not in provider_models: provider_models[model.provider.provider] = [] @@ -362,7 +354,7 @@ def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) - return [] # Call get_parameter_rules method of model instance to get model parameter rules - return model_type_instance.get_parameter_rules(model=model, credentials=credentials) + return list(model_type_instance.get_parameter_rules(model=model, credentials=credentials)) def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]: """ @@ -422,6 +414,7 @@ def get_model_provider_icon( """ provider_instance = model_provider_factory.get_provider_instance(provider) provider_schema = provider_instance.get_provider_schema() + file_name: str | None = None if icon_type.lower() == "icon_small": if not provider_schema.icon_small: @@ -439,6 +432,8 @@ def get_model_provider_icon( file_name = provider_schema.icon_large.zh_Hans else: file_name = provider_schema.icon_large.en_US + if not file_name: + return None, None root_path = current_app.root_path provider_instance_path = os.path.dirname( @@ -524,7 +519,7 @@ def disable_model(self, tenant_id: str, provider: str, model: str, model_type: s def free_quota_submit(self, tenant_id: str, provider: str): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") - api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") + api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "") api_url = api_base_url + "/api/v1/providers/apply" headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} @@ -545,7 +540,7 @@ def free_quota_submit(self, tenant_id: str, provider: str): def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") - api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") + api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "") api_url = api_base_url + "/api/v1/providers/qualification-verify" headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index dfb21e767fc9b9..082afeed89a5e4 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -1,3 +1,5 @@ +from typing import Optional + from core.moderation.factory import ModerationFactory, ModerationOutputsResult from extensions.ext_database import db from models.model import App, AppModelConfig @@ -5,7 +7,7 @@ class ModerationService: def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: - app_model_config: AppModelConfig = None + app_model_config: Optional[AppModelConfig] = None app_model_config = ( db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 1160a1f2751d74..fc1e08518b1945 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,3 +1,5 @@ +from typing import Optional + from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map from extensions.ext_database import db from models.model import App, TraceAppConfig @@ -12,7 +14,7 @@ def get_tracing_app_config(cls, app_id: str, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = ( + trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -22,7 +24,10 @@ def get_tracing_app_config(cls, app_id: str, tracing_provider: str): return None # decrypt_token and obfuscated_token - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config( tenant_id, tracing_provider, trace_config_data.tracing_config ) @@ -73,8 +78,9 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["other_keys"], ) - default_config_instance = config_class(**tracing_config) - for key in other_keys: + # FIXME: ignore type error + default_config_instance = config_class(**tracing_config) # type: ignore + for key in other_keys: # type: ignore if key in tracing_config and tracing_config[key] == "": tracing_config[key] = getattr(default_config_instance, key, None) @@ -92,7 +98,7 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c project_url = None # check if trace config already exists - trace_config_data: TraceAppConfig = ( + trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -102,7 +108,10 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c return None # get tenant id - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config) if project_url: tracing_config["project_url"] = project_url @@ -139,7 +148,10 @@ def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c return None # get tenant id - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config( tenant_id, tracing_provider, tracing_config, current_trace_config.tracing_config ) diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py index 4704d533a950ed..523aebeed52a4e 100644 --- a/api/services/recommend_app/buildin/buildin_retrieval.py +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -41,7 +41,7 @@ def _get_builtin_data(cls) -> dict: Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8") ) - return cls.builtin_data + return cls.builtin_data or {} @classmethod def fetch_recommended_apps_from_builtin(cls, language: str) -> dict: @@ -50,8 +50,8 @@ def fetch_recommended_apps_from_builtin(cls, language: str) -> dict: :param language: language :return: """ - builtin_data = cls._get_builtin_data() - return builtin_data.get("recommended_apps", {}).get(language) + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + return builtin_data.get("recommended_apps", {}).get(language, {}) @classmethod def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]: @@ -60,5 +60,5 @@ def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict :param app_id: App ID :return: """ - builtin_data = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() return builtin_data.get("app_details", {}).get(app_id) diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index b0607a21323acb..80e1aefc01da85 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -47,8 +47,8 @@ def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optiona response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: return None - - return response.json() + data: dict = response.json() + return data @classmethod def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: @@ -63,7 +63,7 @@ def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: if response.status_code != 200: raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") - result = response.json() + result: dict = response.json() if "categories" in result: result["categories"] = sorted(result["categories"]) diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 4660316fcfcf71..54c58455155c03 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -33,5 +33,5 @@ def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() - result = retrieval_instance.get_recommend_app_detail(app_id) + result: dict = retrieval_instance.get_recommend_app_detail(app_id) return result diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 9fe3cecce7546d..4cb8700117e6f3 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -13,6 +13,8 @@ class SavedMessageService: def pagination_by_last_id( cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int ) -> InfiniteScrollPagination: + if not user: + raise ValueError("User is required") saved_messages = ( db.session.query(SavedMessage) .filter( @@ -31,6 +33,8 @@ def pagination_by_last_id( @classmethod def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + if not user: + return saved_message = ( db.session.query(SavedMessage) .filter( @@ -59,6 +63,8 @@ def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_i @classmethod def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + if not user: + return saved_message = ( db.session.query(SavedMessage) .filter( diff --git a/api/services/tag_service.py b/api/services/tag_service.py index a374bdcf002bef..9600601633cddb 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -1,7 +1,7 @@ import uuid from typing import Optional -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy import func from werkzeug.exceptions import NotFound @@ -21,7 +21,7 @@ def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = Non if keyword: query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) query = query.group_by(Tag.id) - results = query.order_by(Tag.created_at.desc()).all() + results: list = query.order_by(Tag.created_at.desc()).all() return results @staticmethod diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 78a80f70ab6b00..0e3bd3a7b83c68 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -1,6 +1,7 @@ import json import logging -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional, cast from httpx import get @@ -28,12 +29,12 @@ class ApiToolManageService: @staticmethod - def parser_api_schema(schema: str) -> list[ApiToolBundle]: + def parser_api_schema(schema: str) -> Mapping[str, Any]: """ parse api schema to tool bundle """ try: - warnings = {} + warnings: dict[str, str] = {} try: tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) except Exception as e: @@ -68,13 +69,16 @@ def parser_api_schema(schema: str) -> list[ApiToolBundle]: ), ] - return jsonable_encoder( - { - "schema_type": schema_type, - "parameters_schema": tool_bundles, - "credentials_schema": credentials_schema, - "warning": warnings, - } + return cast( + Mapping, + jsonable_encoder( + { + "schema_type": schema_type, + "parameters_schema": tool_bundles, + "credentials_schema": credentials_schema, + "warning": warnings, + } + ), ) except Exception as e: raise ValueError(f"invalid schema: {str(e)}") @@ -129,7 +133,7 @@ def create_api_tool_provider( raise ValueError(f"provider {provider_name} already exists") # parse openapi to tool bundle - extra_info = {} + extra_info: dict[str, str] = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) @@ -262,9 +266,8 @@ def update_api_tool_provider( if provider is None: raise ValueError(f"api provider {provider_name} does not exists") - # parse openapi to tool bundle - extra_info = {} + extra_info: dict[str, str] = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) @@ -416,7 +419,7 @@ def test_api_tool_preview( provider_controller.validate_credentials_format(credentials) # get tool tool = provider_controller.get_tool(tool_name) - tool = tool.fork_tool_runtime( + runtime_tool = tool.fork_tool_runtime( runtime={ "credentials": credentials, "tenant_id": tenant_id, @@ -454,7 +457,7 @@ def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) - for tool in tools: + for tool in tools or []: user_provider.tools.append( ToolTransformService.tool_to_user_tool( tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index fada881fdeb741..21adbb0074724e 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -50,8 +50,8 @@ def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str credentials = builtin_provider.credentials credentials = tool_provider_configurations.decrypt_tool_credentials(credentials) - result = [] - for tool in tools: + result: list[UserTool] = [] + for tool in tools or []: result.append( ToolTransformService.tool_to_user_tool( tool=tool, @@ -217,6 +217,8 @@ def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: name_func=lambda x: x.identity.name, ): continue + if provider_controller.identity is None: + continue # convert provider controller to user provider user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( @@ -229,7 +231,7 @@ def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: ToolTransformService.repack_provider(user_builtin_provider) tools = provider_controller.get_tools() - for tool in tools: + for tool in tools or []: user_builtin_provider.tools.append( ToolTransformService.tool_to_user_tool( tenant_id=tenant_id, diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index a4aa870dc80352..b501554bcd091d 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional, Union +from typing import Optional, Union, cast from configs import dify_config from core.tools.entities.api_entities import UserTool, UserToolProvider @@ -35,7 +35,7 @@ def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str return url_prefix + "builtin/" + provider_name + "/icon" elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: try: - return json.loads(icon) + return cast(dict, json.loads(icon)) except: return {"background": "#252525", "content": "\ud83d\ude01"} @@ -53,8 +53,11 @@ def repack_provider(provider: Union[dict, UserToolProvider]): provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"] ) elif isinstance(provider, UserToolProvider): - provider.icon = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon + provider.icon = cast( + str, + ToolTransformService.get_tool_provider_icon_url( + provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon + ), ) @staticmethod @@ -66,6 +69,9 @@ def builtin_provider_to_user_provider( """ convert provider controller to user provider """ + if provider_controller.identity is None: + raise ValueError("provider identity is None") + result = UserToolProvider( id=provider_controller.identity.name, author=provider_controller.identity.author, @@ -93,7 +99,8 @@ def builtin_provider_to_user_provider( # get credentials schema schema = provider_controller.get_credentials_schema() for name, value in schema.items(): - result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type) + assert result.masked_credentials is not None, "masked credentials is None" + result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(str(value.type)) # check if the provider need credentials if not provider_controller.need_credentials: @@ -149,6 +156,9 @@ def workflow_provider_to_user_provider( """ convert provider controller to user provider """ + if provider_controller.identity is None: + raise ValueError("provider identity is None") + return UserToolProvider( id=provider_controller.provider_id, author=provider_controller.identity.author, @@ -180,6 +190,8 @@ def api_provider_to_user_provider( convert provider controller to user provider """ username = "Anonymous" + if db_provider.user is None: + raise ValueError(f"user is None for api provider {db_provider.id}") try: username = db_provider.user.name except Exception as e: @@ -256,19 +268,25 @@ def tool_to_user_tool( if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: current_parameters.append(runtime_parameter) + if tool.identity is None: + raise ValueError("tool identity is None") + return UserTool( author=tool.identity.author, name=tool.identity.name, label=tool.identity.label, - description=tool.description.human, + description=I18nObject( + en_US=tool.description.human if tool.description else "", + zh_Hans=tool.description.human if tool.description else "", + ), parameters=current_parameters, labels=labels, ) if isinstance(tool, ApiToolBundle): return UserTool( author=tool.author, - name=tool.operation_id, - label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id), + name=tool.operation_id or "", + label=I18nObject(en_US=tool.operation_id or "", zh_Hans=tool.operation_id or ""), description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""), parameters=tool.parameters, labels=labels, diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 318107bebb5eb6..69430de432b143 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -6,8 +6,10 @@ from sqlalchemy import or_ from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.entities.api_entities import UserToolProvider +from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController +from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from extensions.ext_database import db @@ -32,7 +34,7 @@ def create_workflow_tool( label: str, icon: dict, description: str, - parameters: Mapping[str, Any], + parameters: list[Mapping[str, Any]], privacy_policy: str = "", labels: Optional[list[str]] = None, ) -> dict: @@ -97,7 +99,7 @@ def update_workflow_tool( label: str, icon: dict, description: str, - parameters: list[dict], + parameters: list[Mapping[str, Any]], privacy_policy: str = "", labels: Optional[list[str]] = None, ) -> dict: @@ -131,7 +133,7 @@ def update_workflow_tool( if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} already exists") - workflow_tool_provider: WorkflowToolProvider = ( + workflow_tool_provider: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() @@ -140,14 +142,14 @@ def update_workflow_tool( if workflow_tool_provider is None: raise ValueError(f"Tool {workflow_tool_id} not found") - app: App = ( + app: Optional[App] = ( db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() ) if app is None: raise ValueError(f"App {workflow_tool_provider.app_id} not found") - workflow: Workflow = app.workflow + workflow: Optional[Workflow] = app.workflow if workflow is None: raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") @@ -193,7 +195,7 @@ def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserTo # skip deleted tools pass - labels = ToolLabelManager.get_tools_labels(tools) + labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)]) result = [] @@ -202,10 +204,11 @@ def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserTo provider_controller=tool, labels=labels.get(tool.provider_id, []) ) ToolTransformService.repack_provider(user_tool_provider) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + continue user_tool_provider.tools = [ - ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, []) - ) + ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=labels.get(tool.provider_id, [])) ] result.append(user_tool_provider) @@ -236,7 +239,7 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = ( + db_tool: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() @@ -245,13 +248,19 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too if db_tool is None: raise ValueError(f"Tool {workflow_tool_id} not found") - workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + workflow_app: Optional[App] = ( + db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + ) if workflow_app is None: raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_tool_id} not found") + return { "name": db_tool.name, "label": db_tool.label, @@ -261,9 +270,9 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too "description": db_tool.description, "parameters": jsonable_encoder(db_tool.parameter_configurations), "tool": ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) + to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool) ), - "synced": workflow_app.workflow.version == db_tool.version, + "synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False, "privacy_policy": db_tool.privacy_policy, } @@ -276,7 +285,7 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_ :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = ( + db_tool: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) .first() @@ -285,12 +294,17 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_ if db_tool is None: raise ValueError(f"Tool {workflow_app_id} not found") - workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + workflow_app: Optional[App] = ( + db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + ) if workflow_app is None: raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_app_id} not found") return { "name": db_tool.name, @@ -301,14 +315,14 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_ "description": db_tool.description, "parameters": jsonable_encoder(db_tool.parameter_configurations), "tool": ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) + to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool) ), - "synced": workflow_app.workflow.version == db_tool.version, + "synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False, "privacy_policy": db_tool.privacy_policy, } @classmethod - def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]: + def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]: """ List workflow tool provider tools. :param user_id: the user id @@ -316,7 +330,7 @@ def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_ :param workflow_app_id: the workflow app id :return: the list of tools """ - db_tool: WorkflowToolProvider = ( + db_tool: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() @@ -326,9 +340,8 @@ def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_ raise ValueError(f"Tool {workflow_tool_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_tool_id} not found") - return [ - ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) - ) - ] + return [ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool))] diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index 508fe20970a703..f698ed3084bdac 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -26,6 +26,8 @@ def pagination_by_last_id( pinned: Optional[bool] = None, sort_by="-updated_at", ) -> InfiniteScrollPagination: + if not user: + raise ValueError("User is required") include_ids = None exclude_ids = None if pinned is not None and user: @@ -59,6 +61,8 @@ def pagination_by_last_id( @classmethod def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + if not user: + return pinned_conversation = ( db.session.query(PinnedConversation) .filter( @@ -89,6 +93,8 @@ def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, @classmethod def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + if not user: + return pinned_conversation = ( db.session.query(PinnedConversation) .filter( diff --git a/api/services/website_service.py b/api/services/website_service.py index 230f5d78152f39..1ad7d0399d6edf 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,8 +1,9 @@ import datetime import json +from typing import Any import requests -from flask_login import current_user +from flask_login import current_user # type: ignore from core.helper import encrypter from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp @@ -23,9 +24,9 @@ def document_create_args_validate(cls, args: dict): @classmethod def crawl_url(cls, args: dict) -> dict: - provider = args.get("provider") + provider = args.get("provider", "") url = args.get("url") - options = args.get("options") + options = args.get("options", "") credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) if provider == "firecrawl": # decrypt api_key @@ -164,16 +165,18 @@ def get_crawl_status(cls, job_id: str, provider: str) -> dict: return crawl_status_data @classmethod - def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None: + def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[Any, Any] | None: credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) # decrypt api_key api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + # FIXME data is redefine too many times here, use Any to ease the type checking, fix it later + data: Any if provider == "firecrawl": file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): - data = storage.load_once(file_key) - if data: - data = json.loads(data.decode("utf-8")) + d = storage.load_once(file_key) + if d: + data = json.loads(d.decode("utf-8")) else: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) @@ -183,22 +186,17 @@ def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str if data: for item in data: if item.get("source_url") == url: - return item + return dict(item) return None elif provider == "jinareader": - file_key = "website_files/" + job_id + ".txt" - if storage.exists(file_key): - data = storage.load_once(file_key) - if data: - data = json.loads(data.decode("utf-8")) - elif not job_id: + if not job_id: response = requests.get( f"https://r.jina.ai/{url}", headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, ) if response.json().get("code") != 200: raise ValueError("Failed to crawl") - return response.json().get("data") + return dict(response.json().get("data", {})) else: api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) response = requests.post( @@ -218,12 +216,13 @@ def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str data = response.json().get("data", {}) for item in data.get("processed", {}).values(): if item.get("data", {}).get("url") == url: - return item.get("data", {}) + return dict(item.get("data", {})) + return None else: raise ValueError("Invalid provider") @classmethod - def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None: + def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict: credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) if provider == "firecrawl": # decrypt api_key diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 90b5cc48362f3b..2b0d57bdfdeda3 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional from core.app.app_config.entities import ( DatasetEntity, @@ -101,7 +101,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config) # init workflow graph - graph = {"nodes": [], "edges": []} + graph: dict[str, Any] = {"nodes": [], "edges": []} # Convert list: # - variables -> start @@ -118,7 +118,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: graph["nodes"].append(start_node) # convert to http request node - external_data_variable_node_mapping = {} + external_data_variable_node_mapping: dict[str, str] = {} if app_config.external_data_variables: http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node( app_model=app_model, @@ -199,15 +199,16 @@ def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: return workflow def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: - app_mode = AppMode.value_of(app_model.mode) - if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: + app_mode_enum = AppMode.value_of(app_model.mode) + app_config: EasyUIBasedAppConfig + if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent: app_model.mode = AppMode.AGENT_CHAT.value app_config = AgentChatAppConfigManager.get_app_config( app_model=app_model, app_model_config=app_model_config ) - elif app_mode == AppMode.CHAT: + elif app_mode_enum == AppMode.CHAT: app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) - elif app_mode == AppMode.COMPLETION: + elif app_mode_enum == AppMode.COMPLETION: app_config = CompletionAppConfigManager.get_app_config( app_model=app_model, app_model_config=app_model_config ) @@ -302,7 +303,7 @@ def _convert_to_http_request_node( nodes.append(http_request_node) # append code node for response body parsing - code_node = { + code_node: dict[str, Any] = { "id": f"code_{index}", "position": None, "data": { @@ -401,6 +402,7 @@ def _convert_to_llm_node( ) role_prefix = None + prompts: Any = None # Chat Model if model_config.mode == LLMMode.CHAT.value: diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index d8ee323908a844..4343596a236f5f 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,3 +1,5 @@ +from typing import Optional + from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom @@ -92,7 +94,7 @@ def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScro return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) - def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: + def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]: """ Get workflow run detail diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 84768d5af053e4..ea8192edde35cc 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,7 +2,7 @@ import time from collections.abc import Sequence from datetime import UTC, datetime -from typing import Optional, cast +from typing import Any, Optional, cast from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager @@ -242,7 +242,7 @@ def run_draft_workflow_node( raise ValueError("Node run failed with no run result") # single step debug mode error handling return if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: - node_error_args = { + node_error_args: dict[str, Any] = { "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": node_run_result.error, "inputs": node_run_result.inputs, @@ -338,7 +338,7 @@ def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> A raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow - new_app = workflow_converter.convert_to_workflow( + new_app: App = workflow_converter.convert_to_workflow( app_model=app_model, account=account, name=args.get("name", "Default Name"), diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 8fcb12b1cb9664..7637b31454e556 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,4 +1,4 @@ -from flask_login import current_user +from flask_login import current_user # type: ignore from configs import dify_config from extensions.ext_database import db @@ -29,6 +29,7 @@ def get_tenant_info(cls, tenant: Tenant): .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) .first() ) + assert tenant_account_join is not None, "TenantAccountJoin not found" tenant_info["role"] = tenant_account_join.role can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo diff --git a/api/tasks/__init__.py b/api/tasks/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 09be6612160471..50bb2b6e634fba 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index 25c55bcfafe11c..aab21a44109975 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index fa7e5ac9190f3c..06162b02d60f8b 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index f0f6b32b06c78c..a6a598ce4b6bca 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.datasource.vdb.vector_factory import Vector from models.dataset import Dataset diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index a2f49135139b08..26bf1c7c9fa32e 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 0bdcd0eccd7f72..b42af0c7faf67e 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index b685d84d07ad28..8c675feaa6e06f 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index dcb7009e44b938..26ae9f8736d79a 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -4,7 +4,7 @@ import uuid import click -from celery import shared_task +from celery import shared_task # type: ignore from sqlalchemy import func from core.indexing_runner import IndexingRunner @@ -58,12 +58,13 @@ def batch_create_segment_to_index_task( model=dataset.embedding_model, ) word_count_change = 0 + segments_to_insert: list[str] = [] # Explicitly type hint the list as List[str] for segment in content: - content = segment["content"] + content_str = segment["content"] doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) + segment_hash = helper.generate_text_hash(content_str) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) if embedding_model else 0 + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0 max_position = ( db.session.query(func.max(DocumentSegment.position)) .filter(DocumentSegment.document_id == dataset_document.id) @@ -90,6 +91,7 @@ def batch_create_segment_to_index_task( word_count_change += segment_document.word_count db.session.add(segment_document) document_segments.append(segment_document) + segments_to_insert.append(str(segment)) # Cast to string if needed # update document word count dataset_document.word_count += word_count_change db.session.add(dataset_document) diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index a555fb28746697..d9278c03793877 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -71,6 +71,8 @@ def clean_dataset_task( image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + if image_file is None: + continue try: storage.delete(image_file.key) except Exception: diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 4d328643bfa165..3e80dd13771802 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -3,7 +3,7 @@ from typing import Optional import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -44,6 +44,8 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + if image_file is None: + continue try: storage.delete(image_file.key) except Exception: diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 75d9e031306381..f5d6406d9cc04f 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 315b01f157bf13..dfa053a43cbc61 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -4,7 +4,7 @@ from typing import Optional import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index cfc54920e23caa..b025509aebe674 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index c3e0ea5d9fbb77..45a612c74550cd 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 15e1e50076e8c9..f30a1cc7acfd6c 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 18316913932874..ac4e81f95d127e 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 734dd2478a9847..21b571b6cb5bd4 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 1a52a6636b1d17..5f1e9a892f54e3 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index f4c3dbd2e2860c..6db2620eb6eef0 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner @@ -26,6 +26,8 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if dataset is None: + raise ValueError("Dataset not found") # check document limit features = FeatureService.get_features(dataset.tenant_id) diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 12639db9392677..2f6eb7b82a0633 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py index d78fc2b8915520..5dc935548f90b8 100644 --- a/api/tasks/mail_email_code_login.py +++ b/api/tasks/mail_email_code_login.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index c7dfb9bf6063ff..3094527fd40945 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from flask import render_template from configs import dify_config diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index 8596ca07cfcee3..d5be94431b6221 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 34c62dc9237fc0..bb3b9e17ead6d2 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -1,7 +1,7 @@ import json import logging -from celery import shared_task +from celery import shared_task # type: ignore from flask import current_app from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 934eb7430c90c3..b603d689ba9d8e 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 66f78636ecca60..c3910e2be3a499 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -3,7 +3,7 @@ from collections.abc import Callable import click -from celery import shared_task +from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 1909eaf3418517..4ba6d1a83e32ae 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 73471fd6e77c9b..485caa5152ea78 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -22,10 +22,13 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): Usage: retry_document_indexing_task.delay(dataset_id, document_id) """ - documents = [] + documents: list[Document] = [] start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("Dataset not found") + for document_id in document_ids: retry_indexing_cache_key = "document_{}_is_retried".format(document_id) # check document limit @@ -55,29 +58,31 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): document = ( db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() ) + if not document: + logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) + return try: - if document: - # clean old data - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids) + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids) - for segment in segments: - db.session.delete(segment) - db.session.commit() - - document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() - db.session.add(document) + for segment in segments: + db.session.delete(segment) db.session.commit() - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - redis_client.delete(retry_indexing_cache_key) + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.utcnow() + db.session.add(document) + db.session.commit() + + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(retry_indexing_cache_key) except Exception as ex: document.indexing_status = "error" document.error = str(ex) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 1d2a338c831764..5d6b069cf44919 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -25,6 +25,8 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if dataset is None: + raise ValueError("Dataset not found") sync_indexing_cache_key = "document_{}_is_sync".format(document_id) # check document limit @@ -52,29 +54,31 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): logging.info(click.style("Start sync website document: {}".format(document_id), fg="green")) document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + if not document: + logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) + return try: - if document: - # clean old data - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids) + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids) - for segment in segments: - db.session.delete(segment) - db.session.commit() - - document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() - db.session.add(document) + for segment in segments: + db.session.delete(segment) db.session.commit() - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - redis_client.delete(sync_indexing_cache_key) + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.utcnow() + db.session.add(document) + db.session.commit() + + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(sync_indexing_cache_key) except Exception as ex: document.indexing_status = "error" document.error = str(ex) diff --git a/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py index 64f2884c4b828c..57fba317638de8 100644 --- a/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py +++ b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py @@ -1,6 +1,6 @@ from typing import Any -import toml +import toml # type: ignore def load_api_poetry_configs() -> dict[str, Any]: @@ -38,7 +38,7 @@ def test_group_dependencies_version_operator(): ) -def test_duplicated_dependency_crossing_groups(): +def test_duplicated_dependency_crossing_groups() -> None: all_dependency_names: list[str] = [] for dependencies in load_all_dependency_groups().values(): dependency_names = list(dependencies.keys()) diff --git a/api/tests/integration_tests/controllers/test_controllers.py b/api/tests/integration_tests/controllers/test_controllers.py index 6371694694653e..5e3ee6bedc7ebb 100644 --- a/api/tests/integration_tests/controllers/test_controllers.py +++ b/api/tests/integration_tests/controllers/test_controllers.py @@ -1,6 +1,6 @@ from unittest.mock import patch -from app_fixture import app, mock_user +from app_fixture import mock_user # type: ignore def test_post_requires_login(app): diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index 5ea86baa83dd4b..b90f8b444477d5 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -1,7 +1,7 @@ from collections.abc import Generator from unittest.mock import MagicMock -import google.generativeai.types.generation_types as generation_config_types +import google.generativeai.types.generation_types as generation_config_types # type: ignore import pytest from _pytest.monkeypatch import MonkeyPatch from google.ai import generativelanguage as glm @@ -45,7 +45,7 @@ def generate_content_sync() -> GenerateContentResponse: return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]) @staticmethod - def generate_content_stream() -> Generator[GenerateContentResponse, None, None]: + def generate_content_stream() -> MockGoogleResponseClass: return MockGoogleResponseClass() def generate_content( diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface.py b/api/tests/integration_tests/model_runtime/__mock/huggingface.py index 97038ef5963e87..4de52514408a06 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface.py @@ -2,7 +2,7 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from huggingface_hub import InferenceClient +from huggingface_hub import InferenceClient # type: ignore from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py index 9ee76c935c9873..77c7e7f5e4089c 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py @@ -3,15 +3,15 @@ from typing import Any, Literal, Optional, Union from _pytest.monkeypatch import MonkeyPatch -from huggingface_hub import InferenceClient -from huggingface_hub.inference._text_generation import ( +from huggingface_hub import InferenceClient # type: ignore +from huggingface_hub.inference._text_generation import ( # type: ignore Details, StreamDetails, TextGenerationResponse, TextGenerationStreamResponse, Token, ) -from huggingface_hub.utils import BadRequestError +from huggingface_hub.utils import BadRequestError # type: ignore class MockHuggingfaceChatClass: diff --git a/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py index 6a25398cbf069a..4e00660a29162f 100644 --- a/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py +++ b/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py @@ -6,7 +6,7 @@ # import monkeypatch from _pytest.monkeypatch import MonkeyPatch -from nomic import embed +from nomic import embed # type: ignore def create_embedding(texts: list[str], model: str, **kwargs: Any) -> dict: diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index 794f4b0585632e..e2abaa52b939a6 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -6,14 +6,14 @@ from _pytest.monkeypatch import MonkeyPatch from requests import Response from requests.sessions import Session -from xinference_client.client.restful.restful_client import ( +from xinference_client.client.restful.restful_client import ( # type: ignore Client, RESTfulChatModelHandle, RESTfulEmbeddingModelHandle, RESTfulGenerateModelHandle, RESTfulRerankModelHandle, ) -from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage +from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage # type: ignore class MockXinferenceClass: diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py index 2dcfb92c63fee2..d37fcf897fc3a8 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py @@ -1,6 +1,6 @@ import os -import dashscope +import dashscope # type: ignore import pytest from core.model_runtime.entities.rerank_entities import RerankResult diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py index 83f4d70ce9ac2f..2860739f0e30b3 100644 --- a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py +++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py @@ -1,5 +1,5 @@ from flask import Flask, request -from flask_restful import Api, Resource +from flask_restful import Api, Resource # type: ignore app = Flask(__name__) api = Api(app) diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index 0ea61369c0304e..4af35a8befcaf8 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -4,11 +4,11 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from pymochow import MochowClient -from pymochow.model.database import Database -from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState -from pymochow.model.schema import HNSWParams, VectorIndex -from pymochow.model.table import Table +from pymochow import MochowClient # type: ignore +from pymochow.model.database import Database # type: ignore +from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore +from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore +from pymochow.model.table import Table # type: ignore from requests.adapters import HTTPAdapter diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index 61d6ed16560c09..68a1e290adc120 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -4,12 +4,12 @@ import pytest from _pytest.monkeypatch import MonkeyPatch from requests.adapters import HTTPAdapter -from tcvectordb import VectorDBClient -from tcvectordb.model.database import Collection, Database -from tcvectordb.model.document import Document, Filter -from tcvectordb.model.enum import ReadConsistency -from tcvectordb.model.index import Index -from xinference_client.types import Embedding +from tcvectordb import VectorDBClient # type: ignore +from tcvectordb.model.database import Collection, Database # type: ignore +from tcvectordb.model.document import Document, Filter # type: ignore +from tcvectordb.model.enum import ReadConsistency # type: ignore +from tcvectordb.model.index import Index # type: ignore +from xinference_client.types import Embedding # type: ignore class MockTcvectordbClass: diff --git a/api/tests/integration_tests/vdb/__mock/vikingdb.py b/api/tests/integration_tests/vdb/__mock/vikingdb.py index 0f40337feba6ee..3ad72e55501f58 100644 --- a/api/tests/integration_tests/vdb/__mock/vikingdb.py +++ b/api/tests/integration_tests/vdb/__mock/vikingdb.py @@ -4,7 +4,7 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from volcengine.viking_db import ( +from volcengine.viking_db import ( # type: ignore Collection, Data, DistanceType, diff --git a/api/tests/unit_tests/oss/__mock/aliyun_oss.py b/api/tests/unit_tests/oss/__mock/aliyun_oss.py index 27e1c0ad85029b..4f6d8a2f54a4fd 100644 --- a/api/tests/unit_tests/oss/__mock/aliyun_oss.py +++ b/api/tests/unit_tests/oss/__mock/aliyun_oss.py @@ -4,8 +4,8 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from oss2 import Bucket -from oss2.models import GetObjectResult, PutObjectResult +from oss2 import Bucket # type: ignore +from oss2.models import GetObjectResult, PutObjectResult # type: ignore from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/__mock/tencent_cos.py b/api/tests/unit_tests/oss/__mock/tencent_cos.py index 5189b68e87132a..c77c5b08f37d15 100644 --- a/api/tests/unit_tests/oss/__mock/tencent_cos.py +++ b/api/tests/unit_tests/oss/__mock/tencent_cos.py @@ -3,8 +3,8 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from qcloud_cos import CosS3Client -from qcloud_cos.streambody import StreamBody +from qcloud_cos import CosS3Client # type: ignore +from qcloud_cos.streambody import StreamBody # type: ignore from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/__mock/volcengine_tos.py b/api/tests/unit_tests/oss/__mock/volcengine_tos.py index 649d93a20261d3..88df59f91c3071 100644 --- a/api/tests/unit_tests/oss/__mock/volcengine_tos.py +++ b/api/tests/unit_tests/oss/__mock/volcengine_tos.py @@ -4,8 +4,8 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from tos import TosClientV2 -from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput +from tos import TosClientV2 # type: ignore +from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py b/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py index 65d31352bd3437..380134bc46d02e 100644 --- a/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py +++ b/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch import pytest -from oss2 import Auth +from oss2 import Auth # type: ignore from extensions.storage.aliyun_oss_storage import AliyunOssStorage from tests.unit_tests.oss.__mock.aliyun_oss import setup_aliyun_oss_mock diff --git a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py index 303f0493bda42f..d289751800633a 100644 --- a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py +++ b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from qcloud_cos import CosConfig +from qcloud_cos import CosConfig # type: ignore from extensions.storage.tencent_cos_storage import TencentCosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py index 5afbc9e8b4cb18..04988e85d85881 100644 --- a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py +++ b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py @@ -1,5 +1,5 @@ import pytest -from tos import TosClientV2 +from tos import TosClientV2 # type: ignore from extensions.storage.volcengine_tos_storage import VolcengineTosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py index 95b93651d57f80..8d645487278a5f 100644 --- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -1,7 +1,7 @@ from textwrap import dedent import pytest -from yaml import YAMLError +from yaml import YAMLError # type: ignore from core.tools.utils.yaml_utils import load_yaml_file diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index e6644883018769..ee1b5c57e1d1d0 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -160,7 +160,10 @@ def get_result(self, workflow_run_id): class KnowledgeBaseClient(DifyClient): def __init__( - self, api_key, base_url: str = "https://api.dify.ai/v1", dataset_id: str = None + self, + api_key, + base_url: str = "https://api.dify.ai/v1", + dataset_id: str | None = None, ): """ Construct a KnowledgeBaseClient object. @@ -187,7 +190,9 @@ def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): "GET", f"/datasets?page={page}&limit={page_size}", **kwargs ) - def create_document_by_text(self, name, text, extra_params: dict = None, **kwargs): + def create_document_by_text( + self, name, text, extra_params: dict | None = None, **kwargs + ): """ Create a document by text. @@ -225,7 +230,7 @@ def create_document_by_text(self, name, text, extra_params: dict = None, **kwarg return self._send_request("POST", url, json=data, **kwargs) def update_document_by_text( - self, document_id, name, text, extra_params: dict = None, **kwargs + self, document_id, name, text, extra_params: dict | None = None, **kwargs ): """ Update a document by text. @@ -262,7 +267,7 @@ def update_document_by_text( return self._send_request("POST", url, json=data, **kwargs) def create_document_by_file( - self, file_path, original_document_id=None, extra_params: dict = None + self, file_path, original_document_id=None, extra_params: dict | None = None ): """ Create a document by file. @@ -304,7 +309,7 @@ def create_document_by_file( ) def update_document_by_file( - self, document_id, file_path, extra_params: dict = None + self, document_id, file_path, extra_params: dict | None = None ): """ Update a document by file. @@ -372,7 +377,11 @@ def delete_document(self, document_id): return self._send_request("DELETE", url) def list_documents( - self, page: int = None, page_size: int = None, keyword: str = None, **kwargs + self, + page: int | None = None, + page_size: int | None = None, + keyword: str | None = None, + **kwargs, ): """ Get a list of documents in this dataset. @@ -402,7 +411,11 @@ def add_segments(self, document_id, segments, **kwargs): return self._send_request("POST", url, json=data, **kwargs) def query_segments( - self, document_id, keyword: str = None, status: str = None, **kwargs + self, + document_id, + keyword: str | None = None, + status: str | None = None, + **kwargs, ): """ Query segments in this document. From cdaef30cc9e0ae5d944b95603842f41fe257b9d0 Mon Sep 17 00:00:00 2001 From: TinsFox Date: Tue, 24 Dec 2024 19:13:24 +0800 Subject: [PATCH 03/65] refactor: replace div with button for better accessibility (#12046) --- web/app/(commonLayout)/apps/NewAppCard.tsx | 23 +++++++++++----------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/web/app/(commonLayout)/apps/NewAppCard.tsx b/web/app/(commonLayout)/apps/NewAppCard.tsx index d353cf239431ff..a90af4ea85caf4 100644 --- a/web/app/(commonLayout)/apps/NewAppCard.tsx +++ b/web/app/(commonLayout)/apps/NewAppCard.tsx @@ -18,7 +18,6 @@ export type CreateAppCardProps = { onSuccess?: () => void } -// eslint-disable-next-line react/display-name const CreateAppCard = forwardRef(({ className, onSuccess }, ref) => { const { t } = useTranslation() const { onPlanInfoChanged } = useProviderContext() @@ -44,24 +43,22 @@ const CreateAppCard = forwardRef(({ classNam >
{t('app.createApp')}
-
setShowNewAppModal(true)}> +
-
setShowNewAppTemplateDialog(true)}> + +
-
-
setShowCreateFromDSLModal(true)} - > -
+ +
+
+ setShowNewAppModal(false)} @@ -108,4 +105,6 @@ const CreateAppCard = forwardRef(({ classNam ) }) +CreateAppCard.displayName = 'CreateAppCard' export default CreateAppCard +export { CreateAppCard } From 49bc602fb237183de94cae25761033bd768e0109 Mon Sep 17 00:00:00 2001 From: eux Date: Tue, 24 Dec 2024 21:58:05 +0800 Subject: [PATCH 04/65] fix: --name option for the create-tenant command does not take effect (#11993) --- api/commands.py | 9 +++++++-- api/services/account_service.py | 3 ++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/api/commands.py b/api/commands.py index ad7ad972f3fd01..59dfce68e0c92f 100644 --- a/api/commands.py +++ b/api/commands.py @@ -561,8 +561,13 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str new_password = secrets.token_urlsafe(16) # register account - account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) - + account = RegisterService.register( + email=email, + name=account_name, + password=new_password, + language=language, + create_workspace_required=False, + ) TenantService.create_owner_tenant_if_not_exist(account, name) click.echo( diff --git a/api/services/account_service.py b/api/services/account_service.py index 91075ec46b16bf..2d37db391c899c 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -797,6 +797,7 @@ def register( language: Optional[str] = None, status: Optional[AccountStatus] = None, is_setup: Optional[bool] = False, + create_workspace_required: Optional[bool] = True, ) -> Account: db.session.begin_nested() """Register account""" @@ -814,7 +815,7 @@ def register( if open_id is not None and provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if FeatureService.get_system_features().is_allow_create_workspace: + if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required: tenant = TenantService.create_tenant(f"{account.name}'s Workspace") TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant From 0ea6a926c5e7d213b06c3825b29cbd4563f47ced Mon Sep 17 00:00:00 2001 From: yihong Date: Tue, 24 Dec 2024 23:14:32 +0800 Subject: [PATCH 05/65] fix: tool can not run (#12054) Signed-off-by: yihong0618 --- api/models/tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/models/tools.py b/api/models/tools.py index 4151a2e9f636a0..13a112ee83b513 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional import sqlalchemy as sa from sqlalchemy import ForeignKey, func @@ -282,8 +282,8 @@ class ToolConversationVariables(db.Model): # type: ignore[name-defined] updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def variables(self) -> dict: - return dict(json.loads(self.variables_str)) + def variables(self) -> Any: + return json.loads(self.variables_str) class ToolFile(db.Model): # type: ignore[name-defined] From 7a24c957bdb3a1477e5d8e2128af25fcb91d6627 Mon Sep 17 00:00:00 2001 From: yihong Date: Tue, 24 Dec 2024 23:14:51 +0800 Subject: [PATCH 06/65] fix: i18n error (#12052) Signed-off-by: yihong0618 --- api/core/tools/entities/tool_entities.py | 8 +++++--- api/services/tools/tools_transform_service.py | 5 +---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 260e4e457f083e..c87a90c03a6f7e 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -243,9 +243,11 @@ def get_simple_instance( :param options: the options of the parameter """ # convert options to ToolParameterOption + # FIXME fix the type error if options: - options_tool_parametor = [ - ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options + options = [ + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) # type: ignore + for option in options # type: ignore ] return cls( name=name, @@ -256,7 +258,7 @@ def get_simple_instance( form=cls.ToolParameterForm.LLM, llm_description=llm_description, required=required, - options=options_tool_parametor, + options=options, # type: ignore ) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index b501554bcd091d..6e3a45be0da1c4 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -275,10 +275,7 @@ def tool_to_user_tool( author=tool.identity.author, name=tool.identity.name, label=tool.identity.label, - description=I18nObject( - en_US=tool.description.human if tool.description else "", - zh_Hans=tool.description.human if tool.description else "", - ), + description=tool.description.human if tool.description else "", # type: ignore parameters=current_parameters, labels=labels, ) From 7da4fb68da065385b78db63ee51183615f74b351 Mon Sep 17 00:00:00 2001 From: yihong Date: Wed, 25 Dec 2024 08:42:52 +0800 Subject: [PATCH 07/65] fix: can not find model bug (#12051) Signed-off-by: yihong0618 --- api/core/entities/provider_configuration.py | 2 +- api/services/entities/model_provider_entities.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 2e27b362d3092c..bff5a0ec9c6be7 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -872,7 +872,7 @@ def _get_system_provider_models( # if llm name not in restricted llm list, remove it restrict_model_names = [rm.model for rm in restrict_models] for model in provider_models: - if model.model_type == ModelType.LLM and m.model not in restrict_model_names: + if model.model_type == ModelType.LLM and model.model not in restrict_model_names: model.status = ModelStatus.NO_PERMISSION elif not quota_configuration.is_valid: model.status = ModelStatus.QUOTA_EXCEEDED diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 334d009ee5f79f..f1417c6cb94b80 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -7,7 +7,6 @@ from core.entities.model_entities import ( ModelWithProviderEntity, ProviderModelWithStatusEntity, - SimpleModelProviderEntity, ) from core.entities.provider_entities import QuotaConfiguration from core.model_runtime.entities.common_entities import I18nObject @@ -152,7 +151,8 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity): Model with provider entity. """ - provider: SimpleModelProviderEntity + # FIXME type error ignore here + provider: SimpleProviderEntityResponse # type: ignore def __init__(self, model: ModelWithProviderEntity) -> None: super().__init__(**model.model_dump()) From 1d3f218662527db991e2b520e961ec3fa07603df Mon Sep 17 00:00:00 2001 From: yihong Date: Wed, 25 Dec 2024 10:57:52 +0800 Subject: [PATCH 08/65] fix: like failed close #12057 (#12058) Signed-off-by: yihong0618 --- api/controllers/console/explore/message.py | 2 +- api/controllers/service_api/app/message.py | 2 +- api/controllers/web/message.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index c3488de29929c9..690297048eb55c 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -69,7 +69,7 @@ def post(self, installed_app, message_id): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"]) + MessageService.create_feedback(app_model, message_id, current_user, args.get("rating"), args.get("content")) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 522c7509b9849d..bed89a99a58683 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -108,7 +108,7 @@ def post(self, app_model: App, end_user: EndUser, message_id): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) + MessageService.create_feedback(app_model, message_id, end_user, args.get("rating"), args.get("content")) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 0f47e643708570..b636e6be620ec6 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -108,7 +108,7 @@ def post(self, app_model, end_user, message_id): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) + MessageService.create_feedback(app_model, message_id, end_user, args.get("rating"), args.get("content")) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") From 3ea54e9d2574446664a137280875ea3ec04ba7a7 Mon Sep 17 00:00:00 2001 From: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com> Date: Wed, 25 Dec 2024 12:00:45 +0900 Subject: [PATCH 09/65] =?UTF-8?q?fix:=20update=20S3=20and=20Azure=20config?= =?UTF-8?q?uration=20typos=20in=20.env.example=20and=20corr=E2=80=A6=20(#1?= =?UTF-8?q?2055)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/.env.example | 6 +++--- api/.ruff.toml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/.env.example b/api/.env.example index 071a200e680278..cc3e868717e2fb 100644 --- a/api/.env.example +++ b/api/.env.example @@ -65,7 +65,7 @@ OPENDAL_FS_ROOT=storage # S3 Storage configuration S3_USE_AWS_MANAGED_IAM=false -S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com +S3_ENDPOINT=https://your-bucket-name.storage.s3.cloudflare.com S3_BUCKET_NAME=your-bucket-name S3_ACCESS_KEY=your-access-key S3_SECRET_KEY=your-secret-key @@ -74,7 +74,7 @@ S3_REGION=your-region # Azure Blob Storage configuration AZURE_BLOB_ACCOUNT_NAME=your-account-name AZURE_BLOB_ACCOUNT_KEY=your-account-key -AZURE_BLOB_CONTAINER_NAME=yout-container-name +AZURE_BLOB_CONTAINER_NAME=your-container-name AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net # Aliyun oss Storage configuration @@ -88,7 +88,7 @@ ALIYUN_OSS_REGION=your-region ALIYUN_OSS_PATH=your-path # Google Storage configuration -GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name +GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string # Tencent COS Storage configuration diff --git a/api/.ruff.toml b/api/.ruff.toml index 26a1b977a9f6ac..f30275a943d806 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -67,7 +67,7 @@ ignore = [ "SIM105", # suppressible-exception "SIM107", # return-in-try-except-finally "SIM108", # if-else-block-instead-of-if-exp - "SIM113", # eumerate-for-loop + "SIM113", # enumerate-for-loop "SIM117", # multiple-with-statements "SIM210", # if-expr-with-true-false ] From c98d91e44d75cf03f395eee521e5af9a36a45ad8 Mon Sep 17 00:00:00 2001 From: jiangbo721 <365065261@qq.com> Date: Wed, 25 Dec 2024 13:29:43 +0800 Subject: [PATCH 10/65] fix: o1 model error, use max_completion_tokens instead of max_tokens. (#12037) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 刘江波 --- .../model_providers/azure_openai/llm/llm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index c5d7a83a4ee69f..03818741f65875 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -113,7 +113,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) - if "o1" in model: + if model.startswith("o1"): client.chat.completions.create( messages=[{"role": "user", "content": "ping"}], model=model, @@ -311,7 +311,10 @@ def _chat_generate( prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) block_as_stream = False - if "o1" in model: + if model.startswith("o1"): + if "max_tokens" in model_parameters: + model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] + del model_parameters["max_tokens"] if stream: block_as_stream = True stream = False @@ -404,7 +407,7 @@ def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[Promp ] ) - if "o1" in model: + if model.startswith("o1"): system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) if system_message_count > 0: new_prompt_messages = [] From b281a80150139d419e43df3ee08286aa4c4f6513 Mon Sep 17 00:00:00 2001 From: marvin-season <64943287+marvin-season@users.noreply.github.com> Date: Wed, 25 Dec 2024 13:30:51 +0800 Subject: [PATCH 11/65] fix: zoom in/out click (#12056) Co-authored-by: marvin --- .../components/workflow/operator/zoom-in-out.tsx | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/web/app/components/workflow/operator/zoom-in-out.tsx b/web/app/components/workflow/operator/zoom-in-out.tsx index 6c4bed3751088f..90b5b46256300c 100644 --- a/web/app/components/workflow/operator/zoom-in-out.tsx +++ b/web/app/components/workflow/operator/zoom-in-out.tsx @@ -129,7 +129,7 @@ const ZoomInOut: FC = () => { crossAxis: -2, }} > - +
{ shortcuts={['ctrl', '-']} >
{ + if (zoom <= 0.25) + return + e.stopPropagation() zoomOut() }} @@ -153,14 +156,17 @@ const ZoomInOut: FC = () => {
-
{parseFloat(`${zoom * 100}`).toFixed(0)}%
+
{parseFloat(`${zoom * 100}`).toFixed(0)}%
= 2 ? 'cursor-not-allowed' : 'cursor-pointer hover:bg-black/5'}`} onClick={(e) => { + if (zoom >= 2) + return + e.stopPropagation() zoomIn() }} From 83ea931e3cfc4467003b93949f86281052325902 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 25 Dec 2024 16:24:52 +0800 Subject: [PATCH 12/65] refactor: optimize database usage (#12071) Signed-off-by: -LAN- --- .../advanced_chat/generate_task_pipeline.py | 352 +++++++++--------- .../app/apps/message_based_app_generator.py | 1 - .../apps/workflow/generate_task_pipeline.py | 192 +++++----- .../based_generate_task_pipeline.py | 36 +- .../easy_ui_based_generate_task_pipeline.py | 145 ++++---- .../app/task_pipeline/message_cycle_manage.py | 4 +- .../task_pipeline/workflow_cycle_manage.py | 182 +++++---- api/core/ops/ops_trace_manager.py | 186 ++++----- api/core/ops/utils.py | 2 +- api/models/account.py | 3 +- api/models/model.py | 10 +- api/models/workflow.py | 48 +-- 12 files changed, 587 insertions(+), 574 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 1073a0f2e4f706..691d178ba2aefa 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -5,6 +5,9 @@ from threading import Thread from typing import Any, Optional, Union +from sqlalchemy import select +from sqlalchemy.orm import Session + from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -79,8 +82,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _task_state: WorkflowTaskState _application_generate_entity: AdvancedChatAppGenerateEntity - _workflow: Workflow - _user: Union[Account, EndUser] _workflow_system_variables: dict[SystemVariableKey, Any] _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] _conversation_name_generate_thread: Optional[Thread] = None @@ -96,32 +97,35 @@ def __init__( stream: bool, dialogue_count: int, ) -> None: - """ - Initialize AdvancedChatAppGenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param workflow: workflow - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param user: user - :param stream: stream - :param dialogue_count: dialogue count - """ - super().__init__(application_generate_entity, queue_manager, user, stream) + super().__init__( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) - if isinstance(self._user, EndUser): - user_id = self._user.session_id + if isinstance(user, EndUser): + self._user_id = user.session_id + self._created_by_role = CreatedByRole.END_USER + elif isinstance(user, Account): + self._user_id = user.id + self._created_by_role = CreatedByRole.ACCOUNT else: - user_id = self._user.id + raise NotImplementedError(f"User type not supported: {type(user)}") + + self._workflow_id = workflow.id + self._workflow_features_dict = workflow.features_dict + + self._conversation_id = conversation.id + self._conversation_mode = conversation.mode + + self._message_id = message.id + self._message_created_at = int(message.created_at.timestamp()) - self._workflow = workflow - self._conversation = conversation - self._message = message self._workflow_system_variables = { SystemVariableKey.QUERY: message.query, SystemVariableKey.FILES: application_generate_entity.files, SystemVariableKey.CONVERSATION_ID: conversation.id, - SystemVariableKey.USER_ID: user_id, + SystemVariableKey.USER_ID: self._user_id, SystemVariableKey.DIALOGUE_COUNT: dialogue_count, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.WORKFLOW_ID: workflow.id, @@ -139,13 +143,9 @@ def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStrea Process generate task pipeline. :return: """ - db.session.refresh(self._workflow) - db.session.refresh(self._user) - db.session.close() - # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, self._application_generate_entity.query + conversation_id=self._conversation_id, query=self._application_generate_entity.query ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) @@ -171,12 +171,12 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] return ChatbotAppBlockingResponse( task_id=stream_response.task_id, data=ChatbotAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - conversation_id=self._conversation.id, - message_id=self._message.id, + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, answer=self._task_state.answer, - created_at=int(self._message.created_at.timestamp()), + created_at=self._message_created_at, **extras, ), ) @@ -194,9 +194,9 @@ def _to_stream_response( """ for stream_response in generator: yield ChatbotAppStreamResponse( - conversation_id=self._conversation.id, - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), + conversation_id=self._conversation_id, + message_id=self._message_id, + created_at=self._message_created_at, stream_response=stream_response, ) @@ -214,7 +214,7 @@ def _wrapper_process_stream_response( tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id - features_dict = self._workflow.features_dict + features_dict = self._workflow_features_dict if ( features_dict.get("text_to_speech") @@ -274,26 +274,33 @@ def _process_stream_response( if isinstance(event, QueuePingEvent): yield self._ping_stream_response() elif isinstance(event, QueueErrorEvent): - err = self._handle_error(event, self._message) + with Session(db.engine) as session: + err = self._handle_error(event=event, session=session, message_id=self._message_id) + session.commit() yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): # override graph runtime state graph_runtime_state = event.graph_runtime_state - # init workflow run - workflow_run = self._handle_workflow_run_start() - - self._refetch_message() - self._message.workflow_run_id = workflow_run.id - - db.session.commit() - db.session.refresh(self._message) - db.session.close() - - yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + with Session(db.engine) as session: + # init workflow run + workflow_run = self._handle_workflow_run_start( + session=session, + workflow_id=self._workflow_id, + user_id=self._user_id, + created_by_role=self._created_by_role, + ) + message = self._get_message(session=session) + if not message: + raise ValueError(f"Message not found: {self._message_id}") + message.workflow_run_id = workflow_run.id + session.commit() + + workflow_start_resp = self._workflow_start_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + yield workflow_start_resp elif isinstance( event, QueueNodeRetryEvent, @@ -304,28 +311,28 @@ def _process_stream_response( workflow_run=workflow_run, event=event ) - response = self._workflow_node_retry_to_stream_response( + node_retry_resp = self._workflow_node_retry_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if response: - yield response + if node_retry_resp: + yield node_retry_resp elif isinstance(event, QueueNodeStartedEvent): if not workflow_run: raise ValueError("workflow run not initialized.") workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) - response_start = self._workflow_node_start_to_stream_response( + node_start_resp = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if response_start: - yield response_start + if node_start_resp: + yield node_start_resp elif isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._handle_workflow_node_execution_success(event) @@ -333,25 +340,24 @@ def _process_stream_response( if event.node_type in [NodeType.ANSWER, NodeType.END]: self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) - response_finish = self._workflow_node_finish_to_stream_response( + node_finish_resp = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if response_finish: - yield response_finish + if node_finish_resp: + yield node_finish_resp elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) - response_finish = self._workflow_node_finish_to_stream_response( + node_finish_resp = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - - if response: - yield response + if node_finish_resp: + yield node_finish_resp elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: @@ -395,20 +401,24 @@ def _process_stream_response( if not graph_runtime_state: raise ValueError("workflow run not initialized.") - workflow_run = self._handle_workflow_run_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - conversation_id=self._conversation.id, - trace_manager=trace_manager, - ) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_success( + session=session, + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + conversation_id=self._conversation_id, + trace_manager=trace_manager, + ) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + yield workflow_finish_resp self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowPartialSuccessEvent): if not workflow_run: @@ -417,21 +427,25 @@ def _process_stream_response( if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_partial_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=None, - trace_manager=trace_manager, - ) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_partial_success( + session=session, + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + yield workflow_finish_resp self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowFailedEvent): if not workflow_run: @@ -440,71 +454,73 @@ def _process_stream_response( if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_failed( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED, - error=event.error, - conversation_id=self._conversation.id, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count, - ) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) - - err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) - yield self._error_to_stream_response(self._handle_error(err_event, self._message)) - break - elif isinstance(event, QueueStopEvent): - if workflow_run and graph_runtime_state: + with Session(db.engine) as session: workflow_run = self._handle_workflow_run_failed( + session=session, workflow_run=workflow_run, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.STOPPED, - error=event.get_stop_reason(), - conversation_id=self._conversation.id, + status=WorkflowRunStatus.FAILED, + error=event.error, + conversation_id=self._conversation_id, trace_manager=trace_manager, + exceptions_count=event.exceptions_count, ) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) - - # Save message - self._save_message(graph_runtime_state=graph_runtime_state) + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) + err = self._handle_error(event=err_event, session=session, message_id=self._message_id) + session.commit() + yield workflow_finish_resp + yield self._error_to_stream_response(err) + break + elif isinstance(event, QueueStopEvent): + if workflow_run and graph_runtime_state: + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_failed( + session=session, + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.STOPPED, + error=event.get_stop_reason(), + conversation_id=self._conversation_id, + trace_manager=trace_manager, + ) + + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + ) + # Save message + self._save_message(session=session, graph_runtime_state=graph_runtime_state) + session.commit() + yield workflow_finish_resp yield self._message_end_to_stream_response() break elif isinstance(event, QueueRetrieverResourcesEvent): self._handle_retriever_resources(event) - self._refetch_message() - - self._message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) - - db.session.commit() - db.session.refresh(self._message) - db.session.close() + with Session(db.engine) as session: + message = self._get_message(session=session) + message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + session.commit() elif isinstance(event, QueueAnnotationReplyEvent): self._handle_annotation_reply(event) - self._refetch_message() - - self._message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) - - db.session.commit() - db.session.refresh(self._message) - db.session.close() + with Session(db.engine) as session: + message = self._get_message(session=session) + message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + session.commit() elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: @@ -521,7 +537,7 @@ def _process_stream_response( self._task_state.answer += delta_text yield self._message_to_stream_response( - answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector + answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector ) elif isinstance(event, QueueMessageReplaceEvent): # published by moderation @@ -536,7 +552,9 @@ def _process_stream_response( yield self._message_replace_to_stream_response(answer=output_moderation_answer) # Save message - self._save_message(graph_runtime_state=graph_runtime_state) + with Session(db.engine) as session: + self._save_message(session=session, graph_runtime_state=graph_runtime_state) + session.commit() yield self._message_end_to_stream_response() else: @@ -549,54 +567,46 @@ def _process_stream_response( if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: - self._refetch_message() - - self._message.answer = self._task_state.answer - self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.message_metadata = ( + def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: + message = self._get_message(session=session) + message.answer = self._task_state.answer + message.provider_response_latency = time.perf_counter() - self._start_at + message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None ) message_files = [ MessageFile( - message_id=self._message.id, + message_id=message.id, type=file["type"], transfer_method=file["transfer_method"], url=file["remote_url"], belongs_to="assistant", upload_file_id=file["related_id"], created_by_role=CreatedByRole.ACCOUNT - if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else CreatedByRole.END_USER, - created_by=self._message.from_account_id or self._message.from_end_user_id or "", + created_by=message.from_account_id or message.from_end_user_id or "", ) for file in self._recorded_files ] - db.session.add_all(message_files) + session.add_all(message_files) if graph_runtime_state and graph_runtime_state.llm_usage: usage = graph_runtime_state.llm_usage - self._message.message_tokens = usage.prompt_tokens - self._message.message_unit_price = usage.prompt_unit_price - self._message.message_price_unit = usage.prompt_price_unit - self._message.answer_tokens = usage.completion_tokens - self._message.answer_unit_price = usage.completion_unit_price - self._message.answer_price_unit = usage.completion_price_unit - self._message.total_price = usage.total_price - self._message.currency = usage.currency - + message.message_tokens = usage.prompt_tokens + message.message_unit_price = usage.prompt_unit_price + message.message_price_unit = usage.prompt_price_unit + message.answer_tokens = usage.completion_tokens + message.answer_unit_price = usage.completion_unit_price + message.answer_price_unit = usage.completion_price_unit + message.total_price = usage.total_price + message.currency = usage.currency self._task_state.metadata["usage"] = jsonable_encoder(usage) else: self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) - - db.session.commit() - message_was_created.send( - self._message, + message, application_generate_entity=self._application_generate_entity, - conversation=self._conversation, - is_first_message=self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras, ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: @@ -613,7 +623,7 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, - id=self._message.id, + id=self._message_id, files=self._recorded_files, metadata=extras.get("metadata", {}), ) @@ -641,11 +651,9 @@ def _handle_output_moderation_chunk(self, text: str) -> bool: return False - def _refetch_message(self) -> None: - """ - Refetch message. - :return: - """ - message = db.session.query(Message).filter(Message.id == self._message.id).first() - if message: - self._message = message + def _get_message(self, *, session: Session): + stmt = select(Message).where(Message.id == self._message_id) + message = session.scalar(stmt) + if not message: + raise ValueError(f"Message not found: {self._message_id}") + return message diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index c2e35faf89ba15..dcd9463b8abd0f 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -70,7 +70,6 @@ def _handle_response( queue_manager=queue_manager, conversation=conversation, message=message, - user=user, stream=stream, ) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index c47b38f5600f4d..574596d4f5a77c 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -3,6 +3,8 @@ from collections.abc import Generator from typing import Any, Optional, Union +from sqlalchemy.orm import Session + from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager @@ -50,6 +52,7 @@ from core.workflow.enums import SystemVariableKey from extensions.ext_database import db from models.account import Account +from models.enums import CreatedByRole from models.model import EndUser from models.workflow import ( Workflow, @@ -68,8 +71,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _workflow: Workflow - _user: Union[Account, EndUser] _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity _workflow_system_variables: dict[SystemVariableKey, Any] @@ -83,25 +84,27 @@ def __init__( user: Union[Account, EndUser], stream: bool, ) -> None: - """ - Initialize GenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param workflow: workflow - :param queue_manager: queue manager - :param user: user - :param stream: is streamed - """ - super().__init__(application_generate_entity, queue_manager, user, stream) + super().__init__( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) - if isinstance(self._user, EndUser): - user_id = self._user.session_id + if isinstance(user, EndUser): + self._user_id = user.session_id + self._created_by_role = CreatedByRole.END_USER + elif isinstance(user, Account): + self._user_id = user.id + self._created_by_role = CreatedByRole.ACCOUNT else: - user_id = self._user.id + raise ValueError(f"Invalid user type: {type(user)}") + + self._workflow_id = workflow.id + self._workflow_features_dict = workflow.features_dict - self._workflow = workflow self._workflow_system_variables = { SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_id, + SystemVariableKey.USER_ID: self._user_id, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, @@ -115,10 +118,6 @@ def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStr Process generate task pipeline. :return: """ - db.session.refresh(self._workflow) - db.session.refresh(self._user) - db.session.close() - generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) @@ -185,7 +184,7 @@ def _wrapper_process_stream_response( tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id - features_dict = self._workflow.features_dict + features_dict = self._workflow_features_dict if ( features_dict.get("text_to_speech") @@ -242,18 +241,26 @@ def _process_stream_response( if isinstance(event, QueuePingEvent): yield self._ping_stream_response() elif isinstance(event, QueueErrorEvent): - err = self._handle_error(event) + err = self._handle_error(event=event) yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): # override graph runtime state graph_runtime_state = event.graph_runtime_state - # init workflow run - workflow_run = self._handle_workflow_run_start() - yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + with Session(db.engine) as session: + # init workflow run + workflow_run = self._handle_workflow_run_start( + session=session, + workflow_id=self._workflow_id, + user_id=self._user_id, + created_by_role=self._created_by_role, + ) + start_resp = self._workflow_start_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + yield start_resp elif isinstance( event, QueueNodeRetryEvent, @@ -350,22 +357,28 @@ def _process_stream_response( if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - conversation_id=None, - trace_manager=trace_manager, - ) - - # save workflow app log - self._save_workflow_app_log(workflow_run) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_success( + session=session, + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + conversation_id=None, + trace_manager=trace_manager, + ) + + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) + + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + ) + session.commit() + yield workflow_finish_resp elif isinstance(event, QueueWorkflowPartialSuccessEvent): if not workflow_run: raise ValueError("workflow run not initialized.") @@ -373,49 +386,58 @@ def _process_stream_response( if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_partial_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=None, - trace_manager=trace_manager, - ) - - # save workflow app log - self._save_workflow_app_log(workflow_run) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_partial_success( + session=session, + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) + + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) + + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield workflow_finish_resp elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): if not workflow_run: raise ValueError("workflow run not initialized.") if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_failed( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED - if isinstance(event, QueueWorkflowFailedEvent) - else WorkflowRunStatus.STOPPED, - error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), - conversation_id=None, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, - ) - - # save workflow app log - self._save_workflow_app_log(workflow_run) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_failed( + session=session, + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, + error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + conversation_id=None, + trace_manager=trace_manager, + exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, + ) + + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) + + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + yield workflow_finish_resp elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: @@ -435,7 +457,7 @@ def _process_stream_response( if tts_publisher: tts_publisher.publish(None) - def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: + def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None: """ Save workflow app log. :return: @@ -457,12 +479,10 @@ def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: workflow_app_log.workflow_id = workflow_run.workflow_id workflow_app_log.workflow_run_id = workflow_run.id workflow_app_log.created_from = created_from.value - workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user" - workflow_app_log.created_by = self._user.id + workflow_app_log.created_by_role = self._created_by_role + workflow_app_log.created_by = self._user_id - db.session.add(workflow_app_log) - db.session.commit() - db.session.close() + session.add(workflow_app_log) def _text_chunk_to_stream_response( self, text: str, from_variable_selector: Optional[list[str]] = None diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 03a81353d02625..e363a7f64244d3 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,6 +1,9 @@ import logging import time -from typing import Optional, Union +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ( @@ -17,9 +20,7 @@ from core.errors.error import QuotaExceededError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration -from extensions.ext_database import db -from models.account import Account -from models.model import EndUser, Message +from models.model import Message logger = logging.getLogger(__name__) @@ -36,7 +37,6 @@ def __init__( self, application_generate_entity: AppGenerateEntity, queue_manager: AppQueueManager, - user: Union[Account, EndUser], stream: bool, ) -> None: """ @@ -48,18 +48,11 @@ def __init__( """ self._application_generate_entity = application_generate_entity self._queue_manager = queue_manager - self._user = user self._start_at = time.perf_counter() self._output_moderation_handler = self._init_output_moderation() self._stream = stream - def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None): - """ - Handle error event. - :param event: event - :param message: message - :return: - """ + def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): logger.debug("error: %s", event.error) e = event.error err: Exception @@ -71,16 +64,17 @@ def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = Non else: err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) - if message: - refetch_message = db.session.query(Message).filter(Message.id == message.id).first() - - if refetch_message: - err_desc = self._error_to_desc(err) - refetch_message.status = "error" - refetch_message.error = err_desc + if not message_id or not session: + return err - db.session.commit() + stmt = select(Message).where(Message.id == message_id) + message = session.scalar(stmt) + if not message: + return err + err_desc = self._error_to_desc(err) + message.status = "error" + message.error = err_desc return err def _error_to_desc(self, e: Exception) -> str: diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index b9f8e7ca560ce7..c84f8ba3e450cc 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -5,6 +5,9 @@ from threading import Thread from typing import Optional, Union, cast +from sqlalchemy import select +from sqlalchemy.orm import Session + from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -55,8 +58,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db -from models.account import Account -from models.model import AppMode, Conversation, EndUser, Message, MessageAgentThought +from models.model import AppMode, Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) @@ -77,23 +79,21 @@ def __init__( queue_manager: AppQueueManager, conversation: Conversation, message: Message, - user: Union[Account, EndUser], stream: bool, ) -> None: - """ - Initialize GenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param user: user - :param stream: stream - """ - super().__init__(application_generate_entity, queue_manager, user, stream) + super().__init__( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) self._model_config = application_generate_entity.model_conf self._app_config = application_generate_entity.app_config - self._conversation = conversation - self._message = message + + self._conversation_id = conversation.id + self._conversation_mode = conversation.mode + + self._message_id = message.id + self._message_created_at = int(message.created_at.timestamp()) self._task_state = EasyUITaskState( llm_result=LLMResult( @@ -113,18 +113,10 @@ def process( CompletionAppBlockingResponse, Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], ]: - """ - Process generate task pipeline. - :return: - """ - db.session.refresh(self._conversation) - db.session.refresh(self._message) - db.session.close() - if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, self._application_generate_entity.query or "" + conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) @@ -148,15 +140,15 @@ def _to_blocking_response( if self._task_state.metadata: extras["metadata"] = self._task_state.metadata response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] - if self._conversation.mode == AppMode.COMPLETION.value: + if self._conversation_mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( task_id=self._application_generate_entity.task_id, data=CompletionAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - message_id=self._message.id, + id=self._message_id, + mode=self._conversation_mode, + message_id=self._message_id, answer=cast(str, self._task_state.llm_result.message.content), - created_at=int(self._message.created_at.timestamp()), + created_at=self._message_created_at, **extras, ), ) @@ -164,12 +156,12 @@ def _to_blocking_response( response = ChatbotAppBlockingResponse( task_id=self._application_generate_entity.task_id, data=ChatbotAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - conversation_id=self._conversation.id, - message_id=self._message.id, + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, answer=cast(str, self._task_state.llm_result.message.content), - created_at=int(self._message.created_at.timestamp()), + created_at=self._message_created_at, **extras, ), ) @@ -190,15 +182,15 @@ def _to_stream_response( for stream_response in generator: if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): yield CompletionAppStreamResponse( - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), + message_id=self._message_id, + created_at=self._message_created_at, stream_response=stream_response, ) else: yield ChatbotAppStreamResponse( - conversation_id=self._conversation.id, - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), + conversation_id=self._conversation_id, + message_id=self._message_id, + created_at=self._message_created_at, stream_response=stream_response, ) @@ -265,7 +257,9 @@ def _process_stream_response( event = message.event if isinstance(event, QueueErrorEvent): - err = self._handle_error(event, self._message) + with Session(db.engine) as session: + err = self._handle_error(event=event, session=session, message_id=self._message_id) + session.commit() yield self._error_to_stream_response(err) break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): @@ -283,10 +277,12 @@ def _process_stream_response( self._task_state.llm_result.message.content = output_moderation_answer yield self._message_replace_to_stream_response(answer=output_moderation_answer) - # Save message - self._save_message(trace_manager) - - yield self._message_end_to_stream_response() + with Session(db.engine) as session: + # Save message + self._save_message(session=session, trace_manager=trace_manager) + session.commit() + message_end_resp = self._message_end_to_stream_response() + yield message_end_resp elif isinstance(event, QueueRetrieverResourcesEvent): self._handle_retriever_resources(event) elif isinstance(event, QueueAnnotationReplyEvent): @@ -320,9 +316,15 @@ def _process_stream_response( self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - yield self._message_to_stream_response(cast(str, delta_text), self._message.id) + yield self._message_to_stream_response( + answer=cast(str, delta_text), + message_id=self._message_id, + ) else: - yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id) + yield self._agent_message_to_stream_response( + answer=cast(str, delta_text), + message_id=self._message_id, + ) elif isinstance(event, QueueMessageReplaceEvent): yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): @@ -334,7 +336,7 @@ def _process_stream_response( if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None: + def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None: """ Save message. :return: @@ -342,53 +344,46 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No llm_result = self._task_state.llm_result usage = llm_result.usage - message = db.session.query(Message).filter(Message.id == self._message.id).first() + message_stmt = select(Message).where(Message.id == self._message_id) + message = session.scalar(message_stmt) if not message: - raise Exception(f"Message {self._message.id} not found") - self._message = message - conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() + raise ValueError(f"message {self._message_id} not found") + conversation_stmt = select(Conversation).where(Conversation.id == self._conversation_id) + conversation = session.scalar(conversation_stmt) if not conversation: - raise Exception(f"Conversation {self._conversation.id} not found") - self._conversation = conversation + raise ValueError(f"Conversation {self._conversation_id} not found") - self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( self._model_config.mode, self._task_state.llm_result.prompt_messages ) - self._message.message_tokens = usage.prompt_tokens - self._message.message_unit_price = usage.prompt_unit_price - self._message.message_price_unit = usage.prompt_price_unit - self._message.answer = ( + message.message_tokens = usage.prompt_tokens + message.message_unit_price = usage.prompt_unit_price + message.message_price_unit = usage.prompt_price_unit + message.answer = ( PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip()) if llm_result.message.content else "" ) - self._message.answer_tokens = usage.completion_tokens - self._message.answer_unit_price = usage.completion_unit_price - self._message.answer_price_unit = usage.completion_price_unit - self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.total_price = usage.total_price - self._message.currency = usage.currency - self._message.message_metadata = ( + message.answer_tokens = usage.completion_tokens + message.answer_unit_price = usage.completion_unit_price + message.answer_price_unit = usage.completion_price_unit + message.provider_response_latency = time.perf_counter() - self._start_at + message.total_price = usage.total_price + message.currency = usage.currency + message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None ) - db.session.commit() - if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id + TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id ) ) message_was_created.send( - self._message, + message, application_generate_entity=self._application_generate_entity, - conversation=self._conversation, - is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} - and hasattr(self._application_generate_entity, "conversation_id") - and self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras, ) def _handle_stop(self, event: QueueStopEvent) -> None: @@ -434,7 +429,7 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, - id=self._message.id, + id=self._message_id, metadata=extras.get("metadata", {}), ) diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 007543f6d0d1f2..15f2c25c66a3d2 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -36,7 +36,7 @@ class MessageCycleManage: ] _task_state: Union[EasyUITaskState, WorkflowTaskState] - def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: + def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: """ Generate conversation name. :param conversation: conversation @@ -56,7 +56,7 @@ def _generate_conversation_name(self, conversation: Conversation, query: str) -> target=self._generate_conversation_name_worker, kwargs={ "flask_app": current_app._get_current_object(), # type: ignore - "conversation_id": conversation.id, + "conversation_id": conversation_id, "query": query, }, ) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index f581e564f224ce..2692008c6653d6 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -5,6 +5,7 @@ from typing import Any, Optional, Union, cast from uuid import uuid4 +from sqlalchemy import func, select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity @@ -63,27 +64,34 @@ class WorkflowCycleManage: _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] - _workflow: Workflow - _user: Union[Account, EndUser] _task_state: WorkflowTaskState _workflow_system_variables: dict[SystemVariableKey, Any] _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] - def _handle_workflow_run_start(self) -> WorkflowRun: - max_sequence = ( - db.session.query(db.func.max(WorkflowRun.sequence_number)) - .filter(WorkflowRun.tenant_id == self._workflow.tenant_id) - .filter(WorkflowRun.app_id == self._workflow.app_id) - .scalar() - or 0 + def _handle_workflow_run_start( + self, + *, + session: Session, + workflow_id: str, + user_id: str, + created_by_role: CreatedByRole, + ) -> WorkflowRun: + workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) + workflow = session.scalar(workflow_stmt) + if not workflow: + raise ValueError(f"Workflow not found: {workflow_id}") + + max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where( + WorkflowRun.tenant_id == workflow.tenant_id, + WorkflowRun.app_id == workflow.app_id, ) + max_sequence = session.scalar(max_sequence_stmt) or 0 new_sequence_number = max_sequence + 1 inputs = {**self._application_generate_entity.inputs} for key, value in (self._workflow_system_variables or {}).items(): if key.value == "conversation": continue - inputs[f"sys.{key.value}"] = value triggered_from = ( @@ -96,33 +104,32 @@ def _handle_workflow_run_start(self) -> WorkflowRun: inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) # init workflow run - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = WorkflowRun() - system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID] - workflow_run.id = system_id or str(uuid4()) - workflow_run.tenant_id = self._workflow.tenant_id - workflow_run.app_id = self._workflow.app_id - workflow_run.sequence_number = new_sequence_number - workflow_run.workflow_id = self._workflow.id - workflow_run.type = self._workflow.type - workflow_run.triggered_from = triggered_from.value - workflow_run.version = self._workflow.version - workflow_run.graph = self._workflow.graph - workflow_run.inputs = json.dumps(inputs) - workflow_run.status = WorkflowRunStatus.RUNNING - workflow_run.created_by_role = ( - CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER - ) - workflow_run.created_by = self._user.id - workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) - - session.add(workflow_run) - session.commit() + workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4())) + + workflow_run = WorkflowRun() + workflow_run.id = workflow_run_id + workflow_run.tenant_id = workflow.tenant_id + workflow_run.app_id = workflow.app_id + workflow_run.sequence_number = new_sequence_number + workflow_run.workflow_id = workflow.id + workflow_run.type = workflow.type + workflow_run.triggered_from = triggered_from.value + workflow_run.version = workflow.version + workflow_run.graph = workflow.graph + workflow_run.inputs = json.dumps(inputs) + workflow_run.status = WorkflowRunStatus.RUNNING + workflow_run.created_by_role = created_by_role + workflow_run.created_by = user_id + workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) + + session.add(workflow_run) return workflow_run def _handle_workflow_run_success( self, + *, + session: Session, workflow_run: WorkflowRun, start_at: float, total_tokens: int, @@ -141,7 +148,7 @@ def _handle_workflow_run_success( :param conversation_id: conversation id :return: """ - workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id) outputs = WorkflowEntry.handle_special_values(outputs) @@ -152,9 +159,6 @@ def _handle_workflow_run_success( workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - db.session.refresh(workflow_run) - if trace_manager: trace_manager.add_trace_task( TraceTask( @@ -165,12 +169,12 @@ def _handle_workflow_run_success( ) ) - db.session.close() - return workflow_run def _handle_workflow_run_partial_success( self, + *, + session: Session, workflow_run: WorkflowRun, start_at: float, total_tokens: int, @@ -190,7 +194,7 @@ def _handle_workflow_run_partial_success( :param conversation_id: conversation id :return: """ - workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id) outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) @@ -201,8 +205,6 @@ def _handle_workflow_run_partial_success( workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - db.session.commit() - db.session.refresh(workflow_run) if trace_manager: trace_manager.add_trace_task( @@ -214,12 +216,12 @@ def _handle_workflow_run_partial_success( ) ) - db.session.close() - return workflow_run def _handle_workflow_run_failed( self, + *, + session: Session, workflow_run: WorkflowRun, start_at: float, total_tokens: int, @@ -240,7 +242,7 @@ def _handle_workflow_run_failed( :param error: error message :return: """ - workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id) workflow_run.status = status.value workflow_run.error = error @@ -249,21 +251,18 @@ def _handle_workflow_run_failed( workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - db.session.commit() - running_workflow_node_executions = ( - db.session.query(WorkflowNodeExecution) - .filter( - WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, - WorkflowNodeExecution.app_id == workflow_run.app_id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == workflow_run.id, - WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, - ) - .all() + stmt = select(WorkflowNodeExecution).where( + WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, + WorkflowNodeExecution.app_id == workflow_run.app_id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == workflow_run.id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, ) + running_workflow_node_executions = session.scalars(stmt).all() + for workflow_node_execution in running_workflow_node_executions: workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = error @@ -271,13 +270,6 @@ def _handle_workflow_run_failed( workflow_node_execution.elapsed_time = ( workflow_node_execution.finished_at - workflow_node_execution.created_at ).total_seconds() - db.session.commit() - - db.session.close() - - # with Session(db.engine, expire_on_commit=False) as session: - # session.add(workflow_run) - # session.refresh(workflow_run) if trace_manager: trace_manager.add_trace_task( @@ -485,14 +477,14 @@ def _handle_workflow_node_execution_retried( ################################################# def _workflow_start_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun + self, + *, + session: Session, + task_id: str, + workflow_run: WorkflowRun, ) -> WorkflowStartStreamResponse: - """ - Workflow start to stream response. - :param task_id: task id - :param workflow_run: workflow run - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return WorkflowStartStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -506,36 +498,32 @@ def _workflow_start_to_stream_response( ) def _workflow_finish_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun + self, + *, + session: Session, + task_id: str, + workflow_run: WorkflowRun, ) -> WorkflowFinishStreamResponse: - """ - Workflow finish to stream response. - :param task_id: task id - :param workflow_run: workflow run - :return: - """ - # Attach WorkflowRun to an active session so "created_by_role" can be accessed. - workflow_run = db.session.merge(workflow_run) - - # Refresh to ensure any expired attributes are fully loaded - db.session.refresh(workflow_run) - created_by = None - if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value: - created_by_account = workflow_run.created_by_account - if created_by_account: + if workflow_run.created_by_role == CreatedByRole.ACCOUNT: + stmt = select(Account).where(Account.id == workflow_run.created_by) + account = session.scalar(stmt) + if account: created_by = { - "id": created_by_account.id, - "name": created_by_account.name, - "email": created_by_account.email, + "id": account.id, + "name": account.name, + "email": account.email, } - else: - created_by_end_user = workflow_run.created_by_end_user - if created_by_end_user: + elif workflow_run.created_by_role == CreatedByRole.END_USER: + stmt = select(EndUser).where(EndUser.id == workflow_run.created_by) + end_user = session.scalar(stmt) + if end_user: created_by = { - "id": created_by_end_user.id, - "user": created_by_end_user.session_id, + "id": end_user.id, + "user": end_user.session_id, } + else: + raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}") return WorkflowFinishStreamResponse( task_id=task_id, @@ -895,14 +883,14 @@ def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any return None - def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: + def _refetch_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: """ Refetch workflow run :param workflow_run_id: workflow run id :return: """ - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() - + stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) + workflow_run = session.scalar(stmt) if not workflow_run: raise WorkflowRunNotFoundError(workflow_run_id) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index f538eaef5bd570..691cb8d400964c 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -9,6 +9,8 @@ from uuid import UUID, uuid4 from flask import current_app +from sqlalchemy import select +from sqlalchemy.orm import Session from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token from core.ops.entities.config_entity import ( @@ -329,15 +331,15 @@ def __init__( ): self.trace_type = trace_type self.message_id = message_id - self.workflow_run = workflow_run + self.workflow_run_id = workflow_run.id if workflow_run else None self.conversation_id = conversation_id self.user_id = user_id self.timer = timer - self.kwargs = kwargs self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") - self.app_id = None + self.kwargs = kwargs + def execute(self): return self.preprocess() @@ -345,19 +347,23 @@ def preprocess(self): preprocess_map = { TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs), TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace( - self.workflow_run, self.conversation_id, self.user_id + workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id + ), + TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id), + TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( + message_id=self.message_id, timer=self.timer, **self.kwargs ), - TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id), - TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs), TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace( - self.message_id, self.timer, **self.kwargs + message_id=self.message_id, timer=self.timer, **self.kwargs ), TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace( - self.message_id, self.timer, **self.kwargs + message_id=self.message_id, timer=self.timer, **self.kwargs + ), + TraceTaskName.TOOL_TRACE: lambda: self.tool_trace( + message_id=self.message_id, timer=self.timer, **self.kwargs ), - TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(self.message_id, self.timer, **self.kwargs), TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( - self.conversation_id, self.timer, **self.kwargs + conversation_id=self.conversation_id, timer=self.timer, **self.kwargs ), } @@ -367,86 +373,100 @@ def preprocess(self): def conversation_trace(self, **kwargs): return kwargs - def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id): - if not workflow_run: - raise ValueError("Workflow run not found") - - db.session.merge(workflow_run) - db.session.refresh(workflow_run) - - workflow_id = workflow_run.workflow_id - tenant_id = workflow_run.tenant_id - workflow_run_id = workflow_run.id - workflow_run_elapsed_time = workflow_run.elapsed_time - workflow_run_status = workflow_run.status - workflow_run_inputs = workflow_run.inputs_dict - workflow_run_outputs = workflow_run.outputs_dict - workflow_run_version = workflow_run.version - error = workflow_run.error or "" - - total_tokens = workflow_run.total_tokens - - file_list = workflow_run_inputs.get("sys.file") or [] - query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" - - # get workflow_app_log_id - workflow_app_log_data = ( - db.session.query(WorkflowAppLog) - .filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id) - .first() - ) - workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None - # get message_id - message_data = ( - db.session.query(Message.id) - .filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id) - .first() - ) - message_id = str(message_data.id) if message_data else None - - metadata = { - "workflow_id": workflow_id, - "conversation_id": conversation_id, - "workflow_run_id": workflow_run_id, - "tenant_id": tenant_id, - "elapsed_time": workflow_run_elapsed_time, - "status": workflow_run_status, - "version": workflow_run_version, - "total_tokens": total_tokens, - "file_list": file_list, - "triggered_form": workflow_run.triggered_from, - "user_id": user_id, - } + def workflow_trace( + self, + *, + workflow_run_id: str | None, + conversation_id: str | None, + user_id: str | None, + ): + if not workflow_run_id: + return {} - workflow_trace_info = WorkflowTraceInfo( - workflow_data=workflow_run.to_dict(), - conversation_id=conversation_id, - workflow_id=workflow_id, - tenant_id=tenant_id, - workflow_run_id=workflow_run_id, - workflow_run_elapsed_time=workflow_run_elapsed_time, - workflow_run_status=workflow_run_status, - workflow_run_inputs=workflow_run_inputs, - workflow_run_outputs=workflow_run_outputs, - workflow_run_version=workflow_run_version, - error=error, - total_tokens=total_tokens, - file_list=file_list, - query=query, - metadata=metadata, - workflow_app_log_id=workflow_app_log_id, - message_id=message_id, - start_time=workflow_run.created_at, - end_time=workflow_run.finished_at, - ) + with Session(db.engine) as session: + workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) + workflow_run = session.scalars(workflow_run_stmt).first() + if not workflow_run: + raise ValueError("Workflow run not found") + + workflow_id = workflow_run.workflow_id + tenant_id = workflow_run.tenant_id + workflow_run_id = workflow_run.id + workflow_run_elapsed_time = workflow_run.elapsed_time + workflow_run_status = workflow_run.status + workflow_run_inputs = workflow_run.inputs_dict + workflow_run_outputs = workflow_run.outputs_dict + workflow_run_version = workflow_run.version + error = workflow_run.error or "" + + total_tokens = workflow_run.total_tokens + + file_list = workflow_run_inputs.get("sys.file") or [] + query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" + + # get workflow_app_log_id + workflow_app_log_data_stmt = select(WorkflowAppLog.id).where( + WorkflowAppLog.tenant_id == tenant_id, + WorkflowAppLog.app_id == workflow_run.app_id, + WorkflowAppLog.workflow_run_id == workflow_run.id, + ) + workflow_app_log_id = session.scalar(workflow_app_log_data_stmt) + # get message_id + message_id = None + if conversation_id: + message_data_stmt = select(Message.id).where( + Message.conversation_id == conversation_id, + Message.workflow_run_id == workflow_run_id, + ) + message_id = session.scalar(message_data_stmt) + + metadata = { + "workflow_id": workflow_id, + "conversation_id": conversation_id, + "workflow_run_id": workflow_run_id, + "tenant_id": tenant_id, + "elapsed_time": workflow_run_elapsed_time, + "status": workflow_run_status, + "version": workflow_run_version, + "total_tokens": total_tokens, + "file_list": file_list, + "triggered_form": workflow_run.triggered_from, + "user_id": user_id, + } + workflow_trace_info = WorkflowTraceInfo( + workflow_data=workflow_run.to_dict(), + conversation_id=conversation_id, + workflow_id=workflow_id, + tenant_id=tenant_id, + workflow_run_id=workflow_run_id, + workflow_run_elapsed_time=workflow_run_elapsed_time, + workflow_run_status=workflow_run_status, + workflow_run_inputs=workflow_run_inputs, + workflow_run_outputs=workflow_run_outputs, + workflow_run_version=workflow_run_version, + error=error, + total_tokens=total_tokens, + file_list=file_list, + query=query, + metadata=metadata, + workflow_app_log_id=workflow_app_log_id, + message_id=message_id, + start_time=workflow_run.created_at, + end_time=workflow_run.finished_at, + ) return workflow_trace_info - def message_trace(self, message_id): + def message_trace(self, message_id: str | None): + if not message_id: + return {} message_data = get_message_data(message_id) if not message_data: return {} - conversation_mode = db.session.query(Conversation.mode).filter_by(id=message_data.conversation_id).first() + conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id) + conversation_mode = db.session.scalars(conversation_mode_stmt).all() + if not conversation_mode or len(conversation_mode) == 0: + return {} conversation_mode = conversation_mode[0] created_at = message_data.created_at inputs = message_data.message diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 998eba9ea9791b..8b06df1930595d 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -18,7 +18,7 @@ def filter_none_values(data: dict): return new_data -def get_message_data(message_id): +def get_message_data(message_id: str): return db.session.query(Message).filter(Message.id == message_id).first() diff --git a/api/models/account.py b/api/models/account.py index 88c96da1a149d5..35a28df7505943 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -3,6 +3,7 @@ from flask_login import UserMixin # type: ignore from sqlalchemy import func +from sqlalchemy.orm import Mapped, mapped_column from .engine import db from .types import StringUUID @@ -20,7 +21,7 @@ class Account(UserMixin, db.Model): # type: ignore[name-defined] __tablename__ = "accounts" __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=True) diff --git a/api/models/model.py b/api/models/model.py index 2a593f08298199..d2d4d5853fd2b9 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -530,13 +530,13 @@ class Conversation(db.Model): # type: ignore[name-defined] db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) app_model_config_id = db.Column(StringUUID, nullable=True) model_provider = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) model_id = db.Column(db.String(255), nullable=True) - mode = db.Column(db.String(255), nullable=False) + mode: Mapped[str] = mapped_column(db.String(255)) name = db.Column(db.String(255), nullable=False) summary = db.Column(db.Text) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) @@ -770,7 +770,7 @@ class Message(db.Model): # type: ignore[name-defined] db.Index("message_created_at_idx", "created_at"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) model_provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) @@ -797,7 +797,7 @@ class Message(db.Model): # type: ignore[name-defined] from_source = db.Column(db.String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) @@ -1322,7 +1322,7 @@ class EndUser(UserMixin, db.Model): # type: ignore[name-defined] external_user_id = db.Column(db.String(255), nullable=True) name = db.Column(db.String(255)) is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - session_id = db.Column(db.String(255), nullable=False) + session_id: Mapped[str] = mapped_column() created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py index 880e044d073a67..78a7f8169fe634 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -392,40 +392,28 @@ class WorkflowRun(db.Model): # type: ignore[name-defined] db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - sequence_number = db.Column(db.Integer, nullable=False) - workflow_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - triggered_from = db.Column(db.String(255), nullable=False) - version = db.Column(db.String(255), nullable=False) - graph = db.Column(db.Text) - inputs = db.Column(db.Text) - status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) + sequence_number: Mapped[int] = mapped_column() + workflow_id: Mapped[str] = mapped_column(StringUUID) + type: Mapped[str] = mapped_column(db.String(255)) + triggered_from: Mapped[str] = mapped_column(db.String(255)) + version: Mapped[str] = mapped_column(db.String(255)) + graph: Mapped[str] = mapped_column(db.Text) + inputs: Mapped[str] = mapped_column(db.Text) + status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") - error = db.Column(db.Text) + error: Mapped[str] = mapped_column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) - total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + total_tokens: Mapped[int] = mapped_column(server_default=db.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0")) - created_by_role = db.Column(db.String(255), nullable=False) # account, end_user + created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) finished_at = db.Column(db.DateTime) exceptions_count = db.Column(db.Integer, server_default=db.text("0")) - @property - def created_by_account(self): - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None - - @property - def created_by_end_user(self): - from models.model import EndUser - - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None - @property def graph_dict(self): return json.loads(self.graph) if self.graph else {} @@ -750,11 +738,11 @@ class WorkflowAppLog(db.Model): # type: ignore[name-defined] db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) workflow_id = db.Column(StringUUID, nullable=False) - workflow_run_id = db.Column(StringUUID, nullable=False) + workflow_run_id: Mapped[str] = mapped_column(StringUUID) created_from = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) From 1885d3df9968d778664bacdde4ea7a4d9448070f Mon Sep 17 00:00:00 2001 From: Cemre Mengu Date: Wed, 25 Dec 2024 11:31:01 +0300 Subject: [PATCH 13/65] fix: unquote urls in docker-compose.yaml (#12072) Signed-off-by: -LAN- Co-authored-by: -LAN- --- docker/docker-compose.yaml | 56 +++++++++++++++++----------------- docker/generate_docker_compose | 2 +- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 7122f4a6d0f768..e65ca45858118f 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -15,15 +15,15 @@ x-shared-env: &shared-api-worker-env LOG_FILE: ${LOG_FILE:-/app/logs/server.log} LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20} LOG_FILE_BACKUP_COUNT: ${LOG_FILE_BACKUP_COUNT:-5} - LOG_DATEFORMAT: ${LOG_DATEFORMAT:-"%Y-%m-%d %H:%M:%S"} + LOG_DATEFORMAT: ${LOG_DATEFORMAT:-%Y-%m-%d %H:%M:%S} LOG_TZ: ${LOG_TZ:-UTC} DEBUG: ${DEBUG:-false} FLASK_DEBUG: ${FLASK_DEBUG:-false} SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U} INIT_PASSWORD: ${INIT_PASSWORD:-} DEPLOY_ENV: ${DEPLOY_ENV:-PRODUCTION} - CHECK_UPDATE_URL: ${CHECK_UPDATE_URL:-"https://updates.dify.ai"} - OPENAI_API_BASE: ${OPENAI_API_BASE:-"https://api.openai.com/v1"} + CHECK_UPDATE_URL: ${CHECK_UPDATE_URL:-https://updates.dify.ai} + OPENAI_API_BASE: ${OPENAI_API_BASE:-https://api.openai.com/v1} MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true} FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300} ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} @@ -69,7 +69,7 @@ x-shared-env: &shared-api-worker-env REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false} REDIS_CLUSTERS: ${REDIS_CLUSTERS:-} REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-} - CELERY_BROKER_URL: ${CELERY_BROKER_URL:-"redis://:difyai123456@redis:6379/1"} + CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} BROKER_USE_SSL: ${BROKER_USE_SSL:-false} CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false} CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-} @@ -88,13 +88,13 @@ x-shared-env: &shared-api-worker-env AZURE_BLOB_ACCOUNT_NAME: ${AZURE_BLOB_ACCOUNT_NAME:-difyai} AZURE_BLOB_ACCOUNT_KEY: ${AZURE_BLOB_ACCOUNT_KEY:-difyai} AZURE_BLOB_CONTAINER_NAME: ${AZURE_BLOB_CONTAINER_NAME:-difyai-container} - AZURE_BLOB_ACCOUNT_URL: ${AZURE_BLOB_ACCOUNT_URL:-"https://.blob.core.windows.net"} + AZURE_BLOB_ACCOUNT_URL: ${AZURE_BLOB_ACCOUNT_URL:-https://.blob.core.windows.net} GOOGLE_STORAGE_BUCKET_NAME: ${GOOGLE_STORAGE_BUCKET_NAME:-your-bucket-name} GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: ${GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64:-your-google-service-account-json-base64-string} ALIYUN_OSS_BUCKET_NAME: ${ALIYUN_OSS_BUCKET_NAME:-your-bucket-name} ALIYUN_OSS_ACCESS_KEY: ${ALIYUN_OSS_ACCESS_KEY:-your-access-key} ALIYUN_OSS_SECRET_KEY: ${ALIYUN_OSS_SECRET_KEY:-your-secret-key} - ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-"https://oss-ap-southeast-1-internal.aliyuncs.com"} + ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-https://oss-ap-southeast-1-internal.aliyuncs.com} ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1} ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path} @@ -103,7 +103,7 @@ x-shared-env: &shared-api-worker-env TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id} TENCENT_COS_REGION: ${TENCENT_COS_REGION:-your-region} TENCENT_COS_SCHEME: ${TENCENT_COS_SCHEME:-your-scheme} - OCI_ENDPOINT: ${OCI_ENDPOINT:-"https://objectstorage.us-ashburn-1.oraclecloud.com"} + OCI_ENDPOINT: ${OCI_ENDPOINT:-https://objectstorage.us-ashburn-1.oraclecloud.com} OCI_BUCKET_NAME: ${OCI_BUCKET_NAME:-your-bucket-name} OCI_ACCESS_KEY: ${OCI_ACCESS_KEY:-your-access-key} OCI_SECRET_KEY: ${OCI_SECRET_KEY:-your-secret-key} @@ -125,14 +125,14 @@ x-shared-env: &shared-api-worker-env SUPABASE_API_KEY: ${SUPABASE_API_KEY:-your-access-key} SUPABASE_URL: ${SUPABASE_URL:-your-server-url} VECTOR_STORE: ${VECTOR_STORE:-weaviate} - WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-"http://weaviate:8080"} + WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080} WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} - QDRANT_URL: ${QDRANT_URL:-"http://qdrant:6333"} + QDRANT_URL: ${QDRANT_URL:-http://qdrant:6333} QDRANT_API_KEY: ${QDRANT_API_KEY:-difyai123456} QDRANT_CLIENT_TIMEOUT: ${QDRANT_CLIENT_TIMEOUT:-20} QDRANT_GRPC_ENABLED: ${QDRANT_GRPC_ENABLED:-false} QDRANT_GRPC_PORT: ${QDRANT_GRPC_PORT:-6334} - MILVUS_URI: ${MILVUS_URI:-"http://127.0.0.1:19530"} + MILVUS_URI: ${MILVUS_URI:-http://127.0.0.1:19530} MILVUS_TOKEN: ${MILVUS_TOKEN:-} MILVUS_USER: ${MILVUS_USER:-root} MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus} @@ -142,7 +142,7 @@ x-shared-env: &shared-api-worker-env MYSCALE_PASSWORD: ${MYSCALE_PASSWORD:-} MYSCALE_DATABASE: ${MYSCALE_DATABASE:-dify} MYSCALE_FTS_PARAMS: ${MYSCALE_FTS_PARAMS:-} - COUCHBASE_CONNECTION_STRING: ${COUCHBASE_CONNECTION_STRING:-"couchbase://couchbase-server"} + COUCHBASE_CONNECTION_STRING: ${COUCHBASE_CONNECTION_STRING:-couchbase://couchbase-server} COUCHBASE_USER: ${COUCHBASE_USER:-Administrator} COUCHBASE_PASSWORD: ${COUCHBASE_PASSWORD:-password} COUCHBASE_BUCKET_NAME: ${COUCHBASE_BUCKET_NAME:-Embeddings} @@ -176,15 +176,15 @@ x-shared-env: &shared-api-worker-env TIDB_VECTOR_USER: ${TIDB_VECTOR_USER:-} TIDB_VECTOR_PASSWORD: ${TIDB_VECTOR_PASSWORD:-} TIDB_VECTOR_DATABASE: ${TIDB_VECTOR_DATABASE:-dify} - TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-"http://127.0.0.1"} + TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-http://127.0.0.1} TIDB_ON_QDRANT_API_KEY: ${TIDB_ON_QDRANT_API_KEY:-dify} TIDB_ON_QDRANT_CLIENT_TIMEOUT: ${TIDB_ON_QDRANT_CLIENT_TIMEOUT:-20} TIDB_ON_QDRANT_GRPC_ENABLED: ${TIDB_ON_QDRANT_GRPC_ENABLED:-false} TIDB_ON_QDRANT_GRPC_PORT: ${TIDB_ON_QDRANT_GRPC_PORT:-6334} TIDB_PUBLIC_KEY: ${TIDB_PUBLIC_KEY:-dify} TIDB_PRIVATE_KEY: ${TIDB_PRIVATE_KEY:-dify} - TIDB_API_URL: ${TIDB_API_URL:-"http://127.0.0.1"} - TIDB_IAM_API_URL: ${TIDB_IAM_API_URL:-"http://127.0.0.1"} + TIDB_API_URL: ${TIDB_API_URL:-http://127.0.0.1} + TIDB_IAM_API_URL: ${TIDB_IAM_API_URL:-http://127.0.0.1} TIDB_REGION: ${TIDB_REGION:-regions/aws-us-east-1} TIDB_PROJECT_ID: ${TIDB_PROJECT_ID:-dify} TIDB_SPEND_LIMIT: ${TIDB_SPEND_LIMIT:-100} @@ -209,7 +209,7 @@ x-shared-env: &shared-api-worker-env OPENSEARCH_USER: ${OPENSEARCH_USER:-admin} OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin} OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true} - TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-"http://127.0.0.1"} + TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1} TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify} TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30} TENCENT_VECTOR_DB_USERNAME: ${TENCENT_VECTOR_DB_USERNAME:-dify} @@ -221,7 +221,7 @@ x-shared-env: &shared-api-worker-env ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} KIBANA_PORT: ${KIBANA_PORT:-5601} - BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-"http://127.0.0.1:5287"} + BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287} BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000} BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root} BAIDU_VECTOR_DB_API_KEY: ${BAIDU_VECTOR_DB_API_KEY:-dify} @@ -235,7 +235,7 @@ x-shared-env: &shared-api-worker-env VIKINGDB_SCHEMA: ${VIKINGDB_SCHEMA:-http} VIKINGDB_CONNECTION_TIMEOUT: ${VIKINGDB_CONNECTION_TIMEOUT:-30} VIKINGDB_SOCKET_TIMEOUT: ${VIKINGDB_SOCKET_TIMEOUT:-30} - LINDORM_URL: ${LINDORM_URL:-"http://lindorm:30070"} + LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070} LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm} LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm} OCEANBASE_VECTOR_HOST: ${OCEANBASE_VECTOR_HOST:-oceanbase} @@ -245,7 +245,7 @@ x-shared-env: &shared-api-worker-env OCEANBASE_VECTOR_DATABASE: ${OCEANBASE_VECTOR_DATABASE:-test} OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} - UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-"https://xxx-vector.upstash.io"} + UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-https://xxx-vector.upstash.io} UPSTASH_VECTOR_TOKEN: ${UPSTASH_VECTOR_TOKEN:-dify} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} @@ -270,7 +270,7 @@ x-shared-env: &shared-api-worker-env NOTION_INTERNAL_SECRET: ${NOTION_INTERNAL_SECRET:-} MAIL_TYPE: ${MAIL_TYPE:-resend} MAIL_DEFAULT_SEND_FROM: ${MAIL_DEFAULT_SEND_FROM:-} - RESEND_API_URL: ${RESEND_API_URL:-"https://api.resend.com"} + RESEND_API_URL: ${RESEND_API_URL:-https://api.resend.com} RESEND_API_KEY: ${RESEND_API_KEY:-your-resend-api-key} SMTP_SERVER: ${SMTP_SERVER:-} SMTP_PORT: ${SMTP_PORT:-465} @@ -281,7 +281,7 @@ x-shared-env: &shared-api-worker-env INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72} RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5} - CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-"http://sandbox:8194"} + CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194} CODE_EXECUTION_API_KEY: ${CODE_EXECUTION_API_KEY:-dify-sandbox} CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807} CODE_MIN_NUMBER: ${CODE_MIN_NUMBER:--9223372036854775808} @@ -303,8 +303,8 @@ x-shared-env: &shared-api-worker-env WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} - SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-"http://ssrf_proxy:3128"} - SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-"http://ssrf_proxy:3128"} + SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-http://ssrf_proxy:3128} + SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-http://ssrf_proxy:3128} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} PGUSER: ${PGUSER:-${DB_USERNAME}} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}} @@ -314,8 +314,8 @@ x-shared-env: &shared-api-worker-env SANDBOX_GIN_MODE: ${SANDBOX_GIN_MODE:-release} SANDBOX_WORKER_TIMEOUT: ${SANDBOX_WORKER_TIMEOUT:-15} SANDBOX_ENABLE_NETWORK: ${SANDBOX_ENABLE_NETWORK:-true} - SANDBOX_HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-"http://ssrf_proxy:3128"} - SANDBOX_HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-"http://ssrf_proxy:3128"} + SANDBOX_HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} + SANDBOX_HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} SANDBOX_PORT: ${SANDBOX_PORT:-8194} WEAVIATE_PERSISTENCE_DATA_PATH: ${WEAVIATE_PERSISTENCE_DATA_PATH:-/var/lib/weaviate} WEAVIATE_QUERY_DEFAULTS_LIMIT: ${WEAVIATE_QUERY_DEFAULTS_LIMIT:-25} @@ -338,8 +338,8 @@ x-shared-env: &shared-api-worker-env ETCD_SNAPSHOT_COUNT: ${ETCD_SNAPSHOT_COUNT:-50000} MINIO_ACCESS_KEY: ${MINIO_ACCESS_KEY:-minioadmin} MINIO_SECRET_KEY: ${MINIO_SECRET_KEY:-minioadmin} - ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-"etcd:2379"} - MINIO_ADDRESS: ${MINIO_ADDRESS:-"minio:9000"} + ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} + MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} MILVUS_AUTHORIZATION_ENABLED: ${MILVUS_AUTHORIZATION_ENABLED:-true} PGVECTOR_PGUSER: ${PGVECTOR_PGUSER:-postgres} PGVECTOR_POSTGRES_PASSWORD: ${PGVECTOR_POSTGRES_PASSWORD:-difyai123456} @@ -360,7 +360,7 @@ x-shared-env: &shared-api-worker-env NGINX_SSL_PORT: ${NGINX_SSL_PORT:-443} NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} - NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-"TLSv1.1 TLSv1.2 TLSv1.3"} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-15M} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} @@ -374,7 +374,7 @@ x-shared-env: &shared-api-worker-env SSRF_COREDUMP_DIR: ${SSRF_COREDUMP_DIR:-/var/spool/squid} SSRF_REVERSE_PROXY_PORT: ${SSRF_REVERSE_PROXY_PORT:-8194} SSRF_SANDBOX_HOST: ${SSRF_SANDBOX_HOST:-sandbox} - COMPOSE_PROFILES: ${COMPOSE_PROFILES:-"${VECTOR_STORE:-weaviate}"} + COMPOSE_PROFILES: ${COMPOSE_PROFILES:-${VECTOR_STORE:-weaviate}} EXPOSE_NGINX_PORT: ${EXPOSE_NGINX_PORT:-80} EXPOSE_NGINX_SSL_PORT: ${EXPOSE_NGINX_SSL_PORT:-443} POSITION_TOOL_PINS: ${POSITION_TOOL_PINS:-} diff --git a/docker/generate_docker_compose b/docker/generate_docker_compose index 54b6d55217f8ba..dc4460f96cf9be 100755 --- a/docker/generate_docker_compose +++ b/docker/generate_docker_compose @@ -43,7 +43,7 @@ def generate_shared_env_block(env_vars, anchor_name="shared-api-worker-env"): else: # If default value contains special characters, wrap it in quotes if re.search(r"[:\s]", default): - default = f'"{default}"' + default = f"{default}" lines.append(f" {key}: ${{{key}:-{default}}}") return "\n".join(lines) From 39ace9bdee5426f09197d7d8ab0bd347db6fc1c1 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 25 Dec 2024 16:34:38 +0800 Subject: [PATCH 14/65] =?UTF-8?q?fix(app=5Fgenerator):=20improve=20error?= =?UTF-8?q?=20handling=20for=20closed=20file=20I/O=20operat=E2=80=A6=20(#1?= =?UTF-8?q?2073)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: -LAN- --- api/core/app/apps/advanced_chat/app_generator.py | 2 +- api/core/app/apps/message_based_app_generator.py | 2 +- api/core/app/apps/workflow/app_generator.py | 2 +- api/core/tools/tool_engine.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index a18b40712b7ce6..b006de23699c1b 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -383,7 +383,7 @@ def _handle_advanced_chat_response( try: return generate_task_pipeline.process() except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error raise GenerateTaskStoppedError() else: logger.exception(f"Failed to process generate task pipeline, conversation_id: {conversation.id}") diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index dcd9463b8abd0f..4e3aa840ceac8a 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -76,7 +76,7 @@ def _handle_response( try: return generate_task_pipeline.process() except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error raise GenerateTaskStoppedError() else: logger.exception(f"Failed to handle response, conversation_id: {conversation.id}") diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 1d5f21b9e0cc07..42bc17277fd7c5 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -309,7 +309,7 @@ def _handle_response( try: return generate_task_pipeline.process() except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error raise GenerateTaskStoppedError() else: logger.exception( diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 425a892527daa4..f7a8ed63f401d5 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -113,7 +113,7 @@ def agent_invoke( error_response = f"tool invoke error: {e}" agent_tool_callback.on_tool_error(e) except ToolEngineInvokeError as e: - meta = e.args[0] + meta = e.meta error_response = f"tool invoke error: {meta.error}" agent_tool_callback.on_tool_error(e) return error_response, [], meta From 2b2263a349326e21da813640435e8fcd4c9ce536 Mon Sep 17 00:00:00 2001 From: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Date: Wed, 25 Dec 2024 18:17:15 +0800 Subject: [PATCH 15/65] Feat/parent child retrieval (#12086) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: yihong0618 Signed-off-by: -LAN- Co-authored-by: AkaraChen Co-authored-by: nite-knite Co-authored-by: Joel Co-authored-by: Warren Chen Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Co-authored-by: yihong Co-authored-by: -LAN- Co-authored-by: KVOJJJin Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com> Co-authored-by: Charlie.Wei Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: huayaoyue6 Co-authored-by: kurokobo Co-authored-by: Matsuda Co-authored-by: shirochan Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: Huỳnh Gia Bôi Co-authored-by: Julian Huynh Co-authored-by: Hash Brown Co-authored-by: 非法操作 Co-authored-by: Kazuki Takamatsu Co-authored-by: Trey Dong <1346650911@qq.com> Co-authored-by: VoidIsVoid <343750470@qq.com> Co-authored-by: Gimling Co-authored-by: xiandan-erizo Co-authored-by: Muneyuki Noguchi Co-authored-by: zhaobingshuang <1475195565@qq.com> Co-authored-by: zhaobs Co-authored-by: suzuki.sh Co-authored-by: Yingchun Lai Co-authored-by: huanshare Co-authored-by: huanshare Co-authored-by: orangeclk Co-authored-by: 문정현 <120004247+JungHyunMoon@users.noreply.github.com> Co-authored-by: barabicu Co-authored-by: Wei Mingzhi Co-authored-by: Paul van Oorschot <20116814+pvoo@users.noreply.github.com> Co-authored-by: zkyTech Co-authored-by: zhangkunyuan Co-authored-by: Tommy <34446820+Asterovim@users.noreply.github.com> Co-authored-by: zxhlyh Co-authored-by: Novice <857526207@qq.com> Co-authored-by: Novice Lee Co-authored-by: Novice Lee Co-authored-by: zxhlyh <16177003+zxhlyh@users.noreply.github.com> Co-authored-by: liuzhenghua <1090179900@qq.com> Co-authored-by: Jiang <65766008+AlwaysBluer@users.noreply.github.com> Co-authored-by: jiangzhijie Co-authored-by: Joe <79627742+ZhouhaoJiang@users.noreply.github.com> Co-authored-by: Alok Shrivastwa Co-authored-by: Alok Shrivastwa Co-authored-by: JasonVV Co-authored-by: Hiroshi Fujita Co-authored-by: Kevin9703 <51311316+Kevin9703@users.noreply.github.com> Co-authored-by: NFish Co-authored-by: Junyan Qin <1010553892@qq.com> Co-authored-by: IWAI, Masaharu Co-authored-by: IWAI, Masaharu Co-authored-by: Bowen Liang Co-authored-by: luckylhb90 Co-authored-by: hobo.l Co-authored-by: douxc <7553076+douxc@users.noreply.github.com> --- api/poetry.lock | 13 +- .../[datasetId]/layout.tsx | 156 +- .../[datasetId]/settings/page.tsx | 6 +- .../[datasetId]/style.module.css | 9 - web/app/(commonLayout)/datasets/Container.tsx | 17 +- .../(commonLayout)/datasets/DatasetCard.tsx | 10 +- .../components/app-sidebar/dataset-info.tsx | 45 + web/app/components/app-sidebar/index.tsx | 30 +- .../dataset-config/settings-modal/index.tsx | 16 +- .../components/base/app-icon/style.module.css | 19 + .../base/auto-height-textarea/common.tsx | 2 + web/app/components/base/badge.tsx | 6 +- .../components/base/checkbox/assets/mixed.svg | 5 + .../components/base/checkbox/index.module.css | 10 + web/app/components/base/checkbox/index.tsx | 5 +- web/app/components/base/divider/index.tsx | 2 +- .../components/base/divider/with-label.tsx | 23 + web/app/components/base/drawer/index.tsx | 4 +- .../base/file-uploader/file-type-icon.tsx | 7 +- .../icons/assets/public/knowledge/chunk.svg | 13 + .../assets/public/knowledge/collapse.svg | 9 + .../assets/public/knowledge/general-type.svg | 5 + .../knowledge/layout-right-2-line-mod.svg | 5 + .../public/knowledge/parent-child-type.svg | 7 + .../assets/public/knowledge/selection-mod.svg | 13 + .../icons/src/public/knowledge/Chunk.json | 116 ++ .../base/icons/src/public/knowledge/Chunk.tsx | 16 + .../icons/src/public/knowledge/Collapse.json | 62 + .../icons/src/public/knowledge/Collapse.tsx | 16 + .../src/public/knowledge/GeneralType.json | 38 + .../src/public/knowledge/GeneralType.tsx | 16 + .../public/knowledge/LayoutRight2LineMod.json | 36 + .../public/knowledge/LayoutRight2LineMod.tsx | 16 + .../src/public/knowledge/ParentChildType.json | 56 + .../src/public/knowledge/ParentChildType.tsx | 16 + .../src/public/knowledge/SelectionMod.json | 116 ++ .../src/public/knowledge/SelectionMod.tsx | 16 + .../base/icons/src/public/knowledge/index.ts | 6 + .../base/icons/src/vender/features/index.ts | 2 +- .../components/base/input-number/index.tsx | 86 + .../base/linked-apps-panel/index.tsx | 62 + web/app/components/base/pagination/index.tsx | 2 +- web/app/components/base/param-item/index.tsx | 36 +- web/app/components/base/radio-card/index.tsx | 23 +- .../components/base/retry-button/index.tsx | 85 - .../base/retry-button/style.module.css | 4 - .../base/simple-pie-chart/index.tsx | 7 +- web/app/components/base/skeleton/index.tsx | 13 +- web/app/components/base/switch/index.tsx | 3 + web/app/components/base/tag-input/index.tsx | 51 +- web/app/components/base/toast/index.tsx | 19 +- web/app/components/base/tooltip/index.tsx | 4 +- .../billing/priority-label/index.tsx | 19 +- web/app/components/datasets/chunk.tsx | 54 + .../datasets/common/chunking-mode-label.tsx | 29 + .../datasets/common/document-file-icon.tsx | 40 + .../common/document-picker/document-list.tsx | 42 + .../datasets/common/document-picker/index.tsx | 118 ++ .../preview-document-picker.tsx | 82 + .../auto-disabled-document.tsx | 38 + .../index-failed.tsx | 69 + .../status-with-action.tsx | 65 + .../index.tsx | 27 +- .../common/retrieval-method-config/index.tsx | 89 +- .../common/retrieval-method-info/index.tsx | 20 +- .../common/retrieval-param-config/index.tsx | 38 +- .../datasets/create/assets/family-mod.svg | 6 + .../create/assets/file-list-3-fill.svg | 5 + .../datasets/create/assets/gold.svg | 4 + .../datasets/create/assets/note-mod.svg | 5 + .../create/assets/option-card-effect-blue.svg | 12 + .../assets/option-card-effect-orange.svg | 12 + .../assets/option-card-effect-purple.svg | 12 + .../create/assets/pattern-recognition-mod.svg | 12 + .../datasets/create/assets/piggy-bank-mod.svg | 7 + .../create/assets/progress-indicator.svg | 8 + .../datasets/create/assets/rerank.svg | 13 + .../datasets/create/assets/research-mod.svg | 6 + .../datasets/create/assets/selection-mod.svg | 12 + .../create/assets/setting-gear-mod.svg | 4 + .../create/embedding-process/index.module.css | 52 +- .../create/embedding-process/index.tsx | 192 ++- .../create/file-preview/index.module.css | 3 +- .../datasets/create/file-preview/index.tsx | 4 +- .../create/file-uploader/index.module.css | 67 +- .../datasets/create/file-uploader/index.tsx | 82 +- web/app/components/datasets/create/icons.ts | 16 + web/app/components/datasets/create/index.tsx | 58 +- .../create/notion-page-preview/index.tsx | 4 +- .../datasets/create/step-one/index.module.css | 38 +- .../datasets/create/step-one/index.tsx | 287 +-- .../datasets/create/step-three/index.tsx | 53 +- .../datasets/create/step-two/index.module.css | 38 +- .../datasets/create/step-two/index.tsx | 1535 +++++++++-------- .../datasets/create/step-two/inputs.tsx | 77 + .../create/step-two/language-select/index.tsx | 33 +- .../datasets/create/step-two/option-card.tsx | 98 ++ .../datasets/create/stepper/index.tsx | 27 + .../datasets/create/stepper/step.tsx | 46 + .../datasets/create/top-bar/index.tsx | 41 + .../create/website/base/error-message.tsx | 2 +- .../create/website/jina-reader/index.tsx | 1 - .../datasets/create/website/preview.tsx | 4 +- .../detail/batch-modal/csv-downloader.tsx | 14 +- .../documents/detail/batch-modal/index.tsx | 4 +- .../detail/completed/InfiniteVirtualList.tsx | 98 -- .../detail/completed/SegmentCard.tsx | 20 +- .../detail/completed/child-segment-detail.tsx | 134 ++ .../detail/completed/child-segment-list.tsx | 195 +++ .../completed/common/action-buttons.tsx | 86 + .../detail/completed/common/add-another.tsx | 32 + .../detail/completed/common/batch-action.tsx | 103 ++ .../detail/completed/common/chunk-content.tsx | 192 +++ .../documents/detail/completed/common/dot.tsx | 11 + .../detail/completed/common/empty.tsx | 78 + .../completed/common/full-screen-drawer.tsx | 35 + .../detail/completed/common/keywords.tsx | 47 + .../completed/common/regeneration-modal.tsx | 131 ++ .../completed/common/segment-index-tag.tsx | 40 + .../documents/detail/completed/common/tag.tsx | 15 + .../detail/completed/display-toggle.tsx | 40 + .../documents/detail/completed/index.tsx | 873 ++++++---- .../detail/completed/new-child-segment.tsx | 175 ++ .../detail/completed/segment-card.tsx | 280 +++ .../detail/completed/segment-detail.tsx | 190 ++ .../detail/completed/segment-list.tsx | 116 ++ .../skeleton/full-doc-list-skeleton.tsx | 25 + .../skeleton/general-list-skeleton.tsx | 74 + .../skeleton/paragraph-list-skeleton.tsx | 76 + .../skeleton/parent-chunk-card-skeleton.tsx | 45 + .../detail/completed/status-item.tsx | 22 + .../detail/completed/style.module.css | 13 +- .../documents/detail/embedding/index.tsx | 302 ++-- .../detail/embedding/skeleton/index.tsx | 66 + .../datasets/documents/detail/index.tsx | 219 ++- .../documents/detail/metadata/index.tsx | 18 +- .../detail/metadata/style.module.css | 13 +- .../documents/detail/new-segment-modal.tsx | 156 -- .../datasets/documents/detail/new-segment.tsx | 208 +++ .../documents/detail/segment-add/index.tsx | 114 +- .../documents/detail/settings/index.tsx | 40 +- .../documents/detail/style.module.css | 10 +- .../components/datasets/documents/index.tsx | 77 +- .../components/datasets/documents/list.tsx | 440 +++-- .../datasets/documents/style.module.css | 17 +- .../formatted-text/flavours/edit-slice.tsx | 115 ++ .../formatted-text/flavours/preview-slice.tsx | 56 + .../formatted-text/flavours/shared.tsx | 60 + .../datasets/formatted-text/flavours/type.ts | 5 + .../datasets/formatted-text/formatted.tsx | 12 + .../components/child-chunks-item.tsx | 30 + .../components/chunk-detail-modal.tsx | 89 + .../hit-testing/components/result-item.tsx | 121 ++ .../datasets/hit-testing/components/score.tsx | 25 + .../datasets/hit-testing/hit-detail.tsx | 68 - .../components/datasets/hit-testing/index.tsx | 144 +- .../hit-testing/modify-retrieval-modal.tsx | 8 +- .../datasets/hit-testing/style.module.css | 36 +- .../datasets/hit-testing/textarea.tsx | 59 +- .../utils/extension-to-file-type.ts | 31 + web/app/components/datasets/loading.tsx | 0 .../components/datasets/preview/container.tsx | 29 + .../components/datasets/preview/header.tsx | 23 + web/app/components/datasets/preview/index.tsx | 0 .../datasets/settings/form/index.tsx | 130 +- .../index-method-radio/index.module.css | 54 - .../settings/index-method-radio/index.tsx | 105 +- .../model-selector/model-trigger.tsx | 16 +- web/app/components/header/indicator/index.tsx | 18 +- web/context/dataset-detail.ts | 13 +- web/hooks/use-metadata.ts | 4 +- web/i18n/en-US/common.ts | 8 +- web/i18n/en-US/dataset-creation.ts | 50 +- web/i18n/en-US/dataset-documents.ts | 66 +- web/i18n/en-US/dataset-hit-testing.ts | 14 +- web/i18n/en-US/dataset-settings.ts | 9 +- web/i18n/en-US/dataset.ts | 22 +- web/i18n/zh-Hans/common.ts | 8 +- web/i18n/zh-Hans/dataset-creation.ts | 29 +- web/i18n/zh-Hans/dataset-documents.ts | 54 +- web/i18n/zh-Hans/dataset-hit-testing.ts | 10 +- web/i18n/zh-Hans/dataset-settings.ts | 9 +- web/i18n/zh-Hans/dataset.ts | 22 +- web/i18n/zh-Hant/dataset-creation.ts | 3 +- web/models/datasets.ts | 111 +- web/package.json | 2 + web/public/screenshots/Light/Agent.png | Bin 0 -> 36209 bytes web/public/screenshots/Light/Agent@2x.png | Bin 0 -> 103245 bytes web/public/screenshots/Light/Agent@3x.png | Bin 0 -> 209674 bytes web/public/screenshots/Light/ChatFlow.png | Bin 0 -> 28423 bytes web/public/screenshots/Light/ChatFlow@2x.png | Bin 0 -> 81229 bytes web/public/screenshots/Light/ChatFlow@3x.png | Bin 0 -> 160820 bytes web/public/screenshots/Light/Chatbot.png | Bin 0 -> 31633 bytes web/public/screenshots/Light/Chatbot@2x.png | Bin 0 -> 84515 bytes web/public/screenshots/Light/Chatbot@3x.png | Bin 0 -> 142013 bytes web/public/screenshots/Light/Chatflow.png | Bin 0 -> 28423 bytes web/public/screenshots/Light/Chatflow@2x.png | Bin 0 -> 81229 bytes web/public/screenshots/Light/Chatflow@3x.png | Bin 0 -> 160820 bytes .../screenshots/Light/TextGenerator.png | Bin 0 -> 26627 bytes .../screenshots/Light/TextGenerator@2x.png | Bin 0 -> 63818 bytes .../screenshots/Light/TextGenerator@3x.png | Bin 0 -> 122391 bytes web/public/screenshots/Light/Workflow.png | Bin 0 -> 22110 bytes web/public/screenshots/Light/Workflow@2x.png | Bin 0 -> 62688 bytes web/public/screenshots/Light/Workflow@3x.png | Bin 0 -> 147073 bytes web/service/datasets.ts | 71 - web/service/knowledge/use-create-dataset.ts | 223 +++ web/service/knowledge/use-dateset.ts | 0 web/service/knowledge/use-document.ts | 124 ++ web/service/knowledge/use-hit-testing.ts | 0 web/service/knowledge/use-import.ts | 0 web/service/knowledge/use-segment.ts | 169 ++ web/tailwind.config.js | 23 +- web/themes/manual-dark.css | 8 + web/themes/manual-light.css | 8 + web/utils/time.ts | 12 + web/yarn.lock | 10 + 216 files changed, 9038 insertions(+), 3088 deletions(-) create mode 100644 web/app/components/app-sidebar/dataset-info.tsx create mode 100644 web/app/components/base/app-icon/style.module.css create mode 100644 web/app/components/base/checkbox/assets/mixed.svg create mode 100644 web/app/components/base/checkbox/index.module.css create mode 100644 web/app/components/base/divider/with-label.tsx create mode 100644 web/app/components/base/icons/assets/public/knowledge/chunk.svg create mode 100644 web/app/components/base/icons/assets/public/knowledge/collapse.svg create mode 100644 web/app/components/base/icons/assets/public/knowledge/general-type.svg create mode 100644 web/app/components/base/icons/assets/public/knowledge/layout-right-2-line-mod.svg create mode 100644 web/app/components/base/icons/assets/public/knowledge/parent-child-type.svg create mode 100644 web/app/components/base/icons/assets/public/knowledge/selection-mod.svg create mode 100644 web/app/components/base/icons/src/public/knowledge/Chunk.json create mode 100644 web/app/components/base/icons/src/public/knowledge/Chunk.tsx create mode 100644 web/app/components/base/icons/src/public/knowledge/Collapse.json create mode 100644 web/app/components/base/icons/src/public/knowledge/Collapse.tsx create mode 100644 web/app/components/base/icons/src/public/knowledge/GeneralType.json create mode 100644 web/app/components/base/icons/src/public/knowledge/GeneralType.tsx create mode 100644 web/app/components/base/icons/src/public/knowledge/LayoutRight2LineMod.json create mode 100644 web/app/components/base/icons/src/public/knowledge/LayoutRight2LineMod.tsx create mode 100644 web/app/components/base/icons/src/public/knowledge/ParentChildType.json create mode 100644 web/app/components/base/icons/src/public/knowledge/ParentChildType.tsx create mode 100644 web/app/components/base/icons/src/public/knowledge/SelectionMod.json create mode 100644 web/app/components/base/icons/src/public/knowledge/SelectionMod.tsx create mode 100644 web/app/components/base/icons/src/public/knowledge/index.ts create mode 100644 web/app/components/base/input-number/index.tsx create mode 100644 web/app/components/base/linked-apps-panel/index.tsx delete mode 100644 web/app/components/base/retry-button/index.tsx delete mode 100644 web/app/components/base/retry-button/style.module.css create mode 100644 web/app/components/datasets/chunk.tsx create mode 100644 web/app/components/datasets/common/chunking-mode-label.tsx create mode 100644 web/app/components/datasets/common/document-file-icon.tsx create mode 100644 web/app/components/datasets/common/document-picker/document-list.tsx create mode 100644 web/app/components/datasets/common/document-picker/index.tsx create mode 100644 web/app/components/datasets/common/document-picker/preview-document-picker.tsx create mode 100644 web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx create mode 100644 web/app/components/datasets/common/document-status-with-action/index-failed.tsx create mode 100644 web/app/components/datasets/common/document-status-with-action/status-with-action.tsx create mode 100644 web/app/components/datasets/create/assets/family-mod.svg create mode 100644 web/app/components/datasets/create/assets/file-list-3-fill.svg create mode 100644 web/app/components/datasets/create/assets/gold.svg create mode 100644 web/app/components/datasets/create/assets/note-mod.svg create mode 100644 web/app/components/datasets/create/assets/option-card-effect-blue.svg create mode 100644 web/app/components/datasets/create/assets/option-card-effect-orange.svg create mode 100644 web/app/components/datasets/create/assets/option-card-effect-purple.svg create mode 100644 web/app/components/datasets/create/assets/pattern-recognition-mod.svg create mode 100644 web/app/components/datasets/create/assets/piggy-bank-mod.svg create mode 100644 web/app/components/datasets/create/assets/progress-indicator.svg create mode 100644 web/app/components/datasets/create/assets/rerank.svg create mode 100644 web/app/components/datasets/create/assets/research-mod.svg create mode 100644 web/app/components/datasets/create/assets/selection-mod.svg create mode 100644 web/app/components/datasets/create/assets/setting-gear-mod.svg create mode 100644 web/app/components/datasets/create/icons.ts create mode 100644 web/app/components/datasets/create/step-two/inputs.tsx create mode 100644 web/app/components/datasets/create/step-two/option-card.tsx create mode 100644 web/app/components/datasets/create/stepper/index.tsx create mode 100644 web/app/components/datasets/create/stepper/step.tsx create mode 100644 web/app/components/datasets/create/top-bar/index.tsx delete mode 100644 web/app/components/datasets/documents/detail/completed/InfiniteVirtualList.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/child-segment-detail.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/child-segment-list.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/common/add-another.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/common/batch-action.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/common/chunk-content.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/common/dot.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/common/empty.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/common/full-screen-drawer.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/common/keywords.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/common/regeneration-modal.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/common/segment-index-tag.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/common/tag.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/display-toggle.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/new-child-segment.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/segment-card.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/segment-detail.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/segment-list.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/skeleton/full-doc-list-skeleton.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/skeleton/general-list-skeleton.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/skeleton/paragraph-list-skeleton.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/skeleton/parent-chunk-card-skeleton.tsx create mode 100644 web/app/components/datasets/documents/detail/completed/status-item.tsx create mode 100644 web/app/components/datasets/documents/detail/embedding/skeleton/index.tsx delete mode 100644 web/app/components/datasets/documents/detail/new-segment-modal.tsx create mode 100644 web/app/components/datasets/documents/detail/new-segment.tsx create mode 100644 web/app/components/datasets/formatted-text/flavours/edit-slice.tsx create mode 100644 web/app/components/datasets/formatted-text/flavours/preview-slice.tsx create mode 100644 web/app/components/datasets/formatted-text/flavours/shared.tsx create mode 100644 web/app/components/datasets/formatted-text/flavours/type.ts create mode 100644 web/app/components/datasets/formatted-text/formatted.tsx create mode 100644 web/app/components/datasets/hit-testing/components/child-chunks-item.tsx create mode 100644 web/app/components/datasets/hit-testing/components/chunk-detail-modal.tsx create mode 100644 web/app/components/datasets/hit-testing/components/result-item.tsx create mode 100644 web/app/components/datasets/hit-testing/components/score.tsx delete mode 100644 web/app/components/datasets/hit-testing/hit-detail.tsx create mode 100644 web/app/components/datasets/hit-testing/utils/extension-to-file-type.ts create mode 100644 web/app/components/datasets/loading.tsx create mode 100644 web/app/components/datasets/preview/container.tsx create mode 100644 web/app/components/datasets/preview/header.tsx create mode 100644 web/app/components/datasets/preview/index.tsx delete mode 100644 web/app/components/datasets/settings/index-method-radio/index.module.css create mode 100644 web/public/screenshots/Light/Agent.png create mode 100644 web/public/screenshots/Light/Agent@2x.png create mode 100644 web/public/screenshots/Light/Agent@3x.png create mode 100644 web/public/screenshots/Light/ChatFlow.png create mode 100644 web/public/screenshots/Light/ChatFlow@2x.png create mode 100644 web/public/screenshots/Light/ChatFlow@3x.png create mode 100644 web/public/screenshots/Light/Chatbot.png create mode 100644 web/public/screenshots/Light/Chatbot@2x.png create mode 100644 web/public/screenshots/Light/Chatbot@3x.png create mode 100644 web/public/screenshots/Light/Chatflow.png create mode 100644 web/public/screenshots/Light/Chatflow@2x.png create mode 100644 web/public/screenshots/Light/Chatflow@3x.png create mode 100644 web/public/screenshots/Light/TextGenerator.png create mode 100644 web/public/screenshots/Light/TextGenerator@2x.png create mode 100644 web/public/screenshots/Light/TextGenerator@3x.png create mode 100644 web/public/screenshots/Light/Workflow.png create mode 100644 web/public/screenshots/Light/Workflow@2x.png create mode 100644 web/public/screenshots/Light/Workflow@3x.png create mode 100644 web/service/knowledge/use-create-dataset.ts create mode 100644 web/service/knowledge/use-dateset.ts create mode 100644 web/service/knowledge/use-document.ts create mode 100644 web/service/knowledge/use-hit-testing.ts create mode 100644 web/service/knowledge/use-import.ts create mode 100644 web/service/knowledge/use-segment.ts create mode 100644 web/utils/time.ts diff --git a/api/poetry.lock b/api/poetry.lock index b42eb22dd40b8a..b2d22a887db2f6 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -1,4 +1,15 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. + +[[package]] +name = "aiofiles" +version = "24.1.0" +description = "File support for asyncio." +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, + {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"}, +] [[package]] name = "aiofiles" diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx index b416659a6a1cfa..a6fb116fa8a50f 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx @@ -7,85 +7,36 @@ import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' import { Cog8ToothIcon, - // CommandLineIcon, - Squares2X2Icon, - // eslint-disable-next-line sort-imports - PuzzlePieceIcon, DocumentTextIcon, PaperClipIcon, - QuestionMarkCircleIcon, } from '@heroicons/react/24/outline' import { Cog8ToothIcon as Cog8ToothSolidIcon, // CommandLineIcon as CommandLineSolidIcon, DocumentTextIcon as DocumentTextSolidIcon, } from '@heroicons/react/24/solid' -import Link from 'next/link' +import { RiApps2AddLine, RiInformation2Line } from '@remixicon/react' import s from './style.module.css' import classNames from '@/utils/classnames' import { fetchDatasetDetail, fetchDatasetRelatedApps } from '@/service/datasets' -import type { RelatedApp, RelatedAppResponse } from '@/models/datasets' +import type { RelatedAppResponse } from '@/models/datasets' import AppSideBar from '@/app/components/app-sidebar' -import Divider from '@/app/components/base/divider' -import AppIcon from '@/app/components/base/app-icon' import Loading from '@/app/components/base/loading' -import FloatPopoverContainer from '@/app/components/base/float-popover-container' import DatasetDetailContext from '@/context/dataset-detail' import { DataSourceType } from '@/models/datasets' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import { LanguagesSupported } from '@/i18n/language' import { useStore } from '@/app/components/app/store' -import { AiText, ChatBot, CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' -import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel' import { getLocaleOnClient } from '@/i18n' import { useAppContext } from '@/context/app-context' +import Tooltip from '@/app/components/base/tooltip' +import LinkedAppsPanel from '@/app/components/base/linked-apps-panel' export type IAppDetailLayoutProps = { children: React.ReactNode params: { datasetId: string } } -type ILikedItemProps = { - type?: 'plugin' | 'app' - appStatus?: boolean - detail: RelatedApp - isMobile: boolean -} - -const LikedItem = ({ - type = 'app', - detail, - isMobile, -}: ILikedItemProps) => { - return ( - -
- - {type === 'app' && ( - - {detail.mode === 'advanced-chat' && ( - - )} - {detail.mode === 'agent-chat' && ( - - )} - {detail.mode === 'chat' && ( - - )} - {detail.mode === 'completion' && ( - - )} - {detail.mode === 'workflow' && ( - - )} - - )} -
- {!isMobile &&
{detail?.name || '--'}
} - - ) -} - const TargetIcon = ({ className }: SVGProps) => { return @@ -117,65 +68,80 @@ const BookOpenIcon = ({ className }: SVGProps) => { type IExtraInfoProps = { isMobile: boolean relatedApps?: RelatedAppResponse + expand: boolean } -const ExtraInfo = ({ isMobile, relatedApps }: IExtraInfoProps) => { +const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => { const locale = getLocaleOnClient() const [isShowTips, { toggle: toggleTips, set: setShowTips }] = useBoolean(!isMobile) const { t } = useTranslation() + const hasRelatedApps = relatedApps?.data && relatedApps?.data?.length > 0 + const relatedAppsTotal = relatedApps?.data?.length || 0 + useEffect(() => { setShowTips(!isMobile) }, [isMobile, setShowTips]) - return
- - {(relatedApps?.data && relatedApps?.data?.length > 0) && ( + return
+ {hasRelatedApps && ( <> - {!isMobile &&
{relatedApps?.total || '--'} {t('common.datasetMenus.relatedApp')}
} + {!isMobile && ( + + } + > +
+ {relatedAppsTotal || '--'} {t('common.datasetMenus.relatedApp')} + +
+
+ )} + {isMobile &&
- {relatedApps?.total || '--'} + {relatedAppsTotal || '--'}
} - {relatedApps?.data?.map((item, index) => ())} )} - {!relatedApps?.data?.length && ( - - + {!hasRelatedApps && !expand && ( + +
+ +
+
{t('common.datasetMenus.emptyTip')}
+ + + {t('common.datasetMenus.viewDoc')} +
} > -
-
-
- -
-
- -
-
-
{t('common.datasetMenus.emptyTip')}
- - - {t('common.datasetMenus.viewDoc')} - +
+ {t('common.datasetMenus.noRelatedApp')} +
- + )}
} @@ -235,7 +201,7 @@ const DatasetDetailLayout: FC = (props) => { }, [isMobile, setAppSiderbarExpand]) if (!datasetRes && !error) - return + return return (
@@ -246,7 +212,7 @@ const DatasetDetailLayout: FC = (props) => { desc={datasetRes?.description || '--'} isExternal={datasetRes?.provider === 'external'} navigation={navigation} - extraInfo={!isCurrentWorkspaceDatasetOperator ? mode => : undefined} + extraInfo={!isCurrentWorkspaceDatasetOperator ? mode => : undefined} iconType={datasetRes?.data_source_type === DataSourceType.NOTION ? 'notion' : 'dataset'} />} = (props) => { dataset: datasetRes, mutateDatasetRes: () => mutateDatasetRes(), }}> -
{children}
+
{children}
) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx index df314ddafe9d15..3a65f1d30ff7ef 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx @@ -7,10 +7,10 @@ const Settings = async () => { const { t } = await translate(locale, 'dataset-settings') return ( -
+
-
{t('title')}
-
{t('desc')}
+
{t('title')}
+
{t('desc')}
diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/style.module.css b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/style.module.css index 0ee64b4fcd0f82..516b124809b9be 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/style.module.css +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/style.module.css @@ -1,12 +1,3 @@ -.itemWrapper { - @apply flex items-center w-full h-10 rounded-lg hover:bg-gray-50 cursor-pointer; -} -.appInfo { - @apply truncate text-gray-700 text-sm font-normal; -} -.iconWrapper { - @apply relative w-6 h-6 rounded-lg; -} .statusPoint { @apply flex justify-center items-center absolute -right-0.5 -bottom-0.5 w-2.5 h-2.5 bg-white rounded; } diff --git a/web/app/(commonLayout)/datasets/Container.tsx b/web/app/(commonLayout)/datasets/Container.tsx index a30521d9988a5e..a0edb1cd61ba3c 100644 --- a/web/app/(commonLayout)/datasets/Container.tsx +++ b/web/app/(commonLayout)/datasets/Container.tsx @@ -17,7 +17,6 @@ import TagManagementModal from '@/app/components/base/tag-management' import TagFilter from '@/app/components/base/tag-management/filter' import Button from '@/app/components/base/button' import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development' -import SearchInput from '@/app/components/base/search-input' // Services import { fetchDatasetApiBaseUrl } from '@/service/datasets' @@ -29,6 +28,7 @@ import { useAppContext } from '@/context/app-context' import { useExternalApiPanel } from '@/context/external-api-panel-context' // eslint-disable-next-line import/order import { useQuery } from '@tanstack/react-query' +import Input from '@/app/components/base/input' const Container = () => { const { t } = useTranslation() @@ -81,17 +81,24 @@ const Container = () => { }, [currentWorkspace, router]) return ( -
-
+
+
setActiveTab(newActiveTab)} options={options} /> {activeTab === 'dataset' && ( -
+
- + handleKeywordsChange(e.target.value)} + onClear={() => handleKeywordsChange('')} + />
+ +
+
+} diff --git a/web/app/components/base/linked-apps-panel/index.tsx b/web/app/components/base/linked-apps-panel/index.tsx new file mode 100644 index 00000000000000..4320cb0fc6389f --- /dev/null +++ b/web/app/components/base/linked-apps-panel/index.tsx @@ -0,0 +1,62 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import Link from 'next/link' +import { useTranslation } from 'react-i18next' +import { RiArrowRightUpLine } from '@remixicon/react' +import cn from '@/utils/classnames' +import AppIcon from '@/app/components/base/app-icon' +import type { RelatedApp } from '@/models/datasets' + +type ILikedItemProps = { + appStatus?: boolean + detail: RelatedApp + isMobile: boolean +} + +const appTypeMap = { + 'chat': 'Chatbot', + 'completion': 'Completion', + 'agent-chat': 'Agent', + 'advanced-chat': 'Chatflow', + 'workflow': 'Workflow', +} + +const LikedItem = ({ + detail, + isMobile, +}: ILikedItemProps) => { + return ( + +
+
+ +
+ {!isMobile &&
{detail?.name || '--'}
} +
+
{appTypeMap[detail.mode]}
+ + + ) +} + +type Props = { + relatedApps: RelatedApp[] + isMobile: boolean +} + +const LinkedAppsPanel: FC = ({ + relatedApps, + isMobile, +}) => { + const { t } = useTranslation() + return ( +
+
{relatedApps.length || '--'} {t('common.datasetMenus.relatedApp')}
+ {relatedApps.map((item, index) => ( + + ))} +
+ ) +} +export default React.memo(LinkedAppsPanel) diff --git a/web/app/components/base/pagination/index.tsx b/web/app/components/base/pagination/index.tsx index b64c712425178a..c0cc9f86ec2852 100644 --- a/web/app/components/base/pagination/index.tsx +++ b/web/app/components/base/pagination/index.tsx @@ -8,7 +8,7 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import cn from '@/utils/classnames' -type Props = { +export type Props = { className?: string current: number onChange: (cur: number) => void diff --git a/web/app/components/base/param-item/index.tsx b/web/app/components/base/param-item/index.tsx index 49acc8148455e5..68c980ad095b76 100644 --- a/web/app/components/base/param-item/index.tsx +++ b/web/app/components/base/param-item/index.tsx @@ -1,5 +1,6 @@ 'use client' import type { FC } from 'react' +import { InputNumber } from '../input-number' import Tooltip from '@/app/components/base/tooltip' import Slider from '@/app/components/base/slider' import Switch from '@/app/components/base/switch' @@ -23,39 +24,44 @@ type Props = { const ParamItem: FC = ({ className, id, name, noTooltip, tip, step = 0.1, min = 0, max, value, enable, onChange, hasSwitch, onSwitchChange }) => { return (
-
-
+
+
{hasSwitch && ( { onSwitchChange?.(id, val) }} /> )} - {name} + {name} {!noTooltip && ( {tip}
} /> )} -
-
-
-
- { - const value = parseFloat(e.target.value) - if (value < min || value > max) - return - - onChange(id, value) - }} /> +
+
+ { + onChange(id, value) + }} + className='w-[72px]' + />
-
+
= ({ onChosen = () => { }, chosenConfig, chosenConfigWrapClassName, + className, }) => { return (
-
-
+
+
{icon}
-
{title}
-
{description}
+
{title}
+
{description}
{!noRadio && ( -
+
= ({ )}
{((isChosen && chosenConfig) || noRadio) && ( -
- {chosenConfig} +
+
+
+ {chosenConfig} +
)}
diff --git a/web/app/components/base/retry-button/index.tsx b/web/app/components/base/retry-button/index.tsx deleted file mode 100644 index 689827af7b3afa..00000000000000 --- a/web/app/components/base/retry-button/index.tsx +++ /dev/null @@ -1,85 +0,0 @@ -'use client' -import type { FC } from 'react' -import React, { useEffect, useReducer } from 'react' -import { useTranslation } from 'react-i18next' -import useSWR from 'swr' -import s from './style.module.css' -import classNames from '@/utils/classnames' -import Divider from '@/app/components/base/divider' -import { getErrorDocs, retryErrorDocs } from '@/service/datasets' -import type { IndexingStatusResponse } from '@/models/datasets' - -const WarningIcon = () => - - - - -type Props = { - datasetId: string -} -type IIndexState = { - value: string -} -type ActionType = 'retry' | 'success' | 'error' - -type IAction = { - type: ActionType -} -const indexStateReducer = (state: IIndexState, action: IAction) => { - const actionMap = { - retry: 'retry', - success: 'success', - error: 'error', - } - - return { - ...state, - value: actionMap[action.type] || state.value, - } -} - -const RetryButton: FC = ({ datasetId }) => { - const { t } = useTranslation() - const [indexState, dispatch] = useReducer(indexStateReducer, { value: 'success' }) - const { data: errorDocs } = useSWR({ datasetId }, getErrorDocs) - - const onRetryErrorDocs = async () => { - dispatch({ type: 'retry' }) - const document_ids = errorDocs?.data.map((doc: IndexingStatusResponse) => doc.id) || [] - const res = await retryErrorDocs({ datasetId, document_ids }) - if (res.result === 'success') - dispatch({ type: 'success' }) - else - dispatch({ type: 'error' }) - } - - useEffect(() => { - if (errorDocs?.total === 0) - dispatch({ type: 'success' }) - else - dispatch({ type: 'error' }) - }, [errorDocs?.total]) - - if (indexState.value === 'success') - return null - - return ( -
- - - {errorDocs?.total} {t('dataset.docsFailedNotice')} - - - - {t('dataset.retry')} - -
- ) -} -export default RetryButton diff --git a/web/app/components/base/retry-button/style.module.css b/web/app/components/base/retry-button/style.module.css deleted file mode 100644 index 99a0947576d9c2..00000000000000 --- a/web/app/components/base/retry-button/style.module.css +++ /dev/null @@ -1,4 +0,0 @@ -.retryBtn { - @apply inline-flex justify-center items-center content-center h-9 leading-5 rounded-lg px-4 py-2 text-base; - @apply border-solid border border-gray-200 text-gray-500 hover:bg-white hover:shadow-sm hover:border-gray-300; -} diff --git a/web/app/components/base/simple-pie-chart/index.tsx b/web/app/components/base/simple-pie-chart/index.tsx index 7de539cbb1ac45..4b987ab42dec7d 100644 --- a/web/app/components/base/simple-pie-chart/index.tsx +++ b/web/app/components/base/simple-pie-chart/index.tsx @@ -10,10 +10,11 @@ export type SimplePieChartProps = { fill?: string stroke?: string size?: number + animationDuration?: number className?: string } -const SimplePieChart = ({ percentage = 80, fill = '#fdb022', stroke = '#f79009', size = 12, className }: SimplePieChartProps) => { +const SimplePieChart = ({ percentage = 80, fill = '#fdb022', stroke = '#f79009', size = 12, animationDuration, className }: SimplePieChartProps) => { const option: EChartsOption = useMemo(() => ({ series: [ { @@ -34,7 +35,7 @@ const SimplePieChart = ({ percentage = 80, fill = '#fdb022', stroke = '#f79009', { type: 'pie', radius: '83%', - animationDuration: 600, + animationDuration: animationDuration ?? 600, data: [ { value: percentage, itemStyle: { color: fill } }, { value: 100 - percentage, itemStyle: { color: '#fff' } }, @@ -48,7 +49,7 @@ const SimplePieChart = ({ percentage = 80, fill = '#fdb022', stroke = '#f79009', cursor: 'default', }, ], - }), [stroke, fill, percentage]) + }), [stroke, fill, percentage, animationDuration]) return ( -export const SkeletonContanier: FC = (props) => { +export const SkeletonContainer: FC = (props) => { const { className, children, ...rest } = props return (
@@ -30,11 +30,14 @@ export const SkeletonRectangle: FC = (props) => { ) } -export const SkeletonPoint: FC = () => -
·
- +export const SkeletonPoint: FC = (props) => { + const { className, ...rest } = props + return ( +
·
+ ) +} /** Usage - * + * * * * diff --git a/web/app/components/base/switch/index.tsx b/web/app/components/base/switch/index.tsx index f61c6f46fff0a7..8bf32b1311b158 100644 --- a/web/app/components/base/switch/index.tsx +++ b/web/app/components/base/switch/index.tsx @@ -64,4 +64,7 @@ const Switch = ({ onChange, size = 'md', defaultValue = false, disabled = false, ) } + +Switch.displayName = 'Switch' + export default React.memo(Switch) diff --git a/web/app/components/base/tag-input/index.tsx b/web/app/components/base/tag-input/index.tsx index b26d0c6438c067..ec6c1cee342665 100644 --- a/web/app/components/base/tag-input/index.tsx +++ b/web/app/components/base/tag-input/index.tsx @@ -3,8 +3,8 @@ import type { ChangeEvent, FC, KeyboardEvent } from 'react' import { } from 'use-context-selector' import { useTranslation } from 'react-i18next' import AutosizeInput from 'react-18-input-autosize' +import { RiAddLine, RiCloseLine } from '@remixicon/react' import cn from '@/utils/classnames' -import { X } from '@/app/components/base/icons/src/vender/line/general' import { useToastContext } from '@/app/components/base/toast' type TagInputProps = { @@ -75,14 +75,14 @@ const TagInput: FC = ({ (items || []).map((item, index) => (
+ className={cn('flex items-center mr-1 mt-1 pl-1.5 pr-1 py-1 system-xs-regular text-text-secondary border border-divider-deep bg-components-badge-white-to-dark rounded-md')} + > {item} { !disableRemove && ( - handleRemove(index)} - /> +
handleRemove(index)}> + +
) }
@@ -90,24 +90,27 @@ const TagInput: FC = ({ } { !disableAdd && ( - setFocused(true)} - onBlur={handleBlur} - value={value} - onChange={(e: ChangeEvent) => { - setValue(e.target.value) - }} - onKeyDown={handleKeyDown} - placeholder={t(placeholder || (isSpecialMode ? 'common.model.params.stop_sequencesPlaceholder' : 'datasetDocuments.segment.addKeyWord'))} - /> +
+ {!isSpecialMode && !focused && } + setFocused(true)} + onBlur={handleBlur} + value={value} + onChange={(e: ChangeEvent) => { + setValue(e.target.value) + }} + onKeyDown={handleKeyDown} + placeholder={t(placeholder || (isSpecialMode ? 'common.model.params.stop_sequencesPlaceholder' : 'datasetDocuments.segment.addKeyWord'))} + /> +
) }
diff --git a/web/app/components/base/toast/index.tsx b/web/app/components/base/toast/index.tsx index b9a6de9fe5ac00..ba7d8af518e51a 100644 --- a/web/app/components/base/toast/index.tsx +++ b/web/app/components/base/toast/index.tsx @@ -21,6 +21,7 @@ export type IToastProps = { children?: ReactNode onClose?: () => void className?: string + customComponent?: ReactNode } type IToastContext = { notify: (props: IToastProps) => void @@ -35,6 +36,7 @@ const Toast = ({ message, children, className, + customComponent, }: IToastProps) => { const { close } = useToastContext() // sometimes message is react node array. Not handle it. @@ -49,8 +51,7 @@ const Toast = ({ 'top-0', 'right-0', )}> - -
-
{message}
+
+
+
{message}
+ {customComponent} +
{children &&
{children}
}
- +
@@ -117,7 +121,8 @@ Toast.notify = ({ message, duration, className, -}: Pick) => { + customComponent, +}: Pick) => { const defaultDuring = (type === 'success' || type === 'info') ? 3000 : 6000 if (typeof window === 'object') { const holder = document.createElement('div') @@ -133,7 +138,7 @@ Toast.notify = ({ } }, }}> - + , ) document.body.appendChild(holder) diff --git a/web/app/components/base/tooltip/index.tsx b/web/app/components/base/tooltip/index.tsx index 8ec3cd8c7ab4b0..65b5a99077ca14 100644 --- a/web/app/components/base/tooltip/index.tsx +++ b/web/app/components/base/tooltip/index.tsx @@ -14,6 +14,7 @@ export type TooltipProps = { popupContent?: React.ReactNode children?: React.ReactNode popupClassName?: string + noDecoration?: boolean offset?: OffsetOptions needsDelay?: boolean asChild?: boolean @@ -27,6 +28,7 @@ const Tooltip: FC = ({ popupContent, children, popupClassName, + noDecoration, offset, asChild = true, needsDelay = false, @@ -96,7 +98,7 @@ const Tooltip: FC = ({ > {popupContent && (
triggerMethod === 'hover' && setHoverPopup()} diff --git a/web/app/components/billing/priority-label/index.tsx b/web/app/components/billing/priority-label/index.tsx index 36338cf4a8e767..6ecac4a79ea504 100644 --- a/web/app/components/billing/priority-label/index.tsx +++ b/web/app/components/billing/priority-label/index.tsx @@ -4,6 +4,7 @@ import { DocumentProcessingPriority, Plan, } from '../type' +import cn from '@/utils/classnames' import { useProviderContext } from '@/context/provider-context' import { ZapFast, @@ -11,7 +12,11 @@ import { } from '@/app/components/base/icons/src/vender/solid/general' import Tooltip from '@/app/components/base/tooltip' -const PriorityLabel = () => { +type PriorityLabelProps = { + className?: string +} + +const PriorityLabel = ({ className }: PriorityLabelProps) => { const { t } = useTranslation() const { plan } = useProviderContext() @@ -37,18 +42,18 @@ const PriorityLabel = () => { }
}> - + { plan.type === Plan.professional && ( - + ) } { (plan.type === Plan.team || plan.type === Plan.enterprise) && ( - + ) } {t(`billing.plansCommon.priority.${priority}`)} diff --git a/web/app/components/datasets/chunk.tsx b/web/app/components/datasets/chunk.tsx new file mode 100644 index 00000000000000..bf2835dbdbe183 --- /dev/null +++ b/web/app/components/datasets/chunk.tsx @@ -0,0 +1,54 @@ +import type { FC, PropsWithChildren } from 'react' +import { SelectionMod } from '../base/icons/src/public/knowledge' +import type { QA } from '@/models/datasets' + +export type ChunkLabelProps = { + label: string + characterCount: number +} + +export const ChunkLabel: FC = (props) => { + const { label, characterCount } = props + return
+ +

+ {label} + + + · + + + {`${characterCount} characters`} +

+
+} + +export type ChunkContainerProps = ChunkLabelProps & PropsWithChildren + +export const ChunkContainer: FC = (props) => { + const { label, characterCount, children } = props + return
+ +
+ {children} +
+
+} + +export type QAPreviewProps = { + qa: QA +} + +export const QAPreview: FC = (props) => { + const { qa } = props + return
+
+ +

{qa.question}

+
+
+ +

{qa.answer}

+
+
+} diff --git a/web/app/components/datasets/common/chunking-mode-label.tsx b/web/app/components/datasets/common/chunking-mode-label.tsx new file mode 100644 index 00000000000000..7c6e924009dd7b --- /dev/null +++ b/web/app/components/datasets/common/chunking-mode-label.tsx @@ -0,0 +1,29 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import Badge from '@/app/components/base/badge' +import { GeneralType, ParentChildType } from '@/app/components/base/icons/src/public/knowledge' + +type Props = { + isGeneralMode: boolean + isQAMode: boolean +} + +const ChunkingModeLabel: FC = ({ + isGeneralMode, + isQAMode, +}) => { + const { t } = useTranslation() + const TypeIcon = isGeneralMode ? GeneralType : ParentChildType + + return ( + +
+ + {isGeneralMode ? `${t('dataset.chunkingMode.general')}${isQAMode ? ' · QA' : ''}` : t('dataset.chunkingMode.parentChild')} +
+
+ ) +} +export default React.memo(ChunkingModeLabel) diff --git a/web/app/components/datasets/common/document-file-icon.tsx b/web/app/components/datasets/common/document-file-icon.tsx new file mode 100644 index 00000000000000..5842cbbc7c3fca --- /dev/null +++ b/web/app/components/datasets/common/document-file-icon.tsx @@ -0,0 +1,40 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import FileTypeIcon from '../../base/file-uploader/file-type-icon' +import type { FileAppearanceType } from '@/app/components/base/file-uploader/types' +import { FileAppearanceTypeEnum } from '@/app/components/base/file-uploader/types' + +const extendToFileTypeMap: { [key: string]: FileAppearanceType } = { + pdf: FileAppearanceTypeEnum.pdf, + json: FileAppearanceTypeEnum.document, + html: FileAppearanceTypeEnum.document, + txt: FileAppearanceTypeEnum.document, + markdown: FileAppearanceTypeEnum.markdown, + md: FileAppearanceTypeEnum.markdown, + xlsx: FileAppearanceTypeEnum.excel, + xls: FileAppearanceTypeEnum.excel, + csv: FileAppearanceTypeEnum.excel, + doc: FileAppearanceTypeEnum.word, + docx: FileAppearanceTypeEnum.word, +} + +type Props = { + extension?: string + name?: string + size?: 'sm' | 'lg' | 'md' + className?: string +} + +const DocumentFileIcon: FC = ({ + extension, + name, + size = 'md', + className, +}) => { + const localExtension = extension?.toLowerCase() || name?.split('.')?.pop()?.toLowerCase() + return ( + + ) +} +export default React.memo(DocumentFileIcon) diff --git a/web/app/components/datasets/common/document-picker/document-list.tsx b/web/app/components/datasets/common/document-picker/document-list.tsx new file mode 100644 index 00000000000000..3e320d75073416 --- /dev/null +++ b/web/app/components/datasets/common/document-picker/document-list.tsx @@ -0,0 +1,42 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import FileIcon from '../document-file-icon' +import cn from '@/utils/classnames' +import type { DocumentItem } from '@/models/datasets' + +type Props = { + className?: string + list: DocumentItem[] + onChange: (value: DocumentItem) => void +} + +const DocumentList: FC = ({ + className, + list, + onChange, +}) => { + const handleChange = useCallback((item: DocumentItem) => { + return () => onChange(item) + }, [onChange]) + + return ( +
+ {list.map((item) => { + const { id, name, extension } = item + return ( +
+ +
{name}
+
+ ) + })} +
+ ) +} + +export default React.memo(DocumentList) diff --git a/web/app/components/datasets/common/document-picker/index.tsx b/web/app/components/datasets/common/document-picker/index.tsx new file mode 100644 index 00000000000000..30690fca007cef --- /dev/null +++ b/web/app/components/datasets/common/document-picker/index.tsx @@ -0,0 +1,118 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback, useState } from 'react' +import { useBoolean } from 'ahooks' +import { RiArrowDownSLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import FileIcon from '../document-file-icon' +import DocumentList from './document-list' +import type { DocumentItem, ParentMode, SimpleDocumentDetail } from '@/models/datasets' +import { ProcessMode } from '@/models/datasets' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import cn from '@/utils/classnames' +import SearchInput from '@/app/components/base/search-input' +import { GeneralType, ParentChildType } from '@/app/components/base/icons/src/public/knowledge' +import { useDocumentList } from '@/service/knowledge/use-document' +import Loading from '@/app/components/base/loading' + +type Props = { + datasetId: string + value: { + name?: string + extension?: string + processMode?: ProcessMode + parentMode?: ParentMode + } + onChange: (value: SimpleDocumentDetail) => void +} + +const DocumentPicker: FC = ({ + datasetId, + value, + onChange, +}) => { + const { t } = useTranslation() + const { + name, + extension, + processMode, + parentMode, + } = value + const [query, setQuery] = useState('') + + const { data } = useDocumentList({ + datasetId, + query: { + keyword: query, + page: 1, + limit: 20, + }, + }) + const documentsList = data?.data + const isParentChild = processMode === ProcessMode.parentChild + const TypeIcon = isParentChild ? ParentChildType : GeneralType + + const [open, { + set: setOpen, + toggle: togglePopup, + }] = useBoolean(false) + const ArrowIcon = RiArrowDownSLine + + const handleChange = useCallback(({ id }: DocumentItem) => { + onChange(documentsList?.find(item => item.id === id) as SimpleDocumentDetail) + setOpen(false) + }, [documentsList, onChange, setOpen]) + + return ( + + +
+ +
+
+ {name || '--'} + +
+
+ + + {isParentChild ? t('dataset.chunkingMode.parentChild') : t('dataset.chunkingMode.general')} + {isParentChild && ` · ${!parentMode ? '--' : parentMode === 'paragraph' ? t('dataset.parentMode.paragraph') : t('dataset.parentMode.fullDoc')}`} + +
+
+
+
+ +
+ + {documentsList + ? ( + ({ + id: d.id, + name: d.name, + extension: d.data_source_detail_dict?.upload_file?.extension || '', + }))} + onChange={handleChange} + /> + ) + : (
+ +
)} +
+ +
+
+ ) +} +export default React.memo(DocumentPicker) diff --git a/web/app/components/datasets/common/document-picker/preview-document-picker.tsx b/web/app/components/datasets/common/document-picker/preview-document-picker.tsx new file mode 100644 index 00000000000000..2a35b75471cde8 --- /dev/null +++ b/web/app/components/datasets/common/document-picker/preview-document-picker.tsx @@ -0,0 +1,82 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useBoolean } from 'ahooks' +import { RiArrowDownSLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import FileIcon from '../document-file-icon' +import DocumentList from './document-list' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import cn from '@/utils/classnames' +import Loading from '@/app/components/base/loading' +import type { DocumentItem } from '@/models/datasets' + +type Props = { + className?: string + value: DocumentItem + files: DocumentItem[] + onChange: (value: DocumentItem) => void +} + +const PreviewDocumentPicker: FC = ({ + className, + value, + files, + onChange, +}) => { + const { t } = useTranslation() + const { name, extension } = value + + const [open, { + set: setOpen, + toggle: togglePopup, + }] = useBoolean(false) + const ArrowIcon = RiArrowDownSLine + + const handleChange = useCallback((item: DocumentItem) => { + onChange(item) + setOpen(false) + }, [onChange, setOpen]) + + return ( + + +
+ +
+
+ {name || '--'} + +
+
+
+
+ +
+ {files?.length > 1 &&
{t('dataset.preprocessDocument', { num: files.length })}
} + {files?.length > 0 + ? ( + + ) + : (
+ +
)} +
+ +
+
+ ) +} +export default React.memo(PreviewDocumentPicker) diff --git a/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx b/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx new file mode 100644 index 00000000000000..b687c004e5c5cd --- /dev/null +++ b/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx @@ -0,0 +1,38 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import StatusWithAction from './status-with-action' +import { useAutoDisabledDocuments, useDocumentEnable, useInvalidDisabledDocument } from '@/service/knowledge/use-document' +import Toast from '@/app/components/base/toast' +type Props = { + datasetId: string +} + +const AutoDisabledDocument: FC = ({ + datasetId, +}) => { + const { t } = useTranslation() + const { data, isLoading } = useAutoDisabledDocuments(datasetId) + const invalidDisabledDocument = useInvalidDisabledDocument() + const documentIds = data?.document_ids + const hasDisabledDocument = documentIds && documentIds.length > 0 + const { mutateAsync: enableDocument } = useDocumentEnable() + const handleEnableDocuments = useCallback(async () => { + await enableDocument({ datasetId, documentIds }) + invalidDisabledDocument() + Toast.notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + }, []) + if (!hasDisabledDocument || isLoading) + return null + + return ( + + ) +} +export default React.memo(AutoDisabledDocument) diff --git a/web/app/components/datasets/common/document-status-with-action/index-failed.tsx b/web/app/components/datasets/common/document-status-with-action/index-failed.tsx new file mode 100644 index 00000000000000..37311768b95756 --- /dev/null +++ b/web/app/components/datasets/common/document-status-with-action/index-failed.tsx @@ -0,0 +1,69 @@ +'use client' +import type { FC } from 'react' +import React, { useEffect, useReducer } from 'react' +import { useTranslation } from 'react-i18next' +import useSWR from 'swr' +import StatusWithAction from './status-with-action' +import { getErrorDocs, retryErrorDocs } from '@/service/datasets' +import type { IndexingStatusResponse } from '@/models/datasets' + +type Props = { + datasetId: string +} +type IIndexState = { + value: string +} +type ActionType = 'retry' | 'success' | 'error' + +type IAction = { + type: ActionType +} +const indexStateReducer = (state: IIndexState, action: IAction) => { + const actionMap = { + retry: 'retry', + success: 'success', + error: 'error', + } + + return { + ...state, + value: actionMap[action.type] || state.value, + } +} + +const RetryButton: FC = ({ datasetId }) => { + const { t } = useTranslation() + const [indexState, dispatch] = useReducer(indexStateReducer, { value: 'success' }) + const { data: errorDocs, isLoading } = useSWR({ datasetId }, getErrorDocs) + + const onRetryErrorDocs = async () => { + dispatch({ type: 'retry' }) + const document_ids = errorDocs?.data.map((doc: IndexingStatusResponse) => doc.id) || [] + const res = await retryErrorDocs({ datasetId, document_ids }) + if (res.result === 'success') + dispatch({ type: 'success' }) + else + dispatch({ type: 'error' }) + } + + useEffect(() => { + if (errorDocs?.total === 0) + dispatch({ type: 'success' }) + else + dispatch({ type: 'error' }) + }, [errorDocs?.total]) + + if (isLoading || indexState.value === 'success') + return null + + return ( + { }} + /> + ) +} +export default RetryButton diff --git a/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx b/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx new file mode 100644 index 00000000000000..a8da9bf6cc2a8b --- /dev/null +++ b/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx @@ -0,0 +1,65 @@ +'use client' +import { RiAlertFill, RiCheckboxCircleFill, RiErrorWarningFill, RiInformation2Fill } from '@remixicon/react' +import type { FC } from 'react' +import React from 'react' +import cn from '@/utils/classnames' +import Divider from '@/app/components/base/divider' + +type Status = 'success' | 'error' | 'warning' | 'info' +type Props = { + type?: Status + description: string + actionText: string + onAction: () => void + disabled?: boolean +} + +const IconMap = { + success: { + Icon: RiCheckboxCircleFill, + color: 'text-text-success', + }, + error: { + Icon: RiErrorWarningFill, + color: 'text-text-destructive', + }, + warning: { + Icon: RiAlertFill, + color: 'text-text-warning-secondary', + }, + info: { + Icon: RiInformation2Fill, + color: 'text-text-accent', + }, +} + +const getIcon = (type: Status) => { + return IconMap[type] +} + +const StatusAction: FC = ({ + type = 'info', + description, + actionText, + onAction, + disabled, +}) => { + const { Icon, color } = getIcon(type) + return ( +
+
+
+ +
{description}
+ +
{actionText}
+
+
+ ) +} +export default React.memo(StatusAction) diff --git a/web/app/components/datasets/common/economical-retrieval-method-config/index.tsx b/web/app/components/datasets/common/economical-retrieval-method-config/index.tsx index f3da67b92cc5e9..9236858ae4c906 100644 --- a/web/app/components/datasets/common/economical-retrieval-method-config/index.tsx +++ b/web/app/components/datasets/common/economical-retrieval-method-config/index.tsx @@ -2,10 +2,11 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' +import Image from 'next/image' import RetrievalParamConfig from '../retrieval-param-config' +import { OptionCard } from '../../create/step-two/option-card' +import { retrievalIcon } from '../../create/icons' import { RETRIEVE_METHOD } from '@/types/app' -import RadioCard from '@/app/components/base/radio-card' -import { HighPriority } from '@/app/components/base/icons/src/vender/solid/arrows' import type { RetrievalConfig } from '@/types/app' type Props = { @@ -21,19 +22,17 @@ const EconomicalRetrievalMethodConfig: FC = ({ return (
- } + } title={t('dataset.retrieval.invertedIndex.title')} - description={t('dataset.retrieval.invertedIndex.description')} - noRadio - chosenConfig={ - - } - /> + description={t('dataset.retrieval.invertedIndex.description')} isActive + activeHeaderClassName='bg-dataset-option-card-purple-gradient' + > + +
) } diff --git a/web/app/components/datasets/common/retrieval-method-config/index.tsx b/web/app/components/datasets/common/retrieval-method-config/index.tsx index 20d93568addbb7..9ab157571b5b66 100644 --- a/web/app/components/datasets/common/retrieval-method-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-method-config/index.tsx @@ -2,12 +2,13 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' +import Image from 'next/image' import RetrievalParamConfig from '../retrieval-param-config' +import { OptionCard } from '../../create/step-two/option-card' +import Effect from '../../create/assets/option-card-effect-purple.svg' +import { retrievalIcon } from '../../create/icons' import type { RetrievalConfig } from '@/types/app' import { RETRIEVE_METHOD } from '@/types/app' -import RadioCard from '@/app/components/base/radio-card' -import { PatternRecognition, Semantic } from '@/app/components/base/icons/src/vender/solid/development' -import { FileSearch02 } from '@/app/components/base/icons/src/vender/solid/files' import { useProviderContext } from '@/context/provider-context' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' @@ -16,6 +17,7 @@ import { RerankingModeEnum, WeightedScoreEnum, } from '@/models/datasets' +import Badge from '@/app/components/base/badge' type Props = { value: RetrievalConfig @@ -56,67 +58,72 @@ const RetrievalMethodConfig: FC = ({ return (
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( - } + } title={t('dataset.retrieval.semantic_search.title')} description={t('dataset.retrieval.semantic_search.description')} - isChosen={value.search_method === RETRIEVE_METHOD.semantic} - onChosen={() => onChange({ + isActive={ + value.search_method === RETRIEVE_METHOD.semantic + } + onSwitched={() => onChange({ ...value, search_method: RETRIEVE_METHOD.semantic, })} - chosenConfig={ - - } - /> + effectImg={Effect.src} + activeHeaderClassName='bg-dataset-option-card-purple-gradient' + > + + )} {supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( - } + } title={t('dataset.retrieval.full_text_search.title')} description={t('dataset.retrieval.full_text_search.description')} - isChosen={value.search_method === RETRIEVE_METHOD.fullText} - onChosen={() => onChange({ + isActive={ + value.search_method === RETRIEVE_METHOD.fullText + } + onSwitched={() => onChange({ ...value, search_method: RETRIEVE_METHOD.fullText, })} - chosenConfig={ - - } - /> + effectImg={Effect.src} + activeHeaderClassName='bg-dataset-option-card-purple-gradient' + > + + )} {supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( - } + } title={
{t('dataset.retrieval.hybrid_search.title')}
-
{t('dataset.retrieval.hybrid_search.recommend')}
+
} - description={t('dataset.retrieval.hybrid_search.description')} - isChosen={value.search_method === RETRIEVE_METHOD.hybrid} - onChosen={() => onChange({ + description={t('dataset.retrieval.hybrid_search.description')} isActive={ + value.search_method === RETRIEVE_METHOD.hybrid + } + onSwitched={() => onChange({ ...value, search_method: RETRIEVE_METHOD.hybrid, reranking_enable: true, })} - chosenConfig={ - - } - /> + effectImg={Effect.src} + activeHeaderClassName='bg-dataset-option-card-purple-gradient' + > + +
)}
) diff --git a/web/app/components/datasets/common/retrieval-method-info/index.tsx b/web/app/components/datasets/common/retrieval-method-info/index.tsx index 7d9b999c53941f..fc3020d4a965a2 100644 --- a/web/app/components/datasets/common/retrieval-method-info/index.tsx +++ b/web/app/components/datasets/common/retrieval-method-info/index.tsx @@ -2,12 +2,11 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' +import Image from 'next/image' +import { retrievalIcon } from '../../create/icons' import type { RetrievalConfig } from '@/types/app' import { RETRIEVE_METHOD } from '@/types/app' import RadioCard from '@/app/components/base/radio-card' -import { HighPriority } from '@/app/components/base/icons/src/vender/solid/arrows' -import { PatternRecognition, Semantic } from '@/app/components/base/icons/src/vender/solid/development' -import { FileSearch02 } from '@/app/components/base/icons/src/vender/solid/files' type Props = { value: RetrievalConfig @@ -15,11 +14,12 @@ type Props = { export const getIcon = (type: RETRIEVE_METHOD) => { return ({ - [RETRIEVE_METHOD.semantic]: Semantic, - [RETRIEVE_METHOD.fullText]: FileSearch02, - [RETRIEVE_METHOD.hybrid]: PatternRecognition, - [RETRIEVE_METHOD.invertedIndex]: HighPriority, - })[type] || FileSearch02 + [RETRIEVE_METHOD.semantic]: retrievalIcon.vector, + [RETRIEVE_METHOD.fullText]: retrievalIcon.fullText, + [RETRIEVE_METHOD.hybrid]: retrievalIcon.hybrid, + [RETRIEVE_METHOD.invertedIndex]: retrievalIcon.vector, + [RETRIEVE_METHOD.keywordSearch]: retrievalIcon.vector, + })[type] || retrievalIcon.vector } const EconomicalRetrievalMethodConfig: FC = ({ @@ -28,11 +28,11 @@ const EconomicalRetrievalMethodConfig: FC = ({ }) => { const { t } = useTranslation() const type = value.search_method - const Icon = getIcon(type) + const icon = return (
} + icon={icon} title={t(`dataset.retrieval.${type}.title`)} description={t(`dataset.retrieval.${type}.description`)} noRadio diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index 9d48d56a8dc511..5136ac1659159d 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -3,6 +3,9 @@ import type { FC } from 'react' import React, { useCallback } from 'react' import { useTranslation } from 'react-i18next' +import Image from 'next/image' +import ProgressIndicator from '../../create/assets/progress-indicator.svg' +import Reranking from '../../create/assets/rerank.svg' import cn from '@/utils/classnames' import TopKItem from '@/app/components/base/param-item/top-k-item' import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold-item' @@ -20,6 +23,7 @@ import { } from '@/models/datasets' import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score' import Toast from '@/app/components/base/toast' +import RadioCard from '@/app/components/base/radio-card' type Props = { type: RETRIEVE_METHOD @@ -116,7 +120,7 @@ const RetrievalParamConfig: FC = ({
{!isEconomical && !isHybridSearch && (
-
+
{canToggleRerankModalEnable && (
= ({
)}
- {t('common.modelProvider.rerankModel.key')} + {t('common.modelProvider.rerankModel.key')} {t('common.modelProvider.rerankModel.tip')}
@@ -163,7 +167,7 @@ const RetrievalParamConfig: FC = ({ )} { !isHybridSearch && ( -
+
= ({ { isHybridSearch && ( <> -
+
{ rerankingModeOptions.map(option => ( -
handleChangeRerankMode(option.value)} - > -
{option.label}
- {option.tips}
} - triggerClassName='ml-0.5 w-3.5 h-3.5' - /> -
+ isChosen={value.reranking_mode === option.value} + onChosen={() => handleChangeRerankMode(option.value)} + icon={} + title={option.label} + description={option.tips} + className='flex-1' + /> )) }
diff --git a/web/app/components/datasets/create/assets/family-mod.svg b/web/app/components/datasets/create/assets/family-mod.svg new file mode 100644 index 00000000000000..b1c4e6f566e54f --- /dev/null +++ b/web/app/components/datasets/create/assets/family-mod.svg @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/web/app/components/datasets/create/assets/file-list-3-fill.svg b/web/app/components/datasets/create/assets/file-list-3-fill.svg new file mode 100644 index 00000000000000..a4e6c4da9783e8 --- /dev/null +++ b/web/app/components/datasets/create/assets/file-list-3-fill.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/web/app/components/datasets/create/assets/gold.svg b/web/app/components/datasets/create/assets/gold.svg new file mode 100644 index 00000000000000..b48ac0eae5de02 --- /dev/null +++ b/web/app/components/datasets/create/assets/gold.svg @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/web/app/components/datasets/create/assets/note-mod.svg b/web/app/components/datasets/create/assets/note-mod.svg new file mode 100644 index 00000000000000..b9e81f6bd533b5 --- /dev/null +++ b/web/app/components/datasets/create/assets/note-mod.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/web/app/components/datasets/create/assets/option-card-effect-blue.svg b/web/app/components/datasets/create/assets/option-card-effect-blue.svg new file mode 100644 index 00000000000000..00a8afad8b1b23 --- /dev/null +++ b/web/app/components/datasets/create/assets/option-card-effect-blue.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/web/app/components/datasets/create/assets/option-card-effect-orange.svg b/web/app/components/datasets/create/assets/option-card-effect-orange.svg new file mode 100644 index 00000000000000..d833764f0cba63 --- /dev/null +++ b/web/app/components/datasets/create/assets/option-card-effect-orange.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/web/app/components/datasets/create/assets/option-card-effect-purple.svg b/web/app/components/datasets/create/assets/option-card-effect-purple.svg new file mode 100644 index 00000000000000..a7857f8e570f05 --- /dev/null +++ b/web/app/components/datasets/create/assets/option-card-effect-purple.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/web/app/components/datasets/create/assets/pattern-recognition-mod.svg b/web/app/components/datasets/create/assets/pattern-recognition-mod.svg new file mode 100644 index 00000000000000..1083e888ed6820 --- /dev/null +++ b/web/app/components/datasets/create/assets/pattern-recognition-mod.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/web/app/components/datasets/create/assets/piggy-bank-mod.svg b/web/app/components/datasets/create/assets/piggy-bank-mod.svg new file mode 100644 index 00000000000000..b1120ad9a9c249 --- /dev/null +++ b/web/app/components/datasets/create/assets/piggy-bank-mod.svg @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file diff --git a/web/app/components/datasets/create/assets/progress-indicator.svg b/web/app/components/datasets/create/assets/progress-indicator.svg new file mode 100644 index 00000000000000..3c997136360595 --- /dev/null +++ b/web/app/components/datasets/create/assets/progress-indicator.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/web/app/components/datasets/create/assets/rerank.svg b/web/app/components/datasets/create/assets/rerank.svg new file mode 100644 index 00000000000000..409b52e6e23804 --- /dev/null +++ b/web/app/components/datasets/create/assets/rerank.svg @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/web/app/components/datasets/create/assets/research-mod.svg b/web/app/components/datasets/create/assets/research-mod.svg new file mode 100644 index 00000000000000..1f0bb3423351a1 --- /dev/null +++ b/web/app/components/datasets/create/assets/research-mod.svg @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/web/app/components/datasets/create/assets/selection-mod.svg b/web/app/components/datasets/create/assets/selection-mod.svg new file mode 100644 index 00000000000000..2d0dd3b5f74a4f --- /dev/null +++ b/web/app/components/datasets/create/assets/selection-mod.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/web/app/components/datasets/create/assets/setting-gear-mod.svg b/web/app/components/datasets/create/assets/setting-gear-mod.svg new file mode 100644 index 00000000000000..c782caade88e15 --- /dev/null +++ b/web/app/components/datasets/create/assets/setting-gear-mod.svg @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/web/app/components/datasets/create/embedding-process/index.module.css b/web/app/components/datasets/create/embedding-process/index.module.css index 1ebb006b543ac2..f2ab4d85a27210 100644 --- a/web/app/components/datasets/create/embedding-process/index.module.css +++ b/web/app/components/datasets/create/embedding-process/index.module.css @@ -14,24 +14,7 @@ border-radius: 6px; overflow: hidden; } -.sourceItem.error { - background: #FEE4E2; -} -.sourceItem.success { - background: #D1FADF; -} -.progressbar { - position: absolute; - top: 0; - left: 0; - height: 100%; - background-color: #B2CCFF; -} -.sourceItem .info { - display: flex; - align-items: center; - z-index: 1; -} + .sourceItem .info .name { font-weight: 500; font-size: 12px; @@ -55,13 +38,6 @@ color: #05603A; } - -.cost { - @apply flex justify-between items-center text-xs text-gray-700; -} -.embeddingStatus { - @apply flex items-center justify-between text-gray-900 font-medium text-sm mr-2; -} .commonIcon { @apply w-3 h-3 mr-1 inline-block align-middle; } @@ -81,35 +57,33 @@ @apply text-xs font-medium; } -.fileIcon { - @apply w-4 h-4 mr-1 bg-center bg-no-repeat; +.unknownFileIcon { background-image: url(../assets/unknown.svg); - background-size: 16px; } -.fileIcon.csv { +.csv { background-image: url(../assets/csv.svg); } -.fileIcon.docx { +.docx { background-image: url(../assets/docx.svg); } -.fileIcon.xlsx, -.fileIcon.xls { +.xlsx, +.xls { background-image: url(../assets/xlsx.svg); } -.fileIcon.pdf { +.pdf { background-image: url(../assets/pdf.svg); } -.fileIcon.html, -.fileIcon.htm { +.html, +.htm { background-image: url(../assets/html.svg); } -.fileIcon.md, -.fileIcon.markdown { +.md, +.markdown { background-image: url(../assets/md.svg); } -.fileIcon.txt { +.txt { background-image: url(../assets/txt.svg); } -.fileIcon.json { +.json { background-image: url(../assets/json.svg); } diff --git a/web/app/components/datasets/create/embedding-process/index.tsx b/web/app/components/datasets/create/embedding-process/index.tsx index 7786582085c16d..201333ffce4cbe 100644 --- a/web/app/components/datasets/create/embedding-process/index.tsx +++ b/web/app/components/datasets/create/embedding-process/index.tsx @@ -6,32 +6,44 @@ import { useTranslation } from 'react-i18next' import { omit } from 'lodash-es' import { ArrowRightIcon } from '@heroicons/react/24/solid' import { + RiCheckboxCircleFill, RiErrorWarningFill, + RiLoader2Fill, + RiTerminalBoxLine, } from '@remixicon/react' -import s from './index.module.css' +import Image from 'next/image' +import { indexMethodIcon, retrievalIcon } from '../icons' +import { IndexingType } from '../step-two' +import DocumentFileIcon from '../../common/document-file-icon' import cn from '@/utils/classnames' import { FieldInfo } from '@/app/components/datasets/documents/detail/metadata' import Button from '@/app/components/base/button' import type { FullDocumentDetail, IndexingStatusResponse, ProcessRuleResponse } from '@/models/datasets' import { fetchIndexingStatusBatch as doFetchIndexingStatus, fetchProcessRule } from '@/service/datasets' -import { DataSourceType } from '@/models/datasets' +import { DataSourceType, ProcessMode } from '@/models/datasets' import NotionIcon from '@/app/components/base/notion-icon' import PriorityLabel from '@/app/components/billing/priority-label' import { Plan } from '@/app/components/billing/type' import { ZapFast } from '@/app/components/base/icons/src/vender/solid/general' import UpgradeBtn from '@/app/components/billing/upgrade-btn' import { useProviderContext } from '@/context/provider-context' -import Tooltip from '@/app/components/base/tooltip' import { sleep } from '@/utils' +import { RETRIEVE_METHOD } from '@/types/app' +import Tooltip from '@/app/components/base/tooltip' type Props = { datasetId: string batchId: string documents?: FullDocumentDetail[] indexingType?: string + retrievalMethod?: string } -const RuleDetail: FC<{ sourceData?: ProcessRuleResponse }> = ({ sourceData }) => { +const RuleDetail: FC<{ + sourceData?: ProcessRuleResponse + indexingType?: string + retrievalMethod?: string +}> = ({ sourceData, indexingType, retrievalMethod }) => { const { t } = useTranslation() const segmentationRuleMap = { @@ -51,29 +63,47 @@ const RuleDetail: FC<{ sourceData?: ProcessRuleResponse }> = ({ sourceData }) => return t('datasetCreation.stepTwo.removeStopwords') } + const isNumber = (value: unknown) => { + return typeof value === 'number' + } + const getValue = useCallback((field: string) => { let value: string | number | undefined = '-' + const maxTokens = isNumber(sourceData?.rules?.segmentation?.max_tokens) + ? sourceData.rules.segmentation.max_tokens + : value + const childMaxTokens = isNumber(sourceData?.rules?.subchunk_segmentation?.max_tokens) + ? sourceData.rules.subchunk_segmentation.max_tokens + : value switch (field) { case 'mode': - value = sourceData?.mode === 'automatic' ? (t('datasetDocuments.embedding.automatic') as string) : (t('datasetDocuments.embedding.custom') as string) + value = !sourceData?.mode + ? value + : sourceData.mode === ProcessMode.general + ? (t('datasetDocuments.embedding.custom') as string) + : `${t('datasetDocuments.embedding.hierarchical')} · ${sourceData?.rules?.parent_mode === 'paragraph' + ? t('dataset.parentMode.paragraph') + : t('dataset.parentMode.fullDoc')}` break case 'segmentLength': - value = sourceData?.rules?.segmentation?.max_tokens + value = !sourceData?.mode + ? value + : sourceData.mode === ProcessMode.general + ? maxTokens + : `${t('datasetDocuments.embedding.parentMaxTokens')} ${maxTokens}; ${t('datasetDocuments.embedding.childMaxTokens')} ${childMaxTokens}` break default: - value = sourceData?.mode === 'automatic' - ? (t('datasetDocuments.embedding.automatic') as string) - // eslint-disable-next-line array-callback-return - : sourceData?.rules?.pre_processing_rules?.map((rule) => { - if (rule.enabled) - return getRuleName(rule.id) - }).filter(Boolean).join(';') + value = !sourceData?.mode + ? value + : sourceData?.rules?.pre_processing_rules?.filter(rule => + rule.enabled).map(rule => getRuleName(rule.id)).join(',') break } return value + // eslint-disable-next-line react-hooks/exhaustive-deps }, [sourceData]) - return
+ return
{Object.keys(segmentationRuleMap).map((field) => { return = ({ sourceData }) => displayedValue={String(getValue(field))} /> })} + + } + /> + + } + />
} -const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], indexingType }) => { +const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], indexingType, retrievalMethod }) => { const { t } = useTranslation() const { enableBilling, plan } = useProviderContext() @@ -127,6 +190,7 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index } useEffect(() => { + setIsStopQuery(false) startQueryStatus() return () => { stopQueryStatus() @@ -146,6 +210,9 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index const navToDocumentList = () => { router.push(`/datasets/${datasetId}/documents`) } + const navToApiDocs = () => { + router.push('/datasets?category=api') + } const isEmbedding = useMemo(() => { return indexingStatusBatchDetail.some(indexingStatusDetail => ['indexing', 'splitting', 'parsing', 'cleaning'].includes(indexingStatusDetail?.indexing_status || '')) @@ -177,13 +244,17 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index return doc?.data_source_info.notion_page_icon } - const isSourceEmbedding = (detail: IndexingStatusResponse) => ['indexing', 'splitting', 'parsing', 'cleaning', 'waiting'].includes(detail.indexing_status || '') + const isSourceEmbedding = (detail: IndexingStatusResponse) => + ['indexing', 'splitting', 'parsing', 'cleaning', 'waiting'].includes(detail.indexing_status || '') return ( <> -
-
- {isEmbedding && t('datasetDocuments.embedding.processing')} +
+
+ {isEmbedding &&
+ + {t('datasetDocuments.embedding.processing')} +
} {isEmbeddingCompleted && t('datasetDocuments.embedding.completed')}
@@ -200,69 +271,80 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index
) } -
+
{indexingStatusBatchDetail.map(indexingStatusDetail => (
{isSourceEmbedding(indexingStatusDetail) && ( -
+
)} -
+
{getSourceType(indexingStatusDetail.id) === DataSourceType.FILE && ( -
+ //
+ )} {getSourceType(indexingStatusDetail.id) === DataSourceType.NOTION && ( )} -
{getSourceName(indexingStatusDetail.id)}
- { - enableBilling && ( - - ) - } -
-
+
+
+ {getSourceName(indexingStatusDetail.id)} +
+ { + enableBilling && ( + + ) + } +
{isSourceEmbedding(indexingStatusDetail) && ( -
{`${getSourcePercent(indexingStatusDetail)}%`}
+
{`${getSourcePercent(indexingStatusDetail)}%`}
)} - {indexingStatusDetail.indexing_status === 'error' && indexingStatusDetail.error && ( + {indexingStatusDetail.indexing_status === 'error' && ( - {indexingStatusDetail.error} -
- )} + popupClassName='px-4 py-[14px] max-w-60 text-sm leading-4 text-text-secondary border-[0.5px] border-components-panel-border rounded-xl' + offset={4} + popupContent={indexingStatusDetail.error} > -
- Error - -
+ + + )} - {indexingStatusDetail.indexing_status === 'error' && !indexingStatusDetail.error && ( -
- Error -
- )} {indexingStatusDetail.indexing_status === 'completed' && ( -
100%
+ )}
))}
- -
+
+ +
+
diff --git a/web/app/components/datasets/create/file-preview/index.module.css b/web/app/components/datasets/create/file-preview/index.module.css index d87522e6d0bd4c..929002e1e2985c 100644 --- a/web/app/components/datasets/create/file-preview/index.module.css +++ b/web/app/components/datasets/create/file-preview/index.module.css @@ -1,6 +1,6 @@ .filePreview { @apply flex flex-col border-l border-gray-200 shrink-0; - width: 528px; + width: 100%; background-color: #fcfcfd; } @@ -48,5 +48,6 @@ } .fileContent { white-space: pre-line; + word-break: break-all; } \ No newline at end of file diff --git a/web/app/components/datasets/create/file-preview/index.tsx b/web/app/components/datasets/create/file-preview/index.tsx index e20af64386c509..cb1f1d6908c4ef 100644 --- a/web/app/components/datasets/create/file-preview/index.tsx +++ b/web/app/components/datasets/create/file-preview/index.tsx @@ -44,7 +44,7 @@ const FilePreview = ({ }, [file]) return ( -
+
{t('datasetCreation.stepOne.filePreview')} @@ -59,7 +59,7 @@ const FilePreview = ({
{loading &&
} {!loading && ( -
{previewContent}
+
{previewContent}
)}
diff --git a/web/app/components/datasets/create/file-uploader/index.module.css b/web/app/components/datasets/create/file-uploader/index.module.css index bf5b7dcaf5b9b7..7d29f2ef9c2b51 100644 --- a/web/app/components/datasets/create/file-uploader/index.module.css +++ b/web/app/components/datasets/create/file-uploader/index.module.css @@ -1,68 +1,3 @@ -.fileUploader { - @apply mb-6; -} - -.fileUploader .title { - @apply mb-2; - font-weight: 500; - font-size: 16px; - line-height: 24px; - color: #344054; -} - -.fileUploader .tip { - font-weight: 400; - font-size: 12px; - line-height: 18px; - color: #667085; -} - -.uploader { - @apply relative box-border flex justify-center items-center mb-2 p-3; - flex-direction: column; - max-width: 640px; - min-height: 80px; - background: #F9FAFB; - border: 1px dashed #EAECF0; - border-radius: 12px; - font-weight: 400; - font-size: 14px; - line-height: 20px; - color: #667085; -} - -.uploader.dragging { - background: #F5F8FF; - border: 1px dashed #B2CCFF; -} - -.uploader .draggingCover { - position: absolute; - top: 0; - left: 0; - width: 100%; - height: 100%; -} - -.uploader .uploadIcon { - content: ''; - display: block; - margin-right: 8px; - width: 24px; - height: 24px; - background: center no-repeat url(../assets/upload-cloud-01.svg); - background-size: contain; -} - -.uploader .browse { - @apply pl-1 cursor-pointer; - color: #155eef; -} - -.fileList { - @apply space-y-2; -} - .file { @apply box-border relative flex items-center justify-between; padding: 8px 12px 8px 8px; @@ -193,4 +128,4 @@ .file:hover .actionWrapper .remove { display: block; -} \ No newline at end of file +} diff --git a/web/app/components/datasets/create/file-uploader/index.tsx b/web/app/components/datasets/create/file-uploader/index.tsx index adb4bed0d167ed..e42a24cfef52d3 100644 --- a/web/app/components/datasets/create/file-uploader/index.tsx +++ b/web/app/components/datasets/create/file-uploader/index.tsx @@ -3,10 +3,12 @@ import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import useSWR from 'swr' -import s from './index.module.css' +import { RiDeleteBinLine, RiUploadCloud2Line } from '@remixicon/react' +import DocumentFileIcon from '../../common/document-file-icon' import cn from '@/utils/classnames' import type { CustomFile as File, FileItem } from '@/models/datasets' import { ToastContext } from '@/app/components/base/toast' +import SimplePieChart from '@/app/components/base/simple-pie-chart' import { upload } from '@/service/base' import { fetchFileUploadConfig } from '@/service/common' @@ -14,6 +16,8 @@ import { fetchSupportFileTypes } from '@/service/datasets' import I18n from '@/context/i18n' import { LanguagesSupported } from '@/i18n/language' import { IS_CE_EDITION } from '@/config' +import { useAppContext } from '@/context/app-context' +import { Theme } from '@/types/app' const FILES_NUMBER_LIMIT = 20 @@ -222,6 +226,9 @@ const FileUploader = ({ initialUpload(files.filter(isValid)) }, [isValid, initialUpload]) + const { theme } = useAppContext() + const chartColor = useMemo(() => theme === Theme.dark ? '#5289ff' : '#296dff', [theme]) + useEffect(() => { dropRef.current?.addEventListener('dragenter', handleDragEnter) dropRef.current?.addEventListener('dragover', handleDragOver) @@ -236,12 +243,12 @@ const FileUploader = ({ }, [handleDrop]) return ( -
+
{!hideUpload && ( )} -
{t('datasetCreation.stepOne.uploader.title')}
+
{t('datasetCreation.stepOne.uploader.title')}
+ {!hideUpload && ( +
+
+ -
-
- {t('datasetCreation.stepOne.uploader.button')} - + {supportTypes.length > 0 && ( + + )}
-
{t('datasetCreation.stepOne.uploader.tip', { +
{t('datasetCreation.stepOne.uploader.tip', { size: fileUploadConfig.file_size_limit, supportTypes: supportTypesShowNames, })}
- {dragging &&
} + {dragging &&
}
)} -
+
+ {fileList.map((fileItem, index) => (
fileItem.file?.id && onPreview(fileItem.file)} className={cn( - s.file, - fileItem.progress < 100 && s.uploading, + 'flex items-center h-12 max-w-[640px] bg-components-panel-on-panel-item-bg text-xs leading-3 text-text-tertiary border border-components-panel-border rounded-lg shadow-xs', + // 'border-state-destructive-border bg-state-destructive-hover', )} > - {fileItem.progress < 100 && ( -
- )} -
-
-
{fileItem.file.name}
-
{getFileSize(fileItem.file.size)}
+
+ +
+
+
+
{fileItem.file.name}
+
+
+ {getFileType(fileItem.file)} + · + {getFileSize(fileItem.file.size)} + {/* · + 10k characters */} +
-
+
+ {/* + + */} {(fileItem.progress < 100 && fileItem.progress >= 0) && ( -
{`${fileItem.progress}%`}
- )} - {fileItem.progress === 100 && ( -
{ - e.stopPropagation() - removeFile(fileItem.fileID) - }} /> + //
{`${fileItem.progress}%`}
+ )} + { + e.stopPropagation() + removeFile(fileItem.fileID) + }}> + +
))} diff --git a/web/app/components/datasets/create/icons.ts b/web/app/components/datasets/create/icons.ts new file mode 100644 index 00000000000000..80c4b6c944778b --- /dev/null +++ b/web/app/components/datasets/create/icons.ts @@ -0,0 +1,16 @@ +import GoldIcon from './assets/gold.svg' +import Piggybank from './assets/piggy-bank-mod.svg' +import Selection from './assets/selection-mod.svg' +import Research from './assets/research-mod.svg' +import PatternRecognition from './assets/pattern-recognition-mod.svg' + +export const indexMethodIcon = { + high_quality: GoldIcon, + economical: Piggybank, +} + +export const retrievalIcon = { + vector: Selection, + fullText: Research, + hybrid: PatternRecognition, +} diff --git a/web/app/components/datasets/create/index.tsx b/web/app/components/datasets/create/index.tsx index 98098445c7695c..9556b9fad5780e 100644 --- a/web/app/components/datasets/create/index.tsx +++ b/web/app/components/datasets/create/index.tsx @@ -3,10 +3,10 @@ import React, { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import AppUnavailable from '../../base/app-unavailable' import { ModelTypeEnum } from '../../header/account-setting/model-provider-page/declarations' -import StepsNavBar from './steps-nav-bar' import StepOne from './step-one' import StepTwo from './step-two' import StepThree from './step-three' +import { Topbar } from './top-bar' import { DataSourceType } from '@/models/datasets' import type { CrawlOptions, CrawlResultItem, DataSet, FileItem, createDocumentResponse } from '@/models/datasets' import { fetchDataSource } from '@/service/common' @@ -36,6 +36,7 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { const [dataSourceType, setDataSourceType] = useState(DataSourceType.FILE) const [step, setStep] = useState(1) const [indexingTypeCache, setIndexTypeCache] = useState('') + const [retrievalMethodCache, setRetrievalMethodCache] = useState('') const [fileList, setFiles] = useState([]) const [result, setResult] = useState() const [hasError, setHasError] = useState(false) @@ -80,6 +81,9 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { const updateResultCache = (res?: createDocumentResponse) => { setResult(res) } + const updateRetrievalMethodCache = (method: string) => { + setRetrievalMethodCache(method) + } const nextStep = useCallback(() => { setStep(step + 1) @@ -118,33 +122,29 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { return return ( -
-
- -
-
-
- setShowAccountSettingModal({ payload: 'data-source' })} - datasetId={datasetId} - dataSourceType={dataSourceType} - dataSourceTypeDisable={!!detail?.data_source_type} - changeType={setDataSourceType} - files={fileList} - updateFile={updateFile} - updateFileList={updateFileList} - notionPages={notionPages} - updateNotionPages={updateNotionPages} - onStepChange={nextStep} - websitePages={websitePages} - updateWebsitePages={setWebsitePages} - onWebsiteCrawlProviderChange={setWebsiteCrawlProvider} - onWebsiteCrawlJobIdChange={setWebsiteCrawlJobId} - crawlOptions={crawlOptions} - onCrawlOptionsChange={setCrawlOptions} - /> -
+
+ +
+ {step === 1 && setShowAccountSettingModal({ payload: 'data-source' })} + datasetId={datasetId} + dataSourceType={dataSourceType} + dataSourceTypeDisable={!!detail?.data_source_type} + changeType={setDataSourceType} + files={fileList} + updateFile={updateFile} + updateFileList={updateFileList} + notionPages={notionPages} + updateNotionPages={updateNotionPages} + onStepChange={nextStep} + websitePages={websitePages} + updateWebsitePages={setWebsitePages} + onWebsiteCrawlProviderChange={setWebsiteCrawlProvider} + onWebsiteCrawlJobIdChange={setWebsiteCrawlJobId} + crawlOptions={crawlOptions} + onCrawlOptionsChange={setCrawlOptions} + />} {(step === 2 && (!datasetId || (datasetId && !!detail))) && setShowAccountSettingModal({ payload: 'provider' })} @@ -158,6 +158,7 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { websiteCrawlJobId={websiteCrawlJobId} onStepChange={changeStep} updateIndexingTypeCache={updateIndexingTypeCache} + updateRetrievalMethodCache={updateRetrievalMethodCache} updateResultCache={updateResultCache} crawlOptions={crawlOptions} />} @@ -165,6 +166,7 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { datasetId={datasetId} datasetName={detail?.name} indexingType={detail?.indexing_technique || indexingTypeCache} + retrievalMethod={detail?.retrieval_model_dict?.search_method || retrievalMethodCache} creationCache={result} />}
diff --git a/web/app/components/datasets/create/notion-page-preview/index.tsx b/web/app/components/datasets/create/notion-page-preview/index.tsx index 8225e56f0400e4..f658f213e85f17 100644 --- a/web/app/components/datasets/create/notion-page-preview/index.tsx +++ b/web/app/components/datasets/create/notion-page-preview/index.tsx @@ -44,7 +44,7 @@ const NotionPagePreview = ({ }, [currentPage]) return ( -
+
{t('datasetCreation.stepOne.pagePreview')} @@ -64,7 +64,7 @@ const NotionPagePreview = ({
{loading &&
} {!loading && ( -
{previewContent}
+
{previewContent}
)}
diff --git a/web/app/components/datasets/create/step-one/index.module.css b/web/app/components/datasets/create/step-one/index.module.css index 4e3cf67cd6c65b..bb8dd9b895c9b5 100644 --- a/web/app/components/datasets/create/step-one/index.module.css +++ b/web/app/components/datasets/create/step-one/index.module.css @@ -2,21 +2,19 @@ position: sticky; top: 0; left: 0; - padding: 42px 64px 12px; + padding: 42px 64px 12px 0; font-weight: 600; font-size: 18px; line-height: 28px; - color: #101828; } .form { position: relative; padding: 12px 64px; - background-color: #fff; } .dataSourceItem { - @apply box-border relative shrink-0 flex items-center mr-3 p-3 h-14 bg-white rounded-xl cursor-pointer; + @apply box-border relative grow shrink-0 flex items-center p-3 h-14 bg-white rounded-xl cursor-pointer; border: 0.5px solid #EAECF0; box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); font-weight: 500; @@ -24,27 +22,32 @@ line-height: 20px; color: #101828; } + .dataSourceItem:hover { background-color: #f5f8ff; border: 0.5px solid #B2CCFF; box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03); } + .dataSourceItem.active { background-color: #f5f8ff; border: 1.5px solid #528BFF; box-shadow: 0px 1px 3px rgba(16, 24, 40, 0.1), 0px 1px 2px rgba(16, 24, 40, 0.06); } + .dataSourceItem.disabled { background-color: #f9fafb; border: 0.5px solid #EAECF0; box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); cursor: default; } + .dataSourceItem.disabled:hover { background-color: #f9fafb; border: 0.5px solid #EAECF0; box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); } + .comingTag { @apply flex justify-center items-center bg-white; position: absolute; @@ -59,6 +62,7 @@ line-height: 18px; color: #444CE7; } + .datasetIcon { @apply flex mr-2 w-8 h-8 rounded-lg bg-center bg-no-repeat; background-color: #F5FAFF; @@ -66,15 +70,18 @@ background-size: 16px; border: 0.5px solid #D1E9FF; } + .dataSourceItem:active .datasetIcon, .dataSourceItem:hover .datasetIcon { background-color: #F5F8FF; border: 0.5px solid #E0EAFF; } + .datasetIcon.notion { background-image: url(../assets/notion.svg); background-size: 20px; } + .datasetIcon.web { background-image: url(../assets/web.svg); } @@ -90,29 +97,12 @@ background-color: #eaecf0; } -.OtherCreationOption { - @apply flex items-center cursor-pointer; - font-weight: 500; - font-size: 13px; - line-height: 18px; - color: #155EEF; -} -.OtherCreationOption::before { - content: ''; - display: block; - margin-right: 4px; - width: 16px; - height: 16px; - background: center no-repeat url(../assets/folder-plus.svg); - background-size: contain; -} - .notionConnectionTip { display: flex; flex-direction: column; align-items: flex-start; padding: 24px; - max-width: 640px; + width: 640px; background: #F9FAFB; border-radius: 16px; } @@ -138,6 +128,7 @@ line-height: 24px; color: #374151; } + .notionConnectionTip .title::after { content: ''; position: absolute; @@ -148,6 +139,7 @@ background: center no-repeat url(../assets/Icon-3-dots.svg); background-size: contain; } + .notionConnectionTip .tip { margin-bottom: 20px; font-style: normal; @@ -155,4 +147,4 @@ font-size: 13px; line-height: 18px; color: #6B7280; -} +} \ No newline at end of file diff --git a/web/app/components/datasets/create/step-one/index.tsx b/web/app/components/datasets/create/step-one/index.tsx index 643932e9ae21d5..2cca003b397207 100644 --- a/web/app/components/datasets/create/step-one/index.tsx +++ b/web/app/components/datasets/create/step-one/index.tsx @@ -1,6 +1,7 @@ 'use client' import React, { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' +import { RiArrowRightLine, RiFolder6Line } from '@remixicon/react' import FilePreview from '../file-preview' import FileUploader from '../file-uploader' import NotionPagePreview from '../notion-page-preview' @@ -17,6 +18,7 @@ import { NotionPageSelector } from '@/app/components/base/notion-page-selector' import { useDatasetDetailContext } from '@/context/dataset-detail' import { useProviderContext } from '@/context/provider-context' import VectorSpaceFull from '@/app/components/billing/vector-space-full' +import classNames from '@/utils/classnames' type IStepOneProps = { datasetId?: string @@ -120,143 +122,174 @@ const StepOne = ({ return true if (isShowVectorSpaceFull) return true - return false - }, [files]) + }, [files, isShowVectorSpaceFull]) + return (
-
- { - shouldShowDataSourceTypeList && ( -
{t('datasetCreation.steps.one')}
- ) - } -
- { - shouldShowDataSourceTypeList && ( -
-
{ - if (dataSourceTypeDisable) - return - changeType(DataSourceType.FILE) - hideFilePreview() - hideNotionPagePreview() - }} - > - - {t('datasetCreation.stepOne.dataSourceType.file')} -
-
{ - if (dataSourceTypeDisable) - return - changeType(DataSourceType.NOTION) - hideFilePreview() - hideNotionPagePreview() - }} - > - - {t('datasetCreation.stepOne.dataSourceType.notion')} +
+
+
+ { + shouldShowDataSourceTypeList && ( +
{t('datasetCreation.steps.one')}
+ ) + } + { + shouldShowDataSourceTypeList && ( +
+
{ + if (dataSourceTypeDisable) + return + changeType(DataSourceType.FILE) + hideFilePreview() + hideNotionPagePreview() + }} + > + + {t('datasetCreation.stepOne.dataSourceType.file')} +
+
{ + if (dataSourceTypeDisable) + return + changeType(DataSourceType.NOTION) + hideFilePreview() + hideNotionPagePreview() + }} + > + + {t('datasetCreation.stepOne.dataSourceType.notion')} +
+
changeType(DataSourceType.WEB)} + > + + {t('datasetCreation.stepOne.dataSourceType.web')} +
-
changeType(DataSourceType.WEB)} - > - - {t('datasetCreation.stepOne.dataSourceType.web')} + ) + } + {dataSourceType === DataSourceType.FILE && ( + <> + + {isShowVectorSpaceFull && ( +
+ +
+ )} +
+ {/* */} +
-
- ) - } - {dataSourceType === DataSourceType.FILE && ( - <> - - {isShowVectorSpaceFull && ( -
- + + )} + {dataSourceType === DataSourceType.NOTION && ( + <> + {!hasConnection && } + {hasConnection && ( + <> +
+ page.page_id)} + onSelect={updateNotionPages} + onPreview={updateCurrentPage} + /> +
+ {isShowVectorSpaceFull && ( +
+ +
+ )} +
+ {/* */} + +
+ + )} + + )} + {dataSourceType === DataSourceType.WEB && ( + <> +
+
- )} - - - )} - {dataSourceType === DataSourceType.NOTION && ( - <> - {!hasConnection && } - {hasConnection && ( - <> -
- page.page_id)} - onSelect={updateNotionPages} - onPreview={updateCurrentPage} - /> + {isShowVectorSpaceFull && ( +
+
- {isShowVectorSpaceFull && ( -
- -
- )} - - - )} - - )} - {dataSourceType === DataSourceType.WEB && ( - <> -
- -
- {isShowVectorSpaceFull && ( -
- + )} +
+ {/* */} +
- )} - - - )} - {!datasetId && ( - <> -
-
{t('datasetCreation.stepOne.emptyDatasetCreation')}
- - )} + + )} + {!datasetId && ( + <> +
+ + + {t('datasetCreation.stepOne.emptyDatasetCreation')} + + + )} +
+
-
- {currentFile && } - {currentNotionPage && } - {currentWebsite && } +
+ {currentFile && } + {currentNotionPage && } + {currentWebsite && } +
) } diff --git a/web/app/components/datasets/create/step-three/index.tsx b/web/app/components/datasets/create/step-three/index.tsx index 804a196ed5ac44..8d979616d13b21 100644 --- a/web/app/components/datasets/create/step-three/index.tsx +++ b/web/app/components/datasets/create/step-three/index.tsx @@ -1,45 +1,51 @@ 'use client' import React from 'react' import { useTranslation } from 'react-i18next' +import { RiBookOpenLine } from '@remixicon/react' import EmbeddingProcess from '../embedding-process' -import s from './index.module.css' -import cn from '@/utils/classnames' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import type { FullDocumentDetail, createDocumentResponse } from '@/models/datasets' +import AppIcon from '@/app/components/base/app-icon' type StepThreeProps = { datasetId?: string datasetName?: string indexingType?: string + retrievalMethod?: string creationCache?: createDocumentResponse } -const StepThree = ({ datasetId, datasetName, indexingType, creationCache }: StepThreeProps) => { +const StepThree = ({ datasetId, datasetName, indexingType, creationCache, retrievalMethod }: StepThreeProps) => { const { t } = useTranslation() const media = useBreakpoints() const isMobile = media === MediaType.mobile return ( -
-
-
+
+
+
{!datasetId && ( <> -
-
{t('datasetCreation.stepThree.creationTitle')}
-
{t('datasetCreation.stepThree.creationContent')}
-
{t('datasetCreation.stepThree.label')}
-
{datasetName || creationCache?.dataset?.name}
+
+
{t('datasetCreation.stepThree.creationTitle')}
+
{t('datasetCreation.stepThree.creationContent')}
+
+ +
+
{t('datasetCreation.stepThree.label')}
+
{datasetName || creationCache?.dataset?.name}
+
+
-
+
)} {datasetId && ( -
-
{t('datasetCreation.stepThree.additionTitle')}
-
{`${t('datasetCreation.stepThree.additionP1')} ${datasetName || creationCache?.dataset?.name} ${t('datasetCreation.stepThree.additionP2')}`}
+
+
{t('datasetCreation.stepThree.additionTitle')}
+
{`${t('datasetCreation.stepThree.additionP1')} ${datasetName || creationCache?.dataset?.name} ${t('datasetCreation.stepThree.additionP2')}`}
)}
- {!isMobile &&
-
- -
{t('datasetCreation.stepThree.sideTipTitle')}
-
{t('datasetCreation.stepThree.sideTipContent')}
+ {!isMobile && ( +
+
+
+ +
+
{t('datasetCreation.stepThree.sideTipTitle')}
+
{t('datasetCreation.stepThree.sideTipContent')}
+
-
} + )}
) } diff --git a/web/app/components/datasets/create/step-two/index.module.css b/web/app/components/datasets/create/step-two/index.module.css index f89d6d67ea7088..178cbeba857dad 100644 --- a/web/app/components/datasets/create/step-two/index.module.css +++ b/web/app/components/datasets/create/step-two/index.module.css @@ -13,18 +13,6 @@ z-index: 10; } -.form { - @apply px-16 pb-8; -} - -.form .label { - @apply pt-6 pb-2 flex items-center; - font-weight: 500; - font-size: 16px; - line-height: 24px; - color: #344054; -} - .segmentationItem { min-height: 68px; } @@ -75,6 +63,10 @@ cursor: pointer; } +.disabled { + cursor: not-allowed !important; +} + .indexItem.disabled:hover { background-color: #fcfcfd; border-color: #f2f4f7; @@ -87,8 +79,7 @@ } .radioItem { - @apply relative mb-2 rounded-xl border border-gray-100 cursor-pointer; - background-color: #fcfcfd; + @apply relative mb-2 rounded-xl border border-components-option-card-option-border cursor-pointer bg-components-option-card-option-bg; } .radioItem.segmentationItem.custom { @@ -146,7 +137,7 @@ } .typeIcon.economical { - background-image: url(../assets/piggy-bank-01.svg); + background-image: url(../assets/piggy-bank-mod.svg); } .radioItem .radio { @@ -247,7 +238,7 @@ } .ruleItem { - @apply flex items-center; + @apply flex items-center py-1.5; } .formFooter { @@ -394,19 +385,6 @@ max-width: 524px; } -.previewHeader { - position: sticky; - top: 0; - left: 0; - padding-top: 42px; - background-color: #fff; - font-weight: 600; - font-size: 18px; - line-height: 28px; - color: #101828; - z-index: 10; -} - /* * `fixed` must under `previewHeader` because of style override would not work */ @@ -432,4 +410,4 @@ font-size: 12px; line-height: 18px; } -} \ No newline at end of file +} diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index f915c68fef0345..0d7202967a5e9f 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -1,65 +1,80 @@ 'use client' -import React, { useCallback, useEffect, useLayoutEffect, useRef, useState } from 'react' +import type { FC, PropsWithChildren } from 'react' +import React, { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { useBoolean } from 'ahooks' -import { XMarkIcon } from '@heroicons/react/20/solid' -import { RocketLaunchIcon } from '@heroicons/react/24/outline' import { - RiCloseLine, + RiAlertFill, + RiArrowLeftLine, + RiSearchEyeLine, } from '@remixicon/react' import Link from 'next/link' -import { groupBy } from 'lodash-es' -import PreviewItem, { PreviewType } from './preview-item' -import LanguageSelect from './language-select' +import Image from 'next/image' +import { useHover } from 'ahooks' +import SettingCog from '../assets/setting-gear-mod.svg' +import OrangeEffect from '../assets/option-card-effect-orange.svg' +import FamilyMod from '../assets/family-mod.svg' +import Note from '../assets/note-mod.svg' +import FileList from '../assets/file-list-3-fill.svg' +import { indexMethodIcon } from '../icons' +import { PreviewContainer } from '../../preview/container' +import { ChunkContainer, QAPreview } from '../../chunk' +import { PreviewHeader } from '../../preview/header' +import { FormattedText } from '../../formatted-text/formatted' +import { PreviewSlice } from '../../formatted-text/flavours/preview-slice' +import PreviewDocumentPicker from '../../common/document-picker/preview-document-picker' import s from './index.module.css' import unescape from './unescape' import escape from './escape' +import { OptionCard } from './option-card' +import LanguageSelect from './language-select' +import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs' import cn from '@/utils/classnames' -import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, FileIndexingEstimateResponse, FullDocumentDetail, IndexingEstimateParams, NotionInfo, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' -import { - createDocument, - createFirstDocument, - fetchFileIndexingEstimate as didFetchFileIndexingEstimate, - fetchDefaultProcessRule, -} from '@/service/datasets' +import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, DocumentItem, FullDocumentDetail, ParentMode, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' + import Button from '@/app/components/base/button' -import Input from '@/app/components/base/input' -import Loading from '@/app/components/base/loading' import FloatRightContainer from '@/app/components/base/float-right-container' import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' import { type RetrievalConfig } from '@/types/app' import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' import Toast from '@/app/components/base/toast' -import { formatNumber } from '@/utils/format' import type { NotionPage } from '@/models/common' import { DataSourceProvider } from '@/models/common' -import { DataSourceType, DocForm } from '@/models/datasets' -import NotionIcon from '@/app/components/base/notion-icon' -import Switch from '@/app/components/base/switch' -import { MessageChatSquare } from '@/app/components/base/icons/src/public/common' +import { ChunkingMode, DataSourceType, RerankingModeEnum } from '@/models/datasets' import { useDatasetDetailContext } from '@/context/dataset-detail' import I18n from '@/context/i18n' -import { IS_CE_EDITION } from '@/config' import { RETRIEVE_METHOD } from '@/types/app' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' -import Tooltip from '@/app/components/base/tooltip' import { useDefaultModel, useModelList, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { LanguagesSupported } from '@/i18n/language' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import type { DefaultModel } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' -import { Globe01 } from '@/app/components/base/icons/src/vender/line/mapsAndTravel' +import Checkbox from '@/app/components/base/checkbox' +import RadioCard from '@/app/components/base/radio-card' +import { IS_CE_EDITION } from '@/config' +import Divider from '@/app/components/base/divider' +import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument, useFetchDefaultProcessRule, useFetchFileIndexingEstimateForFile, useFetchFileIndexingEstimateForNotion, useFetchFileIndexingEstimateForWeb } from '@/service/knowledge/use-create-dataset' +import Badge from '@/app/components/base/badge' +import { SkeletonContainer, SkeletonPoint, SkeletonRectangle, SkeletonRow } from '@/app/components/base/skeleton' +import Tooltip from '@/app/components/base/tooltip' +import CustomDialog from '@/app/components/base/dialog' +import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' +import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' + +const TextLabel: FC = (props) => { + return +} -type ValueOf = T[keyof T] type StepTwoProps = { isSetting?: boolean documentDetail?: FullDocumentDetail isAPIKeySet: boolean onSetting: () => void datasetId?: string - indexingType?: ValueOf + indexingType?: IndexingType + retrievalMethod?: string dataSourceType: DataSourceType files: CustomFile[] notionPages?: NotionPage[] @@ -69,21 +84,48 @@ type StepTwoProps = { websiteCrawlJobId?: string onStepChange?: (delta: number) => void updateIndexingTypeCache?: (type: string) => void + updateRetrievalMethodCache?: (method: string) => void updateResultCache?: (res: createDocumentResponse) => void onSave?: () => void onCancel?: () => void } -enum SegmentType { +export enum SegmentType { AUTO = 'automatic', CUSTOM = 'custom', } -enum IndexingType { +export enum IndexingType { QUALIFIED = 'high_quality', ECONOMICAL = 'economy', } const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n' +const DEFAULT_MAXMIMUM_CHUNK_LENGTH = 500 +const DEFAULT_OVERLAP = 50 + +type ParentChildConfig = { + chunkForContext: ParentMode + parent: { + delimiter: string + maxLength: number + } + child: { + delimiter: string + maxLength: number + } +} + +const defaultParentChildConfig: ParentChildConfig = { + chunkForContext: 'paragraph', + parent: { + delimiter: '\\n\\n', + maxLength: 500, + }, + child: { + delimiter: '\\n', + maxLength: 200, + }, +} const StepTwo = ({ isSetting, @@ -104,6 +146,7 @@ const StepTwo = ({ updateResultCache, onSave, onCancel, + updateRetrievalMethodCache, }: StepTwoProps) => { const { t } = useTranslation() const { locale } = useContext(I18n) @@ -111,66 +154,166 @@ const StepTwo = ({ const isMobile = media === MediaType.mobile const { dataset: currentDataset, mutateDatasetRes } = useDatasetDetailContext() + + const isInUpload = Boolean(currentDataset) + const isUploadInEmptyDataset = isInUpload && !currentDataset?.doc_form + const isNotUploadInEmptyDataset = !isUploadInEmptyDataset + const isInInit = !isInUpload && !isSetting + const isInCreatePage = !datasetId || (datasetId && !currentDataset?.data_source_type) const dataSourceType = isInCreatePage ? inCreatePageDataSourceType : currentDataset?.data_source_type - const scrollRef = useRef(null) - const [scrolled, setScrolled] = useState(false) - const previewScrollRef = useRef(null) - const [previewScrolled, setPreviewScrolled] = useState(false) - const [segmentationType, setSegmentationType] = useState(SegmentType.AUTO) + const [segmentationType, setSegmentationType] = useState(SegmentType.CUSTOM) const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER) - const setSegmentIdentifier = useCallback((value: string) => { - doSetSegmentIdentifier(value ? escape(value) : DEFAULT_SEGMENT_IDENTIFIER) + const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => { + doSetSegmentIdentifier(value ? escape(value) : (canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER)) }, []) - const [maxChunkLength, setMaxChunkLength] = useState(4000) // default chunk length + const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXMIMUM_CHUNK_LENGTH) // default chunk length const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(4000) - const [overlap, setOverlap] = useState(50) + const [overlap, setOverlap] = useState(DEFAULT_OVERLAP) const [rules, setRules] = useState([]) const [defaultConfig, setDefaultConfig] = useState() const hasSetIndexType = !!indexingType - const [indexType, setIndexType] = useState>( + const [indexType, setIndexType] = useState( (indexingType || isAPIKeySet) ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL, ) - const [isLanguageSelectDisabled, setIsLanguageSelectDisabled] = useState(false) - const [docForm, setDocForm] = useState( - (datasetId && documentDetail) ? documentDetail.doc_form : DocForm.TEXT, + + const [previewFile, setPreviewFile] = useState( + (datasetId && documentDetail) + ? documentDetail.file + : files[0], + ) + const [previewNotionPage, setPreviewNotionPage] = useState( + (datasetId && documentDetail) + ? documentDetail.notion_page + : notionPages[0], + ) + + const [previewWebsitePage, setPreviewWebsitePage] = useState( + (datasetId && documentDetail) + ? documentDetail.website_page + : websitePages[0], + ) + + // QA Related + const [isLanguageSelectDisabled, _setIsLanguageSelectDisabled] = useState(false) + const [isQAConfirmDialogOpen, setIsQAConfirmDialogOpen] = useState(false) + const [docForm, setDocForm] = useState( + (datasetId && documentDetail) ? documentDetail.doc_form as ChunkingMode : ChunkingMode.text, ) + const handleChangeDocform = (value: ChunkingMode) => { + if (value === ChunkingMode.qa && indexType === IndexingType.ECONOMICAL) { + setIsQAConfirmDialogOpen(true) + return + } + if (value === ChunkingMode.parentChild && indexType === IndexingType.ECONOMICAL) + setIndexType(IndexingType.QUALIFIED) + setDocForm(value) + // eslint-disable-next-line @typescript-eslint/no-use-before-define + currentEstimateMutation.reset() + } + const [docLanguage, setDocLanguage] = useState( (datasetId && documentDetail) ? documentDetail.doc_language : (locale !== LanguagesSupported[1] ? 'English' : 'Chinese'), ) - const [QATipHide, setQATipHide] = useState(false) - const [previewSwitched, setPreviewSwitched] = useState(false) - const [showPreview, { setTrue: setShowPreview, setFalse: hidePreview }] = useBoolean() - const [customFileIndexingEstimate, setCustomFileIndexingEstimate] = useState(null) - const [automaticFileIndexingEstimate, setAutomaticFileIndexingEstimate] = useState(null) - const fileIndexingEstimate = (() => { - return segmentationType === SegmentType.AUTO ? automaticFileIndexingEstimate : customFileIndexingEstimate - })() - const [isCreating, setIsCreating] = useState(false) + const [parentChildConfig, setParentChildConfig] = useState(defaultParentChildConfig) - const scrollHandle = (e: Event) => { - if ((e.target as HTMLDivElement).scrollTop > 0) - setScrolled(true) + const getIndexing_technique = () => indexingType || indexType + const currentDocForm = currentDataset?.doc_form || docForm - else - setScrolled(false) + const getProcessRule = (): ProcessRule => { + if (currentDocForm === ChunkingMode.parentChild) { + return { + rules: { + pre_processing_rules: rules, + segmentation: { + separator: unescape( + parentChildConfig.parent.delimiter, + ), + max_tokens: parentChildConfig.parent.maxLength, + }, + parent_mode: parentChildConfig.chunkForContext, + subchunk_segmentation: { + separator: unescape(parentChildConfig.child.delimiter), + max_tokens: parentChildConfig.child.maxLength, + }, + }, + mode: 'hierarchical', + } as ProcessRule + } + return { + rules: { + pre_processing_rules: rules, + segmentation: { + separator: unescape(segmentIdentifier), + max_tokens: maxChunkLength, + chunk_overlap: overlap, + }, + }, // api will check this. It will be removed after api refactored. + mode: segmentationType, + } as ProcessRule } - const previewScrollHandle = (e: Event) => { - if ((e.target as HTMLDivElement).scrollTop > 0) - setPreviewScrolled(true) + const fileIndexingEstimateQuery = useFetchFileIndexingEstimateForFile({ + docForm: currentDocForm, + docLanguage, + dataSourceType: DataSourceType.FILE, + files: previewFile + ? [files.find(file => file.name === previewFile.name)!] + : files, + indexingTechnique: getIndexing_technique() as any, + processRule: getProcessRule(), + dataset_id: datasetId!, + }) + const notionIndexingEstimateQuery = useFetchFileIndexingEstimateForNotion({ + docForm: currentDocForm, + docLanguage, + dataSourceType: DataSourceType.NOTION, + notionPages: [previewNotionPage], + indexingTechnique: getIndexing_technique() as any, + processRule: getProcessRule(), + dataset_id: datasetId || '', + }) - else - setPreviewScrolled(false) - } - const getFileName = (name: string) => { - const arr = name.split('.') - return arr.slice(0, -1).join('.') - } + const websiteIndexingEstimateQuery = useFetchFileIndexingEstimateForWeb({ + docForm: currentDocForm, + docLanguage, + dataSourceType: DataSourceType.WEB, + websitePages: [previewWebsitePage], + crawlOptions, + websiteCrawlProvider, + websiteCrawlJobId, + indexingTechnique: getIndexing_technique() as any, + processRule: getProcessRule(), + dataset_id: datasetId || '', + }) + + const currentEstimateMutation = dataSourceType === DataSourceType.FILE + ? fileIndexingEstimateQuery + : dataSourceType === DataSourceType.NOTION + ? notionIndexingEstimateQuery + : websiteIndexingEstimateQuery + + const fetchEstimate = useCallback(() => { + if (dataSourceType === DataSourceType.FILE) + fileIndexingEstimateQuery.mutate() + + if (dataSourceType === DataSourceType.NOTION) + notionIndexingEstimateQuery.mutate() + + if (dataSourceType === DataSourceType.WEB) + websiteIndexingEstimateQuery.mutate() + }, [dataSourceType, fileIndexingEstimateQuery, notionIndexingEstimateQuery, websiteIndexingEstimateQuery]) + + const estimate + = dataSourceType === DataSourceType.FILE + ? fileIndexingEstimateQuery.data + : dataSourceType === DataSourceType.NOTION + ? notionIndexingEstimateQuery.data + : websiteIndexingEstimateQuery.data const getRuleName = (key: string) => { if (key === 'remove_extra_spaces') @@ -198,128 +341,20 @@ const StepTwo = ({ if (defaultConfig) { setSegmentIdentifier(defaultConfig.segmentation.separator) setMaxChunkLength(defaultConfig.segmentation.max_tokens) - setOverlap(defaultConfig.segmentation.chunk_overlap) + setOverlap(defaultConfig.segmentation.chunk_overlap!) setRules(defaultConfig.pre_processing_rules) } + setParentChildConfig(defaultParentChildConfig) } - const fetchFileIndexingEstimate = async (docForm = DocForm.TEXT, language?: string) => { - // eslint-disable-next-line @typescript-eslint/no-use-before-define - const res = await didFetchFileIndexingEstimate(getFileIndexingEstimateParams(docForm, language)!) - if (segmentationType === SegmentType.CUSTOM) - setCustomFileIndexingEstimate(res) - else - setAutomaticFileIndexingEstimate(res) - } - - const confirmChangeCustomConfig = () => { - if (segmentationType === SegmentType.CUSTOM && maxChunkLength > limitMaxChunkLength) { - Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck', { limit: limitMaxChunkLength }) }) + const updatePreview = () => { + if (segmentationType === SegmentType.CUSTOM && maxChunkLength > 4000) { + Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') }) return } - setCustomFileIndexingEstimate(null) - setShowPreview() - fetchFileIndexingEstimate() - setPreviewSwitched(false) - } - - const getIndexing_technique = () => indexingType || indexType - - const getProcessRule = () => { - const processRule: ProcessRule = { - rules: {} as any, // api will check this. It will be removed after api refactored. - mode: segmentationType, - } - if (segmentationType === SegmentType.CUSTOM) { - const ruleObj = { - pre_processing_rules: rules, - segmentation: { - separator: unescape(segmentIdentifier), - max_tokens: maxChunkLength, - chunk_overlap: overlap, - }, - } - processRule.rules = ruleObj - } - return processRule - } - - const getNotionInfo = () => { - const workspacesMap = groupBy(notionPages, 'workspace_id') - const workspaces = Object.keys(workspacesMap).map((workspaceId) => { - return { - workspaceId, - pages: workspacesMap[workspaceId], - } - }) - return workspaces.map((workspace) => { - return { - workspace_id: workspace.workspaceId, - pages: workspace.pages.map((page) => { - const { page_id, page_name, page_icon, type } = page - return { - page_id, - page_name, - page_icon, - type, - } - }), - } - }) as NotionInfo[] - } - - const getWebsiteInfo = () => { - return { - provider: websiteCrawlProvider, - job_id: websiteCrawlJobId, - urls: websitePages.map(page => page.source_url), - only_main_content: crawlOptions?.only_main_content, - } + fetchEstimate() } - const getFileIndexingEstimateParams = (docForm: DocForm, language?: string): IndexingEstimateParams | undefined => { - if (dataSourceType === DataSourceType.FILE) { - return { - info_list: { - data_source_type: dataSourceType, - file_info_list: { - file_ids: files.map(file => file.id) as string[], - }, - }, - indexing_technique: getIndexing_technique() as string, - process_rule: getProcessRule(), - doc_form: docForm, - doc_language: language || docLanguage, - dataset_id: datasetId as string, - } - } - if (dataSourceType === DataSourceType.NOTION) { - return { - info_list: { - data_source_type: dataSourceType, - notion_info_list: getNotionInfo(), - }, - indexing_technique: getIndexing_technique() as string, - process_rule: getProcessRule(), - doc_form: docForm, - doc_language: language || docLanguage, - dataset_id: datasetId as string, - } - } - if (dataSourceType === DataSourceType.WEB) { - return { - info_list: { - data_source_type: dataSourceType, - website_info_list: getWebsiteInfo(), - }, - indexing_technique: getIndexing_technique() as string, - process_rule: getProcessRule(), - doc_form: docForm, - doc_language: language || docLanguage, - dataset_id: datasetId as string, - } - } - } const { modelList: rerankModelList, defaultModel: rerankDefaultModel, @@ -351,13 +386,14 @@ const StepTwo = ({ if (isSetting) { params = { original_document_id: documentDetail?.id, - doc_form: docForm, + doc_form: currentDocForm, doc_language: docLanguage, process_rule: getProcessRule(), // eslint-disable-next-line @typescript-eslint/no-use-before-define retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page. embedding_model: embeddingModel.model, // Readonly embedding_model_provider: embeddingModel.provider, // Readonly + indexing_technique: getIndexing_technique(), } as CreateDocumentReq } else { // create @@ -377,8 +413,12 @@ const StepTwo = ({ } const postRetrievalConfig = ensureRerankModelSelected({ rerankDefaultModel: rerankDefaultModel!, - // eslint-disable-next-line @typescript-eslint/no-use-before-define - retrievalConfig, + retrievalConfig: { + // eslint-disable-next-line @typescript-eslint/no-use-before-define + ...retrievalConfig, + // eslint-disable-next-line @typescript-eslint/no-use-before-define + reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel, + }, indexMethod: indexMethod as string, }) params = { @@ -390,7 +430,7 @@ const StepTwo = ({ }, indexing_technique: getIndexing_technique(), process_rule: getProcessRule(), - doc_form: docForm, + doc_form: currentDocForm, doc_language: docLanguage, retrieval_model: postRetrievalConfig, @@ -403,29 +443,36 @@ const StepTwo = ({ } } if (dataSourceType === DataSourceType.NOTION) - params.data_source.info_list.notion_info_list = getNotionInfo() + params.data_source.info_list.notion_info_list = getNotionInfo(notionPages) - if (dataSourceType === DataSourceType.WEB) - params.data_source.info_list.website_info_list = getWebsiteInfo() + if (dataSourceType === DataSourceType.WEB) { + params.data_source.info_list.website_info_list = getWebsiteInfo({ + websiteCrawlProvider, + websiteCrawlJobId, + websitePages, + }) + } } return params } - const getRules = async () => { - try { - const res = await fetchDefaultProcessRule({ url: '/datasets/process-rule' }) - const separator = res.rules.segmentation.separator + const fetchDefaultProcessRuleMutation = useFetchDefaultProcessRule({ + onSuccess(data) { + const separator = data.rules.segmentation.separator setSegmentIdentifier(separator) - setMaxChunkLength(res.rules.segmentation.max_tokens) - setLimitMaxChunkLength(res.limits.indexing_max_segmentation_tokens_length) - setOverlap(res.rules.segmentation.chunk_overlap) - setRules(res.rules.pre_processing_rules) - setDefaultConfig(res.rules) - } - catch (err) { - console.log(err) - } - } + setMaxChunkLength(data.rules.segmentation.max_tokens) + setOverlap(data.rules.segmentation.chunk_overlap!) + setRules(data.rules.pre_processing_rules) + setDefaultConfig(data.rules) + setLimitMaxChunkLength(data.limits.indexing_max_segmentation_tokens_length) + }, + onError(error) { + Toast.notify({ + type: 'error', + message: `${error}`, + }) + }, + }) const getRulesFromDetail = () => { if (documentDetail) { @@ -435,7 +482,7 @@ const StepTwo = ({ const overlap = rules.segmentation.chunk_overlap setSegmentIdentifier(separator) setMaxChunkLength(max) - setOverlap(overlap) + setOverlap(overlap!) setRules(rules.pre_processing_rules) setDefaultConfig(rules) } @@ -443,119 +490,81 @@ const StepTwo = ({ const getDefaultMode = () => { if (documentDetail) + // @ts-expect-error fix after api refactored setSegmentationType(documentDetail.dataset_process_rule.mode) } - const createHandle = async () => { - if (isCreating) - return - setIsCreating(true) - try { - let res - const params = getCreationParams() - if (!params) - return false - - setIsCreating(true) - if (!datasetId) { - res = await createFirstDocument({ - body: params as CreateDocumentReq, - }) - updateIndexingTypeCache && updateIndexingTypeCache(indexType as string) - updateResultCache && updateResultCache(res) - } - else { - res = await createDocument({ - datasetId, - body: params as CreateDocumentReq, - }) - updateIndexingTypeCache && updateIndexingTypeCache(indexType as string) - updateResultCache && updateResultCache(res) - } - if (mutateDatasetRes) - mutateDatasetRes() - onStepChange && onStepChange(+1) - isSetting && onSave && onSave() - } - catch (err) { + const createFirstDocumentMutation = useCreateFirstDocument({ + onError(error) { Toast.notify({ type: 'error', - message: `${err}`, + message: `${error}`, }) - } - finally { - setIsCreating(false) - } - } + }, + }) + const createDocumentMutation = useCreateDocument(datasetId!, { + onError(error) { + Toast.notify({ + type: 'error', + message: `${error}`, + }) + }, + }) - const handleSwitch = (state: boolean) => { - if (state) - setDocForm(DocForm.QA) - else - setDocForm(DocForm.TEXT) - } + const isCreating = createFirstDocumentMutation.isPending || createDocumentMutation.isPending - const previewSwitch = async (language?: string) => { - setPreviewSwitched(true) - setIsLanguageSelectDisabled(true) - if (segmentationType === SegmentType.AUTO) - setAutomaticFileIndexingEstimate(null) - else - setCustomFileIndexingEstimate(null) - try { - await fetchFileIndexingEstimate(DocForm.QA, language) + const createHandle = async () => { + const params = getCreationParams() + if (!params) + return false + + if (!datasetId) { + await createFirstDocumentMutation.mutateAsync( + params, + { + onSuccess(data) { + updateIndexingTypeCache && updateIndexingTypeCache(indexType as string) + updateResultCache && updateResultCache(data) + // eslint-disable-next-line @typescript-eslint/no-use-before-define + updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string) + }, + }, + ) } - finally { - setIsLanguageSelectDisabled(false) + else { + await createDocumentMutation.mutateAsync(params, { + onSuccess(data) { + updateIndexingTypeCache && updateIndexingTypeCache(indexType as string) + updateResultCache && updateResultCache(data) + }, + }) } - } - - const handleSelect = (language: string) => { - setDocLanguage(language) - // Switch language, re-cutter - if (docForm === DocForm.QA && previewSwitched) - previewSwitch(language) + if (mutateDatasetRes) + mutateDatasetRes() + onStepChange && onStepChange(+1) + isSetting && onSave && onSave() } const changeToEconomicalType = () => { - if (!hasSetIndexType) { + if (docForm !== ChunkingMode.text) + return + + if (!hasSetIndexType) setIndexType(IndexingType.ECONOMICAL) - setDocForm(DocForm.TEXT) - } } useEffect(() => { // fetch rules if (!isSetting) { - getRules() + fetchDefaultProcessRuleMutation.mutate('/datasets/process-rule') } else { getRulesFromDetail() getDefaultMode() } + // eslint-disable-next-line react-hooks/exhaustive-deps }, []) - useEffect(() => { - scrollRef.current?.addEventListener('scroll', scrollHandle) - return () => { - scrollRef.current?.removeEventListener('scroll', scrollHandle) - } - }, []) - - useLayoutEffect(() => { - if (showPreview) { - previewScrollRef.current?.addEventListener('scroll', previewScrollHandle) - return () => { - previewScrollRef.current?.removeEventListener('scroll', previewScrollHandle) - } - } - }, [showPreview]) - - useEffect(() => { - if (indexingType === IndexingType.ECONOMICAL && docForm === DocForm.QA) - setDocForm(DocForm.TEXT) - }, [indexingType, docForm]) - useEffect(() => { // get indexing type by props if (indexingType) @@ -565,20 +574,6 @@ const StepTwo = ({ setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL) }, [isAPIKeySet, indexingType, datasetId]) - useEffect(() => { - if (segmentationType === SegmentType.AUTO) { - setAutomaticFileIndexingEstimate(null) - !isMobile && setShowPreview() - fetchFileIndexingEstimate() - setPreviewSwitched(false) - } - else { - hidePreview() - setCustomFileIndexingEstimate(null) - setPreviewSwitched(false) - } - }, [segmentationType, indexType]) - const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || { search_method: RETRIEVE_METHOD.semantic, reranking_enable: false, @@ -591,433 +586,589 @@ const StepTwo = ({ score_threshold: 0.5, } as RetrievalConfig) + const economyDomRef = useRef(null) + const isHoveringEconomy = useHover(economyDomRef) + return (
-
-
- {t('datasetCreation.steps.two')} - {(isMobile || !showPreview) && ( - - )} -
-
-
{t('datasetCreation.stepTwo.segmentation')}
-
-
setSegmentationType(SegmentType.AUTO)} - > - - -
-
{t('datasetCreation.stepTwo.auto')}
-
{t('datasetCreation.stepTwo.autoDescription')}
-
-
-
setSegmentationType(SegmentType.CUSTOM)} - > - - -
-
{t('datasetCreation.stepTwo.custom')}
-
{t('datasetCreation.stepTwo.customDescription')}
+
+
{t('datasetCreation.stepTwo.segmentation')}
+ {((isInUpload && [ChunkingMode.text, ChunkingMode.qa].includes(currentDataset!.doc_form)) + || isUploadInEmptyDataset + || isInInit) + && } + activeHeaderClassName='bg-dataset-option-card-blue-gradient' + description={t('datasetCreation.stepTwo.generalTip')} + isActive={ + [ChunkingMode.text, ChunkingMode.qa].includes(currentDocForm) + } + onSwitched={() => + handleChangeDocform(ChunkingMode.text) + } + actions={ + <> + + + + } + noHighlight={isInUpload && isNotUploadInEmptyDataset} + > +
+
+ setSegmentIdentifier(e.target.value, true)} + /> + +
- {segmentationType === SegmentType.CUSTOM && ( -
-
-
-
- {t('datasetCreation.stepTwo.separator')} - - {t('datasetCreation.stepTwo.separatorTip')} -
- } - /> -
- setSegmentIdentifier(e.target.value)} - /> -
+
+
+
+ {t('datasetCreation.stepTwo.rules')}
-
-
-
{t('datasetCreation.stepTwo.maxLength')}
- setMaxChunkLength(parseInt(e.target.value.replace(/^0+/, ''), 10))} + +
+
+ {rules.map(rule => ( +
{ + ruleChangeHandle(rule.id) + }}> + +
-
-
-
-
- {t('datasetCreation.stepTwo.overlap')} - - {t('datasetCreation.stepTwo.overlapTip')} -
- } + ))} + {IS_CE_EDITION && <> + +
+
{ + if (currentDataset?.doc_form) + return + if (docForm === ChunkingMode.qa) + handleChangeDocform(ChunkingMode.text) + else + handleChangeDocform(ChunkingMode.qa) + }}> + +
- setOverlap(parseInt(e.target.value.replace(/^0+/, ''), 10))} + +
-
-
-
-
{t('datasetCreation.stepTwo.rules')}
- {rules.map(rule => ( -
- ruleChangeHandle(rule.id)} className="w-4 h-4 rounded border-gray-300 text-blue-700 focus:ring-blue-700" /> - -
- ))} -
-
-
- - -
+ {currentDocForm === ChunkingMode.qa && ( +
+ + + {t('datasetCreation.stepTwo.QATip')} + +
+ )} + }
- )} +
-
-
{t('datasetCreation.stepTwo.indexMode')}
-
-
- {(!hasSetIndexType || (hasSetIndexType && indexingType === IndexingType.QUALIFIED)) && ( -
{ - if (isAPIKeySet) - setIndexType(IndexingType.QUALIFIED) - }} - > - - {!hasSetIndexType && } -
-
- {t('datasetCreation.stepTwo.qualified')} - {!hasSetIndexType && {t('datasetCreation.stepTwo.recommend')}} -
-
{t('datasetCreation.stepTwo.qualifiedTip')}
+ } + { + ( + (isInUpload && currentDataset!.doc_form === ChunkingMode.parentChild) + || isUploadInEmptyDataset + || isInInit + ) + && } + effectImg={OrangeEffect.src} + activeHeaderClassName='bg-dataset-option-card-orange-gradient' + description={t('datasetCreation.stepTwo.parentChildTip')} + isActive={currentDocForm === ChunkingMode.parentChild} + onSwitched={() => handleChangeDocform(ChunkingMode.parentChild)} + actions={ + <> + + + + } + noHighlight={isInUpload && isNotUploadInEmptyDataset} + > +
+
+
+
+ {t('datasetCreation.stepTwo.parentChunkForContext')}
- {!isAPIKeySet && ( -
- {t('datasetCreation.stepTwo.warning')}  - {t('datasetCreation.stepTwo.click')} + +
+ } + title={t('datasetCreation.stepTwo.paragraph')} + description={t('datasetCreation.stepTwo.paragraphTip')} + isChosen={parentChildConfig.chunkForContext === 'paragraph'} + onChosen={() => setParentChildConfig( + { + ...parentChildConfig, + chunkForContext: 'paragraph', + }, + )} + chosenConfig={ +
+ setParentChildConfig({ + ...parentChildConfig, + parent: { + ...parentChildConfig.parent, + delimiter: e.target.value ? escape(e.target.value) : '', + }, + })} + /> + setParentChildConfig({ + ...parentChildConfig, + parent: { + ...parentChildConfig.parent, + maxLength: value, + }, + })} + />
+ } + /> + } + title={t('datasetCreation.stepTwo.fullDoc')} + description={t('datasetCreation.stepTwo.fullDocTip')} + onChosen={() => setParentChildConfig( + { + ...parentChildConfig, + chunkForContext: 'full-doc', + }, )} -
- )} + isChosen={parentChildConfig.chunkForContext === 'full-doc'} + /> +
- {(!hasSetIndexType || (hasSetIndexType && indexingType === IndexingType.ECONOMICAL)) && ( -
- - {!hasSetIndexType && } -
-
{t('datasetCreation.stepTwo.economical')}
-
{t('datasetCreation.stepTwo.economicalTip')}
+
+
+
+ {t('datasetCreation.stepTwo.childChunkForRetrieval')}
+
- )} -
- {hasSetIndexType && indexType === IndexingType.ECONOMICAL && ( -
- {t('datasetCreation.stepTwo.indexSettingTip')} - {t('datasetCreation.stepTwo.datasetSettingLink')} -
- )} - {IS_CE_EDITION && indexType === IndexingType.QUALIFIED && ( -
-
-
- -
-
-
{t('datasetCreation.stepTwo.QATitle')}
-
- {t('datasetCreation.stepTwo.QALanguage')} - -
-
-
- -
+
+ setParentChildConfig({ + ...parentChildConfig, + child: { + ...parentChildConfig.child, + delimiter: e.target.value ? escape(e.target.value) : '', + }, + })} + /> + setParentChildConfig({ + ...parentChildConfig, + child: { + ...parentChildConfig.child, + maxLength: value, + }, + })} + />
- {docForm === DocForm.QA && !QATipHide && ( -
- {t('datasetCreation.stepTwo.QATip')} - setQATipHide(true)} /> -
- )} -
- )} - {/* Embedding model */} - {indexType === IndexingType.QUALIFIED && ( -
-
{t('datasetSettings.form.embeddingModel')}
- { - setEmbeddingModel(model) - }} - /> - {!!datasetId && ( -
- {t('datasetCreation.stepTwo.indexSettingTip')} - {t('datasetCreation.stepTwo.datasetSettingLink')} -
- )}
- )} - {/* Retrieval Method Config */} -
- {!datasetId - ? ( -
-
{t('datasetSettings.form.retrievalSetting.title')}
-
- {t('datasetSettings.form.retrievalSetting.learnMore')} - {t('datasetSettings.form.retrievalSetting.longDescription')} -
+
+
+
+ {t('datasetCreation.stepTwo.rules')}
- ) - : ( -
-
{t('datasetSettings.form.retrievalSetting.title')}
-
- )} - -
- { - getIndexing_technique() === IndexingType.QUALIFIED - ? ( - - ) - : ( - - ) - } -
-
- -
-
- {dataSourceType === DataSourceType.FILE && ( - <> -
{t('datasetCreation.stepTwo.fileSource')}
-
- - {getFileName(files[0].name || '')} - {files.length > 1 && ( - - {t('datasetCreation.stepTwo.other')} - {files.length - 1} - {t('datasetCreation.stepTwo.fileUnit')} - - )} -
- - )} - {dataSourceType === DataSourceType.NOTION && ( - <> -
{t('datasetCreation.stepTwo.notionSource')}
-
- +
+
+ {rules.map(rule => ( +
{ + ruleChangeHandle(rule.id) + }}> + - {notionPages[0]?.page_name} - {notionPages.length > 1 && ( - - {t('datasetCreation.stepTwo.other')} - {notionPages.length - 1} - {t('datasetCreation.stepTwo.notionUnit')} - - )} +
- - )} - {dataSourceType === DataSourceType.WEB && ( - <> -
{t('datasetCreation.stepTwo.websiteSource')}
-
- - {websitePages[0].source_url} - {websitePages.length > 1 && ( - - {t('datasetCreation.stepTwo.other')} - {websitePages.length - 1} - {t('datasetCreation.stepTwo.webpageUnit')} - - )} -
- - )} -
-
-
-
{t('datasetCreation.stepTwo.estimateSegment')}
-
- { - fileIndexingEstimate - ? ( -
{formatNumber(fileIndexingEstimate.total_segments)}
- ) - : ( -
{t('datasetCreation.stepTwo.calculating')}
- ) - } + ))}
- {!isSetting - ? ( -
- -
- -
- ) - : ( -
- - + } + +
{t('datasetCreation.stepTwo.indexMode')}
+
+ {(!hasSetIndexType || (hasSetIndexType && indexingType === IndexingType.QUALIFIED)) && ( + + {t('datasetCreation.stepTwo.qualified')} + + {t('datasetCreation.stepTwo.recommend')} + + + {!hasSetIndexType && } + +
} + description={t('datasetCreation.stepTwo.qualifiedTip')} + icon={} + isActive={!hasSetIndexType && indexType === IndexingType.QUALIFIED} + disabled={!isAPIKeySet || hasSetIndexType} + onSwitched={() => { + if (isAPIKeySet) + setIndexType(IndexingType.QUALIFIED) + }} + /> + )} + + {(!hasSetIndexType || (hasSetIndexType && indexingType === IndexingType.ECONOMICAL)) && ( + <> + setIsQAConfirmDialogOpen(false)} className='w-[432px]'> +
+

+ {t('datasetCreation.stepTwo.qaSwitchHighQualityTipTitle')} +

+

+ {t('datasetCreation.stepTwo.qaSwitchHighQualityTipContent')} +

+
+
+ +
- )} -
+ + + + } + isActive={!hasSetIndexType && indexType === IndexingType.ECONOMICAL} + disabled={!isAPIKeySet || hasSetIndexType || docForm !== ChunkingMode.text} + ref={economyDomRef} + onSwitched={() => { + if (isAPIKeySet && docForm === ChunkingMode.text) + setIndexType(IndexingType.ECONOMICAL) + }} + /> + + +
+ { + docForm === ChunkingMode.qa + ? t('datasetCreation.stepTwo.notAvailableForQA') + : t('datasetCreation.stepTwo.notAvailableForParentChild') + } +
+
+
+ )}
-
- - {showPreview &&
-
-
-
-
{t('datasetCreation.stepTwo.previewTitle')}
- {docForm === DocForm.QA && !previewSwitched && ( - - )} -
-
- -
+ {!hasSetIndexType && indexType === IndexingType.QUALIFIED && ( +
+
+
+
- {docForm === DocForm.QA && !previewSwitched && ( -
- {t('datasetCreation.stepTwo.previewSwitchTipStart')} - {t('datasetCreation.stepTwo.previewSwitchTipEnd')} + {t('datasetCreation.stepTwo.highQualityTip')} +
+ )} + {hasSetIndexType && indexType === IndexingType.ECONOMICAL && ( +
+ {t('datasetCreation.stepTwo.indexSettingTip')} + {t('datasetCreation.stepTwo.datasetSettingLink')} +
+ )} + {/* Embedding model */} + {indexType === IndexingType.QUALIFIED && ( +
+
{t('datasetSettings.form.embeddingModel')}
+ { + setEmbeddingModel(model) + }} + /> + {!!datasetId && ( +
+ {t('datasetCreation.stepTwo.indexSettingTip')} + {t('datasetCreation.stepTwo.datasetSettingLink')}
)}
-
- {previewSwitched && docForm === DocForm.QA && fileIndexingEstimate?.qa_preview && ( - <> - {fileIndexingEstimate?.qa_preview.map((item, index) => ( - - ))} - - )} - {(docForm === DocForm.TEXT || !previewSwitched) && fileIndexingEstimate?.preview && ( - <> - {fileIndexingEstimate?.preview.map((item, index) => ( - - ))} - - )} - {previewSwitched && docForm === DocForm.QA && !fileIndexingEstimate?.qa_preview && ( -
- + )} + + {/* Retrieval Method Config */} +
+ {!datasetId + ? ( +
+
{t('datasetSettings.form.retrievalSetting.title')}
+
+ {t('datasetSettings.form.retrievalSetting.learnMore')} + {t('datasetSettings.form.retrievalSetting.longDescription')} +
- )} - {!previewSwitched && !fileIndexingEstimate?.preview && ( -
- + ) + : ( +
+
{t('datasetSettings.form.retrievalSetting.title')}
)} + +
+ { + getIndexing_technique() === IndexingType.QUALIFIED + ? ( + + ) + : ( + + ) + }
-
} - {!showPreview && ( -
-
- -
{t('datasetCreation.stepTwo.sideTipTitle')}
-
-

{t('datasetCreation.stepTwo.sideTipP1')}

-

{t('datasetCreation.stepTwo.sideTipP2')}

-

{t('datasetCreation.stepTwo.sideTipP3')}

-

{t('datasetCreation.stepTwo.sideTipP4')}

+
+ + {!isSetting + ? ( +
+ + +
+ ) + : ( +
+ + +
+ )} +
+ { }} footer={null}> + +
+ {dataSourceType === DataSourceType.FILE + && >} + onChange={(selected) => { + currentEstimateMutation.reset() + setPreviewFile(selected) + currentEstimateMutation.mutate() + }} + // when it is from setting, it just has one file + value={isSetting ? (files[0]! as Required) : previewFile} + /> + } + {dataSourceType === DataSourceType.NOTION + && ({ + id: page.page_id, + name: page.page_name, + extension: 'md', + })) + } + onChange={(selected) => { + currentEstimateMutation.reset() + const selectedPage = notionPages.find(page => page.page_id === selected.id) + setPreviewNotionPage(selectedPage!) + currentEstimateMutation.mutate() + }} + value={{ + id: previewNotionPage?.page_id || '', + name: previewNotionPage?.page_name || '', + extension: 'md', + }} + /> + } + {dataSourceType === DataSourceType.WEB + && ({ + id: page.source_url, + name: page.title, + extension: 'md', + })) + } + onChange={(selected) => { + currentEstimateMutation.reset() + const selectedPage = websitePages.find(page => page.source_url === selected.id) + setPreviewWebsitePage(selectedPage!) + currentEstimateMutation.mutate() + }} + value={ + { + id: previewWebsitePage?.source_url || '', + name: previewWebsitePage?.title || '', + extension: 'md', + } + } + /> + } + { + currentDocForm !== ChunkingMode.qa + && + } +
+ } + className={cn('flex shrink-0 w-1/2 p-4 pr-0 relative h-full', isMobile && 'w-full max-w-[524px]')} + mainClassName='space-y-6' + > + {currentDocForm === ChunkingMode.qa && estimate?.qa_preview && ( + estimate?.qa_preview.map((item, index) => ( + + + + )) + )} + {currentDocForm === ChunkingMode.text && estimate?.preview && ( + estimate?.preview.map((item, index) => ( + + {item.content} + + )) + )} + {currentDocForm === ChunkingMode.parentChild && currentEstimateMutation.data?.preview && ( + estimate?.preview?.map((item, index) => { + const indexForLabel = index + 1 + return ( + + + {item.child_chunks.map((child, index) => { + const indexForLabel = index + 1 + return ( + + ) + })} + + + ) + }) + )} + {currentEstimateMutation.isIdle && ( +
+
+ +

+ {t('datasetCreation.stepTwo.previewChunkTip')} +

-
- )} + )} + {currentEstimateMutation.isPending && ( +
+ {Array.from({ length: 10 }, (_, i) => ( + + + + + + + + + + + ))} +
+ )} +
) diff --git a/web/app/components/datasets/create/step-two/inputs.tsx b/web/app/components/datasets/create/step-two/inputs.tsx new file mode 100644 index 00000000000000..4231f6242dca20 --- /dev/null +++ b/web/app/components/datasets/create/step-two/inputs.tsx @@ -0,0 +1,77 @@ +import type { FC, PropsWithChildren, ReactNode } from 'react' +import { useTranslation } from 'react-i18next' +import type { InputProps } from '@/app/components/base/input' +import Input from '@/app/components/base/input' +import Tooltip from '@/app/components/base/tooltip' +import type { InputNumberProps } from '@/app/components/base/input-number' +import { InputNumber } from '@/app/components/base/input-number' + +const TextLabel: FC = (props) => { + return +} + +const FormField: FC> = (props) => { + return
+ {props.label} + {props.children} +
+} + +export const DelimiterInput: FC = (props) => { + const { t } = useTranslation() + return + {t('datasetCreation.stepTwo.separator')} + + {props.tooltip || t('datasetCreation.stepTwo.separatorTip')} +
+ } + /> +
}> + + +} + +export const MaxLengthInput: FC = (props) => { + const { t } = useTranslation() + return + {t('datasetCreation.stepTwo.maxLength')} +
}> + + +} + +export const OverlapInput: FC = (props) => { + const { t } = useTranslation() + return + {t('datasetCreation.stepTwo.overlap')} + + {t('datasetCreation.stepTwo.overlapTip')} +
+ } + /> +
}> + + +} diff --git a/web/app/components/datasets/create/step-two/language-select/index.tsx b/web/app/components/datasets/create/step-two/language-select/index.tsx index 41f3e0abb55b6e..9cbf1a40d133fc 100644 --- a/web/app/components/datasets/create/step-two/language-select/index.tsx +++ b/web/app/components/datasets/create/step-two/language-select/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import React from 'react' -import { RiArrowDownSLine } from '@remixicon/react' +import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react' import cn from '@/utils/classnames' import Popover from '@/app/components/base/popover' import { languages } from '@/i18n/language' @@ -22,25 +22,40 @@ const LanguageSelect: FC = ({ manualClose trigger='click' disabled={disabled} + popupClassName='z-20' htmlContent={ -
+
{languages.filter(language => language.supported).map(({ prompt_name }) => (
onSelect(prompt_name)}>{prompt_name} + className='w-full py-2 px-3 inline-flex items-center justify-between hover:bg-state-base-hover rounded-lg cursor-pointer' + onClick={() => onSelect(prompt_name)} + > + {prompt_name} + {(currentLanguage === prompt_name) && }
))}
} btnElement={ -
- {currentLanguage} - +
+ + {currentLanguage} + +
} - btnClassName={open => cn('!border-0 !px-0 !py-0 !bg-inherit !hover:bg-inherit', open ? 'text-blue-600' : 'text-gray-500')} - className='!w-[120px] h-fit !z-20 !translate-x-0 !left-[-16px]' + btnClassName={() => cn( + '!border-0 rounded-md !px-1.5 !py-1 !mx-1 !bg-components-button-tertiary-bg !hover:bg-components-button-tertiary-bg', + disabled ? 'bg-components-button-tertiary-bg-disabled' : '', + )} + className='!w-[140px] h-fit !z-20 !translate-x-0 !left-1' /> ) } diff --git a/web/app/components/datasets/create/step-two/option-card.tsx b/web/app/components/datasets/create/step-two/option-card.tsx new file mode 100644 index 00000000000000..d0efdaabb1cdb8 --- /dev/null +++ b/web/app/components/datasets/create/step-two/option-card.tsx @@ -0,0 +1,98 @@ +import { type ComponentProps, type FC, type ReactNode, forwardRef } from 'react' +import Image from 'next/image' +import classNames from '@/utils/classnames' + +const TriangleArrow: FC> = props => ( + + + +) + +type OptionCardHeaderProps = { + icon: ReactNode + title: ReactNode + description: string + isActive?: boolean + activeClassName?: string + effectImg?: string +} + +export const OptionCardHeader: FC = (props) => { + const { icon, title, description, isActive, activeClassName, effectImg } = props + return
+
+ {isActive && effectImg && } +
+
+ {icon} +
+
+
+ +
+
{title}
+
{description}
+
+
+} + +type OptionCardProps = { + icon: ReactNode + className?: string + activeHeaderClassName?: string + title: ReactNode + description: string + isActive?: boolean + actions?: ReactNode + effectImg?: string + onSwitched?: () => void + noHighlight?: boolean + disabled?: boolean +} & Omit, 'title' | 'onClick'> + +export const OptionCard: FC = forwardRef((props, ref) => { + const { icon, className, title, description, isActive, children, actions, activeHeaderClassName, style, effectImg, onSwitched, noHighlight, disabled, ...rest } = props + return
{ + if (!isActive && !disabled) + onSwitched?.() + }} + {...rest} + ref={ref} + > + + {/** Body */} + {isActive && (children || actions) &&
+ {children} + {actions &&
+ {actions} +
+ } +
} +
+}) + +OptionCard.displayName = 'OptionCard' diff --git a/web/app/components/datasets/create/stepper/index.tsx b/web/app/components/datasets/create/stepper/index.tsx new file mode 100644 index 00000000000000..317c1a76eecf57 --- /dev/null +++ b/web/app/components/datasets/create/stepper/index.tsx @@ -0,0 +1,27 @@ +import { type FC, Fragment } from 'react' +import type { Step } from './step' +import { StepperStep } from './step' + +export type StepperProps = { + steps: Step[] + activeIndex: number +} + +export const Stepper: FC = (props) => { + const { steps, activeIndex } = props + return
+ {steps.map((step, index) => { + const isLast = index === steps.length - 1 + return ( + + + {!isLast &&
} + + ) + })} +
+} diff --git a/web/app/components/datasets/create/stepper/step.tsx b/web/app/components/datasets/create/stepper/step.tsx new file mode 100644 index 00000000000000..c230de1a6e748b --- /dev/null +++ b/web/app/components/datasets/create/stepper/step.tsx @@ -0,0 +1,46 @@ +import type { FC } from 'react' +import classNames from '@/utils/classnames' + +export type Step = { + name: string +} + +export type StepperStepProps = Step & { + index: number + activeIndex: number +} + +export const StepperStep: FC = (props) => { + const { name, activeIndex, index } = props + const isActive = index === activeIndex + const isDisabled = activeIndex < index + const label = isActive ? `STEP ${index + 1}` : `${index + 1}` + return
+
+
+ {label} +
+
+
{name}
+
+} diff --git a/web/app/components/datasets/create/top-bar/index.tsx b/web/app/components/datasets/create/top-bar/index.tsx new file mode 100644 index 00000000000000..20ba7158db5ed8 --- /dev/null +++ b/web/app/components/datasets/create/top-bar/index.tsx @@ -0,0 +1,41 @@ +import type { FC } from 'react' +import { RiArrowLeftLine } from '@remixicon/react' +import Link from 'next/link' +import { useTranslation } from 'react-i18next' +import { Stepper, type StepperProps } from '../stepper' +import classNames from '@/utils/classnames' + +export type TopbarProps = Pick & { + className?: string +} + +const STEP_T_MAP: Record = { + 1: 'datasetCreation.steps.one', + 2: 'datasetCreation.steps.two', + 3: 'datasetCreation.steps.three', +} + +export const Topbar: FC = (props) => { + const { className, ...rest } = props + const { t } = useTranslation() + return
+ +
+ +
+

+ {t('datasetCreation.steps.header.creation')} +

+ +
+ ({ + name: t(STEP_T_MAP[i + 1]), + }))} + {...rest} + /> +
+
+} diff --git a/web/app/components/datasets/create/website/base/error-message.tsx b/web/app/components/datasets/create/website/base/error-message.tsx index aa337ec4bf5323..f061c4624e90f2 100644 --- a/web/app/components/datasets/create/website/base/error-message.tsx +++ b/web/app/components/datasets/create/website/base/error-message.tsx @@ -18,7 +18,7 @@ const ErrorMessage: FC = ({ return (
- +
{title}
{errorMsg && ( diff --git a/web/app/components/datasets/create/website/jina-reader/index.tsx b/web/app/components/datasets/create/website/jina-reader/index.tsx index 51d77d712140b7..1c133f935c076b 100644 --- a/web/app/components/datasets/create/website/jina-reader/index.tsx +++ b/web/app/components/datasets/create/website/jina-reader/index.tsx @@ -94,7 +94,6 @@ const JinaReader: FC = ({ const waitForCrawlFinished = useCallback(async (jobId: string) => { try { const res = await checkJinaReaderTaskStatus(jobId) as any - console.log('res', res) if (res.status === 'completed') { return { isError: false, diff --git a/web/app/components/datasets/create/website/preview.tsx b/web/app/components/datasets/create/website/preview.tsx index 65abe83ed771ac..5180a834423014 100644 --- a/web/app/components/datasets/create/website/preview.tsx +++ b/web/app/components/datasets/create/website/preview.tsx @@ -18,7 +18,7 @@ const WebsitePreview = ({ const { t } = useTranslation() return ( -
+
{t('datasetCreation.stepOne.pagePreview')} @@ -32,7 +32,7 @@ const WebsitePreview = ({
{payload.source_url}
-
{payload.markdown}
+
{payload.markdown}
) diff --git a/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx b/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx index 36216aa7c89658..6602244a480a49 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/csv-downloader.tsx @@ -7,7 +7,7 @@ import { import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { Download02 as DownloadIcon } from '@/app/components/base/icons/src/vender/solid/general' -import { DocForm } from '@/models/datasets' +import { ChunkingMode } from '@/models/datasets' import I18n from '@/context/i18n' import { LanguagesSupported } from '@/i18n/language' @@ -32,18 +32,18 @@ const CSV_TEMPLATE_CN = [ ['内容 2'], ] -const CSVDownload: FC<{ docForm: DocForm }> = ({ docForm }) => { +const CSVDownload: FC<{ docForm: ChunkingMode }> = ({ docForm }) => { const { t } = useTranslation() const { locale } = useContext(I18n) const { CSVDownloader, Type } = useCSVDownloader() const getTemplate = () => { if (locale === LanguagesSupported[1]) { - if (docForm === DocForm.QA) + if (docForm === ChunkingMode.qa) return CSV_TEMPLATE_QA_CN return CSV_TEMPLATE_CN } - if (docForm === DocForm.QA) + if (docForm === ChunkingMode.qa) return CSV_TEMPLATE_QA_EN return CSV_TEMPLATE_EN } @@ -52,7 +52,7 @@ const CSVDownload: FC<{ docForm: DocForm }> = ({ docForm }) => {
{t('share.generation.csvStructureTitle')}
- {docForm === DocForm.QA && ( + {docForm === ChunkingMode.qa && ( @@ -72,7 +72,7 @@ const CSVDownload: FC<{ docForm: DocForm }> = ({ docForm }) => {
)} - {docForm === DocForm.TEXT && ( + {docForm === ChunkingMode.text && ( @@ -97,7 +97,7 @@ const CSVDownload: FC<{ docForm: DocForm }> = ({ docForm }) => { bom={true} data={getTemplate()} > -
+
{t('datasetDocuments.list.batchModal.template')}
diff --git a/web/app/components/datasets/documents/detail/batch-modal/index.tsx b/web/app/components/datasets/documents/detail/batch-modal/index.tsx index 139a364cb40292..c666ba67152988 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/index.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/index.tsx @@ -7,11 +7,11 @@ import CSVUploader from './csv-uploader' import CSVDownloader from './csv-downloader' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' -import type { DocForm } from '@/models/datasets' +import type { ChunkingMode } from '@/models/datasets' export type IBatchModalProps = { isShow: boolean - docForm: DocForm + docForm: ChunkingMode onCancel: () => void onConfirm: (file: File) => void } diff --git a/web/app/components/datasets/documents/detail/completed/InfiniteVirtualList.tsx b/web/app/components/datasets/documents/detail/completed/InfiniteVirtualList.tsx deleted file mode 100644 index 7b510bcf21b626..00000000000000 --- a/web/app/components/datasets/documents/detail/completed/InfiniteVirtualList.tsx +++ /dev/null @@ -1,98 +0,0 @@ -import type { CSSProperties, FC } from 'react' -import React from 'react' -import { FixedSizeList as List } from 'react-window' -import InfiniteLoader from 'react-window-infinite-loader' -import SegmentCard from './SegmentCard' -import s from './style.module.css' -import type { SegmentDetailModel } from '@/models/datasets' - -type IInfiniteVirtualListProps = { - hasNextPage?: boolean // Are there more items to load? (This information comes from the most recent API request.) - isNextPageLoading: boolean // Are we currently loading a page of items? (This may be an in-flight flag in your Redux store for example.) - items: Array // Array of items loaded so far. - loadNextPage: () => Promise // Callback function responsible for loading the next page of items. - onClick: (detail: SegmentDetailModel) => void - onChangeSwitch: (segId: string, enabled: boolean) => Promise - onDelete: (segId: string) => Promise - archived?: boolean - embeddingAvailable: boolean -} - -const InfiniteVirtualList: FC = ({ - hasNextPage, - isNextPageLoading, - items, - loadNextPage, - onClick: onClickCard, - onChangeSwitch, - onDelete, - archived, - embeddingAvailable, -}) => { - // If there are more items to be loaded then add an extra row to hold a loading indicator. - const itemCount = hasNextPage ? items.length + 1 : items.length - - // Only load 1 page of items at a time. - // Pass an empty callback to InfiniteLoader in case it asks us to load more than once. - const loadMoreItems = isNextPageLoading ? () => { } : loadNextPage - - // Every row is loaded except for our loading indicator row. - const isItemLoaded = (index: number) => !hasNextPage || index < items.length - - // Render an item or a loading indicator. - const Item = ({ index, style }: { index: number; style: CSSProperties }) => { - let content - if (!isItemLoaded(index)) { - content = ( - <> - {[1, 2, 3].map(v => ( - - ))} - - ) - } - else { - content = items[index].map(segItem => ( - onClickCard(segItem)} - onChangeSwitch={onChangeSwitch} - onDelete={onDelete} - loading={false} - archived={archived} - embeddingAvailable={embeddingAvailable} - /> - )) - } - - return ( -
- {content} -
- ) - } - - return ( - - {({ onItemsRendered, ref }) => ( - - {Item} - - )} - - ) -} -export default InfiniteVirtualList diff --git a/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx b/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx index 5b76acc9360c69..264d62b68a4f32 100644 --- a/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx +++ b/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx @@ -6,9 +6,9 @@ import { RiDeleteBinLine, } from '@remixicon/react' import { StatusItem } from '../../list' -import { DocumentTitle } from '../index' +import style from '../../style.module.css' import s from './style.module.css' -import { SegmentIndexTag } from './index' +import { SegmentIndexTag } from './common/segment-index-tag' import cn from '@/utils/classnames' import Confirm from '@/app/components/base/confirm' import Switch from '@/app/components/base/switch' @@ -31,6 +31,22 @@ const ProgressBar: FC<{ percent: number; loading: boolean }> = ({ percent, loadi ) } +type DocumentTitleProps = { + extension?: string + name?: string + iconCls?: string + textCls?: string + wrapperCls?: string +} + +const DocumentTitle: FC = ({ extension, name, iconCls, textCls, wrapperCls }) => { + const localExtension = extension?.toLowerCase() || name?.split('.')?.pop()?.toLowerCase() + return
+
+ {name || '--'} +
+} + export type UsageScene = 'doc' | 'hitTesting' type ISegmentCardProps = { diff --git a/web/app/components/datasets/documents/detail/completed/child-segment-detail.tsx b/web/app/components/datasets/documents/detail/completed/child-segment-detail.tsx new file mode 100644 index 00000000000000..085bfddc163415 --- /dev/null +++ b/web/app/components/datasets/documents/detail/completed/child-segment-detail.tsx @@ -0,0 +1,134 @@ +import React, { type FC, useMemo, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { + RiCloseLine, + RiExpandDiagonalLine, +} from '@remixicon/react' +import ActionButtons from './common/action-buttons' +import ChunkContent from './common/chunk-content' +import Dot from './common/dot' +import { SegmentIndexTag } from './common/segment-index-tag' +import { useSegmentListContext } from './index' +import type { ChildChunkDetail, ChunkingMode } from '@/models/datasets' +import { useEventEmitterContextContext } from '@/context/event-emitter' +import { formatNumber } from '@/utils/format' +import classNames from '@/utils/classnames' +import Divider from '@/app/components/base/divider' +import { formatTime } from '@/utils/time' + +type IChildSegmentDetailProps = { + chunkId: string + childChunkInfo?: Partial & { id: string } + onUpdate: (segmentId: string, childChunkId: string, content: string) => void + onCancel: () => void + docForm: ChunkingMode +} + +/** + * Show all the contents of the segment + */ +const ChildSegmentDetail: FC = ({ + chunkId, + childChunkInfo, + onUpdate, + onCancel, + docForm, +}) => { + const { t } = useTranslation() + const [content, setContent] = useState(childChunkInfo?.content || '') + const { eventEmitter } = useEventEmitterContextContext() + const [loading, setLoading] = useState(false) + const fullScreen = useSegmentListContext(s => s.fullScreen) + const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen) + + eventEmitter?.useSubscription((v) => { + if (v === 'update-child-segment') + setLoading(true) + if (v === 'update-child-segment-done') + setLoading(false) + }) + + const handleCancel = () => { + onCancel() + setContent(childChunkInfo?.content || '') + } + + const handleSave = () => { + onUpdate(chunkId, childChunkInfo?.id || '', content) + } + + const wordCountText = useMemo(() => { + const count = content.length + return `${formatNumber(count)} ${t('datasetDocuments.segment.characters', { count })}` + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [content.length]) + + const EditTimeText = useMemo(() => { + const timeText = formatTime({ + date: (childChunkInfo?.updated_at ?? 0) * 1000, + dateFormat: 'MM/DD/YYYY h:mm:ss', + }) + return `${t('datasetDocuments.segment.editedAt')} ${timeText}` + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [childChunkInfo?.updated_at]) + + return ( +
+
+
+
{t('datasetDocuments.segment.editChildChunk')}
+
+ + + {wordCountText} + + + {EditTimeText} + +
+
+
+ {fullScreen && ( + <> + + + + )} +
+ +
+
+ +
+
+
+
+
+ setContent(content)} + isEditMode={true} + /> +
+
+ {!fullScreen && ( +
+ +
+ )} +
+ ) +} + +export default React.memo(ChildSegmentDetail) diff --git a/web/app/components/datasets/documents/detail/completed/child-segment-list.tsx b/web/app/components/datasets/documents/detail/completed/child-segment-list.tsx new file mode 100644 index 00000000000000..1615ea98cf045a --- /dev/null +++ b/web/app/components/datasets/documents/detail/completed/child-segment-list.tsx @@ -0,0 +1,195 @@ +import { type FC, useMemo, useState } from 'react' +import { RiArrowDownSLine, RiArrowRightSLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { EditSlice } from '../../../formatted-text/flavours/edit-slice' +import { useDocumentContext } from '../index' +import { FormattedText } from '../../../formatted-text/formatted' +import Empty from './common/empty' +import FullDocListSkeleton from './skeleton/full-doc-list-skeleton' +import { useSegmentListContext } from './index' +import type { ChildChunkDetail } from '@/models/datasets' +import Input from '@/app/components/base/input' +import classNames from '@/utils/classnames' +import Divider from '@/app/components/base/divider' +import { formatNumber } from '@/utils/format' + +type IChildSegmentCardProps = { + childChunks: ChildChunkDetail[] + parentChunkId: string + handleInputChange?: (value: string) => void + handleAddNewChildChunk?: (parentChunkId: string) => void + enabled: boolean + onDelete?: (segId: string, childChunkId: string) => Promise + onClickSlice?: (childChunk: ChildChunkDetail) => void + total?: number + inputValue?: string + onClearFilter?: () => void + isLoading?: boolean + focused?: boolean +} + +const ChildSegmentList: FC = ({ + childChunks, + parentChunkId, + handleInputChange, + handleAddNewChildChunk, + enabled, + onDelete, + onClickSlice, + total, + inputValue, + onClearFilter, + isLoading, + focused = false, +}) => { + const { t } = useTranslation() + const parentMode = useDocumentContext(s => s.parentMode) + const currChildChunk = useSegmentListContext(s => s.currChildChunk) + + const [collapsed, setCollapsed] = useState(true) + + const toggleCollapse = () => { + setCollapsed(!collapsed) + } + + const isParagraphMode = useMemo(() => { + return parentMode === 'paragraph' + }, [parentMode]) + + const isFullDocMode = useMemo(() => { + return parentMode === 'full-doc' + }, [parentMode]) + + const contentOpacity = useMemo(() => { + return (enabled || focused) ? '' : 'opacity-50 group-hover/card:opacity-100' + }, [enabled, focused]) + + const totalText = useMemo(() => { + const isSearch = inputValue !== '' && isFullDocMode + if (!isSearch) { + const text = isFullDocMode + ? !total + ? '--' + : formatNumber(total) + : formatNumber(childChunks.length) + const count = isFullDocMode + ? text === '--' + ? 0 + : total + : childChunks.length + return `${text} ${t('datasetDocuments.segment.childChunks', { count })}` + } + else { + const text = !total ? '--' : formatNumber(total) + const count = text === '--' ? 0 : total + return `${count} ${t('datasetDocuments.segment.searchResults', { count })}` + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [isFullDocMode, total, childChunks.length, inputValue]) + + return ( +
+ {isFullDocMode ? : null} +
+
{ + event.stopPropagation() + toggleCollapse() + }} + > + { + isParagraphMode + ? collapsed + ? ( + + ) + : () + : null + } + {totalText} + · + +
+ {isFullDocMode + ? handleInputChange?.(e.target.value)} + onClear={() => handleInputChange?.('')} + /> + : null} +
+ {isLoading ? : null} + {((isFullDocMode && !isLoading) || !collapsed) + ?
+ {isParagraphMode && ( +
+ +
+ )} + {childChunks.length > 0 + ? + {childChunks.map((childChunk) => { + const edited = childChunk.updated_at !== childChunk.created_at + const focused = currChildChunk?.childChunkInfo?.id === childChunk.id + return onDelete?.(childChunk.segment_id, childChunk.id)} + labelClassName={focused ? 'bg-state-accent-solid text-text-primary-on-surface' : ''} + labelInnerClassName={'text-[10px] font-semibold align-bottom leading-6'} + contentClassName={classNames('!leading-6', focused ? 'bg-state-accent-hover-alt text-text-primary' : '')} + showDivider={false} + onClick={(e) => { + e.stopPropagation() + onClickSlice?.(childChunk) + }} + offsetOptions={({ rects }) => { + return { + mainAxis: isFullDocMode ? -rects.floating.width : 12 - rects.floating.width, + crossAxis: (20 - rects.floating.height) / 2, + } + }} + /> + })} + + : inputValue !== '' + ?
+ +
+ : null + } +
+ : null} +
+ ) +} + +export default ChildSegmentList diff --git a/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx b/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx new file mode 100644 index 00000000000000..1238d98a9c5025 --- /dev/null +++ b/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx @@ -0,0 +1,86 @@ +import React, { type FC, useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import { useKeyPress } from 'ahooks' +import { useDocumentContext } from '../../index' +import Button from '@/app/components/base/button' +import { getKeyboardKeyCodeBySystem, getKeyboardKeyNameBySystem } from '@/app/components/workflow/utils' + +type IActionButtonsProps = { + handleCancel: () => void + handleSave: () => void + loading: boolean + actionType?: 'edit' | 'add' + handleRegeneration?: () => void + isChildChunk?: boolean +} + +const ActionButtons: FC = ({ + handleCancel, + handleSave, + loading, + actionType = 'edit', + handleRegeneration, + isChildChunk = false, +}) => { + const { t } = useTranslation() + const mode = useDocumentContext(s => s.mode) + const parentMode = useDocumentContext(s => s.parentMode) + + useKeyPress(['esc'], (e) => { + e.preventDefault() + handleCancel() + }) + + useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.s`, (e) => { + e.preventDefault() + if (loading) + return + handleSave() + } + , { exactMatch: true, useCapture: true }) + + const isParentChildParagraphMode = useMemo(() => { + return mode === 'hierarchical' && parentMode === 'paragraph' + }, [mode, parentMode]) + + return ( +
+ + {(isParentChildParagraphMode && actionType === 'edit' && !isChildChunk) + ? + : null + } + +
+ ) +} + +ActionButtons.displayName = 'ActionButtons' + +export default React.memo(ActionButtons) diff --git a/web/app/components/datasets/documents/detail/completed/common/add-another.tsx b/web/app/components/datasets/documents/detail/completed/common/add-another.tsx new file mode 100644 index 00000000000000..444560e55f7f6b --- /dev/null +++ b/web/app/components/datasets/documents/detail/completed/common/add-another.tsx @@ -0,0 +1,32 @@ +import React, { type FC } from 'react' +import { useTranslation } from 'react-i18next' +import classNames from '@/utils/classnames' +import Checkbox from '@/app/components/base/checkbox' + +type AddAnotherProps = { + className?: string + isChecked: boolean + onCheck: () => void +} + +const AddAnother: FC = ({ + className, + isChecked, + onCheck, +}) => { + const { t } = useTranslation() + + return ( +
+ + {t('datasetDocuments.segment.addAnother')} +
+ ) +} + +export default React.memo(AddAnother) diff --git a/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx b/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx new file mode 100644 index 00000000000000..3dd3689b64fb43 --- /dev/null +++ b/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx @@ -0,0 +1,103 @@ +import React, { type FC } from 'react' +import { RiArchive2Line, RiCheckboxCircleLine, RiCloseCircleLine, RiDeleteBinLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { useBoolean } from 'ahooks' +import Divider from '@/app/components/base/divider' +import classNames from '@/utils/classnames' +import Confirm from '@/app/components/base/confirm' + +const i18nPrefix = 'dataset.batchAction' +type IBatchActionProps = { + className?: string + selectedIds: string[] + onBatchEnable: () => void + onBatchDisable: () => void + onBatchDelete: () => Promise + onArchive?: () => void + onCancel: () => void +} + +const BatchAction: FC = ({ + className, + selectedIds, + onBatchEnable, + onBatchDisable, + onArchive, + onBatchDelete, + onCancel, +}) => { + const { t } = useTranslation() + const [isShowDeleteConfirm, { + setTrue: showDeleteConfirm, + setFalse: hideDeleteConfirm, + }] = useBoolean(false) + const [isDeleting, { + setTrue: setIsDeleting, + }] = useBoolean(false) + + const handleBatchDelete = async () => { + setIsDeleting() + await onBatchDelete() + hideDeleteConfirm() + } + return ( +
+
+
+ + {selectedIds.length} + + {t(`${i18nPrefix}.selected`)} +
+ +
+ + +
+
+ + +
+ {onArchive && ( +
+ + +
+ )} +
+ + +
+ + + +
+ { + isShowDeleteConfirm && ( + + ) + } +
+ ) +} + +export default React.memo(BatchAction) diff --git a/web/app/components/datasets/documents/detail/completed/common/chunk-content.tsx b/web/app/components/datasets/documents/detail/completed/common/chunk-content.tsx new file mode 100644 index 00000000000000..e6403fa12fd0aa --- /dev/null +++ b/web/app/components/datasets/documents/detail/completed/common/chunk-content.tsx @@ -0,0 +1,192 @@ +import React, { useEffect, useRef, useState } from 'react' +import type { ComponentProps, FC } from 'react' +import { useTranslation } from 'react-i18next' +import { ChunkingMode } from '@/models/datasets' +import classNames from '@/utils/classnames' + +type IContentProps = ComponentProps<'textarea'> + +const Textarea: FC = React.memo(({ + value, + placeholder, + className, + disabled, + ...rest +}) => { + return ( +