From c91e8b1737f28c9997ce7bd9c1e8ac4e377ef1b7 Mon Sep 17 00:00:00 2001
From: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com>
Date: Tue, 24 Dec 2024 16:26:47 +0800
Subject: [PATCH 01/65] fix: modal bg color (#12042)
---
web/app/components/base/modal/index.tsx | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/web/app/components/base/modal/index.tsx b/web/app/components/base/modal/index.tsx
index 3040cdb00b502a..26cde5fce3dd6a 100644
--- a/web/app/components/base/modal/index.tsx
+++ b/web/app/components/base/modal/index.tsx
@@ -39,7 +39,7 @@ export default function Modal({
leaveFrom="opacity-100"
leaveTo="opacity-0"
>
-
Date: Tue, 24 Dec 2024 18:38:51 +0800
Subject: [PATCH 02/65] feat: mypy for all type check (#10921)
---
.github/workflows/api-tests.yml | 6 +
api/commands.py | 13 +-
api/configs/feature/__init__.py | 14 +-
api/configs/middleware/__init__.py | 4 -
.../remote_settings_sources/apollo/client.py | 5 +-
api/constants/model_template.py | 3 +-
api/controllers/common/fields.py | 2 +-
api/controllers/console/__init__.py | 95 +++++++-
api/controllers/console/admin.py | 2 +-
api/controllers/console/apikey.py | 23 +-
.../console/app/advanced_prompt_template.py | 2 +-
api/controllers/console/app/agent.py | 2 +-
api/controllers/console/app/annotation.py | 6 +-
api/controllers/console/app/app.py | 4 +-
api/controllers/console/app/app_import.py | 4 +-
api/controllers/console/app/audio.py | 2 +-
api/controllers/console/app/completion.py | 4 +-
api/controllers/console/app/conversation.py | 15 +-
.../console/app/conversation_variables.py | 2 +-
api/controllers/console/app/generator.py | 4 +-
api/controllers/console/app/message.py | 6 +-
api/controllers/console/app/model_config.py | 13 +-
api/controllers/console/app/ops_trace.py | 2 +-
api/controllers/console/app/site.py | 6 +-
api/controllers/console/app/statistic.py | 4 +-
api/controllers/console/app/workflow.py | 2 +-
.../console/app/workflow_app_log.py | 4 +-
api/controllers/console/app/workflow_run.py | 4 +-
.../console/app/workflow_statistic.py | 4 +-
api/controllers/console/app/wraps.py | 2 +-
api/controllers/console/auth/activate.py | 6 +-
.../console/auth/data_source_bearer_auth.py | 4 +-
.../console/auth/data_source_oauth.py | 8 +-
.../console/auth/forgot_password.py | 6 +-
api/controllers/console/auth/login.py | 4 +-
api/controllers/console/auth/oauth.py | 7 +-
api/controllers/console/billing/billing.py | 4 +-
.../console/datasets/data_source.py | 4 +-
api/controllers/console/datasets/datasets.py | 6 +-
.../console/datasets/datasets_document.py | 10 +-
.../console/datasets/datasets_segments.py | 4 +-
api/controllers/console/datasets/external.py | 4 +-
.../console/datasets/hit_testing.py | 2 +-
.../console/datasets/hit_testing_base.py | 4 +-
api/controllers/console/datasets/website.py | 2 +-
api/controllers/console/explore/audio.py | 9 +-
api/controllers/console/explore/completion.py | 23 +-
.../console/explore/conversation.py | 32 +--
.../console/explore/installed_app.py | 11 +-
api/controllers/console/explore/message.py | 25 +-
api/controllers/console/explore/parameter.py | 2 +-
.../console/explore/recommended_app.py | 4 +-
.../console/explore/saved_message.py | 6 +-
api/controllers/console/explore/workflow.py | 9 +-
api/controllers/console/explore/wraps.py | 4 +-
api/controllers/console/extension.py | 4 +-
api/controllers/console/feature.py | 4 +-
api/controllers/console/files.py | 9 +-
api/controllers/console/init_validate.py | 2 +-
api/controllers/console/ping.py | 2 +-
api/controllers/console/remote_files.py | 4 +-
api/controllers/console/setup.py | 2 +-
api/controllers/console/tag/tags.py | 6 +-
api/controllers/console/version.py | 2 +-
api/controllers/console/workspace/account.py | 4 +-
.../workspace/load_balancing_config.py | 6 +-
api/controllers/console/workspace/members.py | 31 +--
.../console/workspace/model_providers.py | 9 +-
api/controllers/console/workspace/models.py | 6 +-
.../console/workspace/tool_providers.py | 4 +-
.../console/workspace/workspace.py | 14 +-
api/controllers/console/wraps.py | 6 +-
api/controllers/files/image_preview.py | 2 +-
api/controllers/files/tool_files.py | 2 +-
.../inner_api/workspace/workspace.py | 2 +-
api/controllers/inner_api/wraps.py | 6 +-
api/controllers/service_api/app/app.py | 2 +-
api/controllers/service_api/app/audio.py | 4 +-
api/controllers/service_api/app/completion.py | 2 +-
.../service_api/app/conversation.py | 4 +-
api/controllers/service_api/app/file.py | 2 +-
api/controllers/service_api/app/message.py | 4 +-
api/controllers/service_api/app/workflow.py | 4 +-
.../service_api/dataset/dataset.py | 2 +-
.../service_api/dataset/document.py | 2 +-
.../service_api/dataset/segment.py | 4 +-
api/controllers/service_api/index.py | 2 +-
api/controllers/service_api/wraps.py | 10 +-
api/controllers/web/app.py | 2 +-
api/controllers/web/audio.py | 4 +-
api/controllers/web/completion.py | 2 +-
api/controllers/web/conversation.py | 4 +-
api/controllers/web/feature.py | 2 +-
api/controllers/web/files.py | 4 +-
api/controllers/web/message.py | 4 +-
api/controllers/web/passport.py | 2 +-
api/controllers/web/remote_files.py | 2 +-
api/controllers/web/saved_message.py | 4 +-
api/controllers/web/site.py | 2 +-
api/controllers/web/workflow.py | 2 +-
api/controllers/web/wraps.py | 2 +-
api/core/agent/base_agent_runner.py | 47 ++--
api/core/agent/cot_agent_runner.py | 82 ++++---
api/core/agent/cot_chat_agent_runner.py | 6 +
api/core/agent/cot_completion_agent_runner.py | 21 +-
api/core/agent/entities.py | 2 +-
api/core/agent/fc_agent_runner.py | 46 ++--
.../agent/output_parser/cot_output_parser.py | 10 +-
.../easy_ui_based_app/dataset/manager.py | 2 +
.../easy_ui_based_app/model_config/manager.py | 2 +-
.../features/opening_statement/manager.py | 4 +-
.../app/apps/advanced_chat/app_generator.py | 5 +-
.../app_generator_tts_publisher.py | 23 +-
api/core/app/apps/advanced_chat/app_runner.py | 10 +-
.../advanced_chat/generate_task_pipeline.py | 28 ++-
.../app/apps/agent_chat/app_config_manager.py | 2 +-
api/core/app/apps/agent_chat/app_generator.py | 13 +-
api/core/app/apps/agent_chat/app_runner.py | 19 +-
.../agent_chat/generate_response_converter.py | 14 +-
api/core/app/apps/base_app_queue_manager.py | 2 +-
api/core/app/apps/base_app_runner.py | 45 ++--
api/core/app/apps/chat/app_generator.py | 11 +-
.../apps/chat/generate_response_converter.py | 10 +-
.../app/apps/completion/app_config_manager.py | 2 +-
api/core/app/apps/completion/app_generator.py | 18 +-
api/core/app/apps/completion/app_runner.py | 4 +-
.../completion/generate_response_converter.py | 10 +-
.../app/apps/message_based_app_generator.py | 12 +-
api/core/app/apps/workflow/app_generator.py | 2 +-
.../workflow/generate_response_converter.py | 12 +-
api/core/app/apps/workflow_app_runner.py | 14 +-
api/core/app/entities/app_invoke_entities.py | 4 +-
api/core/app/entities/queue_entities.py | 2 +-
api/core/app/entities/task_entities.py | 14 +-
.../annotation_reply/annotation_reply.py | 2 +-
.../app/features/rate_limiting/rate_limit.py | 2 +-
.../based_generate_task_pipeline.py | 2 +
.../easy_ui_based_generate_task_pipeline.py | 51 ++--
.../app/task_pipeline/message_cycle_manage.py | 2 +-
.../task_pipeline/workflow_cycle_manage.py | 22 +-
.../agent_tool_callback_handler.py | 2 +-
.../index_tool_callback_handler.py | 17 +-
api/core/entities/model_entities.py | 3 +-
api/core/entities/provider_configuration.py | 110 +++++----
.../api_based_extension_requestor.py | 6 +-
api/core/extension/extensible.py | 9 +-
api/core/extension/extension.py | 10 +-
api/core/external_data_tool/api/api.py | 3 +
.../external_data_tool/external_data_fetch.py | 24 +-
api/core/external_data_tool/factory.py | 10 +-
api/core/file/file_manager.py | 5 +-
api/core/file/tool_file_parser.py | 4 +-
.../helper/code_executor/code_executor.py | 16 +-
.../code_executor/jinja2/jinja2_formatter.py | 7 +-
.../code_executor/template_transformer.py | 3 +-
api/core/helper/lru_cache.py | 2 +-
api/core/helper/model_provider_cache.py | 2 +-
api/core/helper/moderation.py | 4 +-
api/core/helper/module_import_helper.py | 13 +-
api/core/helper/tool_parameter_cache.py | 2 +-
api/core/helper/tool_provider_cache.py | 2 +-
api/core/hosting_configuration.py | 14 +-
api/core/indexing_runner.py | 66 ++---
api/core/llm_generator/llm_generator.py | 89 ++++---
api/core/memory/token_buffer_memory.py | 2 +-
api/core/model_manager.py | 149 +++++++-----
.../callbacks/logging_callback.py | 13 +-
.../entities/message_entities.py | 3 +-
.../model_providers/__base/ai_model.py | 13 +-
.../__base/large_language_model.py | 14 +-
.../model_providers/__base/model_provider.py | 3 +-
.../__base/text_embedding_model.py | 6 +-
.../__base/tokenizers/gpt2_tokenzier.py | 4 +-
.../model_providers/__base/tts_model.py | 9 +-
.../azure_openai/speech2text/speech2text.py | 4 +-
.../model_providers/azure_openai/tts/tts.py | 2 +
.../model_providers/bedrock/llm/llm.py | 6 +-
.../model_providers/cohere/rerank/rerank.py | 4 +-
.../model_providers/fireworks/_common.py | 4 +-
.../text_embedding/text_embedding.py | 3 +-
.../model_providers/gitee_ai/_common.py | 2 +-
.../model_providers/gitee_ai/rerank/rerank.py | 4 +-
.../gitee_ai/text_embedding/text_embedding.py | 2 +-
.../model_providers/gitee_ai/tts/tts.py | 8 +-
.../model_providers/google/llm/llm.py | 2 +-
.../huggingface_hub/_common.py | 2 +-
.../huggingface_hub/llm/llm.py | 6 +-
.../text_embedding/text_embedding.py | 2 +-
.../model_providers/hunyuan/llm/llm.py | 12 +-
.../hunyuan/text_embedding/text_embedding.py | 10 +-
.../jina/text_embedding/jina_tokenizer.py | 4 +-
.../minimax/llm/chat_completion.py | 30 +--
.../minimax/llm/chat_completion_pro.py | 26 +-
.../model_providers/minimax/llm/types.py | 4 +-
.../nomic/text_embedding/text_embedding.py | 4 +-
.../model_providers/oci/llm/llm.py | 4 +-
.../oci/text_embedding/text_embedding.py | 2 +-
.../ollama/text_embedding/text_embedding.py | 1 +
.../model_providers/openai/_common.py | 4 +-
.../openai/moderation/moderation.py | 6 +-
.../model_providers/openai/openai.py | 3 +-
.../speech2text/speech2text.py | 1 +
.../text_embedding/text_embedding.py | 1 +
.../openai_api_compatible/tts/tts.py | 1 +
.../openllm/llm/openllm_generate.py | 16 +-
.../text_embedding/text_embedding.py | 7 +-
.../model_providers/replicate/_common.py | 2 +-
.../model_providers/replicate/llm/llm.py | 6 +-
.../text_embedding/text_embedding.py | 6 +-
.../model_providers/sagemaker/llm/llm.py | 8 +-
.../sagemaker/rerank/rerank.py | 3 +-
.../sagemaker/speech2text/speech2text.py | 3 +-
.../text_embedding/text_embedding.py | 3 +-
.../model_providers/sagemaker/tts/tts.py | 2 +-
.../model_providers/siliconflow/llm/llm.py | 2 +-
.../model_providers/spark/llm/llm.py | 4 +-
.../model_providers/togetherai/llm/llm.py | 3 +-
.../model_providers/tongyi/_common.py | 2 +-
.../model_providers/tongyi/llm/llm.py | 6 +-
.../model_providers/tongyi/rerank/rerank.py | 8 +-
.../tongyi/text_embedding/text_embedding.py | 2 +-
.../model_providers/tongyi/tts/tts.py | 8 +-
.../model_providers/upstage/_common.py | 4 +-
.../model_providers/upstage/llm/llm.py | 2 +-
.../upstage/text_embedding/text_embedding.py | 5 +-
.../model_providers/vertex_ai/_common.py | 2 +-
.../model_providers/vertex_ai/llm/llm.py | 2 +-
.../model_providers/vessl_ai/llm/llm.py | 4 +-
.../model_providers/volcengine_maas/client.py | 12 +-
.../volcengine_maas/legacy/errors.py | 3 +-
.../volcengine_maas/llm/llm.py | 2 +-
.../volcengine_maas/llm/models.py | 6 +-
.../model_providers/wenxin/llm/ernie_bot.py | 5 +-
.../wenxin/text_embedding/text_embedding.py | 9 +-
.../model_providers/xinference/llm/llm.py | 2 +-
.../xinference/rerank/rerank.py | 2 +-
.../xinference/speech2text/speech2text.py | 2 +-
.../text_embedding/text_embedding.py | 4 +-
.../model_providers/xinference/tts/tts.py | 7 +-
.../xinference/xinference_helper.py | 12 +-
.../model_providers/yi/llm/llm.py | 2 +-
.../model_providers/zhipuai/llm/llm.py | 6 +-
.../zhipuai/text_embedding/text_embedding.py | 2 +-
.../schema_validators/common_validator.py | 7 +-
api/core/model_runtime/utils/encoders.py | 3 +-
api/core/model_runtime/utils/helper.py | 3 +-
api/core/moderation/api/api.py | 14 +-
api/core/moderation/base.py | 4 +-
api/core/moderation/factory.py | 3 +-
api/core/moderation/input_moderation.py | 8 +-
api/core/moderation/keywords/keywords.py | 6 +-
.../openai_moderation/openai_moderation.py | 4 +
api/core/moderation/output_moderation.py | 2 +-
api/core/ops/entities/trace_entity.py | 5 +-
api/core/ops/langfuse_trace/langfuse_trace.py | 12 +-
.../entities/langsmith_trace_entity.py | 1 -
.../ops/langsmith_trace/langsmith_trace.py | 155 ++++++++++--
api/core/ops/ops_trace_manager.py | 52 ++--
api/core/prompt/advanced_prompt_transform.py | 37 +--
.../prompt/agent_history_prompt_transform.py | 2 +-
api/core/prompt/prompt_transform.py | 22 +-
api/core/prompt/simple_prompt_transform.py | 25 +-
api/core/prompt/utils/prompt_message_util.py | 8 +-
.../prompt/utils/prompt_template_parser.py | 3 +-
api/core/provider_manager.py | 37 ++-
.../rag/datasource/keyword/jieba/jieba.py | 37 ++-
.../jieba/jieba_keyword_table_handler.py | 4 +-
.../rag/datasource/keyword/keyword_base.py | 4 +-
api/core/rag/datasource/retrieval_service.py | 23 +-
.../vdb/analyticdb/analyticdb_vector.py | 35 +--
.../analyticdb/analyticdb_vector_openapi.py | 37 +--
.../vdb/analyticdb/analyticdb_vector_sql.py | 24 +-
.../rag/datasource/vdb/baidu/baidu_vector.py | 28 ++-
.../datasource/vdb/chroma/chroma_vector.py | 24 +-
.../vdb/couchbase/couchbase_vector.py | 24 +-
.../vdb/elasticsearch/elasticsearch_vector.py | 22 +-
.../datasource/vdb/lindorm/lindorm_vector.py | 33 ++-
.../datasource/vdb/milvus/milvus_vector.py | 22 +-
.../datasource/vdb/myscale/myscale_vector.py | 19 +-
.../vdb/oceanbase/oceanbase_vector.py | 4 +-
.../vdb/opensearch/opensearch_vector.py | 4 +-
.../rag/datasource/vdb/oracle/oraclevector.py | 42 ++--
.../datasource/vdb/pgvecto_rs/pgvecto_rs.py | 14 +-
.../rag/datasource/vdb/pgvector/pgvector.py | 31 +--
.../datasource/vdb/qdrant/qdrant_vector.py | 21 +-
.../rag/datasource/vdb/relyt/relyt_vector.py | 26 +-
.../datasource/vdb/tencent/tencent_vector.py | 30 +--
.../tidb_on_qdrant/tidb_on_qdrant_vector.py | 31 ++-
.../vdb/tidb_on_qdrant/tidb_service.py | 11 +-
.../datasource/vdb/tidb_vector/tidb_vector.py | 12 +-
api/core/rag/datasource/vdb/vector_base.py | 11 +-
api/core/rag/datasource/vdb/vector_factory.py | 9 +-
.../vdb/vikingdb/vikingdb_vector.py | 11 +-
.../vdb/weaviate/weaviate_vector.py | 14 +-
api/core/rag/docstore/dataset_docstore.py | 9 +-
api/core/rag/embedding/cached_embedding.py | 16 +-
.../rag/extractor/entity/extract_setting.py | 2 +-
api/core/rag/extractor/excel_extractor.py | 8 +-
api/core/rag/extractor/extract_processor.py | 20 +-
.../rag/extractor/firecrawl/firecrawl_app.py | 21 +-
api/core/rag/extractor/html_extractor.py | 3 +-
api/core/rag/extractor/notion_extractor.py | 14 +-
api/core/rag/extractor/pdf_extractor.py | 6 +-
.../unstructured_eml_extractor.py | 2 +-
.../unstructured_epub_extractor.py | 3 +
.../unstructured_ppt_extractor.py | 4 +-
.../unstructured_pptx_extractor.py | 11 +-
api/core/rag/extractor/word_extractor.py | 6 +
.../index_processor/index_processor_base.py | 1 +
.../index_processor_factory.py | 2 +-
.../processor/paragraph_index_processor.py | 10 +-
.../processor/qa_index_processor.py | 27 +-
api/core/rag/rerank/rerank_model.py | 11 +-
api/core/rag/rerank/weight_rerank.py | 16 +-
api/core/rag/retrieval/dataset_retrieval.py | 111 +++++----
.../multi_dataset_function_call_router.py | 16 +-
.../router/multi_dataset_react_route.py | 22 +-
api/core/rag/splitter/fixed_text_splitter.py | 4 +-
api/core/rag/splitter/text_splitter.py | 4 +-
api/core/tools/entities/api_entities.py | 2 +-
api/core/tools/entities/tool_bundle.py | 2 +-
api/core/tools/entities/tool_entities.py | 12 +-
api/core/tools/provider/api_tool_provider.py | 70 +++---
api/core/tools/provider/app_tool_provider.py | 15 +-
api/core/tools/provider/builtin/_positions.py | 2 +-
.../provider/builtin/aippt/tools/aippt.py | 36 +--
.../builtin/arxiv/tools/arxiv_search.py | 2 +-
.../tools/provider/builtin/audio/tools/tts.py | 20 +-
.../builtin/aws/tools/apply_guardrail.py | 4 +-
.../aws/tools/lambda_translate_utils.py | 2 +-
.../builtin/aws/tools/lambda_yaml_to_json.py | 2 +-
.../aws/tools/sagemaker_text_rerank.py | 6 +-
.../builtin/aws/tools/sagemaker_tts.py | 4 +-
.../builtin/cogview/tools/cogvideo.py | 2 +-
.../builtin/cogview/tools/cogvideo_job.py | 2 +-
.../builtin/cogview/tools/cogview3.py | 2 +-
.../feishu_base/tools/search_records.py | 20 +-
.../feishu_base/tools/update_records.py | 12 +-
.../tools/add_event_attendees.py | 8 +-
.../feishu_calendar/tools/delete_event.py | 6 +-
.../tools/get_primary_calendar.py | 4 +
.../feishu_calendar/tools/list_events.py | 12 +-
.../feishu_calendar/tools/update_event.py | 14 +-
.../feishu_document/tools/create_document.py | 10 +-
.../tools/list_document_blocks.py | 6 +-
.../builtin/json_process/tools/delete.py | 2 +-
.../builtin/json_process/tools/insert.py | 2 +-
.../builtin/json_process/tools/parse.py | 2 +-
.../builtin/json_process/tools/replace.py | 2 +-
.../builtin/maths/tools/eval_expression.py | 2 +-
.../builtin/novitaai/_novita_tool_base.py | 2 +-
.../novitaai/tools/novitaai_createtile.py | 2 +-
.../novitaai/tools/novitaai_txt2img.py | 2 +-
.../tools/podcast_audio_generator.py | 2 +-
.../builtin/qrcode/tools/qrcode_generator.py | 8 +-
.../builtin/transcript/tools/transcript.py | 2 +-
.../builtin/twilio/tools/send_message.py | 2 +-
.../tools/provider/builtin/twilio/twilio.py | 4 +-
.../provider/builtin/vanna/tools/vanna.py | 5 +-
.../wikipedia/tools/wikipedia_search.py | 2 +-
.../provider/builtin/yahoo/tools/analytics.py | 2 +-
.../provider/builtin/yahoo/tools/news.py | 2 +-
.../provider/builtin/yahoo/tools/ticker.py | 2 +-
.../provider/builtin/youtube/tools/videos.py | 2 +-
.../tools/provider/builtin_tool_provider.py | 70 +++---
api/core/tools/provider/tool_provider.py | 63 ++---
.../tools/provider/workflow_tool_provider.py | 13 +-
api/core/tools/tool/api_tool.py | 14 +-
api/core/tools/tool/builtin_tool.py | 35 ++-
.../dataset_multi_retriever_tool.py | 13 +-
.../dataset_retriever_base_tool.py | 2 +-
.../dataset_retriever_tool.py | 57 +++--
api/core/tools/tool/dataset_retriever_tool.py | 11 +-
api/core/tools/tool/tool.py | 17 +-
api/core/tools/tool/workflow_tool.py | 14 +-
api/core/tools/tool_engine.py | 24 +-
api/core/tools/tool_label_manager.py | 8 +-
api/core/tools/tool_manager.py | 126 ++++++----
api/core/tools/utils/configuration.py | 18 +-
api/core/tools/utils/feishu_api_utils.py | 179 ++++++++------
api/core/tools/utils/lark_api_utils.py | 193 +++++++++------
api/core/tools/utils/message_transformer.py | 12 +-
.../tools/utils/model_invocation_utils.py | 23 +-
api/core/tools/utils/parser.py | 17 +-
api/core/tools/utils/web_reader_tool.py | 15 +-
.../utils/workflow_configuration_sync.py | 4 +-
api/core/tools/utils/yaml_utils.py | 2 +-
api/core/variables/variables.py | 3 +-
.../callbacks/workflow_logging_callback.py | 2 +-
api/core/workflow/entities/node_entities.py | 4 +-
.../condition_handlers/condition_handler.py | 2 +-
.../workflow/graph_engine/entities/graph.py | 16 +-
.../workflow/graph_engine/graph_engine.py | 55 +++--
.../nodes/answer/answer_stream_processor.py | 4 +-
.../nodes/answer/base_stream_processor.py | 8 +-
api/core/workflow/nodes/base/entities.py | 5 +-
api/core/workflow/nodes/code/code_node.py | 2 +-
api/core/workflow/nodes/code/entities.py | 2 +-
.../workflow/nodes/document_extractor/node.py | 7 +-
.../nodes/end/end_stream_generate_router.py | 5 +-
.../nodes/end/end_stream_processor.py | 2 +-
api/core/workflow/nodes/event/event.py | 2 +-
.../workflow/nodes/http_request/executor.py | 44 ++--
api/core/workflow/nodes/http_request/node.py | 8 +-
.../nodes/iteration/iteration_node.py | 15 +-
.../knowledge_retrieval_node.py | 16 +-
api/core/workflow/nodes/list_operator/node.py | 34 +--
api/core/workflow/nodes/llm/node.py | 17 +-
api/core/workflow/nodes/loop/loop_node.py | 6 +-
.../nodes/parameter_extractor/entities.py | 4 +-
.../parameter_extractor_node.py | 16 +-
.../nodes/parameter_extractor/prompts.py | 4 +-
.../question_classifier_node.py | 12 +-
api/core/workflow/nodes/tool/tool_node.py | 9 +-
.../nodes/variable_assigner/v1/node.py | 2 +
.../nodes/variable_assigner/v2/node.py | 6 +-
api/core/workflow/workflow_entry.py | 15 +-
.../event_handlers/create_document_index.py | 2 +-
.../create_site_record_when_app_created.py | 29 +--
.../deduct_quota_when_message_created.py | 2 +-
...rameters_cache_when_sync_draft_workflow.py | 5 +-
...aset_join_when_app_model_config_updated.py | 10 +-
...oin_when_app_published_workflow_updated.py | 10 +-
api/extensions/__init__.py | 0
api/extensions/ext_app_metrics.py | 14 +-
api/extensions/ext_celery.py | 6 +-
api/extensions/ext_compress.py | 2 +-
api/extensions/ext_logging.py | 5 +-
api/extensions/ext_login.py | 2 +-
api/extensions/ext_mail.py | 8 +-
api/extensions/ext_migrate.py | 2 +-
api/extensions/ext_proxy_fix.py | 2 +-
api/extensions/ext_sentry.py | 2 +-
api/extensions/ext_storage.py | 8 +-
api/extensions/storage/aliyun_oss_storage.py | 12 +-
api/extensions/storage/aws_s3_storage.py | 8 +-
api/extensions/storage/azure_blob_storage.py | 8 +-
api/extensions/storage/baidu_obs_storage.py | 9 +-
.../storage/google_cloud_storage.py | 4 +-
api/extensions/storage/huawei_obs_storage.py | 4 +-
api/extensions/storage/opendal_storage.py | 8 +-
api/extensions/storage/oracle_oci_storage.py | 6 +-
api/extensions/storage/supabase_storage.py | 2 +-
api/extensions/storage/tencent_cos_storage.py | 4 +-
.../storage/volcengine_tos_storage.py | 4 +-
api/factories/__init__.py | 0
api/factories/file_factory.py | 5 +-
api/factories/variable_factory.py | 21 +-
api/fields/annotation_fields.py | 2 +-
api/fields/api_based_extension_fields.py | 2 +-
api/fields/app_fields.py | 2 +-
api/fields/conversation_fields.py | 2 +-
api/fields/conversation_variable_fields.py | 2 +-
api/fields/data_source_fields.py | 2 +-
api/fields/dataset_fields.py | 2 +-
api/fields/document_fields.py | 2 +-
api/fields/end_user_fields.py | 2 +-
api/fields/external_dataset_fields.py | 2 +-
api/fields/file_fields.py | 2 +-
api/fields/hit_testing_fields.py | 2 +-
api/fields/installed_app_fields.py | 2 +-
api/fields/member_fields.py | 2 +-
api/fields/message_fields.py | 2 +-
api/fields/raws.py | 2 +-
api/fields/segment_fields.py | 2 +-
api/fields/tag_fields.py | 2 +-
api/fields/workflow_app_log_fields.py | 2 +-
api/fields/workflow_fields.py | 2 +-
api/fields/workflow_run_fields.py | 2 +-
api/libs/external_api.py | 7 +-
api/libs/gmpy2_pkcs10aep_cipher.py | 8 +-
api/libs/helper.py | 6 +-
api/libs/json_in_md_parser.py | 1 +
api/libs/login.py | 15 +-
api/libs/oauth.py | 6 +-
api/libs/oauth_data_source.py | 7 +-
api/libs/threadings_utils.py | 4 +-
api/models/account.py | 32 +--
api/models/api_based_extension.py | 2 +-
api/models/dataset.py | 33 +--
api/models/model.py | 73 +++---
api/models/provider.py | 14 +-
api/models/source.py | 4 +-
api/models/task.py | 6 +-
api/models/tools.py | 22 +-
api/models/web.py | 4 +-
api/models/workflow.py | 19 +-
api/mypy.ini | 10 +
api/poetry.lock | 230 ++++++++++++------
api/pyproject.toml | 3 +
api/schedule/clean_messages.py | 3 +-
api/schedule/clean_unused_datasets_task.py | 6 +-
api/schedule/create_tidb_serverless_task.py | 15 +-
.../update_tidb_serverless_status_task.py | 13 +-
api/services/account_service.py | 31 ++-
.../advanced_prompt_template_service.py | 4 +
api/services/agent_service.py | 13 +-
api/services/annotation_service.py | 14 +-
api/services/app_dsl_service.py | 9 +-
api/services/app_generate_service.py | 4 +-
api/services/app_service.py | 24 +-
api/services/audio_service.py | 6 +
api/services/auth/firecrawl/firecrawl.py | 4 +-
api/services/auth/jina.py | 2 +-
api/services/auth/jina/jina.py | 2 +-
api/services/billing_service.py | 6 +-
api/services/conversation_service.py | 3 +-
api/services/dataset_service.py | 42 +++-
api/services/enterprise/base.py | 4 +-
.../entities/model_provider_entities.py | 8 +-
api/services/external_knowledge_service.py | 39 +--
api/services/file_service.py | 6 +-
api/services/hit_testing_service.py | 17 +-
api/services/knowledge_service.py | 2 +-
api/services/message_service.py | 6 +-
api/services/model_load_balancing_service.py | 49 ++--
api/services/model_provider_service.py | 25 +-
api/services/moderation_service.py | 4 +-
api/services/ops_service.py | 26 +-
.../buildin/buildin_retrieval.py | 8 +-
.../recommend_app/remote/remote_retrieval.py | 6 +-
api/services/recommended_app_service.py | 2 +-
api/services/saved_message_service.py | 6 +
api/services/tag_service.py | 4 +-
.../tools/api_tools_manage_service.py | 33 +--
.../tools/builtin_tools_manage_service.py | 8 +-
api/services/tools/tools_transform_service.py | 34 ++-
.../tools/workflow_tools_manage_service.py | 63 +++--
api/services/web_conversation_service.py | 6 +
api/services/website_service.py | 33 ++-
api/services/workflow/workflow_converter.py | 18 +-
api/services/workflow_run_service.py | 4 +-
api/services/workflow_service.py | 6 +-
api/services/workspace_service.py | 3 +-
api/tasks/__init__.py | 0
api/tasks/add_document_to_index_task.py | 2 +-
.../add_annotation_to_index_task.py | 2 +-
.../batch_import_annotations_task.py | 2 +-
.../delete_annotation_index_task.py | 2 +-
.../disable_annotation_reply_task.py | 2 +-
.../enable_annotation_reply_task.py | 2 +-
.../update_annotation_to_index_task.py | 2 +-
.../batch_create_segment_to_index_task.py | 10 +-
api/tasks/clean_dataset_task.py | 4 +-
api/tasks/clean_document_task.py | 4 +-
api/tasks/clean_notion_document_task.py | 2 +-
api/tasks/create_segment_to_index_task.py | 2 +-
api/tasks/deal_dataset_vector_index_task.py | 2 +-
api/tasks/delete_segment_from_index_task.py | 2 +-
api/tasks/disable_segment_from_index_task.py | 2 +-
api/tasks/document_indexing_sync_task.py | 2 +-
api/tasks/document_indexing_task.py | 2 +-
api/tasks/document_indexing_update_task.py | 2 +-
api/tasks/duplicate_document_indexing_task.py | 4 +-
api/tasks/enable_segment_to_index_task.py | 2 +-
api/tasks/mail_email_code_login.py | 2 +-
api/tasks/mail_invite_member_task.py | 2 +-
api/tasks/mail_reset_password_task.py | 2 +-
api/tasks/ops_trace_task.py | 2 +-
api/tasks/recover_document_indexing_task.py | 2 +-
api/tasks/remove_app_and_related_data_task.py | 2 +-
api/tasks/remove_document_from_index_task.py | 2 +-
api/tasks/retry_document_indexing_task.py | 45 ++--
.../sync_website_document_indexing_task.py | 42 ++--
.../dependencies/test_dependencies_sorted.py | 4 +-
.../controllers/test_controllers.py | 2 +-
.../model_runtime/__mock/google.py | 4 +-
.../model_runtime/__mock/huggingface.py | 2 +-
.../model_runtime/__mock/huggingface_chat.py | 6 +-
.../model_runtime/__mock/nomic_embeddings.py | 2 +-
.../model_runtime/__mock/xinference.py | 4 +-
.../model_runtime/tongyi/test_rerank.py | 2 +-
.../tools/__mock_server/openapi_todo.py | 2 +-
.../vdb/__mock/baiduvectordb.py | 10 +-
.../vdb/__mock/tcvectordb.py | 12 +-
.../integration_tests/vdb/__mock/vikingdb.py | 2 +-
api/tests/unit_tests/oss/__mock/aliyun_oss.py | 4 +-
.../unit_tests/oss/__mock/tencent_cos.py | 4 +-
.../unit_tests/oss/__mock/volcengine_tos.py | 4 +-
.../aliyun_oss/aliyun_oss/test_aliyun_oss.py | 2 +-
.../oss/tencent_cos/test_tencent_cos.py | 2 +-
.../oss/volcengine_tos/test_volcengine_tos.py | 2 +-
.../unit_tests/utils/yaml/test_yaml_utils.py | 2 +-
sdks/python-client/dify_client/client.py | 27 +-
584 files changed, 3980 insertions(+), 2831 deletions(-)
create mode 100644 api/extensions/__init__.py
create mode 100644 api/factories/__init__.py
create mode 100644 api/mypy.ini
create mode 100644 api/tasks/__init__.py
diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml
index 2cd0b2a7d430de..fd98db24b961b4 100644
--- a/.github/workflows/api-tests.yml
+++ b/.github/workflows/api-tests.yml
@@ -56,6 +56,12 @@ jobs:
- name: Run Tool
run: poetry run -C api bash dev/pytest/pytest_tools.sh
+ - name: Run mypy
+ run: |
+ pushd api
+ poetry run python -m mypy --install-types --non-interactive .
+ popd
+
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env
diff --git a/api/commands.py b/api/commands.py
index bf013cc77e0627..ad7ad972f3fd01 100644
--- a/api/commands.py
+++ b/api/commands.py
@@ -159,8 +159,7 @@ def migrate_annotation_vector_database():
try:
# get apps info
apps = (
- db.session.query(App)
- .filter(App.status == "normal")
+ App.query.filter(App.status == "normal")
.order_by(App.created_at.desc())
.paginate(page=page, per_page=50)
)
@@ -285,8 +284,7 @@ def migrate_knowledge_vector_database():
while True:
try:
datasets = (
- db.session.query(Dataset)
- .filter(Dataset.indexing_technique == "high_quality")
+ Dataset.query.filter(Dataset.indexing_technique == "high_quality")
.order_by(Dataset.created_at.desc())
.paginate(page=page, per_page=50)
)
@@ -450,7 +448,8 @@ def convert_to_agent_apps():
if app_id not in proceeded_app_ids:
proceeded_app_ids.append(app_id)
app = db.session.query(App).filter(App.id == app_id).first()
- apps.append(app)
+ if app is not None:
+ apps.append(app)
if len(apps) == 0:
break
@@ -621,6 +620,10 @@ def fix_app_site_missing():
try:
app = db.session.query(App).filter(App.id == app_id).first()
+ if not app:
+ print(f"App {app_id} not found")
+ continue
+
tenant = app.tenant
if tenant:
accounts = tenant.get_accounts()
diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py
index 73f8a95989baaf..74cdf944865796 100644
--- a/api/configs/feature/__init__.py
+++ b/api/configs/feature/__init__.py
@@ -239,7 +239,6 @@ class HttpConfig(BaseSettings):
)
@computed_field
- @property
def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",")
@@ -250,7 +249,6 @@ def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]:
)
@computed_field
- @property
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
@@ -715,27 +713,27 @@ class PositionConfig(BaseSettings):
default="",
)
- @computed_field
+ @property
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""]
- @computed_field
+ @property
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""}
- @computed_field
+ @property
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""}
- @computed_field
+ @property
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""]
- @computed_field
+ @property
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""}
- @computed_field
+ @property
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py
index 9265a48d9bc53c..f6a44eaa471e62 100644
--- a/api/configs/middleware/__init__.py
+++ b/api/configs/middleware/__init__.py
@@ -130,7 +130,6 @@ class DatabaseConfig(BaseSettings):
)
@computed_field
- @property
def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = (
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS
@@ -168,7 +167,6 @@ def SQLALCHEMY_DATABASE_URI(self) -> str:
)
@computed_field
- @property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
return {
"pool_size": self.SQLALCHEMY_POOL_SIZE,
@@ -206,7 +204,6 @@ class CeleryConfig(DatabaseConfig):
)
@computed_field
- @property
def CELERY_RESULT_BACKEND(self) -> str | None:
return (
"db+{}".format(self.SQLALCHEMY_DATABASE_URI)
@@ -214,7 +211,6 @@ def CELERY_RESULT_BACKEND(self) -> str | None:
else self.CELERY_BROKER_URL
)
- @computed_field
@property
def BROKER_USE_SSL(self) -> bool:
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False
diff --git a/api/configs/remote_settings_sources/apollo/client.py b/api/configs/remote_settings_sources/apollo/client.py
index d1f6781ed370dd..03c64ea00f0185 100644
--- a/api/configs/remote_settings_sources/apollo/client.py
+++ b/api/configs/remote_settings_sources/apollo/client.py
@@ -4,6 +4,7 @@
import os
import threading
import time
+from collections.abc import Mapping
from pathlib import Path
from .python_3x import http_request, makedirs_wrapper
@@ -255,8 +256,8 @@ def _listener(self):
logger.info("stopped, long_poll")
# add the need for endorsement to the header
- def _sign_headers(self, url):
- headers = {}
+ def _sign_headers(self, url: str) -> Mapping[str, str]:
+ headers: dict[str, str] = {}
if self.secret == "":
return headers
uri = url[len(self.config_url) : len(url)]
diff --git a/api/constants/model_template.py b/api/constants/model_template.py
index 7e1a196356c4e2..c26d8c018610d0 100644
--- a/api/constants/model_template.py
+++ b/api/constants/model_template.py
@@ -1,8 +1,9 @@
import json
+from collections.abc import Mapping
from models.model import AppMode
-default_app_templates = {
+default_app_templates: Mapping[AppMode, Mapping] = {
# workflow default mode
AppMode.WORKFLOW: {
"app": {
diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py
index 79869916eda062..b1ebc444a51868 100644
--- a/api/controllers/common/fields.py
+++ b/api/controllers/common/fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
parameters__system_parameters = {
"image_file_size_limit": fields.Integer,
diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py
index f46d5b6b138d59..cb6b0d097b1fc9 100644
--- a/api/controllers/console/__init__.py
+++ b/api/controllers/console/__init__.py
@@ -3,6 +3,25 @@
from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportConfirmApi
+from .explore.audio import ChatAudioApi, ChatTextApi
+from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
+from .explore.conversation import (
+ ConversationApi,
+ ConversationListApi,
+ ConversationPinApi,
+ ConversationRenameApi,
+ ConversationUnPinApi,
+)
+from .explore.message import (
+ MessageFeedbackApi,
+ MessageListApi,
+ MessageMoreLikeThisApi,
+ MessageSuggestedQuestionApi,
+)
+from .explore.workflow import (
+ InstalledAppWorkflowRunApi,
+ InstalledAppWorkflowTaskStopApi,
+)
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
@@ -66,15 +85,81 @@
# Import explore controllers
from .explore import (
- audio,
- completion,
- conversation,
installed_app,
- message,
parameter,
recommended_app,
saved_message,
- workflow,
+)
+
+# Explore Audio
+api.add_resource(ChatAudioApi, "/installed-apps/
/audio-to-text", endpoint="installed_app_audio")
+api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text")
+
+# Explore Completion
+api.add_resource(
+ CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion"
+)
+api.add_resource(
+ CompletionStopApi,
+ "/installed-apps//completion-messages//stop",
+ endpoint="installed_app_stop_completion",
+)
+api.add_resource(
+ ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion"
+)
+api.add_resource(
+ ChatStopApi,
+ "/installed-apps//chat-messages//stop",
+ endpoint="installed_app_stop_chat_completion",
+)
+
+# Explore Conversation
+api.add_resource(
+ ConversationRenameApi,
+ "/installed-apps//conversations//name",
+ endpoint="installed_app_conversation_rename",
+)
+api.add_resource(
+ ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations"
+)
+api.add_resource(
+ ConversationApi,
+ "/installed-apps//conversations/",
+ endpoint="installed_app_conversation",
+)
+api.add_resource(
+ ConversationPinApi,
+ "/installed-apps//conversations//pin",
+ endpoint="installed_app_conversation_pin",
+)
+api.add_resource(
+ ConversationUnPinApi,
+ "/installed-apps//conversations//unpin",
+ endpoint="installed_app_conversation_unpin",
+)
+
+
+# Explore Message
+api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages")
+api.add_resource(
+ MessageFeedbackApi,
+ "/installed-apps//messages//feedbacks",
+ endpoint="installed_app_message_feedback",
+)
+api.add_resource(
+ MessageMoreLikeThisApi,
+ "/installed-apps//messages//more-like-this",
+ endpoint="installed_app_more_like_this",
+)
+api.add_resource(
+ MessageSuggestedQuestionApi,
+ "/installed-apps//messages//suggested-questions",
+ endpoint="installed_app_suggested_question",
+)
+# Explore Workflow
+api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run")
+api.add_resource(
+ InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop"
)
# Import tag controllers
diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py
index 8c0bf8710d3964..52e0bb6c56bdc2 100644
--- a/api/controllers/console/admin.py
+++ b/api/controllers/console/admin.py
@@ -1,7 +1,7 @@
from functools import wraps
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py
index 953770868904d3..ca8ddc32094ac5 100644
--- a/api/controllers/console/apikey.py
+++ b/api/controllers/console/apikey.py
@@ -1,5 +1,7 @@
-import flask_restful
-from flask_login import current_user
+from typing import Any
+
+import flask_restful # type: ignore
+from flask_login import current_user # type: ignore
from flask_restful import Resource, fields, marshal_with
from werkzeug.exceptions import Forbidden
@@ -35,14 +37,15 @@ def _get_resource(resource_id, tenant_id, resource_model):
class BaseApiKeyListResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]
- resource_type = None
- resource_model = None
- resource_id_field = None
- token_prefix = None
+ resource_type: str | None = None
+ resource_model: Any = None
+ resource_id_field: str | None = None
+ token_prefix: str | None = None
max_keys = 10
@marshal_with(api_key_list)
def get(self, resource_id):
+ assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = (
@@ -54,6 +57,7 @@ def get(self, resource_id):
@marshal_with(api_key_fields)
def post(self, resource_id):
+ assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.is_editor:
@@ -86,11 +90,12 @@ def post(self, resource_id):
class BaseApiKeyResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required]
- resource_type = None
- resource_model = None
- resource_id_field = None
+ resource_type: str | None = None
+ resource_model: Any = None
+ resource_id_field: str | None = None
def delete(self, resource_id, api_key_id):
+ assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
api_key_id = str(api_key_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py
index c228743fa53591..8d0c5b84af5e37 100644
--- a/api/controllers/console/app/advanced_prompt_template.py
+++ b/api/controllers/console/app/advanced_prompt_template.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py
index d4334158945e16..920cae0d859354 100644
--- a/api/controllers/console/app/agent.py
+++ b/api/controllers/console/app/agent.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from controllers.console import api
from controllers.console.app.wraps import get_app_model
diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py
index fd05cbc19bf04f..24f1020c18ec37 100644
--- a/api/controllers/console/app/annotation.py
+++ b/api/controllers/console/app/annotation.py
@@ -1,6 +1,6 @@
from flask import request
-from flask_login import current_user
-from flask_restful import Resource, marshal, marshal_with, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from controllers.console import api
@@ -110,7 +110,7 @@ def get(self, app_id):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
- keyword = request.args.get("keyword", default=None, type=str)
+ keyword = request.args.get("keyword", default="", type=str)
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py
index da72b704c71bd7..9cd56cef0b7039 100644
--- a/api/controllers/console/app/app.py
+++ b/api/controllers/console/app/app.py
@@ -1,8 +1,8 @@
import uuid
from typing import cast
-from flask_login import current_user
-from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, inputs, marshal, marshal_with, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort
diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py
index 244dcd75de29bc..7e2888d71c79c8 100644
--- a/api/controllers/console/app/app_import.py
+++ b/api/controllers/console/app/app_import.py
@@ -1,7 +1,7 @@
from typing import cast
-from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py
index 695b8890e30f5c..9d26af276d2fc3 100644
--- a/api/controllers/console/app/audio.py
+++ b/api/controllers/console/app/audio.py
@@ -1,7 +1,7 @@
import logging
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import InternalServerError
import services
diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py
index 9896fcaab8ad36..dba41e5c47d24f 100644
--- a/api/controllers/console/app/completion.py
+++ b/api/controllers/console/app/completion.py
@@ -1,7 +1,7 @@
import logging
-import flask_login
-from flask_restful import Resource, reqparse
+import flask_login # type: ignore
+from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import InternalServerError, NotFound
import services
diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py
index a25004be4d16ae..8827f129d99317 100644
--- a/api/controllers/console/app/conversation.py
+++ b/api/controllers/console/app/conversation.py
@@ -1,9 +1,9 @@
from datetime import UTC, datetime
-import pytz
-from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
-from flask_restful.inputs import int_range
+import pytz # pip install pytz
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, marshal_with, reqparse # type: ignore
+from flask_restful.inputs import int_range # type: ignore
from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload
from werkzeug.exceptions import Forbidden, NotFound
@@ -77,8 +77,9 @@ def get(self, app_model):
query = query.where(Conversation.created_at < end_datetime_utc)
+ # FIXME, the type ignore in this file
if args["annotation_status"] == "annotated":
- query = query.options(joinedload(Conversation.message_annotations)).join(
+ query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args["annotation_status"] == "not_annotated":
@@ -222,7 +223,7 @@ def get(self, app_model):
query = query.where(Conversation.created_at <= end_datetime_utc)
if args["annotation_status"] == "annotated":
- query = query.options(joinedload(Conversation.message_annotations)).join(
+ query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args["annotation_status"] == "not_annotated":
@@ -234,7 +235,7 @@ def get(self, app_model):
if args["message_count_gte"] and args["message_count_gte"] >= 1:
query = (
- query.options(joinedload(Conversation.messages))
+ query.options(joinedload(Conversation.messages)) # type: ignore
.join(Message, Message.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(Message.id) >= args["message_count_gte"])
diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py
index d49f433ba1f575..c0a20b7160e719 100644
--- a/api/controllers/console/app/conversation_variables.py
+++ b/api/controllers/console/app/conversation_variables.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, marshal_with, reqparse
+from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py
index 9c3cbe4e3e049e..8518d34a8e5af2 100644
--- a/api/controllers/console/app/generator.py
+++ b/api/controllers/console/app/generator.py
@@ -1,7 +1,7 @@
import os
-from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, reqparse # type: ignore
from controllers.console import api
from controllers.console.app.error import (
diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py
index b7a4c31a156b80..b5828b6b4b08c4 100644
--- a/api/controllers/console/app/message.py
+++ b/api/controllers/console/app/message.py
@@ -1,8 +1,8 @@
import logging
-from flask_login import current_user
-from flask_restful import Resource, fields, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
+from flask_restful.inputs import int_range # type: ignore
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from controllers.console import api
diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py
index a46bc6a8a97606..8ecc8a9db5738d 100644
--- a/api/controllers/console/app/model_config.py
+++ b/api/controllers/console/app/model_config.py
@@ -1,8 +1,9 @@
import json
+from typing import cast
from flask import request
-from flask_login import current_user
-from flask_restful import Resource
+from flask_login import current_user # type: ignore
+from flask_restful import Resource # type: ignore
from controllers.console import api
from controllers.console.app.wraps import get_app_model
@@ -26,7 +27,9 @@ def post(self, app_model):
"""Modify app model config"""
# validate config
model_configuration = AppModelConfigService.validate_configuration(
- tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode)
+ tenant_id=current_user.current_tenant_id,
+ config=cast(dict, request.json),
+ app_mode=AppMode.value_of(app_model.mode),
)
new_app_model_config = AppModelConfig(
@@ -38,9 +41,11 @@ def post(self, app_model):
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
# get original app model config
- original_app_model_config: AppModelConfig = (
+ original_app_model_config = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
)
+ if original_app_model_config is None:
+ raise ValueError("Original app model config not found")
agent_mode = original_app_model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
parameter_map = {}
diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py
index 3f10215e702ac1..dd25af8ebf9312 100644
--- a/api/controllers/console/app/ops_trace.py
+++ b/api/controllers/console/app/ops_trace.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import BadRequest
from controllers.console import api
diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py
index 407f6898199bae..db29b95c4140ff 100644
--- a/api/controllers/console/app/site.py
+++ b/api/controllers/console/app/site.py
@@ -1,7 +1,7 @@
from datetime import UTC, datetime
-from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Forbidden, NotFound
from constants.languages import supported_language
@@ -50,7 +50,7 @@ def post(self, app_model):
if not current_user.is_editor:
raise Forbidden()
- site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404()
+ site = Site.query.filter(Site.app_id == app_model.id).one_or_404()
for attr_name in [
"title",
diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py
index db5e2824095ca0..3b21108ceaf76b 100644
--- a/api/controllers/console/app/statistic.py
+++ b/api/controllers/console/app/statistic.py
@@ -3,8 +3,8 @@
import pytz
from flask import jsonify
-from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, reqparse # type: ignore
from controllers.console import api
from controllers.console.app.wraps import get_app_model
diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py
index f228c3ec4a0e07..26a3a022d401a4 100644
--- a/api/controllers/console/app/workflow.py
+++ b/api/controllers/console/app/workflow.py
@@ -2,7 +2,7 @@
import logging
from flask import abort, request
-from flask_restful import Resource, marshal_with, reqparse
+from flask_restful import Resource, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py
index 2940556f84ef4e..882c53e4fb9972 100644
--- a/api/controllers/console/app/workflow_app_log.py
+++ b/api/controllers/console/app/workflow_app_log.py
@@ -1,5 +1,5 @@
-from flask_restful import Resource, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restful import Resource, marshal_with, reqparse # type: ignore
+from flask_restful.inputs import int_range # type: ignore
from controllers.console import api
from controllers.console.app.wraps import get_app_model
diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py
index 08ab61bbb9c97e..25a99c1e1594ae 100644
--- a/api/controllers/console/app/workflow_run.py
+++ b/api/controllers/console/app/workflow_run.py
@@ -1,5 +1,5 @@
-from flask_restful import Resource, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restful import Resource, marshal_with, reqparse # type: ignore
+from flask_restful.inputs import int_range # type: ignore
from controllers.console import api
from controllers.console.app.wraps import get_app_model
diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py
index 6c7c73707bb204..097bf7d1888cf5 100644
--- a/api/controllers/console/app/workflow_statistic.py
+++ b/api/controllers/console/app/workflow_statistic.py
@@ -3,8 +3,8 @@
import pytz
from flask import jsonify
-from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, reqparse # type: ignore
from controllers.console import api
from controllers.console.app.wraps import get_app_model
diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py
index 63edb83079041e..9ad8c158473df9 100644
--- a/api/controllers/console/app/wraps.py
+++ b/api/controllers/console/app/wraps.py
@@ -8,7 +8,7 @@
from models import App, AppMode
-def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
+def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py
index d2aa7c903b046c..c56f551d49be8b 100644
--- a/api/controllers/console/auth/activate.py
+++ b/api/controllers/console/auth/activate.py
@@ -1,14 +1,14 @@
import datetime
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from constants.languages import supported_language
from controllers.console import api
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.helper import StrLen, email, extract_remote_ip, timezone
-from models.account import AccountStatus, Tenant
+from models.account import AccountStatus
from services.account_service import AccountService, RegisterService
@@ -27,7 +27,7 @@ def get(self):
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
if invitation:
data = invitation.get("data", {})
- tenant: Tenant = invitation.get("tenant", None)
+ tenant = invitation.get("tenant", None)
workspace_name = tenant.name if tenant else None
workspace_id = tenant.id if tenant else None
invitee_email = data.get("email") if data else None
diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py
index 465c44e9b6dc2f..ea00c2b8c2272c 100644
--- a/api/controllers/console/auth/data_source_bearer_auth.py
+++ b/api/controllers/console/auth/data_source_bearer_auth.py
@@ -1,5 +1,5 @@
-from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from controllers.console import api
diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py
index faca67bb177f10..e911c9a5e5b5ea 100644
--- a/api/controllers/console/auth/data_source_oauth.py
+++ b/api/controllers/console/auth/data_source_oauth.py
@@ -2,8 +2,8 @@
import requests
from flask import current_app, redirect, request
-from flask_login import current_user
-from flask_restful import Resource
+from flask_login import current_user # type: ignore
+from flask_restful import Resource # type: ignore
from werkzeug.exceptions import Forbidden
from configs import dify_config
@@ -17,8 +17,8 @@
def get_oauth_providers():
with current_app.app_context():
notion_oauth = NotionOAuth(
- client_id=dify_config.NOTION_CLIENT_ID,
- client_secret=dify_config.NOTION_CLIENT_SECRET,
+ client_id=dify_config.NOTION_CLIENT_ID or "",
+ client_secret=dify_config.NOTION_CLIENT_SECRET or "",
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion",
)
diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py
index fb32bb2b60286d..140b9e145fa9cd 100644
--- a/api/controllers/console/auth/forgot_password.py
+++ b/api/controllers/console/auth/forgot_password.py
@@ -2,7 +2,7 @@
import secrets
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from constants.languages import languages
from controllers.console import api
@@ -122,8 +122,8 @@ def post(self):
else:
try:
account = AccountService.create_account_and_tenant(
- email=reset_data.get("email"),
- name=reset_data.get("email"),
+ email=reset_data.get("email", ""),
+ name=reset_data.get("email", ""),
password=password_confirm,
interface_language=languages[0],
)
diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py
index f4463ce9cb3f30..78a80fc8d7e075 100644
--- a/api/controllers/console/auth/login.py
+++ b/api/controllers/console/auth/login.py
@@ -1,8 +1,8 @@
from typing import cast
-import flask_login
+import flask_login # type: ignore
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
import services
from constants.languages import languages
diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py
index b9188aa0798ea2..333b24142727f0 100644
--- a/api/controllers/console/auth/oauth.py
+++ b/api/controllers/console/auth/oauth.py
@@ -4,7 +4,7 @@
import requests
from flask import current_app, redirect, request
-from flask_restful import Resource
+from flask_restful import Resource # type: ignore
from werkzeug.exceptions import Unauthorized
from configs import dify_config
@@ -77,7 +77,8 @@ def get(self, provider: str):
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except requests.exceptions.RequestException as e:
- logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
+ error_text = e.response.text if e.response else str(e)
+ logging.exception(f"An error occurred during the OAuth process with {provider}: {error_text}")
return {"error": "OAuth process failed"}, 400
if invite_token and RegisterService.is_valid_invite_token(invite_token):
@@ -129,7 +130,7 @@ def get(self, provider: str):
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
- account = Account.get_by_openid(provider, user_info.id)
+ account: Optional[Account] = Account.get_by_openid(provider, user_info.id)
if not account:
account = Account.query.filter_by(email=user_info.email).first()
diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py
index 4b0c82ae6c90c2..fd7b7bd8cb3ddd 100644
--- a/api/controllers/console/billing/billing.py
+++ b/api/controllers/console/billing/billing.py
@@ -1,5 +1,5 @@
-from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, reqparse # type: ignore
from controllers.console import api
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py
index 278295ca39a696..d7c431b95080da 100644
--- a/api/controllers/console/datasets/data_source.py
+++ b/api/controllers/console/datasets/data_source.py
@@ -2,8 +2,8 @@
import json
from flask import request
-from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import NotFound
from controllers.console import api
diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py
index 95d4013e3a8f27..f3c3736b25acc5 100644
--- a/api/controllers/console/datasets/datasets.py
+++ b/api/controllers/console/datasets/datasets.py
@@ -1,7 +1,7 @@
-import flask_restful
+import flask_restful # type: ignore
from flask import request
-from flask_login import current_user
-from flask_restful import Resource, marshal, marshal_with, reqparse
+from flask_login import current_user # type: ignore # type: ignore
+from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Forbidden, NotFound
import services
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index ad4768f51959ac..ca41e504be7eda 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -1,12 +1,13 @@
import logging
from argparse import ArgumentTypeError
from datetime import UTC, datetime
+from typing import cast
from flask import request
-from flask_login import current_user
-from flask_restful import Resource, fields, marshal, marshal_with, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore
from sqlalchemy import asc, desc
-from transformers.hf_argparser import string_to_bool
+from transformers.hf_argparser import string_to_bool # type: ignore
from werkzeug.exceptions import Forbidden, NotFound
import services
@@ -733,8 +734,7 @@ def put(self, dataset_id, document_id):
if not isinstance(doc_metadata, dict):
raise ValueError("doc_metadata must be a dictionary.")
-
- metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
+ metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
document.doc_metadata = {}
if doc_type == "others":
diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py
index 6f7ef86d2c3fd3..2d5933ca23609a 100644
--- a/api/controllers/console/datasets/datasets_segments.py
+++ b/api/controllers/console/datasets/datasets_segments.py
@@ -3,8 +3,8 @@
import pandas as pd
from flask import request
-from flask_login import current_user
-from flask_restful import Resource, marshal, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, marshal, reqparse # type: ignore
from werkzeug.exceptions import Forbidden, NotFound
import services
diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py
index bc6e3687c1c99d..48f360dcd179bc 100644
--- a/api/controllers/console/datasets/external.py
+++ b/api/controllers/console/datasets/external.py
@@ -1,6 +1,6 @@
from flask import request
-from flask_login import current_user
-from flask_restful import Resource, marshal, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, marshal, reqparse # type: ignore
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py
index 495f511275b4b9..18b746f547287c 100644
--- a/api/controllers/console/datasets/hit_testing.py
+++ b/api/controllers/console/datasets/hit_testing.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource
+from flask_restful import Resource # type: ignore
from controllers.console import api
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py
index 3b4c07686361d0..bd944602c147cb 100644
--- a/api/controllers/console/datasets/hit_testing_base.py
+++ b/api/controllers/console/datasets/hit_testing_base.py
@@ -1,7 +1,7 @@
import logging
-from flask_login import current_user
-from flask_restful import marshal, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import marshal, reqparse # type: ignore
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services.dataset_service
diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py
index 9127c8af455f6c..da995537e74753 100644
--- a/api/controllers/console/datasets/website.py
+++ b/api/controllers/console/datasets/website.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from controllers.console import api
from controllers.console.datasets.error import WebsiteCrawlError
diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py
index 9690677f61b1c2..c7f9fec326945f 100644
--- a/api/controllers/console/explore/audio.py
+++ b/api/controllers/console/explore/audio.py
@@ -4,7 +4,6 @@
from werkzeug.exceptions import InternalServerError
import services
-from controllers.console import api
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
@@ -67,7 +66,7 @@ def post(self, installed_app):
class ChatTextApi(InstalledAppResource):
def post(self, installed_app):
- from flask_restful import reqparse
+ from flask_restful import reqparse # type: ignore
app_model = installed_app.app
try:
@@ -118,9 +117,3 @@ def post(self, installed_app):
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
-
-
-api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio")
-api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text")
-# api.add_resource(ChatTextApiWithMessageId, '/installed-apps//text-to-audio/message-id',
-# endpoint='installed_app_text_with_message_id')
diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py
index 85c43f8101028e..3331ded70f6620 100644
--- a/api/controllers/console/explore/completion.py
+++ b/api/controllers/console/explore/completion.py
@@ -1,12 +1,11 @@
import logging
from datetime import UTC, datetime
-from flask_login import current_user
-from flask_restful import reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import reqparse # type: ignore
from werkzeug.exceptions import InternalServerError, NotFound
import services
-from controllers.console import api
from controllers.console.app.error import (
AppUnavailableError,
CompletionRequestError,
@@ -147,21 +146,3 @@ def post(self, installed_app, task_id):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"}, 200
-
-
-api.add_resource(
- CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion"
-)
-api.add_resource(
- CompletionStopApi,
- "/installed-apps//completion-messages//stop",
- endpoint="installed_app_stop_completion",
-)
-api.add_resource(
- ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion"
-)
-api.add_resource(
- ChatStopApi,
- "/installed-apps//chat-messages//stop",
- endpoint="installed_app_stop_chat_completion",
-)
diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py
index 5e7a3da017edd7..91916cbc1ed85f 100644
--- a/api/controllers/console/explore/conversation.py
+++ b/api/controllers/console/explore/conversation.py
@@ -1,10 +1,9 @@
-from flask_login import current_user
-from flask_restful import marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_login import current_user # type: ignore
+from flask_restful import marshal_with, reqparse # type: ignore
+from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
-from controllers.console import api
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -118,28 +117,3 @@ def patch(self, installed_app, c_id):
WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}
-
-
-api.add_resource(
- ConversationRenameApi,
- "/installed-apps//conversations//name",
- endpoint="installed_app_conversation_rename",
-)
-api.add_resource(
- ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations"
-)
-api.add_resource(
- ConversationApi,
- "/installed-apps//conversations/",
- endpoint="installed_app_conversation",
-)
-api.add_resource(
- ConversationPinApi,
- "/installed-apps//conversations//pin",
- endpoint="installed_app_conversation_pin",
-)
-api.add_resource(
- ConversationUnPinApi,
- "/installed-apps//conversations//unpin",
- endpoint="installed_app_conversation_unpin",
-)
diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py
index 3de179164de91d..86550b2bdf44b9 100644
--- a/api/controllers/console/explore/installed_app.py
+++ b/api/controllers/console/explore/installed_app.py
@@ -1,8 +1,9 @@
from datetime import UTC, datetime
+from typing import Any
from flask import request
-from flask_login import current_user
-from flask_restful import Resource, inputs, marshal_with, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore
from sqlalchemy import and_
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
@@ -34,7 +35,7 @@ def get(self):
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
- installed_apps = [
+ installed_app_list: list[dict[str, Any]] = [
{
"id": installed_app.id,
"app": installed_app.app,
@@ -47,7 +48,7 @@ def get(self):
for installed_app in installed_apps
if installed_app.app is not None
]
- installed_apps.sort(
+ installed_app_list.sort(
key=lambda app: (
-app["is_pinned"],
app["last_used_at"] is None,
@@ -55,7 +56,7 @@ def get(self):
)
)
- return {"installed_apps": installed_apps}
+ return {"installed_apps": installed_app_list}
@login_required
@account_initialization_required
diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py
index 4e11d8005f138b..c3488de29929c9 100644
--- a/api/controllers/console/explore/message.py
+++ b/api/controllers/console/explore/message.py
@@ -1,12 +1,11 @@
import logging
-from flask_login import current_user
-from flask_restful import marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_login import current_user # type: ignore
+from flask_restful import marshal_with, reqparse # type: ignore
+from flask_restful.inputs import int_range # type: ignore
from werkzeug.exceptions import InternalServerError, NotFound
import services
-from controllers.console import api
from controllers.console.app.error import (
AppMoreLikeThisDisabledError,
CompletionRequestError,
@@ -153,21 +152,3 @@ def get(self, installed_app, message_id):
raise InternalServerError()
return {"data": questions}
-
-
-api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages")
-api.add_resource(
- MessageFeedbackApi,
- "/installed-apps//messages//feedbacks",
- endpoint="installed_app_message_feedback",
-)
-api.add_resource(
- MessageMoreLikeThisApi,
- "/installed-apps//messages//more-like-this",
- endpoint="installed_app_more_like_this",
-)
-api.add_resource(
- MessageSuggestedQuestionApi,
- "/installed-apps//messages//suggested-questions",
- endpoint="installed_app_suggested_question",
-)
diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py
index fee52248a698e0..5bc74d16e784af 100644
--- a/api/controllers/console/explore/parameter.py
+++ b/api/controllers/console/explore/parameter.py
@@ -1,4 +1,4 @@
-from flask_restful import marshal_with
+from flask_restful import marshal_with # type: ignore
from controllers.common import fields
from controllers.common import helpers as controller_helpers
diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py
index ce85f495aacd50..be6b1f5d215fb4 100644
--- a/api/controllers/console/explore/recommended_app.py
+++ b/api/controllers/console/explore/recommended_app.py
@@ -1,5 +1,5 @@
-from flask_login import current_user
-from flask_restful import Resource, fields, marshal_with, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from constants.languages import languages
from controllers.console import api
diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py
index 0fc963747981e1..9f0c4966457186 100644
--- a/api/controllers/console/explore/saved_message.py
+++ b/api/controllers/console/explore/saved_message.py
@@ -1,6 +1,6 @@
-from flask_login import current_user
-from flask_restful import fields, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_login import current_user # type: ignore
+from flask_restful import fields, marshal_with, reqparse # type: ignore
+from flask_restful.inputs import int_range # type: ignore
from werkzeug.exceptions import NotFound
from controllers.console import api
diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py
index 45f99b1db9fa9e..76d30299cd84a7 100644
--- a/api/controllers/console/explore/workflow.py
+++ b/api/controllers/console/explore/workflow.py
@@ -1,9 +1,8 @@
import logging
-from flask_restful import reqparse
+from flask_restful import reqparse # type: ignore
from werkzeug.exceptions import InternalServerError
-from controllers.console import api
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
@@ -73,9 +72,3 @@ def post(self, installed_app: InstalledApp, task_id: str):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"}
-
-
-api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run")
-api.add_resource(
- InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop"
-)
diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py
index 49ea81a8a0f86d..b7ba81fba20f79 100644
--- a/api/controllers/console/explore/wraps.py
+++ b/api/controllers/console/explore/wraps.py
@@ -1,7 +1,7 @@
from functools import wraps
-from flask_login import current_user
-from flask_restful import Resource
+from flask_login import current_user # type: ignore
+from flask_restful import Resource # type: ignore
from werkzeug.exceptions import NotFound
from controllers.console.wraps import account_initialization_required
diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py
index 4ac0aa497e0866..ed6cedb220cf4b 100644
--- a/api/controllers/console/extension.py
+++ b/api/controllers/console/extension.py
@@ -1,5 +1,5 @@
-from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, marshal_with, reqparse # type: ignore
from constants import HIDDEN_VALUE
from controllers.console import api
diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py
index 70ab4ff865cb48..da1171412fdb2d 100644
--- a/api/controllers/console/feature.py
+++ b/api/controllers/console/feature.py
@@ -1,5 +1,5 @@
-from flask_login import current_user
-from flask_restful import Resource
+from flask_login import current_user # type: ignore
+from flask_restful import Resource # type: ignore
from libs.login import login_required
from services.feature_service import FeatureService
diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py
index ca32d29efaa474..8cf754bbd686fd 100644
--- a/api/controllers/console/files.py
+++ b/api/controllers/console/files.py
@@ -1,6 +1,8 @@
+from typing import Literal
+
from flask import request
-from flask_login import current_user
-from flask_restful import Resource, marshal_with
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, marshal_with # type: ignore
from werkzeug.exceptions import Forbidden
import services
@@ -48,7 +50,8 @@ def get(self):
@cloud_edition_billing_resource_check("documents")
def post(self):
file = request.files["file"]
- source = request.form.get("source")
+ source_str = request.form.get("source")
+ source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
if "file" not in request.files:
raise NoFileUploadedError()
diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py
index ae759bb752a30e..d9ae5cf29fc626 100644
--- a/api/controllers/console/init_validate.py
+++ b/api/controllers/console/init_validate.py
@@ -1,7 +1,7 @@
import os
from flask import session
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from configs import dify_config
from libs.helper import StrLen
diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py
index cd28cc946ee288..2a116112a3227c 100644
--- a/api/controllers/console/ping.py
+++ b/api/controllers/console/ping.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource
+from flask_restful import Resource # type: ignore
from controllers.console import api
diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py
index b8cf019e4f068d..30afc930a8e980 100644
--- a/api/controllers/console/remote_files.py
+++ b/api/controllers/console/remote_files.py
@@ -2,8 +2,8 @@
from typing import cast
import httpx
-from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, marshal_with, reqparse # type: ignore
import services
from controllers.common import helpers
diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py
index e0b728d97739d3..aba6f0aad9ee54 100644
--- a/api/controllers/console/setup.py
+++ b/api/controllers/console/setup.py
@@ -1,5 +1,5 @@
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from configs import dify_config
from libs.helper import StrLen, email, extract_remote_ip
diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py
index ccd3293a6266fc..da83f64019161b 100644
--- a/api/controllers/console/tag/tags.py
+++ b/api/controllers/console/tag/tags.py
@@ -1,6 +1,6 @@
from flask import request
-from flask_login import current_user
-from flask_restful import Resource, marshal_with, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from controllers.console import api
@@ -23,7 +23,7 @@ class TagListApi(Resource):
@account_initialization_required
@marshal_with(tag_fields)
def get(self):
- tag_type = request.args.get("type", type=str)
+ tag_type = request.args.get("type", type=str, default="")
keyword = request.args.get("keyword", default=None, type=str)
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py
index 7dea8e554edd7a..7773c99944e42c 100644
--- a/api/controllers/console/version.py
+++ b/api/controllers/console/version.py
@@ -2,7 +2,7 @@
import logging
import requests
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from packaging import version
from configs import dify_config
diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py
index f704783cfff56b..96ed4b7a570256 100644
--- a/api/controllers/console/workspace/account.py
+++ b/api/controllers/console/workspace/account.py
@@ -2,8 +2,8 @@
import pytz
from flask import request
-from flask_login import current_user
-from flask_restful import Resource, fields, marshal_with, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from configs import dify_config
from constants.languages import supported_language
diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py
index d2b2092b75a9ff..7009343d9923da 100644
--- a/api/controllers/console/workspace/load_balancing_config.py
+++ b/api/controllers/console/workspace/load_balancing_config.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from controllers.console import api
@@ -37,7 +37,7 @@ def post(self, provider: str):
model_load_balancing_service = ModelLoadBalancingService()
result = True
- error = None
+ error = ""
try:
model_load_balancing_service.validate_load_balancing_credentials(
@@ -86,7 +86,7 @@ def post(self, provider: str, config_id: str):
model_load_balancing_service = ModelLoadBalancingService()
result = True
- error = None
+ error = ""
try:
model_load_balancing_service.validate_load_balancing_credentials(
diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py
index 38ed2316a58935..1afb41ea87660c 100644
--- a/api/controllers/console/workspace/members.py
+++ b/api/controllers/console/workspace/members.py
@@ -1,7 +1,7 @@
from urllib import parse
-from flask_login import current_user
-from flask_restful import Resource, abort, marshal_with, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, abort, marshal_with, reqparse # type: ignore
import services
from configs import dify_config
@@ -89,19 +89,19 @@ class MemberCancelInviteApi(Resource):
@account_initialization_required
def delete(self, member_id):
member = db.session.query(Account).filter(Account.id == str(member_id)).first()
- if not member:
+ if member is None:
abort(404)
-
- try:
- TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user)
- except services.errors.account.CannotOperateSelfError as e:
- return {"code": "cannot-operate-self", "message": str(e)}, 400
- except services.errors.account.NoPermissionError as e:
- return {"code": "forbidden", "message": str(e)}, 403
- except services.errors.account.MemberNotInTenantError as e:
- return {"code": "member-not-found", "message": str(e)}, 404
- except Exception as e:
- raise ValueError(str(e))
+ else:
+ try:
+ TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user)
+ except services.errors.account.CannotOperateSelfError as e:
+ return {"code": "cannot-operate-self", "message": str(e)}, 400
+ except services.errors.account.NoPermissionError as e:
+ return {"code": "forbidden", "message": str(e)}, 403
+ except services.errors.account.MemberNotInTenantError as e:
+ return {"code": "member-not-found", "message": str(e)}, 404
+ except Exception as e:
+ raise ValueError(str(e))
return {"result": "success"}, 204
@@ -122,10 +122,11 @@ def put(self, member_id):
return {"code": "invalid-role", "message": "Invalid role"}, 400
member = db.session.get(Account, str(member_id))
- if not member:
+ if member:
abort(404)
try:
+ assert member is not None, "Member not found"
TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user)
except Exception as e:
raise ValueError(str(e))
diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py
index 0e54126063be75..2d11295b0fdf61 100644
--- a/api/controllers/console/workspace/model_providers.py
+++ b/api/controllers/console/workspace/model_providers.py
@@ -1,8 +1,8 @@
import io
from flask import send_file
-from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from controllers.console import api
@@ -66,7 +66,7 @@ def post(self, provider: str):
model_provider_service = ModelProviderService()
result = True
- error = None
+ error = ""
try:
model_provider_service.provider_credentials_validate(
@@ -132,7 +132,8 @@ def get(self, provider: str, icon_type: str, lang: str):
icon_type=icon_type,
lang=lang,
)
-
+ if icon is None:
+ raise ValueError(f"icon not found for provider {provider}, icon_type {icon_type}, lang {lang}")
return send_file(io.BytesIO(icon), mimetype=mimetype)
diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py
index f804285f008120..618262e502ab33 100644
--- a/api/controllers/console/workspace/models.py
+++ b/api/controllers/console/workspace/models.py
@@ -1,7 +1,7 @@
import logging
-from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from controllers.console import api
@@ -308,7 +308,7 @@ def post(self, provider: str):
model_provider_service = ModelProviderService()
result = True
- error = None
+ error = ""
try:
model_provider_service.model_credentials_validate(
diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py
index 9e62a546997b12..964f3862291a2e 100644
--- a/api/controllers/console/workspace/tool_providers.py
+++ b/api/controllers/console/workspace/tool_providers.py
@@ -1,8 +1,8 @@
import io
from flask import send_file
-from flask_login import current_user
-from flask_restful import Resource, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py
index 76d76f6b58fc3c..0f99bf62e3c251 100644
--- a/api/controllers/console/workspace/workspace.py
+++ b/api/controllers/console/workspace/workspace.py
@@ -1,8 +1,8 @@
import logging
from flask import request
-from flask_login import current_user
-from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Unauthorized
import services
@@ -82,11 +82,7 @@ def get(self):
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
- tenants = (
- db.session.query(Tenant)
- .order_by(Tenant.created_at.desc())
- .paginate(page=args["page"], per_page=args["limit"])
- )
+ tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate(page=args["page"], per_page=args["limit"])
has_more = False
if len(tenants.items) == args["limit"]:
@@ -151,6 +147,8 @@ def post(self):
raise AccountNotLinkTenantError("Account not link tenant")
new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant
+ if new_tenant is None:
+ raise ValueError("Tenant not found")
return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
@@ -166,7 +164,7 @@ def post(self):
parser.add_argument("replace_webapp_logo", type=str, location="json")
args = parser.parse_args()
- tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404()
+ tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404()
custom_config_dict = {
"remove_webapp_brand": args["remove_webapp_brand"],
diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py
index d0df296c240686..111db7ccf2da04 100644
--- a/api/controllers/console/wraps.py
+++ b/api/controllers/console/wraps.py
@@ -3,7 +3,7 @@
from functools import wraps
from flask import abort, request
-from flask_login import current_user
+from flask_login import current_user # type: ignore
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
@@ -121,8 +121,8 @@ def decorated(*args, **kwargs):
utm_info = request.cookies.get("utm_info")
if utm_info:
- utm_info = json.loads(utm_info)
- OperationService.record_utm(current_user.current_tenant_id, utm_info)
+ utm_info_dict: dict = json.loads(utm_info)
+ OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
except Exception as e:
pass
return view(*args, **kwargs)
diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py
index 6b3ac93cdf3d8f..2357288a50ae36 100644
--- a/api/controllers/files/image_preview.py
+++ b/api/controllers/files/image_preview.py
@@ -1,5 +1,5 @@
from flask import Response, request
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import NotFound
import services
diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py
index a298701a2f8b11..cfcce8124761f5 100644
--- a/api/controllers/files/tool_files.py
+++ b/api/controllers/files/tool_files.py
@@ -1,5 +1,5 @@
from flask import Response
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden, NotFound
from controllers.files import api
diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py
index 99d32af593991f..d7346b13b10a90 100644
--- a/api/controllers/inner_api/workspace/workspace.py
+++ b/api/controllers/inner_api/workspace/workspace.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from controllers.console.wraps import setup_required
from controllers.inner_api import api
diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py
index 51ffe683ff40ad..d4587235f6aef8 100644
--- a/api/controllers/inner_api/wraps.py
+++ b/api/controllers/inner_api/wraps.py
@@ -45,14 +45,14 @@ def decorated(*args, **kwargs):
if " " in user_id:
user_id = user_id.split(" ")[1]
- inner_api_key = request.headers.get("X-Inner-Api-Key")
+ inner_api_key = request.headers.get("X-Inner-Api-Key", "")
data_to_sign = f"DIFY {user_id}"
signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1)
- signature = b64encode(signature.digest()).decode("utf-8")
+ signature_base64 = b64encode(signature.digest()).decode("utf-8")
- if signature != token:
+ if signature_base64 != token:
return view(*args, **kwargs)
kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first()
diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py
index ecff7d07e974d9..8388e2045dd34f 100644
--- a/api/controllers/service_api/app/app.py
+++ b/api/controllers/service_api/app/app.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, marshal_with
+from flask_restful import Resource, marshal_with # type: ignore
from controllers.common import fields
from controllers.common import helpers as controller_helpers
diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py
index 5db41636471220..e6bcc0bfd25355 100644
--- a/api/controllers/service_api/app/audio.py
+++ b/api/controllers/service_api/app/audio.py
@@ -1,7 +1,7 @@
import logging
from flask import request
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import InternalServerError
import services
@@ -83,7 +83,7 @@ def post(self, app_model: App, end_user: EndUser):
and app_model.workflow
and app_model.workflow.features_dict
):
- text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
+ text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py
index 8d8e356c4cb940..1be54b386bfe8c 100644
--- a/api/controllers/service_api/app/completion.py
+++ b/api/controllers/service_api/app/completion.py
@@ -1,6 +1,6 @@
import logging
-from flask_restful import Resource, reqparse
+from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import InternalServerError, NotFound
import services
diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py
index 32940cbc29f355..334f2c56206794 100644
--- a/api/controllers/service_api/app/conversation.py
+++ b/api/controllers/service_api/app/conversation.py
@@ -1,5 +1,5 @@
-from flask_restful import Resource, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restful import Resource, marshal_with, reqparse # type: ignore
+from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py
index b0fd8e65ef97df..27b21b9f505633 100644
--- a/api/controllers/service_api/app/file.py
+++ b/api/controllers/service_api/app/file.py
@@ -1,5 +1,5 @@
from flask import request
-from flask_restful import Resource, marshal_with
+from flask_restful import Resource, marshal_with # type: ignore
import services
from controllers.common.errors import FilenameNotExistsError
diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py
index 599401bc6f1821..522c7509b9849d 100644
--- a/api/controllers/service_api/app/message.py
+++ b/api/controllers/service_api/app/message.py
@@ -1,7 +1,7 @@
import logging
-from flask_restful import Resource, fields, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
+from flask_restful.inputs import int_range # type: ignore
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py
index 96d1337632826a..c7dd4de3452970 100644
--- a/api/controllers/service_api/app/workflow.py
+++ b/api/controllers/service_api/app/workflow.py
@@ -1,7 +1,7 @@
import logging
-from flask_restful import Resource, fields, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
+from flask_restful.inputs import int_range # type: ignore
from werkzeug.exceptions import InternalServerError
from controllers.service_api import api
diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py
index 799fccc228e21d..d6a3beb6b80b9d 100644
--- a/api/controllers/service_api/dataset/dataset.py
+++ b/api/controllers/service_api/dataset/dataset.py
@@ -1,5 +1,5 @@
from flask import request
-from flask_restful import marshal, reqparse
+from flask_restful import marshal, reqparse # type: ignore
from werkzeug.exceptions import NotFound
import services.dataset_service
diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py
index 5c3fc7b241175a..34afe2837f4ca5 100644
--- a/api/controllers/service_api/dataset/document.py
+++ b/api/controllers/service_api/dataset/document.py
@@ -1,7 +1,7 @@
import json
from flask import request
-from flask_restful import marshal, reqparse
+from flask_restful import marshal, reqparse # type: ignore
from sqlalchemy import desc
from werkzeug.exceptions import NotFound
diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py
index e68f6b4dc40a36..34904574a8b88d 100644
--- a/api/controllers/service_api/dataset/segment.py
+++ b/api/controllers/service_api/dataset/segment.py
@@ -1,5 +1,5 @@
-from flask_login import current_user
-from flask_restful import marshal, reqparse
+from flask_login import current_user # type: ignore
+from flask_restful import marshal, reqparse # type: ignore
from werkzeug.exceptions import NotFound
from controllers.service_api import api
diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py
index d24c4597e210fb..75d9141a6d0a3a 100644
--- a/api/controllers/service_api/index.py
+++ b/api/controllers/service_api/index.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource
+from flask_restful import Resource # type: ignore
from configs import dify_config
from controllers.service_api import api
diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py
index 2128c4c53f9909..740b92ef8e4faf 100644
--- a/api/controllers/service_api/wraps.py
+++ b/api/controllers/service_api/wraps.py
@@ -5,8 +5,8 @@
from typing import Optional
from flask import current_app, request
-from flask_login import user_logged_in
-from flask_restful import Resource
+from flask_login import user_logged_in # type: ignore
+from flask_restful import Resource # type: ignore
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden, Unauthorized
@@ -49,6 +49,8 @@ def decorated_view(*args, **kwargs):
raise Forbidden("The app's API service has been disabled.")
tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
+ if tenant is None:
+ raise ValueError("Tenant does not exist.")
if tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("The workspace's status is archived.")
@@ -154,8 +156,8 @@ def decorated(*args, **kwargs):
# Login admin
if account:
account.current_tenant = tenant
- current_app.login_manager._update_request_context_with_user(account)
- user_logged_in.send(current_app._get_current_object(), user=_get_user())
+ current_app.login_manager._update_request_context_with_user(account) # type: ignore
+ user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
else:
diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py
index cc8255ccf4e748..20e071c834ad5b 100644
--- a/api/controllers/web/app.py
+++ b/api/controllers/web/app.py
@@ -1,4 +1,4 @@
-from flask_restful import marshal_with
+from flask_restful import marshal_with # type: ignore
from controllers.common import fields
from controllers.common import helpers as controller_helpers
diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py
index e8521307ad357a..97d980d07c13a7 100644
--- a/api/controllers/web/audio.py
+++ b/api/controllers/web/audio.py
@@ -65,7 +65,7 @@ def post(self, app_model: App, end_user):
class TextApi(WebApiResource):
def post(self, app_model: App, end_user):
- from flask_restful import reqparse
+ from flask_restful import reqparse # type: ignore
try:
parser = reqparse.RequestParser()
@@ -82,7 +82,7 @@ def post(self, app_model: App, end_user):
and app_model.workflow
and app_model.workflow.features_dict
):
- text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
+ text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py
index 45b890dfc4899d..761771a81a4bb3 100644
--- a/api/controllers/web/completion.py
+++ b/api/controllers/web/completion.py
@@ -1,6 +1,6 @@
import logging
-from flask_restful import reqparse
+from flask_restful import reqparse # type: ignore
from werkzeug.exceptions import InternalServerError, NotFound
import services
diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py
index fe0d7c74f32cff..28feb1ca47effd 100644
--- a/api/controllers/web/conversation.py
+++ b/api/controllers/web/conversation.py
@@ -1,5 +1,5 @@
-from flask_restful import marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restful import marshal_with, reqparse # type: ignore
+from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py
index 0563ed22382e6b..ce841a8814972d 100644
--- a/api/controllers/web/feature.py
+++ b/api/controllers/web/feature.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource
+from flask_restful import Resource # type: ignore
from controllers.web import api
from services.feature_service import FeatureService
diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py
index a282fc63a8b056..1d4474015ab648 100644
--- a/api/controllers/web/files.py
+++ b/api/controllers/web/files.py
@@ -1,5 +1,5 @@
from flask import request
-from flask_restful import marshal_with
+from flask_restful import marshal_with # type: ignore
import services
from controllers.common.errors import FilenameNotExistsError
@@ -33,7 +33,7 @@ def post(self, app_model, end_user):
content=file.read(),
mimetype=file.mimetype,
user=end_user,
- source=source,
+ source="datasets" if source == "datasets" else None,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py
index febaab5328e8b3..0f47e643708570 100644
--- a/api/controllers/web/message.py
+++ b/api/controllers/web/message.py
@@ -1,7 +1,7 @@
import logging
-from flask_restful import fields, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restful import fields, marshal_with, reqparse # type: ignore
+from flask_restful.inputs import int_range # type: ignore
from werkzeug.exceptions import InternalServerError, NotFound
import services
diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py
index a01ffd861230a5..4625c1f43dfbd1 100644
--- a/api/controllers/web/passport.py
+++ b/api/controllers/web/passport.py
@@ -1,7 +1,7 @@
import uuid
from flask import request
-from flask_restful import Resource
+from flask_restful import Resource # type: ignore
from werkzeug.exceptions import NotFound, Unauthorized
from controllers.web import api
diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py
index ae68df6bdc4e48..d559ab8e07e736 100644
--- a/api/controllers/web/remote_files.py
+++ b/api/controllers/web/remote_files.py
@@ -1,7 +1,7 @@
import urllib.parse
import httpx
-from flask_restful import marshal_with, reqparse
+from flask_restful import marshal_with, reqparse # type: ignore
import services
from controllers.common import helpers
diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py
index b0492e6b6f0d31..6a9b8189076c3c 100644
--- a/api/controllers/web/saved_message.py
+++ b/api/controllers/web/saved_message.py
@@ -1,5 +1,5 @@
-from flask_restful import fields, marshal_with, reqparse
-from flask_restful.inputs import int_range
+from flask_restful import fields, marshal_with, reqparse # type: ignore
+from flask_restful.inputs import int_range # type: ignore
from werkzeug.exceptions import NotFound
from controllers.web import api
diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py
index 0564b15ea39855..e68dc7aa4afba5 100644
--- a/api/controllers/web/site.py
+++ b/api/controllers/web/site.py
@@ -1,4 +1,4 @@
-from flask_restful import fields, marshal_with
+from flask_restful import fields, marshal_with # type: ignore
from werkzeug.exceptions import Forbidden
from configs import dify_config
diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py
index 55b0c3e2ab34c5..48d25e720c10c3 100644
--- a/api/controllers/web/workflow.py
+++ b/api/controllers/web/workflow.py
@@ -1,6 +1,6 @@
import logging
-from flask_restful import reqparse
+from flask_restful import reqparse # type: ignore
from werkzeug.exceptions import InternalServerError
from controllers.web import api
diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py
index c327c3df18526c..1b4d263bee4401 100644
--- a/api/controllers/web/wraps.py
+++ b/api/controllers/web/wraps.py
@@ -1,7 +1,7 @@
from functools import wraps
from flask import request
-from flask_restful import Resource
+from flask_restful import Resource # type: ignore
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebSSOAuthRequiredError
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index ead293200ea3aa..8d69bdcec2c2ac 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -1,7 +1,6 @@
import json
import logging
import uuid
-from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Optional, Union, cast
@@ -53,6 +52,7 @@
class BaseAgentRunner(AppRunner):
def __init__(
self,
+ *,
tenant_id: str,
application_generate_entity: AgentChatAppGenerateEntity,
conversation: Conversation,
@@ -66,7 +66,7 @@ def __init__(
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
- model_instance: ModelInstance | None = None,
+ model_instance: ModelInstance,
) -> None:
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
@@ -117,7 +117,7 @@ def __init__(
features = model_schema.features if model_schema and model_schema.features else []
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
- self.query = None
+ self.query: Optional[str] = ""
self._current_thoughts: list[PromptMessage] = []
def _repack_app_generate_entity(
@@ -145,7 +145,7 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P
message_tool = PromptMessageTool(
name=tool.tool_name,
- description=tool_entity.description.llm,
+ description=tool_entity.description.llm if tool_entity.description else "",
parameters={
"type": "object",
"properties": {},
@@ -167,7 +167,7 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
- enum = [option.value for option in parameter.options]
+ enum = [option.value for option in parameter.options] if parameter.options else []
message_tool.parameters["properties"][parameter.name] = {
"type": parameter_type,
@@ -187,8 +187,8 @@ def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRe
convert dataset retriever tool to prompt message tool
"""
prompt_tool = PromptMessageTool(
- name=tool.identity.name,
- description=tool.description.llm,
+ name=tool.identity.name if tool.identity else "unknown",
+ description=tool.description.llm if tool.description else "",
parameters={
"type": "object",
"properties": {},
@@ -210,14 +210,14 @@ def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRe
return prompt_tool
- def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
+ def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
"""
Init tools
"""
tool_instances = {}
prompt_messages_tools = []
- for tool in self.app_config.agent.tools if self.app_config.agent else []:
+ for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception:
@@ -234,7 +234,8 @@ def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessage
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# save tool entity
- tool_instances[dataset_tool.identity.name] = dataset_tool
+ if dataset_tool.identity is not None:
+ tool_instances[dataset_tool.identity.name] = dataset_tool
return tool_instances, prompt_messages_tools
@@ -258,7 +259,7 @@ def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool)
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
- enum = [option.value for option in parameter.options]
+ enum = [option.value for option in parameter.options] if parameter.options else []
prompt_tool.parameters["properties"][parameter.name] = {
"type": parameter_type,
@@ -322,16 +323,21 @@ def save_agent_thought(
tool_name: str,
tool_input: Union[str, dict],
thought: str,
- observation: Union[str, dict],
- tool_invoke_meta: Union[str, dict],
+ observation: Union[str, dict, None],
+ tool_invoke_meta: Union[str, dict, None],
answer: str,
messages_ids: list[str],
- llm_usage: LLMUsage = None,
- ) -> MessageAgentThought:
+ llm_usage: LLMUsage | None = None,
+ ):
"""
Save agent thought
"""
- agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
+ queried_thought = (
+ db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
+ )
+ if not queried_thought:
+ raise ValueError(f"Agent thought {agent_thought.id} not found")
+ agent_thought = queried_thought
if thought is not None:
agent_thought.thought = thought
@@ -404,7 +410,7 @@ def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variab
"""
convert tool variables to db variables
"""
- db_variables = (
+ queried_variables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
@@ -412,6 +418,11 @@ def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variab
.first()
)
+ if not queried_variables:
+ return
+
+ db_variables = queried_variables
+
db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()
@@ -421,7 +432,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
"""
Organize agent history
"""
- result = []
+ result: list[PromptMessage] = []
# check if there is a system message in the beginning of the conversation
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):
diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py
index d98ba5a3fad846..e936acb6055cb8 100644
--- a/api/core/agent/cot_agent_runner.py
+++ b/api/core/agent/cot_agent_runner.py
@@ -1,7 +1,7 @@
import json
from abc import ABC, abstractmethod
-from collections.abc import Generator
-from typing import Optional, Union
+from collections.abc import Generator, Mapping
+from typing import Any, Optional
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentScratchpadUnit
@@ -12,6 +12,7 @@
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
+ PromptMessageTool,
ToolPromptMessage,
UserPromptMessage,
)
@@ -26,18 +27,18 @@
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
- _historic_prompt_messages: list[PromptMessage] = None
- _agent_scratchpad: list[AgentScratchpadUnit] = None
- _instruction: str = None
- _query: str = None
- _prompt_messages_tools: list[PromptMessage] = None
+ _historic_prompt_messages: list[PromptMessage] | None = None
+ _agent_scratchpad: list[AgentScratchpadUnit] | None = None
+ _instruction: str = "" # FIXME this must be str for now
+ _query: str | None = None
+ _prompt_messages_tools: list[PromptMessageTool] = []
def run(
self,
message: Message,
query: str,
- inputs: dict[str, str],
- ) -> Union[Generator, LLMResult]:
+ inputs: Mapping[str, str],
+ ) -> Generator:
"""
Run Cot agent application
"""
@@ -57,19 +58,19 @@ def run(
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template
- self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
+ self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs)
iteration_step = 1
- max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
+ max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
# convert tools into ModelRuntime Tool format
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
function_call_state = True
- llm_usage = {"usage": None}
+ llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = ""
- def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
+ def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
@@ -90,7 +91,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
# the last iteration, remove all tools
self._prompt_messages_tools = []
- message_file_ids = []
+ message_file_ids: list[str] = []
agent_thought = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
@@ -105,7 +106,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
- chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
+ chunks = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
tools=[],
@@ -115,11 +116,14 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
callbacks=[],
)
+ if not isinstance(chunks, Generator):
+ raise ValueError("Expected streaming response from LLM")
+
# check llm result
if not chunks:
raise ValueError("failed to invoke llm")
- usage_dict = {}
+ usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response="",
@@ -139,25 +143,30 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk
# detect action
- scratchpad.agent_response += json.dumps(chunk.model_dump())
+ if scratchpad.agent_response is not None:
+ scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump())
scratchpad.action = action
else:
- scratchpad.agent_response += chunk
- scratchpad.thought += chunk
+ if scratchpad.agent_response is not None:
+ scratchpad.agent_response += chunk
+ if scratchpad.thought is not None:
+ scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
)
-
- scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
- self._agent_scratchpad.append(scratchpad)
+ if scratchpad.thought is not None:
+ scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
+ if self._agent_scratchpad is not None:
+ self._agent_scratchpad.append(scratchpad)
# get llm usage
if "usage" in usage_dict:
- increase_usage(llm_usage, usage_dict["usage"])
+ if usage_dict["usage"] is not None:
+ increase_usage(llm_usage, usage_dict["usage"])
else:
usage_dict["usage"] = LLMUsage.empty_usage()
@@ -166,9 +175,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
tool_name=scratchpad.action.action_name if scratchpad.action else "",
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_invoke_meta={},
- thought=scratchpad.thought,
+ thought=scratchpad.thought or "",
observation="",
- answer=scratchpad.agent_response,
+ answer=scratchpad.agent_response or "",
messages_ids=[],
llm_usage=usage_dict["usage"],
)
@@ -209,7 +218,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
- thought=scratchpad.thought,
+ thought=scratchpad.thought or "",
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response,
@@ -247,8 +256,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
answer=final_answer,
messages_ids=[],
)
-
- self.update_db_variables(self.variables_pool, self.db_variables_pool)
+ if self.variables_pool is not None and self.db_variables_pool is not None:
+ self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
@@ -307,8 +316,9 @@ def _handle_invoke_action(
# publish files
for message_file_id, save_as in message_files:
- if save_as:
- self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
+ if save_as is not None and self.variables_pool:
+ # FIXME the save_as type is confusing, it should be a string or not
+ self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as))
# publish message file
self.queue_manager.publish(
@@ -325,7 +335,7 @@ def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
"""
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
- def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
+ def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str:
"""
fill in inputs from external data tools
"""
@@ -376,11 +386,13 @@ def _organize_historic_prompt_messages(
"""
result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = []
- current_scratchpad: AgentScratchpadUnit = None
+ current_scratchpad: AgentScratchpadUnit | None = None
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
+ if not isinstance(message.content, str | None):
+ raise NotImplementedError("expected str type")
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
@@ -399,8 +411,12 @@ def _organize_historic_prompt_messages(
except:
pass
elif isinstance(message, ToolPromptMessage):
- if current_scratchpad:
+ if not current_scratchpad:
+ continue
+ if isinstance(message.content, str):
current_scratchpad.observation = message.content
+ else:
+ raise NotImplementedError("expected str type")
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py
index d8d047fe91cdbd..6a96c349b2611c 100644
--- a/api/core/agent/cot_chat_agent_runner.py
+++ b/api/core/agent/cot_chat_agent_runner.py
@@ -19,7 +19,12 @@ def _organize_system_prompt(self) -> SystemPromptMessage:
"""
Organize system prompt
"""
+ if not self.app_config.agent:
+ raise ValueError("Agent configuration is not set")
+
prompt_entity = self.app_config.agent.prompt
+ if not prompt_entity:
+ raise ValueError("Agent prompt configuration is not set")
first_prompt = prompt_entity.first_prompt
system_prompt = (
@@ -75,6 +80,7 @@ def _organize_prompt_messages(self) -> list[PromptMessage]:
assistant_messages = []
else:
assistant_message = AssistantPromptMessage(content="")
+ assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
for unit in agent_scratchpad:
if unit.is_final():
assistant_message.content += f"Final Answer: {unit.agent_response}"
diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py
index 0563090537e62c..3a4d31e047f5ae 100644
--- a/api/core/agent/cot_completion_agent_runner.py
+++ b/api/core/agent/cot_completion_agent_runner.py
@@ -2,7 +2,12 @@
from typing import Optional
from core.agent.cot_agent_runner import CotAgentRunner
-from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage
+from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
+ PromptMessage,
+ TextPromptMessageContent,
+ UserPromptMessage,
+)
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -11,7 +16,11 @@ def _organize_instruction_prompt(self) -> str:
"""
Organize instruction prompt
"""
+ if self.app_config.agent is None:
+ raise ValueError("Agent configuration is not set")
prompt_entity = self.app_config.agent.prompt
+ if prompt_entity is None:
+ raise ValueError("prompt entity is not set")
first_prompt = prompt_entity.first_prompt
system_prompt = (
@@ -33,7 +42,13 @@ def _organize_historic_prompt(self, current_session_messages: Optional[list[Prom
if isinstance(message, UserPromptMessage):
historic_prompt += f"Question: {message.content}\n\n"
elif isinstance(message, AssistantPromptMessage):
- historic_prompt += message.content + "\n\n"
+ if isinstance(message.content, str):
+ historic_prompt += message.content + "\n\n"
+ elif isinstance(message.content, list):
+ for content in message.content:
+ if not isinstance(content, TextPromptMessageContent):
+ continue
+ historic_prompt += content.data
return historic_prompt
@@ -50,7 +65,7 @@ def _organize_prompt_messages(self) -> list[PromptMessage]:
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
assistant_prompt = ""
- for unit in agent_scratchpad:
+ for unit in agent_scratchpad or []:
if unit.is_final():
assistant_prompt += f"Final Answer: {unit.agent_response}"
else:
diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py
index 119a88fc7becbf..2ae87dca3f8cbd 100644
--- a/api/core/agent/entities.py
+++ b/api/core/agent/entities.py
@@ -78,5 +78,5 @@ class Strategy(Enum):
model: str
strategy: Strategy
prompt: Optional[AgentPromptEntity] = None
- tools: list[AgentToolEntity] = None
+ tools: list[AgentToolEntity] | None = None
max_iteration: int = 5
diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py
index cd546dee124147..b862c96072aaa0 100644
--- a/api/core/agent/fc_agent_runner.py
+++ b/api/core/agent/fc_agent_runner.py
@@ -40,6 +40,8 @@ def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResul
app_generate_entity = self.application_generate_entity
app_config = self.app_config
+ assert app_config is not None, "app_config is required"
+ assert app_config.agent is not None, "app_config.agent is required"
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
@@ -49,7 +51,7 @@ def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResul
# continue to run until there is not any tool call
function_call_state = True
- llm_usage = {"usage": None}
+ llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()}
final_answer = ""
# get tracing instance
@@ -75,7 +77,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
# the last iteration, remove all tools
prompt_messages_tools = []
- message_file_ids = []
+ message_file_ids: list[str] = []
agent_thought = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
@@ -105,7 +107,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
current_llm_usage = None
- if self.stream_tool_call:
+ if self.stream_tool_call and isinstance(chunks, Generator):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
@@ -116,7 +118,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
- tool_calls.extend(self.extract_tool_calls(chunk))
+ tool_calls.extend(self.extract_tool_calls(chunk) or [])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
@@ -131,19 +133,19 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
for content in chunk.delta.message.content:
response += content.data
else:
- response += chunk.delta.message.content
+ response += str(chunk.delta.message.content)
if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
yield chunk
- else:
- result: LLMResult = chunks
+ elif not self.stream_tool_call and isinstance(chunks, LLMResult):
+ result = chunks
# check if there is any tool call
if self.check_blocking_tool_calls(result):
function_call_state = True
- tool_calls.extend(self.extract_blocking_tool_calls(result))
+ tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
@@ -162,7 +164,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
for content in result.message.content:
response += content.data
else:
- response += result.message.content
+ response += str(result.message.content)
if not result.message.content:
result.message.content = ""
@@ -181,6 +183,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
usage=result.usage,
),
)
+ else:
+ raise RuntimeError(f"invalid chunks type: {type(chunks)}")
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if tool_calls:
@@ -243,7 +247,10 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
# publish files
for message_file_id, save_as in message_files:
if save_as:
- self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
+ if self.variables_pool:
+ self.variables_pool.set_file(
+ tool_name=tool_call_name, value=message_file_id, name=save_as
+ )
# publish message file
self.queue_manager.publish(
@@ -263,7 +270,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if tool_response["tool_response"] is not None:
self._current_thoughts.append(
ToolPromptMessage(
- content=tool_response["tool_response"],
+ content=str(tool_response["tool_response"]),
tool_call_id=tool_call_id,
name=tool_call_name,
)
@@ -273,9 +280,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
- tool_name=None,
- tool_input=None,
- thought=None,
+ tool_name="",
+ tool_input="",
+ thought="",
tool_invoke_meta={
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
},
@@ -283,7 +290,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
tool_response["tool_call_name"]: tool_response["tool_response"]
for tool_response in tool_responses
},
- answer=None,
+ answer="",
messages_ids=message_file_ids,
)
self.queue_manager.publish(
@@ -296,7 +303,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
iteration_step += 1
- self.update_db_variables(self.variables_pool, self.db_variables_pool)
+ if self.variables_pool and self.db_variables_pool:
+ self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
@@ -389,9 +397,9 @@ def _init_system_message(
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
- return prompt_messages
+ return prompt_messages or []
- def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
+ def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
@@ -449,7 +457,7 @@ def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]
def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
- query_prompt_messages = self._organize_user_query(self.query, [])
+ query_prompt_messages = self._organize_user_query(self.query or "", [])
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py
index 085bac8601b2da..61fa774ea5f390 100644
--- a/api/core/agent/output_parser/cot_output_parser.py
+++ b/api/core/agent/output_parser/cot_output_parser.py
@@ -38,7 +38,7 @@ def parse_action(json_str):
except:
return json_str or ""
- def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
+ def extra_json_from_code_block(code_block) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
if not code_blocks:
return
@@ -67,15 +67,15 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None,
for response in llm_response:
if response.delta.usage:
usage_dict["usage"] = response.delta.usage
- response = response.delta.message.content
- if not isinstance(response, str):
+ response_content = response.delta.message.content
+ if not isinstance(response_content, str):
continue
# stream
index = 0
- while index < len(response):
+ while index < len(response_content):
steps = 1
- delta = response[index : index + steps]
+ delta = response_content[index : index + steps]
yield_delta = False
if delta == "`":
diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py
index b9aae7904f5e7c..646c4badb9f73a 100644
--- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py
+++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py
@@ -66,6 +66,8 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]:
dataset_configs = config.get("dataset_configs")
else:
dataset_configs = {"retrieval_model": "multiple"}
+ if dataset_configs is None:
+ return None
query_variable = config.get("dataset_query_variable")
if dataset_configs["retrieval_model"] == "single":
diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py
index 5adcf26f1486e8..6426865115126f 100644
--- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py
+++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py
@@ -94,7 +94,7 @@ def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) ->
config["model"]["completion_params"]
)
- return config, ["model"]
+ return dict(config), ["model"]
@classmethod
def validate_model_completion_params(cls, cp: dict) -> dict:
diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py
index b4dacbc409044a..92b4185abf0183 100644
--- a/api/core/app/app_config/features/opening_statement/manager.py
+++ b/api/core/app/app_config/features/opening_statement/manager.py
@@ -7,10 +7,10 @@ def convert(cls, config: dict) -> tuple[str, list]:
:param config: model config args
"""
# opening statement
- opening_statement = config.get("opening_statement")
+ opening_statement = config.get("opening_statement", "")
# suggested questions
- suggested_questions_list = config.get("suggested_questions")
+ suggested_questions_list = config.get("suggested_questions", [])
return opening_statement, suggested_questions_list
diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py
index 6200299d21c869..a18b40712b7ce6 100644
--- a/api/core/app/apps/advanced_chat/app_generator.py
+++ b/api/core/app/apps/advanced_chat/app_generator.py
@@ -29,6 +29,7 @@
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
+from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@@ -145,7 +146,7 @@ def generate(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
- files=file_objs,
+ files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
@@ -313,6 +314,8 @@ def _generate_worker(
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
+ if message is None:
+ raise MessageNotExistsError("Message not exists")
# chatbot app
runner = AdvancedChatAppRunner(
diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py
index 29709914b7cfb8..a506447671abfb 100644
--- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py
+++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py
@@ -5,6 +5,7 @@
import re
import threading
from collections.abc import Iterable
+from typing import Optional
from core.app.entities.queue_entities import (
MessageQueueMessage,
@@ -15,6 +16,7 @@
WorkflowQueueMessage,
)
from core.model_manager import ModelInstance, ModelManager
+from core.model_runtime.entities.message_entities import TextPromptMessageContent
from core.model_runtime.entities.model_entities import ModelType
@@ -71,8 +73,9 @@ def __init__(self, tenant_id: str, voice: str):
if not voice or voice not in values:
self.voice = self.voices[0].get("value")
self.MAX_SENTENCE = 2
- self._last_audio_event = None
- self._runtime_thread = threading.Thread(target=self._runtime).start()
+ self._last_audio_event: Optional[AudioTrunk] = None
+ # FIXME better way to handle this threading.start
+ threading.Thread(target=self._runtime).start()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
@@ -92,10 +95,21 @@ def _runtime(self):
future_queue.put(futures_result)
break
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
- self.msg_text += message.event.chunk.delta.message.content
+ message_content = message.event.chunk.delta.message.content
+ if not message_content:
+ continue
+ if isinstance(message_content, str):
+ self.msg_text += message_content
+ elif isinstance(message_content, list):
+ for content in message_content:
+ if not isinstance(content, TextPromptMessageContent):
+ continue
+ self.msg_text += content.data
elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text
elif isinstance(message.event, QueueNodeSucceededEvent):
+ if message.event.outputs is None:
+ continue
self.msg_text += message.event.outputs.get("output", "")
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
@@ -121,11 +135,10 @@ def check_and_get_audio(self):
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:
self.executor.shutdown(wait=False)
- return self.last_message
+ return self._last_audio_event
audio = self._audio_queue.get_nowait()
if audio and audio.status == "finish":
self.executor.shutdown(wait=False)
- self._runtime_thread = None
if audio:
self._last_audio_event = audio
return audio
diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py
index cf0c9d7593429a..6339d798984800 100644
--- a/api/core/app/apps/advanced_chat/app_runner.py
+++ b/api/core/app/apps/advanced_chat/app_runner.py
@@ -109,18 +109,18 @@ def run(self) -> None:
ConversationVariable.conversation_id == self.conversation.id,
)
with Session(db.engine) as session:
- conversation_variables = session.scalars(stmt).all()
- if not conversation_variables:
+ db_conversation_variables = session.scalars(stmt).all()
+ if not db_conversation_variables:
# Create conversation variables if they don't exist.
- conversation_variables = [
+ db_conversation_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
- session.add_all(conversation_variables)
+ session.add_all(db_conversation_variables)
# Convert database entities to variables.
- conversation_variables = [item.to_variable() for item in conversation_variables]
+ conversation_variables = [item.to_variable() for item in db_conversation_variables]
session.commit()
diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
index 635e482ad980ed..1073a0f2e4f706 100644
--- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py
+++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
@@ -2,6 +2,7 @@
import logging
import time
from collections.abc import Generator, Mapping
+from threading import Thread
from typing import Any, Optional, Union
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@@ -64,6 +65,7 @@
from models.workflow import (
Workflow,
WorkflowNodeExecution,
+ WorkflowRun,
WorkflowRunStatus,
)
@@ -81,6 +83,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
+ _conversation_name_generate_thread: Optional[Thread] = None
def __init__(
self,
@@ -131,7 +134,7 @@ def __init__(
self._conversation_name_generate_thread = None
self._recorded_files: list[Mapping[str, Any]] = []
- def process(self):
+ def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
:return:
@@ -262,8 +265,8 @@ def _process_stream_response(
:return:
"""
# init fake graph runtime state
- graph_runtime_state = None
- workflow_run = None
+ graph_runtime_state: Optional[GraphRuntimeState] = None
+ workflow_run: Optional[WorkflowRun] = None
for queue_message in self._queue_manager.listen():
event = queue_message.event
@@ -315,14 +318,14 @@ def _process_stream_response(
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
- response = self._workflow_node_start_to_stream_response(
+ response_start = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
- if response:
- yield response
+ if response_start:
+ yield response_start
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
@@ -330,18 +333,18 @@ def _process_stream_response(
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
- response = self._workflow_node_finish_to_stream_response(
+ response_finish = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
- if response:
- yield response
+ if response_finish:
+ yield response_finish
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
- response = self._workflow_node_finish_to_stream_response(
+ response_finish = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@@ -609,7 +612,10 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
del extras["metadata"]["annotation_reply"]
return MessageEndStreamResponse(
- task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
+ task_id=self._application_generate_entity.task_id,
+ id=self._message.id,
+ files=self._recorded_files,
+ metadata=extras.get("metadata", {}),
)
def _handle_output_moderation_chunk(self, text: str) -> bool:
diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py
index 417d23eccfb553..55b6ee510f228c 100644
--- a/api/core/app/apps/agent_chat/app_config_manager.py
+++ b/api/core/app/apps/agent_chat/app_config_manager.py
@@ -61,7 +61,7 @@ def get_app_config(
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
- config_dict = override_config_dict
+ config_dict = override_config_dict or {}
app_mode = AppMode.value_of(app_model.mode)
app_config = AgentChatAppConfig(
diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py
index b391169e3dbe5c..63e11bdaa27f74 100644
--- a/api/core/app/apps/agent_chat/app_generator.py
+++ b/api/core/app/apps/agent_chat/app_generator.py
@@ -23,6 +23,7 @@
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser
+from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@@ -97,7 +98,7 @@ def generate(
# get conversation
conversation = None
if args.get("conversation_id"):
- conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
+ conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
@@ -153,7 +154,7 @@ def generate(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
- files=file_objs,
+ files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
@@ -180,7 +181,7 @@ def generate(
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
- "flask_app": current_app._get_current_object(),
+ "flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
@@ -199,8 +200,8 @@ def generate(
user=user,
stream=streaming,
)
-
- return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
+ # FIXME: Type hinting issue here, ignore it for now, will fix it later
+ return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
def _generate_worker(
self,
@@ -224,6 +225,8 @@ def _generate_worker(
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
+ if message is None:
+ raise MessageNotExistsError("Message not exists")
# chatbot app
runner = AgentChatAppRunner()
diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py
index 45b1bf00934d35..ac71f02b6de03d 100644
--- a/api/core/app/apps/agent_chat/app_runner.py
+++ b/api/core/app/apps/agent_chat/app_runner.py
@@ -173,6 +173,8 @@ def run(
return
agent_entity = app_config.agent
+ if not agent_entity:
+ raise ValueError("Agent entity not found")
# load tool variables
tool_conversation_variables = self._load_tool_variables(
@@ -200,14 +202,21 @@ def run(
# change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
+ if not model_schema or not model_schema.features:
+ raise ValueError("Model schema not found")
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
- conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
- message = db.session.query(Message).filter(Message.id == message.id).first()
+ conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
+ if conversation_result is None:
+ raise ValueError("Conversation not found")
+ message_result = db.session.query(Message).filter(Message.id == message.id).first()
+ if message_result is None:
+ raise ValueError("Message not found")
db.session.close()
+ runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
# check LLM mode
@@ -225,12 +234,12 @@ def run(
runner = runner_cls(
tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity,
- conversation=conversation,
+ conversation=conversation_result,
app_config=app_config,
model_config=application_generate_entity.model_conf,
config=agent_entity,
queue_manager=queue_manager,
- message=message,
+ message=message_result,
user_id=application_generate_entity.user_id,
memory=memory,
prompt_messages=prompt_message,
@@ -257,7 +266,7 @@ def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: st
"""
load tool variables from database
"""
- tool_variables: ToolConversationVariables = (
+ tool_variables: ToolConversationVariables | None = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == conversation_id,
diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py
index 629c309c065458..ce331d904cc826 100644
--- a/api/core/app/apps/agent_chat/generate_response_converter.py
+++ b/api/core/app/apps/agent_chat/generate_response_converter.py
@@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
- def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
+ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -37,7 +37,7 @@ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingRes
return response
@classmethod
- def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
+ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -51,8 +51,9 @@ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingR
return response
@classmethod
- def convert_stream_full_response(
- cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
+ def convert_stream_full_response( # type: ignore[override]
+ cls,
+ stream_response: Generator[ChatbotAppStreamResponse, None, None],
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -82,8 +83,9 @@ def convert_stream_full_response(
yield json.dumps(response_chunk)
@classmethod
- def convert_stream_simple_response(
- cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
+ def convert_stream_simple_response( # type: ignore[override]
+ cls,
+ stream_response: Generator[ChatbotAppStreamResponse, None, None],
) -> Generator[str, None, None]:
"""
Convert stream simple response.
diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py
index 3725c6e6ddc4fd..1842fc43033ab8 100644
--- a/api/core/app/apps/base_app_queue_manager.py
+++ b/api/core/app/apps/base_app_queue_manager.py
@@ -50,7 +50,7 @@ def listen(self):
# wait for APP_MAX_EXECUTION_TIME seconds to stop listen
listen_timeout = dify_config.APP_MAX_EXECUTION_TIME
start_time = time.time()
- last_ping_time = 0
+ last_ping_time: int | float = 0
while True:
try:
message = self._q.get(timeout=1)
diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py
index 609fd03f229da8..07a248d77aee86 100644
--- a/api/core/app/apps/base_app_runner.py
+++ b/api/core/app/apps/base_app_runner.py
@@ -1,5 +1,5 @@
import time
-from collections.abc import Generator, Mapping
+from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
@@ -36,8 +36,8 @@ def get_pre_calculate_rest_tokens(
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
- inputs: dict[str, str],
- files: list["File"],
+ inputs: Mapping[str, str],
+ files: Sequence["File"],
query: Optional[str] = None,
) -> int:
"""
@@ -64,7 +64,7 @@ def get_pre_calculate_rest_tokens(
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
- or model_config.parameters.get(parameter_rule.use_template)
+ or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
@@ -85,7 +85,7 @@ def get_pre_calculate_rest_tokens(
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
- rest_tokens = model_context_tokens - max_tokens - prompt_tokens
+ rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise InvokeBadRequestError(
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
@@ -111,7 +111,7 @@ def recalc_llm_max_tokens(
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
- or model_config.parameters.get(parameter_rule.use_template)
+ or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
@@ -136,8 +136,8 @@ def organize_prompt_messages(
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
- inputs: dict[str, str],
- files: list["File"],
+ inputs: Mapping[str, str],
+ files: Sequence["File"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,
@@ -156,6 +156,7 @@ def organize_prompt_messages(
"""
# get prompt without memory and context
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
+ prompt_transform: Union[SimplePromptTransform, AdvancedPromptTransform]
prompt_transform = SimplePromptTransform()
prompt_messages, stop = prompt_transform.get_prompt(
app_mode=AppMode.value_of(app_record.mode),
@@ -171,8 +172,11 @@ def organize_prompt_messages(
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
model_mode = ModelMode.value_of(model_config.mode)
+ prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
+ if not advanced_completion_prompt_template:
+ raise InvokeBadRequestError("Advanced completion prompt template is required.")
prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
if advanced_completion_prompt_template.role_prefix:
@@ -181,6 +185,8 @@ def organize_prompt_messages(
assistant=advanced_completion_prompt_template.role_prefix.assistant,
)
else:
+ if not prompt_template_entity.advanced_chat_prompt_template:
+ raise InvokeBadRequestError("Advanced chat prompt template is required.")
prompt_template = []
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
@@ -246,7 +252,7 @@ def direct_output(
def _handle_invoke_result(
self,
- invoke_result: Union[LLMResult, Generator],
+ invoke_result: Union[LLMResult, Generator[Any, None, None]],
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
@@ -259,10 +265,12 @@ def _handle_invoke_result(
:param agent: agent
:return:
"""
- if not stream:
+ if not stream and isinstance(invoke_result, LLMResult):
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
- else:
+ elif stream and isinstance(invoke_result, Generator):
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
+ else:
+ raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
def _handle_invoke_result_direct(
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
@@ -291,8 +299,8 @@ def _handle_invoke_result_stream(
:param agent: agent
:return:
"""
- model = None
- prompt_messages = []
+ model: str = ""
+ prompt_messages: list[PromptMessage] = []
text = ""
usage = None
for result in invoke_result:
@@ -328,13 +336,14 @@ def _handle_invoke_result_stream(
def moderation_for_inputs(
self,
+ *,
app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
- query: str,
+ query: str | None = None,
message_id: str,
- ) -> tuple[bool, dict, str]:
+ ) -> tuple[bool, Mapping[str, Any], str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
@@ -350,7 +359,7 @@ def moderation_for_inputs(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
- inputs=inputs,
+ inputs=dict(inputs),
query=query or "",
message_id=message_id,
trace_manager=app_generate_entity.trace_manager,
@@ -390,9 +399,9 @@ def fill_in_inputs_from_external_data_tools(
tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
- inputs: dict,
+ inputs: Mapping[str, Any],
query: str,
- ) -> dict:
+ ) -> Mapping[str, Any]:
"""
Fill in variable inputs from external data tools if exists.
diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py
index 5b8debaaae6a56..6ed71fcd843083 100644
--- a/api/core/app/apps/chat/app_generator.py
+++ b/api/core/app/apps/chat/app_generator.py
@@ -24,6 +24,7 @@
from factories import file_factory
from models.account import Account
from models.model import App, EndUser
+from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@@ -91,7 +92,7 @@ def generate(
# get conversation
conversation = None
if args.get("conversation_id"):
- conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
+ conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
@@ -104,7 +105,7 @@ def generate(
# validate config
override_model_config_dict = ChatAppConfigManager.config_validate(
- tenant_id=app_model.tenant_id, config=args.get("model_config")
+ tenant_id=app_model.tenant_id, config=args.get("model_config", {})
)
# always enable retriever resource in debugger mode
@@ -146,7 +147,7 @@ def generate(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
- files=file_objs,
+ files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
invoke_from=invoke_from,
@@ -172,7 +173,7 @@ def generate(
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
- "flask_app": current_app._get_current_object(),
+ "flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
@@ -216,6 +217,8 @@ def _generate_worker(
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
+ if message is None:
+ raise MessageNotExistsError("Message not exists")
# chatbot app
runner = ChatAppRunner()
diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py
index 0fa7af0a7fa36d..9024c3a98273d1 100644
--- a/api/core/app/apps/chat/generate_response_converter.py
+++ b/api/core/app/apps/chat/generate_response_converter.py
@@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
- def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
+ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -37,7 +37,7 @@ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingRes
return response
@classmethod
- def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict:
+ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -52,7 +52,8 @@ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingR
@classmethod
def convert_stream_full_response(
- cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
+ cls,
+ stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -83,7 +84,8 @@ def convert_stream_full_response(
@classmethod
def convert_stream_simple_response(
- cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
+ cls,
+ stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream simple response.
diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py
index 1193c4b7a43632..02e5d475684cdc 100644
--- a/api/core/app/apps/completion/app_config_manager.py
+++ b/api/core/app/apps/completion/app_config_manager.py
@@ -42,7 +42,7 @@ def get_app_config(
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
- config_dict = override_config_dict
+ config_dict = override_config_dict or {}
app_mode = AppMode.value_of(app_model.mode)
app_config = CompletionAppConfig(
diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py
index 14fd33dd398927..17d0d52497ceee 100644
--- a/api/core/app/apps/completion/app_generator.py
+++ b/api/core/app/apps/completion/app_generator.py
@@ -83,8 +83,6 @@ def generate(
query = query.replace("\x00", "")
inputs = args["inputs"]
- extras = {}
-
# get conversation
conversation = None
@@ -99,7 +97,7 @@ def generate(
# validate config
override_model_config_dict = CompletionAppConfigManager.config_validate(
- tenant_id=app_model.tenant_id, config=args.get("model_config")
+ tenant_id=app_model.tenant_id, config=args.get("model_config", {})
)
# parse files
@@ -132,11 +130,11 @@ def generate(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
- files=file_objs,
+ files=list(file_objs),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
- extras=extras,
+ extras={},
trace_manager=trace_manager,
)
@@ -157,7 +155,7 @@ def generate(
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
- "flask_app": current_app._get_current_object(),
+ "flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"message_id": message.id,
@@ -197,6 +195,8 @@ def _generate_worker(
try:
# get message
message = self._get_message(message_id)
+ if message is None:
+ raise MessageNotExistsError()
# chatbot app
runner = CompletionAppRunner()
@@ -231,7 +231,7 @@ def generate_more_like_this(
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True,
- ) -> Union[dict, Generator[str, None, None]]:
+ ) -> Union[Mapping[str, Any], Generator[str, None, None]]:
"""
Generate App response.
@@ -293,7 +293,7 @@ def generate_more_like_this(
model_conf=ModelConfigConverter.convert(app_config),
inputs=message.inputs,
query=message.query,
- files=file_objs,
+ files=list(file_objs),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
@@ -317,7 +317,7 @@ def generate_more_like_this(
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
- "flask_app": current_app._get_current_object(),
+ "flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"message_id": message.id,
diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py
index 908d74ff539a5a..41278b75b42bf4 100644
--- a/api/core/app/apps/completion/app_runner.py
+++ b/api/core/app/apps/completion/app_runner.py
@@ -76,7 +76,7 @@ def run(
tenant_id=app_config.tenant_id,
app_generate_entity=application_generate_entity,
inputs=inputs,
- query=query,
+ query=query or "",
message_id=message.id,
)
except ModerationError as e:
@@ -122,7 +122,7 @@ def run(
tenant_id=app_record.tenant_id,
model_config=application_generate_entity.model_conf,
config=dataset_config,
- query=query,
+ query=query or "",
invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source,
hit_callback=hit_callback,
diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py
index 697f0273a5673e..73f38c3d0bcb96 100644
--- a/api/core/app/apps/completion/generate_response_converter.py
+++ b/api/core/app/apps/completion/generate_response_converter.py
@@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse
@classmethod
- def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict:
+ def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -36,7 +36,7 @@ def convert_blocking_full_response(cls, blocking_response: CompletionAppBlocking
return response
@classmethod
- def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict:
+ def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -51,7 +51,8 @@ def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlocki
@classmethod
def convert_stream_full_response(
- cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
+ cls,
+ stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -81,7 +82,8 @@ def convert_stream_full_response(
@classmethod
def convert_stream_simple_response(
- cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
+ cls,
+ stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream simple response.
diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py
index 95ae798ec1ac74..c2e35faf89ba15 100644
--- a/api/core/app/apps/message_based_app_generator.py
+++ b/api/core/app/apps/message_based_app_generator.py
@@ -2,11 +2,11 @@
import logging
from collections.abc import Generator
from datetime import UTC, datetime
-from typing import Optional, Union
+from typing import Optional, Union, cast
from sqlalchemy import and_
-from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
+from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import (
@@ -42,7 +42,7 @@ def _handle_response(
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
- AdvancedChatAppGenerateEntity,
+ AgentChatAppGenerateEntity,
],
queue_manager: AppQueueManager,
conversation: Conversation,
@@ -144,7 +144,7 @@ def _init_generate_records(
:conversation conversation
:return:
"""
- app_config = application_generate_entity.app_config
+ app_config: EasyUIBasedAppConfig = cast(EasyUIBasedAppConfig, application_generate_entity.app_config)
# get from source
end_user_id = None
@@ -267,7 +267,7 @@ def _get_conversation_introduction(self, application_generate_entity: AppGenerat
except KeyError:
pass
- return introduction
+ return introduction or ""
def _get_conversation(self, conversation_id: str):
"""
@@ -282,7 +282,7 @@ def _get_conversation(self, conversation_id: str):
return conversation
- def _get_message(self, message_id: str) -> Message:
+ def _get_message(self, message_id: str) -> Optional[Message]:
"""
Get message by message id
:param message_id: message id
diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py
index dc4ee9e566a2f3..1d5f21b9e0cc07 100644
--- a/api/core/app/apps/workflow/app_generator.py
+++ b/api/core/app/apps/workflow/app_generator.py
@@ -116,7 +116,7 @@ def generate(
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
- files=system_files,
+ files=list(system_files),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py
index 08d00ee1805aa2..5cdac6ad28fdaa 100644
--- a/api/core/app/apps/workflow/generate_response_converter.py
+++ b/api/core/app/apps/workflow/generate_response_converter.py
@@ -17,16 +17,16 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
- def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict:
+ def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
- return blocking_response.to_dict()
+ return dict(blocking_response.to_dict())
@classmethod
- def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict:
+ def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@@ -36,7 +36,8 @@ def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlocking
@classmethod
def convert_stream_full_response(
- cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
+ cls,
+ stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream full response.
@@ -65,7 +66,8 @@ def convert_stream_full_response(
@classmethod
def convert_stream_simple_response(
- cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
+ cls,
+ stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
) -> Generator[str, None, None]:
"""
Convert stream simple response.
diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py
index 885283504b4175..63f516bcc60682 100644
--- a/api/core/app/apps/workflow_app_runner.py
+++ b/api/core/app/apps/workflow_app_runner.py
@@ -24,6 +24,7 @@
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
+from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
@@ -190,16 +191,15 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent)
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.route_node_state.node_run_result
+ inputs: Mapping[str, Any] | None = {}
+ process_data: Mapping[str, Any] | None = {}
+ outputs: Mapping[str, Any] | None = {}
+ execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {}
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
- else:
- inputs = {}
- process_data = {}
- outputs = {}
- execution_metadata = {}
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
@@ -289,7 +289,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent)
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
- outputs=event.route_node_state.node_run_result.outputs
+ outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
@@ -349,7 +349,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent)
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
- outputs=event.route_node_state.node_run_result.outputs
+ outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py
index 31c3a996e19286..16dc91bb777a9b 100644
--- a/api/core/app/entities/app_invoke_entities.py
+++ b/api/core/app/entities/app_invoke_entities.py
@@ -5,7 +5,7 @@
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from constants import UUID_NIL
-from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
+from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity
@@ -79,7 +79,7 @@ class AppGenerateEntity(BaseModel):
task_id: str
# app config
- app_config: AppConfig
+ app_config: Any
file_upload_config: Optional[FileUploadConfig] = None
inputs: Mapping[str, Any]
diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py
index d73c2eb53bfcd7..a93e533ff45d26 100644
--- a/api/core/app/entities/queue_entities.py
+++ b/api/core/app/entities/queue_entities.py
@@ -308,7 +308,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
- execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
+ execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
error: Optional[str] = None
"""single iteration duration map"""
diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py
index dd088a897816ef..5e845eba2da1d3 100644
--- a/api/core/app/entities/task_entities.py
+++ b/api/core/app/entities/task_entities.py
@@ -70,7 +70,7 @@ class StreamResponse(BaseModel):
event: StreamEvent
task_id: str
- def to_dict(self) -> dict:
+ def to_dict(self):
return jsonable_encoder(self)
@@ -474,8 +474,8 @@ class Data(BaseModel):
title: str
created_at: int
extras: dict = {}
- metadata: dict = {}
- inputs: dict = {}
+ metadata: Mapping = {}
+ inputs: Mapping = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
@@ -526,15 +526,15 @@ class Data(BaseModel):
node_id: str
node_type: str
title: str
- outputs: Optional[dict] = None
+ outputs: Optional[Mapping] = None
created_at: int
extras: Optional[dict] = None
- inputs: Optional[dict] = None
+ inputs: Optional[Mapping] = None
status: WorkflowNodeExecutionStatus
error: Optional[str] = None
elapsed_time: float
total_tokens: int
- execution_metadata: Optional[dict] = None
+ execution_metadata: Optional[Mapping] = None
finished_at: int
steps: int
parallel_id: Optional[str] = None
@@ -628,7 +628,7 @@ class AppBlockingResponse(BaseModel):
task_id: str
- def to_dict(self) -> dict:
+ def to_dict(self):
return jsonable_encoder(self)
diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py
index 77b6bb554c65ec..83fd3debad4cf1 100644
--- a/api/core/app/features/annotation_reply/annotation_reply.py
+++ b/api/core/app/features/annotation_reply/annotation_reply.py
@@ -58,7 +58,7 @@ def query(
query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]}
)
- if documents:
+ if documents and documents[0].metadata:
annotation_id = documents[0].metadata["annotation_id"]
score = documents[0].metadata["score"]
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py
index 8fe1d96b37be0c..dcc2b4e55f6ae1 100644
--- a/api/core/app/features/rate_limiting/rate_limit.py
+++ b/api/core/app/features/rate_limiting/rate_limit.py
@@ -17,7 +17,7 @@ class RateLimit:
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
- _instance_dict = {}
+ _instance_dict: dict[str, "RateLimit"] = {}
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py
index 51d610e2cbedc6..03a81353d02625 100644
--- a/api/core/app/task_pipeline/based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py
@@ -62,6 +62,7 @@ def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = Non
"""
logger.debug("error: %s", event.error)
e = event.error
+ err: Exception
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
@@ -130,6 +131,7 @@ def _init_output_moderation(self) -> Optional[OutputModeration]:
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
queue_manager=self._queue_manager,
)
+ return None
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
"""
diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
index e26b60c4d3043e..b9f8e7ca560ce7 100644
--- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
@@ -2,6 +2,7 @@
import logging
import time
from collections.abc import Generator
+from threading import Thread
from typing import Optional, Union, cast
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@@ -103,7 +104,7 @@ def __init__(
)
)
- self._conversation_name_generate_thread = None
+ self._conversation_name_generate_thread: Optional[Thread] = None
def process(
self,
@@ -123,7 +124,7 @@ def process(
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
- self._conversation, self._application_generate_entity.query
+ self._conversation, self._application_generate_entity.query or ""
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
@@ -146,7 +147,7 @@ def _to_blocking_response(
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata
-
+ response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation.mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
@@ -154,7 +155,7 @@ def _to_blocking_response(
id=self._message.id,
mode=self._conversation.mode,
message_id=self._message.id,
- answer=self._task_state.llm_result.message.content,
+ answer=cast(str, self._task_state.llm_result.message.content),
created_at=int(self._message.created_at.timestamp()),
**extras,
),
@@ -167,7 +168,7 @@ def _to_blocking_response(
mode=self._conversation.mode,
conversation_id=self._conversation.id,
message_id=self._message.id,
- answer=self._task_state.llm_result.message.content,
+ answer=cast(str, self._task_state.llm_result.message.content),
created_at=int(self._message.created_at.timestamp()),
**extras,
),
@@ -252,7 +253,7 @@ def _wrapper_process_stream_response(
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
- self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
+ self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
@@ -269,13 +270,14 @@ def _process_stream_response(
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent):
- self._task_state.llm_result = event.llm_result
+ if event.llm_result:
+ self._task_state.llm_result = event.llm_result
else:
self._handle_stop(event)
# handle output moderation
output_moderation_answer = self._handle_output_moderation_when_task_finished(
- self._task_state.llm_result.message.content
+ cast(str, self._task_state.llm_result.message.content)
)
if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer
@@ -292,7 +294,9 @@ def _process_stream_response(
if annotation:
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, QueueAgentThoughtEvent):
- yield self._agent_thought_to_stream_response(event)
+ agent_thought_response = self._agent_thought_to_stream_response(event)
+ if agent_thought_response is not None:
+ yield agent_thought_response
elif isinstance(event, QueueMessageFileEvent):
response = self._message_file_to_stream_response(event)
if response:
@@ -307,16 +311,18 @@ def _process_stream_response(
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
# handle output moderation chunk
- should_direct_answer = self._handle_output_moderation_chunk(delta_text)
+ should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
if should_direct_answer:
continue
- self._task_state.llm_result.message.content += delta_text
+ current_content = cast(str, self._task_state.llm_result.message.content)
+ current_content += cast(str, delta_text)
+ self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
- yield self._message_to_stream_response(delta_text, self._message.id)
+ yield self._message_to_stream_response(cast(str, delta_text), self._message.id)
else:
- yield self._agent_message_to_stream_response(delta_text, self._message.id)
+ yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
@@ -336,8 +342,14 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No
llm_result = self._task_state.llm_result
usage = llm_result.usage
- self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
- self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
+ message = db.session.query(Message).filter(Message.id == self._message.id).first()
+ if not message:
+ raise Exception(f"Message {self._message.id} not found")
+ self._message = message
+ conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
+ if not conversation:
+ raise Exception(f"Conversation {self._conversation.id} not found")
+ self._conversation = conversation
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
self._model_config.mode, self._task_state.llm_result.prompt_messages
@@ -346,7 +358,7 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
self._message.answer = (
- PromptTemplateParser.remove_template_variables(llm_result.message.content.strip())
+ PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
if llm_result.message.content
else ""
)
@@ -374,6 +386,7 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
+ and hasattr(self._application_generate_entity, "conversation_id")
and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras,
)
@@ -420,7 +433,9 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
extras["metadata"] = self._task_state.metadata
return MessageEndStreamResponse(
- task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
+ task_id=self._application_generate_entity.task_id,
+ id=self._message.id,
+ metadata=extras.get("metadata", {}),
)
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
@@ -440,7 +455,7 @@ def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Op
:param event: agent thought event
:return:
"""
- agent_thought: MessageAgentThought = (
+ agent_thought: Optional[MessageAgentThought] = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
)
db.session.refresh(agent_thought)
diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py
index e818a090ed7d0f..007543f6d0d1f2 100644
--- a/api/core/app/task_pipeline/message_cycle_manage.py
+++ b/api/core/app/task_pipeline/message_cycle_manage.py
@@ -128,7 +128,7 @@ def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Opti
"""
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
- if message_file:
+ if message_file and message_file.url is not None:
# get tool file id
tool_file_id = message_file.url.split("/")[-1]
# trim extension
diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py
index df7dbace0ef818..f581e564f224ce 100644
--- a/api/core/app/task_pipeline/workflow_cycle_manage.py
+++ b/api/core/app/task_pipeline/workflow_cycle_manage.py
@@ -93,7 +93,7 @@ def _handle_workflow_run_start(self) -> WorkflowRun:
)
# handle special values
- inputs = WorkflowEntry.handle_special_values(inputs)
+ inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run
with Session(db.engine, expire_on_commit=False) as session:
@@ -192,7 +192,7 @@ def _handle_workflow_run_partial_success(
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
- outputs = WorkflowEntry.handle_special_values(outputs)
+ outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
workflow_run.outputs = json.dumps(outputs or {})
@@ -500,7 +500,7 @@ def _workflow_start_to_stream_response(
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
- inputs=workflow_run.inputs_dict,
+ inputs=dict(workflow_run.inputs_dict or {}),
created_at=int(workflow_run.created_at.timestamp()),
),
)
@@ -545,7 +545,7 @@ def _workflow_finish_to_stream_response(
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
status=workflow_run.status,
- outputs=workflow_run.outputs_dict,
+ outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None,
error=workflow_run.error,
elapsed_time=workflow_run.elapsed_time,
total_tokens=workflow_run.total_tokens,
@@ -553,7 +553,7 @@ def _workflow_finish_to_stream_response(
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
- files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
+ files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)),
exceptions_count=workflow_run.exceptions_count,
),
)
@@ -655,7 +655,7 @@ def _workflow_node_retry_to_stream_response(
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
- ) -> Optional[NodeFinishStreamResponse]:
+ ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
"""
Workflow node finish to stream response.
:param event: queue node succeeded or failed event
@@ -838,7 +838,7 @@ def _workflow_iteration_completed_to_stream_response(
),
)
- def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]:
+ def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
@@ -851,9 +851,11 @@ def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping
# Remove None
files = [file for file in files if file]
# Flatten list
- files = [file for sublist in files for file in sublist]
+ # Flatten the list of sequences into a single list of mappings
+ flattened_files = [file for sublist in files if sublist for file in sublist]
- return files
+ # Convert to tuple to match Sequence type
+ return tuple(flattened_files)
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
"""
@@ -891,6 +893,8 @@ def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any
elif isinstance(value, File):
return value.to_dict()
+ return None
+
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run
diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py
index d826edf6a0fc19..effc7eff9179ae 100644
--- a/api/core/callback_handler/agent_tool_callback_handler.py
+++ b/api/core/callback_handler/agent_tool_callback_handler.py
@@ -57,7 +57,7 @@ def on_tool_end(
self,
tool_name: str,
tool_inputs: Mapping[str, Any],
- tool_outputs: Sequence[ToolInvokeMessage],
+ tool_outputs: Sequence[ToolInvokeMessage] | str,
message_id: Optional[str] = None,
timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None,
diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py
index 1481578630f63b..8f8aaa93d6f986 100644
--- a/api/core/callback_handler/index_tool_callback_handler.py
+++ b/api/core/callback_handler/index_tool_callback_handler.py
@@ -40,17 +40,18 @@ def on_query(self, query: str, dataset_id: str) -> None:
def on_tool_end(self, documents: list[Document]) -> None:
"""Handle tool end."""
for document in documents:
- query = db.session.query(DocumentSegment).filter(
- DocumentSegment.index_node_id == document.metadata["doc_id"]
- )
+ if document.metadata is not None:
+ query = db.session.query(DocumentSegment).filter(
+ DocumentSegment.index_node_id == document.metadata["doc_id"]
+ )
- if "dataset_id" in document.metadata:
- query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
+ if "dataset_id" in document.metadata:
+ query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
- # add hit count to document segment
- query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
+ # add hit count to document segment
+ query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
- db.session.commit()
+ db.session.commit()
def return_retriever_resource_info(self, resource: list):
"""Handle return_retriever_resource_info."""
diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py
index 9ed5528e43b9b8..5017835565789c 100644
--- a/api/core/entities/model_entities.py
+++ b/api/core/entities/model_entities.py
@@ -1,3 +1,4 @@
+from collections.abc import Sequence
from enum import Enum
from typing import Optional
@@ -72,7 +73,7 @@ class DefaultModelProviderEntity(BaseModel):
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
- supported_model_types: list[ModelType]
+ supported_model_types: Sequence[ModelType] = []
class DefaultModelEntity(BaseModel):
diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py
index d1b34db2fe7172..2e27b362d3092c 100644
--- a/api/core/entities/provider_configuration.py
+++ b/api/core/entities/provider_configuration.py
@@ -40,7 +40,7 @@
logger = logging.getLogger(__name__)
-original_provider_configurate_methods = {}
+original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {}
class ProviderConfiguration(BaseModel):
@@ -99,7 +99,8 @@ def get_current_credentials(self, model_type: ModelType, model: str) -> Optional
continue
restrict_models = quota_configuration.restrict_models
-
+ if self.system_configuration.credentials is None:
+ return None
copy_credentials = self.system_configuration.credentials.copy()
if restrict_models:
for restrict_model in restrict_models:
@@ -124,7 +125,7 @@ def get_current_credentials(self, model_type: ModelType, model: str) -> Optional
return credentials
- def get_system_configuration_status(self) -> SystemConfigurationStatus:
+ def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
"""
Get system configuration status.
:return:
@@ -136,6 +137,8 @@ def get_system_configuration_status(self) -> SystemConfigurationStatus:
current_quota_configuration = next(
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
)
+ if current_quota_configuration is None:
+ return None
return (
SystemConfigurationStatus.ACTIVE
@@ -150,7 +153,7 @@ def is_custom_configuration_available(self) -> bool:
"""
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
- def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
+ def get_custom_credentials(self, obfuscated: bool = False):
"""
Get custom credentials.
@@ -172,7 +175,7 @@ def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
else [],
)
- def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
+ def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]:
"""
Validate custom credentials.
:param credentials: provider credentials
@@ -324,7 +327,7 @@ def get_custom_model_credentials(
def custom_model_credentials_validate(
self, model_type: ModelType, model: str, credentials: dict
- ) -> tuple[ProviderModel, dict]:
+ ) -> tuple[Optional[ProviderModel], dict]:
"""
Validate custom model credentials.
@@ -740,10 +743,10 @@ def get_provider_models(
if model_type:
model_types.append(model_type)
else:
- model_types = provider_instance.get_provider_schema().supported_model_types
+ model_types = list(provider_instance.get_provider_schema().supported_model_types)
# Group model settings by model type and model
- model_setting_map = defaultdict(dict)
+ model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
for model_setting in self.model_settings:
model_setting_map[model_setting.model_type][model_setting.model] = model_setting
@@ -822,54 +825,57 @@ def _get_system_provider_models(
]:
# only customizable model
for restrict_model in restrict_models:
- copy_credentials = self.system_configuration.credentials.copy()
- if restrict_model.base_model_name:
- copy_credentials["base_model_name"] = restrict_model.base_model_name
-
- try:
- custom_model_schema = provider_instance.get_model_instance(
- restrict_model.model_type
- ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
- except Exception as ex:
- logger.warning(f"get custom model schema failed, {ex}")
- continue
-
- if not custom_model_schema:
- continue
-
- if custom_model_schema.model_type not in model_types:
- continue
-
- status = ModelStatus.ACTIVE
- if (
- custom_model_schema.model_type in model_setting_map
- and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
- ):
- model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
- if model_setting.enabled is False:
- status = ModelStatus.DISABLED
-
- provider_models.append(
- ModelWithProviderEntity(
- model=custom_model_schema.model,
- label=custom_model_schema.label,
- model_type=custom_model_schema.model_type,
- features=custom_model_schema.features,
- fetch_from=FetchFrom.PREDEFINED_MODEL,
- model_properties=custom_model_schema.model_properties,
- deprecated=custom_model_schema.deprecated,
- provider=SimpleModelProviderEntity(self.provider),
- status=status,
+ if self.system_configuration.credentials is not None:
+ copy_credentials = self.system_configuration.credentials.copy()
+ if restrict_model.base_model_name:
+ copy_credentials["base_model_name"] = restrict_model.base_model_name
+
+ try:
+ custom_model_schema = provider_instance.get_model_instance(
+ restrict_model.model_type
+ ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
+ except Exception as ex:
+ logger.warning(f"get custom model schema failed, {ex}")
+ continue
+
+ if not custom_model_schema:
+ continue
+
+ if custom_model_schema.model_type not in model_types:
+ continue
+
+ status = ModelStatus.ACTIVE
+ if (
+ custom_model_schema.model_type in model_setting_map
+ and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
+ ):
+ model_setting = model_setting_map[custom_model_schema.model_type][
+ custom_model_schema.model
+ ]
+ if model_setting.enabled is False:
+ status = ModelStatus.DISABLED
+
+ provider_models.append(
+ ModelWithProviderEntity(
+ model=custom_model_schema.model,
+ label=custom_model_schema.label,
+ model_type=custom_model_schema.model_type,
+ features=custom_model_schema.features,
+ fetch_from=FetchFrom.PREDEFINED_MODEL,
+ model_properties=custom_model_schema.model_properties,
+ deprecated=custom_model_schema.deprecated,
+ provider=SimpleModelProviderEntity(self.provider),
+ status=status,
+ )
)
- )
# if llm name not in restricted llm list, remove it
restrict_model_names = [rm.model for rm in restrict_models]
- for m in provider_models:
- if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
- m.status = ModelStatus.NO_PERMISSION
+ for model in provider_models:
+ if model.model_type == ModelType.LLM and m.model not in restrict_model_names:
+ model.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid:
- m.status = ModelStatus.QUOTA_EXCEEDED
+ model.status = ModelStatus.QUOTA_EXCEEDED
return provider_models
@@ -1043,7 +1049,7 @@ def __iter__(self):
return iter(self.configurations)
def values(self) -> Iterator[ProviderConfiguration]:
- return self.configurations.values()
+ return iter(self.configurations.values())
def get(self, key, default=None):
return self.configurations.get(key, default)
diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py
index 38cebb6b6b1c36..3f4e20ec245302 100644
--- a/api/core/extension/api_based_extension_requestor.py
+++ b/api/core/extension/api_based_extension_requestor.py
@@ -1,3 +1,5 @@
+from typing import cast
+
import requests
from configs import dify_config
@@ -5,7 +7,7 @@
class APIBasedExtensionRequestor:
- timeout: (int, int) = (5, 60)
+ timeout: tuple[int, int] = (5, 60)
"""timeout for request connect and read"""
def __init__(self, api_endpoint: str, api_key: str) -> None:
@@ -51,4 +53,4 @@ def request(self, point: APIBasedExtensionPoint, params: dict) -> dict:
"request error, status_code: {}, content: {}".format(response.status_code, response.text[:100])
)
- return response.json()
+ return cast(dict, response.json())
diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py
index 97dbaf2026e790..231743bf2a948c 100644
--- a/api/core/extension/extensible.py
+++ b/api/core/extension/extensible.py
@@ -38,8 +38,8 @@ def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
@classmethod
def scan_extensions(cls):
- extensions: list[ModuleExtension] = []
- position_map = {}
+ extensions = []
+ position_map: dict[str, int] = {}
# get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
@@ -58,7 +58,8 @@ def scan_extensions(cls):
# is builtin extension, builtin extension
# in the front-end page and business logic, there are special treatments.
builtin = False
- position = None
+ # default position is 0 can not be None for sort_to_dict_by_position_map
+ position = 0
if "__builtin__" in file_names:
builtin = True
@@ -89,7 +90,7 @@ def scan_extensions(cls):
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
continue
- json_data = {}
+ json_data: dict[str, Any] = {}
if not builtin:
if "schema.json" not in file_names:
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py
index 3da170455e3398..9eb9e0306b577f 100644
--- a/api/core/extension/extension.py
+++ b/api/core/extension/extension.py
@@ -1,4 +1,6 @@
-from core.extension.extensible import ExtensionModule, ModuleExtension
+from typing import cast
+
+from core.extension.extensible import Extensible, ExtensionModule, ModuleExtension
from core.external_data_tool.base import ExternalDataTool
from core.moderation.base import Moderation
@@ -10,7 +12,8 @@ class Extension:
def init(self):
for module, module_class in self.module_classes.items():
- self.__module_extensions[module.value] = module_class.scan_extensions()
+ m = cast(Extensible, module_class)
+ self.__module_extensions[module.value] = m.scan_extensions()
def module_extensions(self, module: str) -> list[ModuleExtension]:
module_extensions = self.__module_extensions.get(module)
@@ -35,7 +38,8 @@ def module_extension(self, module: ExtensionModule, extension_name: str) -> Modu
def extension_class(self, module: ExtensionModule, extension_name: str) -> type:
module_extension = self.module_extension(module, extension_name)
- return module_extension.extension_class
+ t: type = module_extension.extension_class
+ return t
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
module_extension = self.module_extension(module, extension_name)
diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py
index 54ec97a4933a94..9989c8a09013bd 100644
--- a/api/core/external_data_tool/api/api.py
+++ b/api/core/external_data_tool/api/api.py
@@ -48,7 +48,10 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str:
:return: the tool query result
"""
# get params from config
+ if not self.config:
+ raise ValueError("config is required, config: {}".format(self.config))
api_based_extension_id = self.config.get("api_based_extension_id")
+ assert api_based_extension_id is not None, "api_based_extension_id is required"
# get api_based_extension
api_based_extension = (
diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py
index 84b94e117ff5f9..6a9703a569b308 100644
--- a/api/core/external_data_tool/external_data_fetch.py
+++ b/api/core/external_data_tool/external_data_fetch.py
@@ -1,7 +1,7 @@
-import concurrent
import logging
-from concurrent.futures import ThreadPoolExecutor
-from typing import Optional
+from collections.abc import Mapping
+from concurrent.futures import Future, ThreadPoolExecutor, as_completed
+from typing import Any, Optional
from flask import Flask, current_app
@@ -17,9 +17,9 @@ def fetch(
tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
- inputs: dict,
+ inputs: Mapping[str, Any],
query: str,
- ) -> dict:
+ ) -> Mapping[str, Any]:
"""
Fill in variable inputs from external data tools if exists.
@@ -30,13 +30,14 @@ def fetch(
:param query: the query
:return: the filled inputs
"""
- results = {}
+ results: dict[str, Any] = {}
+ inputs = dict(inputs)
with ThreadPoolExecutor() as executor:
futures = {}
for tool in external_data_tools:
- future = executor.submit(
+ future: Future[tuple[str | None, str | None]] = executor.submit(
self._query_external_data_tool,
- current_app._get_current_object(),
+ current_app._get_current_object(), # type: ignore
tenant_id,
app_id,
tool,
@@ -46,9 +47,10 @@ def fetch(
futures[future] = tool
- for future in concurrent.futures.as_completed(futures):
+ for future in as_completed(futures):
tool_variable, result = future.result()
- results[tool_variable] = result
+ if tool_variable is not None:
+ results[tool_variable] = result
inputs.update(results)
return inputs
@@ -59,7 +61,7 @@ def _query_external_data_tool(
tenant_id: str,
app_id: str,
external_data_tool: ExternalDataVariableEntity,
- inputs: dict,
+ inputs: Mapping[str, Any],
query: str,
) -> tuple[Optional[str], Optional[str]]:
"""
diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py
index 28721098594962..245507e17c7032 100644
--- a/api/core/external_data_tool/factory.py
+++ b/api/core/external_data_tool/factory.py
@@ -1,4 +1,5 @@
-from typing import Optional
+from collections.abc import Mapping
+from typing import Any, Optional, cast
from core.extension.extensible import ExtensionModule
from extensions.ext_code_based_extension import code_based_extension
@@ -23,9 +24,10 @@ def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
"""
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
- extension_class.validate_config(tenant_id, config)
+ # FIXME mypy issue here, figure out how to fix it
+ extension_class.validate_config(tenant_id, config) # type: ignore
- def query(self, inputs: dict, query: Optional[str] = None) -> str:
+ def query(self, inputs: Mapping[str, Any], query: Optional[str] = None) -> str:
"""
Query the external data tool.
@@ -33,4 +35,4 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str:
:param query: the query of chat app
:return: the tool query result
"""
- return self.__extension_instance.query(inputs, query)
+ return cast(str, self.__extension_instance.query(inputs, query))
diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py
index 15eb351a7ef309..4a50fb85c9cca3 100644
--- a/api/core/file/file_manager.py
+++ b/api/core/file/file_manager.py
@@ -1,4 +1,5 @@
import base64
+from collections.abc import Mapping
from configs import dify_config
from core.helper import ssrf_proxy
@@ -55,7 +56,7 @@ def to_prompt_message_content(
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
- prompt_class_map = {
+ prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,
@@ -63,7 +64,7 @@ def to_prompt_message_content(
}
try:
- return prompt_class_map[f.type](**params)
+ return prompt_class_map[f.type].model_validate(params)
except KeyError:
raise ValueError(f"file type {f.type} is not supported")
diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py
index a17b7be3675ab1..6fa101cf36192b 100644
--- a/api/core/file/tool_file_parser.py
+++ b/api/core/file/tool_file_parser.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, cast
if TYPE_CHECKING:
from core.tools.tool_file_manager import ToolFileManager
@@ -9,4 +9,4 @@
class ToolFileParser:
@staticmethod
def get_tool_file_manager() -> "ToolFileManager":
- return tool_file_manager["manager"]
+ return cast("ToolFileManager", tool_file_manager["manager"])
diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py
index 584e3e9698a88d..15b501780e766c 100644
--- a/api/core/helper/code_executor/code_executor.py
+++ b/api/core/helper/code_executor/code_executor.py
@@ -38,7 +38,7 @@ class CodeLanguage(StrEnum):
class CodeExecutor:
- dependencies_cache = {}
+ dependencies_cache: dict[str, str] = {}
dependencies_cache_lock = Lock()
code_template_transformers: dict[CodeLanguage, type[TemplateTransformer]] = {
@@ -103,19 +103,19 @@ def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str:
)
try:
- response = response.json()
+ response_data = response.json()
except:
raise CodeExecutionError("Failed to parse response")
- if (code := response.get("code")) != 0:
- raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}")
+ if (code := response_data.get("code")) != 0:
+ raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}")
- response = CodeExecutionResponse(**response)
+ response_code = CodeExecutionResponse(**response_data)
- if response.data.error:
- raise CodeExecutionError(response.data.error)
+ if response_code.data.error:
+ raise CodeExecutionError(response_code.data.error)
- return response.data.stdout or ""
+ return response_code.data.stdout or ""
@classmethod
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]):
diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py
index db2eb5ebb6b19a..264947b5686d0e 100644
--- a/api/core/helper/code_executor/jinja2/jinja2_formatter.py
+++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py
@@ -1,9 +1,11 @@
+from collections.abc import Mapping
+
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
class Jinja2Formatter:
@classmethod
- def format(cls, template: str, inputs: dict) -> str:
+ def format(cls, template: str, inputs: Mapping[str, str]) -> str:
"""
Format template
:param template: template
@@ -11,5 +13,4 @@ def format(cls, template: str, inputs: dict) -> str:
:return:
"""
result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs)
-
- return result["result"]
+ return str(result.get("result", ""))
diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py
index 605719747a7b56..baa792b5bc6c41 100644
--- a/api/core/helper/code_executor/template_transformer.py
+++ b/api/core/helper/code_executor/template_transformer.py
@@ -29,8 +29,7 @@ def extract_result_str_from_response(cls, response: str):
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
if not result:
raise ValueError("Failed to parse result")
- result = result.group(1)
- return result
+ return result.group(1)
@classmethod
def transform_response(cls, response: str) -> Mapping[str, Any]:
diff --git a/api/core/helper/lru_cache.py b/api/core/helper/lru_cache.py
index 518962c1652df7..81501d2e4e23b2 100644
--- a/api/core/helper/lru_cache.py
+++ b/api/core/helper/lru_cache.py
@@ -4,7 +4,7 @@
class LRUCache:
def __init__(self, capacity: int):
- self.cache = OrderedDict()
+ self.cache: OrderedDict[Any, Any] = OrderedDict()
self.capacity = capacity
def get(self, key: Any) -> Any:
diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py
index 5e274f8916869d..35349210bd53ab 100644
--- a/api/core/helper/model_provider_cache.py
+++ b/api/core/helper/model_provider_cache.py
@@ -30,7 +30,7 @@ def get(self) -> Optional[dict]:
except JSONDecodeError:
return None
- return cached_provider_credentials
+ return dict(cached_provider_credentials)
else:
return None
diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py
index da0fd0031cc6dc..543444463b9f1a 100644
--- a/api/core/helper/moderation.py
+++ b/api/core/helper/moderation.py
@@ -22,6 +22,7 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
provider_name = model_config.provider
if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers:
hosting_openai_config = hosting_configuration.provider_map["openai"]
+ assert hosting_openai_config is not None
# 2000 text per chunk
length = 2000
@@ -34,8 +35,9 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
try:
model_type_instance = OpenAIModerationModel()
+ # FIXME, for type hint using assert or raise ValueError is better here?
moderation_result = model_type_instance.invoke(
- model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk
+ model="text-moderation-stable", credentials=hosting_openai_config.credentials or {}, text=text_chunk
)
if moderation_result is True:
diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py
index 1e2fefce88b632..9a041667e46df5 100644
--- a/api/core/helper/module_import_helper.py
+++ b/api/core/helper/module_import_helper.py
@@ -14,12 +14,13 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
if existed_spec:
spec = existed_spec
if not spec.loader:
- raise Exception(f"Failed to load module {module_name} from {py_file_path}")
+ raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
else:
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
- spec = importlib.util.spec_from_file_location(module_name, py_file_path)
+ # FIXME: mypy does not support the type of spec.loader
+ spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore
if not spec or not spec.loader:
- raise Exception(f"Failed to load module {module_name} from {py_file_path}")
+ raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
if use_lazy_loader:
# Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
spec.loader = importlib.util.LazyLoader(spec.loader)
@@ -29,7 +30,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
spec.loader.exec_module(module)
return module
except Exception as e:
- logging.exception(f"Failed to load module {module_name} from script file '{py_file_path}'")
+ logging.exception(f"Failed to load module {module_name} from script file '{py_file_path!r}'")
raise e
@@ -57,6 +58,6 @@ def load_single_subclass_from_source(
case 1:
return subclasses[0]
case 0:
- raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path}")
+ raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path!r}")
case _:
- raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path}")
+ raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path!r}")
diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py
index e848b46c5633ab..3b67b3f84838d3 100644
--- a/api/core/helper/tool_parameter_cache.py
+++ b/api/core/helper/tool_parameter_cache.py
@@ -33,7 +33,7 @@ def get(self) -> Optional[dict]:
except JSONDecodeError:
return None
- return cached_tool_parameter
+ return dict(cached_tool_parameter)
else:
return None
diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py
index 94b02cf98578b1..6de5e704abf4f5 100644
--- a/api/core/helper/tool_provider_cache.py
+++ b/api/core/helper/tool_provider_cache.py
@@ -28,7 +28,7 @@ def get(self) -> Optional[dict]:
except JSONDecodeError:
return None
- return cached_provider_credentials
+ return dict(cached_provider_credentials)
else:
return None
diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py
index b47ba67f2fa64f..f9fb7275f3624f 100644
--- a/api/core/hosting_configuration.py
+++ b/api/core/hosting_configuration.py
@@ -42,7 +42,7 @@ class HostedModerationConfig(BaseModel):
class HostingConfiguration:
provider_map: dict[str, HostingProvider] = {}
- moderation_config: HostedModerationConfig = None
+ moderation_config: Optional[HostedModerationConfig] = None
def init_app(self, app: Flask) -> None:
if dify_config.EDITION != "CLOUD":
@@ -67,7 +67,7 @@ def init_azure_openai() -> HostingProvider:
"base_model_name": "gpt-35-turbo",
}
- quotas = []
+ quotas: list[HostingQuota] = []
hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
@@ -123,7 +123,7 @@ def init_azure_openai() -> HostingProvider:
def init_openai(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
- quotas = []
+ quotas: list[HostingQuota] = []
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
@@ -157,7 +157,7 @@ def init_openai(self) -> HostingProvider:
@staticmethod
def init_anthropic() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
- quotas = []
+ quotas: list[HostingQuota] = []
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
@@ -187,7 +187,7 @@ def init_anthropic() -> HostingProvider:
def init_minimax() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if dify_config.HOSTED_MINIMAX_ENABLED:
- quotas = [FreeHostingQuota()]
+ quotas: list[HostingQuota] = [FreeHostingQuota()]
return HostingProvider(
enabled=True,
@@ -205,7 +205,7 @@ def init_minimax() -> HostingProvider:
def init_spark() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if dify_config.HOSTED_SPARK_ENABLED:
- quotas = [FreeHostingQuota()]
+ quotas: list[HostingQuota] = [FreeHostingQuota()]
return HostingProvider(
enabled=True,
@@ -223,7 +223,7 @@ def init_spark() -> HostingProvider:
def init_zhipuai() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
if dify_config.HOSTED_ZHIPUAI_ENABLED:
- quotas = [FreeHostingQuota()]
+ quotas: list[HostingQuota] = [FreeHostingQuota()]
return HostingProvider(
enabled=True,
diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py
index 29e161cb747284..1f0a0d0ef1dda4 100644
--- a/api/core/indexing_runner.py
+++ b/api/core/indexing_runner.py
@@ -6,10 +6,10 @@
import threading
import time
import uuid
-from typing import Optional, cast
+from typing import Any, Optional, cast
from flask import Flask, current_app
-from flask_login import current_user
+from flask_login import current_user # type: ignore
from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config
@@ -62,6 +62,8 @@ def run(self, dataset_documents: list[DatasetDocument]):
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
+ if not processing_rule:
+ raise ValueError("no process rule found")
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# extract
@@ -120,6 +122,8 @@ def run_in_splitting_status(self, dataset_document: DatasetDocument):
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
+ if not processing_rule:
+ raise ValueError("no process rule found")
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
@@ -254,7 +258,7 @@ def indexing_estimate(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
- preview_texts = []
+ preview_texts: list[str] = []
total_segments = 0
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
@@ -285,7 +289,8 @@ def indexing_estimate(
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
try:
- storage.delete(image_file.key)
+ if image_file:
+ storage.delete(image_file.key)
except Exception:
logging.exception(
"Delete image_files failed while indexing_estimate, \
@@ -379,8 +384,9 @@ def _extract(
# replace doc id to document model id
text_docs = cast(list[Document], text_docs)
for text_doc in text_docs:
- text_doc.metadata["document_id"] = dataset_document.id
- text_doc.metadata["dataset_id"] = dataset_document.dataset_id
+ if text_doc.metadata is not None:
+ text_doc.metadata["document_id"] = dataset_document.id
+ text_doc.metadata["dataset_id"] = dataset_document.dataset_id
return text_docs
@@ -400,6 +406,7 @@ def _get_splitter(
"""
Get the NodeParser object according to the processing rule.
"""
+ character_splitter: TextSplitter
if processing_rule.mode == "custom":
# The user-defined segmentation rule
rules = json.loads(processing_rule.rules)
@@ -426,9 +433,10 @@ def _get_splitter(
)
else:
# Automatic segmentation
+ automatic_rules: dict[str, Any] = dict(DatasetProcessRule.AUTOMATIC_RULES["segmentation"])
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
- chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"],
- chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"],
+ chunk_size=automatic_rules["max_tokens"],
+ chunk_overlap=automatic_rules["chunk_overlap"],
separators=["\n\n", "。", ". ", " ", ""],
embedding_model_instance=embedding_model_instance,
)
@@ -497,8 +505,8 @@ def _split_to_documents(
"""
Split the text documents into nodes.
"""
- all_documents = []
- all_qa_documents = []
+ all_documents: list[Document] = []
+ all_qa_documents: list[Document] = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.page_content, processing_rule)
@@ -509,10 +517,11 @@ def _split_to_documents(
split_documents = []
for document_node in documents:
if document_node.page_content.strip():
- doc_id = str(uuid.uuid4())
- hash = helper.generate_text_hash(document_node.page_content)
- document_node.metadata["doc_id"] = doc_id
- document_node.metadata["doc_hash"] = hash
+ if document_node.metadata is not None:
+ doc_id = str(uuid.uuid4())
+ hash = helper.generate_text_hash(document_node.page_content)
+ document_node.metadata["doc_id"] = doc_id
+ document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
document_node.page_content = remove_leading_symbols(page_content)
@@ -529,7 +538,7 @@ def _split_to_documents(
document_format_thread = threading.Thread(
target=self.format_qa_document,
kwargs={
- "flask_app": current_app._get_current_object(),
+ "flask_app": current_app._get_current_object(), # type: ignore
"tenant_id": tenant_id,
"document_node": doc,
"all_qa_documents": all_qa_documents,
@@ -557,11 +566,12 @@ def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, al
qa_document = Document(
page_content=result["question"], metadata=document_node.metadata.model_copy()
)
- doc_id = str(uuid.uuid4())
- hash = helper.generate_text_hash(result["question"])
- qa_document.metadata["answer"] = result["answer"]
- qa_document.metadata["doc_id"] = doc_id
- qa_document.metadata["doc_hash"] = hash
+ if qa_document.metadata is not None:
+ doc_id = str(uuid.uuid4())
+ hash = helper.generate_text_hash(result["question"])
+ qa_document.metadata["answer"] = result["answer"]
+ qa_document.metadata["doc_id"] = doc_id
+ qa_document.metadata["doc_hash"] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
@@ -575,7 +585,7 @@ def _split_to_documents_for_estimate(
"""
Split the text documents into nodes.
"""
- all_documents = []
+ all_documents: list[Document] = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.page_content, processing_rule)
@@ -588,11 +598,11 @@ def _split_to_documents_for_estimate(
for document in documents:
if document.page_content is None or not document.page_content.strip():
continue
- doc_id = str(uuid.uuid4())
- hash = helper.generate_text_hash(document.page_content)
-
- document.metadata["doc_id"] = doc_id
- document.metadata["doc_hash"] = hash
+ if document.metadata is not None:
+ doc_id = str(uuid.uuid4())
+ hash = helper.generate_text_hash(document.page_content)
+ document.metadata["doc_id"] = doc_id
+ document.metadata["doc_hash"] = hash
split_documents.append(document)
@@ -648,7 +658,7 @@ def _load(
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
- args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
+ args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
)
create_keyword_thread.start()
if dataset.indexing_technique == "high_quality":
@@ -659,7 +669,7 @@ def _load(
futures.append(
executor.submit(
self._process_chunk,
- current_app._get_current_object(),
+ current_app._get_current_object(), # type: ignore
index_processor,
chunk_documents,
dataset,
diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py
index 3a92c8d9d22562..9fe3f68f2a8af5 100644
--- a/api/core/llm_generator/llm_generator.py
+++ b/api/core/llm_generator/llm_generator.py
@@ -1,7 +1,7 @@
import json
import logging
import re
-from typing import Optional
+from typing import Optional, cast
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
@@ -13,6 +13,7 @@
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
)
from core.model_manager import ModelManager
+from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
@@ -44,10 +45,13 @@ def generate_conversation_name(
prompts = [UserPromptMessage(content=prompt)]
with measure_time() as timer:
- response = model_instance.invoke_llm(
- prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
+ response = cast(
+ LLMResult,
+ model_instance.invoke_llm(
+ prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
+ ),
)
- answer = response.message.content
+ answer = cast(str, response.message.content)
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
if cleaned_answer is None:
return ""
@@ -94,11 +98,16 @@ def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: st
prompt_messages = [UserPromptMessage(content=prompt)]
try:
- response = model_instance.invoke_llm(
- prompt_messages=prompt_messages, model_parameters={"max_tokens": 256, "temperature": 0}, stream=False
+ response = cast(
+ LLMResult,
+ model_instance.invoke_llm(
+ prompt_messages=prompt_messages,
+ model_parameters={"max_tokens": 256, "temperature": 0},
+ stream=False,
+ ),
)
- questions = output_parser.parse(response.message.content)
+ questions = output_parser.parse(cast(str, response.message.content))
except InvokeError:
questions = []
except Exception as e:
@@ -138,11 +147,14 @@ def generate_rule_config(
)
try:
- response = model_instance.invoke_llm(
- prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
+ response = cast(
+ LLMResult,
+ model_instance.invoke_llm(
+ prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
+ ),
)
- rule_config["prompt"] = response.message.content
+ rule_config["prompt"] = cast(str, response.message.content)
except InvokeError as e:
error = str(e)
@@ -178,15 +190,18 @@ def generate_rule_config(
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
- provider=model_config.get("provider") if model_config else None,
- model=model_config.get("name") if model_config else None,
+ provider=model_config.get("provider", ""),
+ model=model_config.get("name", ""),
)
try:
try:
# the first step to generate the task prompt
- prompt_content = model_instance.invoke_llm(
- prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
+ prompt_content = cast(
+ LLMResult,
+ model_instance.invoke_llm(
+ prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
+ ),
)
except InvokeError as e:
error = str(e)
@@ -195,8 +210,10 @@ def generate_rule_config(
return rule_config
- rule_config["prompt"] = prompt_content.message.content
+ rule_config["prompt"] = cast(str, prompt_content.message.content)
+ if not isinstance(prompt_content.message.content, str):
+ raise NotImplementedError("prompt content is not a string")
parameter_generate_prompt = parameter_template.format(
inputs={
"INPUT_TEXT": prompt_content.message.content,
@@ -216,19 +233,25 @@ def generate_rule_config(
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
try:
- parameter_content = model_instance.invoke_llm(
- prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
+ parameter_content = cast(
+ LLMResult,
+ model_instance.invoke_llm(
+ prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
+ ),
)
- rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content)
+ rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
except InvokeError as e:
error = str(e)
error_step = "generate variables"
try:
- statement_content = model_instance.invoke_llm(
- prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
+ statement_content = cast(
+ LLMResult,
+ model_instance.invoke_llm(
+ prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
+ ),
)
- rule_config["opening_statement"] = statement_content.message.content
+ rule_config["opening_statement"] = cast(str, statement_content.message.content)
except InvokeError as e:
error = str(e)
error_step = "generate conversation opener"
@@ -267,19 +290,22 @@ def generate_code(
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
- provider=model_config.get("provider") if model_config else None,
- model=model_config.get("name") if model_config else None,
+ provider=model_config.get("provider", ""),
+ model=model_config.get("name", ""),
)
prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = {"max_tokens": max_tokens, "temperature": 0.01}
try:
- response = model_instance.invoke_llm(
- prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
+ response = cast(
+ LLMResult,
+ model_instance.invoke_llm(
+ prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
+ ),
)
- generated_code = response.message.content
+ generated_code = cast(str, response.message.content)
return {"code": generated_code, "language": code_language, "error": ""}
except InvokeError as e:
@@ -303,9 +329,14 @@ def generate_qa_document(cls, tenant_id: str, query, document_language: str):
prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
- response = model_instance.invoke_llm(
- prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False
+ response = cast(
+ LLMResult,
+ model_instance.invoke_llm(
+ prompt_messages=prompt_messages,
+ model_parameters={"temperature": 0.01, "max_tokens": 2000},
+ stream=False,
+ ),
)
- answer = response.message.content
+ answer = cast(str, response.message.content)
return answer.strip()
diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py
index 81d08dc8854f80..003a0c85b1f12e 100644
--- a/api/core/memory/token_buffer_memory.py
+++ b/api/core/memory/token_buffer_memory.py
@@ -68,7 +68,7 @@ def get_history_prompt_messages(
messages = list(reversed(thread_messages))
- prompt_messages = []
+ prompt_messages: list[PromptMessage] = []
for message in messages:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:
diff --git a/api/core/model_manager.py b/api/core/model_manager.py
index 1986688551b601..d1e71148cd6023 100644
--- a/api/core/model_manager.py
+++ b/api/core/model_manager.py
@@ -124,17 +124,20 @@ def invoke_llm(
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
- return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
- model=self.model,
- credentials=self.credentials,
- prompt_messages=prompt_messages,
- model_parameters=model_parameters,
- tools=tools,
- stop=stop,
- stream=stream,
- user=user,
- callbacks=callbacks,
+ return cast(
+ Union[LLMResult, Generator],
+ self._round_robin_invoke(
+ function=self.model_type_instance.invoke,
+ model=self.model,
+ credentials=self.credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ callbacks=callbacks,
+ ),
)
def get_llm_num_tokens(
@@ -151,12 +154,15 @@ def get_llm_num_tokens(
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
- return self._round_robin_invoke(
- function=self.model_type_instance.get_num_tokens,
- model=self.model,
- credentials=self.credentials,
- prompt_messages=prompt_messages,
- tools=tools,
+ return cast(
+ int,
+ self._round_robin_invoke(
+ function=self.model_type_instance.get_num_tokens,
+ model=self.model,
+ credentials=self.credentials,
+ prompt_messages=prompt_messages,
+ tools=tools,
+ ),
)
def invoke_text_embedding(
@@ -174,13 +180,16 @@ def invoke_text_embedding(
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
- return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
- model=self.model,
- credentials=self.credentials,
- texts=texts,
- user=user,
- input_type=input_type,
+ return cast(
+ TextEmbeddingResult,
+ self._round_robin_invoke(
+ function=self.model_type_instance.invoke,
+ model=self.model,
+ credentials=self.credentials,
+ texts=texts,
+ user=user,
+ input_type=input_type,
+ ),
)
def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
@@ -194,11 +203,14 @@ def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
- return self._round_robin_invoke(
- function=self.model_type_instance.get_num_tokens,
- model=self.model,
- credentials=self.credentials,
- texts=texts,
+ return cast(
+ int,
+ self._round_robin_invoke(
+ function=self.model_type_instance.get_num_tokens,
+ model=self.model,
+ credentials=self.credentials,
+ texts=texts,
+ ),
)
def invoke_rerank(
@@ -223,15 +235,18 @@ def invoke_rerank(
raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
- return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
- model=self.model,
- credentials=self.credentials,
- query=query,
- docs=docs,
- score_threshold=score_threshold,
- top_n=top_n,
- user=user,
+ return cast(
+ RerankResult,
+ self._round_robin_invoke(
+ function=self.model_type_instance.invoke,
+ model=self.model,
+ credentials=self.credentials,
+ query=query,
+ docs=docs,
+ score_threshold=score_threshold,
+ top_n=top_n,
+ user=user,
+ ),
)
def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool:
@@ -246,12 +261,15 @@ def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool:
raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
- return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
- model=self.model,
- credentials=self.credentials,
- text=text,
- user=user,
+ return cast(
+ bool,
+ self._round_robin_invoke(
+ function=self.model_type_instance.invoke,
+ model=self.model,
+ credentials=self.credentials,
+ text=text,
+ user=user,
+ ),
)
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str:
@@ -266,12 +284,15 @@ def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str
raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
- return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
- model=self.model,
- credentials=self.credentials,
- file=file,
- user=user,
+ return cast(
+ str,
+ self._round_robin_invoke(
+ function=self.model_type_instance.invoke,
+ model=self.model,
+ credentials=self.credentials,
+ file=file,
+ user=user,
+ ),
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]:
@@ -288,17 +309,20 @@ def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Option
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
- return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
- model=self.model,
- credentials=self.credentials,
- content_text=content_text,
- user=user,
- tenant_id=tenant_id,
- voice=voice,
+ return cast(
+ Iterable[bytes],
+ self._round_robin_invoke(
+ function=self.model_type_instance.invoke,
+ model=self.model,
+ credentials=self.credentials,
+ content_text=content_text,
+ user=user,
+ tenant_id=tenant_id,
+ voice=voice,
+ ),
)
- def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
+ def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any:
"""
Round-robin invoke
:param function: function to invoke
@@ -309,7 +333,7 @@ def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
if not self.load_balancing_manager:
return function(*args, **kwargs)
- last_exception = None
+ last_exception: Union[InvokeRateLimitError, InvokeAuthorizationError, InvokeConnectionError, None] = None
while True:
lb_config = self.load_balancing_manager.fetch_next()
if not lb_config:
@@ -463,7 +487,7 @@ def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]:
if real_index > max_index:
real_index = 0
- config = self._load_balancing_configs[real_index]
+ config: ModelLoadBalancingConfiguration = self._load_balancing_configs[real_index]
if self.in_cooldown(config):
cooldown_load_balancing_configs.append(config)
@@ -507,8 +531,7 @@ def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool:
self._tenant_id, self._provider, self._model_type.value, self._model, config.id
)
- res = redis_client.exists(cooldown_cache_key)
- res = cast(bool, res)
+ res: bool = redis_client.exists(cooldown_cache_key)
return res
@staticmethod
diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py
index 3b6b825244dfdc..1f21a2d3763c4a 100644
--- a/api/core/model_runtime/callbacks/logging_callback.py
+++ b/api/core/model_runtime/callbacks/logging_callback.py
@@ -1,7 +1,8 @@
import json
import logging
import sys
-from typing import Optional
+from collections.abc import Sequence
+from typing import Optional, cast
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@@ -20,7 +21,7 @@ def on_before_invoke(
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@@ -76,7 +77,7 @@ def on_new_chunk(
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
@@ -94,7 +95,7 @@ def on_new_chunk(
:param stream: is stream response
:param user: unique user id
"""
- sys.stdout.write(chunk.delta.message.content)
+ sys.stdout.write(cast(str, chunk.delta.message.content))
sys.stdout.flush()
def on_after_invoke(
@@ -106,7 +107,7 @@ def on_after_invoke(
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@@ -147,7 +148,7 @@ def on_invoke_error(
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ stop: Optional[Sequence[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py
index 0efe46f87d6de9..2f682ceef578dc 100644
--- a/api/core/model_runtime/entities/message_entities.py
+++ b/api/core/model_runtime/entities/message_entities.py
@@ -3,7 +3,7 @@
from enum import Enum, StrEnum
from typing import Optional
-from pydantic import BaseModel, Field, computed_field, field_validator
+from pydantic import BaseModel, Field, field_validator
class PromptMessageRole(Enum):
@@ -89,7 +89,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
url: str = Field(default="", description="the url of multi-modal file")
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
- @computed_field(return_type=str)
@property
def data(self):
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py
index 79a1d28ebe637e..e2b95603379348 100644
--- a/api/core/model_runtime/model_providers/__base/ai_model.py
+++ b/api/core/model_runtime/model_providers/__base/ai_model.py
@@ -1,7 +1,6 @@
import decimal
import os
from abc import ABC, abstractmethod
-from collections.abc import Mapping
from typing import Optional
from pydantic import ConfigDict
@@ -36,7 +35,7 @@ class AIModel(ABC):
model_config = ConfigDict(protected_namespaces=())
@abstractmethod
- def validate_credentials(self, model: str, credentials: Mapping) -> None:
+ def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
@@ -214,7 +213,7 @@ def predefined_models(self) -> list[AIModelEntity]:
return model_schemas
- def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]:
+ def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
"""
Get model schema by model name and credentials
@@ -236,9 +235,7 @@ def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) ->
return None
- def get_customizable_model_schema_from_credentials(
- self, model: str, credentials: Mapping
- ) -> Optional[AIModelEntity]:
+ def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema from credentials
@@ -248,7 +245,7 @@ def get_customizable_model_schema_from_credentials(
"""
return self._get_customizable_model_schema(model, credentials)
- def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
+ def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema and fill in the template
"""
@@ -301,7 +298,7 @@ def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Op
return schema
- def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema
diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py
index 8faeffa872b40f..402a30376b7546 100644
--- a/api/core/model_runtime/model_providers/__base/large_language_model.py
+++ b/api/core/model_runtime/model_providers/__base/large_language_model.py
@@ -2,7 +2,7 @@
import re
import time
from abc import abstractmethod
-from collections.abc import Generator, Mapping, Sequence
+from collections.abc import Generator, Sequence
from typing import Optional, Union
from pydantic import ConfigDict
@@ -48,7 +48,7 @@ def invoke(
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[Sequence[str]] = None,
+ stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
@@ -291,12 +291,12 @@ def _code_block_mode_stream_processor(
content = piece.delta.message.content
piece.delta.message.content = ""
yield piece
- piece = content
+ content_piece = content
else:
yield piece
continue
new_piece: str = ""
- for char in piece:
+ for char in content_piece:
char = str(char)
if state == "normal":
if char == "`":
@@ -350,7 +350,7 @@ def _code_block_mode_stream_processor_with_backtick(
piece.delta.message.content = ""
# Yield a piece with cleared content before processing it to maintain the generator structure
yield piece
- piece = content
+ content_piece = content
else:
# Yield pieces without content directly
yield piece
@@ -360,7 +360,7 @@ def _code_block_mode_stream_processor_with_backtick(
continue
new_piece: str = ""
- for char in piece:
+ for char in content_piece:
if state == "search_start":
if char == "`":
backtick_count += 1
@@ -535,7 +535,7 @@ def get_parameter_rules(self, model: str, credentials: dict) -> list[ParameterRu
return []
- def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode:
+ def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode:
"""
Get model mode
diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py
index 4374093de4ab38..36e3e7bd557163 100644
--- a/api/core/model_runtime/model_providers/__base/model_provider.py
+++ b/api/core/model_runtime/model_providers/__base/model_provider.py
@@ -104,9 +104,10 @@ def get_model_instance(self, model_type: ModelType) -> AIModel:
mod = import_module_from_source(
module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
)
+ # FIXME "type" has no attribute "__abstractmethods__" ignore it for now fix it later
model_class = next(
filter(
- lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
+ lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, # type: ignore
get_subclasses_from_module(mod, AIModel),
),
None,
diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py
index 2d38fba955fb86..33135129082b1d 100644
--- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py
+++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py
@@ -89,7 +89,8 @@ def _get_context_size(self, model: str, credentials: dict) -> int:
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties:
- return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
+ content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
+ return content_size
return 1000
@@ -104,6 +105,7 @@ def _get_max_chunks(self, model: str, credentials: dict) -> int:
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
- return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
+ max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
+ return max_chunks
return 1
diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py
index 5fe6dda6ad5d79..6dab0aaf2d41e7 100644
--- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py
+++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py
@@ -2,9 +2,9 @@
from threading import Lock
from typing import Any
-from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
+from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
-_tokenizer = None
+_tokenizer: Any = None
_lock = Lock()
diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py
index b394ea4e9d22fe..6ce316b137abb4 100644
--- a/api/core/model_runtime/model_providers/__base/tts_model.py
+++ b/api/core/model_runtime/model_providers/__base/tts_model.py
@@ -127,7 +127,8 @@ def _get_model_audio_type(self, model: str, credentials: dict) -> str:
if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties:
raise ValueError("this model does not support audio type")
- return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
+ audio_type: str = model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
+ return audio_type
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
"""
@@ -138,8 +139,9 @@ def _get_model_word_limit(self, model: str, credentials: dict) -> int:
if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties:
raise ValueError("this model does not support word limit")
+ world_limit: int = model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
- return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
+ return world_limit
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
"""
@@ -150,8 +152,9 @@ def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties:
raise ValueError("this model does not support max workers")
+ workers_limit: int = model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
- return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
+ return workers_limit
@staticmethod
def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):
diff --git a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py
index a2b14cf3dbe6d4..4aa09e61fd3599 100644
--- a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py
+++ b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py
@@ -64,10 +64,12 @@ def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) ->
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
+ if not ai_model_entity:
+ return None
return ai_model_entity.entity
@staticmethod
- def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
+ def _get_ai_model_entity(base_model_name: str, model: str) -> Optional[AzureBaseModel]:
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py
index 173b9d250c1743..6d50ba9163984f 100644
--- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py
+++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py
@@ -114,6 +114,8 @@ def _process_sentence(self, sentence: str, model: str, voice, credentials: dict)
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
+ if not ai_model_entity:
+ return None
return ai_model_entity.entity
@staticmethod
diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py
index 75ed7ad62404cb..29bd673d576fc9 100644
--- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py
+++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py
@@ -6,9 +6,9 @@
from typing import Optional, Union, cast
# 3rd import
-import boto3
-from botocore.config import Config
-from botocore.exceptions import (
+import boto3 # type: ignore
+from botocore.config import Config # type: ignore
+from botocore.exceptions import ( # type: ignore
ClientError,
EndpointConnectionError,
NoRegionError,
diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py
index aba8fedbc097e5..3a0a241f7ea0c0 100644
--- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py
+++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py
@@ -44,7 +44,7 @@ def _invoke(
:return: rerank result
"""
if len(docs) == 0:
- return RerankResult(model=model, docs=docs)
+ return RerankResult(model=model, docs=[])
# initialize client
client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url"))
@@ -62,7 +62,7 @@ def _invoke(
# format document
rerank_document = RerankDocument(
index=result.index,
- text=result.document.text,
+ text=result.document.text if result.document else "",
score=result.relevance_score,
)
diff --git a/api/core/model_runtime/model_providers/fireworks/_common.py b/api/core/model_runtime/model_providers/fireworks/_common.py
index 378ced3a4019ba..38d0a9dfbcadee 100644
--- a/api/core/model_runtime/model_providers/fireworks/_common.py
+++ b/api/core/model_runtime/model_providers/fireworks/_common.py
@@ -1,5 +1,3 @@
-from collections.abc import Mapping
-
import openai
from core.model_runtime.errors.invoke import (
@@ -13,7 +11,7 @@
class _CommonFireworks:
- def _to_credential_kwargs(self, credentials: Mapping) -> dict:
+ def _to_credential_kwargs(self, credentials: dict) -> dict:
"""
Transform credentials to kwargs for model instance
diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py
index c745a7e978f4be..4c036283893fcc 100644
--- a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py
@@ -1,5 +1,4 @@
import time
-from collections.abc import Mapping
from typing import Optional, Union
import numpy as np
@@ -93,7 +92,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int
"""
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
- def validate_credentials(self, model: str, credentials: Mapping) -> None:
+ def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
diff --git a/api/core/model_runtime/model_providers/gitee_ai/_common.py b/api/core/model_runtime/model_providers/gitee_ai/_common.py
index 0750f3b75d0542..ad6600faf7bc15 100644
--- a/api/core/model_runtime/model_providers/gitee_ai/_common.py
+++ b/api/core/model_runtime/model_providers/gitee_ai/_common.py
@@ -1,4 +1,4 @@
-from dashscope.common.error import (
+from dashscope.common.error import ( # type: ignore
AuthenticationError,
InvalidParameter,
RequestFailure,
diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py
index 832ba927406c4c..737d3d5c931221 100644
--- a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py
+++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Any, Optional
import httpx
@@ -51,7 +51,7 @@ def _invoke(
base_url = base_url.removesuffix("/")
try:
- body = {"model": model, "query": query, "documents": docs}
+ body: dict[str, Any] = {"model": model, "query": query, "documents": docs}
if top_n is not None:
body["top_n"] = top_n
response = httpx.post(
diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py
index b833c5652c650a..a1fa89c5b34af6 100644
--- a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py
@@ -24,7 +24,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
super().validate_credentials(model, credentials)
@staticmethod
- def _add_custom_parameters(credentials: dict, model: str) -> None:
+ def _add_custom_parameters(credentials: dict, model: Optional[str]) -> None:
if model is None:
model = "bge-m3"
diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py
index 36dcea405d0974..dc91257daf9d4e 100644
--- a/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py
+++ b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Any, Optional
import requests
@@ -13,9 +13,10 @@ class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
Model class for OpenAI text2speech model.
"""
+ # FIXME this Any return will be better type
def _invoke(
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
- ) -> any:
+ ) -> Any:
"""
_invoke text2speech model
@@ -47,7 +48,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
- def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
+ # FIXME this Any return will be better type
+ def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any:
"""
_tts_invoke_streaming text2speech model
:param model: model name
diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py
index 7d19ccbb74a011..98273f60a41190 100644
--- a/api/core/model_runtime/model_providers/google/llm/llm.py
+++ b/api/core/model_runtime/model_providers/google/llm/llm.py
@@ -7,7 +7,7 @@
from typing import Optional, Union
import google.ai.generativelanguage as glm
-import google.generativeai as genai
+import google.generativeai as genai # type: ignore
import requests
from google.api_core import exceptions
from google.generativeai.types import ContentType, File, GenerateContentResponse
diff --git a/api/core/model_runtime/model_providers/huggingface_hub/_common.py b/api/core/model_runtime/model_providers/huggingface_hub/_common.py
index 3c4020b6eedf24..d8a09265e21059 100644
--- a/api/core/model_runtime/model_providers/huggingface_hub/_common.py
+++ b/api/core/model_runtime/model_providers/huggingface_hub/_common.py
@@ -1,4 +1,4 @@
-from huggingface_hub.utils import BadRequestError, HfHubHTTPError
+from huggingface_hub.utils import BadRequestError, HfHubHTTPError # type: ignore
from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError
diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py
index 9d29237fdde573..cdb4103cd83712 100644
--- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py
+++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py
@@ -1,9 +1,9 @@
from collections.abc import Generator
from typing import Optional, Union
-from huggingface_hub import InferenceClient
-from huggingface_hub.hf_api import HfApi
-from huggingface_hub.utils import BadRequestError
+from huggingface_hub import InferenceClient # type: ignore
+from huggingface_hub.hf_api import HfApi # type: ignore
+from huggingface_hub.utils import BadRequestError # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py
index 8278d1e64def89..4ca5379405f4e6 100644
--- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py
@@ -4,7 +4,7 @@
import numpy as np
import requests
-from huggingface_hub import HfApi, InferenceClient
+from huggingface_hub import HfApi, InferenceClient # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py
index 2014de8516bc11..2dd45f065d5e26 100644
--- a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py
+++ b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py
@@ -3,11 +3,11 @@
from collections.abc import Generator
from typing import cast
-from tencentcloud.common import credential
-from tencentcloud.common.exception import TencentCloudSDKException
-from tencentcloud.common.profile.client_profile import ClientProfile
-from tencentcloud.common.profile.http_profile import HttpProfile
-from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
+from tencentcloud.common import credential # type: ignore
+from tencentcloud.common.exception import TencentCloudSDKException # type: ignore
+from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore
+from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore
+from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
@@ -305,7 +305,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str:
elif isinstance(message, ToolPromptMessage):
message_text = f"{tool_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
- message_text = content
+ message_text = content if isinstance(content, str) else ""
else:
raise ValueError(f"Got unknown type {message}")
diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py
index b6d857cb37cba0..856cda90d35a22 100644
--- a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py
@@ -3,11 +3,11 @@
import time
from typing import Optional
-from tencentcloud.common import credential
-from tencentcloud.common.exception import TencentCloudSDKException
-from tencentcloud.common.profile.client_profile import ClientProfile
-from tencentcloud.common.profile.http_profile import HttpProfile
-from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
+from tencentcloud.common import credential # type: ignore
+from tencentcloud.common.exception import TencentCloudSDKException # type: ignore
+from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore
+from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore
+from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py
index d80cbfa83d6425..1fc0f8c028ba92 100644
--- a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py
+++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py
@@ -1,11 +1,11 @@
from os.path import abspath, dirname, join
from threading import Lock
-from transformers import AutoTokenizer
+from transformers import AutoTokenizer # type: ignore
class JinaTokenizer:
- _tokenizer = None
+ _tokenizer: AutoTokenizer | None = None
_lock = Lock()
@classmethod
diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py
index 88cc0e8e0f32d0..357631b2dba0b9 100644
--- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py
+++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py
@@ -40,7 +40,7 @@ def generate(
url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}"
- extra_kwargs = {}
+ extra_kwargs: dict[str, Any] = {}
if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int:
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
@@ -117,19 +117,19 @@ def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage:
"""
handle chat generate response
"""
- response = response.json()
- if "base_resp" in response and response["base_resp"]["status_code"] != 0:
- code = response["base_resp"]["status_code"]
- msg = response["base_resp"]["status_msg"]
+ response_data = response.json()
+ if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0:
+ code = response_data["base_resp"]["status_code"]
+ msg = response_data["base_resp"]["status_msg"]
self._handle_error(code, msg)
- message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
+ message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
message.usage = {
"prompt_tokens": 0,
- "completion_tokens": response["usage"]["total_tokens"],
- "total_tokens": response["usage"]["total_tokens"],
+ "completion_tokens": response_data["usage"]["total_tokens"],
+ "total_tokens": response_data["usage"]["total_tokens"],
}
- message.stop_reason = response["choices"][0]["finish_reason"]
+ message.stop_reason = response_data["choices"][0]["finish_reason"]
return message
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
@@ -139,10 +139,10 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator
for line in response.iter_lines():
if not line:
continue
- line: str = line.decode("utf-8")
- if line.startswith("data: "):
- line = line[6:].strip()
- data = loads(line)
+ line_str: str = line.decode("utf-8")
+ if line_str.startswith("data: "):
+ line_str = line_str[6:].strip()
+ data = loads(line_str)
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
code = data["base_resp"]["status_code"]
@@ -162,5 +162,5 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator
continue
for choice in choices:
- message = choice["delta"]
- yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value)
+ message_choice = choice["delta"]
+ yield MinimaxMessage(content=message_choice, role=MinimaxMessage.Role.ASSISTANT.value)
diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
index 8b8fdbb6bdf558..284b61829f9729 100644
--- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
+++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
@@ -41,7 +41,7 @@ def generate(
url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}"
- extra_kwargs = {}
+ extra_kwargs: dict[str, Any] = {}
if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int:
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
@@ -122,19 +122,19 @@ def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage:
"""
handle chat generate response
"""
- response = response.json()
- if "base_resp" in response and response["base_resp"]["status_code"] != 0:
- code = response["base_resp"]["status_code"]
- msg = response["base_resp"]["status_msg"]
+ response_data = response.json()
+ if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0:
+ code = response_data["base_resp"]["status_code"]
+ msg = response_data["base_resp"]["status_msg"]
self._handle_error(code, msg)
- message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
+ message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value)
message.usage = {
"prompt_tokens": 0,
- "completion_tokens": response["usage"]["total_tokens"],
- "total_tokens": response["usage"]["total_tokens"],
+ "completion_tokens": response_data["usage"]["total_tokens"],
+ "total_tokens": response_data["usage"]["total_tokens"],
}
- message.stop_reason = response["choices"][0]["finish_reason"]
+ message.stop_reason = response_data["choices"][0]["finish_reason"]
return message
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
@@ -144,10 +144,10 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator
for line in response.iter_lines():
if not line:
continue
- line: str = line.decode("utf-8")
- if line.startswith("data: "):
- line = line[6:].strip()
- data = loads(line)
+ line_str: str = line.decode("utf-8")
+ if line_str.startswith("data: "):
+ line_str = line_str[6:].strip()
+ data = loads(line_str)
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
code = data["base_resp"]["status_code"]
diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py
index 88ebe5e2e00e7a..c248db374a2504 100644
--- a/api/core/model_runtime/model_providers/minimax/llm/types.py
+++ b/api/core/model_runtime/model_providers/minimax/llm/types.py
@@ -11,9 +11,9 @@ class Role(Enum):
role: str = Role.USER.value
content: str
- usage: dict[str, int] = None
+ usage: dict[str, int] | None = None
stop_reason: str = ""
- function_call: dict[str, Any] = None
+ function_call: dict[str, Any] | None = None
def to_dict(self) -> dict[str, Any]:
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
diff --git a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py
index 56a707333c40e9..8a4c19d4d8f71b 100644
--- a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py
@@ -2,8 +2,8 @@
from functools import wraps
from typing import Optional
-from nomic import embed
-from nomic import login as nomic_login
+from nomic import embed # type: ignore
+from nomic import login as nomic_login # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
diff --git a/api/core/model_runtime/model_providers/oci/llm/llm.py b/api/core/model_runtime/model_providers/oci/llm/llm.py
index 1e1fc5b3ea89aa..9f676573fc2ece 100644
--- a/api/core/model_runtime/model_providers/oci/llm/llm.py
+++ b/api/core/model_runtime/model_providers/oci/llm/llm.py
@@ -5,8 +5,8 @@
from collections.abc import Generator
from typing import Optional, Union
-import oci
-from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse
+import oci # type: ignore
+from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse # type: ignore
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py
index 50fa63768c241b..5a428c9fed0466 100644
--- a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py
@@ -4,7 +4,7 @@
from typing import Optional
import numpy as np
-import oci
+import oci # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py
index 83c4facc8db76c..3543fe58bb68d2 100644
--- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py
@@ -61,6 +61,7 @@ def _invoke(
headers = {"Content-Type": "application/json"}
endpoint_url = credentials.get("base_url")
+ assert endpoint_url is not None, "Base URL is required for Ollama API"
if not endpoint_url.endswith("/"):
endpoint_url += "/"
diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py
index 2181bb4f08fd8f..ac2b3e6881c740 100644
--- a/api/core/model_runtime/model_providers/openai/_common.py
+++ b/api/core/model_runtime/model_providers/openai/_common.py
@@ -1,5 +1,3 @@
-from collections.abc import Mapping
-
import openai
from httpx import Timeout
@@ -14,7 +12,7 @@
class _CommonOpenAI:
- def _to_credential_kwargs(self, credentials: Mapping) -> dict:
+ def _to_credential_kwargs(self, credentials: dict) -> dict:
"""
Transform credentials to kwargs for model instance
diff --git a/api/core/model_runtime/model_providers/openai/moderation/moderation.py b/api/core/model_runtime/model_providers/openai/moderation/moderation.py
index 619044d808cdf6..227e4b0c152a05 100644
--- a/api/core/model_runtime/model_providers/openai/moderation/moderation.py
+++ b/api/core/model_runtime/model_providers/openai/moderation/moderation.py
@@ -93,7 +93,8 @@ def _get_max_characters_per_chunk(self, model: str, credentials: dict) -> int:
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK in model_schema.model_properties:
- return model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK]
+ max_characters_per_chunk: int = model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK]
+ return max_characters_per_chunk
return 2000
@@ -108,6 +109,7 @@ def _get_max_chunks(self, model: str, credentials: dict) -> int:
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
- return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
+ max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
+ return max_chunks
return 1
diff --git a/api/core/model_runtime/model_providers/openai/openai.py b/api/core/model_runtime/model_providers/openai/openai.py
index aa6f38ce9fae5a..c546441af61d9b 100644
--- a/api/core/model_runtime/model_providers/openai/openai.py
+++ b/api/core/model_runtime/model_providers/openai/openai.py
@@ -1,5 +1,4 @@
import logging
-from collections.abc import Mapping
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -9,7 +8,7 @@
class OpenAIProvider(ModelProvider):
- def validate_provider_credentials(self, credentials: Mapping) -> None:
+ def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py
index a490537e51a6ad..74229a089aa45e 100644
--- a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py
+++ b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py
@@ -33,6 +33,7 @@ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get("endpoint_url")
+ assert endpoint_url is not None, "endpoint_url is required in credentials"
if not endpoint_url.endswith("/"):
endpoint_url += "/"
endpoint_url = urljoin(endpoint_url, "audio/transcriptions")
diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py
index 9da8f55d0a7ed9..b4d6c6c6ca9942 100644
--- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py
@@ -55,6 +55,7 @@ def _invoke(
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get("endpoint_url")
+ assert endpoint_url is not None, "endpoint_url is required in credentials"
if not endpoint_url.endswith("/"):
endpoint_url += "/"
diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py b/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py
index 8239c625f7ada8..53e895b0ecb376 100644
--- a/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py
+++ b/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py
@@ -44,6 +44,7 @@ def _invoke(
# Construct endpoint URL
endpoint_url = credentials.get("endpoint_url")
+ assert endpoint_url is not None, "endpoint_url is required in credentials"
if not endpoint_url.endswith("/"):
endpoint_url += "/"
endpoint_url = urljoin(endpoint_url, "audio/speech")
diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py
index 2789a9250a1d35..e9509b544d9f4e 100644
--- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py
+++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py
@@ -1,7 +1,7 @@
from collections.abc import Generator
from enum import Enum
from json import dumps, loads
-from typing import Any, Union
+from typing import Any, Optional, Union
from requests import Response, post
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
@@ -20,7 +20,7 @@ class Role(Enum):
role: str = Role.USER.value
content: str
- usage: dict[str, int] = None
+ usage: Optional[dict[str, int]] = None
stop_reason: str = ""
def to_dict(self) -> dict[str, Any]:
@@ -165,17 +165,17 @@ def _handle_chat_stream_generate_response(
if not line:
continue
- line: str = line.decode("utf-8")
- if line.startswith("data: "):
- line = line[6:].strip()
+ line_str: str = line.decode("utf-8")
+ if line_str.startswith("data: "):
+ line_str = line_str[6:].strip()
- if line == "[DONE]":
+ if line_str == "[DONE]":
return
try:
- data = loads(line)
+ data = loads(line_str)
except Exception as e:
- raise InternalServerError(f"Failed to convert response to json: {e} with text: {line}")
+ raise InternalServerError(f"Failed to convert response to json: {e} with text: {line_str}")
output = data["outputs"]
diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py
index 7bbd31e87c595d..40ea4dc0118026 100644
--- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py
@@ -53,14 +53,16 @@ def _invoke(
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
-
+ endpoint_url: Optional[str]
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
endpoint_url = "https://cloud.perfxlab.cn/v1/"
else:
endpoint_url = credentials.get("endpoint_url")
+ assert endpoint_url is not None, "endpoint_url is required in credentials"
if not endpoint_url.endswith("/"):
endpoint_url += "/"
+ assert isinstance(endpoint_url, str)
endpoint_url = urljoin(endpoint_url, "embeddings")
extra_model_kwargs = {}
@@ -142,13 +144,16 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
+ endpoint_url: Optional[str]
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
endpoint_url = "https://cloud.perfxlab.cn/v1/"
else:
endpoint_url = credentials.get("endpoint_url")
+ assert endpoint_url is not None, "endpoint_url is required in credentials"
if not endpoint_url.endswith("/"):
endpoint_url += "/"
+ assert isinstance(endpoint_url, str)
endpoint_url = urljoin(endpoint_url, "embeddings")
payload = {"input": "ping", "model": model}
diff --git a/api/core/model_runtime/model_providers/replicate/_common.py b/api/core/model_runtime/model_providers/replicate/_common.py
index 915f6e0eefcd08..3e2cf2adb306db 100644
--- a/api/core/model_runtime/model_providers/replicate/_common.py
+++ b/api/core/model_runtime/model_providers/replicate/_common.py
@@ -1,4 +1,4 @@
-from replicate.exceptions import ModelError, ReplicateError
+from replicate.exceptions import ModelError, ReplicateError # type: ignore
from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError
diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py
index 3641b35dc02a39..1e7858100b0429 100644
--- a/api/core/model_runtime/model_providers/replicate/llm/llm.py
+++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py
@@ -1,9 +1,9 @@
from collections.abc import Generator
from typing import Optional, Union
-from replicate import Client as ReplicateClient
-from replicate.exceptions import ReplicateError
-from replicate.prediction import Prediction
+from replicate import Client as ReplicateClient # type: ignore
+from replicate.exceptions import ReplicateError # type: ignore
+from replicate.prediction import Prediction # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py
index 41759fe07d0cac..aaf825388a9043 100644
--- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py
@@ -2,11 +2,11 @@
import time
from typing import Optional
-from replicate import Client as ReplicateClient
+from replicate import Client as ReplicateClient # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
-from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
@@ -86,7 +86,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> Option
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
- model_properties={"context_size": 4096, "max_chunks": 1},
+ model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096, ModelPropertyKey.MAX_CHUNKS: 1},
)
return entity
diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py
index 5ff00f008eb621..b8c979b1f53ce9 100644
--- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py
+++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py
@@ -4,7 +4,7 @@
from collections.abc import Generator, Iterator
from typing import Any, Optional, Union, cast
-import boto3
+import boto3 # type: ignore
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
@@ -83,7 +83,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
sagemaker_session: Any = None
predictor: Any = None
- sagemaker_endpoint: str = None
+ sagemaker_endpoint: str | None = None
def _handle_chat_generate_response(
self,
@@ -209,8 +209,8 @@ def _invoke(
:param user: unique user id
:return: full response or stream response chunk generator result
"""
- from sagemaker import Predictor, serializers
- from sagemaker.session import Session
+ from sagemaker import Predictor, serializers # type: ignore
+ from sagemaker.session import Session # type: ignore
if not self.sagemaker_session:
access_key = credentials.get("aws_access_key_id")
diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py
index df797bae265825..7daab6d8653d33 100644
--- a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py
+++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py
@@ -3,7 +3,7 @@
import operator
from typing import Any, Optional
-import boto3
+import boto3 # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
@@ -114,6 +114,7 @@ def _invoke(
except Exception as e:
logger.exception(f"Failed to invoke rerank model, model: {model}")
+ raise InvokeError(f"Failed to invoke rerank model, model: {model}, error: {str(e)}")
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py
index 2d50e9c7b4c28a..a6aca130456063 100644
--- a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py
+++ b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py
@@ -2,7 +2,7 @@
import logging
from typing import IO, Any, Optional
-import boto3
+import boto3 # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
@@ -67,6 +67,7 @@ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional
s3_prefix = "dify/speech2text/"
sagemaker_endpoint = credentials.get("sagemaker_endpoint")
bucket = credentials.get("audio_s3_cache_bucket")
+ assert bucket is not None, "audio_s3_cache_bucket is required in credentials"
s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix)
payload = {"audio_s3_presign_uri": s3_presign_url}
diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py
index ef4ddcd6a72847..e7eccd997d11c1 100644
--- a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py
@@ -4,7 +4,7 @@
import time
from typing import Any, Optional
-import boto3
+import boto3 # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
@@ -118,6 +118,7 @@ def _invoke(
except Exception as e:
logger.exception(f"Failed to invoke text embedding model, model: {model}, line: {line}")
+ raise InvokeError(str(e))
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
diff --git a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py
index 6a5946453be07f..62231c518deef1 100644
--- a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py
+++ b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py
@@ -5,7 +5,7 @@
from enum import Enum
from typing import Any, Optional
-import boto3
+import boto3 # type: ignore
import requests
from core.model_runtime.entities.common_entities import I18nObject
diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py
index e3a323a4965bc7..f61e8b82e4db99 100644
--- a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py
+++ b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py
@@ -43,7 +43,7 @@ def _add_custom_parameters(cls, credentials: dict) -> None:
credentials["mode"] = "chat"
credentials["endpoint_url"] = "https://api.siliconflow.cn/v1"
- def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
return AIModelEntity(
model=model,
label=I18nObject(en_US=model, zh_Hans=model),
diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py
index 1181ba699af886..cb6f28b6c27fa9 100644
--- a/api/core/model_runtime/model_providers/spark/llm/llm.py
+++ b/api/core/model_runtime/model_providers/spark/llm/llm.py
@@ -1,6 +1,6 @@
import threading
from collections.abc import Generator
-from typing import Optional, Union
+from typing import Optional, Union, cast
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
@@ -270,7 +270,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str:
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
- message_text = content
+ message_text = cast(str, content)
else:
raise ValueError(f"Got unknown type {message}")
diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py
index b96d43979ef54a..03eac194235e83 100644
--- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py
+++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py
@@ -12,6 +12,7 @@
AIModelEntity,
DefaultParameterName,
FetchFrom,
+ ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
@@ -67,7 +68,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
REPETITION_PENALTY = "repetition_penalty"
TOP_K = "top_k"
- features = []
+ features: list[ModelFeature] = []
entity = AIModelEntity(
model=model,
diff --git a/api/core/model_runtime/model_providers/tongyi/_common.py b/api/core/model_runtime/model_providers/tongyi/_common.py
index 8a50c7aa05f38c..bb68319555007f 100644
--- a/api/core/model_runtime/model_providers/tongyi/_common.py
+++ b/api/core/model_runtime/model_providers/tongyi/_common.py
@@ -1,4 +1,4 @@
-from dashscope.common.error import (
+from dashscope.common.error import ( # type: ignore
AuthenticationError,
InvalidParameter,
RequestFailure,
diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py
index 0c1f6518811aa8..61ebd45ed64a6d 100644
--- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py
+++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py
@@ -7,9 +7,9 @@
from pathlib import Path
from typing import Optional, Union, cast
-from dashscope import Generation, MultiModalConversation, get_tokenizer
-from dashscope.api_entities.dashscope_response import GenerationResponse
-from dashscope.common.error import (
+from dashscope import Generation, MultiModalConversation, get_tokenizer # type: ignore
+from dashscope.api_entities.dashscope_response import GenerationResponse # type: ignore
+from dashscope.common.error import ( # type: ignore
AuthenticationError,
InvalidParameter,
RequestFailure,
diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py
index a5ce9ead6ee3be..ed682cb0f3c1e4 100644
--- a/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py
+++ b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py
@@ -1,7 +1,7 @@
from typing import Optional
-import dashscope
-from dashscope.common.error import (
+import dashscope # type: ignore
+from dashscope.common.error import ( # type: ignore
AuthenticationError,
InvalidParameter,
RequestFailure,
@@ -51,7 +51,7 @@ def _invoke(
:return: rerank result
"""
if len(docs) == 0:
- return RerankResult(model=model, docs=docs)
+ return RerankResult(model=model, docs=[])
# initialize client
dashscope.api_key = credentials["dashscope_api_key"]
@@ -64,7 +64,7 @@ def _invoke(
return_documents=True,
)
- rerank_documents = []
+ rerank_documents: list[RerankDocument] = []
if not response.output:
return RerankResult(model=model, docs=rerank_documents)
for _, result in enumerate(response.output.results):
diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py
index 2ef7f3f5774481..8c53be413002a9 100644
--- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py
@@ -1,7 +1,7 @@
import time
from typing import Optional
-import dashscope
+import dashscope # type: ignore
import numpy as np
from core.entities.embedding_type import EmbeddingInputType
diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py
index ca3b9fbc1c3c00..a654e2d760d7c4 100644
--- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py
+++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py
@@ -2,10 +2,10 @@
from queue import Queue
from typing import Any, Optional
-import dashscope
-from dashscope import SpeechSynthesizer
-from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
-from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult
+import dashscope # type: ignore
+from dashscope import SpeechSynthesizer # type: ignore
+from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse # type: ignore
+from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult # type: ignore
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
diff --git a/api/core/model_runtime/model_providers/upstage/_common.py b/api/core/model_runtime/model_providers/upstage/_common.py
index 47ebaccd84ab8a..f6609bba77129b 100644
--- a/api/core/model_runtime/model_providers/upstage/_common.py
+++ b/api/core/model_runtime/model_providers/upstage/_common.py
@@ -1,5 +1,3 @@
-from collections.abc import Mapping
-
import openai
from httpx import Timeout
@@ -14,7 +12,7 @@
class _CommonUpstage:
- def _to_credential_kwargs(self, credentials: Mapping) -> dict:
+ def _to_credential_kwargs(self, credentials: dict) -> dict:
"""
Transform credentials to kwargs for model instance
diff --git a/api/core/model_runtime/model_providers/upstage/llm/llm.py b/api/core/model_runtime/model_providers/upstage/llm/llm.py
index a18ee906248a49..2bf6796ca5cf45 100644
--- a/api/core/model_runtime/model_providers/upstage/llm/llm.py
+++ b/api/core/model_runtime/model_providers/upstage/llm/llm.py
@@ -6,7 +6,7 @@
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
from openai.types.chat.chat_completion_message import FunctionCall
-from tokenizers import Tokenizer
+from tokenizers import Tokenizer # type: ignore
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py
index 5b340e53bbc543..87693eca768dfd 100644
--- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py
@@ -1,11 +1,10 @@
import base64
import time
-from collections.abc import Mapping
from typing import Union
import numpy as np
from openai import OpenAI
-from tokenizers import Tokenizer
+from tokenizers import Tokenizer # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
@@ -132,7 +131,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int
return total_num_tokens
- def validate_credentials(self, model: str, credentials: Mapping) -> None:
+ def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
diff --git a/api/core/model_runtime/model_providers/vertex_ai/_common.py b/api/core/model_runtime/model_providers/vertex_ai/_common.py
index 8f7c859e3803c0..4e3df7574e9ce8 100644
--- a/api/core/model_runtime/model_providers/vertex_ai/_common.py
+++ b/api/core/model_runtime/model_providers/vertex_ai/_common.py
@@ -12,4 +12,4 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]
:return: Invoke error mapping
"""
- pass
+ raise NotImplementedError
diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
index c50e0f794616b3..85be34f3f0fe7f 100644
--- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
+++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
@@ -6,7 +6,7 @@
from collections.abc import Generator
from typing import TYPE_CHECKING, Optional, Union, cast
-import google.auth.transport.requests
+import google.auth.transport.requests # type: ignore
import requests
from anthropic import AnthropicVertex, Stream
from anthropic.types import (
diff --git a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py
index 034c066ab5f071..782e4fd6232a3b 100644
--- a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py
+++ b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py
@@ -17,14 +17,12 @@
class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
- features = []
-
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
- features=features,
+ features=[],
model_properties={
ModelPropertyKey.MODE: credentials.get("mode"),
},
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py
index 1cffd902c7a25d..a8a015167e3227 100644
--- a/api/core/model_runtime/model_providers/volcengine_maas/client.py
+++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py
@@ -1,8 +1,8 @@
from collections.abc import Generator
from typing import Optional, cast
-from volcenginesdkarkruntime import Ark
-from volcenginesdkarkruntime.types.chat import (
+from volcenginesdkarkruntime import Ark # type: ignore
+from volcenginesdkarkruntime.types.chat import ( # type: ignore
ChatCompletion,
ChatCompletionAssistantMessageParam,
ChatCompletionChunk,
@@ -15,10 +15,10 @@
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
)
-from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL
-from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function
-from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse
-from volcenginesdkarkruntime.types.shared_params import FunctionDefinition
+from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL # type: ignore
+from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function # type: ignore
+from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse # type: ignore
+from volcenginesdkarkruntime.types.shared_params import FunctionDefinition # type: ignore
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py
index 91dbe21a616195..aa837b8318873d 100644
--- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py
+++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py
@@ -152,5 +152,6 @@ class ServiceNotOpenError(MaasError):
def wrap_error(e: MaasError) -> Exception:
if ErrorCodeMap.get(e.code):
- return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id)
+ # FIXME: mypy type error, try to fix it instead of using type: ignore
+ return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) # type: ignore
return e
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py
index 9e19b7dedaa5a7..f0b2b101b7be9d 100644
--- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py
+++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py
@@ -2,7 +2,7 @@
from collections.abc import Generator
from typing import Optional
-from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk
+from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py
index cf3cf23cfb9cef..7c37368086e0e6 100644
--- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py
+++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py
@@ -1,3 +1,5 @@
+from typing import Any
+
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMMode
@@ -102,7 +104,7 @@ def get_model_config(credentials: dict) -> ModelConfig:
def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None):
- req_params = {}
+ req_params: dict[str, Any] = {}
# predefined properties
model_configs = get_model_config(credentials)
if model_configs:
@@ -130,7 +132,7 @@ def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str]
def get_v3_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None):
- req_params = {}
+ req_params: dict[str, Any] = {}
# predefined properties
model_configs = get_model_config(credentials)
if model_configs:
diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py
index 07b970f8104c8f..d2899795696aa4 100644
--- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py
+++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py
@@ -1,7 +1,7 @@
from collections.abc import Generator
from enum import Enum
from json import dumps, loads
-from typing import Any, Union
+from typing import Any, Optional, Union
from requests import Response, post
@@ -22,7 +22,7 @@ class Role(Enum):
role: str = Role.USER.value
content: str
- usage: dict[str, int] = None
+ usage: Optional[dict[str, int]] = None
stop_reason: str = ""
def to_dict(self) -> dict[str, Any]:
@@ -135,6 +135,7 @@ def _build_function_calling_request_body(
"""
TODO: implement function calling
"""
+ raise NotImplementedError("Function calling is not supported yet.")
def _build_chat_request_body(
self,
diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py
index 19135deb27380d..816b3b98c4b8c5 100644
--- a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py
@@ -1,6 +1,5 @@
import time
from abc import abstractmethod
-from collections.abc import Mapping
from json import dumps
from typing import Any, Optional
@@ -23,12 +22,12 @@
class TextEmbedding:
@abstractmethod
- def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
+ def embed_documents(self, model: str, texts: list[str], user: str) -> tuple[list[list[float]], int, int]:
raise NotImplementedError
class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
- def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
+ def embed_documents(self, model: str, texts: list[str], user: str) -> tuple[list[list[float]], int, int]:
access_token = self._get_access_token()
url = f"{self.api_bases[model]}?access_token={access_token}"
body = self._build_embed_request_body(model, texts, user)
@@ -50,7 +49,7 @@ def _build_embed_request_body(self, model: str, texts: list[str], user: str) ->
}
return body
- def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int):
+ def _handle_embed_response(self, model: str, response: Response) -> tuple[list[list[float]], int, int]:
data = response.json()
if "error_code" in data:
code = data["error_code"]
@@ -147,7 +146,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int
return total_num_tokens
- def validate_credentials(self, model: str, credentials: Mapping) -> None:
+ def validate_credentials(self, model: str, credentials: dict) -> None:
api_key = credentials["api_key"]
secret_key = credentials["secret_key"]
try:
diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py
index 8d86d6937d8ac9..7db1203641cad2 100644
--- a/api/core/model_runtime/model_providers/xinference/llm/llm.py
+++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py
@@ -17,7 +17,7 @@
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
from openai.types.chat.chat_completion_message import FunctionCall
from openai.types.completion import Completion
-from xinference_client.client.restful.restful_client import (
+from xinference_client.client.restful.restful_client import ( # type: ignore
Client,
RESTfulChatModelHandle,
RESTfulGenerateModelHandle,
diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py
index efaf114854b5c1..078ec0537a37f4 100644
--- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py
+++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py
@@ -1,6 +1,6 @@
from typing import Optional
-from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle
+from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py
index 3d7aefeb6dd89a..5f330ece1a5750 100644
--- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py
+++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py
@@ -1,6 +1,6 @@
from typing import IO, Optional
-from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle
+from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py
index e51e6a941c5413..9054aabab2dd05 100644
--- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py
@@ -1,7 +1,7 @@
import time
from typing import Optional
-from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle
+from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
@@ -134,7 +134,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
try:
handle = client.get_model(model_uid=model_uid)
except RuntimeError as e:
- raise InvokeAuthorizationError(e)
+ raise InvokeAuthorizationError(str(e))
if not isinstance(handle, RESTfulEmbeddingModelHandle):
raise InvokeBadRequestError(
diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py
index ad7b64efb5d2e7..8aa39d4de0d2cb 100644
--- a/api/core/model_runtime/model_providers/xinference/tts/tts.py
+++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py
@@ -1,7 +1,7 @@
import concurrent.futures
from typing import Any, Optional
-from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle
+from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle # type: ignore
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
@@ -74,11 +74,14 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
credentials["server_url"] = credentials["server_url"].removesuffix("/")
+ api_key = credentials.get("api_key")
+ if api_key is None:
+ raise CredentialsValidateFailedError("api_key is required")
extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials["server_url"],
model_uid=credentials["model_uid"],
- api_key=credentials.get("api_key"),
+ api_key=api_key,
)
if "text-to-audio" not in extra_param.model_ability:
diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py
index baa3ccbe8adbc0..b51423f4eda2e6 100644
--- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py
+++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py
@@ -1,6 +1,6 @@
from threading import Lock
from time import time
-from typing import Optional
+from typing import Any, Optional
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, MissingSchema, Timeout
@@ -39,13 +39,15 @@ def __init__(
self.model_family = model_family
-cache = {}
+cache: dict[str, dict[str, Any]] = {}
cache_lock = Lock()
class XinferenceHelper:
@staticmethod
- def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
+ def get_xinference_extra_parameter(
+ server_url: str, model_uid: str, api_key: str | None
+ ) -> XinferenceModelExtraParameter:
XinferenceHelper._clean_cache()
with cache_lock:
if model_uid not in cache:
@@ -66,7 +68,9 @@ def _clean_cache() -> None:
pass
@staticmethod
- def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
+ def _get_xinference_extra_parameter(
+ server_url: str, model_uid: str, api_key: str | None
+ ) -> XinferenceModelExtraParameter:
"""
get xinference model extra parameter like model_format and model_handle_type
"""
diff --git a/api/core/model_runtime/model_providers/yi/llm/llm.py b/api/core/model_runtime/model_providers/yi/llm/llm.py
index 0642e72ed500e1..f5b61e207635bc 100644
--- a/api/core/model_runtime/model_providers/yi/llm/llm.py
+++ b/api/core/model_runtime/model_providers/yi/llm/llm.py
@@ -136,7 +136,7 @@ def _add_custom_parameters(credentials: dict) -> None:
parsed_url = urlparse(credentials["endpoint_url"])
credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}"
- def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
return AIModelEntity(
model=model,
label=I18nObject(en_US=model, zh_Hans=model),
diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py
index 59861507e45cd6..eef86cc52c36e8 100644
--- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py
+++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py
@@ -1,9 +1,9 @@
from collections.abc import Generator
from typing import Optional, Union
-from zhipuai import ZhipuAI
-from zhipuai.types.chat.chat_completion import Completion
-from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk
+from zhipuai import ZhipuAI # type: ignore
+from zhipuai.types.chat.chat_completion import Completion # type: ignore
+from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk # type: ignore
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py
index 2428284ba9a8ff..a700304db7b6f3 100644
--- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py
@@ -1,7 +1,7 @@
import time
from typing import Optional
-from zhipuai import ZhipuAI
+from zhipuai import ZhipuAI # type: ignore
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py
index 029ec1a581b2e9..8cc8adfc3656ea 100644
--- a/api/core/model_runtime/schema_validators/common_validator.py
+++ b/api/core/model_runtime/schema_validators/common_validator.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Union, cast
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType
@@ -38,7 +38,7 @@ def _validate_and_filter_credential_form_schemas(
def _validate_credential_form_schema(
self, credential_form_schema: CredentialFormSchema, credentials: dict
- ) -> Optional[str]:
+ ) -> Union[str, bool, None]:
"""
Validate credential form schema
@@ -47,6 +47,7 @@ def _validate_credential_form_schema(
:return: validated credential form schema value
"""
# If the variable does not exist in credentials
+ value: Union[str, bool, None] = None
if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]:
# If required is True, an exception is thrown
if credential_form_schema.required:
@@ -61,7 +62,7 @@ def _validate_credential_form_schema(
return None
# Get the value corresponding to the variable from credentials
- value = credentials[credential_form_schema.variable]
+ value = cast(str, credentials[credential_form_schema.variable])
# If max_length=0, no validation is performed
if credential_form_schema.max_length:
diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py
index ec1bad5698f2eb..03e350627140cf 100644
--- a/api/core/model_runtime/utils/encoders.py
+++ b/api/core/model_runtime/utils/encoders.py
@@ -129,7 +129,8 @@ def jsonable_encoder(
sqlalchemy_safe=sqlalchemy_safe,
)
if dataclasses.is_dataclass(obj):
- obj_dict = dataclasses.asdict(obj)
+ # FIXME: mypy error, try to fix it instead of using type: ignore
+ obj_dict = dataclasses.asdict(obj) # type: ignore
return jsonable_encoder(
obj_dict,
by_alias=by_alias,
diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py
index 2067092d80f582..5e8a723ec7c510 100644
--- a/api/core/model_runtime/utils/helper.py
+++ b/api/core/model_runtime/utils/helper.py
@@ -4,6 +4,7 @@
def dump_model(model: BaseModel) -> dict:
if hasattr(pydantic, "model_dump"):
- return pydantic.model_dump(model)
+ # FIXME mypy error, try to fix it instead of using type: ignore
+ return pydantic.model_dump(model) # type: ignore
else:
return model.model_dump()
diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py
index 094ad7863603dc..c65a3885fd1eb9 100644
--- a/api/core/moderation/api/api.py
+++ b/api/core/moderation/api/api.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
from pydantic import BaseModel
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
@@ -43,6 +45,8 @@ def validate_config(cls, tenant_id: str, config: dict) -> None:
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
+ if self.config is None:
+ raise ValueError("The config is not set.")
if self.config["inputs_config"]["enabled"]:
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
@@ -57,6 +61,8 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""
+ if self.config is None:
+ raise ValueError("The config is not set.")
if self.config["outputs_config"]["enabled"]:
params = ModerationOutputParams(app_id=self.app_id, text=text)
@@ -69,14 +75,18 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
)
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
- extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id"))
+ if self.config is None:
+ raise ValueError("The config is not set.")
+ extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))
+ if not extension:
+ raise ValueError("API-based Extension not found. Please check it again.")
requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key))
result = requestor.request(extension_point, params)
return result
@staticmethod
- def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
+ def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
extension = (
db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py
index 60898d5547ae3b..d8c392d0970e19 100644
--- a/api/core/moderation/base.py
+++ b/api/core/moderation/base.py
@@ -100,14 +100,14 @@ def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_re
if not inputs_config.get("preset_response"):
raise ValueError("inputs_config.preset_response is required")
- if len(inputs_config.get("preset_response")) > 100:
+ if len(inputs_config.get("preset_response", 0)) > 100:
raise ValueError("inputs_config.preset_response must be less than 100 characters")
if outputs_config_enabled:
if not outputs_config.get("preset_response"):
raise ValueError("outputs_config.preset_response is required")
- if len(outputs_config.get("preset_response")) > 100:
+ if len(outputs_config.get("preset_response", 0)) > 100:
raise ValueError("outputs_config.preset_response must be less than 100 characters")
diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py
index 96bf2ab54b41eb..0ad4438c143870 100644
--- a/api/core/moderation/factory.py
+++ b/api/core/moderation/factory.py
@@ -22,7 +22,8 @@ def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
"""
code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
- extension_class.validate_config(tenant_id, config)
+ # FIXME: mypy error, try to fix it instead of using type: ignore
+ extension_class.validate_config(tenant_id, config) # type: ignore
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
"""
diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py
index 46d3963bd07f5a..3ac33966cb14bf 100644
--- a/api/core/moderation/input_moderation.py
+++ b/api/core/moderation/input_moderation.py
@@ -1,5 +1,6 @@
import logging
-from typing import Optional
+from collections.abc import Mapping
+from typing import Any, Optional
from core.app.app_config.entities import AppConfig
from core.moderation.base import ModerationAction, ModerationError
@@ -17,11 +18,11 @@ def check(
app_id: str,
tenant_id: str,
app_config: AppConfig,
- inputs: dict,
+ inputs: Mapping[str, Any],
query: str,
message_id: str,
trace_manager: Optional[TraceQueueManager] = None,
- ) -> tuple[bool, dict, str]:
+ ) -> tuple[bool, Mapping[str, Any], str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
@@ -33,6 +34,7 @@ def check(
:param trace_manager: trace manager
:return:
"""
+ inputs = dict(inputs)
if not app_config.sensitive_word_avoidance:
return False, inputs, query
diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py
index 00b3c56c03602d..9dd2665c3bf3d3 100644
--- a/api/core/moderation/keywords/keywords.py
+++ b/api/core/moderation/keywords/keywords.py
@@ -21,7 +21,7 @@ def validate_config(cls, tenant_id: str, config: dict) -> None:
if not config.get("keywords"):
raise ValueError("keywords is required")
- if len(config.get("keywords")) > 10000:
+ if len(config.get("keywords", [])) > 10000:
raise ValueError("keywords length must be less than 10000")
keywords_row_len = config["keywords"].split("\n")
@@ -31,6 +31,8 @@ def validate_config(cls, tenant_id: str, config: dict) -> None:
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
+ if self.config is None:
+ raise ValueError("The config is not set.")
if self.config["inputs_config"]["enabled"]:
preset_response = self.config["inputs_config"]["preset_response"]
@@ -50,6 +52,8 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""
+ if self.config is None:
+ raise ValueError("The config is not set.")
if self.config["outputs_config"]["enabled"]:
# Filter out empty values
diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py
index 6465de23b9a2de..d64f17b383e0b5 100644
--- a/api/core/moderation/openai_moderation/openai_moderation.py
+++ b/api/core/moderation/openai_moderation/openai_moderation.py
@@ -20,6 +20,8 @@ def validate_config(cls, tenant_id: str, config: dict) -> None:
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
+ if self.config is None:
+ raise ValueError("The config is not set.")
if self.config["inputs_config"]["enabled"]:
preset_response = self.config["inputs_config"]["preset_response"]
@@ -35,6 +37,8 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""
+ if self.config is None:
+ raise ValueError("The config is not set.")
if self.config["outputs_config"]["enabled"]:
flagged = self._is_violated({"text": text})
diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py
index 4635bd9c251851..e595be126c7824 100644
--- a/api/core/moderation/output_moderation.py
+++ b/api/core/moderation/output_moderation.py
@@ -70,7 +70,7 @@ def start_thread(self) -> threading.Thread:
thread = threading.Thread(
target=self.worker,
kwargs={
- "flask_app": current_app._get_current_object(),
+ "flask_app": current_app._get_current_object(), # type: ignore
"buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE,
},
)
diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py
index 71ff03b6ef5160..f0e34c0cd71241 100644
--- a/api/core/ops/entities/trace_entity.py
+++ b/api/core/ops/entities/trace_entity.py
@@ -1,3 +1,4 @@
+from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any, Optional, Union
@@ -38,8 +39,8 @@ class WorkflowTraceInfo(BaseTraceInfo):
workflow_run_id: str
workflow_run_elapsed_time: Union[int, float]
workflow_run_status: str
- workflow_run_inputs: dict[str, Any]
- workflow_run_outputs: dict[str, Any]
+ workflow_run_inputs: Mapping[str, Any]
+ workflow_run_outputs: Mapping[str, Any]
workflow_run_version: str
error: Optional[str] = None
total_tokens: int
diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py
index 29fdebd8feaeb8..b9ba068b19936d 100644
--- a/api/core/ops/langfuse_trace/langfuse_trace.py
+++ b/api/core/ops/langfuse_trace/langfuse_trace.py
@@ -77,8 +77,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo):
id=trace_id,
user_id=user_id,
name=name,
- input=trace_info.workflow_run_inputs,
- output=trace_info.workflow_run_outputs,
+ input=dict(trace_info.workflow_run_inputs),
+ output=dict(trace_info.workflow_run_outputs),
metadata=metadata,
session_id=trace_info.conversation_id,
tags=["message", "workflow"],
@@ -87,8 +87,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo):
workflow_span_data = LangfuseSpan(
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
- input=trace_info.workflow_run_inputs,
- output=trace_info.workflow_run_outputs,
+ input=dict(trace_info.workflow_run_inputs),
+ output=dict(trace_info.workflow_run_outputs),
trace_id=trace_id,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
@@ -102,8 +102,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo):
id=trace_id,
user_id=user_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
- input=trace_info.workflow_run_inputs,
- output=trace_info.workflow_run_outputs,
+ input=dict(trace_info.workflow_run_inputs),
+ output=dict(trace_info.workflow_run_outputs),
metadata=metadata,
session_id=trace_info.conversation_id,
tags=["workflow"],
diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py
index 99221d669b3193..348b7ba5012b6b 100644
--- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py
+++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py
@@ -49,7 +49,6 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run")
input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run")
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
- dotted_order: Optional[str] = Field(None, description="Dotted order of the run")
@field_validator("inputs", "outputs")
@classmethod
diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py
index 672843e5a8f986..4ffd888bddf8a3 100644
--- a/api/core/ops/langsmith_trace/langsmith_trace.py
+++ b/api/core/ops/langsmith_trace/langsmith_trace.py
@@ -3,6 +3,7 @@
import os
import uuid
from datetime import datetime, timedelta
+from typing import Optional, cast
from langsmith import Client
from langsmith.schemas import RunBase
@@ -63,6 +64,8 @@ def trace(self, trace_info: BaseTraceInfo):
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.message_id or trace_info.workflow_run_id
+ if trace_info.start_time is None:
+ trace_info.start_time = datetime.now()
message_dotted_order = (
generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
)
@@ -78,8 +81,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo):
message_run = LangSmithRunModel(
id=trace_info.message_id,
name=TraceTaskName.MESSAGE_TRACE.value,
- inputs=trace_info.workflow_run_inputs,
- outputs=trace_info.workflow_run_outputs,
+ inputs=dict(trace_info.workflow_run_inputs),
+ outputs=dict(trace_info.workflow_run_outputs),
run_type=LangSmithRunType.chain,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
@@ -90,6 +93,15 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo):
error=trace_info.error,
trace_id=trace_id,
dotted_order=message_dotted_order,
+ file_list=[],
+ serialized=None,
+ parent_run_id=None,
+ events=[],
+ session_id=None,
+ session_name=None,
+ reference_example_id=None,
+ input_attachments={},
+ output_attachments={},
)
self.add_run(message_run)
@@ -98,11 +110,11 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo):
total_tokens=trace_info.total_tokens,
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
- inputs=trace_info.workflow_run_inputs,
+ inputs=dict(trace_info.workflow_run_inputs),
run_type=LangSmithRunType.tool,
start_time=trace_info.workflow_data.created_at,
end_time=trace_info.workflow_data.finished_at,
- outputs=trace_info.workflow_run_outputs,
+ outputs=dict(trace_info.workflow_run_outputs),
extra={
"metadata": metadata,
},
@@ -111,6 +123,13 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo):
parent_run_id=trace_info.message_id or None,
trace_id=trace_id,
dotted_order=workflow_dotted_order,
+ serialized=None,
+ events=[],
+ session_id=None,
+ session_name=None,
+ reference_example_id=None,
+ input_attachments={},
+ output_attachments={},
)
self.add_run(langsmith_run)
@@ -211,25 +230,35 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo):
id=node_execution_id,
trace_id=trace_id,
dotted_order=node_dotted_order,
+ error="",
+ serialized=None,
+ events=[],
+ session_id=None,
+ session_name=None,
+ reference_example_id=None,
+ input_attachments={},
+ output_attachments={},
)
self.add_run(langsmith_run)
def message_trace(self, trace_info: MessageTraceInfo):
# get message file data
- file_list = trace_info.file_list
- message_file_data: MessageFile = trace_info.message_file_data
+ file_list = cast(list[str], trace_info.file_list) or []
+ message_file_data: Optional[MessageFile] = trace_info.message_file_data
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
metadata = trace_info.metadata
message_data = trace_info.message_data
+ if message_data is None:
+ return
message_id = message_data.id
user_id = message_data.from_account_id
metadata["user_id"] = user_id
if message_data.from_end_user_id:
- end_user_data: EndUser = (
+ end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
@@ -247,12 +276,20 @@ def message_trace(self, trace_info: MessageTraceInfo):
start_time=trace_info.start_time,
end_time=trace_info.end_time,
outputs=message_data.answer,
- extra={
- "metadata": metadata,
- },
+ extra={"metadata": metadata},
tags=["message", str(trace_info.conversation_mode)],
error=trace_info.error,
file_list=file_list,
+ serialized=None,
+ events=[],
+ session_id=None,
+ session_name=None,
+ reference_example_id=None,
+ input_attachments={},
+ output_attachments={},
+ trace_id=None,
+ dotted_order=None,
+ parent_run_id=None,
)
self.add_run(message_run)
@@ -267,17 +304,27 @@ def message_trace(self, trace_info: MessageTraceInfo):
start_time=trace_info.start_time,
end_time=trace_info.end_time,
outputs=message_data.answer,
- extra={
- "metadata": metadata,
- },
+ extra={"metadata": metadata},
parent_run_id=message_id,
tags=["llm", str(trace_info.conversation_mode)],
error=trace_info.error,
file_list=file_list,
+ serialized=None,
+ events=[],
+ session_id=None,
+ session_name=None,
+ reference_example_id=None,
+ input_attachments={},
+ output_attachments={},
+ trace_id=None,
+ dotted_order=None,
+ id=str(uuid.uuid4()),
)
self.add_run(llm_run)
def moderation_trace(self, trace_info: ModerationTraceInfo):
+ if trace_info.message_data is None:
+ return
langsmith_run = LangSmithRunModel(
name=TraceTaskName.MODERATION_TRACE.value,
inputs=trace_info.inputs,
@@ -288,48 +335,82 @@ def moderation_trace(self, trace_info: ModerationTraceInfo):
"inputs": trace_info.inputs,
},
run_type=LangSmithRunType.tool,
- extra={
- "metadata": trace_info.metadata,
- },
+ extra={"metadata": trace_info.metadata},
tags=["moderation"],
parent_run_id=trace_info.message_id,
start_time=trace_info.start_time or trace_info.message_data.created_at,
end_time=trace_info.end_time or trace_info.message_data.updated_at,
+ id=str(uuid.uuid4()),
+ serialized=None,
+ events=[],
+ session_id=None,
+ session_name=None,
+ reference_example_id=None,
+ input_attachments={},
+ output_attachments={},
+ trace_id=None,
+ dotted_order=None,
+ error="",
+ file_list=[],
)
self.add_run(langsmith_run)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_data = trace_info.message_data
+ if message_data is None:
+ return
suggested_question_run = LangSmithRunModel(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
inputs=trace_info.inputs,
outputs=trace_info.suggested_question,
run_type=LangSmithRunType.tool,
- extra={
- "metadata": trace_info.metadata,
- },
+ extra={"metadata": trace_info.metadata},
tags=["suggested_question"],
parent_run_id=trace_info.message_id,
start_time=trace_info.start_time or message_data.created_at,
end_time=trace_info.end_time or message_data.updated_at,
+ id=str(uuid.uuid4()),
+ serialized=None,
+ events=[],
+ session_id=None,
+ session_name=None,
+ reference_example_id=None,
+ input_attachments={},
+ output_attachments={},
+ trace_id=None,
+ dotted_order=None,
+ error="",
+ file_list=[],
)
self.add_run(suggested_question_run)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
+ if trace_info.message_data is None:
+ return
dataset_retrieval_run = LangSmithRunModel(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
inputs=trace_info.inputs,
outputs={"documents": trace_info.documents},
run_type=LangSmithRunType.retriever,
- extra={
- "metadata": trace_info.metadata,
- },
+ extra={"metadata": trace_info.metadata},
tags=["dataset_retrieval"],
parent_run_id=trace_info.message_id,
start_time=trace_info.start_time or trace_info.message_data.created_at,
end_time=trace_info.end_time or trace_info.message_data.updated_at,
+ id=str(uuid.uuid4()),
+ serialized=None,
+ events=[],
+ session_id=None,
+ session_name=None,
+ reference_example_id=None,
+ input_attachments={},
+ output_attachments={},
+ trace_id=None,
+ dotted_order=None,
+ error="",
+ file_list=[],
)
self.add_run(dataset_retrieval_run)
@@ -347,7 +428,18 @@ def tool_trace(self, trace_info: ToolTraceInfo):
parent_run_id=trace_info.message_id,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
- file_list=[trace_info.file_url],
+ file_list=[cast(str, trace_info.file_url)],
+ id=str(uuid.uuid4()),
+ serialized=None,
+ events=[],
+ session_id=None,
+ session_name=None,
+ reference_example_id=None,
+ input_attachments={},
+ output_attachments={},
+ trace_id=None,
+ dotted_order=None,
+ error=trace_info.error or "",
)
self.add_run(tool_run)
@@ -358,12 +450,23 @@ def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
inputs=trace_info.inputs,
outputs=trace_info.outputs,
run_type=LangSmithRunType.tool,
- extra={
- "metadata": trace_info.metadata,
- },
+ extra={"metadata": trace_info.metadata},
tags=["generate_name"],
start_time=trace_info.start_time or datetime.now(),
end_time=trace_info.end_time or datetime.now(),
+ id=str(uuid.uuid4()),
+ serialized=None,
+ events=[],
+ session_id=None,
+ session_name=None,
+ reference_example_id=None,
+ input_attachments={},
+ output_attachments={},
+ trace_id=None,
+ dotted_order=None,
+ error="",
+ file_list=[],
+ parent_run_id=None,
)
self.add_run(name_run)
diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py
index 4f41b6ed97047f..f538eaef5bd570 100644
--- a/api/core/ops/ops_trace_manager.py
+++ b/api/core/ops/ops_trace_manager.py
@@ -33,11 +33,11 @@
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
-from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig
+from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog, WorkflowRun
from tasks.ops_trace_task import process_trace_tasks
-provider_config_map = {
+provider_config_map: dict[str, dict[str, Any]] = {
TracingProviderEnum.LANGFUSE.value: {
"config_class": LangfuseConfig,
"secret_keys": ["public_key", "secret_key"],
@@ -145,7 +145,7 @@ def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str):
:param tracing_provider: tracing provider
:return:
"""
- trace_config_data: TraceAppConfig = (
+ trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
@@ -155,7 +155,11 @@ def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str):
return None
# decrypt_token
- tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
+ app = db.session.query(App).filter(App.id == app_id).first()
+ if not app:
+ raise ValueError("App not found")
+
+ tenant_id = app.tenant_id
decrypt_tracing_config = cls.decrypt_tracing_config(
tenant_id, tracing_provider, trace_config_data.tracing_config
)
@@ -178,7 +182,7 @@ def get_ops_trace_instance(
if app_id is None:
return None
- app: App = db.session.query(App).filter(App.id == app_id).first()
+ app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
if app is None:
return None
@@ -209,8 +213,12 @@ def get_ops_trace_instance(
def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None
message_data = db.session.query(Message).filter(Message.id == message_id).first()
+ if not message_data:
+ return None
conversation_id = message_data.conversation_id
conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
+ if not conversation_data:
+ return None
if conversation_data.app_model_config_id:
app_model_config = (
@@ -236,7 +244,9 @@ def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider:
if tracing_provider not in provider_config_map and tracing_provider is not None:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
- app_config: App = db.session.query(App).filter(App.id == app_id).first()
+ app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
+ if not app_config:
+ raise ValueError("App not found")
app_config.tracing = json.dumps(
{
"enabled": enabled,
@@ -252,7 +262,9 @@ def get_app_tracing_config(cls, app_id: str):
:param app_id: app id
:return:
"""
- app: App = db.session.query(App).filter(App.id == app_id).first()
+ app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
+ if not app:
+ raise ValueError("App not found")
if not app.tracing:
return {"enabled": False, "tracing_provider": None}
app_trace_config = json.loads(app.tracing)
@@ -483,6 +495,8 @@ def message_trace(self, message_id):
def moderation_trace(self, message_id, timer, **kwargs):
moderation_result = kwargs.get("moderation_result")
+ if not moderation_result:
+ return {}
inputs = kwargs.get("inputs")
message_data = get_message_data(message_id)
if not message_data:
@@ -518,7 +532,7 @@ def moderation_trace(self, message_id, timer, **kwargs):
return moderation_trace_info
def suggested_question_trace(self, message_id, timer, **kwargs):
- suggested_question = kwargs.get("suggested_question")
+ suggested_question = kwargs.get("suggested_question", [])
message_data = get_message_data(message_id)
if not message_data:
return {}
@@ -586,7 +600,7 @@ def dataset_retrieval_trace(self, message_id, timer, **kwargs):
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
message_id=message_id,
inputs=message_data.query or message_data.inputs,
- documents=[doc.model_dump() for doc in documents],
+ documents=[doc.model_dump() for doc in documents] if documents else [],
start_time=timer.get("start"),
end_time=timer.get("end"),
metadata=metadata,
@@ -596,9 +610,9 @@ def dataset_retrieval_trace(self, message_id, timer, **kwargs):
return dataset_retrieval_trace_info
def tool_trace(self, message_id, timer, **kwargs):
- tool_name = kwargs.get("tool_name")
- tool_inputs = kwargs.get("tool_inputs")
- tool_outputs = kwargs.get("tool_outputs")
+ tool_name = kwargs.get("tool_name", "")
+ tool_inputs = kwargs.get("tool_inputs", {})
+ tool_outputs = kwargs.get("tool_outputs", {})
message_data = get_message_data(message_id)
if not message_data:
return {}
@@ -608,7 +622,7 @@ def tool_trace(self, message_id, timer, **kwargs):
tool_parameters = {}
created_time = message_data.created_at
end_time = message_data.updated_at
- agent_thoughts: list[MessageAgentThought] = message_data.agent_thoughts
+ agent_thoughts = message_data.agent_thoughts
for agent_thought in agent_thoughts:
if tool_name in agent_thought.tools:
created_time = agent_thought.created_at
@@ -672,6 +686,8 @@ def generate_name_trace(self, conversation_id, timer, **kwargs):
generate_conversation_name = kwargs.get("generate_conversation_name")
inputs = kwargs.get("inputs")
tenant_id = kwargs.get("tenant_id")
+ if not tenant_id:
+ return {}
start_time = timer.get("start")
end_time = timer.get("end")
@@ -693,8 +709,8 @@ def generate_name_trace(self, conversation_id, timer, **kwargs):
return generate_name_trace_info
-trace_manager_timer = None
-trace_manager_queue = queue.Queue()
+trace_manager_timer: Optional[threading.Timer] = None
+trace_manager_queue: queue.Queue = queue.Queue()
trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
@@ -706,7 +722,7 @@ def __init__(self, app_id=None, user_id=None):
self.app_id = app_id
self.user_id = user_id
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
- self.flask_app = current_app._get_current_object()
+ self.flask_app = current_app._get_current_object() # type: ignore
if trace_manager_timer is None:
self.start_timer()
@@ -723,7 +739,7 @@ def add_trace_task(self, trace_task: TraceTask):
def collect_tasks(self):
global trace_manager_queue
- tasks = []
+ tasks: list[TraceTask] = []
while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
task = trace_manager_queue.get_nowait()
tasks.append(task)
@@ -749,6 +765,8 @@ def start_timer(self):
def send_to_celery(self, tasks: list[TraceTask]):
with self.flask_app.app_context():
for task in tasks:
+ if task.app_id is None:
+ continue
file_id = uuid4().hex
trace_info = task.execute()
task_data = TaskData(
diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py
index 0f3f8249661bf0..87c7a79fb01201 100644
--- a/api/core/prompt/advanced_prompt_transform.py
+++ b/api/core/prompt/advanced_prompt_transform.py
@@ -1,5 +1,5 @@
-from collections.abc import Sequence
-from typing import Optional
+from collections.abc import Mapping, Sequence
+from typing import Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import file_manager
@@ -39,7 +39,7 @@ def get_prompt(
self,
*,
prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate,
- inputs: dict[str, str],
+ inputs: Mapping[str, str],
query: str,
files: Sequence[File],
context: Optional[str],
@@ -77,7 +77,7 @@ def get_prompt(
def _get_completion_model_prompt_messages(
self,
prompt_template: CompletionModelPromptTemplate,
- inputs: dict,
+ inputs: Mapping[str, str],
query: Optional[str],
files: Sequence[File],
context: Optional[str],
@@ -90,15 +90,15 @@ def _get_completion_model_prompt_messages(
"""
raw_prompt = prompt_template.text
- prompt_messages = []
+ prompt_messages: list[PromptMessage] = []
if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
- prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
+ prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs}
prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
- if memory and memory_config:
+ if memory and memory_config and memory_config.role_prefix:
role_prefix = memory_config.role_prefix
prompt_inputs = self._set_histories_variable(
memory=memory,
@@ -135,7 +135,7 @@ def _get_completion_model_prompt_messages(
def _get_chat_model_prompt_messages(
self,
prompt_template: list[ChatModelMessage],
- inputs: dict,
+ inputs: Mapping[str, str],
query: Optional[str],
files: Sequence[File],
context: Optional[str],
@@ -146,7 +146,7 @@ def _get_chat_model_prompt_messages(
"""
Get chat model prompt messages.
"""
- prompt_messages = []
+ prompt_messages: list[PromptMessage] = []
for prompt_item in prompt_template:
raw_prompt = prompt_item.text
@@ -160,7 +160,7 @@ def _get_chat_model_prompt_messages(
prompt = vp.convert_template(raw_prompt).text
else:
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
- prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
+ prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs}
prompt_inputs = self._set_context_variable(
context=context, parser=parser, prompt_inputs=prompt_inputs
)
@@ -207,7 +207,7 @@ def _get_chat_model_prompt_messages(
last_message = prompt_messages[-1] if prompt_messages else None
if last_message and last_message.role == PromptMessageRole.USER:
# get last user message content and add files
- prompt_message_contents = [TextPromptMessageContent(data=last_message.content)]
+ prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))]
for file in files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
@@ -229,7 +229,10 @@ def _get_chat_model_prompt_messages(
return prompt_messages
- def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
+ def _set_context_variable(
+ self, context: str | None, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str]
+ ) -> Mapping[str, str]:
+ prompt_inputs = dict(prompt_inputs)
if "#context#" in parser.variable_keys:
if context:
prompt_inputs["#context#"] = context
@@ -238,7 +241,10 @@ def _set_context_variable(self, context: str | None, parser: PromptTemplateParse
return prompt_inputs
- def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
+ def _set_query_variable(
+ self, query: str, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str]
+ ) -> Mapping[str, str]:
+ prompt_inputs = dict(prompt_inputs)
if "#query#" in parser.variable_keys:
if query:
prompt_inputs["#query#"] = query
@@ -254,9 +260,10 @@ def _set_histories_variable(
raw_prompt: str,
role_prefix: MemoryConfig.RolePrefix,
parser: PromptTemplateParser,
- prompt_inputs: dict,
+ prompt_inputs: Mapping[str, str],
model_config: ModelConfigWithCredentialsEntity,
- ) -> dict:
+ ) -> Mapping[str, str]:
+ prompt_inputs = dict(prompt_inputs)
if "#histories#" in parser.variable_keys:
if memory:
inputs = {"#histories#": "", **prompt_inputs}
diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py
index caa1793ea8c039..09f017a7db0d3a 100644
--- a/api/core/prompt/agent_history_prompt_transform.py
+++ b/api/core/prompt/agent_history_prompt_transform.py
@@ -31,7 +31,7 @@ def __init__(
self.memory = memory
def get_prompt(self) -> list[PromptMessage]:
- prompt_messages = []
+ prompt_messages: list[PromptMessage] = []
num_system = 0
for prompt_message in self.history_messages:
if isinstance(prompt_message, SystemPromptMessage):
diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py
index 87acdb3c49cc01..1f040599be6dac 100644
--- a/api/core/prompt/prompt_transform.py
+++ b/api/core/prompt/prompt_transform.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Any, Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -42,7 +42,7 @@ def _calculate_rest_token(
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
- or model_config.parameters.get(parameter_rule.use_template)
+ or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
@@ -59,7 +59,7 @@ def _get_history_messages_from_memory(
ai_prefix: Optional[str] = None,
) -> str:
"""Get memory messages."""
- kwargs = {"max_token_limit": max_token_limit}
+ kwargs: dict[str, Any] = {"max_token_limit": max_token_limit}
if human_prefix:
kwargs["human_prefix"] = human_prefix
@@ -76,11 +76,15 @@ def _get_history_messages_list_from_memory(
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
) -> list[PromptMessage]:
"""Get memory messages."""
- return memory.get_history_prompt_messages(
- max_token_limit=max_token_limit,
- message_limit=memory_config.window.size
- if (
- memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0
+ return list(
+ memory.get_history_prompt_messages(
+ max_token_limit=max_token_limit,
+ message_limit=memory_config.window.size
+ if (
+ memory_config.window.enabled
+ and memory_config.window.size is not None
+ and memory_config.window.size > 0
+ )
+ else None,
)
- else None,
)
diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py
index 93dd92f188a9c6..e75877de9b695c 100644
--- a/api/core/prompt/simple_prompt_transform.py
+++ b/api/core/prompt/simple_prompt_transform.py
@@ -1,7 +1,8 @@
import enum
import json
import os
-from typing import TYPE_CHECKING, Optional
+from collections.abc import Mapping, Sequence
+from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.app_config.entities import PromptTemplateEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@@ -41,7 +42,7 @@ def value_of(cls, value: str) -> "ModelMode":
raise ValueError(f"invalid mode value {value}")
-prompt_file_contents = {}
+prompt_file_contents: dict[str, Any] = {}
class SimplePromptTransform(PromptTransform):
@@ -53,9 +54,9 @@ def get_prompt(
self,
app_mode: AppMode,
prompt_template_entity: PromptTemplateEntity,
- inputs: dict,
+ inputs: Mapping[str, str],
query: str,
- files: list["File"],
+ files: Sequence["File"],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
@@ -66,7 +67,7 @@ def get_prompt(
if model_mode == ModelMode.CHAT:
prompt_messages, stops = self._get_chat_model_prompt_messages(
app_mode=app_mode,
- pre_prompt=prompt_template_entity.simple_prompt_template,
+ pre_prompt=prompt_template_entity.simple_prompt_template or "",
inputs=inputs,
query=query,
files=files,
@@ -77,7 +78,7 @@ def get_prompt(
else:
prompt_messages, stops = self._get_completion_model_prompt_messages(
app_mode=app_mode,
- pre_prompt=prompt_template_entity.simple_prompt_template,
+ pre_prompt=prompt_template_entity.simple_prompt_template or "",
inputs=inputs,
query=query,
files=files,
@@ -171,11 +172,11 @@ def _get_chat_model_prompt_messages(
inputs: dict,
query: str,
context: Optional[str],
- files: list["File"],
+ files: Sequence["File"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
- prompt_messages = []
+ prompt_messages: list[PromptMessage] = []
# get prompt
prompt, _ = self.get_prompt_str_and_rules(
@@ -216,7 +217,7 @@ def _get_completion_model_prompt_messages(
inputs: dict,
query: str,
context: Optional[str],
- files: list["File"],
+ files: Sequence["File"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
@@ -263,7 +264,7 @@ def _get_completion_model_prompt_messages(
return [self.get_last_user_message(prompt, files)], stops
- def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage:
+ def get_last_user_message(self, prompt: str, files: Sequence["File"]) -> UserPromptMessage:
if files:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
@@ -288,7 +289,7 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict
# Check if the prompt file is already loaded
if prompt_file_name in prompt_file_contents:
- return prompt_file_contents[prompt_file_name]
+ return cast(dict, prompt_file_contents[prompt_file_name])
# Get the absolute path of the subdirectory
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates")
@@ -301,7 +302,7 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict
# Store the content of the prompt file
prompt_file_contents[prompt_file_name] = content
- return content
+ return cast(dict, content)
def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
# baichuan
diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py
index aa175153bc633f..2f4e65146131be 100644
--- a/api/core/prompt/utils/prompt_message_util.py
+++ b/api/core/prompt/utils/prompt_message_util.py
@@ -1,5 +1,5 @@
from collections.abc import Sequence
-from typing import cast
+from typing import Any, cast
from core.model_runtime.entities import (
AssistantPromptMessage,
@@ -72,7 +72,7 @@ def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Seque
}
)
else:
- text = prompt_message.content
+ text = cast(str, prompt_message.content)
prompt = {"role": role, "text": text, "files": files}
@@ -99,9 +99,9 @@ def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Seque
}
)
else:
- text = prompt_message.content
+ text = cast(str, prompt_message.content)
- params = {
+ params: dict[str, Any] = {
"role": "user",
"text": text,
}
diff --git a/api/core/prompt/utils/prompt_template_parser.py b/api/core/prompt/utils/prompt_template_parser.py
index 0fd08c5d3c1a3e..8e40674bc193e0 100644
--- a/api/core/prompt/utils/prompt_template_parser.py
+++ b/api/core/prompt/utils/prompt_template_parser.py
@@ -1,4 +1,5 @@
import re
+from collections.abc import Mapping
REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#histories#|#query#|#context#)\}\}")
WITH_VARIABLE_TMPL_REGEX = re.compile(
@@ -28,7 +29,7 @@ def extract(self) -> list:
# Regular expression to match the template rules
return re.findall(self.regex, self.template)
- def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
+ def format(self, inputs: Mapping[str, str], remove_template_variables: bool = True) -> str:
def replacer(match):
key = match.group(1)
value = inputs.get(key, match.group(0)) # return original matched string if key not found
diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py
index 3a1fe300dfd311..010abd12d275cd 100644
--- a/api/core/provider_manager.py
+++ b/api/core/provider_manager.py
@@ -1,7 +1,7 @@
import json
from collections import defaultdict
from json import JSONDecodeError
-from typing import Optional
+from typing import Optional, cast
from sqlalchemy.exc import IntegrityError
@@ -15,6 +15,7 @@
ModelLoadBalancingConfiguration,
ModelSettings,
QuotaConfiguration,
+ QuotaUnit,
SystemConfiguration,
)
from core.helper import encrypter
@@ -116,8 +117,8 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
- include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
- exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
+ include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
+ exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
data=provider_entity,
name_func=lambda x: x.provider,
):
@@ -490,12 +491,13 @@ def _init_trial_provider_records(
# Init trial provider records if not exists
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
try:
+ # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
provider_record = Provider(
tenant_id=tenant_id,
provider_name=provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value,
- quota_limit=quota.quota_limit,
+ quota_limit=quota.quota_limit, # type: ignore
quota_used=0,
is_valid=True,
)
@@ -589,7 +591,9 @@ def _to_custom_configuration(
if variable in provider_credentials:
try:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
- provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa
+ provider_credentials.get(variable) or "", # type: ignore
+ self.decoding_rsa_key,
+ self.decoding_cipher_rsa,
)
except ValueError:
pass
@@ -671,13 +675,9 @@ def _to_system_configuration(
# Get hosting configuration
hosting_configuration = ext_hosting_provider.hosting_configuration
- if (
- provider_entity.provider not in hosting_configuration.provider_map
- or not hosting_configuration.provider_map.get(provider_entity.provider).enabled
- ):
- return SystemConfiguration(enabled=False)
-
provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider)
+ if provider_hosting_configuration is None or not provider_hosting_configuration.enabled:
+ return SystemConfiguration(enabled=False)
# Convert provider_records to dict
quota_type_to_provider_records_dict = {}
@@ -688,14 +688,13 @@ def _to_system_configuration(
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
provider_record
)
-
quota_configurations = []
for provider_quota in provider_hosting_configuration.quotas:
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
if provider_quota.quota_type == ProviderQuotaType.FREE:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
- quota_unit=provider_hosting_configuration.quota_unit,
+ quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=0,
quota_limit=0,
is_valid=False,
@@ -708,7 +707,7 @@ def _to_system_configuration(
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
- quota_unit=provider_hosting_configuration.quota_unit,
+ quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
@@ -725,12 +724,12 @@ def _to_system_configuration(
current_using_credentials = provider_hosting_configuration.credentials
if current_quota_type == ProviderQuotaType.FREE:
- provider_record = quota_type_to_provider_records_dict.get(current_quota_type)
+ provider_record_quota_free = quota_type_to_provider_records_dict.get(current_quota_type)
- if provider_record:
+ if provider_record_quota_free:
provider_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
- identity_id=provider_record.id,
+ identity_id=provider_record_quota_free.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
@@ -763,7 +762,7 @@ def _to_system_configuration(
except ValueError:
pass
- current_using_credentials = provider_credentials
+ current_using_credentials = provider_credentials or {}
# cache provider credentials
provider_credentials_cache.set(credentials=current_using_credentials)
@@ -842,7 +841,7 @@ def _to_model_settings(
else []
)
- model_settings = []
+ model_settings: list[ModelSettings] = []
if not provider_model_settings:
return model_settings
diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py
index a0153c1e58a1a8..95a2316f1da4dd 100644
--- a/api/core/rag/datasource/keyword/jieba/jieba.py
+++ b/api/core/rag/datasource/keyword/jieba/jieba.py
@@ -32,8 +32,11 @@ def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
- self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
- keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
+ if text.metadata is not None:
+ self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
+ keyword_table = self._add_text_to_keyword_table(
+ keyword_table or {}, text.metadata["doc_id"], list(keywords)
+ )
self._save_dataset_keyword_table(keyword_table)
@@ -58,20 +61,26 @@ def add_texts(self, texts: list[Document], **kwargs):
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
- self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
- keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
+ if text.metadata is not None:
+ self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
+ keyword_table = self._add_text_to_keyword_table(
+ keyword_table or {}, text.metadata["doc_id"], list(keywords)
+ )
self._save_dataset_keyword_table(keyword_table)
def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
+ if keyword_table is None:
+ return False
return id in set.union(*keyword_table.values())
def delete_by_ids(self, ids: list[str]) -> None:
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table = self._get_dataset_keyword_table()
- keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
+ if keyword_table is not None:
+ keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
@@ -80,7 +89,7 @@ def search(self, query: str, **kwargs: Any) -> list[Document]:
k = kwargs.get("top_k", 4)
- sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
+ sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
documents = []
for chunk_index in sorted_chunk_indices:
@@ -137,7 +146,7 @@ def _get_dataset_keyword_table(self) -> Optional[dict]:
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
- return keyword_table_dict["__data__"]["table"]
+ return dict(keyword_table_dict["__data__"]["table"])
else:
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
dataset_keyword_table = DatasetKeywordTable(
@@ -188,8 +197,8 @@ def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
# go through text chunks in order of most matching keywords
chunk_indices_count: dict[str, int] = defaultdict(int)
- keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
- for keyword in keywords:
+ keywords_list = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
+ for keyword in keywords_list:
for node_id in keyword_table[keyword]:
chunk_indices_count[node_id] += 1
@@ -215,7 +224,7 @@ def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list
def create_segment_keywords(self, node_id: str, keywords: list[str]):
keyword_table = self._get_dataset_keyword_table()
self._update_segment_keywords(self.dataset.id, node_id, keywords)
- keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
+ keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
def multi_create_segment_keywords(self, pre_segment_data_list: list):
@@ -226,17 +235,19 @@ def multi_create_segment_keywords(self, pre_segment_data_list: list):
if pre_segment_data["keywords"]:
segment.keywords = pre_segment_data["keywords"]
keyword_table = self._add_text_to_keyword_table(
- keyword_table, segment.index_node_id, pre_segment_data["keywords"]
+ keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
)
else:
keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk)
segment.keywords = list(keywords)
- keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
+ keyword_table = self._add_text_to_keyword_table(
+ keyword_table or {}, segment.index_node_id, list(keywords)
+ )
self._save_dataset_keyword_table(keyword_table)
def update_segment_keywords_index(self, node_id: str, keywords: list[str]):
keyword_table = self._get_dataset_keyword_table()
- keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
+ keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
index ec809cf325306e..8b17e8dc0a3762 100644
--- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
+++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
@@ -4,7 +4,7 @@
class JiebaKeywordTableHandler:
def __init__(self):
- import jieba.analyse
+ import jieba.analyse # type: ignore
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
@@ -12,7 +12,7 @@ def __init__(self):
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf."""
- import jieba
+ import jieba # type: ignore
keywords = jieba.analyse.extract_tags(
sentence=text,
diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py
index be00687abd5025..b261b40b728692 100644
--- a/api/core/rag/datasource/keyword/keyword_base.py
+++ b/api/core/rag/datasource/keyword/keyword_base.py
@@ -37,6 +37,8 @@ def search(self, query: str, **kwargs: Any) -> list[Document]:
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts.copy():
+ if text.metadata is None:
+ continue
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
@@ -45,4 +47,4 @@ def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
- return [text.metadata["doc_id"] for text in texts]
+ return [text.metadata["doc_id"] for text in texts if text.metadata]
diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py
index 18f8d4e8392302..34343ad60ea4c1 100644
--- a/api/core/rag/datasource/retrieval_service.py
+++ b/api/core/rag/datasource/retrieval_service.py
@@ -6,6 +6,7 @@
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
@@ -31,7 +32,7 @@ def retrieve(
top_k: int,
score_threshold: Optional[float] = 0.0,
reranking_model: Optional[dict] = None,
- reranking_mode: Optional[str] = "reranking_model",
+ reranking_mode: str = "reranking_model",
weights: Optional[dict] = None,
):
if not query:
@@ -42,15 +43,15 @@ def retrieve(
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
- all_documents = []
- threads = []
- exceptions = []
+ all_documents: list[Document] = []
+ threads: list[threading.Thread] = []
+ exceptions: list[str] = []
# retrieval_model source with keyword
if retrieval_method == "keyword_search":
keyword_thread = threading.Thread(
target=RetrievalService.keyword_search,
kwargs={
- "flask_app": current_app._get_current_object(),
+ "flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"top_k": top_k,
@@ -65,7 +66,7 @@ def retrieve(
embedding_thread = threading.Thread(
target=RetrievalService.embedding_search,
kwargs={
- "flask_app": current_app._get_current_object(),
+ "flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"top_k": top_k,
@@ -84,7 +85,7 @@ def retrieve(
full_text_index_thread = threading.Thread(
target=RetrievalService.full_text_index_search,
kwargs={
- "flask_app": current_app._get_current_object(),
+ "flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"retrieval_method": retrieval_method,
@@ -124,7 +125,7 @@ def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model
if not dataset:
return []
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
- dataset.tenant_id, dataset_id, query, external_retrieval_model
+ dataset.tenant_id, dataset_id, query, external_retrieval_model or {}
)
return all_documents
@@ -135,6 +136,8 @@ def keyword_search(
with flask_app.app_context():
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ if not dataset:
+ raise ValueError("dataset not found")
keyword = Keyword(dataset=dataset)
@@ -159,6 +162,8 @@ def embedding_search(
with flask_app.app_context():
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ if not dataset:
+ raise ValueError("dataset not found")
vector = Vector(dataset=dataset)
@@ -209,6 +214,8 @@ def full_text_index_search(
with flask_app.app_context():
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ if not dataset:
+ raise ValueError("dataset not found")
vector_processor = Vector(
dataset=dataset,
diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
index 09104ae4223443..603d3fdbcdf1ab 100644
--- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
+++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
@@ -17,12 +17,19 @@
class AnalyticdbVector(BaseVector):
def __init__(
- self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
+ self,
+ collection_name: str,
+ api_config: AnalyticdbVectorOpenAPIConfig | None,
+ sql_config: AnalyticdbVectorBySqlConfig | None,
):
super().__init__(collection_name)
if api_config is not None:
- self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
+ self.analyticdb_vector: AnalyticdbVectorOpenAPI | AnalyticdbVectorBySql = AnalyticdbVectorOpenAPI(
+ collection_name, api_config
+ )
else:
+ if sql_config is None:
+ raise ValueError("Either api_config or sql_config must be provided")
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
def get_type(self) -> str:
@@ -33,8 +40,8 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs)
self.analyticdb_vector._create_collection_if_not_exists(dimension)
self.analyticdb_vector.add_texts(texts, embeddings)
- def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
- self.analyticdb_vector.add_texts(texts, embeddings)
+ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+ self.analyticdb_vector.add_texts(documents, embeddings)
def text_exists(self, id: str) -> bool:
return self.analyticdb_vector.text_exists(id)
@@ -68,13 +75,13 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
if dify_config.ANALYTICDB_HOST is None:
# implemented through OpenAPI
apiConfig = AnalyticdbVectorOpenAPIConfig(
- access_key_id=dify_config.ANALYTICDB_KEY_ID,
- access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
- region_id=dify_config.ANALYTICDB_REGION_ID,
- instance_id=dify_config.ANALYTICDB_INSTANCE_ID,
- account=dify_config.ANALYTICDB_ACCOUNT,
- account_password=dify_config.ANALYTICDB_PASSWORD,
- namespace=dify_config.ANALYTICDB_NAMESPACE,
+ access_key_id=dify_config.ANALYTICDB_KEY_ID or "",
+ access_key_secret=dify_config.ANALYTICDB_KEY_SECRET or "",
+ region_id=dify_config.ANALYTICDB_REGION_ID or "",
+ instance_id=dify_config.ANALYTICDB_INSTANCE_ID or "",
+ account=dify_config.ANALYTICDB_ACCOUNT or "",
+ account_password=dify_config.ANALYTICDB_PASSWORD or "",
+ namespace=dify_config.ANALYTICDB_NAMESPACE or "",
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
)
sqlConfig = None
@@ -83,11 +90,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
sqlConfig = AnalyticdbVectorBySqlConfig(
host=dify_config.ANALYTICDB_HOST,
port=dify_config.ANALYTICDB_PORT,
- account=dify_config.ANALYTICDB_ACCOUNT,
- account_password=dify_config.ANALYTICDB_PASSWORD,
+ account=dify_config.ANALYTICDB_ACCOUNT or "",
+ account_password=dify_config.ANALYTICDB_PASSWORD or "",
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
- namespace=dify_config.ANALYTICDB_NAMESPACE,
+ namespace=dify_config.ANALYTICDB_NAMESPACE or "",
)
apiConfig = None
return AnalyticdbVector(
diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
index 05e0ebc54f7c4c..095752ea8eaa42 100644
--- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
+++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
@@ -1,5 +1,5 @@
import json
-from typing import Any
+from typing import Any, Optional
from pydantic import BaseModel, model_validator
@@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
account: str
account_password: str
namespace: str = "dify"
- namespace_password: str = (None,)
+ namespace_password: Optional[str] = None
metrics: str = "cosine"
read_timeout: int = 60000
@@ -55,8 +55,8 @@ def to_analyticdb_client_params(self):
class AnalyticdbVectorOpenAPI:
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
try:
- from alibabacloud_gpdb20160503.client import Client
- from alibabacloud_tea_openapi import models as open_api_models
+ from alibabacloud_gpdb20160503.client import Client # type: ignore
+ from alibabacloud_tea_openapi import models as open_api_models # type: ignore
except:
raise ImportError(_import_err_msg)
self._collection_name = collection_name.lower()
@@ -77,7 +77,7 @@ def _initialize(self) -> None:
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self) -> None:
- from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
+ from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
@@ -89,7 +89,7 @@ def _initialize_vector_database(self) -> None:
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
- from Tea.exceptions import TeaException
+ from Tea.exceptions import TeaException # type: ignore
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
@@ -159,17 +159,18 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
- metadata = {
- "ref_doc_id": doc.metadata["doc_id"],
- "page_content": doc.page_content,
- "metadata_": json.dumps(doc.metadata),
- }
- rows.append(
- gpdb_20160503_models.UpsertCollectionDataRequestRows(
- vector=embedding,
- metadata=metadata,
+ if doc.metadata is not None:
+ metadata = {
+ "ref_doc_id": doc.metadata["doc_id"],
+ "page_content": doc.page_content,
+ "metadata_": json.dumps(doc.metadata),
+ }
+ rows.append(
+ gpdb_20160503_models.UpsertCollectionDataRequestRows(
+ vector=embedding,
+ metadata=metadata,
+ )
)
- )
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@@ -258,7 +259,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
metadata=metadata,
)
documents.append(doc)
- documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
+ documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -290,7 +291,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
metadata=metadata,
)
documents.append(doc)
- documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
+ documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def delete(self) -> None:
diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
index e474db5cb21971..4d8f7929413cf2 100644
--- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
+++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
@@ -3,8 +3,8 @@
from contextlib import contextmanager
from typing import Any
-import psycopg2.extras
-import psycopg2.pool
+import psycopg2.extras # type: ignore
+import psycopg2.pool # type: ignore
from pydantic import BaseModel, model_validator
from core.rag.models.document import Document
@@ -75,6 +75,7 @@ def _create_connection_pool(self):
@contextmanager
def _get_cursor(self):
+ assert self.pool is not None, "Connection pool is not initialized"
conn = self.pool.getconn()
cur = conn.cursor()
try:
@@ -156,16 +157,17 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
"""
for i, doc in enumerate(documents):
- values.append(
- (
- id_prefix + str(i),
- doc.metadata.get("doc_id", str(uuid.uuid4())),
- embeddings[i],
- doc.page_content,
- json.dumps(doc.metadata),
- doc.page_content,
+ if doc.metadata is not None:
+ values.append(
+ (
+ id_prefix + str(i),
+ doc.metadata.get("doc_id", str(uuid.uuid4())),
+ embeddings[i],
+ doc.page_content,
+ json.dumps(doc.metadata),
+ doc.page_content,
+ )
)
- )
with self._get_cursor() as cur:
psycopg2.extras.execute_batch(cur, sql, values)
diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py
index eb78e8aa698b9b..85596ad20e099a 100644
--- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py
+++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py
@@ -5,13 +5,13 @@
import numpy as np
from pydantic import BaseModel, model_validator
-from pymochow import MochowClient
-from pymochow.auth.bce_credentials import BceCredentials
-from pymochow.configuration import Configuration
-from pymochow.exception import ServerError
-from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState
-from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
-from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
+from pymochow import MochowClient # type: ignore
+from pymochow.auth.bce_credentials import BceCredentials # type: ignore
+from pymochow.configuration import Configuration # type: ignore
+from pymochow.exception import ServerError # type: ignore
+from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
+from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore
+from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -75,7 +75,7 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
texts = [doc.page_content for doc in documents]
- metadatas = [doc.metadata for doc in documents]
+ metadatas = [doc.metadata for doc in documents if doc.metadata is not None]
total_count = len(documents)
batch_size = 1000
@@ -84,6 +84,8 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
for start in range(0, total_count, batch_size):
end = min(start + batch_size, total_count)
rows = []
+ assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
+ # FIXME do you need this assert?
for i in range(start, end, 1):
row = Row(
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
@@ -136,7 +138,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# baidu vector database doesn't support bm25 search on current version
return []
- def _get_search_res(self, res, score_threshold):
+ def _get_search_res(self, res, score_threshold) -> list[Document]:
docs = []
for row in res.rows:
row_data = row.get("row", {})
@@ -276,11 +278,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
return BaiduVector(
collection_name=collection_name,
config=BaiduConfig(
- endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT,
+ endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT or "",
connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS,
- account=dify_config.BAIDU_VECTOR_DB_ACCOUNT,
- api_key=dify_config.BAIDU_VECTOR_DB_API_KEY,
- database=dify_config.BAIDU_VECTOR_DB_DATABASE,
+ account=dify_config.BAIDU_VECTOR_DB_ACCOUNT or "",
+ api_key=dify_config.BAIDU_VECTOR_DB_API_KEY or "",
+ database=dify_config.BAIDU_VECTOR_DB_DATABASE or "",
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
),
diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py
index a9e1486edd25f1..0eab01b507dc94 100644
--- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py
+++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py
@@ -71,11 +71,13 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
metadatas = [d.metadata for d in documents]
collection = self._client.get_or_create_collection(self._collection_name)
- collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas)
+ # FIXME: chromadb using numpy array, fix the type error later
+ collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
def delete_by_metadata_field(self, key: str, value: str):
collection = self._client.get_or_create_collection(self._collection_name)
- collection.delete(where={key: {"$eq": value}})
+ # FIXME: fix the type error later
+ collection.delete(where={key: {"$eq": value}}) # type: ignore
def delete(self):
self._client.delete_collection(self._collection_name)
@@ -94,15 +96,19 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
- ids: list[str] = results["ids"][0]
- documents: list[str] = results["documents"][0]
- metadatas: dict[str, Any] = results["metadatas"][0]
- distances: list[float] = results["distances"][0]
+ # Check if results contain data
+ if not results["ids"] or not results["documents"] or not results["metadatas"] or not results["distances"]:
+ return []
+
+ ids = results["ids"][0]
+ documents = results["documents"][0]
+ metadatas = results["metadatas"][0]
+ distances = results["distances"][0]
docs = []
for index in range(len(ids)):
distance = distances[index]
- metadata = metadatas[index]
+ metadata = dict(metadatas[index])
if distance >= score_threshold:
metadata["score"] = distance
doc = Document(
@@ -111,7 +117,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
)
docs.append(doc)
# Sort the documents by score in descending order
- docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
+ docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -133,7 +139,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
return ChromaVector(
collection_name=collection_name,
config=ChromaConfig(
- host=dify_config.CHROMA_HOST,
+ host=dify_config.CHROMA_HOST or "",
port=dify_config.CHROMA_PORT,
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py
index d26726e86438bd..68a9952789e5b6 100644
--- a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py
+++ b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py
@@ -5,14 +5,14 @@
from datetime import timedelta
from typing import Any
-from couchbase import search
-from couchbase.auth import PasswordAuthenticator
-from couchbase.cluster import Cluster
-from couchbase.management.search import SearchIndex
+from couchbase import search # type: ignore
+from couchbase.auth import PasswordAuthenticator # type: ignore
+from couchbase.cluster import Cluster # type: ignore
+from couchbase.management.search import SearchIndex # type: ignore
# needed for options -- cluster, timeout, SQL++ (N1QL) query, etc.
-from couchbase.options import ClusterOptions, SearchOptions
-from couchbase.vector_search import VectorQuery, VectorSearch
+from couchbase.options import ClusterOptions, SearchOptions # type: ignore
+from couchbase.vector_search import VectorQuery, VectorSearch # type: ignore
from flask import current_app
from pydantic import BaseModel, model_validator
@@ -231,7 +231,7 @@ def text_exists(self, id: str) -> bool:
# Pass the id as a parameter to the query
result = self._cluster.query(query, named_parameters={"doc_id": id}).execute()
for row in result:
- return row["count"] > 0
+ return bool(row["count"] > 0)
return False # Return False if no rows are returned
def delete_by_ids(self, ids: list[str]) -> None:
@@ -369,10 +369,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
return CouchbaseVector(
collection_name=collection_name,
config=CouchbaseConfig(
- connection_string=config.get("COUCHBASE_CONNECTION_STRING"),
- user=config.get("COUCHBASE_USER"),
- password=config.get("COUCHBASE_PASSWORD"),
- bucket_name=config.get("COUCHBASE_BUCKET_NAME"),
- scope_name=config.get("COUCHBASE_SCOPE_NAME"),
+ connection_string=config.get("COUCHBASE_CONNECTION_STRING", ""),
+ user=config.get("COUCHBASE_USER", ""),
+ password=config.get("COUCHBASE_PASSWORD", ""),
+ bucket_name=config.get("COUCHBASE_BUCKET_NAME", ""),
+ scope_name=config.get("COUCHBASE_SCOPE_NAME", ""),
),
)
diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
index b08811a02181d2..8661828dc2aa52 100644
--- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
+++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
@@ -1,7 +1,7 @@
import json
import logging
import math
-from typing import Any, Optional
+from typing import Any, Optional, cast
from urllib.parse import urlparse
import requests
@@ -70,7 +70,7 @@ def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
def _get_version(self) -> str:
info = self._client.info()
- return info["version"]["number"]
+ return cast(str, info["version"]["number"])
def _check_version(self):
if self._version < "8.0.0":
@@ -135,7 +135,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
- doc.metadata["score"] = score
+ if doc.metadata is not None:
+ doc.metadata["score"] = score
docs.append(doc)
return docs
@@ -156,12 +157,15 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return docs
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
- metadatas = [d.metadata for d in texts]
+ metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(
- self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
+ self,
+ embeddings: list[list[float]],
+ metadatas: Optional[list[dict[Any, Any]]] = None,
+ index_params: Optional[dict] = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
@@ -208,10 +212,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(
- host=config.get("ELASTICSEARCH_HOST"),
- port=config.get("ELASTICSEARCH_PORT"),
- username=config.get("ELASTICSEARCH_USERNAME"),
- password=config.get("ELASTICSEARCH_PASSWORD"),
+ host=config.get("ELASTICSEARCH_HOST", "localhost"),
+ port=config.get("ELASTICSEARCH_PORT", 9200),
+ username=config.get("ELASTICSEARCH_USERNAME", ""),
+ password=config.get("ELASTICSEARCH_PASSWORD", ""),
),
attributes=[],
)
diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
index 8646e52cf493ca..d7a14207e9375a 100644
--- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
+++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
@@ -42,7 +42,7 @@ def validate_config(cls, values: dict) -> dict:
return values
def to_opensearch_params(self) -> dict[str, Any]:
- params = {"hosts": self.hosts}
+ params: dict[str, Any] = {"hosts": self.hosts}
if self.username and self.password:
params["http_auth"] = (self.username, self.password)
return params
@@ -53,7 +53,7 @@ def __init__(self, collection_name: str, config: LindormVectorStoreConfig, using
self._routing = None
self._routing_field = None
if using_ugc:
- routing_value: str = kwargs.get("routing_value")
+ routing_value: str | None = kwargs.get("routing_value")
if routing_value is None:
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
self._routing = routing_value.lower()
@@ -87,14 +87,15 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
"_id": uuids[i],
}
}
- action_values = {
+ action_values: dict[str, Any] = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
- action_values[self._routing_field] = self._routing
+ if self._routing_field is not None:
+ action_values[self._routing_field] = self._routing
actions.append(action_header)
actions.append(action_values)
response = self._client.bulk(actions)
@@ -105,7 +106,9 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
self.refresh()
def get_ids_by_metadata_field(self, key: str, value: str):
- query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}}
+ query: dict[str, Any] = {
+ "query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}
+ }
if self._using_ugc:
query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}})
response = self._client.search(index=self._collection_name, body=query)
@@ -191,7 +194,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
if score > score_threshold:
- doc.metadata["score"] = score
+ if doc.metadata is not None:
+ doc.metadata["score"] = score
docs.append(doc)
return docs
@@ -366,6 +370,7 @@ def default_text_search_query(
routing_field: Optional[str] = None,
**kwargs,
) -> dict:
+ query_clause: dict[str, Any] = {}
if routing is not None:
query_clause = {
"bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]}
@@ -386,7 +391,7 @@ def default_text_search_query(
else:
must = [query_clause]
- boolean_query = {"must": must}
+ boolean_query: dict[str, Any] = {"must": must}
if must_not:
if not isinstance(must_not, list):
@@ -426,7 +431,7 @@ def default_vector_search_query(
filter_type = "post_filter" if filter_type is None else filter_type
if not isinstance(filters, list):
raise RuntimeError(f"unexpected filter with {type(filters)}")
- final_ext = {"lvector": {}}
+ final_ext: dict[str, Any] = {"lvector": {}}
if min_score != "0.0":
final_ext["lvector"]["min_score"] = min_score
if ef_search:
@@ -438,7 +443,7 @@ def default_vector_search_query(
if client_refactor:
final_ext["lvector"]["client_refactor"] = client_refactor
- search_query = {
+ search_query: dict[str, Any] = {
"size": k,
"_source": True, # force return '_source'
"query": {"knn": {vector_field: {"vector": query_vector, "k": k}}},
@@ -446,8 +451,8 @@ def default_vector_search_query(
if filters is not None:
# when using filter, transform filter from List[Dict] to Dict as valid format
- filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
- search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict
+ filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
+ search_query["query"]["knn"][vector_field]["filter"] = filter_dict # filter should be Dict
if filter_type:
final_ext["lvector"]["filter_type"] = filter_type
@@ -459,17 +464,19 @@ def default_vector_search_query(
class LindormVectorStoreFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
lindorm_config = LindormVectorStoreConfig(
- hosts=dify_config.LINDORM_URL,
+ hosts=dify_config.LINDORM_URL or "",
username=dify_config.LINDORM_USERNAME,
password=dify_config.LINDORM_PASSWORD,
using_ugc=dify_config.USING_UGC_INDEX,
)
using_ugc = dify_config.USING_UGC_INDEX
+ if using_ugc is None:
+ raise ValueError("USING_UGC_INDEX is not set")
routing_value = None
if dataset.index_struct:
# if an existed record's index_struct_dict doesn't contain using_ugc field,
# it actually stores in the normal index format
- stored_in_ugc = dataset.index_struct_dict.get("using_ugc", False)
+ stored_in_ugc: bool = dataset.index_struct_dict.get("using_ugc", False)
using_ugc = stored_in_ugc
if stored_in_ugc:
dimension = dataset.index_struct_dict["dimension"]
diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py
index 5a263d6e78c3bd..9b029ffc193cc0 100644
--- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py
+++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py
@@ -3,8 +3,8 @@
from typing import Any, Optional
from pydantic import BaseModel, model_validator
-from pymilvus import MilvusClient, MilvusException
-from pymilvus.milvus_client import IndexParams
+from pymilvus import MilvusClient, MilvusException # type: ignore
+from pymilvus.milvus_client import IndexParams # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -54,14 +54,14 @@ def __init__(self, collection_name: str, config: MilvusConfig):
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = "Session"
- self._fields = []
+ self._fields: list[str] = []
def get_type(self) -> str:
return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
- metadatas = [d.metadata for d in texts]
+ metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
@@ -161,8 +161,8 @@ def create_collection(
return
# Grab the existing collection if it exists
if not self._client.has_collection(self._collection_name):
- from pymilvus import CollectionSchema, DataType, FieldSchema
- from pymilvus.orm.types import infer_dtype_bydata
+ from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
+ from pymilvus.orm.types import infer_dtype_bydata # type: ignore
# Determine embedding dim
dim = len(embeddings[0])
@@ -217,10 +217,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
return MilvusVector(
collection_name=collection_name,
config=MilvusConfig(
- uri=dify_config.MILVUS_URI,
- token=dify_config.MILVUS_TOKEN,
- user=dify_config.MILVUS_USER,
- password=dify_config.MILVUS_PASSWORD,
- database=dify_config.MILVUS_DATABASE,
+ uri=dify_config.MILVUS_URI or "",
+ token=dify_config.MILVUS_TOKEN or "",
+ user=dify_config.MILVUS_USER or "",
+ password=dify_config.MILVUS_PASSWORD or "",
+ database=dify_config.MILVUS_DATABASE or "",
),
)
diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py
index b7b6b803ad20af..e63e1f522b3812 100644
--- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py
+++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py
@@ -74,15 +74,16 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
columns = ["id", "text", "vector", "metadata"]
values = []
for i, doc in enumerate(documents):
- doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
- row = (
- doc_id,
- self.escape_str(doc.page_content),
- embeddings[i],
- json.dumps(doc.metadata) if doc.metadata else {},
- )
- values.append(str(row))
- ids.append(doc_id)
+ if doc.metadata is not None:
+ doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
+ row = (
+ doc_id,
+ self.escape_str(doc.page_content),
+ embeddings[i],
+ json.dumps(doc.metadata) if doc.metadata else {},
+ )
+ values.append(str(row))
+ ids.append(doc_id)
sql = f"""
INSERT INTO {self._config.database}.{self._collection_name}
({",".join(columns)}) VALUES {",".join(values)}
diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
index c44338d42a591a..957c799a60cbfe 100644
--- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
+++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
@@ -4,7 +4,7 @@
from typing import Any
from pydantic import BaseModel, model_validator
-from pyobvector import VECTOR, ObVecClient
+from pyobvector import VECTOR, ObVecClient # type: ignore
from sqlalchemy import JSON, Column, String, func
from sqlalchemy.dialects.mysql import LONGTEXT
@@ -131,7 +131,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
def text_exists(self, id: str) -> bool:
cur = self._client.get(table_name=self._collection_name, id=id)
- return cur.rowcount != 0
+ return bool(cur.rowcount != 0)
def delete_by_ids(self, ids: list[str]) -> None:
self._client.delete(table_name=self._collection_name, ids=ids)
diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py
index 7a976d7c3c8955..72a15022052f0a 100644
--- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py
+++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py
@@ -66,7 +66,7 @@ def get_type(self) -> str:
return VectorType.OPENSEARCH
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
- metadatas = [d.metadata for d in texts]
+ metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings)
@@ -244,7 +244,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
open_search_config = OpenSearchConfig(
- host=dify_config.OPENSEARCH_HOST,
+ host=dify_config.OPENSEARCH_HOST or "localhost",
port=dify_config.OPENSEARCH_PORT,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py
index 74608f1e1a3b05..dfff3563c3bb28 100644
--- a/api/core/rag/datasource/vdb/oracle/oraclevector.py
+++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py
@@ -5,7 +5,7 @@
from contextlib import contextmanager
from typing import Any
-import jieba.posseg as pseg
+import jieba.posseg as pseg # type: ignore
import numpy
import oracledb
from pydantic import BaseModel, model_validator
@@ -88,12 +88,11 @@ def input_type_handler(self, cursor, value, arraysize):
def numpy_converter_out(self, value):
if value.typecode == "b":
- dtype = numpy.int8
+ return numpy.array(value, copy=False, dtype=numpy.int8)
elif value.typecode == "f":
- dtype = numpy.float32
+ return numpy.array(value, copy=False, dtype=numpy.float32)
else:
- dtype = numpy.float64
- return numpy.array(value, copy=False, dtype=dtype)
+ return numpy.array(value, copy=False, dtype=numpy.float64)
def output_type_handler(self, cursor, metadata):
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
@@ -135,17 +134,18 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
values = []
pks = []
for i, doc in enumerate(documents):
- doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
- pks.append(doc_id)
- values.append(
- (
- doc_id,
- doc.page_content,
- json.dumps(doc.metadata),
- # array.array("f", embeddings[i]),
- numpy.array(embeddings[i]),
+ if doc.metadata is not None:
+ doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
+ pks.append(doc_id)
+ values.append(
+ (
+ doc_id,
+ doc.page_content,
+ json.dumps(doc.metadata),
+ # array.array("f", embeddings[i]),
+ numpy.array(embeddings[i]),
+ )
)
- )
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
with self._get_cursor() as cur:
cur.executemany(
@@ -201,8 +201,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# lazy import
- import nltk
- from nltk.corpus import stopwords
+ import nltk # type: ignore
+ from nltk.corpus import stopwords # type: ignore
top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later
@@ -285,10 +285,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
return OracleVector(
collection_name=collection_name,
config=OracleVectorConfig(
- host=dify_config.ORACLE_HOST,
+ host=dify_config.ORACLE_HOST or "localhost",
port=dify_config.ORACLE_PORT,
- user=dify_config.ORACLE_USER,
- password=dify_config.ORACLE_PASSWORD,
- database=dify_config.ORACLE_DATABASE,
+ user=dify_config.ORACLE_USER or "system",
+ password=dify_config.ORACLE_PASSWORD or "oracle",
+ database=dify_config.ORACLE_DATABASE or "orcl",
),
)
diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
index 7cbbdcc81f6039..221bc68d68a6f7 100644
--- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
+++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
@@ -4,7 +4,7 @@
from uuid import UUID, uuid4
from numpy import ndarray
-from pgvecto_rs.sqlalchemy import VECTOR
+from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
from pydantic import BaseModel, model_validator
from sqlalchemy import Float, String, create_engine, insert, select, text
from sqlalchemy import text as sql_text
@@ -58,7 +58,7 @@ def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int):
with Session(self._client) as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
session.commit()
- self._fields = []
+ self._fields: list[str] = []
class _Table(CollectionORM):
__tablename__ = collection_name
@@ -222,11 +222,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
return PGVectoRS(
collection_name=collection_name,
config=PgvectoRSConfig(
- host=dify_config.PGVECTO_RS_HOST,
- port=dify_config.PGVECTO_RS_PORT,
- user=dify_config.PGVECTO_RS_USER,
- password=dify_config.PGVECTO_RS_PASSWORD,
- database=dify_config.PGVECTO_RS_DATABASE,
+ host=dify_config.PGVECTO_RS_HOST or "localhost",
+ port=dify_config.PGVECTO_RS_PORT or 5432,
+ user=dify_config.PGVECTO_RS_USER or "postgres",
+ password=dify_config.PGVECTO_RS_PASSWORD or "",
+ database=dify_config.PGVECTO_RS_DATABASE or "postgres",
),
dim=dim,
)
diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py
index 40a9cdd136b404..271281ca7e939f 100644
--- a/api/core/rag/datasource/vdb/pgvector/pgvector.py
+++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py
@@ -3,8 +3,8 @@
from contextlib import contextmanager
from typing import Any
-import psycopg2.extras
-import psycopg2.pool
+import psycopg2.extras # type: ignore
+import psycopg2.pool # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config
@@ -98,16 +98,17 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
values = []
pks = []
for i, doc in enumerate(documents):
- doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
- pks.append(doc_id)
- values.append(
- (
- doc_id,
- doc.page_content,
- json.dumps(doc.metadata),
- embeddings[i],
+ if doc.metadata is not None:
+ doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
+ pks.append(doc_id)
+ values.append(
+ (
+ doc_id,
+ doc.page_content,
+ json.dumps(doc.metadata),
+ embeddings[i],
+ )
)
- )
with self._get_cursor() as cur:
psycopg2.extras.execute_values(
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
@@ -216,11 +217,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
return PGVector(
collection_name=collection_name,
config=PGVectorConfig(
- host=dify_config.PGVECTOR_HOST,
+ host=dify_config.PGVECTOR_HOST or "localhost",
port=dify_config.PGVECTOR_PORT,
- user=dify_config.PGVECTOR_USER,
- password=dify_config.PGVECTOR_PASSWORD,
- database=dify_config.PGVECTOR_DATABASE,
+ user=dify_config.PGVECTOR_USER or "postgres",
+ password=dify_config.PGVECTOR_PASSWORD or "",
+ database=dify_config.PGVECTOR_DATABASE or "postgres",
min_connection=dify_config.PGVECTOR_MIN_CONNECTION,
max_connection=dify_config.PGVECTOR_MAX_CONNECTION,
),
diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
index 3811458e02957c..6e94cb69db309d 100644
--- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
+++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
@@ -51,6 +51,8 @@ def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
+ if not self.root_path:
+ raise ValueError("Root path is not set")
path = os.path.join(self.root_path, path)
return {"path": path}
@@ -149,9 +151,12 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
-
added_ids = []
- for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
+ # Filter out None values from metadatas list to match expected type
+ filtered_metadatas = [m for m in metadatas if m is not None]
+ for batch_ids, points in self._generate_rest_batches(
+ texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
+ ):
self._client.upsert(collection_name=self._collection_name, points=points)
added_ids.extend(batch_ids)
@@ -194,7 +199,7 @@ def _generate_rest_batches(
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
- group_id,
+ group_id or "", # Ensure group_id is never None
Field.GROUP_KEY.value,
),
)
@@ -337,18 +342,20 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
)
docs = []
for result in results:
+ if result.payload is None:
+ continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(
- page_content=result.payload.get(Field.CONTENT_KEY.value),
+ page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
- docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
+ docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -432,9 +439,9 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
collection_name=collection_name,
group_id=dataset.id,
config=QdrantConfig(
- endpoint=dify_config.QDRANT_URL,
+ endpoint=dify_config.QDRANT_URL or "",
api_key=dify_config.QDRANT_API_KEY,
- root_path=current_app.config.root_path,
+ root_path=str(current_app.config.root_path),
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py
index f373dcfeabef92..a3a20448ff7a0a 100644
--- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py
+++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py
@@ -3,7 +3,7 @@
from typing import Any, Optional
from pydantic import BaseModel, model_validator
-from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
+from sqlalchemy import Column, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session
@@ -58,14 +58,14 @@ def __init__(self, collection_name: str, config: RelytConfig, group_id: str):
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self.client = create_engine(self._url)
- self._fields = []
+ self._fields: list[str] = []
self._group_id = group_id
def get_type(self) -> str:
return VectorType.RELYT
- def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
- index_params = {}
+ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None:
+ index_params: dict[str, Any] = {}
metadatas = [d.metadata for d in texts]
self.create_collection(len(embeddings[0]))
self.embedding_dimension = len(embeddings[0])
@@ -107,10 +107,10 @@ def create_collection(self, dimension: int):
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
- from pgvecto_rs.sqlalchemy import VECTOR
+ from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
ids = [str(uuid.uuid1()) for _ in documents]
- metadatas = [d.metadata for d in documents]
+ metadatas = [d.metadata for d in documents if d.metadata is not None]
for metadata in metadatas:
metadata["group_id"] = self._group_id
texts = [d.page_content for d in documents]
@@ -242,10 +242,6 @@ def similarity_search_with_score_by_vector(
filter: Optional[dict] = None,
) -> list[tuple[Document, float]]:
# Add the filter if provided
- try:
- from sqlalchemy.engine import Row
- except ImportError:
- raise ImportError("Could not import Row from sqlalchemy.engine. Please 'pip install sqlalchemy>=1.4'.")
filter_condition = ""
if filter is not None:
@@ -275,7 +271,7 @@ def similarity_search_with_score_by_vector(
# Execute the query and fetch the results
with self.client.connect() as conn:
- results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall()
+ results = conn.execute(sql_text(sql_query), params).fetchall()
documents_with_scores = [
(
@@ -307,11 +303,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
return RelytVector(
collection_name=collection_name,
config=RelytConfig(
- host=dify_config.RELYT_HOST,
+ host=dify_config.RELYT_HOST or "localhost",
port=dify_config.RELYT_PORT,
- user=dify_config.RELYT_USER,
- password=dify_config.RELYT_PASSWORD,
- database=dify_config.RELYT_DATABASE,
+ user=dify_config.RELYT_USER or "",
+ password=dify_config.RELYT_PASSWORD or "",
+ database=dify_config.RELYT_DATABASE or "default",
),
group_id=dataset.id,
)
diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py
index f971a9c5eb1696..c15f4b229f81c3 100644
--- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py
+++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py
@@ -2,10 +2,10 @@
from typing import Any, Optional
from pydantic import BaseModel
-from tcvectordb import VectorDBClient
-from tcvectordb.model import document, enum
-from tcvectordb.model import index as vdb_index
-from tcvectordb.model.document import Filter
+from tcvectordb import VectorDBClient # type: ignore
+from tcvectordb.model import document, enum # type: ignore
+from tcvectordb.model import index as vdb_index # type: ignore
+from tcvectordb.model.document import Filter # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -25,8 +25,8 @@ class TencentConfig(BaseModel):
database: Optional[str]
index_type: str = "HNSW"
metric_type: str = "L2"
- shard: int = (1,)
- replicas: int = (2,)
+ shard: int = 1
+ replicas: int = 2
def to_tencent_params(self):
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
@@ -120,15 +120,15 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
metadatas = [doc.metadata for doc in documents]
total_count = len(embeddings)
docs = []
- for id in range(0, total_count):
+ for i in range(0, total_count):
if metadatas is None:
continue
- metadata = json.dumps(metadatas[id])
+ metadata = metadatas[i] or {}
doc = document.Document(
- id=metadatas[id]["doc_id"],
- vector=embeddings[id],
- text=texts[id],
- metadata=metadata,
+ id=metadata.get("doc_id"),
+ vector=embeddings[i],
+ text=texts[i],
+ metadata=json.dumps(metadata),
)
docs.append(doc)
self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout)
@@ -159,8 +159,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
- def _get_search_res(self, res, score_threshold):
- docs = []
+ def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]:
+ docs: list[Document] = []
if res is None or len(res) == 0:
return docs
@@ -193,7 +193,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
return TencentVector(
collection_name=collection_name,
config=TencentConfig(
- url=dify_config.TENCENT_VECTOR_DB_URL,
+ url=dify_config.TENCENT_VECTOR_DB_URL or "",
api_key=dify_config.TENCENT_VECTOR_DB_API_KEY,
timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT,
username=dify_config.TENCENT_VECTOR_DB_USERNAME,
diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
index cfd47aac5ba05b..19c5579a688f5a 100644
--- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
+++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
@@ -54,7 +54,10 @@ def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
- path = os.path.join(self.root_path, path)
+ if self.root_path:
+ path = os.path.join(self.root_path, path)
+ else:
+ raise ValueError("root_path is required")
return {"path": path}
else:
@@ -157,7 +160,7 @@ def create_collection(self, collection_name: str, vector_size: int):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
- metadatas = [d.metadata for d in documents]
+ metadatas = [d.metadata for d in documents if d.metadata is not None]
added_ids = []
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
@@ -203,7 +206,7 @@ def _generate_rest_batches(
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
- group_id,
+ group_id or "",
Field.GROUP_KEY.value,
),
)
@@ -334,18 +337,20 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
)
docs = []
for result in results:
+ if result.payload is None:
+ continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold") or 0.0
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(
- page_content=result.payload.get(Field.CONTENT_KEY.value),
+ page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
- docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
+ docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -427,12 +432,12 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
else:
new_cluster = TidbService.create_tidb_serverless_cluster(
- dify_config.TIDB_PROJECT_ID,
- dify_config.TIDB_API_URL,
- dify_config.TIDB_IAM_API_URL,
- dify_config.TIDB_PUBLIC_KEY,
- dify_config.TIDB_PRIVATE_KEY,
- dify_config.TIDB_REGION,
+ dify_config.TIDB_PROJECT_ID or "",
+ dify_config.TIDB_API_URL or "",
+ dify_config.TIDB_IAM_API_URL or "",
+ dify_config.TIDB_PUBLIC_KEY or "",
+ dify_config.TIDB_PRIVATE_KEY or "",
+ dify_config.TIDB_REGION or "",
)
new_tidb_auth_binding = TidbAuthBinding(
cluster_id=new_cluster["cluster_id"],
@@ -464,9 +469,9 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
collection_name=collection_name,
group_id=dataset.id,
config=TidbOnQdrantConfig(
- endpoint=dify_config.TIDB_ON_QDRANT_URL,
+ endpoint=dify_config.TIDB_ON_QDRANT_URL or "",
api_key=TIDB_ON_QDRANT_API_KEY,
- root_path=config.root_path,
+ root_path=str(config.root_path),
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT,
prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED,
diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
index 8dd5922ad0171d..0a48c79511bf26 100644
--- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
+++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
@@ -146,7 +146,7 @@ def batch_update_tidb_serverless_cluster_status(
iam_url: str,
public_key: str,
private_key: str,
- ) -> list[dict]:
+ ):
"""
Update the status of a new TiDB Serverless cluster.
:param project_id: The project ID of the TiDB Cloud project (required).
@@ -159,7 +159,6 @@ def batch_update_tidb_serverless_cluster_status(
:return: The response from the API.
"""
- clusters = []
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "BASIC"}
@@ -169,7 +168,6 @@ def batch_update_tidb_serverless_cluster_status(
if response.status_code == 200:
response_data = response.json()
- cluster_infos = []
for item in response_data["clusters"]:
state = item["state"]
userPrefix = item["userPrefix"]
@@ -236,16 +234,17 @@ def batch_create_tidb_serverless_cluster(
cluster_infos = []
for item in response_data["clusters"]:
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
- password = redis_client.get(cache_key)
- if not password:
+ cached_password = redis_client.get(cache_key)
+ if not cached_password:
continue
cluster_info = {
"cluster_id": item["clusterId"],
"cluster_name": item["displayName"],
"account": "root",
- "password": password.decode("utf-8"),
+ "password": cached_password.decode("utf-8"),
}
cluster_infos.append(cluster_info)
return cluster_infos
else:
response.raise_for_status()
+ return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception
diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
index 39ab6ea71e9485..be3a417390e802 100644
--- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
+++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
@@ -49,7 +49,7 @@ def get_type(self) -> str:
return VectorType.TIDB_VECTOR
def _table(self, dim: int) -> Table:
- from tidb_vector.sqlalchemy import VectorType
+ from tidb_vector.sqlalchemy import VectorType # type: ignore
return Table(
self._collection_name,
@@ -241,11 +241,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
return TiDBVector(
collection_name=collection_name,
config=TiDBVectorConfig(
- host=dify_config.TIDB_VECTOR_HOST,
- port=dify_config.TIDB_VECTOR_PORT,
- user=dify_config.TIDB_VECTOR_USER,
- password=dify_config.TIDB_VECTOR_PASSWORD,
- database=dify_config.TIDB_VECTOR_DATABASE,
+ host=dify_config.TIDB_VECTOR_HOST or "",
+ port=dify_config.TIDB_VECTOR_PORT or 0,
+ user=dify_config.TIDB_VECTOR_USER or "",
+ password=dify_config.TIDB_VECTOR_PASSWORD or "",
+ database=dify_config.TIDB_VECTOR_DATABASE or "",
program_name=dify_config.APPLICATION_NAME,
),
)
diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py
index 22e191340d3a47..edfce2edd896ee 100644
--- a/api/core/rag/datasource/vdb/vector_base.py
+++ b/api/core/rag/datasource/vdb/vector_base.py
@@ -51,15 +51,16 @@ def delete(self) -> None:
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts.copy():
- doc_id = text.metadata["doc_id"]
- exists_duplicate_node = self.text_exists(doc_id)
- if exists_duplicate_node:
- texts.remove(text)
+ if text.metadata and "doc_id" in text.metadata:
+ doc_id = text.metadata["doc_id"]
+ exists_duplicate_node = self.text_exists(doc_id)
+ if exists_duplicate_node:
+ texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
- return [text.metadata["doc_id"] for text in texts]
+ return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata]
@property
def collection_name(self):
diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py
index 6d2e04fc020ab5..523fa80f124b0c 100644
--- a/api/core/rag/datasource/vdb/vector_factory.py
+++ b/api/core/rag/datasource/vdb/vector_factory.py
@@ -193,10 +193,13 @@ def _get_embeddings(self) -> Embeddings:
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts.copy():
+ if text.metadata is None:
+ continue
doc_id = text.metadata["doc_id"]
- exists_duplicate_node = self.text_exists(doc_id)
- if exists_duplicate_node:
- texts.remove(text)
+ if doc_id:
+ exists_duplicate_node = self.text_exists(doc_id)
+ if exists_duplicate_node:
+ texts.remove(text)
return texts
diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
index 4f927f28995613..9de8761a91ca68 100644
--- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
+++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
@@ -2,7 +2,7 @@
from typing import Any
from pydantic import BaseModel
-from volcengine.viking_db import (
+from volcengine.viking_db import ( # type: ignore
Data,
DistanceType,
Field,
@@ -121,11 +121,12 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
for i, page_content in enumerate(page_contents):
metadata = {}
if metadatas is not None:
- for key, val in metadatas[i].items():
+ for key, val in (metadatas[i] or {}).items():
metadata[key] = val
+ # FIXME: fix the type of metadata later
doc = Data(
{
- vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"],
+ vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore
vdb_Field.VECTOR.value: embeddings[i] if embeddings else None,
vdb_Field.CONTENT_KEY.value: page_content,
vdb_Field.METADATA_KEY.value: json.dumps(metadata),
@@ -178,7 +179,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(results, score_threshold)
- def _get_search_res(self, results, score_threshold):
+ def _get_search_res(self, results, score_threshold) -> list[Document]:
if len(results) == 0:
return []
@@ -191,7 +192,7 @@ def _get_search_res(self, results, score_threshold):
metadata["score"] = result.score
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
- docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
+ docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
index 649cfbfea8253c..68d043a19f171f 100644
--- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
+++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
@@ -3,7 +3,7 @@
from typing import Any, Optional
import requests
-import weaviate
+import weaviate # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config
@@ -107,7 +107,8 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
for i, text in enumerate(texts):
data_properties = {Field.TEXT_KEY.value: text}
if metadatas is not None:
- for key, val in metadatas[i].items():
+ # metadata maybe None
+ for key, val in (metadatas[i] or {}).items():
data_properties[key] = self._json_serializable(val)
batch.add_data_object(
@@ -208,10 +209,11 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# check score threshold
if score > score_threshold:
- doc.metadata["score"] = score
- docs.append(doc)
+ if doc.metadata is not None:
+ doc.metadata["score"] = score
+ docs.append(doc)
# Sort the documents by score in descending order
- docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
+ docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -275,7 +277,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
- endpoint=dify_config.WEAVIATE_ENDPOINT,
+ endpoint=dify_config.WEAVIATE_ENDPOINT or "",
api_key=dify_config.WEAVIATE_API_KEY,
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
),
diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py
index 319a2612c7ecb8..35becaa0c7bea7 100644
--- a/api/core/rag/docstore/dataset_docstore.py
+++ b/api/core/rag/docstore/dataset_docstore.py
@@ -83,6 +83,9 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True) ->
if not isinstance(doc, Document):
raise ValueError("doc must be a Document")
+ if doc.metadata is None:
+ raise ValueError("doc.metadata must be a dict")
+
segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"])
# NOTE: doc could already exist in the store, but we overwrite it
@@ -179,10 +182,10 @@ def get_document_hash(self, doc_id: str) -> Optional[str]:
if document_segment is None:
return None
+ data: Optional[str] = document_segment.index_node_hash
+ return data
- return document_segment.index_node_hash
-
- def get_document_segment(self, doc_id: str) -> DocumentSegment:
+ def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
document_segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py
index 8ddda7e9832d97..a2c8737da79198 100644
--- a/api/core/rag/embedding/cached_embedding.py
+++ b/api/core/rag/embedding/cached_embedding.py
@@ -1,6 +1,6 @@
import base64
import logging
-from typing import Optional, cast
+from typing import Any, Optional, cast
import numpy as np
from sqlalchemy.exc import IntegrityError
@@ -27,7 +27,7 @@ def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) ->
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs in batches of 10."""
# use doc embedding cache or store if not exists
- text_embeddings = [None for _ in range(len(texts))]
+ text_embeddings: list[Any] = [None for _ in range(len(texts))]
embedding_queue_indices = []
for i, text in enumerate(texts):
hash = helper.generate_text_hash(text)
@@ -64,7 +64,8 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
for vector in embedding_result.embeddings:
try:
- normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
+ # FIXME: type ignore for numpy here
+ normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
if np.isnan(normalized_embedding).any():
# for issue #11827 float values are not json compliant
@@ -77,8 +78,8 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
logging.exception("Failed transform embedding")
cache_embeddings = []
try:
- for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
- text_embeddings[i] = embedding
+ for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
+ text_embeddings[i] = n_embedding
hash = helper.generate_text_hash(texts[i])
if hash not in cache_embeddings:
embedding_cache = Embedding(
@@ -86,7 +87,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
hash=hash,
provider_name=self._model_instance.provider,
)
- embedding_cache.set_embedding(embedding)
+ embedding_cache.set_embedding(n_embedding)
db.session.add(embedding_cache)
cache_embeddings.append(hash)
db.session.commit()
@@ -115,7 +116,8 @@ def embed_query(self, text: str) -> list[float]:
)
embedding_results = embedding_result.embeddings[0]
- embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
+ # FIXME: type ignore for numpy here
+ embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore
if np.isnan(embedding_results).any():
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex:
diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py
index 3692b5d19dfb65..7c00c668dd49a3 100644
--- a/api/core/rag/extractor/entity/extract_setting.py
+++ b/api/core/rag/extractor/entity/extract_setting.py
@@ -14,7 +14,7 @@ class NotionInfo(BaseModel):
notion_workspace_id: str
notion_obj_id: str
notion_page_type: str
- document: Document = None
+ document: Optional[Document] = None
tenant_id: str
model_config = ConfigDict(arbitrary_types_allowed=True)
diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py
index fc331657195454..c444105bb59443 100644
--- a/api/core/rag/extractor/excel_extractor.py
+++ b/api/core/rag/extractor/excel_extractor.py
@@ -1,7 +1,7 @@
"""Abstract interface for document loader implementations."""
import os
-from typing import Optional
+from typing import Optional, cast
import pandas as pd
from openpyxl import load_workbook
@@ -47,7 +47,7 @@ def extract(self) -> list[Document]:
for col_index, (k, v) in enumerate(row.items()):
if pd.notna(v):
cell = sheet.cell(
- row=index + 2, column=col_index + 1
+ row=cast(int, index) + 2, column=col_index + 1
) # +2 to account for header and 1-based index
if cell.hyperlink:
value = f"[{v}]({cell.hyperlink.target})"
@@ -60,8 +60,8 @@ def extract(self) -> list[Document]:
elif file_extension == ".xls":
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
- for sheet_name in excel_file.sheet_names:
- df = excel_file.parse(sheet_name=sheet_name)
+ for excel_sheet_name in excel_file.sheet_names:
+ df = excel_file.parse(sheet_name=excel_sheet_name)
df.dropna(how="all", inplace=True)
for _, row in df.iterrows():
diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py
index 69659e31080da6..a473b3dfa78a90 100644
--- a/api/core/rag/extractor/extract_processor.py
+++ b/api/core/rag/extractor/extract_processor.py
@@ -10,6 +10,7 @@
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.excel_extractor import ExcelExtractor
+from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
from core.rag.extractor.html_extractor import HtmlExtractor
from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor
@@ -66,9 +67,13 @@ def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Docume
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
- suffix = "." + re.search(r"\.(\w+)$", filename).group(1)
-
- file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
+ match = re.search(r"\.(\w+)$", filename)
+ if match:
+ suffix = "." + match.group(1)
+ else:
+ suffix = ""
+ # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
+ file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
Path(file_path).write_bytes(response.content)
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
if return_text:
@@ -89,15 +94,20 @@ def extract(
if extract_setting.datasource_type == DatasourceType.FILE.value:
with tempfile.TemporaryDirectory() as temp_dir:
if not file_path:
+ assert extract_setting.upload_file is not None, "upload_file is required"
upload_file: UploadFile = extract_setting.upload_file
suffix = Path(upload_file.key).suffix
- file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
+ # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
+ file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
file_extension = input_file.suffix.lower()
etl_type = dify_config.ETL_TYPE
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
+ assert unstructured_api_url is not None, "unstructured_api_url is required"
+ assert unstructured_api_key is not None, "unstructured_api_key is required"
+ extractor: Optional[BaseExtractor] = None
if etl_type == "Unstructured":
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
@@ -156,6 +166,7 @@ def extract(
extractor = TextExtractor(file_path, autodetect_encoding=True)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.NOTION.value:
+ assert extract_setting.notion_info is not None, "notion_info is required"
extractor = NotionExtractor(
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
notion_obj_id=extract_setting.notion_info.notion_obj_id,
@@ -165,6 +176,7 @@ def extract(
)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
+ assert extract_setting.website_info is not None, "website_info is required"
if extract_setting.website_info.provider == "firecrawl":
extractor = FirecrawlWebExtractor(
url=extract_setting.website_info.url,
diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py
index 17c2087a0ab575..8ae4579c7cf93f 100644
--- a/api/core/rag/extractor/firecrawl/firecrawl_app.py
+++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py
@@ -1,5 +1,6 @@
import json
import time
+from typing import cast
import requests
@@ -20,9 +21,9 @@ def scrape_url(self, url, params=None) -> dict:
json_data.update(params)
response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data)
if response.status_code == 200:
- response = response.json()
- if response["success"] == True:
- data = response["data"]
+ response_data = response.json()
+ if response_data["success"] == True:
+ data = response_data["data"]
return {
"title": data.get("metadata").get("title"),
"description": data.get("metadata").get("description"),
@@ -30,7 +31,7 @@ def scrape_url(self, url, params=None) -> dict:
"markdown": data.get("markdown"),
}
else:
- raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
+ raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}')
elif response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
@@ -46,9 +47,11 @@ def crawl_url(self, url, params=None) -> str:
response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers)
if response.status_code == 200:
job_id = response.json().get("jobId")
- return job_id
+ return cast(str, job_id)
else:
self._handle_error(response, "start crawl job")
+ # FIXME: unreachable code for mypy
+ return "" # unreachable
def check_crawl_status(self, job_id) -> dict:
headers = self._prepare_headers()
@@ -64,9 +67,9 @@ def check_crawl_status(self, job_id) -> dict:
for item in data:
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
url_data = {
- "title": item.get("metadata").get("title"),
- "description": item.get("metadata").get("description"),
- "source_url": item.get("metadata").get("sourceURL"),
+ "title": item.get("metadata", {}).get("title"),
+ "description": item.get("metadata", {}).get("description"),
+ "source_url": item.get("metadata", {}).get("sourceURL"),
"markdown": item.get("markdown"),
}
url_data_list.append(url_data)
@@ -92,6 +95,8 @@ def check_crawl_status(self, job_id) -> dict:
else:
self._handle_error(response, "check crawl status")
+ # FIXME: unreachable code for mypy
+ return {} # unreachable
def _prepare_headers(self):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
diff --git a/api/core/rag/extractor/html_extractor.py b/api/core/rag/extractor/html_extractor.py
index 560c2d1d84b04e..350b522347b09d 100644
--- a/api/core/rag/extractor/html_extractor.py
+++ b/api/core/rag/extractor/html_extractor.py
@@ -1,6 +1,6 @@
"""Abstract interface for document loader implementations."""
-from bs4 import BeautifulSoup
+from bs4 import BeautifulSoup # type: ignore
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
@@ -23,6 +23,7 @@ def extract(self) -> list[Document]:
return [Document(page_content=self._load_as_text())]
def _load_as_text(self) -> str:
+ text: str = ""
with open(self._file_path, "rb") as fp:
soup = BeautifulSoup(fp, "html.parser")
text = soup.get_text()
diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py
index 87a4ce08bf3f89..fdc2e46d141d07 100644
--- a/api/core/rag/extractor/notion_extractor.py
+++ b/api/core/rag/extractor/notion_extractor.py
@@ -1,6 +1,6 @@
import json
import logging
-from typing import Any, Optional
+from typing import Any, Optional, cast
import requests
@@ -78,6 +78,7 @@ def _load_data_as_documents(self, notion_obj_id: str, notion_page_type: str) ->
def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]:
"""Get all the pages from a Notion database."""
+ assert self._notion_access_token is not None, "Notion access token is required"
res = requests.post(
DATABASE_URL_TMPL.format(database_id=database_id),
headers={
@@ -96,6 +97,7 @@ def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any]
for result in data["results"]:
properties = result["properties"]
data = {}
+ value: Any
for property_name, property_value in properties.items():
type = property_value["type"]
if type == "multi_select":
@@ -130,6 +132,7 @@ def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any]
return [Document(page_content="\n".join(database_content))]
def _get_notion_block_data(self, page_id: str) -> list[str]:
+ assert self._notion_access_token is not None, "Notion access token is required"
result_lines_arr = []
start_cursor = None
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id)
@@ -184,6 +187,7 @@ def _get_notion_block_data(self, page_id: str) -> list[str]:
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
"""Read a block."""
+ assert self._notion_access_token is not None, "Notion access token is required"
result_lines_arr = []
start_cursor = None
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id)
@@ -242,6 +246,7 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
def _read_table_rows(self, block_id: str) -> str:
"""Read table rows."""
+ assert self._notion_access_token is not None, "Notion access token is required"
done = False
result_lines_arr = []
start_cursor = None
@@ -296,7 +301,7 @@ def _read_table_rows(self, block_id: str) -> str:
result_lines = "\n".join(result_lines_arr)
return result_lines
- def update_last_edited_time(self, document_model: DocumentModel):
+ def update_last_edited_time(self, document_model: Optional[DocumentModel]):
if not document_model:
return
@@ -309,6 +314,7 @@ def update_last_edited_time(self, document_model: DocumentModel):
db.session.commit()
def get_notion_last_edited_time(self) -> str:
+ assert self._notion_access_token is not None, "Notion access token is required"
obj_id = self._notion_obj_id
page_type = self._notion_page_type
if page_type == "database":
@@ -330,7 +336,7 @@ def get_notion_last_edited_time(self) -> str:
)
data = res.json()
- return data["last_edited_time"]
+ return cast(str, data["last_edited_time"])
@classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
@@ -349,4 +355,4 @@ def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
f"and notion workspace {notion_workspace_id}"
)
- return data_source_binding.access_token
+ return cast(str, data_source_binding.access_token)
diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py
index 57cb9610ba267e..89a7061c26accc 100644
--- a/api/core/rag/extractor/pdf_extractor.py
+++ b/api/core/rag/extractor/pdf_extractor.py
@@ -1,7 +1,7 @@
"""Abstract interface for document loader implementations."""
from collections.abc import Iterator
-from typing import Optional
+from typing import Optional, cast
from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
@@ -27,7 +27,7 @@ def extract(self) -> list[Document]:
plaintext_file_exists = False
if self._file_cache_key:
try:
- text = storage.load(self._file_cache_key).decode("utf-8")
+ text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
@@ -53,7 +53,7 @@ def load(
def parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
- import pypdfium2
+ import pypdfium2 # type: ignore
with blob.as_bytes_io() as file_path:
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py
index bd669bbad36873..9647dedfff8516 100644
--- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py
+++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py
@@ -1,7 +1,7 @@
import base64
import logging
-from bs4 import BeautifulSoup
+from bs4 import BeautifulSoup # type: ignore
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py
index 35220b558afab9..80c29157aaf529 100644
--- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py
+++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py
@@ -30,6 +30,9 @@ def extract(self) -> list[Document]:
if self._api_url:
from unstructured.partition.api import partition_via_api
+ if self._api_key is None:
+ raise ValueError("api_key is required")
+
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
else:
from unstructured.partition.epub import partition_epub
diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py
index 0fdcd58b2e569b..e504d4bc23014c 100644
--- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py
+++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py
@@ -27,9 +27,11 @@ def extract(self) -> list[Document]:
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
else:
raise NotImplementedError("Unstructured API Url is not configured")
- text_by_page = {}
+ text_by_page: dict[int, str] = {}
for element in elements:
page = element.metadata.page_number
+ if page is None:
+ continue
text = element.text
if page in text_by_page:
text_by_page[page] += "\n" + text
diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py
index ab41290fbc4537..cefe72b29052a1 100644
--- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py
+++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py
@@ -29,14 +29,15 @@ def extract(self) -> list[Document]:
from unstructured.partition.pptx import partition_pptx
elements = partition_pptx(filename=self._file_path)
- text_by_page = {}
+ text_by_page: dict[int, str] = {}
for element in elements:
page = element.metadata.page_number
text = element.text
- if page in text_by_page:
- text_by_page[page] += "\n" + text
- else:
- text_by_page[page] = text
+ if page is not None:
+ if page in text_by_page:
+ text_by_page[page] += "\n" + text
+ else:
+ text_by_page[page] = text
combined_texts = list(text_by_page.values())
documents = []
diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py
index 0c38a9c0762130..c3161bc812cb73 100644
--- a/api/core/rag/extractor/word_extractor.py
+++ b/api/core/rag/extractor/word_extractor.py
@@ -89,6 +89,8 @@ def _extract_images_from_docx(self, doc, image_folder):
response = ssrf_proxy.get(url)
if response.status_code == 200:
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
+ if image_ext is None:
+ continue
file_uuid = str(uuid.uuid4())
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
mime_type, _ = mimetypes.guess_type(file_key)
@@ -97,6 +99,8 @@ def _extract_images_from_docx(self, doc, image_folder):
continue
else:
image_ext = rel.target_ref.split(".")[-1]
+ if image_ext is None:
+ continue
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
@@ -226,6 +230,8 @@ def parse_docx(self, docx_path, image_folder):
if x_child is None:
continue
if x.tag.endswith("instrText"):
+ if x.text is None:
+ continue
for i in url_pattern.findall(x.text):
hyperlinks_url = str(i)
except Exception as e:
diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py
index be857bd12215fd..7e5efdc66ed533 100644
--- a/api/core/rag/index_processor/index_processor_base.py
+++ b/api/core/rag/index_processor/index_processor_base.py
@@ -49,6 +49,7 @@ def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optiona
"""
Get the NodeParser object according to the processing rule.
"""
+ character_splitter: TextSplitter
if processing_rule["mode"] == "custom":
# The user-defined segmentation rule
rules = processing_rule["rules"]
diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py
index 9b855ece2c3512..c5ba6295f32f84 100644
--- a/api/core/rag/index_processor/index_processor_factory.py
+++ b/api/core/rag/index_processor/index_processor_factory.py
@@ -9,7 +9,7 @@
class IndexProcessorFactory:
"""IndexProcessorInit."""
- def __init__(self, index_type: str):
+ def __init__(self, index_type: str | None):
self._index_type = index_type
def init_index_processor(self) -> BaseIndexProcessor:
diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py
index a631f953ce2191..c66fa54d503e9f 100644
--- a/api/core/rag/index_processor/processor/paragraph_index_processor.py
+++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py
@@ -27,12 +27,13 @@ def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
# Split the text documents into nodes.
splitter = self._get_splitter(
- processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance")
+ processing_rule=kwargs.get("process_rule", {}),
+ embedding_model_instance=kwargs.get("embedding_model_instance"),
)
all_documents = []
for document in documents:
# document clean
- document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule"))
+ document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule", {}))
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
@@ -41,8 +42,9 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]:
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
- document_node.metadata["doc_id"] = doc_id
- document_node.metadata["doc_hash"] = hash
+ if document_node.metadata is not None:
+ document_node.metadata["doc_id"] = doc_id
+ document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = remove_leading_symbols(document_node.page_content).strip()
if len(page_content) > 0:
diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py
index 320f0157a10049..20fd16e8f39b65 100644
--- a/api/core/rag/index_processor/processor/qa_index_processor.py
+++ b/api/core/rag/index_processor/processor/qa_index_processor.py
@@ -32,15 +32,16 @@ def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
splitter = self._get_splitter(
- processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance")
+ processing_rule=kwargs.get("process_rule") or {},
+ embedding_model_instance=kwargs.get("embedding_model_instance"),
)
# Split the text documents into nodes.
- all_documents = []
- all_qa_documents = []
+ all_documents: list[Document] = []
+ all_qa_documents: list[Document] = []
for document in documents:
# document clean
- document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule"))
+ document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule") or {})
document.page_content = document_text
# parse document to nodes
@@ -50,8 +51,9 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]:
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
- document_node.metadata["doc_id"] = doc_id
- document_node.metadata["doc_hash"] = hash
+ if document_node.metadata is not None:
+ document_node.metadata["doc_id"] = doc_id
+ document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
document_node.page_content = remove_leading_symbols(page_content)
@@ -64,7 +66,7 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]:
document_format_thread = threading.Thread(
target=self._format_qa_document,
kwargs={
- "flask_app": current_app._get_current_object(),
+ "flask_app": current_app._get_current_object(), # type: ignore
"tenant_id": kwargs.get("tenant_id"),
"document_node": doc,
"all_qa_documents": all_qa_documents,
@@ -148,11 +150,12 @@ def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, a
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy())
- doc_id = str(uuid.uuid4())
- hash = helper.generate_text_hash(result["question"])
- qa_document.metadata["answer"] = result["answer"]
- qa_document.metadata["doc_id"] = doc_id
- qa_document.metadata["doc_hash"] = hash
+ if qa_document.metadata is not None:
+ doc_id = str(uuid.uuid4())
+ hash = helper.generate_text_hash(result["question"])
+ qa_document.metadata["answer"] = result["answer"]
+ qa_document.metadata["doc_id"] = doc_id
+ qa_document.metadata["doc_hash"] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py
index 6ae432a526b169..ac7a3f8bb857e4 100644
--- a/api/core/rag/rerank/rerank_model.py
+++ b/api/core/rag/rerank/rerank_model.py
@@ -30,7 +30,11 @@ def run(
doc_ids = set()
unique_documents = []
for document in documents:
- if document.provider == "dify" and document.metadata["doc_id"] not in doc_ids:
+ if (
+ document.provider == "dify"
+ and document.metadata is not None
+ and document.metadata["doc_id"] not in doc_ids
+ ):
doc_ids.add(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
@@ -54,7 +58,8 @@ def run(
metadata=documents[result.index].metadata,
provider=documents[result.index].provider,
)
- rerank_document.metadata["score"] = result.score
- rerank_documents.append(rerank_document)
+ if rerank_document.metadata is not None:
+ rerank_document.metadata["score"] = result.score
+ rerank_documents.append(rerank_document)
return rerank_documents
diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py
index 4719be012f99cc..cbc96037bf2cc0 100644
--- a/api/core/rag/rerank/weight_rerank.py
+++ b/api/core/rag/rerank/weight_rerank.py
@@ -39,7 +39,7 @@ def run(
unique_documents = []
doc_ids = set()
for document in documents:
- if document.metadata["doc_id"] not in doc_ids:
+ if document.metadata is not None and document.metadata["doc_id"] not in doc_ids:
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
@@ -56,10 +56,11 @@ def run(
)
if score_threshold and score < score_threshold:
continue
- document.metadata["score"] = score
- rerank_documents.append(document)
+ if document.metadata is not None:
+ document.metadata["score"] = score
+ rerank_documents.append(document)
- rerank_documents.sort(key=lambda x: x.metadata["score"], reverse=True)
+ rerank_documents.sort(key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
@@ -76,8 +77,9 @@ def _calculate_keyword_score(self, query: str, documents: list[Document]) -> lis
for document in documents:
# get the document keywords
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
- document.metadata["keywords"] = document_keywords
- documents_keywords.append(document_keywords)
+ if document.metadata is not None:
+ document.metadata["keywords"] = document_keywords
+ documents_keywords.append(document_keywords)
# Counter query keywords(TF)
query_keyword_counts = Counter(query_keywords)
@@ -162,7 +164,7 @@ def _calculate_cosine(
query_vector = cache_embedding.embed_query(query)
for document in documents:
# calculate cosine similarity
- if "score" in document.metadata:
+ if document.metadata and "score" in document.metadata:
query_vector_scores.append(document.metadata["score"])
else:
# transform to NumPy
diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py
index 7a5bf39fa63f48..a265f36671b04b 100644
--- a/api/core/rag/retrieval/dataset_retrieval.py
+++ b/api/core/rag/retrieval/dataset_retrieval.py
@@ -1,7 +1,7 @@
import math
import threading
from collections import Counter
-from typing import Optional, cast
+from typing import Any, Optional, cast
from flask import Flask, current_app
@@ -34,7 +34,7 @@
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
-default_retrieval_model = {
+default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@@ -140,12 +140,12 @@ def retrieve(
user_from,
available_datasets,
query,
- retrieve_config.top_k,
- retrieve_config.score_threshold,
- retrieve_config.rerank_mode,
+ retrieve_config.top_k or 0,
+ retrieve_config.score_threshold or 0,
+ retrieve_config.rerank_mode or "reranking_model",
retrieve_config.reranking_model,
retrieve_config.weights,
- retrieve_config.reranking_enabled,
+ retrieve_config.reranking_enabled or True,
message_id,
)
@@ -300,10 +300,11 @@ def single_retrieve(
metadata=external_document.get("metadata"),
provider="external",
)
- document.metadata["score"] = external_document.get("score")
- document.metadata["title"] = external_document.get("title")
- document.metadata["dataset_id"] = dataset_id
- document.metadata["dataset_name"] = dataset.name
+ if document.metadata is not None:
+ document.metadata["score"] = external_document.get("score")
+ document.metadata["title"] = external_document.get("title")
+ document.metadata["dataset_id"] = dataset_id
+ document.metadata["dataset_name"] = dataset.name
results.append(document)
else:
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
@@ -325,7 +326,7 @@ def single_retrieve(
score_threshold = 0.0
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
- score_threshold = retrieval_model_config.get("score_threshold")
+ score_threshold = retrieval_model_config.get("score_threshold", 0.0)
with measure_time() as timer:
results = RetrievalService.retrieve(
@@ -358,14 +359,14 @@ def multiple_retrieve(
score_threshold: float,
reranking_mode: str,
reranking_model: Optional[dict] = None,
- weights: Optional[dict] = None,
+ weights: Optional[dict[str, Any]] = None,
reranking_enable: bool = True,
message_id: Optional[str] = None,
):
if not available_datasets:
return []
threads = []
- all_documents = []
+ all_documents: list[Document] = []
dataset_ids = [dataset.id for dataset in available_datasets]
index_type_check = all(
item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
@@ -392,15 +393,18 @@ def multiple_retrieve(
"The configured knowledge base list have different embedding model, please set reranking model."
)
if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
- weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider
- weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
+ if weights is not None:
+ weights["vector_setting"]["embedding_provider_name"] = available_datasets[
+ 0
+ ].embedding_model_provider
+ weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
for dataset in available_datasets:
index_type = dataset.indexing_technique
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
- "flask_app": current_app._get_current_object(),
+ "flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
@@ -439,21 +443,22 @@ def _on_retrieval_end(
"""Handle retrieval end."""
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
- query = db.session.query(DocumentSegment).filter(
- DocumentSegment.index_node_id == document.metadata["doc_id"]
- )
+ if document.metadata is not None:
+ query = db.session.query(DocumentSegment).filter(
+ DocumentSegment.index_node_id == document.metadata["doc_id"]
+ )
- # if 'dataset_id' in document.metadata:
- if "dataset_id" in document.metadata:
- query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
+ # if 'dataset_id' in document.metadata:
+ if "dataset_id" in document.metadata:
+ query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
- # add hit count to document segment
- query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
+ # add hit count to document segment
+ query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
- db.session.commit()
+ db.session.commit()
# get tracing instance
- trace_manager: TraceQueueManager = (
+ trace_manager: Optional[TraceQueueManager] = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
@@ -504,10 +509,11 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int,
metadata=external_document.get("metadata"),
provider="external",
)
- document.metadata["score"] = external_document.get("score")
- document.metadata["title"] = external_document.get("title")
- document.metadata["dataset_id"] = dataset_id
- document.metadata["dataset_name"] = dataset.name
+ if document.metadata is not None:
+ document.metadata["score"] = external_document.get("score")
+ document.metadata["title"] = external_document.get("title")
+ document.metadata["dataset_id"] = dataset_id
+ document.metadata["dataset_name"] = dataset.name
all_documents.append(document)
else:
# get retrieval model , if the model is not setting , using default
@@ -607,19 +613,20 @@ def to_dataset_retriever_tool(
tools.append(tool)
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
- tool = DatasetMultiRetrieverTool.from_dataset(
- dataset_ids=[dataset.id for dataset in available_datasets],
- tenant_id=tenant_id,
- top_k=retrieve_config.top_k or 2,
- score_threshold=retrieve_config.score_threshold,
- hit_callbacks=[hit_callback],
- return_resource=return_resource,
- retriever_from=invoke_from.to_source(),
- reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
- reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
- )
+ if retrieve_config.reranking_model is not None:
+ tool = DatasetMultiRetrieverTool.from_dataset(
+ dataset_ids=[dataset.id for dataset in available_datasets],
+ tenant_id=tenant_id,
+ top_k=retrieve_config.top_k or 2,
+ score_threshold=retrieve_config.score_threshold,
+ hit_callbacks=[hit_callback],
+ return_resource=return_resource,
+ retriever_from=invoke_from.to_source(),
+ reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
+ reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
+ )
- tools.append(tool)
+ tools.append(tool)
return tools
@@ -635,10 +642,11 @@ def calculate_keyword_score(self, query: str, documents: list[Document], top_k:
query_keywords = keyword_table_handler.extract_keywords(query, None)
documents_keywords = []
for document in documents:
- # get the document keywords
- document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
- document.metadata["keywords"] = document_keywords
- documents_keywords.append(document_keywords)
+ if document.metadata is not None:
+ # get the document keywords
+ document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
+ document.metadata["keywords"] = document_keywords
+ documents_keywords.append(document_keywords)
# Counter query keywords(TF)
query_keyword_counts = Counter(query_keywords)
@@ -696,8 +704,9 @@ def cosine_similarity(vec1, vec2):
for document, score in zip(documents, similarities):
# format document
- document.metadata["score"] = score
- documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
+ if document.metadata is not None:
+ document.metadata["score"] = score
+ documents = sorted(documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
return documents[:top_k] if top_k else documents
def calculate_vector_score(
@@ -705,10 +714,12 @@ def calculate_vector_score(
) -> list[Document]:
filter_documents = []
for document in all_documents:
- if score_threshold is None or document.metadata["score"] >= score_threshold:
+ if score_threshold is None or (document.metadata and document.metadata.get("score", 0) >= score_threshold):
filter_documents.append(document)
if not filter_documents:
return []
- filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True)
+ filter_documents = sorted(
+ filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
+ )
return filter_documents[:top_k] if top_k else filter_documents
diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py
index 06147fe7b56544..b008d0df9c2f0e 100644
--- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py
+++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py
@@ -1,7 +1,8 @@
-from typing import Union
+from typing import Union, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
+from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
@@ -27,11 +28,14 @@ def invoke(
SystemPromptMessage(content="You are a helpful AI assistant."),
UserPromptMessage(content=query),
]
- result = model_instance.invoke_llm(
- prompt_messages=prompt_messages,
- tools=dataset_tools,
- stream=False,
- model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
+ result = cast(
+ LLMResult,
+ model_instance.invoke_llm(
+ prompt_messages=prompt_messages,
+ tools=dataset_tools,
+ stream=False,
+ model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
+ ),
)
if result.message.tool_calls:
# get retrieval model config
diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py
index 68fab0c127a253..05e8d043dfe741 100644
--- a/api/core/rag/retrieval/router/multi_dataset_react_route.py
+++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py
@@ -1,9 +1,9 @@
from collections.abc import Generator, Sequence
-from typing import Union
+from typing import Union, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
-from core.model_runtime.entities.llm_entities import LLMUsage
+from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
@@ -92,6 +92,7 @@ def _react_invoke(
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> Union[str, None]:
+ prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
if model_config.mode == "chat":
prompt = self.create_chat_prompt(
query=query,
@@ -149,12 +150,15 @@ def _invoke_llm(
:param stop: stop
:return:
"""
- invoke_result = model_instance.invoke_llm(
- prompt_messages=prompt_messages,
- model_parameters=completion_param,
- stop=stop,
- stream=True,
- user=user_id,
+ invoke_result = cast(
+ Generator[LLMResult, None, None],
+ model_instance.invoke_llm(
+ prompt_messages=prompt_messages,
+ model_parameters=completion_param,
+ stop=stop,
+ stream=True,
+ user=user_id,
+ ),
)
# handle invoke result
@@ -172,7 +176,7 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage
:return:
"""
model = None
- prompt_messages = []
+ prompt_messages: list[PromptMessage] = []
full_text = ""
usage = None
for result in invoke_result:
diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py
index 53032b34d570c7..3376bd7f75dd96 100644
--- a/api/core/rag/splitter/fixed_text_splitter.py
+++ b/api/core/rag/splitter/fixed_text_splitter.py
@@ -26,8 +26,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
def from_encoder(
cls: type[TS],
embedding_model_instance: Optional[ModelInstance],
- allowed_special: Union[Literal[all], Set[str]] = set(),
- disallowed_special: Union[Literal[all], Collection[str]] = "all",
+ allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
**kwargs: Any,
):
def _token_encoder(text: str) -> int:
diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py
index 7dd62f8de18a15..4bfa541fd454ad 100644
--- a/api/core/rag/splitter/text_splitter.py
+++ b/api/core/rag/splitter/text_splitter.py
@@ -92,7 +92,7 @@ def split_documents(self, documents: Iterable[Document]) -> list[Document]:
texts, metadatas = [], []
for doc in documents:
texts.append(doc.page_content)
- metadatas.append(doc.metadata)
+ metadatas.append(doc.metadata or {})
return self.create_documents(texts, metadatas=metadatas)
def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
@@ -143,7 +143,7 @@ def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
"""Text splitter that uses HuggingFace tokenizer to count length."""
try:
- from transformers import PreTrainedTokenizerBase
+ from transformers import PreTrainedTokenizerBase # type: ignore
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase")
diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py
index ddb1481276df67..975c374cae8356 100644
--- a/api/core/tools/entities/api_entities.py
+++ b/api/core/tools/entities/api_entities.py
@@ -14,7 +14,7 @@ class UserTool(BaseModel):
label: I18nObject # label
description: I18nObject
parameters: Optional[list[ToolParameter]] = None
- labels: list[str] = None
+ labels: list[str] | None = None
UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]]
diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py
index 0c15b2a3711f11..7c365dc69d3b39 100644
--- a/api/core/tools/entities/tool_bundle.py
+++ b/api/core/tools/entities/tool_bundle.py
@@ -18,7 +18,7 @@ class ApiToolBundle(BaseModel):
# summary
summary: Optional[str] = None
# operation_id
- operation_id: str = None
+ operation_id: str | None = None
# parameters
parameters: Optional[list[ToolParameter]] = None
# author
diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py
index 4fc383f91baeba..260e4e457f083e 100644
--- a/api/core/tools/entities/tool_entities.py
+++ b/api/core/tools/entities/tool_entities.py
@@ -244,18 +244,19 @@ def get_simple_instance(
"""
# convert options to ToolParameterOption
if options:
- options = [
+ options_tool_parametor = [
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options
]
return cls(
name=name,
label=I18nObject(en_US="", zh_Hans=""),
human_description=I18nObject(en_US="", zh_Hans=""),
+ placeholder=None,
type=type,
form=cls.ToolParameterForm.LLM,
llm_description=llm_description,
required=required,
- options=options,
+ options=options_tool_parametor,
)
@@ -331,7 +332,7 @@ def to_dict(self) -> dict:
"default": self.default,
"options": self.options,
"help": self.help.to_dict() if self.help else None,
- "label": self.label.to_dict(),
+ "label": self.label.to_dict() if self.label else None,
"url": self.url,
"placeholder": self.placeholder.to_dict() if self.placeholder else None,
}
@@ -374,7 +375,10 @@ def __init__(self, **data: Any):
pool[index] = ToolRuntimeImageVariable(**variable)
super().__init__(**data)
- def dict(self) -> dict:
+ def dict(self) -> dict: # type: ignore
+ """
+ FIXME: just ignore the type check for now
+ """
return {
"conversation_id": self.conversation_id,
"user_id": self.user_id,
diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py
index d99314e33a3204..f451edbf2ee969 100644
--- a/api/core/tools/provider/api_tool_provider.py
+++ b/api/core/tools/provider/api_tool_provider.py
@@ -1,9 +1,14 @@
+from typing import Optional
+
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolCredentialsOption,
+ ToolDescription,
+ ToolIdentity,
ToolProviderCredentials,
+ ToolProviderIdentity,
ToolProviderType,
)
from core.tools.provider.tool_provider import ToolProviderController
@@ -64,21 +69,18 @@ def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "Ap
pass
else:
raise ValueError(f"invalid auth type {auth_type}")
-
- user_name = db_provider.user.name if db_provider.user_id else ""
-
+ user_name = db_provider.user.name if db_provider.user_id and db_provider.user is not None else ""
return ApiToolProviderController(
- **{
- "identity": {
- "author": user_name,
- "name": db_provider.name,
- "label": {"en_US": db_provider.name, "zh_Hans": db_provider.name},
- "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
- "icon": db_provider.icon,
- },
- "credentials_schema": credentials_schema,
- "provider_id": db_provider.id or "",
- }
+ identity=ToolProviderIdentity(
+ author=user_name,
+ name=db_provider.name,
+ label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
+ description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
+ icon=db_provider.icon,
+ ),
+ credentials_schema=credentials_schema,
+ provider_id=db_provider.id or "",
+ tools=None,
)
@property
@@ -93,24 +95,22 @@ def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
:return: the tool
"""
return ApiTool(
- **{
- "api_bundle": tool_bundle,
- "identity": {
- "author": tool_bundle.author,
- "name": tool_bundle.operation_id,
- "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
- "icon": self.identity.icon,
- "provider": self.provider_id,
- },
- "description": {
- "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""},
- "llm": tool_bundle.summary or "",
- },
- "parameters": tool_bundle.parameters or [],
- }
+ api_bundle=tool_bundle,
+ identity=ToolIdentity(
+ author=tool_bundle.author,
+ name=tool_bundle.operation_id or "",
+ label=I18nObject(en_US=tool_bundle.operation_id, zh_Hans=tool_bundle.operation_id),
+ icon=self.identity.icon if self.identity else None,
+ provider=self.provider_id,
+ ),
+ description=ToolDescription(
+ human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""),
+ llm=tool_bundle.summary or "",
+ ),
+ parameters=tool_bundle.parameters or [],
)
- def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
+ def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[Tool]:
"""
load bundled tools
@@ -121,7 +121,7 @@ def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
return self.tools
- def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]:
+ def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
"""
fetch tools from database
@@ -131,6 +131,8 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]:
"""
if self.tools is not None:
return self.tools
+ if self.identity is None:
+ return None
tools: list[Tool] = []
@@ -151,7 +153,7 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]:
self.tools = tools
return tools
- def get_tool(self, tool_name: str) -> ApiTool:
+ def get_tool(self, tool_name: str) -> Tool:
"""
get tool by name
@@ -161,7 +163,9 @@ def get_tool(self, tool_name: str) -> ApiTool:
if self.tools is None:
self.get_tools()
- for tool in self.tools:
+ for tool in self.tools or []:
+ if tool.identity is None:
+ continue
if tool.identity.name == tool_name:
return tool
diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py
index 582ad636b1953a..fc29920acd40dc 100644
--- a/api/core/tools/provider/app_tool_provider.py
+++ b/api/core/tools/provider/app_tool_provider.py
@@ -1,9 +1,10 @@
import logging
-from typing import Any
+from typing import Any, Optional
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
from core.tools.provider.tool_provider import ToolProviderController
+from core.tools.tool.api_tool import ApiTool
from core.tools.tool.tool import Tool
from extensions.ext_database import db
from models.model import App, AppModelConfig
@@ -20,10 +21,10 @@ def provider_type(self) -> ToolProviderType:
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
pass
- def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
+ def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
pass
- def get_tools(self, user_id: str) -> list[Tool]:
+ def get_tools(self, user_id: str = "", tenant_id: str = "") -> list[Tool]:
db_tools: list[PublishedAppTool] = (
db.session.query(PublishedAppTool)
.filter(
@@ -38,7 +39,7 @@ def get_tools(self, user_id: str) -> list[Tool]:
tools: list[Tool] = []
for db_tool in db_tools:
- tool = {
+ tool: dict[str, Any] = {
"identity": {
"author": db_tool.author,
"name": db_tool.tool_name,
@@ -52,7 +53,7 @@ def get_tools(self, user_id: str) -> list[Tool]:
"parameters": [],
}
# get app from db
- app: App = db_tool.app
+ app: Optional[App] = db_tool.app
if not app:
logger.error(f"app {db_tool.app_id} not found")
@@ -79,6 +80,7 @@ def get_tools(self, user_id: str) -> list[Tool]:
type=ToolParameter.ToolParameterType.STRING,
required=required,
default=default,
+ placeholder=I18nObject(en_US="", zh_Hans=""),
)
)
elif form_type == "select":
@@ -92,6 +94,7 @@ def get_tools(self, user_id: str) -> list[Tool]:
type=ToolParameter.ToolParameterType.SELECT,
required=required,
default=default,
+ placeholder=I18nObject(en_US="", zh_Hans=""),
options=[
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
for option in options
@@ -99,5 +102,5 @@ def get_tools(self, user_id: str) -> list[Tool]:
)
)
- tools.append(Tool(**tool))
+ tools.append(ApiTool(**tool))
return tools
diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py
index 5c10f72fdaed01..99a062f8c366aa 100644
--- a/api/core/tools/provider/builtin/_positions.py
+++ b/api/core/tools/provider/builtin/_positions.py
@@ -5,7 +5,7 @@
class BuiltinToolProviderSort:
- _position = {}
+ _position: dict[str, int] = {}
@classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py
index 38123f125ae974..cf10f5d2556edd 100644
--- a/api/core/tools/provider/builtin/aippt/tools/aippt.py
+++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py
@@ -4,7 +4,7 @@
from json import loads as json_loads
from threading import Lock
from time import sleep, time
-from typing import Any
+from typing import Any, Union
from httpx import get, post
from requests import get as requests_get
@@ -21,23 +21,25 @@ class AIPPTGenerateToolAdapter:
"""
_api_base_url = URL("https://co.aippt.cn/api")
- _api_token_cache = {}
- _style_cache = {}
+ _api_token_cache: dict[str, dict[str, Union[str, float]]] = {}
+ _style_cache: dict[str, dict[str, Union[list[dict[str, Any]], float]]] = {}
- _api_token_cache_lock = Lock()
- _style_cache_lock = Lock()
+ _api_token_cache_lock: Lock = Lock()
+ _style_cache_lock: Lock = Lock()
- _task = {}
+ _task: dict[str, Any] = {}
_task_type_map = {
"auto": 1,
"markdown": 7,
}
- _tool: BuiltinTool
+ _tool: BuiltinTool | None
- def __init__(self, tool: BuiltinTool = None):
+ def __init__(self, tool: BuiltinTool | None = None):
self._tool = tool
- def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
+ def _invoke(
+ self, user_id: str, tool_parameters: dict[str, Any]
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invokes the AIPPT generate tool with the given user ID and tool parameters.
@@ -68,8 +70,8 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe
)
# get suit
- color: str = tool_parameters.get("color")
- style: str = tool_parameters.get("style")
+ color: str = tool_parameters.get("color", "")
+ style: str = tool_parameters.get("style", "")
if color == "__default__":
color_id = ""
@@ -226,7 +228,7 @@ def _generate_content(self, task_id: str, model: str, user_id: str) -> str:
return ""
- def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
+ def _generate_ppt(self, task_id: str, suit_id: int, user_id: str) -> tuple[str, str]:
"""
Generate a ppt
@@ -362,7 +364,9 @@ def _calculate_sign(access_key: str, secret_key: str, timestamp: int) -> str:
).decode("utf-8")
@classmethod
- def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
+ def _get_styles(
+ cls, credentials: dict[str, str], user_id: str
+ ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""
Get styles
"""
@@ -415,7 +419,7 @@ def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[di
return colors, styles
- def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
+ def get_styles(self, user_id: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""
Get styles
@@ -507,7 +511,9 @@ class AIPPTGenerateTool(BuiltinTool):
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
- def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
+ def _invoke(
+ self, user_id: str, tool_parameters: dict[str, Any]
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters)
def get_runtime_parameters(self) -> list[ToolParameter]:
diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py
index 2d65ba2d6f4389..8bd16050ecf0a6 100644
--- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py
+++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py
@@ -1,7 +1,7 @@
import logging
from typing import Any, Optional
-import arxiv
+import arxiv # type: ignore
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
diff --git a/api/core/tools/provider/builtin/audio/tools/tts.py b/api/core/tools/provider/builtin/audio/tools/tts.py
index f83a64d041faab..8a33ac405bd4c3 100644
--- a/api/core/tools/provider/builtin/audio/tools/tts.py
+++ b/api/core/tools/provider/builtin/audio/tools/tts.py
@@ -11,19 +11,21 @@
class TTSTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
- provider, model = tool_parameters.get("model").split("#")
- voice = tool_parameters.get(f"voice#{provider}#{model}")
+ provider, model = tool_parameters.get("model", "").split("#")
+ voice = tool_parameters.get(f"voice#{provider}#{model}", "")
model_manager = ModelManager()
+ if not self.runtime:
+ raise ValueError("Runtime is required")
model_instance = model_manager.get_model_instance(
- tenant_id=self.runtime.tenant_id,
+ tenant_id=self.runtime.tenant_id or "",
provider=provider,
model_type=ModelType.TTS,
model=model,
)
tts = model_instance.invoke_tts(
- content_text=tool_parameters.get("text"),
+ content_text=tool_parameters.get("text", ""),
user=user_id,
- tenant_id=self.runtime.tenant_id,
+ tenant_id=self.runtime.tenant_id or "",
voice=voice,
)
buffer = io.BytesIO()
@@ -41,8 +43,11 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInv
]
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
+ if not self.runtime:
+ raise ValueError("Runtime is required")
model_provider_service = ModelProviderService()
- models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
+ tid: str = self.runtime.tenant_id or ""
+ models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts")
items = []
for provider_model in models:
provider = provider_model.provider
@@ -62,6 +67,8 @@ def get_runtime_parameters(self) -> list[ToolParameter]:
ToolParameter(
name=f"voice#{provider}#{model}",
label=I18nObject(en_US=f"Voice of {model}({provider})"),
+ human_description=I18nObject(en_US=f"Select a voice for {model} model"),
+ placeholder=I18nObject(en_US="Select a voice"),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
options=[
@@ -83,6 +90,7 @@ def get_runtime_parameters(self) -> list[ToolParameter]:
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=True,
+ placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"),
options=options,
),
)
diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py
index a04f5c0fe9f1af..b224ff5258c879 100644
--- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py
+++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py
@@ -2,8 +2,8 @@
import logging
from typing import Any, Union
-import boto3
-from botocore.exceptions import BotoCoreError
+import boto3 # type: ignore
+from botocore.exceptions import BotoCoreError # type: ignore
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py
index 989608122185c8..b6d16d2759c30e 100644
--- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py
+++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py
@@ -1,7 +1,7 @@
import json
from typing import Any, Union
-import boto3
+import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py
index f43f3b6fe05694..01bc596346c231 100644
--- a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py
+++ b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py
@@ -2,7 +2,7 @@
import logging
from typing import Any, Union
-import boto3
+import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py
index bffcd058b509bf..715b1ddeddcae5 100644
--- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py
+++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py
@@ -2,7 +2,7 @@
import operator
from typing import Any, Union
-import boto3
+import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
@@ -10,8 +10,8 @@
class SageMakerReRankTool(BuiltinTool):
sagemaker_client: Any = None
- sagemaker_endpoint: str = None
- topk: int = None
+ sagemaker_endpoint: str | None = None
+ topk: int | None = None
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
inputs = [query_input] * len(docs)
diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py
index 1fafe09b4d96bf..55cff89798a4eb 100644
--- a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py
+++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py
@@ -2,7 +2,7 @@
from enum import Enum
from typing import Any, Optional, Union
-import boto3
+import boto3 # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
@@ -17,7 +17,7 @@ class TTSModelType(Enum):
class SageMakerTTSTool(BuiltinTool):
sagemaker_client: Any = None
- sagemaker_endpoint: str = None
+ sagemaker_endpoint: str | None = None
s3_client: Any = None
comprehend_client: Any = None
diff --git a/api/core/tools/provider/builtin/cogview/tools/cogvideo.py b/api/core/tools/provider/builtin/cogview/tools/cogvideo.py
index 7f69e833cb9046..a60062ca66abbf 100644
--- a/api/core/tools/provider/builtin/cogview/tools/cogvideo.py
+++ b/api/core/tools/provider/builtin/cogview/tools/cogvideo.py
@@ -1,6 +1,6 @@
from typing import Any, Union
-from zhipuai import ZhipuAI
+from zhipuai import ZhipuAI # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py b/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py
index a521f1c28a41b6..3e24b74d2598a7 100644
--- a/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py
+++ b/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py
@@ -1,7 +1,7 @@
from typing import Any, Union
import httpx
-from zhipuai import ZhipuAI
+from zhipuai import ZhipuAI # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py
index 12b4173fa40270..9aa781709a726c 100644
--- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py
+++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py
@@ -1,7 +1,7 @@
import random
from typing import Any, Union
-from zhipuai import ZhipuAI
+from zhipuai import ZhipuAI # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin/feishu_base/tools/search_records.py b/api/core/tools/provider/builtin/feishu_base/tools/search_records.py
index c959496735e747..d58b42b82029ce 100644
--- a/api/core/tools/provider/builtin/feishu_base/tools/search_records.py
+++ b/api/core/tools/provider/builtin/feishu_base/tools/search_records.py
@@ -7,18 +7,22 @@
class SearchRecordsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
+ if not self.runtime or not self.runtime.credentials:
+ raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
+ if not app_id or not app_secret:
+ raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
- app_token = tool_parameters.get("app_token")
- table_id = tool_parameters.get("table_id")
- table_name = tool_parameters.get("table_name")
- view_id = tool_parameters.get("view_id")
- field_names = tool_parameters.get("field_names")
- sort = tool_parameters.get("sort")
- filters = tool_parameters.get("filter")
- page_token = tool_parameters.get("page_token")
+ app_token = tool_parameters.get("app_token", "")
+ table_id = tool_parameters.get("table_id", "")
+ table_name = tool_parameters.get("table_name", "")
+ view_id = tool_parameters.get("view_id", "")
+ field_names = tool_parameters.get("field_names", "")
+ sort = tool_parameters.get("sort", "")
+ filters = tool_parameters.get("filter", "")
+ page_token = tool_parameters.get("page_token", "")
automatic_fields = tool_parameters.get("automatic_fields", False)
user_id_type = tool_parameters.get("user_id_type", "open_id")
page_size = tool_parameters.get("page_size", 20)
diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_records.py b/api/core/tools/provider/builtin/feishu_base/tools/update_records.py
index a7b036387500b0..31cf8e18d85b8d 100644
--- a/api/core/tools/provider/builtin/feishu_base/tools/update_records.py
+++ b/api/core/tools/provider/builtin/feishu_base/tools/update_records.py
@@ -7,14 +7,18 @@
class UpdateRecordsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
+ if not self.runtime or not self.runtime.credentials:
+ raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
+ if not app_id or not app_secret:
+ raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
- app_token = tool_parameters.get("app_token")
- table_id = tool_parameters.get("table_id")
- table_name = tool_parameters.get("table_name")
- records = tool_parameters.get("records")
+ app_token = tool_parameters.get("app_token", "")
+ table_id = tool_parameters.get("table_id", "")
+ table_name = tool_parameters.get("table_name", "")
+ records = tool_parameters.get("records", "")
user_id_type = tool_parameters.get("user_id_type", "open_id")
res = client.update_records(app_token, table_id, table_name, records, user_id_type)
diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py b/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py
index 8f83aea5abbe3d..80287feca176e1 100644
--- a/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py
+++ b/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py
@@ -7,12 +7,16 @@
class AddEventAttendeesTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
+ if not self.runtime or not self.runtime.credentials:
+ raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
+ if not app_id or not app_secret:
+ raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
- event_id = tool_parameters.get("event_id")
- attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email")
+ event_id = tool_parameters.get("event_id", "")
+ attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email", "")
need_notification = tool_parameters.get("need_notification", True)
res = client.add_event_attendees(event_id, attendee_phone_or_email, need_notification)
diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py b/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py
index 144889692f9055..02e9b445219ac8 100644
--- a/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py
+++ b/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py
@@ -7,11 +7,15 @@
class DeleteEventTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
+ if not self.runtime or not self.runtime.credentials:
+ raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
+ if not app_id or not app_secret:
+ raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
- event_id = tool_parameters.get("event_id")
+ event_id = tool_parameters.get("event_id", "")
need_notification = tool_parameters.get("need_notification", True)
res = client.delete_event(event_id, need_notification)
diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py b/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py
index a2cd5a8b17d0af..4dafe4b3baf0cd 100644
--- a/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py
+++ b/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py
@@ -7,8 +7,12 @@
class GetPrimaryCalendarTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
+ if not self.runtime or not self.runtime.credentials:
+ raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
+ if not app_id or not app_secret:
+ raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
user_id_type = tool_parameters.get("user_id_type", "open_id")
diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py b/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py
index 8815b4c9c871cd..2e8ca968b3cc42 100644
--- a/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py
+++ b/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py
@@ -7,14 +7,18 @@
class ListEventsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
+ if not self.runtime or not self.runtime.credentials:
+ raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
+ if not app_id or not app_secret:
+ raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
- start_time = tool_parameters.get("start_time")
- end_time = tool_parameters.get("end_time")
- page_token = tool_parameters.get("page_token")
- page_size = tool_parameters.get("page_size")
+ start_time = tool_parameters.get("start_time", "")
+ end_time = tool_parameters.get("end_time", "")
+ page_token = tool_parameters.get("page_token", "")
+ page_size = tool_parameters.get("page_size", 50)
res = client.list_events(start_time, end_time, page_token, page_size)
diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py b/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py
index 85bcb1d3f63847..b20eb6c31828e4 100644
--- a/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py
+++ b/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py
@@ -7,16 +7,20 @@
class UpdateEventTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
+ if not self.runtime or not self.runtime.credentials:
+ raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
+ if not app_id or not app_secret:
+ raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
- event_id = tool_parameters.get("event_id")
- summary = tool_parameters.get("summary")
- description = tool_parameters.get("description")
+ event_id = tool_parameters.get("event_id", "")
+ summary = tool_parameters.get("summary", "")
+ description = tool_parameters.get("description", "")
need_notification = tool_parameters.get("need_notification", True)
- start_time = tool_parameters.get("start_time")
- end_time = tool_parameters.get("end_time")
+ start_time = tool_parameters.get("start_time", "")
+ end_time = tool_parameters.get("end_time", "")
auto_record = tool_parameters.get("auto_record", False)
res = client.update_event(event_id, summary, description, need_notification, start_time, end_time, auto_record)
diff --git a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py
index 090a0828e89bbf..1533f594172878 100644
--- a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py
+++ b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py
@@ -7,13 +7,17 @@
class CreateDocumentTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
+ if not self.runtime or not self.runtime.credentials:
+ raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
+ if not app_id or not app_secret:
+ raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
- title = tool_parameters.get("title")
- content = tool_parameters.get("content")
- folder_token = tool_parameters.get("folder_token")
+ title = tool_parameters.get("title", "")
+ content = tool_parameters.get("content", "")
+ folder_token = tool_parameters.get("folder_token", "")
res = client.create_document(title, content, folder_token)
return self.create_json_message(res)
diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py
index dd57c6870d0ba9..8ea68a2ed87855 100644
--- a/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py
+++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py
@@ -7,11 +7,15 @@
class ListDocumentBlockTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
+ if not self.runtime or not self.runtime.credentials:
+ raise ValueError("Runtime is not set")
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
+ if not app_id or not app_secret:
+ raise ValueError("app_id and app_secret are required")
client = FeishuRequest(app_id, app_secret)
- document_id = tool_parameters.get("document_id")
+ document_id = tool_parameters.get("document_id", "")
page_token = tool_parameters.get("page_token", "")
user_id_type = tool_parameters.get("user_id_type", "open_id")
page_size = tool_parameters.get("page_size", 500)
diff --git a/api/core/tools/provider/builtin/json_process/tools/delete.py b/api/core/tools/provider/builtin/json_process/tools/delete.py
index fcab3d71a93cf9..06f6cacd5d6126 100644
--- a/api/core/tools/provider/builtin/json_process/tools/delete.py
+++ b/api/core/tools/provider/builtin/json_process/tools/delete.py
@@ -1,7 +1,7 @@
import json
from typing import Any, Union
-from jsonpath_ng import parse
+from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin/json_process/tools/insert.py b/api/core/tools/provider/builtin/json_process/tools/insert.py
index 793c74e5f9df51..e825329a6d8f61 100644
--- a/api/core/tools/provider/builtin/json_process/tools/insert.py
+++ b/api/core/tools/provider/builtin/json_process/tools/insert.py
@@ -1,7 +1,7 @@
import json
from typing import Any, Union
-from jsonpath_ng import parse
+from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin/json_process/tools/parse.py b/api/core/tools/provider/builtin/json_process/tools/parse.py
index f91432ee77f488..193017ba9a7c53 100644
--- a/api/core/tools/provider/builtin/json_process/tools/parse.py
+++ b/api/core/tools/provider/builtin/json_process/tools/parse.py
@@ -1,7 +1,7 @@
import json
from typing import Any, Union
-from jsonpath_ng import parse
+from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin/json_process/tools/replace.py b/api/core/tools/provider/builtin/json_process/tools/replace.py
index 383825c2d0b259..feca0d8a7c2783 100644
--- a/api/core/tools/provider/builtin/json_process/tools/replace.py
+++ b/api/core/tools/provider/builtin/json_process/tools/replace.py
@@ -1,7 +1,7 @@
import json
from typing import Any, Union
-from jsonpath_ng import parse
+from jsonpath_ng import parse # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin/maths/tools/eval_expression.py b/api/core/tools/provider/builtin/maths/tools/eval_expression.py
index 0c5b5e41cbe1e1..d3a497d1cd5c54 100644
--- a/api/core/tools/provider/builtin/maths/tools/eval_expression.py
+++ b/api/core/tools/provider/builtin/maths/tools/eval_expression.py
@@ -1,7 +1,7 @@
import logging
from typing import Any, Union
-import numexpr as ne
+import numexpr as ne # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py
index db4adfd4ad4629..6473c509e1f4c2 100644
--- a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py
+++ b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py
@@ -1,4 +1,4 @@
-from novita_client import (
+from novita_client import ( # type: ignore
Txt2ImgV3Embedding,
Txt2ImgV3HiresFix,
Txt2ImgV3LoRA,
diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py
index 0b4f2edff3607f..097b234bd50640 100644
--- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py
+++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py
@@ -2,7 +2,7 @@
from copy import deepcopy
from typing import Any, Union
-from novita_client import (
+from novita_client import ( # type: ignore
NovitaClient,
)
diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py
index 9c61eab9f95784..297a27abba667a 100644
--- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py
+++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py
@@ -2,7 +2,7 @@
from copy import deepcopy
from typing import Any, Union
-from novita_client import (
+from novita_client import ( # type: ignore
NovitaClient,
)
diff --git a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py
index 165e93956eff38..704e0015d961a3 100644
--- a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py
+++ b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py
@@ -13,7 +13,7 @@
with warnings.catch_warnings():
warnings.simplefilter("ignore")
- from pydub import AudioSegment
+ from pydub import AudioSegment # type: ignore
class PodcastAudioGeneratorTool(BuiltinTool):
diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py
index d8ca20bde6ffc9..4a47c4211f4fd4 100644
--- a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py
+++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py
@@ -2,10 +2,10 @@
import logging
from typing import Any, Union
-from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q
-from qrcode.image.base import BaseImage
-from qrcode.image.pure import PyPNGImage
-from qrcode.main import QRCode
+from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q # type: ignore
+from qrcode.image.base import BaseImage # type: ignore
+from qrcode.image.pure import PyPNGImage # type: ignore
+from qrcode.main import QRCode # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin/transcript/tools/transcript.py b/api/core/tools/provider/builtin/transcript/tools/transcript.py
index 27f700efbd6936..ac7565d9eef5b8 100644
--- a/api/core/tools/provider/builtin/transcript/tools/transcript.py
+++ b/api/core/tools/provider/builtin/transcript/tools/transcript.py
@@ -1,7 +1,7 @@
from typing import Any, Union
from urllib.parse import parse_qs, urlparse
-from youtube_transcript_api import YouTubeTranscriptApi
+from youtube_transcript_api import YouTubeTranscriptApi # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py
index 5ee839baa56f02..98a108f4ec7e93 100644
--- a/api/core/tools/provider/builtin/twilio/tools/send_message.py
+++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py
@@ -37,7 +37,7 @@ class TwilioAPIWrapper(BaseModel):
def set_validator(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
try:
- from twilio.rest import Client
+ from twilio.rest import Client # type: ignore
except ImportError:
raise ImportError("Could not import twilio python package. Please install it with `pip install twilio`.")
account_sid = values.get("account_sid")
diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py
index b1d100aad93dba..649e03d185121c 100644
--- a/api/core/tools/provider/builtin/twilio/twilio.py
+++ b/api/core/tools/provider/builtin/twilio/twilio.py
@@ -1,7 +1,7 @@
from typing import Any
-from twilio.base.exceptions import TwilioRestException
-from twilio.rest import Client
+from twilio.base.exceptions import TwilioRestException # type: ignore
+from twilio.rest import Client # type: ignore
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py
index 1c7cb39c92b40b..a6afd2dddfc63a 100644
--- a/api/core/tools/provider/builtin/vanna/tools/vanna.py
+++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py
@@ -1,6 +1,6 @@
from typing import Any, Union
-from vanna.remote import VannaDefault
+from vanna.remote import VannaDefault # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolProviderCredentialValidationError
@@ -14,6 +14,9 @@ def _invoke(
"""
invoke tools
"""
+ # Ensure runtime and credentials
+ if not self.runtime or not self.runtime.credentials:
+ raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing")
api_key = self.runtime.credentials.get("api_key", None)
if not api_key:
raise ToolProviderCredentialValidationError("Please input api key")
diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py
index cb88e9519a4346..edb96e722f7f33 100644
--- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py
+++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py
@@ -1,6 +1,6 @@
from typing import Any, Optional, Union
-import wikipedia
+import wikipedia # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin/yahoo/tools/analytics.py b/api/core/tools/provider/builtin/yahoo/tools/analytics.py
index f044fbe5404b0a..95a65ba22fc8af 100644
--- a/api/core/tools/provider/builtin/yahoo/tools/analytics.py
+++ b/api/core/tools/provider/builtin/yahoo/tools/analytics.py
@@ -3,7 +3,7 @@
import pandas as pd
from requests.exceptions import HTTPError, ReadTimeout
-from yfinance import download
+from yfinance import download # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin/yahoo/tools/news.py b/api/core/tools/provider/builtin/yahoo/tools/news.py
index ff820430f9f366..c9ae0c4ca7fcc6 100644
--- a/api/core/tools/provider/builtin/yahoo/tools/news.py
+++ b/api/core/tools/provider/builtin/yahoo/tools/news.py
@@ -1,6 +1,6 @@
from typing import Any, Union
-import yfinance
+import yfinance # type: ignore
from requests.exceptions import HTTPError, ReadTimeout
from core.tools.entities.tool_entities import ToolInvokeMessage
diff --git a/api/core/tools/provider/builtin/yahoo/tools/ticker.py b/api/core/tools/provider/builtin/yahoo/tools/ticker.py
index dfc7e460473c33..74d0d25addf04b 100644
--- a/api/core/tools/provider/builtin/yahoo/tools/ticker.py
+++ b/api/core/tools/provider/builtin/yahoo/tools/ticker.py
@@ -1,7 +1,7 @@
from typing import Any, Union
from requests.exceptions import HTTPError, ReadTimeout
-from yfinance import Ticker
+from yfinance import Ticker # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.py b/api/core/tools/provider/builtin/youtube/tools/videos.py
index 95dec2eac9a752..a24fe89679b29b 100644
--- a/api/core/tools/provider/builtin/youtube/tools/videos.py
+++ b/api/core/tools/provider/builtin/youtube/tools/videos.py
@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Any, Union
-from googleapiclient.discovery import build
+from googleapiclient.discovery import build # type: ignore
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py
index 955a0add3b4513..61de75ac5e2ccd 100644
--- a/api/core/tools/provider/builtin_tool_provider.py
+++ b/api/core/tools/provider/builtin_tool_provider.py
@@ -1,6 +1,6 @@
from abc import abstractmethod
from os import listdir, path
-from typing import Any
+from typing import Any, Optional
from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
@@ -50,6 +50,8 @@ def _get_builtin_tools(self) -> list[Tool]:
"""
if self.tools:
return self.tools
+ if not self.identity:
+ return []
provider = self.identity.name
tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools")
@@ -86,7 +88,7 @@ def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
return self.credentials_schema.copy()
- def get_tools(self) -> list[Tool]:
+ def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
"""
returns a list of tools that the provider can provide
@@ -94,11 +96,14 @@ def get_tools(self) -> list[Tool]:
"""
return self._get_builtin_tools()
- def get_tool(self, tool_name: str) -> Tool:
+ def get_tool(self, tool_name: str) -> Optional[Tool]:
"""
returns the tool that the provider can provide
"""
- return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
+ tools = self.get_tools()
+ if tools is None:
+ raise ValueError("tools not found")
+ return next((t for t in tools if t.identity and t.identity.name == tool_name), None)
def get_parameters(self, tool_name: str) -> list[ToolParameter]:
"""
@@ -107,10 +112,13 @@ def get_parameters(self, tool_name: str) -> list[ToolParameter]:
:param tool_name: the name of the tool, defined in `get_tools`
:return: list of parameters
"""
- tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
+ tools = self.get_tools()
+ if tools is None:
+ raise ToolNotFoundError(f"tool {tool_name} not found")
+ tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
- return tool.parameters
+ return tool.parameters or []
@property
def need_credentials(self) -> bool:
@@ -144,6 +152,8 @@ def _get_tool_labels(self) -> list[ToolLabelEnum]:
"""
returns the labels of the provider
"""
+ if self.identity is None:
+ return []
return self.identity.tags or []
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
@@ -159,56 +169,56 @@ def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dic
for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter
- for parameter in tool_parameters:
- if parameter not in tool_parameters_need_to_validate:
- raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}")
+ for parameter_name in tool_parameters:
+ if parameter_name not in tool_parameters_need_to_validate:
+ raise ToolParameterValidationError(f"parameter {parameter_name} not found in tool {tool_name}")
# check type
- parameter_schema = tool_parameters_need_to_validate[parameter]
+ parameter_schema = tool_parameters_need_to_validate[parameter_name]
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
- if not isinstance(tool_parameters[parameter], str):
- raise ToolParameterValidationError(f"parameter {parameter} should be string")
+ if not isinstance(tool_parameters[parameter_name], str):
+ raise ToolParameterValidationError(f"parameter {parameter_name} should be string")
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
- if not isinstance(tool_parameters[parameter], int | float):
- raise ToolParameterValidationError(f"parameter {parameter} should be number")
+ if not isinstance(tool_parameters[parameter_name], int | float):
+ raise ToolParameterValidationError(f"parameter {parameter_name} should be number")
- if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
+ if parameter_schema.min is not None and tool_parameters[parameter_name] < parameter_schema.min:
raise ToolParameterValidationError(
- f"parameter {parameter} should be greater than {parameter_schema.min}"
+ f"parameter {parameter_name} should be greater than {parameter_schema.min}"
)
- if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
+ if parameter_schema.max is not None and tool_parameters[parameter_name] > parameter_schema.max:
raise ToolParameterValidationError(
- f"parameter {parameter} should be less than {parameter_schema.max}"
+ f"parameter {parameter_name} should be less than {parameter_schema.max}"
)
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
- if not isinstance(tool_parameters[parameter], bool):
- raise ToolParameterValidationError(f"parameter {parameter} should be boolean")
+ if not isinstance(tool_parameters[parameter_name], bool):
+ raise ToolParameterValidationError(f"parameter {parameter_name} should be boolean")
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
- if not isinstance(tool_parameters[parameter], str):
- raise ToolParameterValidationError(f"parameter {parameter} should be string")
+ if not isinstance(tool_parameters[parameter_name], str):
+ raise ToolParameterValidationError(f"parameter {parameter_name} should be string")
options = parameter_schema.options
if not isinstance(options, list):
- raise ToolParameterValidationError(f"parameter {parameter} options should be list")
+ raise ToolParameterValidationError(f"parameter {parameter_name} options should be list")
- if tool_parameters[parameter] not in [x.value for x in options]:
- raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}")
+ if tool_parameters[parameter_name] not in [x.value for x in options]:
+ raise ToolParameterValidationError(f"parameter {parameter_name} should be one of {options}")
- tool_parameters_need_to_validate.pop(parameter)
+ tool_parameters_need_to_validate.pop(parameter_name)
- for parameter in tool_parameters_need_to_validate:
- parameter_schema = tool_parameters_need_to_validate[parameter]
+ for parameter_name in tool_parameters_need_to_validate:
+ parameter_schema = tool_parameters_need_to_validate[parameter_name]
if parameter_schema.required:
- raise ToolParameterValidationError(f"parameter {parameter} is required")
+ raise ToolParameterValidationError(f"parameter {parameter_name} is required")
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
default_value = parameter_schema.type.cast_value(parameter_schema.default)
- tool_parameters[parameter] = default_value
+ tool_parameters[parameter_name] = default_value
def validate_credentials(self, credentials: dict[str, Any]) -> None:
"""
diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py
index bc05a11562b717..e35207e4f06404 100644
--- a/api/core/tools/provider/tool_provider.py
+++ b/api/core/tools/provider/tool_provider.py
@@ -24,10 +24,12 @@ def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
:return: the credentials schema
"""
+ if self.credentials_schema is None:
+ return {}
return self.credentials_schema.copy()
@abstractmethod
- def get_tools(self) -> list[Tool]:
+ def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
"""
returns a list of tools that the provider can provide
@@ -36,7 +38,7 @@ def get_tools(self) -> list[Tool]:
pass
@abstractmethod
- def get_tool(self, tool_name: str) -> Tool:
+ def get_tool(self, tool_name: str) -> Optional[Tool]:
"""
returns a tool that the provider can provide
@@ -51,10 +53,13 @@ def get_parameters(self, tool_name: str) -> list[ToolParameter]:
:param tool_name: the name of the tool, defined in `get_tools`
:return: list of parameters
"""
- tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
+ tools = self.get_tools()
+ if tools is None:
+ raise ToolNotFoundError(f"tool {tool_name} not found")
+ tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
- return tool.parameters
+ return tool.parameters or []
@property
def provider_type(self) -> ToolProviderType:
@@ -78,55 +83,55 @@ def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dic
for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter
- for parameter in tool_parameters:
- if parameter not in tool_parameters_need_to_validate:
- raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}")
+ for tool_parameter in tool_parameters:
+ if tool_parameter not in tool_parameters_need_to_validate:
+ raise ToolParameterValidationError(f"parameter {tool_parameter} not found in tool {tool_name}")
# check type
- parameter_schema = tool_parameters_need_to_validate[parameter]
+ parameter_schema = tool_parameters_need_to_validate[tool_parameter]
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
- if not isinstance(tool_parameters[parameter], str):
- raise ToolParameterValidationError(f"parameter {parameter} should be string")
+ if not isinstance(tool_parameters[tool_parameter], str):
+ raise ToolParameterValidationError(f"parameter {tool_parameter} should be string")
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
- if not isinstance(tool_parameters[parameter], int | float):
- raise ToolParameterValidationError(f"parameter {parameter} should be number")
+ if not isinstance(tool_parameters[tool_parameter], int | float):
+ raise ToolParameterValidationError(f"parameter {tool_parameter} should be number")
- if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
+ if parameter_schema.min is not None and tool_parameters[tool_parameter] < parameter_schema.min:
raise ToolParameterValidationError(
- f"parameter {parameter} should be greater than {parameter_schema.min}"
+ f"parameter {tool_parameter} should be greater than {parameter_schema.min}"
)
- if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
+ if parameter_schema.max is not None and tool_parameters[tool_parameter] > parameter_schema.max:
raise ToolParameterValidationError(
- f"parameter {parameter} should be less than {parameter_schema.max}"
+ f"parameter {tool_parameter} should be less than {parameter_schema.max}"
)
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
- if not isinstance(tool_parameters[parameter], bool):
- raise ToolParameterValidationError(f"parameter {parameter} should be boolean")
+ if not isinstance(tool_parameters[tool_parameter], bool):
+ raise ToolParameterValidationError(f"parameter {tool_parameter} should be boolean")
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
- if not isinstance(tool_parameters[parameter], str):
- raise ToolParameterValidationError(f"parameter {parameter} should be string")
+ if not isinstance(tool_parameters[tool_parameter], str):
+ raise ToolParameterValidationError(f"parameter {tool_parameter} should be string")
options = parameter_schema.options
if not isinstance(options, list):
- raise ToolParameterValidationError(f"parameter {parameter} options should be list")
+ raise ToolParameterValidationError(f"parameter {tool_parameter} options should be list")
- if tool_parameters[parameter] not in [x.value for x in options]:
- raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}")
+ if tool_parameters[tool_parameter] not in [x.value for x in options]:
+ raise ToolParameterValidationError(f"parameter {tool_parameter} should be one of {options}")
- tool_parameters_need_to_validate.pop(parameter)
+ tool_parameters_need_to_validate.pop(tool_parameter)
- for parameter in tool_parameters_need_to_validate:
- parameter_schema = tool_parameters_need_to_validate[parameter]
+ for tool_parameter_validate in tool_parameters_need_to_validate:
+ parameter_schema = tool_parameters_need_to_validate[tool_parameter_validate]
if parameter_schema.required:
- raise ToolParameterValidationError(f"parameter {parameter} is required")
+ raise ToolParameterValidationError(f"parameter {tool_parameter_validate} is required")
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
- tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default)
+ tool_parameters[tool_parameter_validate] = parameter_schema.type.cast_value(parameter_schema.default)
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""
@@ -144,6 +149,8 @@ def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
for credential_name in credentials:
if credential_name not in credentials_need_to_validate:
+ if self.identity is None:
+ raise ValueError("identity is not set")
raise ToolProviderCredentialValidationError(
f"credential {credential_name} not found in provider {self.identity.name}"
)
diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py
index 5656dd09ab8c94..17fe2e20cf282e 100644
--- a/api/core/tools/provider/workflow_tool_provider.py
+++ b/api/core/tools/provider/workflow_tool_provider.py
@@ -11,6 +11,7 @@
ToolProviderType,
)
from core.tools.provider.tool_provider import ToolProviderController
+from core.tools.tool.tool import Tool
from core.tools.tool.workflow_tool import WorkflowTool
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from extensions.ext_database import db
@@ -116,6 +117,7 @@ def fetch_workflow_variable(variable_name: str):
llm_description=parameter.description,
required=variable.required,
options=options,
+ placeholder=I18nObject(en_US="", zh_Hans=""),
)
)
elif features.file_upload:
@@ -128,6 +130,7 @@ def fetch_workflow_variable(variable_name: str):
llm_description=parameter.description,
required=False,
form=parameter.form,
+ placeholder=I18nObject(en_US="", zh_Hans=""),
)
)
else:
@@ -157,7 +160,7 @@ def fetch_workflow_variable(variable_name: str):
label=db_provider.label,
)
- def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
+ def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
"""
fetch tools from database
@@ -168,7 +171,7 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
if self.tools is not None:
return self.tools
- db_providers: WorkflowToolProvider = (
+ db_providers: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(
WorkflowToolProvider.tenant_id == tenant_id,
@@ -179,12 +182,14 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
if not db_providers:
return []
+ if not db_providers.app:
+ raise ValueError("app not found")
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
return self.tools
- def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
+ def get_tool(self, tool_name: str) -> Optional[Tool]:
"""
get tool by name
@@ -195,6 +200,8 @@ def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
return None
for tool in self.tools:
+ if tool.identity is None:
+ continue
if tool.identity.name == tool_name:
return tool
diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py
index 48aac75dbb4115..9a00450290a660 100644
--- a/api/core/tools/tool/api_tool.py
+++ b/api/core/tools/tool/api_tool.py
@@ -32,11 +32,13 @@ def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
:param meta: the meta data of a tool call processing, tenant_id is required
:return: the new tool
"""
+ if self.api_bundle is None:
+ raise ValueError("api_bundle is required")
return self.__class__(
identity=self.identity.model_copy() if self.identity else None,
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.model_copy() if self.description else None,
- api_bundle=self.api_bundle.model_copy() if self.api_bundle else None,
+ api_bundle=self.api_bundle.model_copy(),
runtime=Tool.Runtime(**runtime),
)
@@ -61,6 +63,8 @@ def tool_provider_type(self) -> ToolProviderType:
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
headers = {}
+ if self.runtime is None:
+ raise ValueError("runtime is required")
credentials = self.runtime.credentials or {}
if "auth_type" not in credentials:
@@ -88,7 +92,7 @@ def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
headers[api_key_header] = credentials["api_key_value"]
- needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required]
+ needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required]
for parameter in needed_parameters:
if parameter.required and parameter.name not in parameters:
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
@@ -137,7 +141,8 @@ def do_http_request(
params = {}
path_params = {}
- body = {}
+ # FIXME: body should be a dict[str, Any] but it changed a lot in this function
+ body: Any = {}
cookies = {}
files = []
@@ -198,7 +203,7 @@ def do_http_request(
body = body
if method in {"get", "head", "post", "put", "delete", "patch"}:
- response = getattr(ssrf_proxy, method)(
+ response: httpx.Response = getattr(ssrf_proxy, method)(
url,
params=params,
headers=headers,
@@ -288,6 +293,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe
"""
invoke http request
"""
+ response: httpx.Response | str = ""
# assemble request
headers = self.assembling_request(tool_parameters)
diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py
index e2a81ed0a36edd..adda4297f38e8a 100644
--- a/api/core/tools/tool/builtin_tool.py
+++ b/api/core/tools/tool/builtin_tool.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, cast
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
@@ -32,9 +32,12 @@ def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop:
:return: the model result
"""
# invoke model
+ if self.runtime is None or self.identity is None:
+ raise ValueError("runtime and identity are required")
+
return ModelInvocationUtils.invoke(
user_id=user_id,
- tenant_id=self.runtime.tenant_id,
+ tenant_id=self.runtime.tenant_id or "",
tool_type="builtin",
tool_name=self.identity.name,
prompt_messages=prompt_messages,
@@ -50,8 +53,11 @@ def get_max_tokens(self) -> int:
:param model_config: the model config
:return: the max tokens
"""
+ if self.runtime is None:
+ raise ValueError("runtime is required")
+
return ModelInvocationUtils.get_max_llm_context_tokens(
- tenant_id=self.runtime.tenant_id,
+ tenant_id=self.runtime.tenant_id or "",
)
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
@@ -61,7 +67,12 @@ def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
:param prompt_messages: the prompt messages
:return: the tokens
"""
- return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages)
+ if self.runtime is None:
+ raise ValueError("runtime is required")
+
+ return ModelInvocationUtils.calculate_tokens(
+ tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
+ )
def summary(self, user_id: str, content: str) -> str:
max_tokens = self.get_max_tokens()
@@ -81,7 +92,7 @@ def summarize(content: str) -> str:
stop=[],
)
- return summary.message.content
+ return cast(str, summary.message.content)
lines = content.split("\n")
new_lines = []
@@ -102,16 +113,16 @@ def summarize(content: str) -> str:
# merge lines into messages with max tokens
messages: list[str] = []
- for i in new_lines:
+ for j in new_lines:
if len(messages) == 0:
- messages.append(i)
+ messages.append(j)
else:
- if len(messages[-1]) + len(i) < max_tokens * 0.5:
- messages[-1] += i
- if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
- messages.append(i)
+ if len(messages[-1]) + len(j) < max_tokens * 0.5:
+ messages[-1] += j
+ if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7:
+ messages.append(j)
else:
- messages[-1] += i
+ messages[-1] += j
summaries = []
for i in range(len(messages)):
diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py
index ab7b40a2536db8..a4afea4b9df429 100644
--- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py
+++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py
@@ -1,4 +1,5 @@
import threading
+from typing import Any
from flask import Flask, current_app
from pydantic import BaseModel, Field
@@ -7,13 +8,14 @@
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
-default_retrieval_model = {
+default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@@ -44,12 +46,12 @@ def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
def _run(self, query: str) -> str:
threads = []
- all_documents = []
+ all_documents: list[RagDocument] = []
for dataset_id in self.dataset_ids:
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
- "flask_app": current_app._get_current_object(),
+ "flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"all_documents": all_documents,
@@ -77,11 +79,11 @@ def _run(self, query: str) -> str:
document_score_list = {}
for item in all_documents:
- if item.metadata.get("score"):
+ if item.metadata and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
- index_node_ids = [document.metadata["doc_id"] for document in all_documents]
+ index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
@@ -139,6 +141,7 @@ def _run(self, query: str) -> str:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
+ return ""
def _retriever(
self,
diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py
index dad8c773579099..a4d2de3b1c8ef3 100644
--- a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py
+++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py
@@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import Any, Optional
-from msal_extensions.persistence import ABC
+from msal_extensions.persistence import ABC # type: ignore
from pydantic import BaseModel, ConfigDict
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py
index 987f94a35046e9..b382016473055d 100644
--- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py
+++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py
@@ -1,3 +1,5 @@
+from typing import Any
+
from pydantic import BaseModel, Field
from core.rag.datasource.retrieval_service import RetrievalService
@@ -69,25 +71,27 @@ def _run(self, query: str) -> str:
metadata=external_document.get("metadata"),
provider="external",
)
- document.metadata["score"] = external_document.get("score")
- document.metadata["title"] = external_document.get("title")
- document.metadata["dataset_id"] = dataset.id
- document.metadata["dataset_name"] = dataset.name
- results.append(document)
+ if document.metadata is not None:
+ document.metadata["score"] = external_document.get("score")
+ document.metadata["title"] = external_document.get("title")
+ document.metadata["dataset_id"] = dataset.id
+ document.metadata["dataset_name"] = dataset.name
+ results.append(document)
# deal with external documents
context_list = []
for position, item in enumerate(results, start=1):
- source = {
- "position": position,
- "dataset_id": item.metadata.get("dataset_id"),
- "dataset_name": item.metadata.get("dataset_name"),
- "document_name": item.metadata.get("title"),
- "data_source_type": "external",
- "retriever_from": self.retriever_from,
- "score": item.metadata.get("score"),
- "title": item.metadata.get("title"),
- "content": item.page_content,
- }
+ if item.metadata is not None:
+ source = {
+ "position": position,
+ "dataset_id": item.metadata.get("dataset_id"),
+ "dataset_name": item.metadata.get("dataset_name"),
+ "document_name": item.metadata.get("title"),
+ "data_source_type": "external",
+ "retriever_from": self.retriever_from,
+ "score": item.metadata.get("score"),
+ "title": item.metadata.get("title"),
+ "content": item.page_content,
+ }
context_list.append(source)
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
@@ -95,7 +99,7 @@ def _run(self, query: str) -> str:
return str("\n".join([item.page_content for item in results]))
else:
# get retrieval model , if the model is not setting , using default
- retrieval_model = dataset.retrieval_model or default_retrieval_model
+ retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
@@ -113,11 +117,11 @@ def _run(self, query: str) -> str:
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
- reranking_model=retrieval_model.get("reranking_model", None)
+ reranking_model=retrieval_model.get("reranking_model")
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
- weights=retrieval_model.get("weights", None),
+ weights=retrieval_model.get("weights"),
)
else:
documents = []
@@ -127,7 +131,7 @@ def _run(self, query: str) -> str:
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
- if item.metadata.get("score"):
+ if item.metadata is not None and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in documents]
@@ -155,20 +159,21 @@ def _run(self, query: str) -> str:
context_list = []
resource_number = 1
for segment in sorted_segments:
- context = {}
- document = Document.query.filter(
+ document_segment = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
- if dataset and document:
+ if not document_segment:
+ continue
+ if dataset and document_segment:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
- "document_id": document.id,
- "document_name": document.name,
- "data_source_type": document.data_source_type,
+ "document_id": document_segment.id,
+ "document_name": document_segment.name,
+ "data_source_type": document_segment.data_source_type,
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),
diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py
index 3c9295c493c470..2d7e193e152645 100644
--- a/api/core/tools/tool/dataset_retriever_tool.py
+++ b/api/core/tools/tool/dataset_retriever_tool.py
@@ -1,4 +1,4 @@
-from typing import Any
+from typing import Any, Optional
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -23,7 +23,7 @@ class DatasetRetrieverTool(Tool):
def get_dataset_tools(
tenant_id: str,
dataset_ids: list[str],
- retrieve_config: DatasetRetrieveConfigEntity,
+ retrieve_config: Optional[DatasetRetrieveConfigEntity],
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
@@ -51,6 +51,8 @@ def get_dataset_tools(
invoke_from=invoke_from,
hit_callback=hit_callback,
)
+ if retrieval_tools is None:
+ return []
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode
@@ -83,6 +85,7 @@ def get_runtime_parameters(self) -> list[ToolParameter]:
llm_description="Query for the dataset to be used to retrieve the dataset.",
required=True,
default="",
+ placeholder=I18nObject(en_US="", zh_Hans=""),
),
]
@@ -102,7 +105,9 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe
return self.create_text_message(text=result)
- def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
+ def validate_credentials(
+ self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
+ ) -> str | None:
"""
validate the credentials for dataset retriever tool
"""
diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py
index 8d4045038171a6..55f94d7619635b 100644
--- a/api/core/tools/tool/tool.py
+++ b/api/core/tools/tool/tool.py
@@ -91,7 +91,7 @@ def tool_provider_type(self) -> ToolProviderType:
:return: the tool provider type
"""
- def load_variables(self, variables: ToolRuntimeVariablePool):
+ def load_variables(self, variables: ToolRuntimeVariablePool | None) -> None:
"""
load variables from database
@@ -105,6 +105,8 @@ def set_image_variable(self, variable_name: str, image_key: str) -> None:
"""
if not self.variables:
return
+ if self.identity is None:
+ return
self.variables.set_file(self.identity.name, variable_name, image_key)
@@ -114,6 +116,8 @@ def set_text_variable(self, variable_name: str, text: str) -> None:
"""
if not self.variables:
return
+ if self.identity is None:
+ return
self.variables.set_text(self.identity.name, variable_name, text)
@@ -200,7 +204,11 @@ def list_default_image_variables(self) -> list[ToolRuntimeVariable]:
def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]:
# update tool_parameters
# TODO: Fix type error.
+ if self.runtime is None:
+ return []
if self.runtime.runtime_parameters:
+ # Convert Mapping to dict before updating
+ tool_parameters = dict(tool_parameters)
tool_parameters.update(self.runtime.runtime_parameters)
# try parse tool parameters into the correct type
@@ -221,7 +229,7 @@ def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) ->
Transform tool parameters type
"""
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
- result = deepcopy(tool_parameters)
+ result: dict[str, Any] = deepcopy(dict(tool_parameters))
for parameter in self.parameters or []:
if parameter.name in tool_parameters:
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
@@ -234,12 +242,15 @@ def _invoke(
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
pass
- def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
+ def validate_credentials(
+ self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
+ ) -> str | None:
"""
validate the credentials
:param credentials: the credentials
:param parameters: the parameters
+ :param format_only: only return the formatted
"""
pass
diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py
index 33b4ad021a5e7f..edff4a2d07cca2 100644
--- a/api/core/tools/tool/workflow_tool.py
+++ b/api/core/tools/tool/workflow_tool.py
@@ -68,20 +68,20 @@ def _invoke(
if data.get("error"):
raise Exception(data.get("error"))
- result = []
+ r = []
outputs = data.get("outputs")
if outputs == None:
outputs = {}
else:
- outputs, files = self._extract_files(outputs)
- for file in files:
- result.append(self.create_file_message(file))
+ outputs, extracted_files = self._extract_files(outputs)
+ for f in extracted_files:
+ r.append(self.create_file_message(f))
- result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
- result.append(self.create_json_message(outputs))
+ r.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
+ r.append(self.create_json_message(outputs))
- return result
+ return r
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
"""
diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py
index f92b43608ed935..425a892527daa4 100644
--- a/api/core/tools/tool_engine.py
+++ b/api/core/tools/tool_engine.py
@@ -3,7 +3,7 @@
from copy import deepcopy
from datetime import UTC, datetime
from mimetypes import guess_type
-from typing import Any, Optional, Union
+from typing import Any, Optional, Union, cast
from yarl import URL
@@ -46,7 +46,7 @@ def agent_invoke(
invoke_from: InvokeFrom,
agent_tool_callback: DifyAgentCallbackHandler,
trace_manager: Optional[TraceQueueManager] = None,
- ) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
+ ) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]:
"""
Agent invokes the tool with the given arguments.
"""
@@ -69,6 +69,8 @@ def agent_invoke(
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
# invoke the tool
+ if tool.identity is None:
+ raise ValueError("tool identity is not set")
try:
# hit the callback handler
agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)
@@ -163,6 +165,8 @@ def _invoke(tool: Tool, tool_parameters: dict, user_id: str) -> tuple[ToolInvoke
"""
Invoke the tool with the given arguments.
"""
+ if tool.identity is None:
+ raise ValueError("tool identity is not set")
started_at = datetime.now(UTC)
meta = ToolInvokeMeta(
time_cost=0.0,
@@ -171,7 +175,7 @@ def _invoke(tool: Tool, tool_parameters: dict, user_id: str) -> tuple[ToolInvoke
"tool_name": tool.identity.name,
"tool_provider": tool.identity.provider,
"tool_provider_type": tool.tool_provider_type().value,
- "tool_parameters": deepcopy(tool.runtime.runtime_parameters),
+ "tool_parameters": deepcopy(tool.runtime.runtime_parameters) if tool.runtime else {},
"tool_icon": tool.identity.icon,
},
)
@@ -194,9 +198,9 @@ def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str
result = ""
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
- result += response.message
+ result += str(response.message) if response.message is not None else ""
elif response.type == ToolInvokeMessage.MessageType.LINK:
- result += f"result link: {response.message}. please tell user to check it."
+ result += f"result link: {response.message!r}. please tell user to check it."
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
result += (
"image has been created and sent to user already, you do not need to create it,"
@@ -205,7 +209,7 @@ def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str
elif response.type == ToolInvokeMessage.MessageType.JSON:
result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}."
else:
- result += f"tool response: {response.message}."
+ result += f"tool response: {response.message!r}."
return result
@@ -223,7 +227,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis
mimetype = response.meta.get("mime_type")
else:
try:
- url = URL(response.message)
+ url = URL(cast(str, response.message))
extension = url.suffix
guess_type_result, _ = guess_type(f"a{extension}")
if guess_type_result:
@@ -237,7 +241,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis
result.append(
ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "image/jpeg"),
- url=response.message,
+ url=cast(str, response.message),
save_as=response.save_as,
)
)
@@ -245,7 +249,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis
result.append(
ToolInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "octet/stream"),
- url=response.message,
+ url=cast(str, response.message),
save_as=response.save_as,
)
)
@@ -257,7 +261,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis
mimetype=response.meta.get("mime_type", "octet/stream")
if response.meta
else "octet/stream",
- url=response.message,
+ url=cast(str, response.message),
save_as=response.save_as,
)
)
diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py
index 2a5a2944ef8471..e53985951b0627 100644
--- a/api/core/tools/tool_label_manager.py
+++ b/api/core/tools/tool_label_manager.py
@@ -84,13 +84,17 @@ def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
raise ValueError("Unsupported tool type")
- provider_ids = [controller.provider_id for controller in tool_providers]
+ provider_ids = [
+ controller.provider_id
+ for controller in tool_providers
+ if isinstance(controller, (ApiToolProviderController, WorkflowToolProviderController))
+ ]
labels: list[ToolLabelBinding] = (
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
)
- tool_labels = {label.tool_id: [] for label in labels}
+ tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
for label in labels:
tool_labels[label.tool_id].append(label.label_name)
diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py
index ac333162b6bb1c..5b2173a4d0ad69 100644
--- a/api/core/tools/tool_manager.py
+++ b/api/core/tools/tool_manager.py
@@ -4,7 +4,7 @@
from collections.abc import Generator
from os import listdir, path
from threading import Lock, Thread
-from typing import Any, Optional, Union
+from typing import Any, Optional, Union, cast
from configs import dify_config
from core.agent.entities import AgentToolEntity
@@ -15,15 +15,18 @@
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter
-from core.tools.errors import ToolProviderNotFoundError
+from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+from core.tools.provider.tool_provider import ToolProviderController
+from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
+from core.workflow.nodes.tool.entities import ToolEntity
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
@@ -33,9 +36,9 @@
class ToolManager:
_builtin_provider_lock = Lock()
- _builtin_providers = {}
+ _builtin_providers: dict[str, BuiltinToolProviderController] = {}
_builtin_providers_loaded = False
- _builtin_tools_labels = {}
+ _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
@@ -55,7 +58,7 @@ def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
return cls._builtin_providers[provider]
@classmethod
- def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
+ def get_builtin_tool(cls, provider: str, tool_name: str) -> Union[BuiltinTool, Tool]:
"""
get the builtin tool
@@ -66,13 +69,15 @@ def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
"""
provider_controller = cls.get_builtin_provider(provider)
tool = provider_controller.get_tool(tool_name)
+ if tool is None:
+ raise ToolNotFoundError(f"tool {tool_name} not found")
return tool
@classmethod
def get_tool(
cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: Optional[str] = None
- ) -> Union[BuiltinTool, ApiTool]:
+ ) -> Union[BuiltinTool, ApiTool, Tool]:
"""
get the tool
@@ -103,7 +108,7 @@ def get_tool_runtime(
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
- ) -> Union[BuiltinTool, ApiTool]:
+ ) -> Union[BuiltinTool, ApiTool, Tool]:
"""
get the tool runtime
@@ -113,6 +118,7 @@ def get_tool_runtime(
:return: the tool
"""
+ controller: Union[BuiltinToolProviderController, ApiToolProviderController, WorkflowToolProviderController]
if provider_type == "builtin":
builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
@@ -129,7 +135,7 @@ def get_tool_runtime(
)
# get credentials
- builtin_provider: BuiltinToolProvider = (
+ builtin_provider: Optional[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
@@ -177,7 +183,7 @@ def get_tool_runtime(
}
)
elif provider_type == "workflow":
- workflow_provider = (
+ workflow_provider: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
@@ -187,8 +193,13 @@ def get_tool_runtime(
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
+ controller_tools: Optional[list[Tool]] = controller.get_tools(
+ user_id="", tenant_id=workflow_provider.tenant_id
+ )
+ if controller_tools is None or len(controller_tools) == 0:
+ raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
- return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
+ return controller_tools[0].fork_tool_runtime(
runtime={
"tenant_id": tenant_id,
"credentials": {},
@@ -215,7 +226,7 @@ def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict
if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
- options = [x.value for x in parameter_rule.options]
+ options = [x.value for x in parameter_rule.options or []]
if parameter_value is not None and parameter_value not in options:
raise ValueError(
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
@@ -267,6 +278,8 @@ def get_agent_tool_runtime(
identity_id=f"AGENT.{app_id}",
)
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
+ if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
+ raise ValueError("runtime not found or runtime parameters not found")
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@@ -312,6 +325,9 @@ def get_workflow_tool_runtime(
if runtime_parameters:
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
+ if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
+ raise ValueError("runtime not found or runtime parameters not found")
+
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@@ -326,6 +342,8 @@ def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]:
"""
# get provider
provider_controller = cls.get_builtin_provider(provider)
+ if provider_controller.identity is None:
+ raise ToolProviderNotFoundError(f"builtin provider {provider} not found")
absolute_path = path.join(
path.dirname(path.realpath(__file__)),
@@ -381,11 +399,15 @@ def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, Non
),
parent_type=BuiltinToolProviderController,
)
- provider: BuiltinToolProviderController = provider_class()
- cls._builtin_providers[provider.identity.name] = provider
- for tool in provider.get_tools():
+ provider_controller: BuiltinToolProviderController = provider_class()
+ if provider_controller.identity is None:
+ continue
+ cls._builtin_providers[provider_controller.identity.name] = provider_controller
+ for tool in provider_controller.get_tools() or []:
+ if tool.identity is None:
+ continue
cls._builtin_tools_labels[tool.identity.name] = tool.identity.label
- yield provider
+ yield provider_controller
except Exception as e:
logger.exception(f"load builtin provider {provider}")
@@ -449,9 +471,11 @@ def user_list_providers(
# append builtin providers
for provider in builtin_providers:
# handle include, exclude
+ if provider.identity is None:
+ continue
if is_filtered(
- include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
- exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
+ include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
+ exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
data=provider,
name_func=lambda x: x.identity.name,
):
@@ -472,7 +496,7 @@ def user_list_providers(
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
)
- api_provider_controllers = [
+ api_provider_controllers: list[dict[str, Any]] = [
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
for provider in db_api_providers
]
@@ -495,7 +519,7 @@ def user_list_providers(
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
)
- workflow_provider_controllers = []
+ workflow_provider_controllers: list[WorkflowToolProviderController] = []
for provider in workflow_providers:
try:
workflow_provider_controllers.append(
@@ -505,7 +529,9 @@ def user_list_providers(
# app has been deleted
pass
- labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
+ labels = ToolLabelManager.get_tools_labels(
+ [cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
+ )
for provider_controller in workflow_provider_controllers:
user_provider = ToolTransformService.workflow_provider_to_user_provider(
@@ -527,7 +553,7 @@ def get_api_provider_controller(
:return: the provider controller, the credentials
"""
- provider: ApiToolProvider = (
+ provider: Optional[ApiToolProvider] = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.id == provider_id,
@@ -556,7 +582,7 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
get tool provider
"""
provider_name = provider
- provider: ApiToolProvider = (
+ provider_tool: Optional[ApiToolProvider] = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
@@ -565,17 +591,18 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
.first()
)
- if provider is None:
+ if provider_tool is None:
raise ValueError(f"you have not added provider {provider_name}")
try:
- credentials = json.loads(provider.credentials_str) or {}
+ credentials = json.loads(provider_tool.credentials_str) or {}
except:
credentials = {}
# package tool provider controller
controller = ApiToolProviderController.from_db(
- provider, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE
+ provider_tool,
+ ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
)
# init tool configuration
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
@@ -584,25 +611,28 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
try:
- icon = json.loads(provider.icon)
+ icon = json.loads(provider_tool.icon)
except:
icon = {"background": "#252525", "content": "\ud83d\ude01"}
# add tool labels
labels = ToolLabelManager.get_tool_labels(controller)
- return jsonable_encoder(
- {
- "schema_type": provider.schema_type,
- "schema": provider.schema,
- "tools": provider.tools,
- "icon": icon,
- "description": provider.description,
- "credentials": masked_credentials,
- "privacy_policy": provider.privacy_policy,
- "custom_disclaimer": provider.custom_disclaimer,
- "labels": labels,
- }
+ return cast(
+ dict,
+ jsonable_encoder(
+ {
+ "schema_type": provider_tool.schema_type,
+ "schema": provider_tool.schema,
+ "tools": provider_tool.tools,
+ "icon": icon,
+ "description": provider_tool.description,
+ "credentials": masked_credentials,
+ "privacy_policy": provider_tool.privacy_policy,
+ "custom_disclaimer": provider_tool.custom_disclaimer,
+ "labels": labels,
+ }
+ ),
)
@classmethod
@@ -617,6 +647,7 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) ->
"""
provider_type = provider_type
provider_id = provider_id
+ provider: Optional[Union[BuiltinToolProvider, ApiToolProvider, WorkflowToolProvider]] = None
if provider_type == "builtin":
return (
dify_config.CONSOLE_API_URL
@@ -626,16 +657,21 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) ->
)
elif provider_type == "api":
try:
- provider: ApiToolProvider = (
+ provider = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first()
)
- return json.loads(provider.icon)
+ if provider is None:
+ raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
+ icon = json.loads(provider.icon)
+ if isinstance(icon, (str, dict)):
+ return icon
+ return {"background": "#252525", "content": "\ud83d\ude01"}
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
elif provider_type == "workflow":
- provider: WorkflowToolProvider = (
+ provider = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
@@ -643,7 +679,13 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) ->
if provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
- return json.loads(provider.icon)
+ try:
+ icon = json.loads(provider.icon)
+ if isinstance(icon, (str, dict)):
+ return icon
+ return {"background": "#252525", "content": "\ud83d\ude01"}
+ except:
+ return {"background": "#252525", "content": "\ud83d\ude01"}
else:
raise ValueError(f"provider type {provider_type} not found")
diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py
index 8b5e27f5382ee7..d7720928644701 100644
--- a/api/core/tools/utils/configuration.py
+++ b/api/core/tools/utils/configuration.py
@@ -72,9 +72,13 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str
return a deep copy of credentials with decrypted values
"""
+ identity_id = ""
+ if self.provider_controller.identity:
+ identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}"
+
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
- identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}",
+ identity_id=identity_id,
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cached_credentials = cache.get()
@@ -95,9 +99,13 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str
return credentials
def delete_tool_credentials_cache(self):
+ identity_id = ""
+ if self.provider_controller.identity:
+ identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}"
+
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
- identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}",
+ identity_id=identity_id,
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cache.delete()
@@ -199,6 +207,9 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
return a deep copy of parameters with decrypted values
"""
+ if self.tool_runtime is None or self.tool_runtime.identity is None:
+ raise ValueError("tool_runtime is required")
+
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type}.{self.provider_name}",
@@ -232,6 +243,9 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
return parameters
def delete_tool_parameters_cache(self):
+ if self.tool_runtime is None or self.tool_runtime.identity is None:
+ raise ValueError("tool_runtime is required")
+
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type}.{self.provider_name}",
diff --git a/api/core/tools/utils/feishu_api_utils.py b/api/core/tools/utils/feishu_api_utils.py
index ea28037df03720..ecf60045aa8dc5 100644
--- a/api/core/tools/utils/feishu_api_utils.py
+++ b/api/core/tools/utils/feishu_api_utils.py
@@ -1,5 +1,5 @@
import json
-from typing import Optional
+from typing import Any, Optional, cast
import httpx
@@ -101,7 +101,7 @@ def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict:
"""
url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token"
payload = {"app_id": app_id, "app_secret": app_secret}
- res = self._send_request(url, require_token=False, payload=payload)
+ res: dict = self._send_request(url, require_token=False, payload=payload)
return res
def create_document(self, title: str, content: str, folder_token: str) -> dict:
@@ -126,15 +126,16 @@ def create_document(self, title: str, content: str, folder_token: str) -> dict:
"content": content,
"folder_token": folder_token,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def write_document(self, document_id: str, content: str, position: str = "end") -> dict:
url = f"{self.API_BASE_URL}/document/write_document"
payload = {"document_id": document_id, "content": content, "position": position}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
return res
def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str:
@@ -155,9 +156,9 @@ def get_document_content(self, document_id: str, mode: str = "markdown", lang: s
"lang": lang,
}
url = f"{self.API_BASE_URL}/document/get_document_content"
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data").get("content")
+ return cast(str, res.get("data", {}).get("content"))
return ""
def list_document_blocks(
@@ -173,9 +174,10 @@ def list_document_blocks(
"page_token": page_token,
}
url = f"{self.API_BASE_URL}/document/list_document_blocks"
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict:
@@ -191,9 +193,10 @@ def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str,
"msg_type": msg_type,
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
}
- res = self._send_request(url, params=params, payload=payload)
+ res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict:
@@ -203,7 +206,7 @@ def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dic
"msg_type": msg_type,
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
}
- res = self._send_request(url, require_token=False, payload=payload)
+ res: dict = self._send_request(url, require_token=False, payload=payload)
return res
def get_chat_messages(
@@ -227,9 +230,10 @@ def get_chat_messages(
"page_token": page_token,
"page_size": page_size,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def get_thread_messages(
@@ -245,9 +249,10 @@ def get_thread_messages(
"page_token": page_token,
"page_size": page_size,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict:
@@ -260,9 +265,10 @@ def create_task(self, summary: str, start_time: str, end_time: str, completed_ti
"completed_at": completed_time,
"description": description,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def update_task(
@@ -278,9 +284,10 @@ def update_task(
"completed_time": completed_time,
"description": description,
}
- res = self._send_request(url, method="PATCH", payload=payload)
+ res: dict = self._send_request(url, method="PATCH", payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def delete_task(self, task_guid: str) -> dict:
@@ -289,7 +296,7 @@ def delete_task(self, task_guid: str) -> dict:
payload = {
"task_guid": task_guid,
}
- res = self._send_request(url, method="DELETE", payload=payload)
+ res: dict = self._send_request(url, method="DELETE", payload=payload)
return res
def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict:
@@ -300,7 +307,7 @@ def add_members(self, task_guid: str, member_phone_or_email: str, member_role: s
"member_phone_or_email": member_phone_or_email,
"member_role": member_role,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
return res
def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict:
@@ -312,9 +319,10 @@ def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str,
"page_token": page_token,
"page_size": page_size,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
@@ -322,9 +330,10 @@ def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
params = {
"user_id_type": user_id_type,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def create_event(
@@ -347,9 +356,10 @@ def create_event(
"auto_record": auto_record,
"attendee_ability": attendee_ability,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def update_event(
@@ -363,7 +373,7 @@ def update_event(
auto_record: bool,
) -> dict:
url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}"
- payload = {}
+ payload: dict[str, Any] = {}
if summary:
payload["summary"] = summary
if description:
@@ -376,7 +386,7 @@ def update_event(
payload["need_notification"] = need_notification
if auto_record:
payload["auto_record"] = auto_record
- res = self._send_request(url, method="PATCH", payload=payload)
+ res: dict = self._send_request(url, method="PATCH", payload=payload)
return res
def delete_event(self, event_id: str, need_notification: bool = True) -> dict:
@@ -384,7 +394,7 @@ def delete_event(self, event_id: str, need_notification: bool = True) -> dict:
params = {
"need_notification": need_notification,
}
- res = self._send_request(url, method="DELETE", params=params)
+ res: dict = self._send_request(url, method="DELETE", params=params)
return res
def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict:
@@ -395,9 +405,10 @@ def list_events(self, start_time: str, end_time: str, page_token: str, page_size
"page_token": page_token,
"page_size": page_size,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def search_events(
@@ -418,9 +429,10 @@ def search_events(
"user_id_type": user_id_type,
"page_size": page_size,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict:
@@ -431,9 +443,10 @@ def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_
"attendee_phone_or_email": attendee_phone_or_email,
"need_notification": need_notification,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def create_spreadsheet(
@@ -447,9 +460,10 @@ def create_spreadsheet(
"title": title,
"folder_token": folder_token,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def get_spreadsheet(
@@ -463,9 +477,10 @@ def get_spreadsheet(
"spreadsheet_token": spreadsheet_token,
"user_id_type": user_id_type,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def list_spreadsheet_sheets(
@@ -477,9 +492,10 @@ def list_spreadsheet_sheets(
params = {
"spreadsheet_token": spreadsheet_token,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def add_rows(
@@ -499,9 +515,10 @@ def add_rows(
"length": length,
"values": values,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def add_cols(
@@ -521,9 +538,10 @@ def add_cols(
"length": length,
"values": values,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def read_rows(
@@ -545,9 +563,10 @@ def read_rows(
"num_rows": num_rows,
"user_id_type": user_id_type,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def read_cols(
@@ -569,9 +588,10 @@ def read_cols(
"num_cols": num_cols,
"user_id_type": user_id_type,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def read_table(
@@ -593,9 +613,10 @@ def read_table(
"query": query,
"user_id_type": user_id_type,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def create_base(
@@ -609,9 +630,10 @@ def create_base(
"name": name,
"folder_token": folder_token,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def add_records(
@@ -633,9 +655,10 @@ def add_records(
payload = {
"records": convert_add_records(records),
}
- res = self._send_request(url, params=params, payload=payload)
+ res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def update_records(
@@ -657,9 +680,10 @@ def update_records(
payload = {
"records": convert_update_records(records),
}
- res = self._send_request(url, params=params, payload=payload)
+ res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def delete_records(
@@ -686,9 +710,10 @@ def delete_records(
payload = {
"records": record_id_list,
}
- res = self._send_request(url, params=params, payload=payload)
+ res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def search_record(
@@ -740,7 +765,7 @@ def search_record(
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
- payload = {}
+ payload: dict[str, Any] = {}
if view_id:
payload["view_id"] = view_id
@@ -752,10 +777,11 @@ def search_record(
payload["filter"] = filter_dict
if automatic_fields:
payload["automatic_fields"] = automatic_fields
- res = self._send_request(url, params=params, payload=payload)
+ res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def get_base_info(
@@ -767,9 +793,10 @@ def get_base_info(
params = {
"app_token": app_token,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def create_table(
@@ -797,9 +824,10 @@ def create_table(
}
if default_view_name:
payload["default_view_name"] = default_view_name
- res = self._send_request(url, params=params, payload=payload)
+ res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def delete_tables(
@@ -834,9 +862,10 @@ def delete_tables(
"table_names": table_name_list,
}
- res = self._send_request(url, params=params, payload=payload)
+ res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def list_tables(
@@ -852,9 +881,10 @@ def list_tables(
"page_token": page_token,
"page_size": page_size,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def read_records(
@@ -882,7 +912,8 @@ def read_records(
"record_ids": record_id_list,
"user_id_type": user_id_type,
}
- res = self._send_request(url, method="GET", params=params, payload=payload)
+ res: dict = self._send_request(url, method="GET", params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
diff --git a/api/core/tools/utils/lark_api_utils.py b/api/core/tools/utils/lark_api_utils.py
index 30cb0cb141d9a6..de394a39bf5a00 100644
--- a/api/core/tools/utils/lark_api_utils.py
+++ b/api/core/tools/utils/lark_api_utils.py
@@ -1,5 +1,5 @@
import json
-from typing import Optional
+from typing import Any, Optional, cast
import httpx
@@ -62,12 +62,10 @@ def convert_update_records(self, json_str):
def tenant_access_token(self) -> str:
feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token"
if redis_client.exists(feishu_tenant_access_token):
- return redis_client.get(feishu_tenant_access_token).decode()
- res = self.get_tenant_access_token(self.app_id, self.app_secret)
+ return str(redis_client.get(feishu_tenant_access_token).decode())
+ res: dict[str, str] = self.get_tenant_access_token(self.app_id, self.app_secret)
redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token"))
- if "tenant_access_token" in res:
- return res.get("tenant_access_token")
- return ""
+ return res.get("tenant_access_token", "")
def _send_request(
self,
@@ -91,7 +89,7 @@ def _send_request(
def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict:
url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token"
payload = {"app_id": app_id, "app_secret": app_secret}
- res = self._send_request(url, require_token=False, payload=payload)
+ res: dict = self._send_request(url, require_token=False, payload=payload)
return res
def create_document(self, title: str, content: str, folder_token: str) -> dict:
@@ -101,15 +99,16 @@ def create_document(self, title: str, content: str, folder_token: str) -> dict:
"content": content,
"folder_token": folder_token,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def write_document(self, document_id: str, content: str, position: str = "end") -> dict:
url = f"{self.API_BASE_URL}/document/write_document"
payload = {"document_id": document_id, "content": content, "position": position}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
return res
def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str | dict:
@@ -119,9 +118,9 @@ def get_document_content(self, document_id: str, mode: str = "markdown", lang: s
"lang": lang,
}
url = f"{self.API_BASE_URL}/document/get_document_content"
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data").get("content")
+ return cast(dict, res.get("data", {}).get("content"))
return ""
def list_document_blocks(
@@ -134,9 +133,10 @@ def list_document_blocks(
"page_token": page_token,
}
url = f"{self.API_BASE_URL}/document/list_document_blocks"
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict:
@@ -149,9 +149,10 @@ def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str,
"msg_type": msg_type,
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
}
- res = self._send_request(url, params=params, payload=payload)
+ res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict:
@@ -161,7 +162,7 @@ def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dic
"msg_type": msg_type,
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
}
- res = self._send_request(url, require_token=False, payload=payload)
+ res: dict = self._send_request(url, require_token=False, payload=payload)
return res
def get_chat_messages(
@@ -182,9 +183,10 @@ def get_chat_messages(
"page_token": page_token,
"page_size": page_size,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def get_thread_messages(
@@ -197,9 +199,10 @@ def get_thread_messages(
"page_token": page_token,
"page_size": page_size,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict:
@@ -211,9 +214,10 @@ def create_task(self, summary: str, start_time: str, end_time: str, completed_ti
"completed_at": completed_time,
"description": description,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def update_task(
@@ -228,9 +232,10 @@ def update_task(
"completed_time": completed_time,
"description": description,
}
- res = self._send_request(url, method="PATCH", payload=payload)
+ res: dict = self._send_request(url, method="PATCH", payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def delete_task(self, task_guid: str) -> dict:
@@ -238,9 +243,10 @@ def delete_task(self, task_guid: str) -> dict:
payload = {
"task_guid": task_guid,
}
- res = self._send_request(url, method="DELETE", payload=payload)
+ res: dict = self._send_request(url, method="DELETE", payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict:
@@ -250,9 +256,10 @@ def add_members(self, task_guid: str, member_phone_or_email: str, member_role: s
"member_phone_or_email": member_phone_or_email,
"member_role": member_role,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict:
@@ -263,9 +270,10 @@ def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str,
"page_token": page_token,
"page_size": page_size,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
@@ -273,9 +281,10 @@ def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
params = {
"user_id_type": user_id_type,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def create_event(
@@ -298,9 +307,10 @@ def create_event(
"auto_record": auto_record,
"attendee_ability": attendee_ability,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def update_event(
@@ -314,7 +324,7 @@ def update_event(
auto_record: bool,
) -> dict:
url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}"
- payload = {}
+ payload: dict[str, Any] = {}
if summary:
payload["summary"] = summary
if description:
@@ -327,7 +337,7 @@ def update_event(
payload["need_notification"] = need_notification
if auto_record:
payload["auto_record"] = auto_record
- res = self._send_request(url, method="PATCH", payload=payload)
+ res: dict = self._send_request(url, method="PATCH", payload=payload)
return res
def delete_event(self, event_id: str, need_notification: bool = True) -> dict:
@@ -335,7 +345,7 @@ def delete_event(self, event_id: str, need_notification: bool = True) -> dict:
params = {
"need_notification": need_notification,
}
- res = self._send_request(url, method="DELETE", params=params)
+ res: dict = self._send_request(url, method="DELETE", params=params)
return res
def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict:
@@ -346,9 +356,10 @@ def list_events(self, start_time: str, end_time: str, page_token: str, page_size
"page_token": page_token,
"page_size": page_size,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def search_events(
@@ -369,9 +380,10 @@ def search_events(
"user_id_type": user_id_type,
"page_size": page_size,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict:
@@ -381,9 +393,10 @@ def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_
"attendee_phone_or_email": attendee_phone_or_email,
"need_notification": need_notification,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def create_spreadsheet(
@@ -396,9 +409,10 @@ def create_spreadsheet(
"title": title,
"folder_token": folder_token,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def get_spreadsheet(
@@ -411,9 +425,10 @@ def get_spreadsheet(
"spreadsheet_token": spreadsheet_token,
"user_id_type": user_id_type,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def list_spreadsheet_sheets(
@@ -424,9 +439,10 @@ def list_spreadsheet_sheets(
params = {
"spreadsheet_token": spreadsheet_token,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def add_rows(
@@ -445,9 +461,10 @@ def add_rows(
"length": length,
"values": values,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def add_cols(
@@ -466,9 +483,10 @@ def add_cols(
"length": length,
"values": values,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def read_rows(
@@ -489,9 +507,10 @@ def read_rows(
"num_rows": num_rows,
"user_id_type": user_id_type,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def read_cols(
@@ -512,9 +531,10 @@ def read_cols(
"num_cols": num_cols,
"user_id_type": user_id_type,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def read_table(
@@ -535,9 +555,10 @@ def read_table(
"query": query,
"user_id_type": user_id_type,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def create_base(
@@ -550,9 +571,10 @@ def create_base(
"name": name,
"folder_token": folder_token,
}
- res = self._send_request(url, payload=payload)
+ res: dict = self._send_request(url, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def add_records(
@@ -573,9 +595,10 @@ def add_records(
payload = {
"records": self.convert_add_records(records),
}
- res = self._send_request(url, params=params, payload=payload)
+ res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def update_records(
@@ -596,9 +619,10 @@ def update_records(
payload = {
"records": self.convert_update_records(records),
}
- res = self._send_request(url, params=params, payload=payload)
+ res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def delete_records(
@@ -624,9 +648,10 @@ def delete_records(
payload = {
"records": record_id_list,
}
- res = self._send_request(url, params=params, payload=payload)
+ res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def search_record(
@@ -678,7 +703,7 @@ def search_record(
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
- payload = {}
+ payload: dict[str, Any] = {}
if view_id:
payload["view_id"] = view_id
@@ -690,9 +715,10 @@ def search_record(
payload["filter"] = filter_dict
if automatic_fields:
payload["automatic_fields"] = automatic_fields
- res = self._send_request(url, params=params, payload=payload)
+ res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def get_base_info(
@@ -703,9 +729,10 @@ def get_base_info(
params = {
"app_token": app_token,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def create_table(
@@ -732,9 +759,10 @@ def create_table(
}
if default_view_name:
payload["default_view_name"] = default_view_name
- res = self._send_request(url, params=params, payload=payload)
+ res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def delete_tables(
@@ -767,9 +795,10 @@ def delete_tables(
"table_ids": table_id_list,
"table_names": table_name_list,
}
- res = self._send_request(url, params=params, payload=payload)
+ res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def list_tables(
@@ -784,9 +813,10 @@ def list_tables(
"page_token": page_token,
"page_size": page_size,
}
- res = self._send_request(url, method="GET", params=params)
+ res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
def read_records(
@@ -814,7 +844,8 @@ def read_records(
"record_ids": record_id_list,
"user_id_type": user_id_type,
}
- res = self._send_request(url, method="POST", params=params, payload=payload)
+ res: dict = self._send_request(url, method="POST", params=params, payload=payload)
if "data" in res:
- return res.get("data")
+ data: dict = res.get("data", {})
+ return data
return res
diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py
index e30c903a4b1146..3509f1e6e59f77 100644
--- a/api/core/tools/utils/message_transformer.py
+++ b/api/core/tools/utils/message_transformer.py
@@ -90,12 +90,12 @@ def transform_tool_invoke_messages(
)
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
- file = message.meta.get("file")
- if isinstance(file, File):
- if file.transfer_method == FileTransferMethod.TOOL_FILE:
- assert file.related_id is not None
- url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
- if file.type == FileType.IMAGE:
+ file_mata = message.meta.get("file")
+ if isinstance(file_mata, File):
+ if file_mata.transfer_method == FileTransferMethod.TOOL_FILE:
+ assert file_mata.related_id is not None
+ url = cls.get_tool_file_url(tool_file_id=file_mata.related_id, extension=file_mata.extension)
+ if file_mata.type == FileType.IMAGE:
result.append(
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py
index 4e226810d6ac90..3689dcc9e5ebfd 100644
--- a/api/core/tools/utils/model_invocation_utils.py
+++ b/api/core/tools/utils/model_invocation_utils.py
@@ -5,7 +5,7 @@
"""
import json
-from typing import cast
+from typing import Optional, cast
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
@@ -51,7 +51,7 @@ def get_max_llm_context_tokens(
if not schema:
raise InvokeModelError("No model schema found")
- max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
+ max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
if max_tokens is None:
return 2048
@@ -133,14 +133,17 @@ def invoke(
db.session.commit()
try:
- response: LLMResult = model_instance.invoke_llm(
- prompt_messages=prompt_messages,
- model_parameters=model_parameters,
- tools=[],
- stop=[],
- stream=False,
- user=user_id,
- callbacks=[],
+ response: LLMResult = cast(
+ LLMResult,
+ model_instance.invoke_llm(
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=[],
+ stop=[],
+ stream=False,
+ user=user_id,
+ callbacks=[],
+ ),
)
except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}")
diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py
index ae44b1b99d447a..f1dc1123b9935f 100644
--- a/api/core/tools/utils/parser.py
+++ b/api/core/tools/utils/parser.py
@@ -6,7 +6,7 @@
from typing import Optional
from requests import get
-from yaml import YAMLError, safe_load
+from yaml import YAMLError, safe_load # type: ignore
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
@@ -64,6 +64,9 @@ def parse_openapi_to_tool_bundle(
default=parameter["schema"]["default"]
if "schema" in parameter and "default" in parameter["schema"]
else None,
+ placeholder=I18nObject(
+ en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
+ ),
)
# check if there is a type
@@ -108,6 +111,9 @@ def parse_openapi_to_tool_bundle(
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=property.get("default", None),
+ placeholder=I18nObject(
+ en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
+ ),
)
# check if there is a type
@@ -158,9 +164,9 @@ def parse_openapi_to_tool_bundle(
return bundles
@staticmethod
- def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType:
+ def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
parameter = parameter or {}
- typ = None
+ typ: Optional[str] = None
if parameter.get("format") == "binary":
return ToolParameter.ToolParameterType.FILE
@@ -175,6 +181,8 @@ def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType
return ToolParameter.ToolParameterType.BOOLEAN
elif typ == "string":
return ToolParameter.ToolParameterType.STRING
+ else:
+ return None
@staticmethod
def parse_openapi_yaml_to_tool_bundle(
@@ -236,7 +244,8 @@ def parse_swagger_to_openapi(swagger: dict, extra_info: Optional[dict], warning:
if ("summary" not in operation or len(operation["summary"]) == 0) and (
"description" not in operation or len(operation["description"]) == 0
):
- warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
+ if warning is not None:
+ warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
openapi["paths"][path][method] = {
"operationId": operation["operationId"],
diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py
index 3aae31e93a1304..d42fd99fce5e80 100644
--- a/api/core/tools/utils/web_reader_tool.py
+++ b/api/core/tools/utils/web_reader_tool.py
@@ -9,13 +9,13 @@
import unicodedata
from contextlib import contextmanager
from pathlib import Path
-from typing import Optional
+from typing import Any, Literal, Optional, cast
from urllib.parse import unquote
import chardet
-import cloudscraper
-from bs4 import BeautifulSoup, CData, Comment, NavigableString
-from regex import regex
+import cloudscraper # type: ignore
+from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore
+from regex import regex # type: ignore
from core.helper import ssrf_proxy
from core.rag.extractor import extract_processor
@@ -68,7 +68,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
- return ExtractProcessor.load_from_url(url, return_text=True)
+ return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
@@ -125,7 +125,7 @@ def extract_using_readabilipy(html):
os.unlink(article_json_path)
os.unlink(html_path)
- article_json = {
+ article_json: dict[str, Any] = {
"title": None,
"byline": None,
"date": None,
@@ -300,7 +300,7 @@ def strip_control_characters(text):
def normalize_unicode(text):
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
- normal_form = "NFKC"
+ normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC"
text = unicodedata.normalize(normal_form, text)
return text
@@ -332,6 +332,7 @@ def add_content_digest(element):
def content_digest(element):
+ digest: Any
if is_text(element):
# Hash
trimmed_string = element.string.strip()
diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py
index d92bfb9b90a9aa..08a112cfdb2b91 100644
--- a/api/core/tools/utils/workflow_configuration_sync.py
+++ b/api/core/tools/utils/workflow_configuration_sync.py
@@ -7,7 +7,7 @@
class WorkflowToolConfigurationUtils:
@classmethod
- def check_parameter_configurations(cls, configurations: Mapping[str, Any]):
+ def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@@ -27,7 +27,7 @@ def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[Vari
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
- ) -> None:
+ ) -> bool:
"""
check is synced
diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py
index 42c7f85bc6daeb..ee7ca11e056625 100644
--- a/api/core/tools/utils/yaml_utils.py
+++ b/api/core/tools/utils/yaml_utils.py
@@ -2,7 +2,7 @@
from pathlib import Path
from typing import Any
-import yaml
+import yaml # type: ignore
from yaml import YAMLError
logger = logging.getLogger(__name__)
diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py
index 973e420961bb46..c32815b24d02ed 100644
--- a/api/core/variables/variables.py
+++ b/api/core/variables/variables.py
@@ -1,4 +1,5 @@
from collections.abc import Sequence
+from typing import cast
from uuid import uuid4
from pydantic import Field
@@ -78,7 +79,7 @@ class SecretVariable(StringVariable):
@property
def log(self) -> str:
- return encrypter.obfuscated_token(self.value)
+ return cast(str, encrypter.obfuscated_token(self.value))
class NoneVariable(NoneSegment, Variable):
diff --git a/api/core/workflow/callbacks/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py
index ed737e7316973c..b9c6b35ad3476a 100644
--- a/api/core/workflow/callbacks/workflow_logging_callback.py
+++ b/api/core/workflow/callbacks/workflow_logging_callback.py
@@ -33,7 +33,7 @@
class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None:
- self.current_node_id = None
+ self.current_node_id: Optional[str] = None
def on_event(self, event: GraphEngineEvent) -> None:
if isinstance(event, GraphRunStartedEvent):
diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py
index ca01dcd7d8d4a8..ae5f117bf9b121 100644
--- a/api/core/workflow/entities/node_entities.py
+++ b/api/core/workflow/entities/node_entities.py
@@ -36,9 +36,9 @@ class NodeRunResult(BaseModel):
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[Mapping[str, Any]] = None # node inputs
- process_data: Optional[dict[str, Any]] = None # process data
+ process_data: Optional[Mapping[str, Any]] = None # process data
outputs: Optional[Mapping[str, Any]] = None # node outputs
- metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
+ metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata
llm_usage: Optional[LLMUsage] = None # llm usage
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py
index bc3a15bd004ace..b8470aecbd83a2 100644
--- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py
+++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py
@@ -5,7 +5,7 @@
class ConditionRunConditionHandlerHandler(RunConditionHandler):
- def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
+ def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState):
"""
Check if the condition can be executed
diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py
index 800dd136afb57f..b3bcc3b2ccc309 100644
--- a/api/core/workflow/graph_engine/entities/graph.py
+++ b/api/core/workflow/graph_engine/entities/graph.py
@@ -1,4 +1,5 @@
import uuid
+from collections import defaultdict
from collections.abc import Mapping
from typing import Any, Optional, cast
@@ -310,26 +311,17 @@ def _recursively_add_parallels(
parallel = None
if len(target_node_edges) > 1:
# fetch all node ids in current parallels
- parallel_branch_node_ids = {}
- condition_edge_mappings = {}
+ parallel_branch_node_ids = defaultdict(list)
+ condition_edge_mappings = defaultdict(list)
for graph_edge in target_node_edges:
if graph_edge.run_condition is None:
- if "default" not in parallel_branch_node_ids:
- parallel_branch_node_ids["default"] = []
-
parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
else:
condition_hash = graph_edge.run_condition.hash
- if condition_hash not in condition_edge_mappings:
- condition_edge_mappings[condition_hash] = []
-
condition_edge_mappings[condition_hash].append(graph_edge)
for condition_hash, graph_edges in condition_edge_mappings.items():
if len(graph_edges) > 1:
- if condition_hash not in parallel_branch_node_ids:
- parallel_branch_node_ids[condition_hash] = []
-
for graph_edge in graph_edges:
parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)
@@ -418,7 +410,7 @@ def _recursively_add_parallels(
if condition_edge_mappings:
for condition_hash, graph_edges in condition_edge_mappings.items():
for graph_edge in graph_edges:
- current_parallel: GraphParallel | None = cls._get_current_parallel(
+ current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=condition_parallels.get(condition_hash),
diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py
index 854036b2c13212..db1e01f14fda59 100644
--- a/api/core/workflow/graph_engine/graph_engine.py
+++ b/api/core/workflow/graph_engine/graph_engine.py
@@ -40,6 +40,7 @@
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
+from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
@@ -66,7 +67,7 @@ def __init__(
self.max_submit_count = max_submit_count
self.submit_count = 0
- def submit(self, fn, *args, **kwargs):
+ def submit(self, fn, /, *args, **kwargs):
self.submit_count += 1
self.check_is_full()
@@ -140,7 +141,8 @@ def __init__(
def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event
yield GraphRunStartedEvent()
- handle_exceptions = []
+ handle_exceptions: list[str] = []
+ stream_processor: StreamProcessor
try:
if self.init_params.workflow_type == WorkflowType.CHAT:
@@ -168,7 +170,7 @@ def run(self) -> Generator[GraphEngineEvent, None, None]:
elif isinstance(item, NodeRunSucceededEvent):
if item.node_type == NodeType.END:
self.graph_runtime_state.outputs = (
- item.route_node_state.node_run_result.outputs
+ dict(item.route_node_state.node_run_result.outputs)
if item.route_node_state.node_run_result
and item.route_node_state.node_run_result.outputs
else {}
@@ -350,7 +352,7 @@ def _run(
if any(edge.run_condition for edge in edge_mappings):
# if nodes has run conditions, get node id which branch to take based on the run condition results
- condition_edge_mappings = {}
+ condition_edge_mappings: dict[str, list[GraphEdge]] = {}
for edge in edge_mappings:
if edge.run_condition:
run_condition_hash = edge.run_condition.hash
@@ -364,6 +366,9 @@ def _run(
continue
edge = cast(GraphEdge, sub_edge_mappings[0])
+ if edge.run_condition is None:
+ logger.warning(f"Edge {edge.target_node_id} run condition is None")
+ continue
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
@@ -387,11 +392,11 @@ def _run(
handle_exceptions=handle_exceptions,
)
- for item in parallel_generator:
- if isinstance(item, str):
- final_node_id = item
+ for parallel_result in parallel_generator:
+ if isinstance(parallel_result, str):
+ final_node_id = parallel_result
else:
- yield item
+ yield parallel_result
break
@@ -413,11 +418,11 @@ def _run(
handle_exceptions=handle_exceptions,
)
- for item in parallel_generator:
- if isinstance(item, str):
- final_node_id = item
+ for generated_item in parallel_generator:
+ if isinstance(generated_item, str):
+ final_node_id = generated_item
else:
- yield item
+ yield generated_item
if not final_node_id:
break
@@ -653,7 +658,7 @@ def _run_node(
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
- error=run_result.error,
+ error=run_result.error or "Unknown error",
retry_index=retries,
start_at=retry_start_at,
)
@@ -732,20 +737,20 @@ def _run_node(
variable_value=variable_value,
)
- # add parallel info to run result metadata
- if parallel_id and parallel_start_node_id:
- if not run_result.metadata:
- run_result.metadata = {}
+ # When setting metadata, convert to dict first
+ if not run_result.metadata:
+ run_result.metadata = {}
- run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
- run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
- parallel_start_node_id
- )
+ if parallel_id and parallel_start_node_id:
+ metadata_dict = dict(run_result.metadata)
+ metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
+ metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
if parent_parallel_id and parent_parallel_start_node_id:
- run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
- run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
+ metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
+ metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)
+ run_result.metadata = metadata_dict
yield NodeRunSucceededEvent(
id=node_instance.id,
@@ -869,8 +874,8 @@ def _handle_continue_on_error(
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
# add error message to handle_exceptions
- handle_exceptions.append(error_result.error)
- node_error_args = {
+ handle_exceptions.append(error_result.error or "")
+ node_error_args: dict[str, Any] = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": error_result.error,
"inputs": error_result.inputs,
diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py
index ed033e7f283961..40213bd151f7af 100644
--- a/api/core/workflow/nodes/answer/answer_stream_processor.py
+++ b/api/core/workflow/nodes/answer/answer_stream_processor.py
@@ -63,7 +63,7 @@ def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generat
self._remove_unreachable_nodes(event)
# generate stream outputs
- yield from self._generate_stream_outputs_when_node_finished(event)
+ yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event))
else:
yield event
@@ -130,7 +130,7 @@ def _generate_stream_outputs_when_node_finished(
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
- from_variable_selector=value_selector,
+ from_variable_selector=list(value_selector),
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py
index d785397e130565..8ffb487ec108f8 100644
--- a/api/core/workflow/nodes/answer/base_stream_processor.py
+++ b/api/core/workflow/nodes/answer/base_stream_processor.py
@@ -3,7 +3,7 @@
from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
+from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
logger = logging.getLogger(__name__)
@@ -19,7 +19,7 @@ def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
- def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
+ def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
@@ -32,8 +32,8 @@ def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
return
if run_result.edge_source_handle:
- reachable_node_ids = []
- unreachable_first_node_ids = []
+ reachable_node_ids: list[str] = []
+ unreachable_first_node_ids: list[str] = []
if finished_node_id not in self.graph.edge_mapping:
logger.warning(f"node {finished_node_id} has no edge mapping")
return
diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py
index 529fd7be74e9a1..6bf8899f5d698b 100644
--- a/api/core/workflow/nodes/base/entities.py
+++ b/api/core/workflow/nodes/base/entities.py
@@ -38,7 +38,8 @@ def _parse_json(value: str) -> Any:
@staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
"""Unified array type validation"""
- return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
+ # FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
+ return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
@staticmethod
def _convert_number(value: str) -> float:
@@ -84,7 +85,7 @@ def validate_value_type(self) -> "DefaultValue":
},
}
- validator = type_validators.get(self.type)
+ validator: dict[str, Any] = type_validators.get(self.type, {})
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py
index 4e371ca43645a5..2f82bf8c382b55 100644
--- a/api/core/workflow/nodes/code/code_node.py
+++ b/api/core/workflow/nodes/code/code_node.py
@@ -125,7 +125,7 @@ def _transform_result(
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
- transformed_result = {}
+ transformed_result: dict[str, Any] = {}
if output_schema is None:
# validate output thought instance type
for output_name, output_value in result.items():
diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py
index e78183baf12389..a4540358883210 100644
--- a/api/core/workflow/nodes/code/entities.py
+++ b/api/core/workflow/nodes/code/entities.py
@@ -14,7 +14,7 @@ class CodeNodeData(BaseNodeData):
class Output(BaseModel):
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
- children: Optional[dict[str, "Output"]] = None
+ children: Optional[dict[str, "CodeNodeData.Output"]] = None
class Dependency(BaseModel):
name: str
diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py
index 6d82dbe6d70da3..0b1dc611c59da2 100644
--- a/api/core/workflow/nodes/document_extractor/node.py
+++ b/api/core/workflow/nodes/document_extractor/node.py
@@ -4,6 +4,7 @@
import logging
import os
import tempfile
+from typing import cast
import docx
import pandas as pd
@@ -159,7 +160,7 @@ def _extract_text_from_yaml(file_content: bytes) -> str:
"""Extract the content from yaml file"""
try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore"))
- return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
+ return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
except (UnicodeDecodeError, yaml.YAMLError) as e:
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
@@ -229,9 +230,9 @@ def _download_file_content(file: File) -> bytes:
raise FileDownloadError("Missing URL for remote file")
response = ssrf_proxy.get(file.remote_url)
response.raise_for_status()
- return response.content
+ return cast(bytes, response.content)
else:
- return file_manager.download(file)
+ return cast(bytes, file_manager.download(file))
except Exception as e:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py
index 0db1ba9f09d36e..b3678a82b73959 100644
--- a/api/core/workflow/nodes/end/end_stream_generate_router.py
+++ b/api/core/workflow/nodes/end/end_stream_generate_router.py
@@ -67,7 +67,7 @@ def extract_stream_variable_selector_from_node_data(
and node_type == NodeType.LLM.value
and variable_selector.value_selector[1] == "text"
):
- value_selectors.append(variable_selector.value_selector)
+ value_selectors.append(list(variable_selector.value_selector))
return value_selectors
@@ -119,8 +119,7 @@ def _recursive_fetch_end_dependencies(
current_node_id: str,
end_node_id: str,
node_id_config_mapping: dict[str, dict],
- reverse_edge_mapping: dict[str, list["GraphEdge"]],
- # type: ignore[name-defined]
+ reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
end_dependencies: dict[str, list[str]],
) -> None:
"""
diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py
index 1aecf863ac5fb9..a770eb951f6c8c 100644
--- a/api/core/workflow/nodes/end/end_stream_processor.py
+++ b/api/core/workflow/nodes/end/end_stream_processor.py
@@ -23,7 +23,7 @@ def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.route_position[end_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
self.has_output = False
- self.output_node_ids = set()
+ self.output_node_ids: set[str] = set()
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
diff --git a/api/core/workflow/nodes/event/event.py b/api/core/workflow/nodes/event/event.py
index 137b47655102af..9fea3fbda3141f 100644
--- a/api/core/workflow/nodes/event/event.py
+++ b/api/core/workflow/nodes/event/event.py
@@ -42,6 +42,6 @@ class RunRetryEvent(BaseModel):
class SingleStepRetryEvent(NodeRunResult):
"""Single step retry event"""
- status: str = WorkflowNodeExecutionStatus.RETRY.value
+ status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RETRY
elapsed_time: float = Field(..., description="elapsed time")
diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py
index 575db15d365efb..cdfdc6e6d51b77 100644
--- a/api/core/workflow/nodes/http_request/executor.py
+++ b/api/core/workflow/nodes/http_request/executor.py
@@ -107,9 +107,9 @@ def _init_params(self):
if not (key := key.strip()):
continue
- value = value[0].strip() if value else ""
+ value_str = value[0].strip() if value else ""
result.append(
- (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value).text)
+ (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text)
)
self.params = result
@@ -182,9 +182,10 @@ def _init_body(self):
self.variable_pool.convert_template(item.key).text: item.file
for item in filter(lambda item: item.type == "file", data)
}
+ files: dict[str, Any] = {}
files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()}
files = {k: v for k, v in files.items() if v is not None}
- files = {k: variable.value for k, variable in files.items()}
+ files = {k: variable.value for k, variable in files.items() if variable is not None}
files = {
k: (v.filename, file_manager.download(v), v.mime_type or "application/octet-stream")
for k, v in files.items()
@@ -258,7 +259,8 @@ def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response:
response = getattr(ssrf_proxy, self.method)(**request_args)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e))
- return response
+ # FIXME: fix type ignore, this maybe httpx type issue
+ return response # type: ignore
def invoke(self) -> Response:
# assemble headers
@@ -300,37 +302,37 @@ def to_log(self):
continue
raw += f"{k}: {v}\r\n"
- body = ""
+ body_string = ""
if self.files:
for k, v in self.files.items():
- body += f"--{boundary}\r\n"
- body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
- body += f"{v[1]}\r\n"
- body += f"--{boundary}--\r\n"
+ body_string += f"--{boundary}\r\n"
+ body_string += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
+ body_string += f"{v[1]}\r\n"
+ body_string += f"--{boundary}--\r\n"
elif self.node_data.body:
if self.content:
if isinstance(self.content, str):
- body = self.content
+ body_string = self.content
elif isinstance(self.content, bytes):
- body = self.content.decode("utf-8", errors="replace")
+ body_string = self.content.decode("utf-8", errors="replace")
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
- body = urlencode(self.data)
+ body_string = urlencode(self.data)
elif self.data and self.node_data.body.type == "form-data":
for key, value in self.data.items():
- body += f"--{boundary}\r\n"
- body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
- body += f"{value}\r\n"
- body += f"--{boundary}--\r\n"
+ body_string += f"--{boundary}\r\n"
+ body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
+ body_string += f"{value}\r\n"
+ body_string += f"--{boundary}--\r\n"
elif self.json:
- body = json.dumps(self.json)
+ body_string = json.dumps(self.json)
elif self.node_data.body.type == "raw-text":
if len(self.node_data.body.data) != 1:
raise RequestBodyError("raw-text body type should have exactly one item")
- body = self.node_data.body.data[0].value
- if body:
- raw += f"Content-Length: {len(body)}\r\n"
+ body_string = self.node_data.body.data[0].value
+ if body_string:
+ raw += f"Content-Length: {len(body_string)}\r\n"
raw += "\r\n" # Empty line between headers and body
- raw += body
+ raw += body_string
return raw
diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py
index ebed690f6f3ffb..861119f26cb088 100644
--- a/api/core/workflow/nodes/http_request/node.py
+++ b/api/core/workflow/nodes/http_request/node.py
@@ -1,7 +1,7 @@
import logging
import mimetypes
from collections.abc import Mapping, Sequence
-from typing import Any
+from typing import Any, Optional
from configs import dify_config
from core.file import File, FileTransferMethod
@@ -36,7 +36,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
_node_type = NodeType.HTTP_REQUEST
@classmethod
- def get_default_config(cls, filters: dict | None = None) -> dict:
+ def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
return {
"type": "http-request",
"config": {
@@ -160,8 +160,8 @@ def _extract_variable_selector_to_variable_mapping(
)
mapping = {}
- for selector in selectors:
- mapping[node_id + "." + selector.variable] = selector.value_selector
+ for selector_iter in selectors:
+ mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector
return mapping
diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py
index 6a89cbfad61684..f1289558fffa82 100644
--- a/api/core/workflow/nodes/iteration/iteration_node.py
+++ b/api/core/workflow/nodes/iteration/iteration_node.py
@@ -361,13 +361,16 @@ def _handle_event_metadata(
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
-
if NodeRunMetadataKey.ITERATION_ID not in metadata:
- metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
- if self.node_data.is_parallel:
- metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
- else:
- metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index
+ metadata = {
+ **metadata,
+ NodeRunMetadataKey.ITERATION_ID: self.node_id,
+ NodeRunMetadataKey.PARALLEL_MODE_RUN_ID
+ if self.node_data.is_parallel
+ else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id
+ if self.node_data.is_parallel
+ else iter_run_index,
+ }
event.route_node_state.node_run_result.metadata = metadata
return event
diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
index 4f9e415f4b83a3..bfd93c074dd6d5 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
@@ -147,6 +147,8 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query:
planning_strategy=planning_strategy,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
+ if node_data.multiple_retrieval_config is None:
+ raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
@@ -157,6 +159,8 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query:
reranking_model = None
weights = None
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
+ if node_data.multiple_retrieval_config.weights is None:
+ raise ValueError("weights is required")
reranking_model = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
@@ -180,7 +184,9 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query:
available_datasets=available_datasets,
query=query,
top_k=node_data.multiple_retrieval_config.top_k,
- score_threshold=node_data.multiple_retrieval_config.score_threshold,
+ score_threshold=node_data.multiple_retrieval_config.score_threshold
+ if node_data.multiple_retrieval_config.score_threshold is not None
+ else 0.0,
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
reranking_model=reranking_model,
weights=weights,
@@ -205,7 +211,7 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query:
"content": item.page_content,
}
retrieval_resource_list.append(source)
- document_score_list = {}
+ document_score_list: dict[str, float] = {}
# deal with dify documents
if dify_documents:
document_score_list = {}
@@ -260,7 +266,9 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query:
retrieval_resource_list.append(source)
if retrieval_resource_list:
retrieval_resource_list = sorted(
- retrieval_resource_list, key=lambda x: x.get("metadata").get("score") or 0.0, reverse=True
+ retrieval_resource_list,
+ key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
+ reverse=True,
)
position = 1
for item in retrieval_resource_list:
@@ -295,6 +303,8 @@ def _fetch_model_config(
:param node_data: node data
:return:
"""
+ if node_data.single_retrieval_config is None:
+ raise ValueError("single_retrieval_config is required")
model_name = node_data.single_retrieval_config.model.name
provider_name = node_data.single_retrieval_config.model.provider
diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py
index 79066cece4f93c..432c57294ecbe9 100644
--- a/api/core/workflow/nodes/list_operator/node.py
+++ b/api/core/workflow/nodes/list_operator/node.py
@@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence
-from typing import Literal, Union
+from typing import Any, Literal, Union
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
@@ -17,9 +17,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
_node_type = NodeType.LIST_OPERATOR
def _run(self):
- inputs = {}
- process_data = {}
- outputs = {}
+ inputs: dict[str, list] = {}
+ process_data: dict[str, list] = {}
+ outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
if variable is None:
@@ -93,6 +93,8 @@ def _run(self):
def _apply_filter(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
+ filter_func: Callable[[Any], bool]
+ result: list[Any] = []
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
@@ -236,6 +238,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
+ extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
@@ -249,47 +252,47 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
raise InvalidKeyError(f"Invalid key: {key}")
-def _contains(value: str):
+def _contains(value: str) -> Callable[[str], bool]:
return lambda x: value in x
-def _startswith(value: str):
+def _startswith(value: str) -> Callable[[str], bool]:
return lambda x: x.startswith(value)
-def _endswith(value: str):
+def _endswith(value: str) -> Callable[[str], bool]:
return lambda x: x.endswith(value)
-def _is(value: str):
+def _is(value: str) -> Callable[[str], bool]:
return lambda x: x is value
-def _in(value: str | Sequence[str]):
+def _in(value: str | Sequence[str]) -> Callable[[str], bool]:
return lambda x: x in value
-def _eq(value: int | float):
+def _eq(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x == value
-def _ne(value: int | float):
+def _ne(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x != value
-def _lt(value: int | float):
+def _lt(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x < value
-def _le(value: int | float):
+def _le(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x <= value
-def _gt(value: int | float):
+def _gt(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x > value
-def _ge(value: int | float):
+def _ge(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x >= value
@@ -302,6 +305,7 @@ def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
+ extract_func: Callable[[File], Any]
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
extract_func = _get_file_extract_string_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py
index 55fac45576c821..6909b30c9e82ca 100644
--- a/api/core/workflow/nodes/llm/node.py
+++ b/api/core/workflow/nodes/llm/node.py
@@ -88,8 +88,8 @@ class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
- def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]:
- node_inputs = None
+ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
+ node_inputs: Optional[dict[str, Any]] = None
process_data = None
try:
@@ -196,7 +196,6 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]
error_type=type(e).__name__,
)
)
- return
except Exception as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
@@ -206,7 +205,6 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]
process_data=process_data,
)
)
- return
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
@@ -302,7 +300,7 @@ def _transform_chat_messages(
return messages
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
- variables = {}
+ variables: dict[str, Any] = {}
if not node_data.prompt_config:
return variables
@@ -319,7 +317,7 @@ def parse_dict(input_dict: Mapping[str, Any]) -> str:
"""
# check if it's a context structure
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
- return input_dict["content"]
+ return str(input_dict["content"])
# else, parse the dict
try:
@@ -557,7 +555,8 @@ def _fetch_prompt_messages(
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
- prompt_messages = []
+ # FIXME: fix the type error cause prompt_messages is type quick a few times
+ prompt_messages: list[Any] = []
if isinstance(prompt_template, list):
# For chat model
@@ -783,7 +782,7 @@ def _extract_variable_selector_to_variable_mapping(
else:
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
- variable_mapping = {}
+ variable_mapping: dict[str, Any] = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
@@ -981,7 +980,7 @@ def _handle_memory_chat_mode(
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
- memory_messages = []
+ memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py
index 6fdff966026b63..a366c287c2ac56 100644
--- a/api/core/workflow/nodes/loop/loop_node.py
+++ b/api/core/workflow/nodes/loop/loop_node.py
@@ -14,8 +14,8 @@ class LoopNode(BaseNode[LoopNodeData]):
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
- def _run(self) -> LoopState:
- return super()._run()
+ def _run(self) -> LoopState: # type: ignore
+ return super()._run() # type: ignore
@classmethod
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
@@ -28,7 +28,7 @@ def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
# TODO waiting for implementation
return [
- Condition(
+ Condition( # type: ignore
variable_selector=[node_id, "index"],
comparison_operator="≤",
value_type="value_selector",
diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py
index a001b44dc7dfee..369eb13b04e8c4 100644
--- a/api/core/workflow/nodes/parameter_extractor/entities.py
+++ b/api/core/workflow/nodes/parameter_extractor/entities.py
@@ -25,7 +25,7 @@ def validate_name(cls, value) -> str:
raise ValueError("Parameter name is required")
if value in {"__reason", "__is_success"}:
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
- return value
+ return str(value)
class ParameterExtractorNodeData(BaseNodeData):
@@ -52,7 +52,7 @@ def get_parameter_json_schema(self) -> dict:
:return: parameter json schema
"""
- parameters = {"type": "object", "properties": {}, "required": []}
+ parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description}
diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
index c8c854a43b3269..9c88047f2c8e57 100644
--- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
+++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
@@ -63,7 +63,8 @@ class ParameterExtractorNode(LLMNode):
Parameter Extractor Node.
"""
- _node_data_cls = ParameterExtractorNodeData
+ # FIXME: figure out why here is different from super class
+ _node_data_cls = ParameterExtractorNodeData # type: ignore
_node_type = NodeType.PARAMETER_EXTRACTOR
_model_instance: Optional[ModelInstance] = None
@@ -253,6 +254,9 @@ def _invoke(
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
+ if text is None:
+ text = ""
+
return text, usage, tool_call
def _generate_function_call_prompt(
@@ -605,9 +609,10 @@ def extract_json(text):
json_str = extract_json(result[idx:])
if json_str:
try:
- return json.loads(json_str)
+ return cast(dict, json.loads(json_str))
except Exception:
pass
+ return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
"""
@@ -616,13 +621,13 @@ def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCal
if not tool_call or not tool_call.function.arguments:
return None
- return json.loads(tool_call.function.arguments)
+ return cast(dict, json.loads(tool_call.function.arguments))
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
"""
Generate default result.
"""
- result = {}
+ result: dict[str, Any] = {}
for parameter in data.parameters:
if parameter.type == "number":
result[parameter.name] = 0
@@ -772,7 +777,7 @@ def _extract_variable_selector_to_variable_mapping(
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: ParameterExtractorNodeData,
+ node_data: ParameterExtractorNodeData, # type: ignore
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -781,6 +786,7 @@ def _extract_variable_selector_to_variable_mapping(
:param node_data: node data
:return:
"""
+ # FIXME: fix the type error later
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
if node_data.instruction:
diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py
index e603add1704544..6c3155ac9a54e3 100644
--- a/api/core/workflow/nodes/parameter_extractor/prompts.py
+++ b/api/core/workflow/nodes/parameter_extractor/prompts.py
@@ -1,3 +1,5 @@
+from typing import Any
+
FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters"
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy.
@@ -35,7 +37,7 @@
""" # noqa: E501
-FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [
+FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [
{
"user": {
"query": "What is the weather today in SF?",
diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py
index 31f8368d590ea9..0ec44eefacf52f 100644
--- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py
+++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py
@@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
-from typing import TYPE_CHECKING, Any, Optional, cast
+from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -34,12 +34,9 @@
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
-if TYPE_CHECKING:
- from core.file import File
-
class QuestionClassifierNode(LLMNode):
- _node_data_cls = QuestionClassifierNodeData
+ _node_data_cls = QuestionClassifierNodeData # type: ignore
_node_type = NodeType.QUESTION_CLASSIFIER
def _run(self):
@@ -61,7 +58,7 @@ def _run(self):
node_data.instruction = node_data.instruction or ""
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
- files: Sequence[File] = (
+ files = (
self._fetch_files(
selector=node_data.vision.configs.variable_selector,
)
@@ -168,7 +165,7 @@ def _extract_variable_selector_to_variable_mapping(
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: QuestionClassifierNodeData,
+ node_data: Any,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -177,6 +174,7 @@ def _extract_variable_selector_to_variable_mapping(
:param node_data: node data
:return:
"""
+ node_data = cast(QuestionClassifierNodeData, node_data)
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors = []
if node_data.instruction:
diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py
index 983fa7e623177a..01d07e494944b4 100644
--- a/api/core/workflow/nodes/tool/tool_node.py
+++ b/api/core/workflow/nodes/tool/tool_node.py
@@ -9,7 +9,6 @@
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
-from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
@@ -46,6 +45,8 @@ def _run(self) -> NodeRunResult:
# get tool runtime
try:
+ from core.tools.tool_manager import ToolManager
+
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
)
@@ -142,7 +143,7 @@ def _generate_parameters(
"""
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
- result = {}
+ result: dict[str, Any] = {}
for parameter_name in node_data.tool_parameters:
parameter = tool_parameters_dictionary.get(parameter_name)
if not parameter:
@@ -264,9 +265,9 @@ def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) ->
"""
return "\n".join(
[
- f"{message.message}"
+ str(message.message)
if message.type == ToolInvokeMessage.MessageType.TEXT
- else f"Link: {message.message}"
+ else f"Link: {str(message.message)}"
for message in tool_response
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}
]
diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py
index 8eb4bd5c2da573..9acc76f326eec9 100644
--- a/api/core/workflow/nodes/variable_assigner/v1/node.py
+++ b/api/core/workflow/nodes/variable_assigner/v1/node.py
@@ -36,6 +36,8 @@ def _run(self) -> NodeRunResult:
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
+ if income_value is None:
+ raise VariableOperatorNodeError("income value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py
index d73c7442029225..0c4aae827c0a0f 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/node.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/node.py
@@ -1,5 +1,5 @@
import json
-from typing import Any
+from typing import Any, cast
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
@@ -29,7 +29,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump()
- process_data = {}
+ process_data: dict[str, Any] = {}
# NOTE: This node has no outputs
updated_variables: list[Variable] = []
@@ -119,7 +119,7 @@ def _run(self) -> NodeRunResult:
else:
conversation_id = conversation_id.value
common_helpers.update_conversation_variable(
- conversation_id=conversation_id,
+ conversation_id=cast(str, conversation_id),
variable=variable,
)
diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py
index 811e40c11e5407..b14c6fafbd9fdc 100644
--- a/api/core/workflow/workflow_entry.py
+++ b/api/core/workflow/workflow_entry.py
@@ -129,11 +129,11 @@ def single_step_run(
:return:
"""
# fetch node info from workflow graph
- graph = workflow.graph_dict
- if not graph:
+ workflow_graph = workflow.graph_dict
+ if not workflow_graph:
raise ValueError("workflow graph not found")
- nodes = graph.get("nodes")
+ nodes = workflow_graph.get("nodes")
if not nodes:
raise ValueError("nodes not found in workflow graph")
@@ -196,7 +196,8 @@ def single_step_run(
@staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
- return WorkflowEntry._handle_special_values(value)
+ result = WorkflowEntry._handle_special_values(value)
+ return result if isinstance(result, Mapping) or result is None else dict(result)
@staticmethod
def _handle_special_values(value: Any) -> Any:
@@ -208,10 +209,10 @@ def _handle_special_values(value: Any) -> Any:
res[k] = WorkflowEntry._handle_special_values(v)
return res
if isinstance(value, list):
- res = []
+ res_list = []
for item in value:
- res.append(WorkflowEntry._handle_special_values(item))
- return res
+ res_list.append(WorkflowEntry._handle_special_values(item))
+ return res_list
if isinstance(value, File):
return value.to_dict()
return value
diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py
index 24fa013697994c..8a677f6b6fc017 100644
--- a/api/events/event_handlers/create_document_index.py
+++ b/api/events/event_handlers/create_document_index.py
@@ -14,7 +14,7 @@
@document_index_created.connect
def handle(sender, **kwargs):
dataset_id = sender
- document_ids = kwargs.get("document_ids")
+ document_ids = kwargs.get("document_ids", [])
documents = []
start_at = time.perf_counter()
for document_id in document_ids:
diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py
index 1515661b2d45b8..5e7caf8cbed71e 100644
--- a/api/events/event_handlers/create_site_record_when_app_created.py
+++ b/api/events/event_handlers/create_site_record_when_app_created.py
@@ -8,18 +8,19 @@ def handle(sender, **kwargs):
"""Create site record when an app is created."""
app = sender
account = kwargs.get("account")
- site = Site(
- app_id=app.id,
- title=app.name,
- icon_type=app.icon_type,
- icon=app.icon,
- icon_background=app.icon_background,
- default_language=account.interface_language,
- customize_token_strategy="not_allow",
- code=Site.generate_code(16),
- created_by=app.created_by,
- updated_by=app.updated_by,
- )
+ if account is not None:
+ site = Site(
+ app_id=app.id,
+ title=app.name,
+ icon_type=app.icon_type,
+ icon=app.icon,
+ icon_background=app.icon_background,
+ default_language=account.interface_language,
+ customize_token_strategy="not_allow",
+ code=Site.generate_code(16),
+ created_by=app.created_by,
+ updated_by=app.updated_by,
+ )
- db.session.add(site)
- db.session.commit()
+ db.session.add(site)
+ db.session.commit()
diff --git a/api/events/event_handlers/deduct_quota_when_message_created.py b/api/events/event_handlers/deduct_quota_when_message_created.py
index 843a2320968ced..1ed37efba0b3be 100644
--- a/api/events/event_handlers/deduct_quota_when_message_created.py
+++ b/api/events/event_handlers/deduct_quota_when_message_created.py
@@ -44,7 +44,7 @@ def handle(sender, **kwargs):
else:
used_quota = 1
- if used_quota is not None:
+ if used_quota is not None and system_configuration.current_quota_type is not None:
db.session.query(Provider).filter(
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
Provider.provider_name == model_config.provider,
diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py
index 9c5955c8c5a1a5..f89fae24a56378 100644
--- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py
+++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py
@@ -8,7 +8,10 @@
@app_draft_workflow_was_synced.connect
def handle(sender, **kwargs):
app = sender
- for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []):
+ synced_draft_workflow = kwargs.get("synced_draft_workflow")
+ if synced_draft_workflow is None:
+ return
+ for node_data in synced_draft_workflow.graph_dict.get("nodes", []):
if node_data.get("data", {}).get("type") == NodeType.TOOL.value:
try:
tool_entity = ToolEntity(**node_data["data"])
diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py
index de7c0f4dfeb74f..408ed31096d2a0 100644
--- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py
+++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py
@@ -8,16 +8,18 @@
def handle(sender, **kwargs):
app = sender
app_model_config = kwargs.get("app_model_config")
+ if app_model_config is None:
+ return
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
- removed_dataset_ids = []
+ removed_dataset_ids: set[int] = set()
if not app_dataset_joins:
added_dataset_ids = dataset_ids
else:
- old_dataset_ids = set()
+ old_dataset_ids: set[int] = set()
old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
added_dataset_ids = dataset_ids - old_dataset_ids
@@ -37,8 +39,8 @@ def handle(sender, **kwargs):
db.session.commit()
-def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set:
- dataset_ids = set()
+def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[int]:
+ dataset_ids: set[int] = set()
if not app_model_config:
return dataset_ids
diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py
index 453395e8d7dc1c..7a31c82f6adbc2 100644
--- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py
+++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py
@@ -17,11 +17,11 @@ def handle(sender, **kwargs):
dataset_ids = get_dataset_ids_from_workflow(published_workflow)
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
- removed_dataset_ids = []
+ removed_dataset_ids: set[int] = set()
if not app_dataset_joins:
added_dataset_ids = dataset_ids
else:
- old_dataset_ids = set()
+ old_dataset_ids: set[int] = set()
old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)
added_dataset_ids = dataset_ids - old_dataset_ids
@@ -41,8 +41,8 @@ def handle(sender, **kwargs):
db.session.commit()
-def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set:
- dataset_ids = set()
+def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[int]:
+ dataset_ids: set[int] = set()
graph = published_workflow.graph_dict
if not graph:
return dataset_ids
@@ -60,7 +60,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set:
for node in knowledge_retrieval_nodes:
try:
node_data = KnowledgeRetrievalNodeData(**node.get("data", {}))
- dataset_ids.update(node_data.dataset_ids)
+ dataset_ids.update(int(dataset_id) for dataset_id in node_data.dataset_ids)
except Exception as e:
continue
diff --git a/api/extensions/__init__.py b/api/extensions/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/api/extensions/ext_app_metrics.py b/api/extensions/ext_app_metrics.py
index de1cdfeb984e86..b7d412d68deda1 100644
--- a/api/extensions/ext_app_metrics.py
+++ b/api/extensions/ext_app_metrics.py
@@ -54,12 +54,14 @@ def pool_stat():
from extensions.ext_database import db
engine = db.engine
+ # TODO: Fix the type error
+ # FIXME maybe its sqlalchemy issue
return {
"pid": os.getpid(),
- "pool_size": engine.pool.size(),
- "checked_in_connections": engine.pool.checkedin(),
- "checked_out_connections": engine.pool.checkedout(),
- "overflow_connections": engine.pool.overflow(),
- "connection_timeout": engine.pool.timeout(),
- "recycle_time": db.engine.pool._recycle,
+ "pool_size": engine.pool.size(), # type: ignore
+ "checked_in_connections": engine.pool.checkedin(), # type: ignore
+ "checked_out_connections": engine.pool.checkedout(), # type: ignore
+ "overflow_connections": engine.pool.overflow(), # type: ignore
+ "connection_timeout": engine.pool.timeout(), # type: ignore
+ "recycle_time": db.engine.pool._recycle, # type: ignore
}
diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py
index 9dbc4b93d46266..30f216ff95612b 100644
--- a/api/extensions/ext_celery.py
+++ b/api/extensions/ext_celery.py
@@ -1,8 +1,8 @@
from datetime import timedelta
import pytz
-from celery import Celery, Task
-from celery.schedules import crontab
+from celery import Celery, Task # type: ignore
+from celery.schedules import crontab # type: ignore
from configs import dify_config
from dify_app import DifyApp
@@ -47,7 +47,7 @@ def __call__(self, *args: object, **kwargs: object) -> object:
worker_log_format=dify_config.LOG_FORMAT,
worker_task_log_format=dify_config.LOG_FORMAT,
worker_hijack_root_logger=False,
- timezone=pytz.timezone(dify_config.LOG_TZ),
+ timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"),
)
if dify_config.BROKER_USE_SSL:
diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py
index 9c3a663af417ae..26ff6427bef1cc 100644
--- a/api/extensions/ext_compress.py
+++ b/api/extensions/ext_compress.py
@@ -7,7 +7,7 @@ def is_enabled() -> bool:
def init_app(app: DifyApp):
- from flask_compress import Compress
+ from flask_compress import Compress # type: ignore
compress = Compress()
compress.init_app(app)
diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py
index 9fc29b4eb17212..e1c459e8c17fd0 100644
--- a/api/extensions/ext_logging.py
+++ b/api/extensions/ext_logging.py
@@ -11,7 +11,7 @@
def init_app(app: DifyApp):
- log_handlers = []
+ log_handlers: list[logging.Handler] = []
log_file = dify_config.LOG_FILE
if log_file:
log_dir = os.path.dirname(log_file)
@@ -49,7 +49,8 @@ def time_converter(seconds):
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
for handler in logging.root.handlers:
- handler.formatter.converter = time_converter
+ if handler.formatter:
+ handler.formatter.converter = time_converter
def get_request_id():
diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py
index b2955307144d67..10fb89eb7370ee 100644
--- a/api/extensions/ext_login.py
+++ b/api/extensions/ext_login.py
@@ -1,6 +1,6 @@
import json
-import flask_login
+import flask_login # type: ignore
from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import Unauthorized
diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py
index 468aedd47ea90b..9240ebe7fcba73 100644
--- a/api/extensions/ext_mail.py
+++ b/api/extensions/ext_mail.py
@@ -26,7 +26,7 @@ def init_app(self, app: Flask):
match mail_type:
case "resend":
- import resend
+ import resend # type: ignore
api_key = dify_config.RESEND_API_KEY
if not api_key:
@@ -48,9 +48,9 @@ def init_app(self, app: Flask):
self._client = SMTPClient(
server=dify_config.SMTP_SERVER,
port=dify_config.SMTP_PORT,
- username=dify_config.SMTP_USERNAME,
- password=dify_config.SMTP_PASSWORD,
- _from=dify_config.MAIL_DEFAULT_SEND_FROM,
+ username=dify_config.SMTP_USERNAME or "",
+ password=dify_config.SMTP_PASSWORD or "",
+ _from=dify_config.MAIL_DEFAULT_SEND_FROM or "",
use_tls=dify_config.SMTP_USE_TLS,
opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS,
)
diff --git a/api/extensions/ext_migrate.py b/api/extensions/ext_migrate.py
index 6d8f35c30d9c65..5f862181fa8540 100644
--- a/api/extensions/ext_migrate.py
+++ b/api/extensions/ext_migrate.py
@@ -2,7 +2,7 @@
def init_app(app: DifyApp):
- import flask_migrate
+ import flask_migrate # type: ignore
from extensions.ext_database import db
diff --git a/api/extensions/ext_proxy_fix.py b/api/extensions/ext_proxy_fix.py
index 3b895ac95b5029..514e0658257293 100644
--- a/api/extensions/ext_proxy_fix.py
+++ b/api/extensions/ext_proxy_fix.py
@@ -6,4 +6,4 @@ def init_app(app: DifyApp):
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
from werkzeug.middleware.proxy_fix import ProxyFix
- app.wsgi_app = ProxyFix(app.wsgi_app)
+ app.wsgi_app = ProxyFix(app.wsgi_app) # type: ignore
diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py
index 3ec8ae6e1dc14e..3a74aace6a34cf 100644
--- a/api/extensions/ext_sentry.py
+++ b/api/extensions/ext_sentry.py
@@ -6,7 +6,7 @@ def init_app(app: DifyApp):
if dify_config.SENTRY_DSN:
import openai
import sentry_sdk
- from langfuse import parse_error
+ from langfuse import parse_error # type: ignore
from sentry_sdk.integrations.celery import CeleryIntegration
from sentry_sdk.integrations.flask import FlaskIntegration
from werkzeug.exceptions import HTTPException
diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py
index 42422263c4dd03..588bdb2d2717e0 100644
--- a/api/extensions/ext_storage.py
+++ b/api/extensions/ext_storage.py
@@ -1,6 +1,6 @@
import logging
from collections.abc import Callable, Generator
-from typing import Union
+from typing import Literal, Union, overload
from flask import Flask
@@ -79,6 +79,12 @@ def save(self, filename, data):
logger.exception(f"Failed to save file {filename}")
raise e
+ @overload
+ def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ...
+
+ @overload
+ def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ...
+
def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]:
try:
if stream:
diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py
index 58c917dbd386bc..00bf5d4f93ae3b 100644
--- a/api/extensions/storage/aliyun_oss_storage.py
+++ b/api/extensions/storage/aliyun_oss_storage.py
@@ -1,7 +1,7 @@
import posixpath
from collections.abc import Generator
-import oss2 as aliyun_s3
+import oss2 as aliyun_s3 # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -33,7 +33,7 @@ def save(self, filename, data):
def load_once(self, filename: str) -> bytes:
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
- data = obj.read()
+ data: bytes = obj.read()
return data
def load_stream(self, filename: str) -> Generator:
@@ -41,14 +41,14 @@ def load_stream(self, filename: str) -> Generator:
while chunk := obj.read(4096):
yield chunk
- def download(self, filename, target_filepath):
+ def download(self, filename: str, target_filepath):
self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath)
- def exists(self, filename):
+ def exists(self, filename: str):
return self.client.object_exists(self.__wrapper_folder_filename(filename))
- def delete(self, filename):
+ def delete(self, filename: str):
self.client.delete_object(self.__wrapper_folder_filename(filename))
- def __wrapper_folder_filename(self, filename) -> str:
+ def __wrapper_folder_filename(self, filename: str) -> str:
return posixpath.join(self.folder, filename) if self.folder else filename
diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py
index ce36c2e7deeeda..7b6b2eedd62bf2 100644
--- a/api/extensions/storage/aws_s3_storage.py
+++ b/api/extensions/storage/aws_s3_storage.py
@@ -1,9 +1,9 @@
import logging
from collections.abc import Generator
-import boto3
-from botocore.client import Config
-from botocore.exceptions import ClientError
+import boto3 # type: ignore
+from botocore.client import Config # type: ignore
+from botocore.exceptions import ClientError # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -53,7 +53,7 @@ def save(self, filename, data):
def load_once(self, filename: str) -> bytes:
try:
- data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
+ data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found")
diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py
index b26caa8671b6df..2f8532f4f8f653 100644
--- a/api/extensions/storage/azure_blob_storage.py
+++ b/api/extensions/storage/azure_blob_storage.py
@@ -27,7 +27,7 @@ def load_once(self, filename: str) -> bytes:
client = self._sync_client()
blob = client.get_container_client(container=self.bucket_name)
blob = blob.get_blob_client(blob=filename)
- data = blob.download_blob().readall()
+ data: bytes = blob.download_blob().readall()
return data
def load_stream(self, filename: str) -> Generator:
@@ -63,11 +63,11 @@ def _sync_client(self):
sas_token = cache_result.decode("utf-8")
else:
sas_token = generate_account_sas(
- account_name=self.account_name,
- account_key=self.account_key,
+ account_name=self.account_name or "",
+ account_key=self.account_key or "",
resource_types=ResourceTypes(service=True, container=True, object=True),
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1),
)
redis_client.set(cache_key, sas_token, ex=3000)
- return BlobServiceClient(account_url=self.account_url, credential=sas_token)
+ return BlobServiceClient(account_url=self.account_url or "", credential=sas_token)
diff --git a/api/extensions/storage/baidu_obs_storage.py b/api/extensions/storage/baidu_obs_storage.py
index e0d2140e91272c..b94efa08be7613 100644
--- a/api/extensions/storage/baidu_obs_storage.py
+++ b/api/extensions/storage/baidu_obs_storage.py
@@ -2,9 +2,9 @@
import hashlib
from collections.abc import Generator
-from baidubce.auth.bce_credentials import BceCredentials
-from baidubce.bce_client_configuration import BceClientConfiguration
-from baidubce.services.bos.bos_client import BosClient
+from baidubce.auth.bce_credentials import BceCredentials # type: ignore
+from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore
+from baidubce.services.bos.bos_client import BosClient # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -36,7 +36,8 @@ def save(self, filename, data):
def load_once(self, filename: str) -> bytes:
response = self.client.get_object(bucket_name=self.bucket_name, key=filename)
- return response.data.read()
+ data: bytes = response.data.read()
+ return data
def load_stream(self, filename: str) -> Generator:
response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data
diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py
index 26b662d2f04daf..705639f42e716f 100644
--- a/api/extensions/storage/google_cloud_storage.py
+++ b/api/extensions/storage/google_cloud_storage.py
@@ -3,7 +3,7 @@
import json
from collections.abc import Generator
-from google.cloud import storage as google_cloud_storage
+from google.cloud import storage as google_cloud_storage # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -35,7 +35,7 @@ def save(self, filename, data):
def load_once(self, filename: str) -> bytes:
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
- data = blob.download_as_bytes()
+ data: bytes = blob.download_as_bytes()
return data
def load_stream(self, filename: str) -> Generator:
diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py
index 20be70ef83dd7a..07f1d199701be4 100644
--- a/api/extensions/storage/huawei_obs_storage.py
+++ b/api/extensions/storage/huawei_obs_storage.py
@@ -1,6 +1,6 @@
from collections.abc import Generator
-from obs import ObsClient
+from obs import ObsClient # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -23,7 +23,7 @@ def save(self, filename, data):
self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data)
def load_once(self, filename: str) -> bytes:
- data = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
+ data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
return data
def load_stream(self, filename: str) -> Generator:
diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py
index e671eff059ba21..b78fc94dae7843 100644
--- a/api/extensions/storage/opendal_storage.py
+++ b/api/extensions/storage/opendal_storage.py
@@ -3,7 +3,7 @@
from collections.abc import Generator
from pathlib import Path
-import opendal
+import opendal # type: ignore[import]
from dotenv import dotenv_values
from extensions.storage.base_storage import BaseStorage
@@ -18,7 +18,7 @@ def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str
if key.startswith(config_prefix):
kwargs[key[len(config_prefix) :].lower()] = value
- file_env_vars = dotenv_values(env_file_path)
+ file_env_vars: dict = dotenv_values(env_file_path) or {}
for key, value in file_env_vars.items():
if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value:
kwargs[key[len(config_prefix) :].lower()] = value
@@ -48,7 +48,7 @@ def load_once(self, filename: str) -> bytes:
if not self.exists(filename):
raise FileNotFoundError("File not found")
- content = self.op.read(path=filename)
+ content: bytes = self.op.read(path=filename)
logger.debug(f"file {filename} loaded")
return content
@@ -75,7 +75,7 @@ def exists(self, filename: str) -> bool:
# error handler here when opendal python-binding has a exists method, we should use it
# more https://github.com/apache/opendal/blob/main/bindings/python/src/operator.rs
try:
- res = self.op.stat(path=filename).mode.is_file()
+ res: bool = self.op.stat(path=filename).mode.is_file()
logger.debug(f"file {filename} checked")
return res
except Exception:
diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py
index b59f83b8de90bf..82829f7fd50d65 100644
--- a/api/extensions/storage/oracle_oci_storage.py
+++ b/api/extensions/storage/oracle_oci_storage.py
@@ -1,7 +1,7 @@
from collections.abc import Generator
-import boto3
-from botocore.exceptions import ClientError
+import boto3 # type: ignore
+from botocore.exceptions import ClientError # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -27,7 +27,7 @@ def save(self, filename, data):
def load_once(self, filename: str) -> bytes:
try:
- data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
+ data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found")
diff --git a/api/extensions/storage/supabase_storage.py b/api/extensions/storage/supabase_storage.py
index 9f7c69a9ae6312..711c3f72117c86 100644
--- a/api/extensions/storage/supabase_storage.py
+++ b/api/extensions/storage/supabase_storage.py
@@ -32,7 +32,7 @@ def save(self, filename, data):
self.client.storage.from_(self.bucket_name).upload(filename, data)
def load_once(self, filename: str) -> bytes:
- content = self.client.storage.from_(self.bucket_name).download(filename)
+ content: bytes = self.client.storage.from_(self.bucket_name).download(filename)
return content
def load_stream(self, filename: str) -> Generator:
diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py
index 13a6c9239c2d1e..9cdd3e67f75aab 100644
--- a/api/extensions/storage/tencent_cos_storage.py
+++ b/api/extensions/storage/tencent_cos_storage.py
@@ -1,6 +1,6 @@
from collections.abc import Generator
-from qcloud_cos import CosConfig, CosS3Client
+from qcloud_cos import CosConfig, CosS3Client # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -25,7 +25,7 @@ def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename)
def load_once(self, filename: str) -> bytes:
- data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read()
+ data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read()
return data
def load_stream(self, filename: str) -> Generator:
diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py
index de82be04ea87b7..55fe6545ec3d2d 100644
--- a/api/extensions/storage/volcengine_tos_storage.py
+++ b/api/extensions/storage/volcengine_tos_storage.py
@@ -1,6 +1,6 @@
from collections.abc import Generator
-import tos
+import tos # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
@@ -24,6 +24,8 @@ def save(self, filename, data):
def load_once(self, filename: str) -> bytes:
data = self.client.get_object(bucket=self.bucket_name, key=filename).read()
+ if not isinstance(data, bytes):
+ raise TypeError("Expected bytes, got {}".format(type(data).__name__))
return data
def load_stream(self, filename: str) -> Generator:
diff --git a/api/factories/__init__.py b/api/factories/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py
index 13034f5cf5688b..856cf62e3ed243 100644
--- a/api/factories/file_factory.py
+++ b/api/factories/file_factory.py
@@ -64,7 +64,7 @@ def build_from_mapping(
if not build_func:
raise ValueError(f"Invalid file transfer method: {transfer_method}")
- file = build_func(
+ file: File = build_func(
mapping=mapping,
tenant_id=tenant_id,
transfer_method=transfer_method,
@@ -72,7 +72,7 @@ def build_from_mapping(
if config and not _is_file_valid_with_config(
input_file_type=mapping.get("type", FileType.CUSTOM),
- file_extension=file.extension,
+ file_extension=file.extension or "",
file_transfer_method=file.transfer_method,
config=config,
):
@@ -281,6 +281,7 @@ def _get_file_type_by_extension(extension: str) -> FileType | None:
return FileType.AUDIO
elif extension in DOCUMENT_EXTENSIONS:
return FileType.DOCUMENT
+ return None
def _get_file_type_by_mimetype(mime_type: str) -> FileType | None:
diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py
index 16a578728aa16e..bbca8448ec0662 100644
--- a/api/factories/variable_factory.py
+++ b/api/factories/variable_factory.py
@@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
-from typing import Any
+from typing import Any, cast
from uuid import uuid4
from configs import dify_config
@@ -84,6 +84,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
raise VariableError("missing value type")
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
+ # FIXME: using Any here, fix it later
+ result: Any
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
@@ -109,7 +111,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector:
result = result.model_copy(update={"selector": selector})
- return result
+ return cast(Variable, result)
def build_segment(value: Any, /) -> Segment:
@@ -164,10 +166,13 @@ def segment_to_variable(
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
- return variable_class(
- id=id,
- name=name,
- description=description,
- value=segment.value,
- selector=selector,
+ return cast(
+ Variable,
+ variable_class(
+ id=id,
+ name=name,
+ description=description,
+ value=segment.value,
+ selector=selector,
+ ),
)
diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py
index 379dcc6d16fe56..1c58b3a2579087 100644
--- a/api/fields/annotation_fields.py
+++ b/api/fields/annotation_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from libs.helper import TimestampField
diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py
index a85d4a34dbe7b1..d40407bfcc6193 100644
--- a/api/fields/api_based_extension_fields.py
+++ b/api/fields/api_based_extension_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from libs.helper import TimestampField
diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py
index abb27fdad17d63..73800eab853cd3 100644
--- a/api/fields/app_fields.py
+++ b/api/fields/app_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from fields.workflow_fields import workflow_partial_fields
from libs.helper import AppIconUrlField, TimestampField
diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py
index 6a9e347b1e04b4..c54554a6de8405 100644
--- a/api/fields/conversation_fields.py
+++ b/api/fields/conversation_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py
index 983e50e73ceb9f..c6385efb5a3cf1 100644
--- a/api/fields/conversation_variable_fields.py
+++ b/api/fields/conversation_variable_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from libs.helper import TimestampField
diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py
index 071071376fe6c8..608672121e2b50 100644
--- a/api/fields/data_source_fields.py
+++ b/api/fields/data_source_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from libs.helper import TimestampField
diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py
index 533e3a0837b815..a74e6f54fb3858 100644
--- a/api/fields/dataset_fields.py
+++ b/api/fields/dataset_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from libs.helper import TimestampField
diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py
index a83ec7bc97adee..2b2ac6243f4da5 100644
--- a/api/fields/document_fields.py
+++ b/api/fields/document_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from fields.dataset_fields import dataset_fields
from libs.helper import TimestampField
diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py
index 99e529f9d1c076..aefa0b27580ca7 100644
--- a/api/fields/end_user_fields.py
+++ b/api/fields/end_user_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
simple_end_user_fields = {
"id": fields.String,
diff --git a/api/fields/external_dataset_fields.py b/api/fields/external_dataset_fields.py
index 2281460fe22146..9cc4e14a0575d7 100644
--- a/api/fields/external_dataset_fields.py
+++ b/api/fields/external_dataset_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from libs.helper import TimestampField
diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py
index afaacc0568ea0c..f896c15f0fec70 100644
--- a/api/fields/file_fields.py
+++ b/api/fields/file_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from libs.helper import TimestampField
diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py
index f36e80f8d493d5..aaafcab8ab6ba0 100644
--- a/api/fields/hit_testing_fields.py
+++ b/api/fields/hit_testing_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from libs.helper import TimestampField
diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py
index e0b3e340f67b8c..16f265b9bb6d07 100644
--- a/api/fields/installed_app_fields.py
+++ b/api/fields/installed_app_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from libs.helper import AppIconUrlField, TimestampField
diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py
index 1cf8e408d13d32..0c854c640c3f98 100644
--- a/api/fields/member_fields.py
+++ b/api/fields/member_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from libs.helper import TimestampField
diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py
index 5f6e7884a69c5e..0571faab08c134 100644
--- a/api/fields/message_fields.py
+++ b/api/fields/message_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField
diff --git a/api/fields/raws.py b/api/fields/raws.py
index 15ec16ab13e4a8..493d4b6cce7d31 100644
--- a/api/fields/raws.py
+++ b/api/fields/raws.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from core.file import File
diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py
index 2dd4cb45be409b..4413af31607897 100644
--- a/api/fields/segment_fields.py
+++ b/api/fields/segment_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from libs.helper import TimestampField
diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py
index 9af4fc57dd061c..986cd725f70910 100644
--- a/api/fields/tag_fields.py
+++ b/api/fields/tag_fields.py
@@ -1,3 +1,3 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String}
diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py
index a53b54624915c2..c45b33597b3978 100644
--- a/api/fields/workflow_app_log_fields.py
+++ b/api/fields/workflow_app_log_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py
index 0d860d6f406502..bd093d4063bc2e 100644
--- a/api/fields/workflow_fields.py
+++ b/api/fields/workflow_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from core.helper import encrypter
from core.variables import SecretVariable, SegmentType, Variable
diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py
index 74fdf8bd97b23a..ef59c57ec37957 100644
--- a/api/fields/workflow_run_fields.py
+++ b/api/fields/workflow_run_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
diff --git a/api/libs/external_api.py b/api/libs/external_api.py
index 179617ac0a6588..922d2d9cd33324 100644
--- a/api/libs/external_api.py
+++ b/api/libs/external_api.py
@@ -1,8 +1,9 @@
import re
import sys
+from typing import Any
from flask import current_app, got_request_exception
-from flask_restful import Api, http_status_message
+from flask_restful import Api, http_status_message # type: ignore
from werkzeug.datastructures import Headers
from werkzeug.exceptions import HTTPException
@@ -84,7 +85,7 @@ def handle_error(self, e):
# record the exception in the logs when we have a server error of status code: 500
if status_code and status_code >= 500:
- exc_info = sys.exc_info()
+ exc_info: Any = sys.exc_info()
if exc_info[1] is None:
exc_info = None
current_app.log_exception(exc_info)
@@ -100,7 +101,7 @@ def handle_error(self, e):
resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype)
elif status_code == 400:
if isinstance(data.get("message"), dict):
- param_key, param_value = list(data.get("message").items())[0]
+ param_key, param_value = list(data.get("message", {}).items())[0]
data = {"code": "invalid_param", "message": param_value, "params": param_key}
else:
if "code" not in data:
diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py
index 83f9c74e339e17..2dae87e1710bf6 100644
--- a/api/libs/gmpy2_pkcs10aep_cipher.py
+++ b/api/libs/gmpy2_pkcs10aep_cipher.py
@@ -23,7 +23,7 @@
import Crypto.Hash.SHA1
import Crypto.Util.number
-import gmpy2
+import gmpy2 # type: ignore
from Crypto import Random
from Crypto.Signature.pss import MGF1
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
@@ -191,12 +191,12 @@ def decrypt(self, ciphertext):
# Step 3g
one_pos = hLen + db[hLen:].find(b"\x01")
lHash1 = db[:hLen]
- invalid = bord(y) | int(one_pos < hLen)
+ invalid = bord(y) | int(one_pos < hLen) # type: ignore
hash_compare = strxor(lHash1, lHash)
for x in hash_compare:
- invalid |= bord(x)
+ invalid |= bord(x) # type: ignore
for x in db[hLen:one_pos]:
- invalid |= bord(x)
+ invalid |= bord(x) # type: ignore
if invalid != 0:
raise ValueError("Incorrect decryption.")
# Step 4
diff --git a/api/libs/helper.py b/api/libs/helper.py
index 91b1d1fe173d6f..eaa4efdb714355 100644
--- a/api/libs/helper.py
+++ b/api/libs/helper.py
@@ -13,7 +13,7 @@
from zoneinfo import available_timezones
from flask import Response, stream_with_context
-from flask_restful import fields
+from flask_restful import fields # type: ignore
from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
@@ -248,13 +248,13 @@ def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]
if token_data_json is None:
logging.warning(f"{token_type} token {token} not found with key {key}")
return None
- token_data = json.loads(token_data_json)
+ token_data: Optional[dict[str, Any]] = json.loads(token_data_json)
return token_data
@classmethod
def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]:
key = cls._get_account_token_key(account_id, token_type)
- current_token = redis_client.get(key)
+ current_token: Optional[str] = redis_client.get(key)
return current_token
@classmethod
diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py
index 267af611f5e8cb..9ab53b6294db93 100644
--- a/api/libs/json_in_md_parser.py
+++ b/api/libs/json_in_md_parser.py
@@ -10,6 +10,7 @@ def parse_json_markdown(json_string: str) -> dict:
ends = ["```", "``", "`", "}"]
end_index = -1
start_index = 0
+ parsed: dict = {}
for s in starts:
start_index = json_string.find(s)
if start_index != -1:
diff --git a/api/libs/login.py b/api/libs/login.py
index 0ea191a185785d..5395534a6df93a 100644
--- a/api/libs/login.py
+++ b/api/libs/login.py
@@ -1,8 +1,9 @@
from functools import wraps
+from typing import Any
from flask import current_app, g, has_request_context, request
-from flask_login import user_logged_in
-from flask_login.config import EXEMPT_METHODS
+from flask_login import user_logged_in # type: ignore
+from flask_login.config import EXEMPT_METHODS # type: ignore
from werkzeug.exceptions import Unauthorized
from werkzeug.local import LocalProxy
@@ -12,7 +13,7 @@
#: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user
-current_user = LocalProxy(lambda: _get_user())
+current_user: Any = LocalProxy(lambda: _get_user())
def login_required(func):
@@ -79,12 +80,12 @@ def decorated_view(*args, **kwargs):
# Login admin
if account:
account.current_tenant = tenant
- current_app.login_manager._update_request_context_with_user(account)
- user_logged_in.send(current_app._get_current_object(), user=_get_user())
+ current_app.login_manager._update_request_context_with_user(account) # type: ignore
+ user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
pass
elif not current_user.is_authenticated:
- return current_app.login_manager.unauthorized()
+ return current_app.login_manager.unauthorized() # type: ignore
# flask 1.x compatibility
# current_app.ensure_sync is only available in Flask >= 2.0
@@ -98,7 +99,7 @@ def decorated_view(*args, **kwargs):
def _get_user():
if has_request_context():
if "_login_user" not in g:
- current_app.login_manager._load_user()
+ current_app.login_manager._load_user() # type: ignore
return g._login_user
diff --git a/api/libs/oauth.py b/api/libs/oauth.py
index 6b6919de24f90f..df75b550195298 100644
--- a/api/libs/oauth.py
+++ b/api/libs/oauth.py
@@ -77,9 +77,9 @@ def get_raw_user_info(self, token: str):
email_response = requests.get(self._EMAIL_INFO_URL, headers=headers)
email_info = email_response.json()
- primary_email = next((email for email in email_info if email["primary"] == True), None)
+ primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
- return {**user_info, "email": primary_email["email"]}
+ return {**user_info, "email": primary_email.get("email", "")}
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
email = raw_info.get("email")
@@ -130,4 +130,4 @@ def get_raw_user_info(self, token: str):
return response.json()
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
- return OAuthUserInfo(id=str(raw_info["sub"]), name=None, email=raw_info["email"])
+ return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])
diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py
index 1d39abd8fa7886..0c872a0066d127 100644
--- a/api/libs/oauth_data_source.py
+++ b/api/libs/oauth_data_source.py
@@ -1,8 +1,9 @@
import datetime
import urllib.parse
+from typing import Any
import requests
-from flask_login import current_user
+from flask_login import current_user # type: ignore
from extensions.ext_database import db
from models.source import DataSourceOauthBinding
@@ -226,7 +227,7 @@ def notion_page_search(self, access_token: str):
has_more = True
while has_more:
- data = {
+ data: dict[str, Any] = {
"filter": {"value": "page", "property": "object"},
**({"start_cursor": next_cursor} if next_cursor else {}),
}
@@ -281,7 +282,7 @@ def notion_database_search(self, access_token: str):
has_more = True
while has_more:
- data = {
+ data: dict[str, Any] = {
"filter": {"value": "database", "property": "object"},
**({"start_cursor": next_cursor} if next_cursor else {}),
}
diff --git a/api/libs/threadings_utils.py b/api/libs/threadings_utils.py
index d356def418ab1d..e4d63fd3142ce2 100644
--- a/api/libs/threadings_utils.py
+++ b/api/libs/threadings_utils.py
@@ -9,8 +9,8 @@ def apply_gevent_threading_patch():
:return:
"""
if not dify_config.DEBUG:
- from gevent import monkey
- from grpc.experimental import gevent as grpc_gevent
+ from gevent import monkey # type: ignore
+ from grpc.experimental import gevent as grpc_gevent # type: ignore
# gevent
monkey.patch_all()
diff --git a/api/models/account.py b/api/models/account.py
index a8602d10a97308..88c96da1a149d5 100644
--- a/api/models/account.py
+++ b/api/models/account.py
@@ -1,7 +1,7 @@
import enum
import json
-from flask_login import UserMixin
+from flask_login import UserMixin # type: ignore
from sqlalchemy import func
from .engine import db
@@ -16,7 +16,7 @@ class AccountStatus(enum.StrEnum):
CLOSED = "closed"
-class Account(UserMixin, db.Model):
+class Account(UserMixin, db.Model): # type: ignore[name-defined]
__tablename__ = "accounts"
__table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email"))
@@ -43,7 +43,8 @@ def is_password_set(self):
@property
def current_tenant(self):
- return self._current_tenant
+ # FIXME: fix the type error later, because the type is important maybe cause some bugs
+ return self._current_tenant # type: ignore
@current_tenant.setter
def current_tenant(self, value: "Tenant"):
@@ -52,7 +53,8 @@ def current_tenant(self, value: "Tenant"):
if ta:
tenant.current_role = ta.role
else:
- tenant = None
+ # FIXME: fix the type error later, because the type is important maybe cause some bugs
+ tenant = None # type: ignore
self._current_tenant = tenant
@property
@@ -89,7 +91,7 @@ def get_status(self) -> AccountStatus:
return AccountStatus(status_str)
@classmethod
- def get_by_openid(cls, provider: str, open_id: str) -> db.Model:
+ def get_by_openid(cls, provider: str, open_id: str):
account_integrate = (
db.session.query(AccountIntegrate)
.filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
@@ -134,7 +136,7 @@ class TenantAccountRole(enum.StrEnum):
@staticmethod
def is_valid_role(role: str) -> bool:
- return role and role in {
+ return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
@@ -144,15 +146,15 @@ def is_valid_role(role: str) -> bool:
@staticmethod
def is_privileged_role(role: str) -> bool:
- return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
+ return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
@staticmethod
def is_admin_role(role: str) -> bool:
- return role and role == TenantAccountRole.ADMIN
+ return role == TenantAccountRole.ADMIN
@staticmethod
def is_non_owner_role(role: str) -> bool:
- return role and role in {
+ return role in {
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
@@ -161,11 +163,11 @@ def is_non_owner_role(role: str) -> bool:
@staticmethod
def is_editing_role(role: str) -> bool:
- return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
+ return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
@staticmethod
def is_dataset_edit_role(role: str) -> bool:
- return role and role in {
+ return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
@@ -173,7 +175,7 @@ def is_dataset_edit_role(role: str) -> bool:
}
-class Tenant(db.Model):
+class Tenant(db.Model): # type: ignore[name-defined]
__tablename__ = "tenants"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
@@ -209,7 +211,7 @@ class TenantAccountJoinRole(enum.Enum):
DATASET_OPERATOR = "dataset_operator"
-class TenantAccountJoin(db.Model):
+class TenantAccountJoin(db.Model): # type: ignore[name-defined]
__tablename__ = "tenant_account_joins"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
@@ -228,7 +230,7 @@ class TenantAccountJoin(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class AccountIntegrate(db.Model):
+class AccountIntegrate(db.Model): # type: ignore[name-defined]
__tablename__ = "account_integrates"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
@@ -245,7 +247,7 @@ class AccountIntegrate(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class InvitationCode(db.Model):
+class InvitationCode(db.Model): # type: ignore[name-defined]
__tablename__ = "invitation_codes"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py
index fbffe7a3b2ee9d..6b6d808710afc0 100644
--- a/api/models/api_based_extension.py
+++ b/api/models/api_based_extension.py
@@ -13,7 +13,7 @@ class APIBasedExtensionPoint(enum.Enum):
APP_MODERATION_OUTPUT = "app.moderation.output"
-class APIBasedExtension(db.Model):
+class APIBasedExtension(db.Model): # type: ignore[name-defined]
__tablename__ = "api_based_extensions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),
diff --git a/api/models/dataset.py b/api/models/dataset.py
index 7279e8d5b3394a..b9b41dcf475bb1 100644
--- a/api/models/dataset.py
+++ b/api/models/dataset.py
@@ -9,6 +9,7 @@
import re
import time
from json import JSONDecodeError
+from typing import Any, cast
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB
@@ -29,7 +30,7 @@ class DatasetPermissionEnum(enum.StrEnum):
PARTIAL_TEAM = "partial_members"
-class Dataset(db.Model):
+class Dataset(db.Model): # type: ignore[name-defined]
__tablename__ = "datasets"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_pkey"),
@@ -200,7 +201,7 @@ def gen_collection_name_by_id(dataset_id: str) -> str:
return f"Vector_index_{normalized_dataset_id}_Node"
-class DatasetProcessRule(db.Model):
+class DatasetProcessRule(db.Model): # type: ignore[name-defined]
__tablename__ = "dataset_process_rules"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
@@ -216,7 +217,7 @@ class DatasetProcessRule(db.Model):
MODES = ["automatic", "custom"]
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
- AUTOMATIC_RULES = {
+ AUTOMATIC_RULES: dict[str, Any] = {
"pre_processing_rules": [
{"id": "remove_extra_spaces", "enabled": True},
{"id": "remove_urls_emails", "enabled": False},
@@ -242,7 +243,7 @@ def rules_dict(self):
return None
-class Document(db.Model):
+class Document(db.Model): # type: ignore[name-defined]
__tablename__ = "documents"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_pkey"),
@@ -492,7 +493,7 @@ def from_dict(cls, data: dict):
)
-class DocumentSegment(db.Model):
+class DocumentSegment(db.Model): # type: ignore[name-defined]
__tablename__ = "document_segments"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
@@ -604,7 +605,7 @@ def get_sign_content(self):
return text
-class AppDatasetJoin(db.Model):
+class AppDatasetJoin(db.Model): # type: ignore[name-defined]
__tablename__ = "app_dataset_joins"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
@@ -621,7 +622,7 @@ def app(self):
return db.session.get(App, self.app_id)
-class DatasetQuery(db.Model):
+class DatasetQuery(db.Model): # type: ignore[name-defined]
__tablename__ = "dataset_queries"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
@@ -638,7 +639,7 @@ class DatasetQuery(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
-class DatasetKeywordTable(db.Model):
+class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
__tablename__ = "dataset_keyword_tables"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
@@ -683,7 +684,7 @@ def object_hook(self, dct):
return None
-class Embedding(db.Model):
+class Embedding(db.Model): # type: ignore[name-defined]
__tablename__ = "embeddings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="embedding_pkey"),
@@ -704,10 +705,10 @@ def set_embedding(self, embedding_data: list[float]):
self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
def get_embedding(self) -> list[float]:
- return pickle.loads(self.embedding)
+ return cast(list[float], pickle.loads(self.embedding))
-class DatasetCollectionBinding(db.Model):
+class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
__tablename__ = "dataset_collection_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
@@ -722,7 +723,7 @@ class DatasetCollectionBinding(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class TidbAuthBinding(db.Model):
+class TidbAuthBinding(db.Model): # type: ignore[name-defined]
__tablename__ = "tidb_auth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
@@ -742,7 +743,7 @@ class TidbAuthBinding(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class Whitelist(db.Model):
+class Whitelist(db.Model): # type: ignore[name-defined]
__tablename__ = "whitelists"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
@@ -754,7 +755,7 @@ class Whitelist(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class DatasetPermission(db.Model):
+class DatasetPermission(db.Model): # type: ignore[name-defined]
__tablename__ = "dataset_permissions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
@@ -771,7 +772,7 @@ class DatasetPermission(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class ExternalKnowledgeApis(db.Model):
+class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
__tablename__ = "external_knowledge_apis"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
@@ -824,7 +825,7 @@ def dataset_bindings(self):
return dataset_bindings
-class ExternalKnowledgeBindings(db.Model):
+class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
__tablename__ = "external_knowledge_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
diff --git a/api/models/model.py b/api/models/model.py
index 1417298c79c0a2..2a593f08298199 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -4,11 +4,11 @@
from collections.abc import Mapping
from datetime import datetime
from enum import Enum, StrEnum
-from typing import TYPE_CHECKING, Any, Literal, Optional
+from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import sqlalchemy as sa
from flask import request
-from flask_login import UserMixin
+from flask_login import UserMixin # type: ignore
from sqlalchemy import Float, func, text
from sqlalchemy.orm import Mapped, mapped_column
@@ -28,7 +28,7 @@
from .workflow import Workflow
-class DifySetup(db.Model):
+class DifySetup(db.Model): # type: ignore[name-defined]
__tablename__ = "dify_setups"
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
@@ -63,7 +63,7 @@ class IconType(Enum):
EMOJI = "emoji"
-class App(db.Model):
+class App(db.Model): # type: ignore[name-defined]
__tablename__ = "apps"
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id"))
@@ -86,7 +86,7 @@ class App(db.Model):
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
tracing = db.Column(db.Text, nullable=True)
- max_active_requests = db.Column(db.Integer, nullable=True)
+ max_active_requests: Mapped[Optional[int]] = mapped_column(nullable=True)
created_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
@@ -154,7 +154,7 @@ def mode_compatible_with_agent(self) -> str:
if self.mode == AppMode.CHAT.value and self.is_agent:
return AppMode.AGENT_CHAT.value
- return self.mode
+ return str(self.mode)
@property
def deleted_tools(self) -> list:
@@ -219,7 +219,7 @@ def tags(self):
return tags or []
-class AppModelConfig(db.Model):
+class AppModelConfig(db.Model): # type: ignore[name-defined]
__tablename__ = "app_model_configs"
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id"))
@@ -322,7 +322,7 @@ def external_data_tools_list(self) -> list[dict]:
return json.loads(self.external_data_tools) if self.external_data_tools else []
@property
- def user_input_form_list(self) -> dict:
+ def user_input_form_list(self) -> list[dict]:
return json.loads(self.user_input_form) if self.user_input_form else []
@property
@@ -344,7 +344,7 @@ def completion_prompt_config_dict(self) -> dict:
@property
def dataset_configs_dict(self) -> dict:
if self.dataset_configs:
- dataset_configs = json.loads(self.dataset_configs)
+ dataset_configs: dict = json.loads(self.dataset_configs)
if "retrieval_model" not in dataset_configs:
return {"retrieval_model": "single"}
else:
@@ -466,7 +466,7 @@ def copy(self):
return new_app_model_config
-class RecommendedApp(db.Model):
+class RecommendedApp(db.Model): # type: ignore[name-defined]
__tablename__ = "recommended_apps"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
@@ -494,7 +494,7 @@ def app(self):
return app
-class InstalledApp(db.Model):
+class InstalledApp(db.Model): # type: ignore[name-defined]
__tablename__ = "installed_apps"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="installed_app_pkey"),
@@ -523,7 +523,7 @@ def tenant(self):
return tenant
-class Conversation(db.Model):
+class Conversation(db.Model): # type: ignore[name-defined]
__tablename__ = "conversations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="conversation_pkey"),
@@ -602,6 +602,8 @@ def inputs(self, value: Mapping[str, Any]):
@property
def model_config(self):
model_config = {}
+ app_model_config: Optional[AppModelConfig] = None
+
if self.mode == AppMode.ADVANCED_CHAT.value:
if self.override_model_configs:
override_model_configs = json.loads(self.override_model_configs)
@@ -613,6 +615,7 @@ def model_config(self):
if "model" in override_model_configs:
app_model_config = AppModelConfig()
app_model_config = app_model_config.from_model_config_dict(override_model_configs)
+ assert app_model_config is not None, "app model config not found"
model_config = app_model_config.to_dict()
else:
model_config["configs"] = override_model_configs
@@ -755,7 +758,7 @@ def in_debug_mode(self):
return self.override_model_configs is not None
-class Message(db.Model):
+class Message(db.Model): # type: ignore[name-defined]
__tablename__ = "messages"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_pkey"),
@@ -995,7 +998,7 @@ def message_files(self):
if not current_app:
raise ValueError(f"App {self.app_id} not found")
- files: list[File] = []
+ files = []
for message_file in message_files:
if message_file.transfer_method == "local_file":
if message_file.upload_file_id is None:
@@ -1102,7 +1105,7 @@ def from_dict(cls, data: dict):
)
-class MessageFeedback(db.Model):
+class MessageFeedback(db.Model): # type: ignore[name-defined]
__tablename__ = "message_feedbacks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
@@ -1129,7 +1132,7 @@ def from_account(self):
return account
-class MessageFile(db.Model):
+class MessageFile(db.Model): # type: ignore[name-defined]
__tablename__ = "message_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_file_pkey"),
@@ -1170,7 +1173,7 @@ def __init__(
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class MessageAnnotation(db.Model):
+class MessageAnnotation(db.Model): # type: ignore[name-defined]
__tablename__ = "message_annotations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
@@ -1201,7 +1204,7 @@ def annotation_create_account(self):
return account
-class AppAnnotationHitHistory(db.Model):
+class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
__tablename__ = "app_annotation_hit_histories"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@@ -1239,7 +1242,7 @@ def annotation_create_account(self):
return account
-class AppAnnotationSetting(db.Model):
+class AppAnnotationSetting(db.Model): # type: ignore[name-defined]
__tablename__ = "app_annotation_settings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
@@ -1287,7 +1290,7 @@ def collection_binding_detail(self):
return collection_binding_detail
-class OperationLog(db.Model):
+class OperationLog(db.Model): # type: ignore[name-defined]
__tablename__ = "operation_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="operation_log_pkey"),
@@ -1304,7 +1307,7 @@ class OperationLog(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class EndUser(UserMixin, db.Model):
+class EndUser(UserMixin, db.Model): # type: ignore[name-defined]
__tablename__ = "end_users"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="end_user_pkey"),
@@ -1324,7 +1327,7 @@ class EndUser(UserMixin, db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class Site(db.Model):
+class Site(db.Model): # type: ignore[name-defined]
__tablename__ = "sites"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="site_pkey"),
@@ -1381,7 +1384,7 @@ def app_base_url(self):
return dify_config.APP_WEB_URL or request.url_root.rstrip("/")
-class ApiToken(db.Model):
+class ApiToken(db.Model): # type: ignore[name-defined]
__tablename__ = "api_tokens"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="api_token_pkey"),
@@ -1408,7 +1411,7 @@ def generate_api_key(prefix, n):
return result
-class UploadFile(db.Model):
+class UploadFile(db.Model): # type: ignore[name-defined]
__tablename__ = "upload_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="upload_file_pkey"),
@@ -1470,7 +1473,7 @@ def __init__(
self.source_url = source_url
-class ApiRequest(db.Model):
+class ApiRequest(db.Model): # type: ignore[name-defined]
__tablename__ = "api_requests"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="api_request_pkey"),
@@ -1487,7 +1490,7 @@ class ApiRequest(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class MessageChain(db.Model):
+class MessageChain(db.Model): # type: ignore[name-defined]
__tablename__ = "message_chains"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_chain_pkey"),
@@ -1502,7 +1505,7 @@ class MessageChain(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
-class MessageAgentThought(db.Model):
+class MessageAgentThought(db.Model): # type: ignore[name-defined]
__tablename__ = "message_agent_thoughts"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
@@ -1542,7 +1545,7 @@ class MessageAgentThought(db.Model):
@property
def files(self) -> list:
if self.message_files:
- return json.loads(self.message_files)
+ return cast(list[Any], json.loads(self.message_files))
else:
return []
@@ -1554,7 +1557,7 @@ def tools(self) -> list[str]:
def tool_labels(self) -> dict:
try:
if self.tool_labels_str:
- return json.loads(self.tool_labels_str)
+ return cast(dict, json.loads(self.tool_labels_str))
else:
return {}
except Exception as e:
@@ -1564,7 +1567,7 @@ def tool_labels(self) -> dict:
def tool_meta(self) -> dict:
try:
if self.tool_meta_str:
- return json.loads(self.tool_meta_str)
+ return cast(dict, json.loads(self.tool_meta_str))
else:
return {}
except Exception as e:
@@ -1612,9 +1615,11 @@ def tool_outputs_dict(self) -> dict:
except Exception as e:
if self.observation:
return dict.fromkeys(tools, self.observation)
+ else:
+ return {}
-class DatasetRetrieverResource(db.Model):
+class DatasetRetrieverResource(db.Model): # type: ignore[name-defined]
__tablename__ = "dataset_retriever_resources"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
@@ -1641,7 +1646,7 @@ class DatasetRetrieverResource(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
-class Tag(db.Model):
+class Tag(db.Model): # type: ignore[name-defined]
__tablename__ = "tags"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tag_pkey"),
@@ -1659,7 +1664,7 @@ class Tag(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class TagBinding(db.Model):
+class TagBinding(db.Model): # type: ignore[name-defined]
__tablename__ = "tag_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tag_binding_pkey"),
@@ -1675,7 +1680,7 @@ class TagBinding(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class TraceAppConfig(db.Model):
+class TraceAppConfig(db.Model): # type: ignore[name-defined]
__tablename__ = "trace_app_config"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),
diff --git a/api/models/provider.py b/api/models/provider.py
index fdd3e802d79211..abe673975c1ccc 100644
--- a/api/models/provider.py
+++ b/api/models/provider.py
@@ -36,7 +36,7 @@ def value_of(value):
raise ValueError(f"No matching enum found for value '{value}'")
-class Provider(db.Model):
+class Provider(db.Model): # type: ignore[name-defined]
"""
Provider model representing the API providers and their configurations.
"""
@@ -89,7 +89,7 @@ def is_enabled(self):
return self.is_valid and self.token_is_set
-class ProviderModel(db.Model):
+class ProviderModel(db.Model): # type: ignore[name-defined]
"""
Provider model representing the API provider_models and their configurations.
"""
@@ -114,7 +114,7 @@ class ProviderModel(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class TenantDefaultModel(db.Model):
+class TenantDefaultModel(db.Model): # type: ignore[name-defined]
__tablename__ = "tenant_default_models"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
@@ -130,7 +130,7 @@ class TenantDefaultModel(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class TenantPreferredModelProvider(db.Model):
+class TenantPreferredModelProvider(db.Model): # type: ignore[name-defined]
__tablename__ = "tenant_preferred_model_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"),
@@ -145,7 +145,7 @@ class TenantPreferredModelProvider(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class ProviderOrder(db.Model):
+class ProviderOrder(db.Model): # type: ignore[name-defined]
__tablename__ = "provider_orders"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="provider_order_pkey"),
@@ -170,7 +170,7 @@ class ProviderOrder(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class ProviderModelSetting(db.Model):
+class ProviderModelSetting(db.Model): # type: ignore[name-defined]
"""
Provider model settings for record the model enabled status and load balancing status.
"""
@@ -192,7 +192,7 @@ class ProviderModelSetting(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class LoadBalancingModelConfig(db.Model):
+class LoadBalancingModelConfig(db.Model): # type: ignore[name-defined]
"""
Configurations for load balancing models.
"""
diff --git a/api/models/source.py b/api/models/source.py
index 114db8e1100e5d..881cfaac7d3998 100644
--- a/api/models/source.py
+++ b/api/models/source.py
@@ -7,7 +7,7 @@
from .types import StringUUID
-class DataSourceOauthBinding(db.Model):
+class DataSourceOauthBinding(db.Model): # type: ignore[name-defined]
__tablename__ = "data_source_oauth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="source_binding_pkey"),
@@ -25,7 +25,7 @@ class DataSourceOauthBinding(db.Model):
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
-class DataSourceApiKeyAuthBinding(db.Model):
+class DataSourceApiKeyAuthBinding(db.Model): # type: ignore[name-defined]
__tablename__ = "data_source_api_key_auth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),
diff --git a/api/models/task.py b/api/models/task.py
index 27571e24746fe7..0db1c632299fcb 100644
--- a/api/models/task.py
+++ b/api/models/task.py
@@ -1,11 +1,11 @@
from datetime import UTC, datetime
-from celery import states
+from celery import states # type: ignore
from .engine import db
-class CeleryTask(db.Model):
+class CeleryTask(db.Model): # type: ignore[name-defined]
"""Task result/status."""
__tablename__ = "celery_taskmeta"
@@ -29,7 +29,7 @@ class CeleryTask(db.Model):
queue = db.Column(db.String(155), nullable=True)
-class CeleryTaskSet(db.Model):
+class CeleryTaskSet(db.Model): # type: ignore[name-defined]
"""TaskSet result."""
__tablename__ = "celery_tasksetmeta"
diff --git a/api/models/tools.py b/api/models/tools.py
index e90ab669c66f1e..4151a2e9f636a0 100644
--- a/api/models/tools.py
+++ b/api/models/tools.py
@@ -14,7 +14,7 @@
from .types import StringUUID
-class BuiltinToolProvider(db.Model):
+class BuiltinToolProvider(db.Model): # type: ignore[name-defined]
"""
This table stores the tool provider information for built-in tools for each tenant.
"""
@@ -41,10 +41,10 @@ class BuiltinToolProvider(db.Model):
@property
def credentials(self) -> dict:
- return json.loads(self.encrypted_credentials)
+ return dict(json.loads(self.encrypted_credentials))
-class PublishedAppTool(db.Model):
+class PublishedAppTool(db.Model): # type: ignore[name-defined]
"""
The table stores the apps published as a tool for each person.
"""
@@ -86,7 +86,7 @@ def app(self):
return db.session.query(App).filter(App.id == self.app_id).first()
-class ApiToolProvider(db.Model):
+class ApiToolProvider(db.Model): # type: ignore[name-defined]
"""
The table stores the api providers.
"""
@@ -133,7 +133,7 @@ def tools(self) -> list[ApiToolBundle]:
@property
def credentials(self) -> dict:
- return json.loads(self.credentials_str)
+ return dict(json.loads(self.credentials_str))
@property
def user(self) -> Account | None:
@@ -144,7 +144,7 @@ def tenant(self) -> Tenant | None:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
-class ToolLabelBinding(db.Model):
+class ToolLabelBinding(db.Model): # type: ignore[name-defined]
"""
The table stores the labels for tools.
"""
@@ -164,7 +164,7 @@ class ToolLabelBinding(db.Model):
label_name = db.Column(db.String(40), nullable=False)
-class WorkflowToolProvider(db.Model):
+class WorkflowToolProvider(db.Model): # type: ignore[name-defined]
"""
The table stores the workflow providers.
"""
@@ -218,7 +218,7 @@ def app(self) -> App | None:
return db.session.query(App).filter(App.id == self.app_id).first()
-class ToolModelInvoke(db.Model):
+class ToolModelInvoke(db.Model): # type: ignore[name-defined]
"""
store the invoke logs from tool invoke
"""
@@ -255,7 +255,7 @@ class ToolModelInvoke(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-class ToolConversationVariables(db.Model):
+class ToolConversationVariables(db.Model): # type: ignore[name-defined]
"""
store the conversation variables from tool invoke
"""
@@ -283,10 +283,10 @@ class ToolConversationVariables(db.Model):
@property
def variables(self) -> dict:
- return json.loads(self.variables_str)
+ return dict(json.loads(self.variables_str))
-class ToolFile(db.Model):
+class ToolFile(db.Model): # type: ignore[name-defined]
__tablename__ = "tool_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_file_pkey"),
diff --git a/api/models/web.py b/api/models/web.py
index 028a768519d99a..864428fe0931b6 100644
--- a/api/models/web.py
+++ b/api/models/web.py
@@ -6,7 +6,7 @@
from .types import StringUUID
-class SavedMessage(db.Model):
+class SavedMessage(db.Model): # type: ignore[name-defined]
__tablename__ = "saved_messages"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="saved_message_pkey"),
@@ -25,7 +25,7 @@ def message(self):
return db.session.query(Message).filter(Message.id == self.message_id).first()
-class PinnedConversation(db.Model):
+class PinnedConversation(db.Model): # type: ignore[name-defined]
__tablename__ = "pinned_conversations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),
diff --git a/api/models/workflow.py b/api/models/workflow.py
index d5be949bf44f2a..880e044d073a67 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -2,7 +2,7 @@
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from enum import Enum, StrEnum
-from typing import Any, Optional, Union
+from typing import TYPE_CHECKING, Any, Optional, Union
import sqlalchemy as sa
from sqlalchemy import func
@@ -20,6 +20,9 @@
from .engine import db
from .types import StringUUID
+if TYPE_CHECKING:
+ from models.model import AppMode, Message
+
class WorkflowType(Enum):
"""
@@ -56,7 +59,7 @@ def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT
-class Workflow(db.Model):
+class Workflow(db.Model): # type: ignore[name-defined]
"""
Workflow, for `Workflow App` and `Chat App workflow mode`.
@@ -182,7 +185,7 @@ def features(self, value: str) -> None:
self._features = value
@property
- def features_dict(self) -> Mapping[str, Any]:
+ def features_dict(self) -> dict[str, Any]:
return json.loads(self.features) if self.features else {}
def user_input_form(self, to_old_structure: bool = False) -> list:
@@ -199,7 +202,7 @@ def user_input_form(self, to_old_structure: bool = False) -> list:
return []
# get user_input_form from start node
- variables = start_node.get("data", {}).get("variables", [])
+ variables: list[Any] = start_node.get("data", {}).get("variables", [])
if to_old_structure:
old_structure_variables = []
@@ -344,7 +347,7 @@ def value_of(cls, value: str) -> "WorkflowRunStatus":
raise ValueError(f"invalid workflow run status value {value}")
-class WorkflowRun(db.Model):
+class WorkflowRun(db.Model): # type: ignore[name-defined]
"""
Workflow Run
@@ -546,7 +549,7 @@ def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
raise ValueError(f"invalid workflow node execution status value {value}")
-class WorkflowNodeExecution(db.Model):
+class WorkflowNodeExecution(db.Model): # type: ignore[name-defined]
"""
Workflow Node Execution
@@ -712,7 +715,7 @@ def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
raise ValueError(f"invalid workflow app log created from value {value}")
-class WorkflowAppLog(db.Model):
+class WorkflowAppLog(db.Model): # type: ignore[name-defined]
"""
Workflow App execution log, excluding workflow debugging records.
@@ -774,7 +777,7 @@ def created_by_end_user(self):
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
-class ConversationVariable(db.Model):
+class ConversationVariable(db.Model): # type: ignore[name-defined]
__tablename__ = "workflow_conversation_variables"
id: Mapped[str] = db.Column(StringUUID, primary_key=True)
diff --git a/api/mypy.ini b/api/mypy.ini
new file mode 100644
index 00000000000000..2c754f9fcd7c63
--- /dev/null
+++ b/api/mypy.ini
@@ -0,0 +1,10 @@
+[mypy]
+warn_return_any = True
+warn_unused_configs = True
+check_untyped_defs = True
+exclude = (?x)(
+ core/tools/provider/builtin/
+ | core/model_runtime/model_providers/
+ | tests/
+ | migrations/
+ )
\ No newline at end of file
diff --git a/api/poetry.lock b/api/poetry.lock
index 35fda9b36fa42a..b42eb22dd40b8a 100644
--- a/api/poetry.lock
+++ b/api/poetry.lock
@@ -5643,6 +5643,58 @@ files = [
{file = "multitasking-0.0.11.tar.gz", hash = "sha256:4d6bc3cc65f9b2dca72fb5a787850a88dae8f620c2b36ae9b55248e51bcd6026"},
]
+[[package]]
+name = "mypy"
+version = "1.13.0"
+description = "Optional static typing for Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"},
+ {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"},
+ {file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"},
+ {file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"},
+ {file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"},
+ {file = "mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d"},
+ {file = "mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d"},
+ {file = "mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b"},
+ {file = "mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73"},
+ {file = "mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca"},
+ {file = "mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5"},
+ {file = "mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e"},
+ {file = "mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2"},
+ {file = "mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0"},
+ {file = "mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2"},
+ {file = "mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7"},
+ {file = "mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62"},
+ {file = "mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8"},
+ {file = "mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7"},
+ {file = "mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc"},
+ {file = "mypy-1.13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:100fac22ce82925f676a734af0db922ecfea991e1d7ec0ceb1e115ebe501301a"},
+ {file = "mypy-1.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bcb0bb7f42a978bb323a7c88f1081d1b5dee77ca86f4100735a6f541299d8fb"},
+ {file = "mypy-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bde31fc887c213e223bbfc34328070996061b0833b0a4cfec53745ed61f3519b"},
+ {file = "mypy-1.13.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07de989f89786f62b937851295ed62e51774722e5444a27cecca993fc3f9cd74"},
+ {file = "mypy-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:4bde84334fbe19bad704b3f5b78c4abd35ff1026f8ba72b29de70dda0916beb6"},
+ {file = "mypy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0246bcb1b5de7f08f2826451abd947bf656945209b140d16ed317f65a17dc7dc"},
+ {file = "mypy-1.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f5b7deae912cf8b77e990b9280f170381fdfbddf61b4ef80927edd813163732"},
+ {file = "mypy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7029881ec6ffb8bc233a4fa364736789582c738217b133f1b55967115288a2bc"},
+ {file = "mypy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3e38b980e5681f28f033f3be86b099a247b13c491f14bb8b1e1e134d23bb599d"},
+ {file = "mypy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:a6789be98a2017c912ae6ccb77ea553bbaf13d27605d2ca20a76dfbced631b24"},
+ {file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"},
+ {file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"},
+]
+
+[package.dependencies]
+mypy-extensions = ">=1.0.0"
+typing-extensions = ">=4.6.0"
+
+[package.extras]
+dmypy = ["psutil (>=4.0)"]
+faster-cache = ["orjson"]
+install-types = ["pip"]
+mypyc = ["setuptools (>=50)"]
+reports = ["lxml"]
+
[[package]]
name = "mypy-extensions"
version = "1.0.0"
@@ -6537,6 +6589,21 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.9.2)"]
+[[package]]
+name = "pandas-stubs"
+version = "2.2.3.241126"
+description = "Type annotations for pandas"
+optional = false
+python-versions = ">=3.10"
+files = [
+ {file = "pandas_stubs-2.2.3.241126-py3-none-any.whl", hash = "sha256:74aa79c167af374fe97068acc90776c0ebec5266a6e5c69fe11e9c2cf51f2267"},
+ {file = "pandas_stubs-2.2.3.241126.tar.gz", hash = "sha256:cf819383c6d9ae7d4dabf34cd47e1e45525bb2f312e6ad2939c2c204cb708acd"},
+]
+
+[package.dependencies]
+numpy = ">=1.23.5"
+types-pytz = ">=2022.1.1"
+
[[package]]
name = "pathos"
version = "0.3.3"
@@ -9255,13 +9322,13 @@ sqlcipher = ["sqlcipher3_binary"]
[[package]]
name = "sqlparse"
-version = "0.5.2"
+version = "0.5.3"
description = "A non-validating SQL parser."
optional = false
python-versions = ">=3.8"
files = [
- {file = "sqlparse-0.5.2-py3-none-any.whl", hash = "sha256:e99bc85c78160918c3e1d9230834ab8d80fc06c59d03f8db2618f65f65dda55e"},
- {file = "sqlparse-0.5.2.tar.gz", hash = "sha256:9e37b35e16d1cc652a2545f0997c1deb23ea28fa1f3eefe609eee3063c3b105f"},
+ {file = "sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca"},
+ {file = "sqlparse-0.5.3.tar.gz", hash = "sha256:09f67787f56a0b16ecdbde1bfc7f5d9c3371ca683cfeaa8e6ff60b4807ec9272"},
]
[package.extras]
@@ -9847,6 +9914,17 @@ rich = ">=10.11.0"
shellingham = ">=1.3.0"
typing-extensions = ">=3.7.4.3"
+[[package]]
+name = "types-pytz"
+version = "2024.2.0.20241003"
+description = "Typing stubs for pytz"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "types-pytz-2024.2.0.20241003.tar.gz", hash = "sha256:575dc38f385a922a212bac00a7d6d2e16e141132a3c955078f4a4fd13ed6cb44"},
+ {file = "types_pytz-2024.2.0.20241003-py3-none-any.whl", hash = "sha256:3e22df1336c0c6ad1d29163c8fda82736909eb977281cb823c57f8bae07118b7"},
+]
+
[[package]]
name = "types-requests"
version = "2.32.0.20241016"
@@ -10313,82 +10391,82 @@ ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic
[[package]]
name = "watchfiles"
-version = "1.0.0"
+version = "1.0.3"
description = "Simple, modern and high performance file watching and code reload in python."
optional = false
python-versions = ">=3.9"
files = [
- {file = "watchfiles-1.0.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:1d19df28f99d6a81730658fbeb3ade8565ff687f95acb59665f11502b441be5f"},
- {file = "watchfiles-1.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:28babb38cf2da8e170b706c4b84aa7e4528a6fa4f3ee55d7a0866456a1662041"},
- {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12ab123135b2f42517f04e720526d41448667ae8249e651385afb5cda31fedc0"},
- {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:13a4f9ee0cd25682679eea5c14fc629e2eaa79aab74d963bc4e21f43b8ea1877"},
- {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e1d9284cc84de7855fcf83472e51d32daf6f6cecd094160192628bc3fee1b78"},
- {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ee5edc939f53466b329bbf2e58333a5461e6c7b50c980fa6117439e2c18b42d"},
- {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5dccfc70480087567720e4e36ec381bba1ed68d7e5f368fe40c93b3b1eba0105"},
- {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c83a6d33a9eda0af6a7470240d1af487807adc269704fe76a4972dd982d16236"},
- {file = "watchfiles-1.0.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:905f69aad276639eff3893759a07d44ea99560e67a1cf46ff389cd62f88872a2"},
- {file = "watchfiles-1.0.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:09551237645d6bff3972592f2aa5424df9290e7a2e15d63c5f47c48cde585935"},
- {file = "watchfiles-1.0.0-cp310-none-win32.whl", hash = "sha256:d2b39aa8edd9e5f56f99a2a2740a251dc58515398e9ed5a4b3e5ff2827060755"},
- {file = "watchfiles-1.0.0-cp310-none-win_amd64.whl", hash = "sha256:2de52b499e1ab037f1a87cb8ebcb04a819bf087b1015a4cf6dcf8af3c2a2613e"},
- {file = "watchfiles-1.0.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:fbd0ab7a9943bbddb87cbc2bf2f09317e74c77dc55b1f5657f81d04666c25269"},
- {file = "watchfiles-1.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:774ef36b16b7198669ce655d4f75b4c3d370e7f1cbdfb997fb10ee98717e2058"},
- {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b4fb98100267e6a5ebaff6aaa5d20aea20240584647470be39fe4823012ac96"},
- {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0fc3bf0effa2d8075b70badfdd7fb839d7aa9cea650d17886982840d71fdeabf"},
- {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:648e2b6db53eca6ef31245805cd528a16f56fa4cc15aeec97795eaf713c11435"},
- {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa13d604fcb9417ae5f2e3de676e66aa97427d888e83662ad205bed35a313176"},
- {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:936f362e7ff28311b16f0b97ec51e8f2cc451763a3264640c6ed40fb252d1ee4"},
- {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:245fab124b9faf58430da547512d91734858df13f2ddd48ecfa5e493455ffccb"},
- {file = "watchfiles-1.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4ff9c7e84e8b644a8f985c42bcc81457240316f900fc72769aaedec9d088055a"},
- {file = "watchfiles-1.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9c9a8d8fd97defe935ef8dd53d562e68942ad65067cd1c54d6ed8a088b1d931d"},
- {file = "watchfiles-1.0.0-cp311-none-win32.whl", hash = "sha256:a0abf173975eb9dd17bb14c191ee79999e650997cc644562f91df06060610e62"},
- {file = "watchfiles-1.0.0-cp311-none-win_amd64.whl", hash = "sha256:2a825ba4b32c214e3855b536eb1a1f7b006511d8e64b8215aac06eb680642d84"},
- {file = "watchfiles-1.0.0-cp311-none-win_arm64.whl", hash = "sha256:a5a7a06cfc65e34fd0a765a7623c5ba14707a0870703888e51d3d67107589817"},
- {file = "watchfiles-1.0.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:28fb64b5843d94e2c2483f7b024a1280662a44409bedee8f2f51439767e2d107"},
- {file = "watchfiles-1.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e3750434c83b61abb3163b49c64b04180b85b4dabb29a294513faec57f2ffdb7"},
- {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bedf84835069f51c7b026b3ca04e2e747ea8ed0a77c72006172c72d28c9f69fc"},
- {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:90004553be36427c3d06ec75b804233f8f816374165d5225b93abd94ba6e7234"},
- {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b46e15c34d4e401e976d6949ad3a74d244600d5c4b88c827a3fdf18691a46359"},
- {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:487d15927f1b0bd24e7df921913399bb1ab94424c386bea8b267754d698f8f0e"},
- {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1ff236d7a3f4b0a42f699a22fc374ba526bc55048a70cbb299661158e1bb5e1f"},
- {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c01446626574561756067f00b37e6b09c8622b0fc1e9fdbc7cbcea328d4e514"},
- {file = "watchfiles-1.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b551c465a59596f3d08170bd7e1c532c7260dd90ed8135778038e13c5d48aa81"},
- {file = "watchfiles-1.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e1ed613ee107269f66c2df631ec0fc8efddacface85314d392a4131abe299f00"},
- {file = "watchfiles-1.0.0-cp312-none-win32.whl", hash = "sha256:5f75cd42e7e2254117cf37ff0e68c5b3f36c14543756b2da621408349bd9ca7c"},
- {file = "watchfiles-1.0.0-cp312-none-win_amd64.whl", hash = "sha256:cf517701a4a872417f4e02a136e929537743461f9ec6cdb8184d9a04f4843545"},
- {file = "watchfiles-1.0.0-cp312-none-win_arm64.whl", hash = "sha256:8a2127cd68950787ee36753e6d401c8ea368f73beaeb8e54df5516a06d1ecd82"},
- {file = "watchfiles-1.0.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:95de85c254f7fe8cbdf104731f7f87f7f73ae229493bebca3722583160e6b152"},
- {file = "watchfiles-1.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:533a7cbfe700e09780bb31c06189e39c65f06c7f447326fee707fd02f9a6e945"},
- {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2218e78e2c6c07b1634a550095ac2a429026b2d5cbcd49a594f893f2bb8c936"},
- {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9122b8fdadc5b341315d255ab51d04893f417df4e6c1743b0aac8bf34e96e025"},
- {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9272fdbc0e9870dac3b505bce1466d386b4d8d6d2bacf405e603108d50446940"},
- {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4a3b33c3aefe9067ebd87846806cd5fc0b017ab70d628aaff077ab9abf4d06b3"},
- {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bc338ce9f8846543d428260fa0f9a716626963148edc937d71055d01d81e1525"},
- {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ac778a460ea22d63c7e6fb0bc0f5b16780ff0b128f7f06e57aaec63bd339285"},
- {file = "watchfiles-1.0.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:53ae447f06f8f29f5ab40140f19abdab822387a7c426a369eb42184b021e97eb"},
- {file = "watchfiles-1.0.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1f73c2147a453315d672c1ad907abe6d40324e34a185b51e15624bc793f93cc6"},
- {file = "watchfiles-1.0.0-cp313-none-win32.whl", hash = "sha256:eba98901a2eab909dbd79681190b9049acc650f6111fde1845484a4450761e98"},
- {file = "watchfiles-1.0.0-cp313-none-win_amd64.whl", hash = "sha256:d562a6114ddafb09c33246c6ace7effa71ca4b6a2324a47f4b09b6445ea78941"},
- {file = "watchfiles-1.0.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3d94fd83ed54266d789f287472269c0def9120a2022674990bd24ad989ebd7a0"},
- {file = "watchfiles-1.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:48051d1c504448b2fcda71c5e6e3610ae45de6a0b8f5a43b961f250be4bdf5a8"},
- {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29cf884ad4285d23453c702ed03d689f9c0e865e3c85d20846d800d4787de00f"},
- {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d3572d4c34c4e9c33d25b3da47d9570d5122f8433b9ac6519dca49c2740d23cd"},
- {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c2696611182c85eb0e755b62b456f48debff484b7306b56f05478b843ca8ece"},
- {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:550109001920a993a4383b57229c717fa73627d2a4e8fcb7ed33c7f1cddb0c85"},
- {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b555a93c15bd2c71081922be746291d776d47521a00703163e5fbe6d2a402399"},
- {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:947ccba18a38b85c366dafeac8df2f6176342d5992ca240a9d62588b214d731f"},
- {file = "watchfiles-1.0.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ffd98a299b0a74d1b704ef0ed959efb753e656a4e0425c14e46ae4c3cbdd2919"},
- {file = "watchfiles-1.0.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f8c4f3a1210ed099a99e6a710df4ff2f8069411059ffe30fa5f9467ebed1256b"},
- {file = "watchfiles-1.0.0-cp39-none-win32.whl", hash = "sha256:1e176b6b4119b3f369b2b4e003d53a226295ee862c0962e3afd5a1c15680b4e3"},
- {file = "watchfiles-1.0.0-cp39-none-win_amd64.whl", hash = "sha256:2d9c0518fabf4a3f373b0a94bb9e4ea7a1df18dec45e26a4d182aa8918dee855"},
- {file = "watchfiles-1.0.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f159ac795785cde4899e0afa539f4c723fb5dd336ce5605bc909d34edd00b79b"},
- {file = "watchfiles-1.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:c3d258d78341d5d54c0c804a5b7faa66cd30ba50b2756a7161db07ce15363b8d"},
- {file = "watchfiles-1.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bbd0311588c2de7f9ea5cf3922ccacfd0ec0c1922870a2be503cc7df1ca8be7"},
- {file = "watchfiles-1.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9a13ac46b545a7d0d50f7641eefe47d1597e7d1783a5d89e09d080e6dff44b0"},
- {file = "watchfiles-1.0.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2bca898c1dc073912d3db7fa6926cc08be9575add9e84872de2c99c688bac4e"},
- {file = "watchfiles-1.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:06d828fe2adc4ac8a64b875ca908b892a3603d596d43e18f7948f3fef5fc671c"},
- {file = "watchfiles-1.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:074c7618cd6c807dc4eaa0982b4a9d3f8051cd0b72793511848fd64630174b17"},
- {file = "watchfiles-1.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95dc785bc284552d044e561b8f4fe26d01ab5ca40d35852a6572d542adfeb4bc"},
- {file = "watchfiles-1.0.0.tar.gz", hash = "sha256:37566c844c9ce3b5deb964fe1a23378e575e74b114618d211fbda8f59d7b5dab"},
+ {file = "watchfiles-1.0.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:1da46bb1eefb5a37a8fb6fd52ad5d14822d67c498d99bda8754222396164ae42"},
+ {file = "watchfiles-1.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2b961b86cd3973f5822826017cad7f5a75795168cb645c3a6b30c349094e02e3"},
+ {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34e87c7b3464d02af87f1059fedda5484e43b153ef519e4085fe1a03dd94801e"},
+ {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d9dd2b89a16cf7ab9c1170b5863e68de6bf83db51544875b25a5f05a7269e678"},
+ {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b4691234d31686dca133c920f94e478b548a8e7c750f28dbbc2e4333e0d3da9"},
+ {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:90b0fe1fcea9bd6e3084b44875e179b4adcc4057a3b81402658d0eb58c98edf8"},
+ {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0b90651b4cf9e158d01faa0833b073e2e37719264bcee3eac49fc3c74e7d304b"},
+ {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2e9fe695ff151b42ab06501820f40d01310fbd58ba24da8923ace79cf6d702d"},
+ {file = "watchfiles-1.0.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62691f1c0894b001c7cde1195c03b7801aaa794a837bd6eef24da87d1542838d"},
+ {file = "watchfiles-1.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:275c1b0e942d335fccb6014d79267d1b9fa45b5ac0639c297f1e856f2f532552"},
+ {file = "watchfiles-1.0.3-cp310-cp310-win32.whl", hash = "sha256:06ce08549e49ba69ccc36fc5659a3d0ff4e3a07d542b895b8a9013fcab46c2dc"},
+ {file = "watchfiles-1.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:f280b02827adc9d87f764972fbeb701cf5611f80b619c20568e1982a277d6146"},
+ {file = "watchfiles-1.0.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ffe709b1d0bc2e9921257569675674cafb3a5f8af689ab9f3f2b3f88775b960f"},
+ {file = "watchfiles-1.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:418c5ce332f74939ff60691e5293e27c206c8164ce2b8ce0d9abf013003fb7fe"},
+ {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f492d2907263d6d0d52f897a68647195bc093dafed14508a8d6817973586b6b"},
+ {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:48c9f3bc90c556a854f4cab6a79c16974099ccfa3e3e150673d82d47a4bc92c9"},
+ {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75d3bcfa90454dba8df12adc86b13b6d85fda97d90e708efc036c2760cc6ba44"},
+ {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5691340f259b8f76b45fb31b98e594d46c36d1dc8285efa7975f7f50230c9093"},
+ {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1e263cc718545b7f897baeac1f00299ab6fabe3e18caaacacb0edf6d5f35513c"},
+ {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c6cf7709ed3e55704cc06f6e835bf43c03bc8e3cb8ff946bf69a2e0a78d9d77"},
+ {file = "watchfiles-1.0.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:703aa5e50e465be901e0e0f9d5739add15e696d8c26c53bc6fc00eb65d7b9469"},
+ {file = "watchfiles-1.0.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bfcae6aecd9e0cb425f5145afee871465b98b75862e038d42fe91fd753ddd780"},
+ {file = "watchfiles-1.0.3-cp311-cp311-win32.whl", hash = "sha256:6a76494d2c5311584f22416c5a87c1e2cb954ff9b5f0988027bc4ef2a8a67181"},
+ {file = "watchfiles-1.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:cf745cbfad6389c0e331786e5fe9ae3f06e9d9c2ce2432378e1267954793975c"},
+ {file = "watchfiles-1.0.3-cp311-cp311-win_arm64.whl", hash = "sha256:2dcc3f60c445f8ce14156854a072ceb36b83807ed803d37fdea2a50e898635d6"},
+ {file = "watchfiles-1.0.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:93436ed550e429da007fbafb723e0769f25bae178fbb287a94cb4ccdf42d3af3"},
+ {file = "watchfiles-1.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c18f3502ad0737813c7dad70e3e1cc966cc147fbaeef47a09463bbffe70b0a00"},
+ {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a5bc3ca468bb58a2ef50441f953e1f77b9a61bd1b8c347c8223403dc9b4ac9a"},
+ {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0d1ec043f02ca04bf21b1b32cab155ce90c651aaf5540db8eb8ad7f7e645cba8"},
+ {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f58d3bfafecf3d81c15d99fc0ecf4319e80ac712c77cf0ce2661c8cf8bf84066"},
+ {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1df924ba82ae9e77340101c28d56cbaff2c991bd6fe8444a545d24075abb0a87"},
+ {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:632a52dcaee44792d0965c17bdfe5dc0edad5b86d6a29e53d6ad4bf92dc0ff49"},
+ {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bf4b459d94a0387617a1b499f314aa04d8a64b7a0747d15d425b8c8b151da0"},
+ {file = "watchfiles-1.0.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ca94c85911601b097d53caeeec30201736ad69a93f30d15672b967558df02885"},
+ {file = "watchfiles-1.0.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:65ab1fb635476f6170b07e8e21db0424de94877e4b76b7feabfe11f9a5fc12b5"},
+ {file = "watchfiles-1.0.3-cp312-cp312-win32.whl", hash = "sha256:49bc1bc26abf4f32e132652f4b3bfeec77d8f8f62f57652703ef127e85a3e38d"},
+ {file = "watchfiles-1.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:48681c86f2cb08348631fed788a116c89c787fdf1e6381c5febafd782f6c3b44"},
+ {file = "watchfiles-1.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:9e080cf917b35b20c889225a13f290f2716748362f6071b859b60b8847a6aa43"},
+ {file = "watchfiles-1.0.3-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:e153a690b7255c5ced17895394b4f109d5dcc2a4f35cb809374da50f0e5c456a"},
+ {file = "watchfiles-1.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ac1be85fe43b4bf9a251978ce5c3bb30e1ada9784290441f5423a28633a958a7"},
+ {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2ec98e31e1844eac860e70d9247db9d75440fc8f5f679c37d01914568d18721"},
+ {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0179252846be03fa97d4d5f8233d1c620ef004855f0717712ae1c558f1974a16"},
+ {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:995c374e86fa82126c03c5b4630c4e312327ecfe27761accb25b5e1d7ab50ec8"},
+ {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29b9cb35b7f290db1c31fb2fdf8fc6d3730cfa4bca4b49761083307f441cac5a"},
+ {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f8dc09ae69af50bead60783180f656ad96bd33ffbf6e7a6fce900f6d53b08f1"},
+ {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:489b80812f52a8d8c7b0d10f0d956db0efed25df2821c7a934f6143f76938bd6"},
+ {file = "watchfiles-1.0.3-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:228e2247de583475d4cebf6b9af5dc9918abb99d1ef5ee737155bb39fb33f3c0"},
+ {file = "watchfiles-1.0.3-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1550be1a5cb3be08a3fb84636eaafa9b7119b70c71b0bed48726fd1d5aa9b868"},
+ {file = "watchfiles-1.0.3-cp313-cp313-win32.whl", hash = "sha256:16db2d7e12f94818cbf16d4c8938e4d8aaecee23826344addfaaa671a1527b07"},
+ {file = "watchfiles-1.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:160eff7d1267d7b025e983ca8460e8cc67b328284967cbe29c05f3c3163711a3"},
+ {file = "watchfiles-1.0.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c05b021f7b5aa333124f2a64d56e4cb9963b6efdf44e8d819152237bbd93ba15"},
+ {file = "watchfiles-1.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:310505ad305e30cb6c5f55945858cdbe0eb297fc57378f29bacceb534ac34199"},
+ {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ddff3f8b9fa24a60527c137c852d0d9a7da2a02cf2151650029fdc97c852c974"},
+ {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:46e86ed457c3486080a72bc837300dd200e18d08183f12b6ca63475ab64ed651"},
+ {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f79fe7993e230a12172ce7d7c7db061f046f672f2b946431c81aff8f60b2758b"},
+ {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea2b51c5f38bad812da2ec0cd7eec09d25f521a8b6b6843cbccedd9a1d8a5c15"},
+ {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fe4e740ea94978b2b2ab308cbf9270a246bcbb44401f77cc8740348cbaeac3d"},
+ {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9af037d3df7188ae21dc1c7624501f2f90d81be6550904e07869d8d0e6766655"},
+ {file = "watchfiles-1.0.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:52bb50a4c4ca2a689fdba84ba8ecc6a4e6210f03b6af93181bb61c4ec3abaf86"},
+ {file = "watchfiles-1.0.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c14a07bdb475eb696f85c715dbd0f037918ccbb5248290448488a0b4ef201aad"},
+ {file = "watchfiles-1.0.3-cp39-cp39-win32.whl", hash = "sha256:be37f9b1f8934cd9e7eccfcb5612af9fb728fecbe16248b082b709a9d1b348bf"},
+ {file = "watchfiles-1.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:ef9ec8068cf23458dbf36a08e0c16f0a2df04b42a8827619646637be1769300a"},
+ {file = "watchfiles-1.0.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:84fac88278f42d61c519a6c75fb5296fd56710b05bbdcc74bdf85db409a03780"},
+ {file = "watchfiles-1.0.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:c68be72b1666d93b266714f2d4092d78dc53bd11cf91ed5a3c16527587a52e29"},
+ {file = "watchfiles-1.0.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:889a37e2acf43c377b5124166bece139b4c731b61492ab22e64d371cce0e6e80"},
+ {file = "watchfiles-1.0.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ca05cacf2e5c4a97d02a2878a24020daca21dbb8823b023b978210a75c79098"},
+ {file = "watchfiles-1.0.3-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:8af4b582d5fc1b8465d1d2483e5e7b880cc1a4e99f6ff65c23d64d070867ac58"},
+ {file = "watchfiles-1.0.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:127de3883bdb29dbd3b21f63126bb8fa6e773b74eaef46521025a9ce390e1073"},
+ {file = "watchfiles-1.0.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:713f67132346bdcb4c12df185c30cf04bdf4bf6ea3acbc3ace0912cab6b7cb8c"},
+ {file = "watchfiles-1.0.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abd85de513eb83f5ec153a802348e7a5baa4588b818043848247e3e8986094e8"},
+ {file = "watchfiles-1.0.3.tar.gz", hash = "sha256:f3ff7da165c99a5412fe5dd2304dd2dbaaaa5da718aad942dcb3a178eaa70c56"},
]
[package.dependencies]
@@ -11095,4 +11173,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.11,<3.13"
-content-hash = "14476bf95504a4df4b8d5a5c6608c6aa3dae7499d27d1e41ef39d761cc7c693d"
+content-hash = "f4accd01805cbf080c4c5295f97a06c8e4faec7365d2c43d0435e56b46461732"
diff --git a/api/pyproject.toml b/api/pyproject.toml
index da9eabecf55ccf..28e0305406a18b 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -60,6 +60,7 @@ oci = "~2.135.1"
openai = "~1.52.0"
openpyxl = "~3.1.5"
pandas = { version = "~2.2.2", extras = ["performance", "excel"] }
+pandas-stubs = "~2.2.3.241009"
psycopg2-binary = "~2.9.6"
pycryptodome = "3.19.1"
pydantic = "~2.9.2"
@@ -84,6 +85,7 @@ tencentcloud-sdk-python-hunyuan = "~3.0.1158"
tiktoken = "~0.8.0"
tokenizers = "~0.15.0"
transformers = "~4.35.0"
+types-pytz = "~2024.2.0.20241003"
unstructured = { version = "~0.16.1", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] }
validators = "0.21.0"
volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"}
@@ -173,6 +175,7 @@ optional = true
[tool.poetry.group.dev.dependencies]
coverage = "~7.2.4"
faker = "~32.1.0"
+mypy = "~1.13.0"
pytest = "~8.3.2"
pytest-benchmark = "~4.0.0"
pytest-env = "~1.1.3"
diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py
index 97e5c77e95361a..48bdc872f41e5c 100644
--- a/api/schedule/clean_messages.py
+++ b/api/schedule/clean_messages.py
@@ -32,8 +32,9 @@ def clean_messages():
while True:
try:
# Main query with join and filter
+ # FIXME:for mypy no paginate method error
messages = (
- db.session.query(Message)
+ db.session.query(Message) # type: ignore
.filter(Message.created_at < plan_sandbox_clean_message_day)
.order_by(Message.created_at.desc())
.limit(100)
diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py
index e12be649e4d02d..f66b3c47979435 100644
--- a/api/schedule/clean_unused_datasets_task.py
+++ b/api/schedule/clean_unused_datasets_task.py
@@ -52,8 +52,7 @@ def clean_unused_datasets_task():
# Main query with join and filter
datasets = (
- db.session.query(Dataset)
- .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
+ Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter(
Dataset.created_at < plan_sandbox_clean_day,
@@ -120,8 +119,7 @@ def clean_unused_datasets_task():
# Main query with join and filter
datasets = (
- db.session.query(Dataset)
- .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
+ Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter(
Dataset.created_at < plan_pro_clean_day,
diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py
index a20b500308a4d6..1c985461c6aa2e 100644
--- a/api/schedule/create_tidb_serverless_task.py
+++ b/api/schedule/create_tidb_serverless_task.py
@@ -36,14 +36,15 @@ def create_tidb_serverless_task():
def create_clusters(batch_size):
try:
+ # TODO: maybe we can set the default value for the following parameters in the config file
new_clusters = TidbService.batch_create_tidb_serverless_cluster(
- batch_size,
- dify_config.TIDB_PROJECT_ID,
- dify_config.TIDB_API_URL,
- dify_config.TIDB_IAM_API_URL,
- dify_config.TIDB_PUBLIC_KEY,
- dify_config.TIDB_PRIVATE_KEY,
- dify_config.TIDB_REGION,
+ batch_size=batch_size,
+ project_id=dify_config.TIDB_PROJECT_ID or "",
+ api_url=dify_config.TIDB_API_URL or "",
+ iam_url=dify_config.TIDB_IAM_API_URL or "",
+ public_key=dify_config.TIDB_PUBLIC_KEY or "",
+ private_key=dify_config.TIDB_PRIVATE_KEY or "",
+ region=dify_config.TIDB_REGION or "",
)
for new_cluster in new_clusters:
tidb_auth_binding = TidbAuthBinding(
diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py
index b2d8746f9ca8f4..11a39e60ee4ce5 100644
--- a/api/schedule/update_tidb_serverless_status_task.py
+++ b/api/schedule/update_tidb_serverless_status_task.py
@@ -36,13 +36,14 @@ def update_clusters(tidb_serverless_list: list[TidbAuthBinding]):
# batch 20
for i in range(0, len(tidb_serverless_list), 20):
items = tidb_serverless_list[i : i + 20]
+ # TODO: maybe we can set the default value for the following parameters in the config file
TidbService.batch_update_tidb_serverless_cluster_status(
- items,
- dify_config.TIDB_PROJECT_ID,
- dify_config.TIDB_API_URL,
- dify_config.TIDB_IAM_API_URL,
- dify_config.TIDB_PUBLIC_KEY,
- dify_config.TIDB_PRIVATE_KEY,
+ tidb_serverless_list=items,
+ project_id=dify_config.TIDB_PROJECT_ID or "",
+ api_url=dify_config.TIDB_API_URL or "",
+ iam_url=dify_config.TIDB_IAM_API_URL or "",
+ public_key=dify_config.TIDB_PUBLIC_KEY or "",
+ private_key=dify_config.TIDB_PRIVATE_KEY or "",
)
except Exception as e:
click.echo(click.style(f"Error: {e}", fg="red"))
diff --git a/api/services/account_service.py b/api/services/account_service.py
index 22b54a3ab87473..91075ec46b16bf 100644
--- a/api/services/account_service.py
+++ b/api/services/account_service.py
@@ -6,7 +6,7 @@
import uuid
from datetime import UTC, datetime, timedelta
from hashlib import sha256
-from typing import Any, Optional
+from typing import Any, Optional, cast
from pydantic import BaseModel
from sqlalchemy import func
@@ -119,7 +119,7 @@ def load_user(user_id: str) -> None | Account:
account.last_active_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
- return account
+ return cast(Account, account)
@staticmethod
def get_account_jwt_token(account: Account) -> str:
@@ -132,7 +132,7 @@ def get_account_jwt_token(account: Account) -> str:
"sub": "Console API Passport",
}
- token = PassportService().issue(payload)
+ token: str = PassportService().issue(payload)
return token
@staticmethod
@@ -164,7 +164,7 @@ def authenticate(email: str, password: str, invite_token: Optional[str] = None)
db.session.commit()
- return account
+ return cast(Account, account)
@staticmethod
def update_account_password(account, password, new_password):
@@ -347,6 +347,8 @@ def send_reset_password_email(
language: Optional[str] = "en-US",
):
account_email = account.email if account else email
+ if account_email is None:
+ raise ValueError("Email must be provided.")
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import PasswordResetRateLimitExceededError
@@ -377,6 +379,8 @@ def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
):
+ if email is None:
+ raise ValueError("Email must be provided.")
if cls.email_code_login_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
@@ -669,7 +673,7 @@ def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoi
@staticmethod
def get_tenant_count() -> int:
"""Get tenant count"""
- return db.session.query(func.count(Tenant.id)).scalar()
+ return cast(int, db.session.query(func.count(Tenant.id)).scalar())
@staticmethod
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None:
@@ -733,10 +737,10 @@ def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
db.session.commit()
@staticmethod
- def get_custom_config(tenant_id: str) -> None:
- tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).one_or_404()
+ def get_custom_config(tenant_id: str) -> dict:
+ tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404()
- return tenant.custom_config_dict
+ return cast(dict, tenant.custom_config_dict)
class RegisterService:
@@ -807,7 +811,7 @@ def register(
account.status = AccountStatus.ACTIVE.value if not status else status.value
account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
- if open_id is not None or provider is not None:
+ if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account)
if FeatureService.get_system_features().is_allow_create_workspace:
@@ -828,10 +832,11 @@ def register(
@classmethod
def invite_new_member(
- cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None
+ cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Optional[Account] = None
) -> str:
"""Invite new member"""
account = Account.query.filter_by(email=email).first()
+ assert inviter is not None, "Inviter must be provided."
if not account:
TenantService.check_member_permission(tenant, inviter, None, "add")
@@ -894,7 +899,9 @@ def revoke_token(cls, workspace_id: str, email: str, token: str):
redis_client.delete(cls._get_invitation_token_key(token))
@classmethod
- def get_invitation_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[dict[str, Any]]:
+ def get_invitation_if_token_valid(
+ cls, workspace_id: Optional[str], email: str, token: str
+ ) -> Optional[dict[str, Any]]:
invitation_data = cls._get_invitation_by_token(token, workspace_id, email)
if not invitation_data:
return None
@@ -953,7 +960,7 @@ def _get_invitation_by_token(
if not data:
return None
- invitation = json.loads(data)
+ invitation: dict = json.loads(data)
return invitation
diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py
index d2cd7bea67c5b6..6dc1affa11d036 100644
--- a/api/services/advanced_prompt_template_service.py
+++ b/api/services/advanced_prompt_template_service.py
@@ -48,6 +48,8 @@ def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) ->
return cls.get_chat_prompt(
copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
)
+ # default return empty dict
+ return {}
@classmethod
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
@@ -91,3 +93,5 @@ def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
+ # default return empty dict
+ return {}
diff --git a/api/services/agent_service.py b/api/services/agent_service.py
index c8819535f11a39..b02f762ad267b8 100644
--- a/api/services/agent_service.py
+++ b/api/services/agent_service.py
@@ -1,5 +1,7 @@
+from typing import Optional
+
import pytz
-from flask_login import current_user
+from flask_login import current_user # type: ignore
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
from core.tools.tool_manager import ToolManager
@@ -14,7 +16,7 @@ def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -
"""
Service to get agent logs
"""
- conversation: Conversation = (
+ conversation: Optional[Conversation] = (
db.session.query(Conversation)
.filter(
Conversation.id == conversation_id,
@@ -26,7 +28,7 @@ def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -
if not conversation:
raise ValueError(f"Conversation not found: {conversation_id}")
- message: Message = (
+ message: Optional[Message] = (
db.session.query(Message)
.filter(
Message.id == message_id,
@@ -72,7 +74,10 @@ def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -
}
agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict())
- agent_tools = agent_config.tools
+ if not agent_config:
+ return result
+
+ agent_tools = agent_config.tools or []
def find_agent_tool(tool_name: str):
for agent_tool in agent_tools:
diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py
index f45c21cb18f5e3..a946405c955cec 100644
--- a/api/services/annotation_service.py
+++ b/api/services/annotation_service.py
@@ -1,8 +1,9 @@
import datetime
import uuid
+from typing import cast
import pandas as pd
-from flask_login import current_user
+from flask_login import current_user # type: ignore
from sqlalchemy import or_
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@@ -71,7 +72,7 @@ def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> Messa
app_id,
annotation_setting.collection_binding_id,
)
- return annotation
+ return cast(MessageAnnotation, annotation)
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
@@ -124,8 +125,7 @@ def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keywo
raise NotFound("App not found")
if keyword:
annotations = (
- db.session.query(MessageAnnotation)
- .filter(MessageAnnotation.app_id == app_id)
+ MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
.filter(
or_(
MessageAnnotation.question.ilike("%{}%".format(keyword)),
@@ -137,8 +137,7 @@ def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keywo
)
else:
annotations = (
- db.session.query(MessageAnnotation)
- .filter(MessageAnnotation.app_id == app_id)
+ MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
@@ -327,8 +326,7 @@ def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, lim
raise NotFound("Annotation not found")
annotation_hit_histories = (
- db.session.query(AppAnnotationHitHistory)
- .filter(
+ AppAnnotationHitHistory.query.filter(
AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id,
)
diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py
index 7c1a175988071e..b191fa2397fa9e 100644
--- a/api/services/app_dsl_service.py
+++ b/api/services/app_dsl_service.py
@@ -1,7 +1,7 @@
import logging
import uuid
from enum import StrEnum
-from typing import Optional
+from typing import Optional, cast
from uuid import uuid4
import yaml
@@ -103,7 +103,7 @@ def import_app(
raise ValueError(f"Invalid import_mode: {import_mode}")
# Get YAML content
- content = ""
+ content: bytes | str = b""
if mode == ImportMode.YAML_URL:
if not yaml_url:
return Import(
@@ -136,7 +136,7 @@ def import_app(
)
try:
- content = content.decode("utf-8")
+ content = cast(bytes, content).decode("utf-8")
except UnicodeDecodeError as e:
return Import(
id=import_id,
@@ -362,6 +362,9 @@ def _create_or_update_app(
app.icon_background = icon_background or app_data.get("icon_background", app.icon_background)
app.updated_by = account.id
else:
+ if account.current_tenant_id is None:
+ raise ValueError("Current tenant is not set")
+
# Create new app
app = App()
app.id = str(uuid4())
diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py
index 9def7d15e928d4..51aef7ccab9a0c 100644
--- a/api/services/app_generate_service.py
+++ b/api/services/app_generate_service.py
@@ -118,7 +118,7 @@ def generate(
@staticmethod
def _get_max_active_requests(app_model: App) -> int:
max_active_requests = app_model.max_active_requests
- if app_model.max_active_requests is None:
+ if max_active_requests is None:
max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS)
return max_active_requests
@@ -150,7 +150,7 @@ def generate_more_like_this(
message_id: str,
invoke_from: InvokeFrom,
streaming: bool = True,
- ) -> Union[dict, Generator]:
+ ) -> Union[Mapping, Generator]:
"""
Generate more like this
:param app_model: app model
diff --git a/api/services/app_service.py b/api/services/app_service.py
index 8d8ba735ecfa71..41c15bbf0a330b 100644
--- a/api/services/app_service.py
+++ b/api/services/app_service.py
@@ -1,9 +1,9 @@
import json
import logging
from datetime import UTC, datetime
-from typing import cast
+from typing import Optional, cast
-from flask_login import current_user
+from flask_login import current_user # type: ignore
from flask_sqlalchemy.pagination import Pagination
from configs import dify_config
@@ -83,7 +83,7 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App:
# get default model instance
try:
model_instance = model_manager.get_default_model_instance(
- tenant_id=account.current_tenant_id, model_type=ModelType.LLM
+ tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
)
except (ProviderTokenNotInitError, LLMBadRequestError):
model_instance = None
@@ -100,6 +100,8 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App:
else:
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
+ if model_schema is None:
+ raise ValueError(f"model schema not found for model {model_instance.model}")
default_model_dict = {
"provider": model_instance.provider,
@@ -109,7 +111,7 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App:
}
else:
provider, model = model_manager.get_default_provider_model_name(
- tenant_id=account.current_tenant_id, model_type=ModelType.LLM
+ tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM
)
default_model_config["model"]["provider"] = provider
default_model_config["model"]["name"] = model
@@ -314,7 +316,7 @@ def get_app_meta(self, app_model: App) -> dict:
"""
app_mode = AppMode.value_of(app_model.mode)
- meta = {"tool_icons": {}}
+ meta: dict = {"tool_icons": {}}
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
@@ -336,7 +338,7 @@ def get_app_meta(self, app_model: App) -> dict:
}
)
else:
- app_model_config: AppModelConfig = app_model.app_model_config
+ app_model_config: Optional[AppModelConfig] = app_model.app_model_config
if not app_model_config:
return meta
@@ -352,16 +354,18 @@ def get_app_meta(self, app_model: App) -> dict:
keys = list(tool.keys())
if len(keys) >= 4:
# current tool standard
- provider_type = tool.get("provider_type")
- provider_id = tool.get("provider_id")
- tool_name = tool.get("tool_name")
+ provider_type = tool.get("provider_type", "")
+ provider_id = tool.get("provider_id", "")
+ tool_name = tool.get("tool_name", "")
if provider_type == "builtin":
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
elif provider_type == "api":
try:
- provider: ApiToolProvider = (
+ provider: Optional[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first()
)
+ if provider is None:
+ raise ValueError(f"provider not found for tool {tool_name}")
meta["tool_icons"][tool_name] = json.loads(provider.icon)
except:
meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"}
diff --git a/api/services/audio_service.py b/api/services/audio_service.py
index 7a0cd5725b2a96..973110f5156523 100644
--- a/api/services/audio_service.py
+++ b/api/services/audio_service.py
@@ -110,6 +110,8 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None):
voices = model_instance.get_tts_voices()
if voices:
voice = voices[0].get("value")
+ if not voice:
+ raise ValueError("Sorry, no voice available.")
else:
raise ValueError("Sorry, no voice available.")
@@ -121,6 +123,8 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None):
if message_id:
message = db.session.query(Message).filter(Message.id == message_id).first()
+ if message is None:
+ return None
if message.answer == "" and message.status == "normal":
return None
@@ -130,6 +134,8 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None):
return Response(stream_with_context(response), content_type="audio/mpeg")
return response
else:
+ if not text:
+ raise ValueError("Text is required")
response = invoke_tts(text, app_model, voice)
if isinstance(response, Generator):
return Response(stream_with_context(response), content_type="audio/mpeg")
diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py
index afc491398f25f3..50e4edff140346 100644
--- a/api/services/auth/firecrawl/firecrawl.py
+++ b/api/services/auth/firecrawl/firecrawl.py
@@ -11,8 +11,8 @@ def __init__(self, credentials: dict):
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer")
- self.api_key = credentials.get("config").get("api_key", None)
- self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev")
+ self.api_key = credentials.get("config", {}).get("api_key", None)
+ self.base_url = credentials.get("config", {}).get("base_url", "https://api.firecrawl.dev")
if not self.api_key:
raise ValueError("No API key provided")
diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py
index de898a1f94b763..6100e9afc8f278 100644
--- a/api/services/auth/jina.py
+++ b/api/services/auth/jina.py
@@ -11,7 +11,7 @@ def __init__(self, credentials: dict):
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
- self.api_key = credentials.get("config").get("api_key", None)
+ self.api_key = credentials.get("config", {}).get("api_key", None)
if not self.api_key:
raise ValueError("No API key provided")
diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py
index de898a1f94b763..6100e9afc8f278 100644
--- a/api/services/auth/jina/jina.py
+++ b/api/services/auth/jina/jina.py
@@ -11,7 +11,7 @@ def __init__(self, credentials: dict):
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
- self.api_key = credentials.get("config").get("api_key", None)
+ self.api_key = credentials.get("config", {}).get("api_key", None)
if not self.api_key:
raise ValueError("No API key provided")
diff --git a/api/services/billing_service.py b/api/services/billing_service.py
index edc51682179cc5..d98018648839a9 100644
--- a/api/services/billing_service.py
+++ b/api/services/billing_service.py
@@ -1,4 +1,5 @@
import os
+from typing import Optional
import httpx
from tenacity import retry, retry_if_not_exception_type, stop_before_delay, wait_fixed
@@ -58,11 +59,14 @@ def _send_request(cls, method, endpoint, json=None, params=None):
def is_tenant_owner_or_admin(current_user):
tenant_id = current_user.current_tenant_id
- join = (
+ join: Optional[TenantAccountJoin] = (
db.session.query(TenantAccountJoin)
.filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
.first()
)
+ if not join:
+ raise ValueError("Tenant account join not found")
+
if not TenantAccountRole.is_privileged_role(join.role):
raise ValueError("Only team owner or team admin can perform this action")
diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py
index 456dc3ebebaa28..6485cbf37d5b7f 100644
--- a/api/services/conversation_service.py
+++ b/api/services/conversation_service.py
@@ -72,8 +72,7 @@ def pagination_by_last_id(
sort_direction=sort_direction,
reference_conversation=current_page_last_conversation,
)
- count_stmt = stmt.where(rest_filter_condition)
- count_stmt = select(func.count()).select_from(count_stmt.subquery())
+ count_stmt = select(func.count()).select_from(stmt.where(rest_filter_condition).subquery())
rest_count = session.scalar(count_stmt) or 0
if rest_count > 0:
has_more = True
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index 4e99c73ad4787a..d2d8a718d55c8a 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -6,7 +6,7 @@
import uuid
from typing import Any, Optional
-from flask_login import current_user
+from flask_login import current_user # type: ignore
from sqlalchemy import func
from werkzeug.exceptions import NotFound
@@ -186,8 +186,9 @@ def create_empty_dataset(
return dataset
@staticmethod
- def get_dataset(dataset_id) -> Dataset:
- return Dataset.query.filter_by(id=dataset_id).first()
+ def get_dataset(dataset_id) -> Optional[Dataset]:
+ dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first()
+ return dataset
@staticmethod
def check_dataset_model_setting(dataset):
@@ -228,6 +229,8 @@ def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str,
@staticmethod
def update_dataset(dataset_id, data, user):
dataset = DatasetService.get_dataset(dataset_id)
+ if not dataset:
+ raise ValueError("Dataset not found")
DatasetService.check_dataset_permission(dataset, user)
if dataset.provider == "external":
@@ -371,7 +374,13 @@ def check_dataset_permission(dataset, user):
raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod
- def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None):
+ def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None):
+ if not dataset:
+ raise ValueError("Dataset not found")
+
+ if not user:
+ raise ValueError("User not found")
+
if dataset.permission == DatasetPermissionEnum.ONLY_ME:
if dataset.created_by != user.id:
raise NoPermissionError("You do not have permission to access this dataset.")
@@ -765,6 +774,11 @@ def save_document_with_dataset_id(
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
+ else:
+ logging.warn(
+ f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule"
+ )
+ return
db.session.add(dataset_process_rule)
db.session.commit()
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
@@ -1009,9 +1023,10 @@ def update_document_with_dataset_id(
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
- db.session.add(dataset_process_rule)
- db.session.commit()
- document.dataset_process_rule_id = dataset_process_rule.id
+ if dataset_process_rule is not None:
+ db.session.add(dataset_process_rule)
+ db.session.commit()
+ document.dataset_process_rule_id = dataset_process_rule.id
# update document data source
if document_data.get("data_source"):
file_name = ""
@@ -1554,7 +1569,7 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document
segment.word_count = len(content)
if document.doc_form == "qa_model":
segment.answer = segment_update_entity.answer
- segment.word_count += len(segment_update_entity.answer)
+ segment.word_count += len(segment_update_entity.answer or "")
word_count_change = segment.word_count - word_count_change
if segment_update_entity.keywords:
segment.keywords = segment_update_entity.keywords
@@ -1569,7 +1584,8 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document
db.session.add(document)
# update segment index task
if segment_update_entity.enabled:
- VectorService.create_segments_vector([segment_update_entity.keywords], [segment], dataset)
+ keywords = segment_update_entity.keywords or []
+ VectorService.create_segments_vector([keywords], [segment], dataset)
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
@@ -1601,7 +1617,7 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document
segment.disabled_by = None
if document.doc_form == "qa_model":
segment.answer = segment_update_entity.answer
- segment.word_count += len(segment_update_entity.answer)
+ segment.word_count += len(segment_update_entity.answer or "")
word_count_change = segment.word_count - word_count_change
# update document word count
if word_count_change != 0:
@@ -1619,8 +1635,8 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document
segment.status = "error"
segment.error = str(e)
db.session.commit()
- segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
- return segment
+ new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
+ return new_segment
@classmethod
def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset):
@@ -1680,6 +1696,8 @@ def get_dataset_collection_binding_by_id_and_type(
.order_by(DatasetCollectionBinding.created_at)
.first()
)
+ if not dataset_collection_binding:
+ raise ValueError("Dataset collection binding not found")
return dataset_collection_binding
diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py
index 92098f06cca538..3c3f9704440342 100644
--- a/api/services/enterprise/base.py
+++ b/api/services/enterprise/base.py
@@ -8,8 +8,8 @@ class EnterpriseRequest:
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
proxies = {
- "http": None,
- "https": None,
+ "http": "",
+ "https": "",
}
@classmethod
diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py
index c519f0b0e51b68..334d009ee5f79f 100644
--- a/api/services/entities/model_provider_entities.py
+++ b/api/services/entities/model_provider_entities.py
@@ -4,7 +4,11 @@
from pydantic import BaseModel, ConfigDict
from configs import dify_config
-from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
+from core.entities.model_entities import (
+ ModelWithProviderEntity,
+ ProviderModelWithStatusEntity,
+ SimpleModelProviderEntity,
+)
from core.entities.provider_entities import QuotaConfiguration
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
@@ -148,7 +152,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity):
Model with provider entity.
"""
- provider: SimpleProviderEntityResponse
+ provider: SimpleModelProviderEntity
def __init__(self, model: ModelWithProviderEntity) -> None:
super().__init__(**model.model_dump())
diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py
index 7be20301a74b78..898624066bef7e 100644
--- a/api/services/external_knowledge_service.py
+++ b/api/services/external_knowledge_service.py
@@ -1,7 +1,7 @@
import json
from copy import deepcopy
from datetime import UTC, datetime
-from typing import Any, Optional, Union
+from typing import Any, Optional, Union, cast
import httpx
import validators
@@ -45,7 +45,10 @@ def validate_api_list(cls, api_settings: dict):
@staticmethod
def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis:
- ExternalDatasetService.check_endpoint_and_api_key(args.get("settings"))
+ settings = args.get("settings")
+ if settings is None:
+ raise ValueError("settings is required")
+ ExternalDatasetService.check_endpoint_and_api_key(settings)
external_knowledge_api = ExternalKnowledgeApis(
tenant_id=tenant_id,
created_by=user_id,
@@ -86,11 +89,16 @@ def check_endpoint_and_api_key(settings: dict):
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
- return ExternalKnowledgeApis.query.filter_by(id=external_knowledge_api_id).first()
+ external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
+ id=external_knowledge_api_id
+ ).first()
+ if external_knowledge_api is None:
+ raise ValueError("api template not found")
+ return external_knowledge_api
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
- external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
+ external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
if external_knowledge_api is None:
@@ -127,7 +135,7 @@ def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bo
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
- external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
+ external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id, tenant_id=tenant_id
).first()
if not external_knowledge_binding:
@@ -163,8 +171,9 @@ def process_external_api(
"follow_redirects": True,
}
- response = getattr(ssrf_proxy, settings.request_method)(data=json.dumps(settings.params), files=files, **kwargs)
-
+ response: httpx.Response = getattr(ssrf_proxy, settings.request_method)(
+ data=json.dumps(settings.params), files=files, **kwargs
+ )
return response
@staticmethod
@@ -265,15 +274,15 @@ def fetch_external_knowledge_retrieval(
"knowledge_id": external_knowledge_binding.external_knowledge_id,
}
- external_knowledge_api_setting = {
- "url": f"{settings.get('endpoint')}/retrieval",
- "request_method": "post",
- "headers": headers,
- "params": request_params,
- }
response = ExternalDatasetService.process_external_api(
- ExternalKnowledgeApiSetting(**external_knowledge_api_setting), None
+ ExternalKnowledgeApiSetting(
+ url=f"{settings.get('endpoint')}/retrieval",
+ request_method="post",
+ headers=headers,
+ params=request_params,
+ ),
+ None,
)
if response.status_code == 200:
- return response.json().get("records", [])
+ return cast(list[Any], response.json().get("records", []))
return []
diff --git a/api/services/file_service.py b/api/services/file_service.py
index b12b95ca13558c..d417e81734c8af 100644
--- a/api/services/file_service.py
+++ b/api/services/file_service.py
@@ -3,7 +3,7 @@
import uuid
from typing import Any, Literal, Union
-from flask_login import current_user
+from flask_login import current_user # type: ignore
from werkzeug.exceptions import NotFound
from configs import dify_config
@@ -61,14 +61,14 @@ def upload_file(
# end_user
current_tenant_id = user.tenant_id
- file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension
+ file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension
# save file to storage
storage.save(file_key, content)
# save file to db
upload_file = UploadFile(
- tenant_id=current_tenant_id,
+ tenant_id=current_tenant_id or "",
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=filename,
diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py
index 7957b4dc82dfd4..41b4e1ec46374a 100644
--- a/api/services/hit_testing_service.py
+++ b/api/services/hit_testing_service.py
@@ -1,5 +1,6 @@
import logging
import time
+from typing import Any
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
@@ -24,7 +25,7 @@ def retrieve(
dataset: Dataset,
query: str,
account: Account,
- retrieval_model: dict,
+ retrieval_model: Any, # FIXME drop this any
external_retrieval_model: dict,
limit: int = 10,
) -> dict:
@@ -68,7 +69,7 @@ def retrieve(
db.session.add(dataset_query)
db.session.commit()
- return cls.compact_retrieve_response(dataset, query, all_documents)
+ return dict(cls.compact_retrieve_response(dataset, query, all_documents))
@classmethod
def external_retrieve(
@@ -102,13 +103,16 @@ def external_retrieve(
db.session.add(dataset_query)
db.session.commit()
- return cls.compact_external_retrieve_response(dataset, query, all_documents)
+ return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
@classmethod
def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
records = []
for document in documents:
+ if document.metadata is None:
+ continue
+
index_node_id = document.metadata["doc_id"]
segment = (
@@ -140,7 +144,7 @@ def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list
}
@classmethod
- def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list):
+ def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]:
records = []
if dataset.provider == "external":
for document in documents:
@@ -152,11 +156,10 @@ def compact_external_retrieve_response(cls, dataset: Dataset, query: str, docume
}
records.append(record)
return {
- "query": {
- "content": query,
- },
+ "query": {"content": query},
"records": records,
}
+ return {"query": {"content": query}, "records": []}
@classmethod
def hit_testing_args_check(cls, args):
diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py
index 02fe1d19bc42be..8df1a6ba144d4e 100644
--- a/api/services/knowledge_service.py
+++ b/api/services/knowledge_service.py
@@ -1,4 +1,4 @@
-import boto3
+import boto3 # type: ignore
from configs import dify_config
diff --git a/api/services/message_service.py b/api/services/message_service.py
index be2922f4c58e76..c4447a84da5e09 100644
--- a/api/services/message_service.py
+++ b/api/services/message_service.py
@@ -157,7 +157,7 @@ def create_feedback(
user: Optional[Union[Account, EndUser]],
rating: Optional[str],
content: Optional[str],
- ) -> MessageFeedback:
+ ):
if not user:
raise ValueError("user cannot be None")
@@ -264,6 +264,8 @@ def get_suggested_questions_after_answer(
)
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
+ if not app_model_config:
+ raise ValueError("did not find app model config")
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
if suggested_questions_after_answer.get("enabled", False) is False:
@@ -285,7 +287,7 @@ def get_suggested_questions_after_answer(
)
with measure_time() as timer:
- questions = LLMGenerator.generate_suggested_questions_after_answer(
+ questions: list[Message] = LLMGenerator.generate_suggested_questions_after_answer(
tenant_id=app_model.tenant_id, histories=histories
)
diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py
index b20bda87551ca9..bacd3a8ec3d04f 100644
--- a/api/services/model_load_balancing_service.py
+++ b/api/services/model_load_balancing_service.py
@@ -2,7 +2,7 @@
import json
import logging
from json import JSONDecodeError
-from typing import Optional
+from typing import Optional, Union
from constants import HIDDEN_VALUE
from core.entities.provider_configuration import ProviderConfiguration
@@ -88,11 +88,11 @@ def get_load_balancing_configs(
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
- model_type = ModelType.value_of(model_type)
+ model_type_enum = ModelType.value_of(model_type)
# Get provider model setting
provider_model_setting = provider_configuration.get_provider_model_setting(
- model_type=model_type,
+ model_type=model_type_enum,
model=model,
)
@@ -106,7 +106,7 @@ def get_load_balancing_configs(
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
- LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+ LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.order_by(LoadBalancingModelConfig.created_at)
@@ -124,7 +124,7 @@ def get_load_balancing_configs(
if not inherit_config_exists:
# Initialize the inherit configuration
- inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type)
+ inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type_enum)
# prepend the inherit configuration
load_balancing_configs.insert(0, inherit_config)
@@ -148,7 +148,7 @@ def get_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=model,
- model_type=model_type,
+ model_type=model_type_enum,
config_id=load_balancing_config.id,
)
@@ -214,7 +214,7 @@ def get_load_balancing_config(
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
- model_type = ModelType.value_of(model_type)
+ model_type_enum = ModelType.value_of(model_type)
# Get load balancing configurations
load_balancing_model_config = (
@@ -222,7 +222,7 @@ def get_load_balancing_config(
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
- LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+ LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
@@ -300,7 +300,7 @@ def update_load_balancing_configs(
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
- model_type = ModelType.value_of(model_type)
+ model_type_enum = ModelType.value_of(model_type)
if not isinstance(configs, list):
raise ValueError("Invalid load balancing configs")
@@ -310,7 +310,7 @@ def update_load_balancing_configs(
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
- LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+ LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.all()
@@ -359,7 +359,7 @@ def update_load_balancing_configs(
credentials = self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
- model_type=model_type,
+ model_type=model_type_enum,
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_config,
@@ -395,7 +395,7 @@ def update_load_balancing_configs(
credentials = self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
- model_type=model_type,
+ model_type=model_type_enum,
model=model,
credentials=credentials,
validate=False,
@@ -405,7 +405,7 @@ def update_load_balancing_configs(
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
- model_type=model_type.to_origin_model_type(),
+ model_type=model_type_enum.to_origin_model_type(),
model_name=model,
name=name,
encrypted_config=json.dumps(credentials),
@@ -450,7 +450,7 @@ def validate_load_balancing_credentials(
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
- model_type = ModelType.value_of(model_type)
+ model_type_enum = ModelType.value_of(model_type)
load_balancing_model_config = None
if config_id:
@@ -460,7 +460,7 @@ def validate_load_balancing_credentials(
.filter(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
- LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
+ LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
@@ -474,7 +474,7 @@ def validate_load_balancing_credentials(
self._custom_credentials_validate(
tenant_id=tenant_id,
provider_configuration=provider_configuration,
- model_type=model_type,
+ model_type=model_type_enum,
model=model,
credentials=credentials,
load_balancing_model_config=load_balancing_model_config,
@@ -547,19 +547,14 @@ def _custom_credentials_validate(
def _get_credential_schema(
self, provider_configuration: ProviderConfiguration
- ) -> ModelCredentialSchema | ProviderCredentialSchema:
- """
- Get form schemas.
- :param provider_configuration: provider configuration
- :return:
- """
- # Get credential form schemas from model credential schema or provider credential schema
+ ) -> Union[ModelCredentialSchema, ProviderCredentialSchema]:
+ """Get form schemas."""
if provider_configuration.provider.model_credential_schema:
- credential_schema = provider_configuration.provider.model_credential_schema
+ return provider_configuration.provider.model_credential_schema
+ elif provider_configuration.provider.provider_credential_schema:
+ return provider_configuration.provider.provider_credential_schema
else:
- credential_schema = provider_configuration.provider.provider_credential_schema
-
- return credential_schema
+ raise ValueError("No credential schema found")
def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
"""
diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py
index 384a072b371fdd..b10c5ad2d616e9 100644
--- a/api/services/model_provider_service.py
+++ b/api/services/model_provider_service.py
@@ -7,7 +7,7 @@
import requests
from flask import current_app
-from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity
+from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@@ -100,23 +100,15 @@ def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWit
ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider)
]
- def get_provider_credentials(self, tenant_id: str, provider: str) -> dict:
+ def get_provider_credentials(self, tenant_id: str, provider: str):
"""
get provider credentials.
-
- :param tenant_id:
- :param provider:
- :return:
"""
- # Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
-
- # Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
- # Get provider custom credentials from workspace
return provider_configuration.get_custom_credentials(obfuscated=True)
def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None:
@@ -176,7 +168,7 @@ def remove_provider_credentials(self, tenant_id: str, provider: str) -> None:
# Remove custom provider credentials.
provider_configuration.delete_custom_credentials()
- def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> dict:
+ def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str):
"""
get model credentials.
@@ -287,7 +279,7 @@ def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[Prov
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type))
# Group models by provider
- provider_models = {}
+ provider_models: dict[str, list[ModelWithProviderEntity]] = {}
for model in models:
if model.provider.provider not in provider_models:
provider_models[model.provider.provider] = []
@@ -362,7 +354,7 @@ def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) -
return []
# Call get_parameter_rules method of model instance to get model parameter rules
- return model_type_instance.get_parameter_rules(model=model, credentials=credentials)
+ return list(model_type_instance.get_parameter_rules(model=model, credentials=credentials))
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
"""
@@ -422,6 +414,7 @@ def get_model_provider_icon(
"""
provider_instance = model_provider_factory.get_provider_instance(provider)
provider_schema = provider_instance.get_provider_schema()
+ file_name: str | None = None
if icon_type.lower() == "icon_small":
if not provider_schema.icon_small:
@@ -439,6 +432,8 @@ def get_model_provider_icon(
file_name = provider_schema.icon_large.zh_Hans
else:
file_name = provider_schema.icon_large.en_US
+ if not file_name:
+ return None, None
root_path = current_app.root_path
provider_instance_path = os.path.dirname(
@@ -524,7 +519,7 @@ def disable_model(self, tenant_id: str, provider: str, model: str, model_type: s
def free_quota_submit(self, tenant_id: str, provider: str):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
- api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
+ api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "")
api_url = api_base_url + "/api/v1/providers/apply"
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
@@ -545,7 +540,7 @@ def free_quota_submit(self, tenant_id: str, provider: str):
def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
- api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
+ api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "")
api_url = api_base_url + "/api/v1/providers/qualification-verify"
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py
index dfb21e767fc9b9..082afeed89a5e4 100644
--- a/api/services/moderation_service.py
+++ b/api/services/moderation_service.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
from core.moderation.factory import ModerationFactory, ModerationOutputsResult
from extensions.ext_database import db
from models.model import App, AppModelConfig
@@ -5,7 +7,7 @@
class ModerationService:
def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
- app_model_config: AppModelConfig = None
+ app_model_config: Optional[AppModelConfig] = None
app_model_config = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
diff --git a/api/services/ops_service.py b/api/services/ops_service.py
index 1160a1f2751d74..fc1e08518b1945 100644
--- a/api/services/ops_service.py
+++ b/api/services/ops_service.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
from extensions.ext_database import db
from models.model import App, TraceAppConfig
@@ -12,7 +14,7 @@ def get_tracing_app_config(cls, app_id: str, tracing_provider: str):
:param tracing_provider: tracing provider
:return:
"""
- trace_config_data: TraceAppConfig = (
+ trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
@@ -22,7 +24,10 @@ def get_tracing_app_config(cls, app_id: str, tracing_provider: str):
return None
# decrypt_token and obfuscated_token
- tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
+ tenant = db.session.query(App).filter(App.id == app_id).first()
+ if not tenant:
+ return None
+ tenant_id = tenant.tenant_id
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(
tenant_id, tracing_provider, trace_config_data.tracing_config
)
@@ -73,8 +78,9 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["other_keys"],
)
- default_config_instance = config_class(**tracing_config)
- for key in other_keys:
+ # FIXME: ignore type error
+ default_config_instance = config_class(**tracing_config) # type: ignore
+ for key in other_keys: # type: ignore
if key in tracing_config and tracing_config[key] == "":
tracing_config[key] = getattr(default_config_instance, key, None)
@@ -92,7 +98,7 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c
project_url = None
# check if trace config already exists
- trace_config_data: TraceAppConfig = (
+ trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
@@ -102,7 +108,10 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c
return None
# get tenant id
- tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
+ tenant = db.session.query(App).filter(App.id == app_id).first()
+ if not tenant:
+ return None
+ tenant_id = tenant.tenant_id
tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config)
if project_url:
tracing_config["project_url"] = project_url
@@ -139,7 +148,10 @@ def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c
return None
# get tenant id
- tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
+ tenant = db.session.query(App).filter(App.id == app_id).first()
+ if not tenant:
+ return None
+ tenant_id = tenant.tenant_id
tracing_config = OpsTraceManager.encrypt_tracing_config(
tenant_id, tracing_provider, tracing_config, current_trace_config.tracing_config
)
diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py
index 4704d533a950ed..523aebeed52a4e 100644
--- a/api/services/recommend_app/buildin/buildin_retrieval.py
+++ b/api/services/recommend_app/buildin/buildin_retrieval.py
@@ -41,7 +41,7 @@ def _get_builtin_data(cls) -> dict:
Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8")
)
- return cls.builtin_data
+ return cls.builtin_data or {}
@classmethod
def fetch_recommended_apps_from_builtin(cls, language: str) -> dict:
@@ -50,8 +50,8 @@ def fetch_recommended_apps_from_builtin(cls, language: str) -> dict:
:param language: language
:return:
"""
- builtin_data = cls._get_builtin_data()
- return builtin_data.get("recommended_apps", {}).get(language)
+ builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
+ return builtin_data.get("recommended_apps", {}).get(language, {})
@classmethod
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]:
@@ -60,5 +60,5 @@ def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict
:param app_id: App ID
:return:
"""
- builtin_data = cls._get_builtin_data()
+ builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
return builtin_data.get("app_details", {}).get(app_id)
diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py
index b0607a21323acb..80e1aefc01da85 100644
--- a/api/services/recommend_app/remote/remote_retrieval.py
+++ b/api/services/recommend_app/remote/remote_retrieval.py
@@ -47,8 +47,8 @@ def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optiona
response = requests.get(url, timeout=(3, 10))
if response.status_code != 200:
return None
-
- return response.json()
+ data: dict = response.json()
+ return data
@classmethod
def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
@@ -63,7 +63,7 @@ def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
if response.status_code != 200:
raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}")
- result = response.json()
+ result: dict = response.json()
if "categories" in result:
result["categories"] = sorted(result["categories"])
diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py
index 4660316fcfcf71..54c58455155c03 100644
--- a/api/services/recommended_app_service.py
+++ b/api/services/recommended_app_service.py
@@ -33,5 +33,5 @@ def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]:
"""
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
- result = retrieval_instance.get_recommend_app_detail(app_id)
+ result: dict = retrieval_instance.get_recommend_app_detail(app_id)
return result
diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py
index 9fe3cecce7546d..4cb8700117e6f3 100644
--- a/api/services/saved_message_service.py
+++ b/api/services/saved_message_service.py
@@ -13,6 +13,8 @@ class SavedMessageService:
def pagination_by_last_id(
cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int
) -> InfiniteScrollPagination:
+ if not user:
+ raise ValueError("User is required")
saved_messages = (
db.session.query(SavedMessage)
.filter(
@@ -31,6 +33,8 @@ def pagination_by_last_id(
@classmethod
def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
+ if not user:
+ return
saved_message = (
db.session.query(SavedMessage)
.filter(
@@ -59,6 +63,8 @@ def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_i
@classmethod
def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
+ if not user:
+ return
saved_message = (
db.session.query(SavedMessage)
.filter(
diff --git a/api/services/tag_service.py b/api/services/tag_service.py
index a374bdcf002bef..9600601633cddb 100644
--- a/api/services/tag_service.py
+++ b/api/services/tag_service.py
@@ -1,7 +1,7 @@
import uuid
from typing import Optional
-from flask_login import current_user
+from flask_login import current_user # type: ignore
from sqlalchemy import func
from werkzeug.exceptions import NotFound
@@ -21,7 +21,7 @@ def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = Non
if keyword:
query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.group_by(Tag.id)
- results = query.order_by(Tag.created_at.desc()).all()
+ results: list = query.order_by(Tag.created_at.desc()).all()
return results
@staticmethod
diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py
index 78a80f70ab6b00..0e3bd3a7b83c68 100644
--- a/api/services/tools/api_tools_manage_service.py
+++ b/api/services/tools/api_tools_manage_service.py
@@ -1,6 +1,7 @@
import json
import logging
-from typing import Optional
+from collections.abc import Mapping
+from typing import Any, Optional, cast
from httpx import get
@@ -28,12 +29,12 @@
class ApiToolManageService:
@staticmethod
- def parser_api_schema(schema: str) -> list[ApiToolBundle]:
+ def parser_api_schema(schema: str) -> Mapping[str, Any]:
"""
parse api schema to tool bundle
"""
try:
- warnings = {}
+ warnings: dict[str, str] = {}
try:
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
except Exception as e:
@@ -68,13 +69,16 @@ def parser_api_schema(schema: str) -> list[ApiToolBundle]:
),
]
- return jsonable_encoder(
- {
- "schema_type": schema_type,
- "parameters_schema": tool_bundles,
- "credentials_schema": credentials_schema,
- "warning": warnings,
- }
+ return cast(
+ Mapping,
+ jsonable_encoder(
+ {
+ "schema_type": schema_type,
+ "parameters_schema": tool_bundles,
+ "credentials_schema": credentials_schema,
+ "warning": warnings,
+ }
+ ),
)
except Exception as e:
raise ValueError(f"invalid schema: {str(e)}")
@@ -129,7 +133,7 @@ def create_api_tool_provider(
raise ValueError(f"provider {provider_name} already exists")
# parse openapi to tool bundle
- extra_info = {}
+ extra_info: dict[str, str] = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
@@ -262,9 +266,8 @@ def update_api_tool_provider(
if provider is None:
raise ValueError(f"api provider {provider_name} does not exists")
-
# parse openapi to tool bundle
- extra_info = {}
+ extra_info: dict[str, str] = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
@@ -416,7 +419,7 @@ def test_api_tool_preview(
provider_controller.validate_credentials_format(credentials)
# get tool
tool = provider_controller.get_tool(tool_name)
- tool = tool.fork_tool_runtime(
+ runtime_tool = tool.fork_tool_runtime(
runtime={
"credentials": credentials,
"tenant_id": tenant_id,
@@ -454,7 +457,7 @@ def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
- for tool in tools:
+ for tool in tools or []:
user_provider.tools.append(
ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py
index fada881fdeb741..21adbb0074724e 100644
--- a/api/services/tools/builtin_tools_manage_service.py
+++ b/api/services/tools/builtin_tools_manage_service.py
@@ -50,8 +50,8 @@ def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
- result = []
- for tool in tools:
+ result: list[UserTool] = []
+ for tool in tools or []:
result.append(
ToolTransformService.tool_to_user_tool(
tool=tool,
@@ -217,6 +217,8 @@ def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
name_func=lambda x: x.identity.name,
):
continue
+ if provider_controller.identity is None:
+ continue
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
@@ -229,7 +231,7 @@ def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
ToolTransformService.repack_provider(user_builtin_provider)
tools = provider_controller.get_tools()
- for tool in tools:
+ for tool in tools or []:
user_builtin_provider.tools.append(
ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id,
diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py
index a4aa870dc80352..b501554bcd091d 100644
--- a/api/services/tools/tools_transform_service.py
+++ b/api/services/tools/tools_transform_service.py
@@ -1,6 +1,6 @@
import json
import logging
-from typing import Optional, Union
+from typing import Optional, Union, cast
from configs import dify_config
from core.tools.entities.api_entities import UserTool, UserToolProvider
@@ -35,7 +35,7 @@ def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str
return url_prefix + "builtin/" + provider_name + "/icon"
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
try:
- return json.loads(icon)
+ return cast(dict, json.loads(icon))
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
@@ -53,8 +53,11 @@ def repack_provider(provider: Union[dict, UserToolProvider]):
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
)
elif isinstance(provider, UserToolProvider):
- provider.icon = ToolTransformService.get_tool_provider_icon_url(
- provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
+ provider.icon = cast(
+ str,
+ ToolTransformService.get_tool_provider_icon_url(
+ provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
+ ),
)
@staticmethod
@@ -66,6 +69,9 @@ def builtin_provider_to_user_provider(
"""
convert provider controller to user provider
"""
+ if provider_controller.identity is None:
+ raise ValueError("provider identity is None")
+
result = UserToolProvider(
id=provider_controller.identity.name,
author=provider_controller.identity.author,
@@ -93,7 +99,8 @@ def builtin_provider_to_user_provider(
# get credentials schema
schema = provider_controller.get_credentials_schema()
for name, value in schema.items():
- result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type)
+ assert result.masked_credentials is not None, "masked credentials is None"
+ result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(str(value.type))
# check if the provider need credentials
if not provider_controller.need_credentials:
@@ -149,6 +156,9 @@ def workflow_provider_to_user_provider(
"""
convert provider controller to user provider
"""
+ if provider_controller.identity is None:
+ raise ValueError("provider identity is None")
+
return UserToolProvider(
id=provider_controller.provider_id,
author=provider_controller.identity.author,
@@ -180,6 +190,8 @@ def api_provider_to_user_provider(
convert provider controller to user provider
"""
username = "Anonymous"
+ if db_provider.user is None:
+ raise ValueError(f"user is None for api provider {db_provider.id}")
try:
username = db_provider.user.name
except Exception as e:
@@ -256,19 +268,25 @@ def tool_to_user_tool(
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
+ if tool.identity is None:
+ raise ValueError("tool identity is None")
+
return UserTool(
author=tool.identity.author,
name=tool.identity.name,
label=tool.identity.label,
- description=tool.description.human,
+ description=I18nObject(
+ en_US=tool.description.human if tool.description else "",
+ zh_Hans=tool.description.human if tool.description else "",
+ ),
parameters=current_parameters,
labels=labels,
)
if isinstance(tool, ApiToolBundle):
return UserTool(
author=tool.author,
- name=tool.operation_id,
- label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
+ name=tool.operation_id or "",
+ label=I18nObject(en_US=tool.operation_id or "", zh_Hans=tool.operation_id or ""),
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
parameters=tool.parameters,
labels=labels,
diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py
index 318107bebb5eb6..69430de432b143 100644
--- a/api/services/tools/workflow_tools_manage_service.py
+++ b/api/services/tools/workflow_tools_manage_service.py
@@ -6,8 +6,10 @@
from sqlalchemy import or_
from core.model_runtime.utils.encoders import jsonable_encoder
-from core.tools.entities.api_entities import UserToolProvider
+from core.tools.entities.api_entities import UserTool, UserToolProvider
+from core.tools.provider.tool_provider import ToolProviderController
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
+from core.tools.tool.tool import Tool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from extensions.ext_database import db
@@ -32,7 +34,7 @@ def create_workflow_tool(
label: str,
icon: dict,
description: str,
- parameters: Mapping[str, Any],
+ parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: Optional[list[str]] = None,
) -> dict:
@@ -97,7 +99,7 @@ def update_workflow_tool(
label: str,
icon: dict,
description: str,
- parameters: list[dict],
+ parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: Optional[list[str]] = None,
) -> dict:
@@ -131,7 +133,7 @@ def update_workflow_tool(
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} already exists")
- workflow_tool_provider: WorkflowToolProvider = (
+ workflow_tool_provider: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
@@ -140,14 +142,14 @@ def update_workflow_tool(
if workflow_tool_provider is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
- app: App = (
+ app: Optional[App] = (
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
)
if app is None:
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
- workflow: Workflow = app.workflow
+ workflow: Optional[Workflow] = app.workflow
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
@@ -193,7 +195,7 @@ def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserTo
# skip deleted tools
pass
- labels = ToolLabelManager.get_tools_labels(tools)
+ labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)])
result = []
@@ -202,10 +204,11 @@ def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserTo
provider_controller=tool, labels=labels.get(tool.provider_id, [])
)
ToolTransformService.repack_provider(user_tool_provider)
+ to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
+ if to_user_tool is None or len(to_user_tool) == 0:
+ continue
user_tool_provider.tools = [
- ToolTransformService.tool_to_user_tool(
- tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, [])
- )
+ ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=labels.get(tool.provider_id, []))
]
result.append(user_tool_provider)
@@ -236,7 +239,7 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too
:param workflow_app_id: the workflow app id
:return: the tool
"""
- db_tool: WorkflowToolProvider = (
+ db_tool: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
@@ -245,13 +248,19 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too
if db_tool is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
- workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
+ workflow_app: Optional[App] = (
+ db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
+ )
if workflow_app is None:
raise ValueError(f"App {db_tool.app_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
+ to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
+ if to_user_tool is None or len(to_user_tool) == 0:
+ raise ValueError(f"Tool {workflow_tool_id} not found")
+
return {
"name": db_tool.name,
"label": db_tool.label,
@@ -261,9 +270,9 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool(
- tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
+ to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool)
),
- "synced": workflow_app.workflow.version == db_tool.version,
+ "synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False,
"privacy_policy": db_tool.privacy_policy,
}
@@ -276,7 +285,7 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_
:param workflow_app_id: the workflow app id
:return: the tool
"""
- db_tool: WorkflowToolProvider = (
+ db_tool: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first()
@@ -285,12 +294,17 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_
if db_tool is None:
raise ValueError(f"Tool {workflow_app_id} not found")
- workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
+ workflow_app: Optional[App] = (
+ db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
+ )
if workflow_app is None:
raise ValueError(f"App {db_tool.app_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
+ to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
+ if to_user_tool is None or len(to_user_tool) == 0:
+ raise ValueError(f"Tool {workflow_app_id} not found")
return {
"name": db_tool.name,
@@ -301,14 +315,14 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool(
- tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
+ to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool)
),
- "synced": workflow_app.workflow.version == db_tool.version,
+ "synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False,
"privacy_policy": db_tool.privacy_policy,
}
@classmethod
- def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]:
+ def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]:
"""
List workflow tool provider tools.
:param user_id: the user id
@@ -316,7 +330,7 @@ def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_
:param workflow_app_id: the workflow app id
:return: the list of tools
"""
- db_tool: WorkflowToolProvider = (
+ db_tool: Optional[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
@@ -326,9 +340,8 @@ def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_
raise ValueError(f"Tool {workflow_tool_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
+ to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id)
+ if to_user_tool is None or len(to_user_tool) == 0:
+ raise ValueError(f"Tool {workflow_tool_id} not found")
- return [
- ToolTransformService.tool_to_user_tool(
- tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
- )
- ]
+ return [ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool))]
diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py
index 508fe20970a703..f698ed3084bdac 100644
--- a/api/services/web_conversation_service.py
+++ b/api/services/web_conversation_service.py
@@ -26,6 +26,8 @@ def pagination_by_last_id(
pinned: Optional[bool] = None,
sort_by="-updated_at",
) -> InfiniteScrollPagination:
+ if not user:
+ raise ValueError("User is required")
include_ids = None
exclude_ids = None
if pinned is not None and user:
@@ -59,6 +61,8 @@ def pagination_by_last_id(
@classmethod
def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
+ if not user:
+ return
pinned_conversation = (
db.session.query(PinnedConversation)
.filter(
@@ -89,6 +93,8 @@ def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account,
@classmethod
def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
+ if not user:
+ return
pinned_conversation = (
db.session.query(PinnedConversation)
.filter(
diff --git a/api/services/website_service.py b/api/services/website_service.py
index 230f5d78152f39..1ad7d0399d6edf 100644
--- a/api/services/website_service.py
+++ b/api/services/website_service.py
@@ -1,8 +1,9 @@
import datetime
import json
+from typing import Any
import requests
-from flask_login import current_user
+from flask_login import current_user # type: ignore
from core.helper import encrypter
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
@@ -23,9 +24,9 @@ def document_create_args_validate(cls, args: dict):
@classmethod
def crawl_url(cls, args: dict) -> dict:
- provider = args.get("provider")
+ provider = args.get("provider", "")
url = args.get("url")
- options = args.get("options")
+ options = args.get("options", "")
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
if provider == "firecrawl":
# decrypt api_key
@@ -164,16 +165,18 @@ def get_crawl_status(cls, job_id: str, provider: str) -> dict:
return crawl_status_data
@classmethod
- def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
+ def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[Any, Any] | None:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
# decrypt api_key
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
+ # FIXME data is redefine too many times here, use Any to ease the type checking, fix it later
+ data: Any
if provider == "firecrawl":
file_key = "website_files/" + job_id + ".txt"
if storage.exists(file_key):
- data = storage.load_once(file_key)
- if data:
- data = json.loads(data.decode("utf-8"))
+ d = storage.load_once(file_key)
+ if d:
+ data = json.loads(d.decode("utf-8"))
else:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
result = firecrawl_app.check_crawl_status(job_id)
@@ -183,22 +186,17 @@ def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str
if data:
for item in data:
if item.get("source_url") == url:
- return item
+ return dict(item)
return None
elif provider == "jinareader":
- file_key = "website_files/" + job_id + ".txt"
- if storage.exists(file_key):
- data = storage.load_once(file_key)
- if data:
- data = json.loads(data.decode("utf-8"))
- elif not job_id:
+ if not job_id:
response = requests.get(
f"https://r.jina.ai/{url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
if response.json().get("code") != 200:
raise ValueError("Failed to crawl")
- return response.json().get("data")
+ return dict(response.json().get("data", {}))
else:
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
response = requests.post(
@@ -218,12 +216,13 @@ def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str
data = response.json().get("data", {})
for item in data.get("processed", {}).values():
if item.get("data", {}).get("url") == url:
- return item.get("data", {})
+ return dict(item.get("data", {}))
+ return None
else:
raise ValueError("Invalid provider")
@classmethod
- def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None:
+ def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
if provider == "firecrawl":
# decrypt api_key
diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py
index 90b5cc48362f3b..2b0d57bdfdeda3 100644
--- a/api/services/workflow/workflow_converter.py
+++ b/api/services/workflow/workflow_converter.py
@@ -1,5 +1,5 @@
import json
-from typing import Optional
+from typing import Any, Optional
from core.app.app_config.entities import (
DatasetEntity,
@@ -101,7 +101,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, app_model_config:
app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config)
# init workflow graph
- graph = {"nodes": [], "edges": []}
+ graph: dict[str, Any] = {"nodes": [], "edges": []}
# Convert list:
# - variables -> start
@@ -118,7 +118,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, app_model_config:
graph["nodes"].append(start_node)
# convert to http request node
- external_data_variable_node_mapping = {}
+ external_data_variable_node_mapping: dict[str, str] = {}
if app_config.external_data_variables:
http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node(
app_model=app_model,
@@ -199,15 +199,16 @@ def convert_app_model_config_to_workflow(self, app_model: App, app_model_config:
return workflow
def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
- app_mode = AppMode.value_of(app_model.mode)
- if app_mode == AppMode.AGENT_CHAT or app_model.is_agent:
+ app_mode_enum = AppMode.value_of(app_model.mode)
+ app_config: EasyUIBasedAppConfig
+ if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent:
app_model.mode = AppMode.AGENT_CHAT.value
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config
)
- elif app_mode == AppMode.CHAT:
+ elif app_mode_enum == AppMode.CHAT:
app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
- elif app_mode == AppMode.COMPLETION:
+ elif app_mode_enum == AppMode.COMPLETION:
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config
)
@@ -302,7 +303,7 @@ def _convert_to_http_request_node(
nodes.append(http_request_node)
# append code node for response body parsing
- code_node = {
+ code_node: dict[str, Any] = {
"id": f"code_{index}",
"position": None,
"data": {
@@ -401,6 +402,7 @@ def _convert_to_llm_node(
)
role_prefix = None
+ prompts: Any = None
# Chat Model
if model_config.mode == LLMMode.CHAT.value:
diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py
index d8ee323908a844..4343596a236f5f 100644
--- a/api/services/workflow_run_service.py
+++ b/api/services/workflow_run_service.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
@@ -92,7 +94,7 @@ def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScro
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
- def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun:
+ def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]:
"""
Get workflow run detail
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index 84768d5af053e4..ea8192edde35cc 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -2,7 +2,7 @@
import time
from collections.abc import Sequence
from datetime import UTC, datetime
-from typing import Optional, cast
+from typing import Any, Optional, cast
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
@@ -242,7 +242,7 @@ def run_draft_workflow_node(
raise ValueError("Node run failed with no run result")
# single step debug mode error handling return
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
- node_error_args = {
+ node_error_args: dict[str, Any] = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
@@ -338,7 +338,7 @@ def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> A
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
# convert to workflow
- new_app = workflow_converter.convert_to_workflow(
+ new_app: App = workflow_converter.convert_to_workflow(
app_model=app_model,
account=account,
name=args.get("name", "Default Name"),
diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py
index 8fcb12b1cb9664..7637b31454e556 100644
--- a/api/services/workspace_service.py
+++ b/api/services/workspace_service.py
@@ -1,4 +1,4 @@
-from flask_login import current_user
+from flask_login import current_user # type: ignore
from configs import dify_config
from extensions.ext_database import db
@@ -29,6 +29,7 @@ def get_tenant_info(cls, tenant: Tenant):
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
.first()
)
+ assert tenant_account_join is not None, "TenantAccountJoin not found"
tenant_info["role"] = tenant_account_join.role
can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo
diff --git a/api/tasks/__init__.py b/api/tasks/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py
index 09be6612160471..50bb2b6e634fba 100644
--- a/api/tasks/add_document_to_index_task.py
+++ b/api/tasks/add_document_to_index_task.py
@@ -3,7 +3,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py
index 25c55bcfafe11c..aab21a44109975 100644
--- a/api/tasks/annotation/add_annotation_to_index_task.py
+++ b/api/tasks/annotation/add_annotation_to_index_task.py
@@ -2,7 +2,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py
index fa7e5ac9190f3c..06162b02d60f8b 100644
--- a/api/tasks/annotation/batch_import_annotations_task.py
+++ b/api/tasks/annotation/batch_import_annotations_task.py
@@ -2,7 +2,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.datasource.vdb.vector_factory import Vector
diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py
index f0f6b32b06c78c..a6a598ce4b6bca 100644
--- a/api/tasks/annotation/delete_annotation_index_task.py
+++ b/api/tasks/annotation/delete_annotation_index_task.py
@@ -2,7 +2,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from core.rag.datasource.vdb.vector_factory import Vector
from models.dataset import Dataset
diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py
index a2f49135139b08..26bf1c7c9fa32e 100644
--- a/api/tasks/annotation/disable_annotation_reply_task.py
+++ b/api/tasks/annotation/disable_annotation_reply_task.py
@@ -2,7 +2,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.datasource.vdb.vector_factory import Vector
diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py
index 0bdcd0eccd7f72..b42af0c7faf67e 100644
--- a/api/tasks/annotation/enable_annotation_reply_task.py
+++ b/api/tasks/annotation/enable_annotation_reply_task.py
@@ -3,7 +3,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.datasource.vdb.vector_factory import Vector
diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py
index b685d84d07ad28..8c675feaa6e06f 100644
--- a/api/tasks/annotation/update_annotation_to_index_task.py
+++ b/api/tasks/annotation/update_annotation_to_index_task.py
@@ -2,7 +2,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py
index dcb7009e44b938..26ae9f8736d79a 100644
--- a/api/tasks/batch_create_segment_to_index_task.py
+++ b/api/tasks/batch_create_segment_to_index_task.py
@@ -4,7 +4,7 @@
import uuid
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from sqlalchemy import func
from core.indexing_runner import IndexingRunner
@@ -58,12 +58,13 @@ def batch_create_segment_to_index_task(
model=dataset.embedding_model,
)
word_count_change = 0
+ segments_to_insert: list[str] = [] # Explicitly type hint the list as List[str]
for segment in content:
- content = segment["content"]
+ content_str = segment["content"]
doc_id = str(uuid.uuid4())
- segment_hash = helper.generate_text_hash(content)
+ segment_hash = helper.generate_text_hash(content_str)
# calc embedding use tokens
- tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) if embedding_model else 0
+ tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == dataset_document.id)
@@ -90,6 +91,7 @@ def batch_create_segment_to_index_task(
word_count_change += segment_document.word_count
db.session.add(segment_document)
document_segments.append(segment_document)
+ segments_to_insert.append(str(segment)) # Cast to string if needed
# update document word count
dataset_document.word_count += word_count_change
db.session.add(dataset_document)
diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py
index a555fb28746697..d9278c03793877 100644
--- a/api/tasks/clean_dataset_task.py
+++ b/api/tasks/clean_dataset_task.py
@@ -2,7 +2,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
@@ -71,6 +71,8 @@ def clean_dataset_task(
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
+ if image_file is None:
+ continue
try:
storage.delete(image_file.key)
except Exception:
diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py
index 4d328643bfa165..3e80dd13771802 100644
--- a/api/tasks/clean_document_task.py
+++ b/api/tasks/clean_document_task.py
@@ -3,7 +3,7 @@
from typing import Optional
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
@@ -44,6 +44,8 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
+ if image_file is None:
+ continue
try:
storage.delete(image_file.key)
except Exception:
diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py
index 75d9e031306381..f5d6406d9cc04f 100644
--- a/api/tasks/clean_notion_document_task.py
+++ b/api/tasks/clean_notion_document_task.py
@@ -2,7 +2,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py
index 315b01f157bf13..dfa053a43cbc61 100644
--- a/api/tasks/create_segment_to_index_task.py
+++ b/api/tasks/create_segment_to_index_task.py
@@ -4,7 +4,7 @@
from typing import Optional
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py
index cfc54920e23caa..b025509aebe674 100644
--- a/api/tasks/deal_dataset_vector_index_task.py
+++ b/api/tasks/deal_dataset_vector_index_task.py
@@ -2,7 +2,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py
index c3e0ea5d9fbb77..45a612c74550cd 100644
--- a/api/tasks/delete_segment_from_index_task.py
+++ b/api/tasks/delete_segment_from_index_task.py
@@ -2,7 +2,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py
index 15e1e50076e8c9..f30a1cc7acfd6c 100644
--- a/api/tasks/disable_segment_from_index_task.py
+++ b/api/tasks/disable_segment_from_index_task.py
@@ -2,7 +2,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py
index 18316913932874..ac4e81f95d127e 100644
--- a/api/tasks/document_indexing_sync_task.py
+++ b/api/tasks/document_indexing_sync_task.py
@@ -3,7 +3,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py
index 734dd2478a9847..21b571b6cb5bd4 100644
--- a/api/tasks/document_indexing_task.py
+++ b/api/tasks/document_indexing_task.py
@@ -3,7 +3,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from configs import dify_config
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py
index 1a52a6636b1d17..5f1e9a892f54e3 100644
--- a/api/tasks/document_indexing_update_task.py
+++ b/api/tasks/document_indexing_update_task.py
@@ -3,7 +3,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py
index f4c3dbd2e2860c..6db2620eb6eef0 100644
--- a/api/tasks/duplicate_document_indexing_task.py
+++ b/api/tasks/duplicate_document_indexing_task.py
@@ -3,7 +3,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from configs import dify_config
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
@@ -26,6 +26,8 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ if dataset is None:
+ raise ValueError("Dataset not found")
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py
index 12639db9392677..2f6eb7b82a0633 100644
--- a/api/tasks/enable_segment_to_index_task.py
+++ b/api/tasks/enable_segment_to_index_task.py
@@ -3,7 +3,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py
index d78fc2b8915520..5dc935548f90b8 100644
--- a/api/tasks/mail_email_code_login.py
+++ b/api/tasks/mail_email_code_login.py
@@ -2,7 +2,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from flask import render_template
from extensions.ext_mail import mail
diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py
index c7dfb9bf6063ff..3094527fd40945 100644
--- a/api/tasks/mail_invite_member_task.py
+++ b/api/tasks/mail_invite_member_task.py
@@ -2,7 +2,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from flask import render_template
from configs import dify_config
diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py
index 8596ca07cfcee3..d5be94431b6221 100644
--- a/api/tasks/mail_reset_password_task.py
+++ b/api/tasks/mail_reset_password_task.py
@@ -2,7 +2,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from flask import render_template
from extensions.ext_mail import mail
diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py
index 34c62dc9237fc0..bb3b9e17ead6d2 100644
--- a/api/tasks/ops_trace_task.py
+++ b/api/tasks/ops_trace_task.py
@@ -1,7 +1,7 @@
import json
import logging
-from celery import shared_task
+from celery import shared_task # type: ignore
from flask import current_app
from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY
diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py
index 934eb7430c90c3..b603d689ba9d8e 100644
--- a/api/tasks/recover_document_indexing_task.py
+++ b/api/tasks/recover_document_indexing_task.py
@@ -2,7 +2,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py
index 66f78636ecca60..c3910e2be3a499 100644
--- a/api/tasks/remove_app_and_related_data_task.py
+++ b/api/tasks/remove_app_and_related_data_task.py
@@ -3,7 +3,7 @@
from collections.abc import Callable
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError
diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py
index 1909eaf3418517..4ba6d1a83e32ae 100644
--- a/api/tasks/remove_document_from_index_task.py
+++ b/api/tasks/remove_document_from_index_task.py
@@ -2,7 +2,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py
index 73471fd6e77c9b..485caa5152ea78 100644
--- a/api/tasks/retry_document_indexing_task.py
+++ b/api/tasks/retry_document_indexing_task.py
@@ -3,7 +3,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -22,10 +22,13 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
Usage: retry_document_indexing_task.delay(dataset_id, document_id)
"""
- documents = []
+ documents: list[Document] = []
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ if not dataset:
+ raise ValueError("Dataset not found")
+
for document_id in document_ids:
retry_indexing_cache_key = "document_{}_is_retried".format(document_id)
# check document limit
@@ -55,29 +58,31 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
document = (
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
+ if not document:
+ logging.info(click.style("Document not found: {}".format(document_id), fg="yellow"))
+ return
try:
- if document:
- # clean old data
- index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
+ # clean old data
+ index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
- segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
- # delete from vector index
- index_processor.clean(dataset, index_node_ids)
+ segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ # delete from vector index
+ index_processor.clean(dataset, index_node_ids)
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
-
- document.indexing_status = "parsing"
- document.processing_started_at = datetime.datetime.utcnow()
- db.session.add(document)
+ for segment in segments:
+ db.session.delete(segment)
db.session.commit()
- indexing_runner = IndexingRunner()
- indexing_runner.run([document])
- redis_client.delete(retry_indexing_cache_key)
+ document.indexing_status = "parsing"
+ document.processing_started_at = datetime.datetime.utcnow()
+ db.session.add(document)
+ db.session.commit()
+
+ indexing_runner = IndexingRunner()
+ indexing_runner.run([document])
+ redis_client.delete(retry_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py
index 1d2a338c831764..5d6b069cf44919 100644
--- a/api/tasks/sync_website_document_indexing_task.py
+++ b/api/tasks/sync_website_document_indexing_task.py
@@ -3,7 +3,7 @@
import time
import click
-from celery import shared_task
+from celery import shared_task # type: ignore
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -25,6 +25,8 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ if dataset is None:
+ raise ValueError("Dataset not found")
sync_indexing_cache_key = "document_{}_is_sync".format(document_id)
# check document limit
@@ -52,29 +54,31 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
logging.info(click.style("Start sync website document: {}".format(document_id), fg="green"))
document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ if not document:
+ logging.info(click.style("Document not found: {}".format(document_id), fg="yellow"))
+ return
try:
- if document:
- # clean old data
- index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
+ # clean old data
+ index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
- segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
- # delete from vector index
- index_processor.clean(dataset, index_node_ids)
+ segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ # delete from vector index
+ index_processor.clean(dataset, index_node_ids)
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
-
- document.indexing_status = "parsing"
- document.processing_started_at = datetime.datetime.utcnow()
- db.session.add(document)
+ for segment in segments:
+ db.session.delete(segment)
db.session.commit()
- indexing_runner = IndexingRunner()
- indexing_runner.run([document])
- redis_client.delete(sync_indexing_cache_key)
+ document.indexing_status = "parsing"
+ document.processing_started_at = datetime.datetime.utcnow()
+ db.session.add(document)
+ db.session.commit()
+
+ indexing_runner = IndexingRunner()
+ indexing_runner.run([document])
+ redis_client.delete(sync_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
diff --git a/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py
index 64f2884c4b828c..57fba317638de8 100644
--- a/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py
+++ b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py
@@ -1,6 +1,6 @@
from typing import Any
-import toml
+import toml # type: ignore
def load_api_poetry_configs() -> dict[str, Any]:
@@ -38,7 +38,7 @@ def test_group_dependencies_version_operator():
)
-def test_duplicated_dependency_crossing_groups():
+def test_duplicated_dependency_crossing_groups() -> None:
all_dependency_names: list[str] = []
for dependencies in load_all_dependency_groups().values():
dependency_names = list(dependencies.keys())
diff --git a/api/tests/integration_tests/controllers/test_controllers.py b/api/tests/integration_tests/controllers/test_controllers.py
index 6371694694653e..5e3ee6bedc7ebb 100644
--- a/api/tests/integration_tests/controllers/test_controllers.py
+++ b/api/tests/integration_tests/controllers/test_controllers.py
@@ -1,6 +1,6 @@
from unittest.mock import patch
-from app_fixture import app, mock_user
+from app_fixture import mock_user # type: ignore
def test_post_requires_login(app):
diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py
index 5ea86baa83dd4b..b90f8b444477d5 100644
--- a/api/tests/integration_tests/model_runtime/__mock/google.py
+++ b/api/tests/integration_tests/model_runtime/__mock/google.py
@@ -1,7 +1,7 @@
from collections.abc import Generator
from unittest.mock import MagicMock
-import google.generativeai.types.generation_types as generation_config_types
+import google.generativeai.types.generation_types as generation_config_types # type: ignore
import pytest
from _pytest.monkeypatch import MonkeyPatch
from google.ai import generativelanguage as glm
@@ -45,7 +45,7 @@ def generate_content_sync() -> GenerateContentResponse:
return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])
@staticmethod
- def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
+ def generate_content_stream() -> MockGoogleResponseClass:
return MockGoogleResponseClass()
def generate_content(
diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface.py b/api/tests/integration_tests/model_runtime/__mock/huggingface.py
index 97038ef5963e87..4de52514408a06 100644
--- a/api/tests/integration_tests/model_runtime/__mock/huggingface.py
+++ b/api/tests/integration_tests/model_runtime/__mock/huggingface.py
@@ -2,7 +2,7 @@
import pytest
from _pytest.monkeypatch import MonkeyPatch
-from huggingface_hub import InferenceClient
+from huggingface_hub import InferenceClient # type: ignore
from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass
diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py
index 9ee76c935c9873..77c7e7f5e4089c 100644
--- a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py
+++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py
@@ -3,15 +3,15 @@
from typing import Any, Literal, Optional, Union
from _pytest.monkeypatch import MonkeyPatch
-from huggingface_hub import InferenceClient
-from huggingface_hub.inference._text_generation import (
+from huggingface_hub import InferenceClient # type: ignore
+from huggingface_hub.inference._text_generation import ( # type: ignore
Details,
StreamDetails,
TextGenerationResponse,
TextGenerationStreamResponse,
Token,
)
-from huggingface_hub.utils import BadRequestError
+from huggingface_hub.utils import BadRequestError # type: ignore
class MockHuggingfaceChatClass:
diff --git a/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py
index 6a25398cbf069a..4e00660a29162f 100644
--- a/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py
+++ b/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py
@@ -6,7 +6,7 @@
# import monkeypatch
from _pytest.monkeypatch import MonkeyPatch
-from nomic import embed
+from nomic import embed # type: ignore
def create_embedding(texts: list[str], model: str, **kwargs: Any) -> dict:
diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py
index 794f4b0585632e..e2abaa52b939a6 100644
--- a/api/tests/integration_tests/model_runtime/__mock/xinference.py
+++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py
@@ -6,14 +6,14 @@
from _pytest.monkeypatch import MonkeyPatch
from requests import Response
from requests.sessions import Session
-from xinference_client.client.restful.restful_client import (
+from xinference_client.client.restful.restful_client import ( # type: ignore
Client,
RESTfulChatModelHandle,
RESTfulEmbeddingModelHandle,
RESTfulGenerateModelHandle,
RESTfulRerankModelHandle,
)
-from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
+from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage # type: ignore
class MockXinferenceClass:
diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py
index 2dcfb92c63fee2..d37fcf897fc3a8 100644
--- a/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py
+++ b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py
@@ -1,6 +1,6 @@
import os
-import dashscope
+import dashscope # type: ignore
import pytest
from core.model_runtime.entities.rerank_entities import RerankResult
diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py
index 83f4d70ce9ac2f..2860739f0e30b3 100644
--- a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py
+++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py
@@ -1,5 +1,5 @@
from flask import Flask, request
-from flask_restful import Api, Resource
+from flask_restful import Api, Resource # type: ignore
app = Flask(__name__)
api = Api(app)
diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py
index 0ea61369c0304e..4af35a8befcaf8 100644
--- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py
+++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py
@@ -4,11 +4,11 @@
import pytest
from _pytest.monkeypatch import MonkeyPatch
-from pymochow import MochowClient
-from pymochow.model.database import Database
-from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
-from pymochow.model.schema import HNSWParams, VectorIndex
-from pymochow.model.table import Table
+from pymochow import MochowClient # type: ignore
+from pymochow.model.database import Database # type: ignore
+from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore
+from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore
+from pymochow.model.table import Table # type: ignore
from requests.adapters import HTTPAdapter
diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py
index 61d6ed16560c09..68a1e290adc120 100644
--- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py
+++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py
@@ -4,12 +4,12 @@
import pytest
from _pytest.monkeypatch import MonkeyPatch
from requests.adapters import HTTPAdapter
-from tcvectordb import VectorDBClient
-from tcvectordb.model.database import Collection, Database
-from tcvectordb.model.document import Document, Filter
-from tcvectordb.model.enum import ReadConsistency
-from tcvectordb.model.index import Index
-from xinference_client.types import Embedding
+from tcvectordb import VectorDBClient # type: ignore
+from tcvectordb.model.database import Collection, Database # type: ignore
+from tcvectordb.model.document import Document, Filter # type: ignore
+from tcvectordb.model.enum import ReadConsistency # type: ignore
+from tcvectordb.model.index import Index # type: ignore
+from xinference_client.types import Embedding # type: ignore
class MockTcvectordbClass:
diff --git a/api/tests/integration_tests/vdb/__mock/vikingdb.py b/api/tests/integration_tests/vdb/__mock/vikingdb.py
index 0f40337feba6ee..3ad72e55501f58 100644
--- a/api/tests/integration_tests/vdb/__mock/vikingdb.py
+++ b/api/tests/integration_tests/vdb/__mock/vikingdb.py
@@ -4,7 +4,7 @@
import pytest
from _pytest.monkeypatch import MonkeyPatch
-from volcengine.viking_db import (
+from volcengine.viking_db import ( # type: ignore
Collection,
Data,
DistanceType,
diff --git a/api/tests/unit_tests/oss/__mock/aliyun_oss.py b/api/tests/unit_tests/oss/__mock/aliyun_oss.py
index 27e1c0ad85029b..4f6d8a2f54a4fd 100644
--- a/api/tests/unit_tests/oss/__mock/aliyun_oss.py
+++ b/api/tests/unit_tests/oss/__mock/aliyun_oss.py
@@ -4,8 +4,8 @@
import pytest
from _pytest.monkeypatch import MonkeyPatch
-from oss2 import Bucket
-from oss2.models import GetObjectResult, PutObjectResult
+from oss2 import Bucket # type: ignore
+from oss2.models import GetObjectResult, PutObjectResult # type: ignore
from tests.unit_tests.oss.__mock.base import (
get_example_bucket,
diff --git a/api/tests/unit_tests/oss/__mock/tencent_cos.py b/api/tests/unit_tests/oss/__mock/tencent_cos.py
index 5189b68e87132a..c77c5b08f37d15 100644
--- a/api/tests/unit_tests/oss/__mock/tencent_cos.py
+++ b/api/tests/unit_tests/oss/__mock/tencent_cos.py
@@ -3,8 +3,8 @@
import pytest
from _pytest.monkeypatch import MonkeyPatch
-from qcloud_cos import CosS3Client
-from qcloud_cos.streambody import StreamBody
+from qcloud_cos import CosS3Client # type: ignore
+from qcloud_cos.streambody import StreamBody # type: ignore
from tests.unit_tests.oss.__mock.base import (
get_example_bucket,
diff --git a/api/tests/unit_tests/oss/__mock/volcengine_tos.py b/api/tests/unit_tests/oss/__mock/volcengine_tos.py
index 649d93a20261d3..88df59f91c3071 100644
--- a/api/tests/unit_tests/oss/__mock/volcengine_tos.py
+++ b/api/tests/unit_tests/oss/__mock/volcengine_tos.py
@@ -4,8 +4,8 @@
import pytest
from _pytest.monkeypatch import MonkeyPatch
-from tos import TosClientV2
-from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput
+from tos import TosClientV2 # type: ignore
+from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore
from tests.unit_tests.oss.__mock.base import (
get_example_bucket,
diff --git a/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py b/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py
index 65d31352bd3437..380134bc46d02e 100644
--- a/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py
+++ b/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py
@@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
-from oss2 import Auth
+from oss2 import Auth # type: ignore
from extensions.storage.aliyun_oss_storage import AliyunOssStorage
from tests.unit_tests.oss.__mock.aliyun_oss import setup_aliyun_oss_mock
diff --git a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py
index 303f0493bda42f..d289751800633a 100644
--- a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py
+++ b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py
@@ -1,7 +1,7 @@
from unittest.mock import patch
import pytest
-from qcloud_cos import CosConfig
+from qcloud_cos import CosConfig # type: ignore
from extensions.storage.tencent_cos_storage import TencentCosStorage
from tests.unit_tests.oss.__mock.base import (
diff --git a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py
index 5afbc9e8b4cb18..04988e85d85881 100644
--- a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py
+++ b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py
@@ -1,5 +1,5 @@
import pytest
-from tos import TosClientV2
+from tos import TosClientV2 # type: ignore
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
from tests.unit_tests.oss.__mock.base import (
diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py
index 95b93651d57f80..8d645487278a5f 100644
--- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py
+++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py
@@ -1,7 +1,7 @@
from textwrap import dedent
import pytest
-from yaml import YAMLError
+from yaml import YAMLError # type: ignore
from core.tools.utils.yaml_utils import load_yaml_file
diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py
index e6644883018769..ee1b5c57e1d1d0 100644
--- a/sdks/python-client/dify_client/client.py
+++ b/sdks/python-client/dify_client/client.py
@@ -160,7 +160,10 @@ def get_result(self, workflow_run_id):
class KnowledgeBaseClient(DifyClient):
def __init__(
- self, api_key, base_url: str = "https://api.dify.ai/v1", dataset_id: str = None
+ self,
+ api_key,
+ base_url: str = "https://api.dify.ai/v1",
+ dataset_id: str | None = None,
):
"""
Construct a KnowledgeBaseClient object.
@@ -187,7 +190,9 @@ def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
"GET", f"/datasets?page={page}&limit={page_size}", **kwargs
)
- def create_document_by_text(self, name, text, extra_params: dict = None, **kwargs):
+ def create_document_by_text(
+ self, name, text, extra_params: dict | None = None, **kwargs
+ ):
"""
Create a document by text.
@@ -225,7 +230,7 @@ def create_document_by_text(self, name, text, extra_params: dict = None, **kwarg
return self._send_request("POST", url, json=data, **kwargs)
def update_document_by_text(
- self, document_id, name, text, extra_params: dict = None, **kwargs
+ self, document_id, name, text, extra_params: dict | None = None, **kwargs
):
"""
Update a document by text.
@@ -262,7 +267,7 @@ def update_document_by_text(
return self._send_request("POST", url, json=data, **kwargs)
def create_document_by_file(
- self, file_path, original_document_id=None, extra_params: dict = None
+ self, file_path, original_document_id=None, extra_params: dict | None = None
):
"""
Create a document by file.
@@ -304,7 +309,7 @@ def create_document_by_file(
)
def update_document_by_file(
- self, document_id, file_path, extra_params: dict = None
+ self, document_id, file_path, extra_params: dict | None = None
):
"""
Update a document by file.
@@ -372,7 +377,11 @@ def delete_document(self, document_id):
return self._send_request("DELETE", url)
def list_documents(
- self, page: int = None, page_size: int = None, keyword: str = None, **kwargs
+ self,
+ page: int | None = None,
+ page_size: int | None = None,
+ keyword: str | None = None,
+ **kwargs,
):
"""
Get a list of documents in this dataset.
@@ -402,7 +411,11 @@ def add_segments(self, document_id, segments, **kwargs):
return self._send_request("POST", url, json=data, **kwargs)
def query_segments(
- self, document_id, keyword: str = None, status: str = None, **kwargs
+ self,
+ document_id,
+ keyword: str | None = None,
+ status: str | None = None,
+ **kwargs,
):
"""
Query segments in this document.
From cdaef30cc9e0ae5d944b95603842f41fe257b9d0 Mon Sep 17 00:00:00 2001
From: TinsFox
Date: Tue, 24 Dec 2024 19:13:24 +0800
Subject: [PATCH 03/65] refactor: replace div with button for better
accessibility (#12046)
---
web/app/(commonLayout)/apps/NewAppCard.tsx | 23 +++++++++++-----------
1 file changed, 11 insertions(+), 12 deletions(-)
diff --git a/web/app/(commonLayout)/apps/NewAppCard.tsx b/web/app/(commonLayout)/apps/NewAppCard.tsx
index d353cf239431ff..a90af4ea85caf4 100644
--- a/web/app/(commonLayout)/apps/NewAppCard.tsx
+++ b/web/app/(commonLayout)/apps/NewAppCard.tsx
@@ -18,7 +18,6 @@ export type CreateAppCardProps = {
onSuccess?: () => void
}
-// eslint-disable-next-line react/display-name
const CreateAppCard = forwardRef(({ className, onSuccess }, ref) => {
const { t } = useTranslation()
const { onPlanInfoChanged } = useProviderContext()
@@ -44,24 +43,22 @@ const CreateAppCard = forwardRef(({ classNam
>
{t('app.createApp')}
-
setShowNewAppModal(true)}>
+
-
setShowNewAppTemplateDialog(true)}>
+
+
-
- setShowCreateFromDSLModal(true)}
- >
-
+
+
+
+
setShowNewAppModal(false)}
@@ -108,4 +105,6 @@ const CreateAppCard = forwardRef(({ classNam
)
})
+CreateAppCard.displayName = 'CreateAppCard'
export default CreateAppCard
+export { CreateAppCard }
From 49bc602fb237183de94cae25761033bd768e0109 Mon Sep 17 00:00:00 2001
From: eux
Date: Tue, 24 Dec 2024 21:58:05 +0800
Subject: [PATCH 04/65] fix: --name option for the create-tenant command does
not take effect (#11993)
---
api/commands.py | 9 +++++++--
api/services/account_service.py | 3 ++-
2 files changed, 9 insertions(+), 3 deletions(-)
diff --git a/api/commands.py b/api/commands.py
index ad7ad972f3fd01..59dfce68e0c92f 100644
--- a/api/commands.py
+++ b/api/commands.py
@@ -561,8 +561,13 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
new_password = secrets.token_urlsafe(16)
# register account
- account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
-
+ account = RegisterService.register(
+ email=email,
+ name=account_name,
+ password=new_password,
+ language=language,
+ create_workspace_required=False,
+ )
TenantService.create_owner_tenant_if_not_exist(account, name)
click.echo(
diff --git a/api/services/account_service.py b/api/services/account_service.py
index 91075ec46b16bf..2d37db391c899c 100644
--- a/api/services/account_service.py
+++ b/api/services/account_service.py
@@ -797,6 +797,7 @@ def register(
language: Optional[str] = None,
status: Optional[AccountStatus] = None,
is_setup: Optional[bool] = False,
+ create_workspace_required: Optional[bool] = True,
) -> Account:
db.session.begin_nested()
"""Register account"""
@@ -814,7 +815,7 @@ def register(
if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account)
- if FeatureService.get_system_features().is_allow_create_workspace:
+ if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
From 0ea6a926c5e7d213b06c3825b29cbd4563f47ced Mon Sep 17 00:00:00 2001
From: yihong
Date: Tue, 24 Dec 2024 23:14:32 +0800
Subject: [PATCH 05/65] fix: tool can not run (#12054)
Signed-off-by: yihong0618
---
api/models/tools.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/api/models/tools.py b/api/models/tools.py
index 4151a2e9f636a0..13a112ee83b513 100644
--- a/api/models/tools.py
+++ b/api/models/tools.py
@@ -1,5 +1,5 @@
import json
-from typing import Optional
+from typing import Any, Optional
import sqlalchemy as sa
from sqlalchemy import ForeignKey, func
@@ -282,8 +282,8 @@ class ToolConversationVariables(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
- def variables(self) -> dict:
- return dict(json.loads(self.variables_str))
+ def variables(self) -> Any:
+ return json.loads(self.variables_str)
class ToolFile(db.Model): # type: ignore[name-defined]
From 7a24c957bdb3a1477e5d8e2128af25fcb91d6627 Mon Sep 17 00:00:00 2001
From: yihong
Date: Tue, 24 Dec 2024 23:14:51 +0800
Subject: [PATCH 06/65] fix: i18n error (#12052)
Signed-off-by: yihong0618
---
api/core/tools/entities/tool_entities.py | 8 +++++---
api/services/tools/tools_transform_service.py | 5 +----
2 files changed, 6 insertions(+), 7 deletions(-)
diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py
index 260e4e457f083e..c87a90c03a6f7e 100644
--- a/api/core/tools/entities/tool_entities.py
+++ b/api/core/tools/entities/tool_entities.py
@@ -243,9 +243,11 @@ def get_simple_instance(
:param options: the options of the parameter
"""
# convert options to ToolParameterOption
+ # FIXME fix the type error
if options:
- options_tool_parametor = [
- ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options
+ options = [
+ ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) # type: ignore
+ for option in options # type: ignore
]
return cls(
name=name,
@@ -256,7 +258,7 @@ def get_simple_instance(
form=cls.ToolParameterForm.LLM,
llm_description=llm_description,
required=required,
- options=options_tool_parametor,
+ options=options, # type: ignore
)
diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py
index b501554bcd091d..6e3a45be0da1c4 100644
--- a/api/services/tools/tools_transform_service.py
+++ b/api/services/tools/tools_transform_service.py
@@ -275,10 +275,7 @@ def tool_to_user_tool(
author=tool.identity.author,
name=tool.identity.name,
label=tool.identity.label,
- description=I18nObject(
- en_US=tool.description.human if tool.description else "",
- zh_Hans=tool.description.human if tool.description else "",
- ),
+ description=tool.description.human if tool.description else "", # type: ignore
parameters=current_parameters,
labels=labels,
)
From 7da4fb68da065385b78db63ee51183615f74b351 Mon Sep 17 00:00:00 2001
From: yihong
Date: Wed, 25 Dec 2024 08:42:52 +0800
Subject: [PATCH 07/65] fix: can not find model bug (#12051)
Signed-off-by: yihong0618
---
api/core/entities/provider_configuration.py | 2 +-
api/services/entities/model_provider_entities.py | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py
index 2e27b362d3092c..bff5a0ec9c6be7 100644
--- a/api/core/entities/provider_configuration.py
+++ b/api/core/entities/provider_configuration.py
@@ -872,7 +872,7 @@ def _get_system_provider_models(
# if llm name not in restricted llm list, remove it
restrict_model_names = [rm.model for rm in restrict_models]
for model in provider_models:
- if model.model_type == ModelType.LLM and m.model not in restrict_model_names:
+ if model.model_type == ModelType.LLM and model.model not in restrict_model_names:
model.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid:
model.status = ModelStatus.QUOTA_EXCEEDED
diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py
index 334d009ee5f79f..f1417c6cb94b80 100644
--- a/api/services/entities/model_provider_entities.py
+++ b/api/services/entities/model_provider_entities.py
@@ -7,7 +7,6 @@
from core.entities.model_entities import (
ModelWithProviderEntity,
ProviderModelWithStatusEntity,
- SimpleModelProviderEntity,
)
from core.entities.provider_entities import QuotaConfiguration
from core.model_runtime.entities.common_entities import I18nObject
@@ -152,7 +151,8 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity):
Model with provider entity.
"""
- provider: SimpleModelProviderEntity
+ # FIXME type error ignore here
+ provider: SimpleProviderEntityResponse # type: ignore
def __init__(self, model: ModelWithProviderEntity) -> None:
super().__init__(**model.model_dump())
From 1d3f218662527db991e2b520e961ec3fa07603df Mon Sep 17 00:00:00 2001
From: yihong
Date: Wed, 25 Dec 2024 10:57:52 +0800
Subject: [PATCH 08/65] fix: like failed close #12057 (#12058)
Signed-off-by: yihong0618
---
api/controllers/console/explore/message.py | 2 +-
api/controllers/service_api/app/message.py | 2 +-
api/controllers/web/message.py | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py
index c3488de29929c9..690297048eb55c 100644
--- a/api/controllers/console/explore/message.py
+++ b/api/controllers/console/explore/message.py
@@ -69,7 +69,7 @@ def post(self, installed_app, message_id):
args = parser.parse_args()
try:
- MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"])
+ MessageService.create_feedback(app_model, message_id, current_user, args.get("rating"), args.get("content"))
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py
index 522c7509b9849d..bed89a99a58683 100644
--- a/api/controllers/service_api/app/message.py
+++ b/api/controllers/service_api/app/message.py
@@ -108,7 +108,7 @@ def post(self, app_model: App, end_user: EndUser, message_id):
args = parser.parse_args()
try:
- MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"])
+ MessageService.create_feedback(app_model, message_id, end_user, args.get("rating"), args.get("content"))
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py
index 0f47e643708570..b636e6be620ec6 100644
--- a/api/controllers/web/message.py
+++ b/api/controllers/web/message.py
@@ -108,7 +108,7 @@ def post(self, app_model, end_user, message_id):
args = parser.parse_args()
try:
- MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"])
+ MessageService.create_feedback(app_model, message_id, end_user, args.get("rating"), args.get("content"))
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
From 3ea54e9d2574446664a137280875ea3ec04ba7a7 Mon Sep 17 00:00:00 2001
From: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com>
Date: Wed, 25 Dec 2024 12:00:45 +0900
Subject: [PATCH 09/65] =?UTF-8?q?fix:=20update=20S3=20and=20Azure=20config?=
=?UTF-8?q?uration=20typos=20in=20.env.example=20and=20corr=E2=80=A6=20(#1?=
=?UTF-8?q?2055)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
api/.env.example | 6 +++---
api/.ruff.toml | 2 +-
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/api/.env.example b/api/.env.example
index 071a200e680278..cc3e868717e2fb 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -65,7 +65,7 @@ OPENDAL_FS_ROOT=storage
# S3 Storage configuration
S3_USE_AWS_MANAGED_IAM=false
-S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com
+S3_ENDPOINT=https://your-bucket-name.storage.s3.cloudflare.com
S3_BUCKET_NAME=your-bucket-name
S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key
@@ -74,7 +74,7 @@ S3_REGION=your-region
# Azure Blob Storage configuration
AZURE_BLOB_ACCOUNT_NAME=your-account-name
AZURE_BLOB_ACCOUNT_KEY=your-account-key
-AZURE_BLOB_CONTAINER_NAME=yout-container-name
+AZURE_BLOB_CONTAINER_NAME=your-container-name
AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net
# Aliyun oss Storage configuration
@@ -88,7 +88,7 @@ ALIYUN_OSS_REGION=your-region
ALIYUN_OSS_PATH=your-path
# Google Storage configuration
-GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
+GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
# Tencent COS Storage configuration
diff --git a/api/.ruff.toml b/api/.ruff.toml
index 26a1b977a9f6ac..f30275a943d806 100644
--- a/api/.ruff.toml
+++ b/api/.ruff.toml
@@ -67,7 +67,7 @@ ignore = [
"SIM105", # suppressible-exception
"SIM107", # return-in-try-except-finally
"SIM108", # if-else-block-instead-of-if-exp
- "SIM113", # eumerate-for-loop
+ "SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
]
From c98d91e44d75cf03f395eee521e5af9a36a45ad8 Mon Sep 17 00:00:00 2001
From: jiangbo721 <365065261@qq.com>
Date: Wed, 25 Dec 2024 13:29:43 +0800
Subject: [PATCH 10/65] fix: o1 model error, use max_completion_tokens instead
of max_tokens. (#12037)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: 刘江波
---
.../model_providers/azure_openai/llm/llm.py | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py
index c5d7a83a4ee69f..03818741f65875 100644
--- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py
+++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py
@@ -113,7 +113,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
try:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
- if "o1" in model:
+ if model.startswith("o1"):
client.chat.completions.create(
messages=[{"role": "user", "content": "ping"}],
model=model,
@@ -311,7 +311,10 @@ def _chat_generate(
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
block_as_stream = False
- if "o1" in model:
+ if model.startswith("o1"):
+ if "max_tokens" in model_parameters:
+ model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
+ del model_parameters["max_tokens"]
if stream:
block_as_stream = True
stream = False
@@ -404,7 +407,7 @@ def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[Promp
]
)
- if "o1" in model:
+ if model.startswith("o1"):
system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
if system_message_count > 0:
new_prompt_messages = []
From b281a80150139d419e43df3ee08286aa4c4f6513 Mon Sep 17 00:00:00 2001
From: marvin-season <64943287+marvin-season@users.noreply.github.com>
Date: Wed, 25 Dec 2024 13:30:51 +0800
Subject: [PATCH 11/65] fix: zoom in/out click (#12056)
Co-authored-by: marvin
---
.../components/workflow/operator/zoom-in-out.tsx | 14 ++++++++++----
1 file changed, 10 insertions(+), 4 deletions(-)
diff --git a/web/app/components/workflow/operator/zoom-in-out.tsx b/web/app/components/workflow/operator/zoom-in-out.tsx
index 6c4bed3751088f..90b5b46256300c 100644
--- a/web/app/components/workflow/operator/zoom-in-out.tsx
+++ b/web/app/components/workflow/operator/zoom-in-out.tsx
@@ -129,7 +129,7 @@ const ZoomInOut: FC = () => {
crossAxis: -2,
}}
>
-
+
{
shortcuts={['ctrl', '-']}
>
{
+ if (zoom <= 0.25)
+ return
+
e.stopPropagation()
zoomOut()
}}
@@ -153,14 +156,17 @@ const ZoomInOut: FC = () => {
-
{parseFloat(`${zoom * 100}`).toFixed(0)}%
+
{parseFloat(`${zoom * 100}`).toFixed(0)}%
= 2 ? 'cursor-not-allowed' : 'cursor-pointer hover:bg-black/5'}`}
onClick={(e) => {
+ if (zoom >= 2)
+ return
+
e.stopPropagation()
zoomIn()
}}
From 83ea931e3cfc4467003b93949f86281052325902 Mon Sep 17 00:00:00 2001
From: -LAN-
Date: Wed, 25 Dec 2024 16:24:52 +0800
Subject: [PATCH 12/65] refactor: optimize database usage (#12071)
Signed-off-by: -LAN-
---
.../advanced_chat/generate_task_pipeline.py | 352 +++++++++---------
.../app/apps/message_based_app_generator.py | 1 -
.../apps/workflow/generate_task_pipeline.py | 192 +++++-----
.../based_generate_task_pipeline.py | 36 +-
.../easy_ui_based_generate_task_pipeline.py | 145 ++++----
.../app/task_pipeline/message_cycle_manage.py | 4 +-
.../task_pipeline/workflow_cycle_manage.py | 182 +++++----
api/core/ops/ops_trace_manager.py | 186 ++++-----
api/core/ops/utils.py | 2 +-
api/models/account.py | 3 +-
api/models/model.py | 10 +-
api/models/workflow.py | 48 +--
12 files changed, 587 insertions(+), 574 deletions(-)
diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
index 1073a0f2e4f706..691d178ba2aefa 100644
--- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py
+++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
@@ -5,6 +5,9 @@
from threading import Thread
from typing import Any, Optional, Union
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -79,8 +82,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
- _workflow: Workflow
- _user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_conversation_name_generate_thread: Optional[Thread] = None
@@ -96,32 +97,35 @@ def __init__(
stream: bool,
dialogue_count: int,
) -> None:
- """
- Initialize AdvancedChatAppGenerateTaskPipeline.
- :param application_generate_entity: application generate entity
- :param workflow: workflow
- :param queue_manager: queue manager
- :param conversation: conversation
- :param message: message
- :param user: user
- :param stream: stream
- :param dialogue_count: dialogue count
- """
- super().__init__(application_generate_entity, queue_manager, user, stream)
+ super().__init__(
+ application_generate_entity=application_generate_entity,
+ queue_manager=queue_manager,
+ stream=stream,
+ )
- if isinstance(self._user, EndUser):
- user_id = self._user.session_id
+ if isinstance(user, EndUser):
+ self._user_id = user.session_id
+ self._created_by_role = CreatedByRole.END_USER
+ elif isinstance(user, Account):
+ self._user_id = user.id
+ self._created_by_role = CreatedByRole.ACCOUNT
else:
- user_id = self._user.id
+ raise NotImplementedError(f"User type not supported: {type(user)}")
+
+ self._workflow_id = workflow.id
+ self._workflow_features_dict = workflow.features_dict
+
+ self._conversation_id = conversation.id
+ self._conversation_mode = conversation.mode
+
+ self._message_id = message.id
+ self._message_created_at = int(message.created_at.timestamp())
- self._workflow = workflow
- self._conversation = conversation
- self._message = message
self._workflow_system_variables = {
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
- SystemVariableKey.USER_ID: user_id,
+ SystemVariableKey.USER_ID: self._user_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
@@ -139,13 +143,9 @@ def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStrea
Process generate task pipeline.
:return:
"""
- db.session.refresh(self._workflow)
- db.session.refresh(self._user)
- db.session.close()
-
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
- self._conversation, self._application_generate_entity.query
+ conversation_id=self._conversation_id, query=self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
@@ -171,12 +171,12 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]
return ChatbotAppBlockingResponse(
task_id=stream_response.task_id,
data=ChatbotAppBlockingResponse.Data(
- id=self._message.id,
- mode=self._conversation.mode,
- conversation_id=self._conversation.id,
- message_id=self._message.id,
+ id=self._message_id,
+ mode=self._conversation_mode,
+ conversation_id=self._conversation_id,
+ message_id=self._message_id,
answer=self._task_state.answer,
- created_at=int(self._message.created_at.timestamp()),
+ created_at=self._message_created_at,
**extras,
),
)
@@ -194,9 +194,9 @@ def _to_stream_response(
"""
for stream_response in generator:
yield ChatbotAppStreamResponse(
- conversation_id=self._conversation.id,
- message_id=self._message.id,
- created_at=int(self._message.created_at.timestamp()),
+ conversation_id=self._conversation_id,
+ message_id=self._message_id,
+ created_at=self._message_created_at,
stream_response=stream_response,
)
@@ -214,7 +214,7 @@ def _wrapper_process_stream_response(
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
- features_dict = self._workflow.features_dict
+ features_dict = self._workflow_features_dict
if (
features_dict.get("text_to_speech")
@@ -274,26 +274,33 @@ def _process_stream_response(
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
- err = self._handle_error(event, self._message)
+ with Session(db.engine) as session:
+ err = self._handle_error(event=event, session=session, message_id=self._message_id)
+ session.commit()
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
- # init workflow run
- workflow_run = self._handle_workflow_run_start()
-
- self._refetch_message()
- self._message.workflow_run_id = workflow_run.id
-
- db.session.commit()
- db.session.refresh(self._message)
- db.session.close()
-
- yield self._workflow_start_to_stream_response(
- task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
- )
+ with Session(db.engine) as session:
+ # init workflow run
+ workflow_run = self._handle_workflow_run_start(
+ session=session,
+ workflow_id=self._workflow_id,
+ user_id=self._user_id,
+ created_by_role=self._created_by_role,
+ )
+ message = self._get_message(session=session)
+ if not message:
+ raise ValueError(f"Message not found: {self._message_id}")
+ message.workflow_run_id = workflow_run.id
+ session.commit()
+
+ workflow_start_resp = self._workflow_start_to_stream_response(
+ session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
+ )
+ yield workflow_start_resp
elif isinstance(
event,
QueueNodeRetryEvent,
@@ -304,28 +311,28 @@ def _process_stream_response(
workflow_run=workflow_run, event=event
)
- response = self._workflow_node_retry_to_stream_response(
+ node_retry_resp = self._workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
- if response:
- yield response
+ if node_retry_resp:
+ yield node_retry_resp
elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run:
raise ValueError("workflow run not initialized.")
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
- response_start = self._workflow_node_start_to_stream_response(
+ node_start_resp = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
- if response_start:
- yield response_start
+ if node_start_resp:
+ yield node_start_resp
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
@@ -333,25 +340,24 @@ def _process_stream_response(
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
- response_finish = self._workflow_node_finish_to_stream_response(
+ node_finish_resp = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
- if response_finish:
- yield response_finish
+ if node_finish_resp:
+ yield node_finish_resp
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
- response_finish = self._workflow_node_finish_to_stream_response(
+ node_finish_resp = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
-
- if response:
- yield response
+ if node_finish_resp:
+ yield node_finish_resp
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
@@ -395,20 +401,24 @@ def _process_stream_response(
if not graph_runtime_state:
raise ValueError("workflow run not initialized.")
- workflow_run = self._handle_workflow_run_success(
- workflow_run=workflow_run,
- start_at=graph_runtime_state.start_at,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- outputs=event.outputs,
- conversation_id=self._conversation.id,
- trace_manager=trace_manager,
- )
+ with Session(db.engine) as session:
+ workflow_run = self._handle_workflow_run_success(
+ session=session,
+ workflow_run=workflow_run,
+ start_at=graph_runtime_state.start_at,
+ total_tokens=graph_runtime_state.total_tokens,
+ total_steps=graph_runtime_state.node_run_steps,
+ outputs=event.outputs,
+ conversation_id=self._conversation_id,
+ trace_manager=trace_manager,
+ )
- yield self._workflow_finish_to_stream_response(
- task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
- )
+ workflow_finish_resp = self._workflow_finish_to_stream_response(
+ session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
+ )
+ session.commit()
+ yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
@@ -417,21 +427,25 @@ def _process_stream_response(
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
- workflow_run = self._handle_workflow_run_partial_success(
- workflow_run=workflow_run,
- start_at=graph_runtime_state.start_at,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- outputs=event.outputs,
- exceptions_count=event.exceptions_count,
- conversation_id=None,
- trace_manager=trace_manager,
- )
+ with Session(db.engine) as session:
+ workflow_run = self._handle_workflow_run_partial_success(
+ session=session,
+ workflow_run=workflow_run,
+ start_at=graph_runtime_state.start_at,
+ total_tokens=graph_runtime_state.total_tokens,
+ total_steps=graph_runtime_state.node_run_steps,
+ outputs=event.outputs,
+ exceptions_count=event.exceptions_count,
+ conversation_id=None,
+ trace_manager=trace_manager,
+ )
- yield self._workflow_finish_to_stream_response(
- task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
- )
+ workflow_finish_resp = self._workflow_finish_to_stream_response(
+ session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
+ )
+ session.commit()
+ yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
@@ -440,71 +454,73 @@ def _process_stream_response(
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
- workflow_run = self._handle_workflow_run_failed(
- workflow_run=workflow_run,
- start_at=graph_runtime_state.start_at,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- status=WorkflowRunStatus.FAILED,
- error=event.error,
- conversation_id=self._conversation.id,
- trace_manager=trace_manager,
- exceptions_count=event.exceptions_count,
- )
-
- yield self._workflow_finish_to_stream_response(
- task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
- )
-
- err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
- yield self._error_to_stream_response(self._handle_error(err_event, self._message))
- break
- elif isinstance(event, QueueStopEvent):
- if workflow_run and graph_runtime_state:
+ with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
+ session=session,
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
- status=WorkflowRunStatus.STOPPED,
- error=event.get_stop_reason(),
- conversation_id=self._conversation.id,
+ status=WorkflowRunStatus.FAILED,
+ error=event.error,
+ conversation_id=self._conversation_id,
trace_manager=trace_manager,
+ exceptions_count=event.exceptions_count,
)
-
- yield self._workflow_finish_to_stream_response(
- task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
+ workflow_finish_resp = self._workflow_finish_to_stream_response(
+ session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
-
- # Save message
- self._save_message(graph_runtime_state=graph_runtime_state)
+ err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
+ err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
+ session.commit()
+ yield workflow_finish_resp
+ yield self._error_to_stream_response(err)
+ break
+ elif isinstance(event, QueueStopEvent):
+ if workflow_run and graph_runtime_state:
+ with Session(db.engine) as session:
+ workflow_run = self._handle_workflow_run_failed(
+ session=session,
+ workflow_run=workflow_run,
+ start_at=graph_runtime_state.start_at,
+ total_tokens=graph_runtime_state.total_tokens,
+ total_steps=graph_runtime_state.node_run_steps,
+ status=WorkflowRunStatus.STOPPED,
+ error=event.get_stop_reason(),
+ conversation_id=self._conversation_id,
+ trace_manager=trace_manager,
+ )
+
+ workflow_finish_resp = self._workflow_finish_to_stream_response(
+ session=session,
+ task_id=self._application_generate_entity.task_id,
+ workflow_run=workflow_run,
+ )
+ # Save message
+ self._save_message(session=session, graph_runtime_state=graph_runtime_state)
+ session.commit()
+ yield workflow_finish_resp
yield self._message_end_to_stream_response()
break
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
- self._refetch_message()
-
- self._message.message_metadata = (
- json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
- )
-
- db.session.commit()
- db.session.refresh(self._message)
- db.session.close()
+ with Session(db.engine) as session:
+ message = self._get_message(session=session)
+ message.message_metadata = (
+ json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
+ )
+ session.commit()
elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)
- self._refetch_message()
-
- self._message.message_metadata = (
- json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
- )
-
- db.session.commit()
- db.session.refresh(self._message)
- db.session.close()
+ with Session(db.engine) as session:
+ message = self._get_message(session=session)
+ message.message_metadata = (
+ json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
+ )
+ session.commit()
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
@@ -521,7 +537,7 @@ def _process_stream_response(
self._task_state.answer += delta_text
yield self._message_to_stream_response(
- answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector
+ answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
@@ -536,7 +552,9 @@ def _process_stream_response(
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message
- self._save_message(graph_runtime_state=graph_runtime_state)
+ with Session(db.engine) as session:
+ self._save_message(session=session, graph_runtime_state=graph_runtime_state)
+ session.commit()
yield self._message_end_to_stream_response()
else:
@@ -549,54 +567,46 @@ def _process_stream_response(
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
- def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
- self._refetch_message()
-
- self._message.answer = self._task_state.answer
- self._message.provider_response_latency = time.perf_counter() - self._start_at
- self._message.message_metadata = (
+ def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
+ message = self._get_message(session=session)
+ message.answer = self._task_state.answer
+ message.provider_response_latency = time.perf_counter() - self._start_at
+ message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message_files = [
MessageFile(
- message_id=self._message.id,
+ message_id=message.id,
type=file["type"],
transfer_method=file["transfer_method"],
url=file["remote_url"],
belongs_to="assistant",
upload_file_id=file["related_id"],
created_by_role=CreatedByRole.ACCOUNT
- if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
+ if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatedByRole.END_USER,
- created_by=self._message.from_account_id or self._message.from_end_user_id or "",
+ created_by=message.from_account_id or message.from_end_user_id or "",
)
for file in self._recorded_files
]
- db.session.add_all(message_files)
+ session.add_all(message_files)
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
- self._message.message_tokens = usage.prompt_tokens
- self._message.message_unit_price = usage.prompt_unit_price
- self._message.message_price_unit = usage.prompt_price_unit
- self._message.answer_tokens = usage.completion_tokens
- self._message.answer_unit_price = usage.completion_unit_price
- self._message.answer_price_unit = usage.completion_price_unit
- self._message.total_price = usage.total_price
- self._message.currency = usage.currency
-
+ message.message_tokens = usage.prompt_tokens
+ message.message_unit_price = usage.prompt_unit_price
+ message.message_price_unit = usage.prompt_price_unit
+ message.answer_tokens = usage.completion_tokens
+ message.answer_unit_price = usage.completion_unit_price
+ message.answer_price_unit = usage.completion_price_unit
+ message.total_price = usage.total_price
+ message.currency = usage.currency
self._task_state.metadata["usage"] = jsonable_encoder(usage)
else:
self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage())
-
- db.session.commit()
-
message_was_created.send(
- self._message,
+ message,
application_generate_entity=self._application_generate_entity,
- conversation=self._conversation,
- is_first_message=self._application_generate_entity.conversation_id is None,
- extras=self._application_generate_entity.extras,
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@@ -613,7 +623,7 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
- id=self._message.id,
+ id=self._message_id,
files=self._recorded_files,
metadata=extras.get("metadata", {}),
)
@@ -641,11 +651,9 @@ def _handle_output_moderation_chunk(self, text: str) -> bool:
return False
- def _refetch_message(self) -> None:
- """
- Refetch message.
- :return:
- """
- message = db.session.query(Message).filter(Message.id == self._message.id).first()
- if message:
- self._message = message
+ def _get_message(self, *, session: Session):
+ stmt = select(Message).where(Message.id == self._message_id)
+ message = session.scalar(stmt)
+ if not message:
+ raise ValueError(f"Message not found: {self._message_id}")
+ return message
diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py
index c2e35faf89ba15..dcd9463b8abd0f 100644
--- a/api/core/app/apps/message_based_app_generator.py
+++ b/api/core/app/apps/message_based_app_generator.py
@@ -70,7 +70,6 @@ def _handle_response(
queue_manager=queue_manager,
conversation=conversation,
message=message,
- user=user,
stream=stream,
)
diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py
index c47b38f5600f4d..574596d4f5a77c 100644
--- a/api/core/app/apps/workflow/generate_task_pipeline.py
+++ b/api/core/app/apps/workflow/generate_task_pipeline.py
@@ -3,6 +3,8 @@
from collections.abc import Generator
from typing import Any, Optional, Union
+from sqlalchemy.orm import Session
+
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager
@@ -50,6 +52,7 @@
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db
from models.account import Account
+from models.enums import CreatedByRole
from models.model import EndUser
from models.workflow import (
Workflow,
@@ -68,8 +71,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
- _workflow: Workflow
- _user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
@@ -83,25 +84,27 @@ def __init__(
user: Union[Account, EndUser],
stream: bool,
) -> None:
- """
- Initialize GenerateTaskPipeline.
- :param application_generate_entity: application generate entity
- :param workflow: workflow
- :param queue_manager: queue manager
- :param user: user
- :param stream: is streamed
- """
- super().__init__(application_generate_entity, queue_manager, user, stream)
+ super().__init__(
+ application_generate_entity=application_generate_entity,
+ queue_manager=queue_manager,
+ stream=stream,
+ )
- if isinstance(self._user, EndUser):
- user_id = self._user.session_id
+ if isinstance(user, EndUser):
+ self._user_id = user.session_id
+ self._created_by_role = CreatedByRole.END_USER
+ elif isinstance(user, Account):
+ self._user_id = user.id
+ self._created_by_role = CreatedByRole.ACCOUNT
else:
- user_id = self._user.id
+ raise ValueError(f"Invalid user type: {type(user)}")
+
+ self._workflow_id = workflow.id
+ self._workflow_features_dict = workflow.features_dict
- self._workflow = workflow
self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files,
- SystemVariableKey.USER_ID: user_id,
+ SystemVariableKey.USER_ID: self._user_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
@@ -115,10 +118,6 @@ def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStr
Process generate task pipeline.
:return:
"""
- db.session.refresh(self._workflow)
- db.session.refresh(self._user)
- db.session.close()
-
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
return self._to_stream_response(generator)
@@ -185,7 +184,7 @@ def _wrapper_process_stream_response(
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
- features_dict = self._workflow.features_dict
+ features_dict = self._workflow_features_dict
if (
features_dict.get("text_to_speech")
@@ -242,18 +241,26 @@ def _process_stream_response(
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
- err = self._handle_error(event)
+ err = self._handle_error(event=event)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
- # init workflow run
- workflow_run = self._handle_workflow_run_start()
- yield self._workflow_start_to_stream_response(
- task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
- )
+ with Session(db.engine) as session:
+ # init workflow run
+ workflow_run = self._handle_workflow_run_start(
+ session=session,
+ workflow_id=self._workflow_id,
+ user_id=self._user_id,
+ created_by_role=self._created_by_role,
+ )
+ start_resp = self._workflow_start_to_stream_response(
+ session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
+ )
+ session.commit()
+ yield start_resp
elif isinstance(
event,
QueueNodeRetryEvent,
@@ -350,22 +357,28 @@ def _process_stream_response(
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
- workflow_run = self._handle_workflow_run_success(
- workflow_run=workflow_run,
- start_at=graph_runtime_state.start_at,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- outputs=event.outputs,
- conversation_id=None,
- trace_manager=trace_manager,
- )
-
- # save workflow app log
- self._save_workflow_app_log(workflow_run)
-
- yield self._workflow_finish_to_stream_response(
- task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
- )
+ with Session(db.engine) as session:
+ workflow_run = self._handle_workflow_run_success(
+ session=session,
+ workflow_run=workflow_run,
+ start_at=graph_runtime_state.start_at,
+ total_tokens=graph_runtime_state.total_tokens,
+ total_steps=graph_runtime_state.node_run_steps,
+ outputs=event.outputs,
+ conversation_id=None,
+ trace_manager=trace_manager,
+ )
+
+ # save workflow app log
+ self._save_workflow_app_log(session=session, workflow_run=workflow_run)
+
+ workflow_finish_resp = self._workflow_finish_to_stream_response(
+ session=session,
+ task_id=self._application_generate_entity.task_id,
+ workflow_run=workflow_run,
+ )
+ session.commit()
+ yield workflow_finish_resp
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise ValueError("workflow run not initialized.")
@@ -373,49 +386,58 @@ def _process_stream_response(
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
- workflow_run = self._handle_workflow_run_partial_success(
- workflow_run=workflow_run,
- start_at=graph_runtime_state.start_at,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- outputs=event.outputs,
- exceptions_count=event.exceptions_count,
- conversation_id=None,
- trace_manager=trace_manager,
- )
-
- # save workflow app log
- self._save_workflow_app_log(workflow_run)
-
- yield self._workflow_finish_to_stream_response(
- task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
- )
+ with Session(db.engine) as session:
+ workflow_run = self._handle_workflow_run_partial_success(
+ session=session,
+ workflow_run=workflow_run,
+ start_at=graph_runtime_state.start_at,
+ total_tokens=graph_runtime_state.total_tokens,
+ total_steps=graph_runtime_state.node_run_steps,
+ outputs=event.outputs,
+ exceptions_count=event.exceptions_count,
+ conversation_id=None,
+ trace_manager=trace_manager,
+ )
+
+ # save workflow app log
+ self._save_workflow_app_log(session=session, workflow_run=workflow_run)
+
+ workflow_finish_resp = self._workflow_finish_to_stream_response(
+ session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
+ )
+ session.commit()
+
+ yield workflow_finish_resp
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
- workflow_run = self._handle_workflow_run_failed(
- workflow_run=workflow_run,
- start_at=graph_runtime_state.start_at,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- status=WorkflowRunStatus.FAILED
- if isinstance(event, QueueWorkflowFailedEvent)
- else WorkflowRunStatus.STOPPED,
- error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
- conversation_id=None,
- trace_manager=trace_manager,
- exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
- )
-
- # save workflow app log
- self._save_workflow_app_log(workflow_run)
-
- yield self._workflow_finish_to_stream_response(
- task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
- )
+ with Session(db.engine) as session:
+ workflow_run = self._handle_workflow_run_failed(
+ session=session,
+ workflow_run=workflow_run,
+ start_at=graph_runtime_state.start_at,
+ total_tokens=graph_runtime_state.total_tokens,
+ total_steps=graph_runtime_state.node_run_steps,
+ status=WorkflowRunStatus.FAILED
+ if isinstance(event, QueueWorkflowFailedEvent)
+ else WorkflowRunStatus.STOPPED,
+ error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
+ conversation_id=None,
+ trace_manager=trace_manager,
+ exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
+ )
+
+ # save workflow app log
+ self._save_workflow_app_log(session=session, workflow_run=workflow_run)
+
+ workflow_finish_resp = self._workflow_finish_to_stream_response(
+ session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
+ )
+ session.commit()
+ yield workflow_finish_resp
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
@@ -435,7 +457,7 @@ def _process_stream_response(
if tts_publisher:
tts_publisher.publish(None)
- def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
+ def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None:
"""
Save workflow app log.
:return:
@@ -457,12 +479,10 @@ def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value
- workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user"
- workflow_app_log.created_by = self._user.id
+ workflow_app_log.created_by_role = self._created_by_role
+ workflow_app_log.created_by = self._user_id
- db.session.add(workflow_app_log)
- db.session.commit()
- db.session.close()
+ session.add(workflow_app_log)
def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None
diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py
index 03a81353d02625..e363a7f64244d3 100644
--- a/api/core/app/task_pipeline/based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py
@@ -1,6 +1,9 @@
import logging
import time
-from typing import Optional, Union
+from typing import Optional
+
+from sqlalchemy import select
+from sqlalchemy.orm import Session
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import (
@@ -17,9 +20,7 @@
from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.moderation.output_moderation import ModerationRule, OutputModeration
-from extensions.ext_database import db
-from models.account import Account
-from models.model import EndUser, Message
+from models.model import Message
logger = logging.getLogger(__name__)
@@ -36,7 +37,6 @@ def __init__(
self,
application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
- user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
@@ -48,18 +48,11 @@ def __init__(
"""
self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager
- self._user = user
self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation()
self._stream = stream
- def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None):
- """
- Handle error event.
- :param event: event
- :param message: message
- :return:
- """
+ def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
logger.debug("error: %s", event.error)
e = event.error
err: Exception
@@ -71,16 +64,17 @@ def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = Non
else:
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
- if message:
- refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
-
- if refetch_message:
- err_desc = self._error_to_desc(err)
- refetch_message.status = "error"
- refetch_message.error = err_desc
+ if not message_id or not session:
+ return err
- db.session.commit()
+ stmt = select(Message).where(Message.id == message_id)
+ message = session.scalar(stmt)
+ if not message:
+ return err
+ err_desc = self._error_to_desc(err)
+ message.status = "error"
+ message.error = err_desc
return err
def _error_to_desc(self, e: Exception) -> str:
diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
index b9f8e7ca560ce7..c84f8ba3e450cc 100644
--- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
@@ -5,6 +5,9 @@
from threading import Thread
from typing import Optional, Union, cast
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -55,8 +58,7 @@
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
-from models.account import Account
-from models.model import AppMode, Conversation, EndUser, Message, MessageAgentThought
+from models.model import AppMode, Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)
@@ -77,23 +79,21 @@ def __init__(
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
- user: Union[Account, EndUser],
stream: bool,
) -> None:
- """
- Initialize GenerateTaskPipeline.
- :param application_generate_entity: application generate entity
- :param queue_manager: queue manager
- :param conversation: conversation
- :param message: message
- :param user: user
- :param stream: stream
- """
- super().__init__(application_generate_entity, queue_manager, user, stream)
+ super().__init__(
+ application_generate_entity=application_generate_entity,
+ queue_manager=queue_manager,
+ stream=stream,
+ )
self._model_config = application_generate_entity.model_conf
self._app_config = application_generate_entity.app_config
- self._conversation = conversation
- self._message = message
+
+ self._conversation_id = conversation.id
+ self._conversation_mode = conversation.mode
+
+ self._message_id = message.id
+ self._message_created_at = int(message.created_at.timestamp())
self._task_state = EasyUITaskState(
llm_result=LLMResult(
@@ -113,18 +113,10 @@ def process(
CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]:
- """
- Process generate task pipeline.
- :return:
- """
- db.session.refresh(self._conversation)
- db.session.refresh(self._message)
- db.session.close()
-
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
- self._conversation, self._application_generate_entity.query or ""
+ conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
@@ -148,15 +140,15 @@ def _to_blocking_response(
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
- if self._conversation.mode == AppMode.COMPLETION.value:
+ if self._conversation_mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=CompletionAppBlockingResponse.Data(
- id=self._message.id,
- mode=self._conversation.mode,
- message_id=self._message.id,
+ id=self._message_id,
+ mode=self._conversation_mode,
+ message_id=self._message_id,
answer=cast(str, self._task_state.llm_result.message.content),
- created_at=int(self._message.created_at.timestamp()),
+ created_at=self._message_created_at,
**extras,
),
)
@@ -164,12 +156,12 @@ def _to_blocking_response(
response = ChatbotAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=ChatbotAppBlockingResponse.Data(
- id=self._message.id,
- mode=self._conversation.mode,
- conversation_id=self._conversation.id,
- message_id=self._message.id,
+ id=self._message_id,
+ mode=self._conversation_mode,
+ conversation_id=self._conversation_id,
+ message_id=self._message_id,
answer=cast(str, self._task_state.llm_result.message.content),
- created_at=int(self._message.created_at.timestamp()),
+ created_at=self._message_created_at,
**extras,
),
)
@@ -190,15 +182,15 @@ def _to_stream_response(
for stream_response in generator:
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
yield CompletionAppStreamResponse(
- message_id=self._message.id,
- created_at=int(self._message.created_at.timestamp()),
+ message_id=self._message_id,
+ created_at=self._message_created_at,
stream_response=stream_response,
)
else:
yield ChatbotAppStreamResponse(
- conversation_id=self._conversation.id,
- message_id=self._message.id,
- created_at=int(self._message.created_at.timestamp()),
+ conversation_id=self._conversation_id,
+ message_id=self._message_id,
+ created_at=self._message_created_at,
stream_response=stream_response,
)
@@ -265,7 +257,9 @@ def _process_stream_response(
event = message.event
if isinstance(event, QueueErrorEvent):
- err = self._handle_error(event, self._message)
+ with Session(db.engine) as session:
+ err = self._handle_error(event=event, session=session, message_id=self._message_id)
+ session.commit()
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
@@ -283,10 +277,12 @@ def _process_stream_response(
self._task_state.llm_result.message.content = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
- # Save message
- self._save_message(trace_manager)
-
- yield self._message_end_to_stream_response()
+ with Session(db.engine) as session:
+ # Save message
+ self._save_message(session=session, trace_manager=trace_manager)
+ session.commit()
+ message_end_resp = self._message_end_to_stream_response()
+ yield message_end_resp
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent):
@@ -320,9 +316,15 @@ def _process_stream_response(
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
- yield self._message_to_stream_response(cast(str, delta_text), self._message.id)
+ yield self._message_to_stream_response(
+ answer=cast(str, delta_text),
+ message_id=self._message_id,
+ )
else:
- yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id)
+ yield self._agent_message_to_stream_response(
+ answer=cast(str, delta_text),
+ message_id=self._message_id,
+ )
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
@@ -334,7 +336,7 @@ def _process_stream_response(
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
- def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None:
+ def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None:
"""
Save message.
:return:
@@ -342,53 +344,46 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No
llm_result = self._task_state.llm_result
usage = llm_result.usage
- message = db.session.query(Message).filter(Message.id == self._message.id).first()
+ message_stmt = select(Message).where(Message.id == self._message_id)
+ message = session.scalar(message_stmt)
if not message:
- raise Exception(f"Message {self._message.id} not found")
- self._message = message
- conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
+ raise ValueError(f"message {self._message_id} not found")
+ conversation_stmt = select(Conversation).where(Conversation.id == self._conversation_id)
+ conversation = session.scalar(conversation_stmt)
if not conversation:
- raise Exception(f"Conversation {self._conversation.id} not found")
- self._conversation = conversation
+ raise ValueError(f"Conversation {self._conversation_id} not found")
- self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
+ message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
self._model_config.mode, self._task_state.llm_result.prompt_messages
)
- self._message.message_tokens = usage.prompt_tokens
- self._message.message_unit_price = usage.prompt_unit_price
- self._message.message_price_unit = usage.prompt_price_unit
- self._message.answer = (
+ message.message_tokens = usage.prompt_tokens
+ message.message_unit_price = usage.prompt_unit_price
+ message.message_price_unit = usage.prompt_price_unit
+ message.answer = (
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
if llm_result.message.content
else ""
)
- self._message.answer_tokens = usage.completion_tokens
- self._message.answer_unit_price = usage.completion_unit_price
- self._message.answer_price_unit = usage.completion_price_unit
- self._message.provider_response_latency = time.perf_counter() - self._start_at
- self._message.total_price = usage.total_price
- self._message.currency = usage.currency
- self._message.message_metadata = (
+ message.answer_tokens = usage.completion_tokens
+ message.answer_unit_price = usage.completion_unit_price
+ message.answer_price_unit = usage.completion_price_unit
+ message.provider_response_latency = time.perf_counter() - self._start_at
+ message.total_price = usage.total_price
+ message.currency = usage.currency
+ message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
- db.session.commit()
-
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
- TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id
+ TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
)
)
message_was_created.send(
- self._message,
+ message,
application_generate_entity=self._application_generate_entity,
- conversation=self._conversation,
- is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
- and hasattr(self._application_generate_entity, "conversation_id")
- and self._application_generate_entity.conversation_id is None,
- extras=self._application_generate_entity.extras,
)
def _handle_stop(self, event: QueueStopEvent) -> None:
@@ -434,7 +429,7 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
- id=self._message.id,
+ id=self._message_id,
metadata=extras.get("metadata", {}),
)
diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py
index 007543f6d0d1f2..15f2c25c66a3d2 100644
--- a/api/core/app/task_pipeline/message_cycle_manage.py
+++ b/api/core/app/task_pipeline/message_cycle_manage.py
@@ -36,7 +36,7 @@ class MessageCycleManage:
]
_task_state: Union[EasyUITaskState, WorkflowTaskState]
- def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]:
+ def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
"""
Generate conversation name.
:param conversation: conversation
@@ -56,7 +56,7 @@ def _generate_conversation_name(self, conversation: Conversation, query: str) ->
target=self._generate_conversation_name_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
- "conversation_id": conversation.id,
+ "conversation_id": conversation_id,
"query": query,
},
)
diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py
index f581e564f224ce..2692008c6653d6 100644
--- a/api/core/app/task_pipeline/workflow_cycle_manage.py
+++ b/api/core/app/task_pipeline/workflow_cycle_manage.py
@@ -5,6 +5,7 @@
from typing import Any, Optional, Union, cast
from uuid import uuid4
+from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
@@ -63,27 +64,34 @@
class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
- _workflow: Workflow
- _user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
- def _handle_workflow_run_start(self) -> WorkflowRun:
- max_sequence = (
- db.session.query(db.func.max(WorkflowRun.sequence_number))
- .filter(WorkflowRun.tenant_id == self._workflow.tenant_id)
- .filter(WorkflowRun.app_id == self._workflow.app_id)
- .scalar()
- or 0
+ def _handle_workflow_run_start(
+ self,
+ *,
+ session: Session,
+ workflow_id: str,
+ user_id: str,
+ created_by_role: CreatedByRole,
+ ) -> WorkflowRun:
+ workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
+ workflow = session.scalar(workflow_stmt)
+ if not workflow:
+ raise ValueError(f"Workflow not found: {workflow_id}")
+
+ max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where(
+ WorkflowRun.tenant_id == workflow.tenant_id,
+ WorkflowRun.app_id == workflow.app_id,
)
+ max_sequence = session.scalar(max_sequence_stmt) or 0
new_sequence_number = max_sequence + 1
inputs = {**self._application_generate_entity.inputs}
for key, value in (self._workflow_system_variables or {}).items():
if key.value == "conversation":
continue
-
inputs[f"sys.{key.value}"] = value
triggered_from = (
@@ -96,33 +104,32 @@ def _handle_workflow_run_start(self) -> WorkflowRun:
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_run = WorkflowRun()
- system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID]
- workflow_run.id = system_id or str(uuid4())
- workflow_run.tenant_id = self._workflow.tenant_id
- workflow_run.app_id = self._workflow.app_id
- workflow_run.sequence_number = new_sequence_number
- workflow_run.workflow_id = self._workflow.id
- workflow_run.type = self._workflow.type
- workflow_run.triggered_from = triggered_from.value
- workflow_run.version = self._workflow.version
- workflow_run.graph = self._workflow.graph
- workflow_run.inputs = json.dumps(inputs)
- workflow_run.status = WorkflowRunStatus.RUNNING
- workflow_run.created_by_role = (
- CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER
- )
- workflow_run.created_by = self._user.id
- workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
-
- session.add(workflow_run)
- session.commit()
+ workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4()))
+
+ workflow_run = WorkflowRun()
+ workflow_run.id = workflow_run_id
+ workflow_run.tenant_id = workflow.tenant_id
+ workflow_run.app_id = workflow.app_id
+ workflow_run.sequence_number = new_sequence_number
+ workflow_run.workflow_id = workflow.id
+ workflow_run.type = workflow.type
+ workflow_run.triggered_from = triggered_from.value
+ workflow_run.version = workflow.version
+ workflow_run.graph = workflow.graph
+ workflow_run.inputs = json.dumps(inputs)
+ workflow_run.status = WorkflowRunStatus.RUNNING
+ workflow_run.created_by_role = created_by_role
+ workflow_run.created_by = user_id
+ workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
+
+ session.add(workflow_run)
return workflow_run
def _handle_workflow_run_success(
self,
+ *,
+ session: Session,
workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
@@ -141,7 +148,7 @@ def _handle_workflow_run_success(
:param conversation_id: conversation id
:return:
"""
- workflow_run = self._refetch_workflow_run(workflow_run.id)
+ workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id)
outputs = WorkflowEntry.handle_special_values(outputs)
@@ -152,9 +159,6 @@ def _handle_workflow_run_success(
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
- db.session.commit()
- db.session.refresh(workflow_run)
-
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
@@ -165,12 +169,12 @@ def _handle_workflow_run_success(
)
)
- db.session.close()
-
return workflow_run
def _handle_workflow_run_partial_success(
self,
+ *,
+ session: Session,
workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
@@ -190,7 +194,7 @@ def _handle_workflow_run_partial_success(
:param conversation_id: conversation id
:return:
"""
- workflow_run = self._refetch_workflow_run(workflow_run.id)
+ workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id)
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
@@ -201,8 +205,6 @@ def _handle_workflow_run_partial_success(
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
- db.session.commit()
- db.session.refresh(workflow_run)
if trace_manager:
trace_manager.add_trace_task(
@@ -214,12 +216,12 @@ def _handle_workflow_run_partial_success(
)
)
- db.session.close()
-
return workflow_run
def _handle_workflow_run_failed(
self,
+ *,
+ session: Session,
workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
@@ -240,7 +242,7 @@ def _handle_workflow_run_failed(
:param error: error message
:return:
"""
- workflow_run = self._refetch_workflow_run(workflow_run.id)
+ workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id)
workflow_run.status = status.value
workflow_run.error = error
@@ -249,21 +251,18 @@ def _handle_workflow_run_failed(
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
- db.session.commit()
- running_workflow_node_executions = (
- db.session.query(WorkflowNodeExecution)
- .filter(
- WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
- WorkflowNodeExecution.app_id == workflow_run.app_id,
- WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
- WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
- WorkflowNodeExecution.workflow_run_id == workflow_run.id,
- WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
- )
- .all()
+ stmt = select(WorkflowNodeExecution).where(
+ WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
+ WorkflowNodeExecution.app_id == workflow_run.app_id,
+ WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
+ WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
+ WorkflowNodeExecution.workflow_run_id == workflow_run.id,
+ WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
)
+ running_workflow_node_executions = session.scalars(stmt).all()
+
for workflow_node_execution in running_workflow_node_executions:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
@@ -271,13 +270,6 @@ def _handle_workflow_run_failed(
workflow_node_execution.elapsed_time = (
workflow_node_execution.finished_at - workflow_node_execution.created_at
).total_seconds()
- db.session.commit()
-
- db.session.close()
-
- # with Session(db.engine, expire_on_commit=False) as session:
- # session.add(workflow_run)
- # session.refresh(workflow_run)
if trace_manager:
trace_manager.add_trace_task(
@@ -485,14 +477,14 @@ def _handle_workflow_node_execution_retried(
#################################################
def _workflow_start_to_stream_response(
- self, task_id: str, workflow_run: WorkflowRun
+ self,
+ *,
+ session: Session,
+ task_id: str,
+ workflow_run: WorkflowRun,
) -> WorkflowStartStreamResponse:
- """
- Workflow start to stream response.
- :param task_id: task id
- :param workflow_run: workflow run
- :return:
- """
+ # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
+ _ = session
return WorkflowStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
@@ -506,36 +498,32 @@ def _workflow_start_to_stream_response(
)
def _workflow_finish_to_stream_response(
- self, task_id: str, workflow_run: WorkflowRun
+ self,
+ *,
+ session: Session,
+ task_id: str,
+ workflow_run: WorkflowRun,
) -> WorkflowFinishStreamResponse:
- """
- Workflow finish to stream response.
- :param task_id: task id
- :param workflow_run: workflow run
- :return:
- """
- # Attach WorkflowRun to an active session so "created_by_role" can be accessed.
- workflow_run = db.session.merge(workflow_run)
-
- # Refresh to ensure any expired attributes are fully loaded
- db.session.refresh(workflow_run)
-
created_by = None
- if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value:
- created_by_account = workflow_run.created_by_account
- if created_by_account:
+ if workflow_run.created_by_role == CreatedByRole.ACCOUNT:
+ stmt = select(Account).where(Account.id == workflow_run.created_by)
+ account = session.scalar(stmt)
+ if account:
created_by = {
- "id": created_by_account.id,
- "name": created_by_account.name,
- "email": created_by_account.email,
+ "id": account.id,
+ "name": account.name,
+ "email": account.email,
}
- else:
- created_by_end_user = workflow_run.created_by_end_user
- if created_by_end_user:
+ elif workflow_run.created_by_role == CreatedByRole.END_USER:
+ stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
+ end_user = session.scalar(stmt)
+ if end_user:
created_by = {
- "id": created_by_end_user.id,
- "user": created_by_end_user.session_id,
+ "id": end_user.id,
+ "user": end_user.session_id,
}
+ else:
+ raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
return WorkflowFinishStreamResponse(
task_id=task_id,
@@ -895,14 +883,14 @@ def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any
return None
- def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
+ def _refetch_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run
:param workflow_run_id: workflow run id
:return:
"""
- workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
-
+ stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
+ workflow_run = session.scalar(stmt)
if not workflow_run:
raise WorkflowRunNotFoundError(workflow_run_id)
diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py
index f538eaef5bd570..691cb8d400964c 100644
--- a/api/core/ops/ops_trace_manager.py
+++ b/api/core/ops/ops_trace_manager.py
@@ -9,6 +9,8 @@
from uuid import UUID, uuid4
from flask import current_app
+from sqlalchemy import select
+from sqlalchemy.orm import Session
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
@@ -329,15 +331,15 @@ def __init__(
):
self.trace_type = trace_type
self.message_id = message_id
- self.workflow_run = workflow_run
+ self.workflow_run_id = workflow_run.id if workflow_run else None
self.conversation_id = conversation_id
self.user_id = user_id
self.timer = timer
- self.kwargs = kwargs
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
-
self.app_id = None
+ self.kwargs = kwargs
+
def execute(self):
return self.preprocess()
@@ -345,19 +347,23 @@ def preprocess(self):
preprocess_map = {
TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
- self.workflow_run, self.conversation_id, self.user_id
+ workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
+ ),
+ TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
+ TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
+ message_id=self.message_id, timer=self.timer, **self.kwargs
),
- TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id),
- TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs),
TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
- self.message_id, self.timer, **self.kwargs
+ message_id=self.message_id, timer=self.timer, **self.kwargs
),
TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
- self.message_id, self.timer, **self.kwargs
+ message_id=self.message_id, timer=self.timer, **self.kwargs
+ ),
+ TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(
+ message_id=self.message_id, timer=self.timer, **self.kwargs
),
- TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(self.message_id, self.timer, **self.kwargs),
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
- self.conversation_id, self.timer, **self.kwargs
+ conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
),
}
@@ -367,86 +373,100 @@ def preprocess(self):
def conversation_trace(self, **kwargs):
return kwargs
- def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id):
- if not workflow_run:
- raise ValueError("Workflow run not found")
-
- db.session.merge(workflow_run)
- db.session.refresh(workflow_run)
-
- workflow_id = workflow_run.workflow_id
- tenant_id = workflow_run.tenant_id
- workflow_run_id = workflow_run.id
- workflow_run_elapsed_time = workflow_run.elapsed_time
- workflow_run_status = workflow_run.status
- workflow_run_inputs = workflow_run.inputs_dict
- workflow_run_outputs = workflow_run.outputs_dict
- workflow_run_version = workflow_run.version
- error = workflow_run.error or ""
-
- total_tokens = workflow_run.total_tokens
-
- file_list = workflow_run_inputs.get("sys.file") or []
- query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
-
- # get workflow_app_log_id
- workflow_app_log_data = (
- db.session.query(WorkflowAppLog)
- .filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id)
- .first()
- )
- workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
- # get message_id
- message_data = (
- db.session.query(Message.id)
- .filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id)
- .first()
- )
- message_id = str(message_data.id) if message_data else None
-
- metadata = {
- "workflow_id": workflow_id,
- "conversation_id": conversation_id,
- "workflow_run_id": workflow_run_id,
- "tenant_id": tenant_id,
- "elapsed_time": workflow_run_elapsed_time,
- "status": workflow_run_status,
- "version": workflow_run_version,
- "total_tokens": total_tokens,
- "file_list": file_list,
- "triggered_form": workflow_run.triggered_from,
- "user_id": user_id,
- }
+ def workflow_trace(
+ self,
+ *,
+ workflow_run_id: str | None,
+ conversation_id: str | None,
+ user_id: str | None,
+ ):
+ if not workflow_run_id:
+ return {}
- workflow_trace_info = WorkflowTraceInfo(
- workflow_data=workflow_run.to_dict(),
- conversation_id=conversation_id,
- workflow_id=workflow_id,
- tenant_id=tenant_id,
- workflow_run_id=workflow_run_id,
- workflow_run_elapsed_time=workflow_run_elapsed_time,
- workflow_run_status=workflow_run_status,
- workflow_run_inputs=workflow_run_inputs,
- workflow_run_outputs=workflow_run_outputs,
- workflow_run_version=workflow_run_version,
- error=error,
- total_tokens=total_tokens,
- file_list=file_list,
- query=query,
- metadata=metadata,
- workflow_app_log_id=workflow_app_log_id,
- message_id=message_id,
- start_time=workflow_run.created_at,
- end_time=workflow_run.finished_at,
- )
+ with Session(db.engine) as session:
+ workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
+ workflow_run = session.scalars(workflow_run_stmt).first()
+ if not workflow_run:
+ raise ValueError("Workflow run not found")
+
+ workflow_id = workflow_run.workflow_id
+ tenant_id = workflow_run.tenant_id
+ workflow_run_id = workflow_run.id
+ workflow_run_elapsed_time = workflow_run.elapsed_time
+ workflow_run_status = workflow_run.status
+ workflow_run_inputs = workflow_run.inputs_dict
+ workflow_run_outputs = workflow_run.outputs_dict
+ workflow_run_version = workflow_run.version
+ error = workflow_run.error or ""
+
+ total_tokens = workflow_run.total_tokens
+
+ file_list = workflow_run_inputs.get("sys.file") or []
+ query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
+
+ # get workflow_app_log_id
+ workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
+ WorkflowAppLog.tenant_id == tenant_id,
+ WorkflowAppLog.app_id == workflow_run.app_id,
+ WorkflowAppLog.workflow_run_id == workflow_run.id,
+ )
+ workflow_app_log_id = session.scalar(workflow_app_log_data_stmt)
+ # get message_id
+ message_id = None
+ if conversation_id:
+ message_data_stmt = select(Message.id).where(
+ Message.conversation_id == conversation_id,
+ Message.workflow_run_id == workflow_run_id,
+ )
+ message_id = session.scalar(message_data_stmt)
+
+ metadata = {
+ "workflow_id": workflow_id,
+ "conversation_id": conversation_id,
+ "workflow_run_id": workflow_run_id,
+ "tenant_id": tenant_id,
+ "elapsed_time": workflow_run_elapsed_time,
+ "status": workflow_run_status,
+ "version": workflow_run_version,
+ "total_tokens": total_tokens,
+ "file_list": file_list,
+ "triggered_form": workflow_run.triggered_from,
+ "user_id": user_id,
+ }
+ workflow_trace_info = WorkflowTraceInfo(
+ workflow_data=workflow_run.to_dict(),
+ conversation_id=conversation_id,
+ workflow_id=workflow_id,
+ tenant_id=tenant_id,
+ workflow_run_id=workflow_run_id,
+ workflow_run_elapsed_time=workflow_run_elapsed_time,
+ workflow_run_status=workflow_run_status,
+ workflow_run_inputs=workflow_run_inputs,
+ workflow_run_outputs=workflow_run_outputs,
+ workflow_run_version=workflow_run_version,
+ error=error,
+ total_tokens=total_tokens,
+ file_list=file_list,
+ query=query,
+ metadata=metadata,
+ workflow_app_log_id=workflow_app_log_id,
+ message_id=message_id,
+ start_time=workflow_run.created_at,
+ end_time=workflow_run.finished_at,
+ )
return workflow_trace_info
- def message_trace(self, message_id):
+ def message_trace(self, message_id: str | None):
+ if not message_id:
+ return {}
message_data = get_message_data(message_id)
if not message_data:
return {}
- conversation_mode = db.session.query(Conversation.mode).filter_by(id=message_data.conversation_id).first()
+ conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
+ conversation_mode = db.session.scalars(conversation_mode_stmt).all()
+ if not conversation_mode or len(conversation_mode) == 0:
+ return {}
conversation_mode = conversation_mode[0]
created_at = message_data.created_at
inputs = message_data.message
diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py
index 998eba9ea9791b..8b06df1930595d 100644
--- a/api/core/ops/utils.py
+++ b/api/core/ops/utils.py
@@ -18,7 +18,7 @@ def filter_none_values(data: dict):
return new_data
-def get_message_data(message_id):
+def get_message_data(message_id: str):
return db.session.query(Message).filter(Message.id == message_id).first()
diff --git a/api/models/account.py b/api/models/account.py
index 88c96da1a149d5..35a28df7505943 100644
--- a/api/models/account.py
+++ b/api/models/account.py
@@ -3,6 +3,7 @@
from flask_login import UserMixin # type: ignore
from sqlalchemy import func
+from sqlalchemy.orm import Mapped, mapped_column
from .engine import db
from .types import StringUUID
@@ -20,7 +21,7 @@ class Account(UserMixin, db.Model): # type: ignore[name-defined]
__tablename__ = "accounts"
__table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email"))
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
name = db.Column(db.String(255), nullable=False)
email = db.Column(db.String(255), nullable=False)
password = db.Column(db.String(255), nullable=True)
diff --git a/api/models/model.py b/api/models/model.py
index 2a593f08298199..d2d4d5853fd2b9 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -530,13 +530,13 @@ class Conversation(db.Model): # type: ignore[name-defined]
db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False)
app_model_config_id = db.Column(StringUUID, nullable=True)
model_provider = db.Column(db.String(255), nullable=True)
override_model_configs = db.Column(db.Text)
model_id = db.Column(db.String(255), nullable=True)
- mode = db.Column(db.String(255), nullable=False)
+ mode: Mapped[str] = mapped_column(db.String(255))
name = db.Column(db.String(255), nullable=False)
summary = db.Column(db.Text)
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
@@ -770,7 +770,7 @@ class Message(db.Model): # type: ignore[name-defined]
db.Index("message_created_at_idx", "created_at"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False)
model_provider = db.Column(db.String(255), nullable=True)
model_id = db.Column(db.String(255), nullable=True)
@@ -797,7 +797,7 @@ class Message(db.Model): # type: ignore[name-defined]
from_source = db.Column(db.String(255), nullable=False)
from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID)
from_account_id: Mapped[Optional[str]] = db.Column(StringUUID)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
workflow_run_id = db.Column(StringUUID)
@@ -1322,7 +1322,7 @@ class EndUser(UserMixin, db.Model): # type: ignore[name-defined]
external_user_id = db.Column(db.String(255), nullable=True)
name = db.Column(db.String(255))
is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
- session_id = db.Column(db.String(255), nullable=False)
+ session_id: Mapped[str] = mapped_column()
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
diff --git a/api/models/workflow.py b/api/models/workflow.py
index 880e044d073a67..78a7f8169fe634 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -392,40 +392,28 @@ class WorkflowRun(db.Model): # type: ignore[name-defined]
db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- app_id = db.Column(StringUUID, nullable=False)
- sequence_number = db.Column(db.Integer, nullable=False)
- workflow_id = db.Column(StringUUID, nullable=False)
- type = db.Column(db.String(255), nullable=False)
- triggered_from = db.Column(db.String(255), nullable=False)
- version = db.Column(db.String(255), nullable=False)
- graph = db.Column(db.Text)
- inputs = db.Column(db.Text)
- status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID)
+ app_id: Mapped[str] = mapped_column(StringUUID)
+ sequence_number: Mapped[int] = mapped_column()
+ workflow_id: Mapped[str] = mapped_column(StringUUID)
+ type: Mapped[str] = mapped_column(db.String(255))
+ triggered_from: Mapped[str] = mapped_column(db.String(255))
+ version: Mapped[str] = mapped_column(db.String(255))
+ graph: Mapped[str] = mapped_column(db.Text)
+ inputs: Mapped[str] = mapped_column(db.Text)
+ status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
- error = db.Column(db.Text)
+ error: Mapped[str] = mapped_column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
- total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
+ total_tokens: Mapped[int] = mapped_column(server_default=db.text("0"))
total_steps = db.Column(db.Integer, server_default=db.text("0"))
- created_by_role = db.Column(db.String(255), nullable=False) # account, end_user
+ created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
finished_at = db.Column(db.DateTime)
exceptions_count = db.Column(db.Integer, server_default=db.text("0"))
- @property
- def created_by_account(self):
- created_by_role = CreatedByRole(self.created_by_role)
- return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
-
- @property
- def created_by_end_user(self):
- from models.model import EndUser
-
- created_by_role = CreatedByRole(self.created_by_role)
- return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
-
@property
def graph_dict(self):
return json.loads(self.graph) if self.graph else {}
@@ -750,11 +738,11 @@ class WorkflowAppLog(db.Model): # type: ignore[name-defined]
db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- app_id = db.Column(StringUUID, nullable=False)
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID)
+ app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id = db.Column(StringUUID, nullable=False)
- workflow_run_id = db.Column(StringUUID, nullable=False)
+ workflow_run_id: Mapped[str] = mapped_column(StringUUID)
created_from = db.Column(db.String(255), nullable=False)
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False)
From 1885d3df9968d778664bacdde4ea7a4d9448070f Mon Sep 17 00:00:00 2001
From: Cemre Mengu
Date: Wed, 25 Dec 2024 11:31:01 +0300
Subject: [PATCH 13/65] fix: unquote urls in docker-compose.yaml (#12072)
Signed-off-by: -LAN-
Co-authored-by: -LAN-
---
docker/docker-compose.yaml | 56 +++++++++++++++++-----------------
docker/generate_docker_compose | 2 +-
2 files changed, 29 insertions(+), 29 deletions(-)
diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml
index 7122f4a6d0f768..e65ca45858118f 100644
--- a/docker/docker-compose.yaml
+++ b/docker/docker-compose.yaml
@@ -15,15 +15,15 @@ x-shared-env: &shared-api-worker-env
LOG_FILE: ${LOG_FILE:-/app/logs/server.log}
LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20}
LOG_FILE_BACKUP_COUNT: ${LOG_FILE_BACKUP_COUNT:-5}
- LOG_DATEFORMAT: ${LOG_DATEFORMAT:-"%Y-%m-%d %H:%M:%S"}
+ LOG_DATEFORMAT: ${LOG_DATEFORMAT:-%Y-%m-%d %H:%M:%S}
LOG_TZ: ${LOG_TZ:-UTC}
DEBUG: ${DEBUG:-false}
FLASK_DEBUG: ${FLASK_DEBUG:-false}
SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U}
INIT_PASSWORD: ${INIT_PASSWORD:-}
DEPLOY_ENV: ${DEPLOY_ENV:-PRODUCTION}
- CHECK_UPDATE_URL: ${CHECK_UPDATE_URL:-"https://updates.dify.ai"}
- OPENAI_API_BASE: ${OPENAI_API_BASE:-"https://api.openai.com/v1"}
+ CHECK_UPDATE_URL: ${CHECK_UPDATE_URL:-https://updates.dify.ai}
+ OPENAI_API_BASE: ${OPENAI_API_BASE:-https://api.openai.com/v1}
MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true}
FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300}
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
@@ -69,7 +69,7 @@ x-shared-env: &shared-api-worker-env
REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false}
REDIS_CLUSTERS: ${REDIS_CLUSTERS:-}
REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-}
- CELERY_BROKER_URL: ${CELERY_BROKER_URL:-"redis://:difyai123456@redis:6379/1"}
+ CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1}
BROKER_USE_SSL: ${BROKER_USE_SSL:-false}
CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false}
CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-}
@@ -88,13 +88,13 @@ x-shared-env: &shared-api-worker-env
AZURE_BLOB_ACCOUNT_NAME: ${AZURE_BLOB_ACCOUNT_NAME:-difyai}
AZURE_BLOB_ACCOUNT_KEY: ${AZURE_BLOB_ACCOUNT_KEY:-difyai}
AZURE_BLOB_CONTAINER_NAME: ${AZURE_BLOB_CONTAINER_NAME:-difyai-container}
- AZURE_BLOB_ACCOUNT_URL: ${AZURE_BLOB_ACCOUNT_URL:-"https://.blob.core.windows.net"}
+ AZURE_BLOB_ACCOUNT_URL: ${AZURE_BLOB_ACCOUNT_URL:-https://.blob.core.windows.net}
GOOGLE_STORAGE_BUCKET_NAME: ${GOOGLE_STORAGE_BUCKET_NAME:-your-bucket-name}
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: ${GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64:-your-google-service-account-json-base64-string}
ALIYUN_OSS_BUCKET_NAME: ${ALIYUN_OSS_BUCKET_NAME:-your-bucket-name}
ALIYUN_OSS_ACCESS_KEY: ${ALIYUN_OSS_ACCESS_KEY:-your-access-key}
ALIYUN_OSS_SECRET_KEY: ${ALIYUN_OSS_SECRET_KEY:-your-secret-key}
- ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-"https://oss-ap-southeast-1-internal.aliyuncs.com"}
+ ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-https://oss-ap-southeast-1-internal.aliyuncs.com}
ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1}
ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4}
ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path}
@@ -103,7 +103,7 @@ x-shared-env: &shared-api-worker-env
TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id}
TENCENT_COS_REGION: ${TENCENT_COS_REGION:-your-region}
TENCENT_COS_SCHEME: ${TENCENT_COS_SCHEME:-your-scheme}
- OCI_ENDPOINT: ${OCI_ENDPOINT:-"https://objectstorage.us-ashburn-1.oraclecloud.com"}
+ OCI_ENDPOINT: ${OCI_ENDPOINT:-https://objectstorage.us-ashburn-1.oraclecloud.com}
OCI_BUCKET_NAME: ${OCI_BUCKET_NAME:-your-bucket-name}
OCI_ACCESS_KEY: ${OCI_ACCESS_KEY:-your-access-key}
OCI_SECRET_KEY: ${OCI_SECRET_KEY:-your-secret-key}
@@ -125,14 +125,14 @@ x-shared-env: &shared-api-worker-env
SUPABASE_API_KEY: ${SUPABASE_API_KEY:-your-access-key}
SUPABASE_URL: ${SUPABASE_URL:-your-server-url}
VECTOR_STORE: ${VECTOR_STORE:-weaviate}
- WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-"http://weaviate:8080"}
+ WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080}
WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih}
- QDRANT_URL: ${QDRANT_URL:-"http://qdrant:6333"}
+ QDRANT_URL: ${QDRANT_URL:-http://qdrant:6333}
QDRANT_API_KEY: ${QDRANT_API_KEY:-difyai123456}
QDRANT_CLIENT_TIMEOUT: ${QDRANT_CLIENT_TIMEOUT:-20}
QDRANT_GRPC_ENABLED: ${QDRANT_GRPC_ENABLED:-false}
QDRANT_GRPC_PORT: ${QDRANT_GRPC_PORT:-6334}
- MILVUS_URI: ${MILVUS_URI:-"http://127.0.0.1:19530"}
+ MILVUS_URI: ${MILVUS_URI:-http://127.0.0.1:19530}
MILVUS_TOKEN: ${MILVUS_TOKEN:-}
MILVUS_USER: ${MILVUS_USER:-root}
MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus}
@@ -142,7 +142,7 @@ x-shared-env: &shared-api-worker-env
MYSCALE_PASSWORD: ${MYSCALE_PASSWORD:-}
MYSCALE_DATABASE: ${MYSCALE_DATABASE:-dify}
MYSCALE_FTS_PARAMS: ${MYSCALE_FTS_PARAMS:-}
- COUCHBASE_CONNECTION_STRING: ${COUCHBASE_CONNECTION_STRING:-"couchbase://couchbase-server"}
+ COUCHBASE_CONNECTION_STRING: ${COUCHBASE_CONNECTION_STRING:-couchbase://couchbase-server}
COUCHBASE_USER: ${COUCHBASE_USER:-Administrator}
COUCHBASE_PASSWORD: ${COUCHBASE_PASSWORD:-password}
COUCHBASE_BUCKET_NAME: ${COUCHBASE_BUCKET_NAME:-Embeddings}
@@ -176,15 +176,15 @@ x-shared-env: &shared-api-worker-env
TIDB_VECTOR_USER: ${TIDB_VECTOR_USER:-}
TIDB_VECTOR_PASSWORD: ${TIDB_VECTOR_PASSWORD:-}
TIDB_VECTOR_DATABASE: ${TIDB_VECTOR_DATABASE:-dify}
- TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-"http://127.0.0.1"}
+ TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-http://127.0.0.1}
TIDB_ON_QDRANT_API_KEY: ${TIDB_ON_QDRANT_API_KEY:-dify}
TIDB_ON_QDRANT_CLIENT_TIMEOUT: ${TIDB_ON_QDRANT_CLIENT_TIMEOUT:-20}
TIDB_ON_QDRANT_GRPC_ENABLED: ${TIDB_ON_QDRANT_GRPC_ENABLED:-false}
TIDB_ON_QDRANT_GRPC_PORT: ${TIDB_ON_QDRANT_GRPC_PORT:-6334}
TIDB_PUBLIC_KEY: ${TIDB_PUBLIC_KEY:-dify}
TIDB_PRIVATE_KEY: ${TIDB_PRIVATE_KEY:-dify}
- TIDB_API_URL: ${TIDB_API_URL:-"http://127.0.0.1"}
- TIDB_IAM_API_URL: ${TIDB_IAM_API_URL:-"http://127.0.0.1"}
+ TIDB_API_URL: ${TIDB_API_URL:-http://127.0.0.1}
+ TIDB_IAM_API_URL: ${TIDB_IAM_API_URL:-http://127.0.0.1}
TIDB_REGION: ${TIDB_REGION:-regions/aws-us-east-1}
TIDB_PROJECT_ID: ${TIDB_PROJECT_ID:-dify}
TIDB_SPEND_LIMIT: ${TIDB_SPEND_LIMIT:-100}
@@ -209,7 +209,7 @@ x-shared-env: &shared-api-worker-env
OPENSEARCH_USER: ${OPENSEARCH_USER:-admin}
OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin}
OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true}
- TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-"http://127.0.0.1"}
+ TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1}
TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify}
TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30}
TENCENT_VECTOR_DB_USERNAME: ${TENCENT_VECTOR_DB_USERNAME:-dify}
@@ -221,7 +221,7 @@ x-shared-env: &shared-api-worker-env
ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic}
ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic}
KIBANA_PORT: ${KIBANA_PORT:-5601}
- BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-"http://127.0.0.1:5287"}
+ BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287}
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000}
BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root}
BAIDU_VECTOR_DB_API_KEY: ${BAIDU_VECTOR_DB_API_KEY:-dify}
@@ -235,7 +235,7 @@ x-shared-env: &shared-api-worker-env
VIKINGDB_SCHEMA: ${VIKINGDB_SCHEMA:-http}
VIKINGDB_CONNECTION_TIMEOUT: ${VIKINGDB_CONNECTION_TIMEOUT:-30}
VIKINGDB_SOCKET_TIMEOUT: ${VIKINGDB_SOCKET_TIMEOUT:-30}
- LINDORM_URL: ${LINDORM_URL:-"http://lindorm:30070"}
+ LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070}
LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm}
LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm}
OCEANBASE_VECTOR_HOST: ${OCEANBASE_VECTOR_HOST:-oceanbase}
@@ -245,7 +245,7 @@ x-shared-env: &shared-api-worker-env
OCEANBASE_VECTOR_DATABASE: ${OCEANBASE_VECTOR_DATABASE:-test}
OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai}
OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G}
- UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-"https://xxx-vector.upstash.io"}
+ UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-https://xxx-vector.upstash.io}
UPSTASH_VECTOR_TOKEN: ${UPSTASH_VECTOR_TOKEN:-dify}
UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5}
@@ -270,7 +270,7 @@ x-shared-env: &shared-api-worker-env
NOTION_INTERNAL_SECRET: ${NOTION_INTERNAL_SECRET:-}
MAIL_TYPE: ${MAIL_TYPE:-resend}
MAIL_DEFAULT_SEND_FROM: ${MAIL_DEFAULT_SEND_FROM:-}
- RESEND_API_URL: ${RESEND_API_URL:-"https://api.resend.com"}
+ RESEND_API_URL: ${RESEND_API_URL:-https://api.resend.com}
RESEND_API_KEY: ${RESEND_API_KEY:-your-resend-api-key}
SMTP_SERVER: ${SMTP_SERVER:-}
SMTP_PORT: ${SMTP_PORT:-465}
@@ -281,7 +281,7 @@ x-shared-env: &shared-api-worker-env
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000}
INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72}
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5}
- CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-"http://sandbox:8194"}
+ CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194}
CODE_EXECUTION_API_KEY: ${CODE_EXECUTION_API_KEY:-dify-sandbox}
CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807}
CODE_MIN_NUMBER: ${CODE_MIN_NUMBER:--9223372036854775808}
@@ -303,8 +303,8 @@ x-shared-env: &shared-api-worker-env
WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10}
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
- SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-"http://ssrf_proxy:3128"}
- SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-"http://ssrf_proxy:3128"}
+ SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-http://ssrf_proxy:3128}
+ SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-http://ssrf_proxy:3128}
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
PGUSER: ${PGUSER:-${DB_USERNAME}}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}}
@@ -314,8 +314,8 @@ x-shared-env: &shared-api-worker-env
SANDBOX_GIN_MODE: ${SANDBOX_GIN_MODE:-release}
SANDBOX_WORKER_TIMEOUT: ${SANDBOX_WORKER_TIMEOUT:-15}
SANDBOX_ENABLE_NETWORK: ${SANDBOX_ENABLE_NETWORK:-true}
- SANDBOX_HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-"http://ssrf_proxy:3128"}
- SANDBOX_HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-"http://ssrf_proxy:3128"}
+ SANDBOX_HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128}
+ SANDBOX_HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128}
SANDBOX_PORT: ${SANDBOX_PORT:-8194}
WEAVIATE_PERSISTENCE_DATA_PATH: ${WEAVIATE_PERSISTENCE_DATA_PATH:-/var/lib/weaviate}
WEAVIATE_QUERY_DEFAULTS_LIMIT: ${WEAVIATE_QUERY_DEFAULTS_LIMIT:-25}
@@ -338,8 +338,8 @@ x-shared-env: &shared-api-worker-env
ETCD_SNAPSHOT_COUNT: ${ETCD_SNAPSHOT_COUNT:-50000}
MINIO_ACCESS_KEY: ${MINIO_ACCESS_KEY:-minioadmin}
MINIO_SECRET_KEY: ${MINIO_SECRET_KEY:-minioadmin}
- ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-"etcd:2379"}
- MINIO_ADDRESS: ${MINIO_ADDRESS:-"minio:9000"}
+ ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379}
+ MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000}
MILVUS_AUTHORIZATION_ENABLED: ${MILVUS_AUTHORIZATION_ENABLED:-true}
PGVECTOR_PGUSER: ${PGVECTOR_PGUSER:-postgres}
PGVECTOR_POSTGRES_PASSWORD: ${PGVECTOR_POSTGRES_PASSWORD:-difyai123456}
@@ -360,7 +360,7 @@ x-shared-env: &shared-api-worker-env
NGINX_SSL_PORT: ${NGINX_SSL_PORT:-443}
NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt}
NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key}
- NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-"TLSv1.1 TLSv1.2 TLSv1.3"}
+ NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3}
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-15M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
@@ -374,7 +374,7 @@ x-shared-env: &shared-api-worker-env
SSRF_COREDUMP_DIR: ${SSRF_COREDUMP_DIR:-/var/spool/squid}
SSRF_REVERSE_PROXY_PORT: ${SSRF_REVERSE_PROXY_PORT:-8194}
SSRF_SANDBOX_HOST: ${SSRF_SANDBOX_HOST:-sandbox}
- COMPOSE_PROFILES: ${COMPOSE_PROFILES:-"${VECTOR_STORE:-weaviate}"}
+ COMPOSE_PROFILES: ${COMPOSE_PROFILES:-${VECTOR_STORE:-weaviate}}
EXPOSE_NGINX_PORT: ${EXPOSE_NGINX_PORT:-80}
EXPOSE_NGINX_SSL_PORT: ${EXPOSE_NGINX_SSL_PORT:-443}
POSITION_TOOL_PINS: ${POSITION_TOOL_PINS:-}
diff --git a/docker/generate_docker_compose b/docker/generate_docker_compose
index 54b6d55217f8ba..dc4460f96cf9be 100755
--- a/docker/generate_docker_compose
+++ b/docker/generate_docker_compose
@@ -43,7 +43,7 @@ def generate_shared_env_block(env_vars, anchor_name="shared-api-worker-env"):
else:
# If default value contains special characters, wrap it in quotes
if re.search(r"[:\s]", default):
- default = f'"{default}"'
+ default = f"{default}"
lines.append(f" {key}: ${{{key}:-{default}}}")
return "\n".join(lines)
From 39ace9bdee5426f09197d7d8ab0bd347db6fc1c1 Mon Sep 17 00:00:00 2001
From: -LAN-
Date: Wed, 25 Dec 2024 16:34:38 +0800
Subject: [PATCH 14/65] =?UTF-8?q?fix(app=5Fgenerator):=20improve=20error?=
=?UTF-8?q?=20handling=20for=20closed=20file=20I/O=20operat=E2=80=A6=20(#1?=
=?UTF-8?q?2073)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: -LAN-
---
api/core/app/apps/advanced_chat/app_generator.py | 2 +-
api/core/app/apps/message_based_app_generator.py | 2 +-
api/core/app/apps/workflow/app_generator.py | 2 +-
api/core/tools/tool_engine.py | 2 +-
4 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py
index a18b40712b7ce6..b006de23699c1b 100644
--- a/api/core/app/apps/advanced_chat/app_generator.py
+++ b/api/core/app/apps/advanced_chat/app_generator.py
@@ -383,7 +383,7 @@ def _handle_advanced_chat_response(
try:
return generate_task_pipeline.process()
except ValueError as e:
- if e.args[0] == "I/O operation on closed file.": # ignore this error
+ if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception(f"Failed to process generate task pipeline, conversation_id: {conversation.id}")
diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py
index dcd9463b8abd0f..4e3aa840ceac8a 100644
--- a/api/core/app/apps/message_based_app_generator.py
+++ b/api/core/app/apps/message_based_app_generator.py
@@ -76,7 +76,7 @@ def _handle_response(
try:
return generate_task_pipeline.process()
except ValueError as e:
- if e.args[0] == "I/O operation on closed file.": # ignore this error
+ if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception(f"Failed to handle response, conversation_id: {conversation.id}")
diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py
index 1d5f21b9e0cc07..42bc17277fd7c5 100644
--- a/api/core/app/apps/workflow/app_generator.py
+++ b/api/core/app/apps/workflow/app_generator.py
@@ -309,7 +309,7 @@ def _handle_response(
try:
return generate_task_pipeline.process()
except ValueError as e:
- if e.args[0] == "I/O operation on closed file.": # ignore this error
+ if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception(
diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py
index 425a892527daa4..f7a8ed63f401d5 100644
--- a/api/core/tools/tool_engine.py
+++ b/api/core/tools/tool_engine.py
@@ -113,7 +113,7 @@ def agent_invoke(
error_response = f"tool invoke error: {e}"
agent_tool_callback.on_tool_error(e)
except ToolEngineInvokeError as e:
- meta = e.args[0]
+ meta = e.meta
error_response = f"tool invoke error: {meta.error}"
agent_tool_callback.on_tool_error(e)
return error_response, [], meta
From 2b2263a349326e21da813640435e8fcd4c9ce536 Mon Sep 17 00:00:00 2001
From: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
Date: Wed, 25 Dec 2024 18:17:15 +0800
Subject: [PATCH 15/65] Feat/parent child retrieval (#12086)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: yihong0618
Signed-off-by: -LAN-
Co-authored-by: AkaraChen
Co-authored-by: nite-knite
Co-authored-by: Joel
Co-authored-by: Warren Chen
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com>
Co-authored-by: yihong
Co-authored-by: -LAN-
Co-authored-by: KVOJJJin
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com>
Co-authored-by: Charlie.Wei
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: huayaoyue6
Co-authored-by: kurokobo
Co-authored-by: Matsuda
Co-authored-by: shirochan
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: Huỳnh Gia Bôi
Co-authored-by: Julian Huynh
Co-authored-by: Hash Brown
Co-authored-by: 非法操作
Co-authored-by: Kazuki Takamatsu
Co-authored-by: Trey Dong <1346650911@qq.com>
Co-authored-by: VoidIsVoid <343750470@qq.com>
Co-authored-by: Gimling
Co-authored-by: xiandan-erizo
Co-authored-by: Muneyuki Noguchi
Co-authored-by: zhaobingshuang <1475195565@qq.com>
Co-authored-by: zhaobs
Co-authored-by: suzuki.sh
Co-authored-by: Yingchun Lai
Co-authored-by: huanshare
Co-authored-by: huanshare
Co-authored-by: orangeclk
Co-authored-by: 문정현 <120004247+JungHyunMoon@users.noreply.github.com>
Co-authored-by: barabicu
Co-authored-by: Wei Mingzhi
Co-authored-by: Paul van Oorschot <20116814+pvoo@users.noreply.github.com>
Co-authored-by: zkyTech
Co-authored-by: zhangkunyuan
Co-authored-by: Tommy <34446820+Asterovim@users.noreply.github.com>
Co-authored-by: zxhlyh
Co-authored-by: Novice <857526207@qq.com>
Co-authored-by: Novice Lee
Co-authored-by: Novice Lee
Co-authored-by: zxhlyh <16177003+zxhlyh@users.noreply.github.com>
Co-authored-by: liuzhenghua <1090179900@qq.com>
Co-authored-by: Jiang <65766008+AlwaysBluer@users.noreply.github.com>
Co-authored-by: jiangzhijie
Co-authored-by: Joe <79627742+ZhouhaoJiang@users.noreply.github.com>
Co-authored-by: Alok Shrivastwa
Co-authored-by: Alok Shrivastwa
Co-authored-by: JasonVV
Co-authored-by: Hiroshi Fujita
Co-authored-by: Kevin9703 <51311316+Kevin9703@users.noreply.github.com>
Co-authored-by: NFish
Co-authored-by: Junyan Qin <1010553892@qq.com>
Co-authored-by: IWAI, Masaharu
Co-authored-by: IWAI, Masaharu
Co-authored-by: Bowen Liang
Co-authored-by: luckylhb90
Co-authored-by: hobo.l
Co-authored-by: douxc <7553076+douxc@users.noreply.github.com>
---
api/poetry.lock | 13 +-
.../[datasetId]/layout.tsx | 156 +-
.../[datasetId]/settings/page.tsx | 6 +-
.../[datasetId]/style.module.css | 9 -
web/app/(commonLayout)/datasets/Container.tsx | 17 +-
.../(commonLayout)/datasets/DatasetCard.tsx | 10 +-
.../components/app-sidebar/dataset-info.tsx | 45 +
web/app/components/app-sidebar/index.tsx | 30 +-
.../dataset-config/settings-modal/index.tsx | 16 +-
.../components/base/app-icon/style.module.css | 19 +
.../base/auto-height-textarea/common.tsx | 2 +
web/app/components/base/badge.tsx | 6 +-
.../components/base/checkbox/assets/mixed.svg | 5 +
.../components/base/checkbox/index.module.css | 10 +
web/app/components/base/checkbox/index.tsx | 5 +-
web/app/components/base/divider/index.tsx | 2 +-
.../components/base/divider/with-label.tsx | 23 +
web/app/components/base/drawer/index.tsx | 4 +-
.../base/file-uploader/file-type-icon.tsx | 7 +-
.../icons/assets/public/knowledge/chunk.svg | 13 +
.../assets/public/knowledge/collapse.svg | 9 +
.../assets/public/knowledge/general-type.svg | 5 +
.../knowledge/layout-right-2-line-mod.svg | 5 +
.../public/knowledge/parent-child-type.svg | 7 +
.../assets/public/knowledge/selection-mod.svg | 13 +
.../icons/src/public/knowledge/Chunk.json | 116 ++
.../base/icons/src/public/knowledge/Chunk.tsx | 16 +
.../icons/src/public/knowledge/Collapse.json | 62 +
.../icons/src/public/knowledge/Collapse.tsx | 16 +
.../src/public/knowledge/GeneralType.json | 38 +
.../src/public/knowledge/GeneralType.tsx | 16 +
.../public/knowledge/LayoutRight2LineMod.json | 36 +
.../public/knowledge/LayoutRight2LineMod.tsx | 16 +
.../src/public/knowledge/ParentChildType.json | 56 +
.../src/public/knowledge/ParentChildType.tsx | 16 +
.../src/public/knowledge/SelectionMod.json | 116 ++
.../src/public/knowledge/SelectionMod.tsx | 16 +
.../base/icons/src/public/knowledge/index.ts | 6 +
.../base/icons/src/vender/features/index.ts | 2 +-
.../components/base/input-number/index.tsx | 86 +
.../base/linked-apps-panel/index.tsx | 62 +
web/app/components/base/pagination/index.tsx | 2 +-
web/app/components/base/param-item/index.tsx | 36 +-
web/app/components/base/radio-card/index.tsx | 23 +-
.../components/base/retry-button/index.tsx | 85 -
.../base/retry-button/style.module.css | 4 -
.../base/simple-pie-chart/index.tsx | 7 +-
web/app/components/base/skeleton/index.tsx | 13 +-
web/app/components/base/switch/index.tsx | 3 +
web/app/components/base/tag-input/index.tsx | 51 +-
web/app/components/base/toast/index.tsx | 19 +-
web/app/components/base/tooltip/index.tsx | 4 +-
.../billing/priority-label/index.tsx | 19 +-
web/app/components/datasets/chunk.tsx | 54 +
.../datasets/common/chunking-mode-label.tsx | 29 +
.../datasets/common/document-file-icon.tsx | 40 +
.../common/document-picker/document-list.tsx | 42 +
.../datasets/common/document-picker/index.tsx | 118 ++
.../preview-document-picker.tsx | 82 +
.../auto-disabled-document.tsx | 38 +
.../index-failed.tsx | 69 +
.../status-with-action.tsx | 65 +
.../index.tsx | 27 +-
.../common/retrieval-method-config/index.tsx | 89 +-
.../common/retrieval-method-info/index.tsx | 20 +-
.../common/retrieval-param-config/index.tsx | 38 +-
.../datasets/create/assets/family-mod.svg | 6 +
.../create/assets/file-list-3-fill.svg | 5 +
.../datasets/create/assets/gold.svg | 4 +
.../datasets/create/assets/note-mod.svg | 5 +
.../create/assets/option-card-effect-blue.svg | 12 +
.../assets/option-card-effect-orange.svg | 12 +
.../assets/option-card-effect-purple.svg | 12 +
.../create/assets/pattern-recognition-mod.svg | 12 +
.../datasets/create/assets/piggy-bank-mod.svg | 7 +
.../create/assets/progress-indicator.svg | 8 +
.../datasets/create/assets/rerank.svg | 13 +
.../datasets/create/assets/research-mod.svg | 6 +
.../datasets/create/assets/selection-mod.svg | 12 +
.../create/assets/setting-gear-mod.svg | 4 +
.../create/embedding-process/index.module.css | 52 +-
.../create/embedding-process/index.tsx | 192 ++-
.../create/file-preview/index.module.css | 3 +-
.../datasets/create/file-preview/index.tsx | 4 +-
.../create/file-uploader/index.module.css | 67 +-
.../datasets/create/file-uploader/index.tsx | 82 +-
web/app/components/datasets/create/icons.ts | 16 +
web/app/components/datasets/create/index.tsx | 58 +-
.../create/notion-page-preview/index.tsx | 4 +-
.../datasets/create/step-one/index.module.css | 38 +-
.../datasets/create/step-one/index.tsx | 287 +--
.../datasets/create/step-three/index.tsx | 53 +-
.../datasets/create/step-two/index.module.css | 38 +-
.../datasets/create/step-two/index.tsx | 1535 +++++++++--------
.../datasets/create/step-two/inputs.tsx | 77 +
.../create/step-two/language-select/index.tsx | 33 +-
.../datasets/create/step-two/option-card.tsx | 98 ++
.../datasets/create/stepper/index.tsx | 27 +
.../datasets/create/stepper/step.tsx | 46 +
.../datasets/create/top-bar/index.tsx | 41 +
.../create/website/base/error-message.tsx | 2 +-
.../create/website/jina-reader/index.tsx | 1 -
.../datasets/create/website/preview.tsx | 4 +-
.../detail/batch-modal/csv-downloader.tsx | 14 +-
.../documents/detail/batch-modal/index.tsx | 4 +-
.../detail/completed/InfiniteVirtualList.tsx | 98 --
.../detail/completed/SegmentCard.tsx | 20 +-
.../detail/completed/child-segment-detail.tsx | 134 ++
.../detail/completed/child-segment-list.tsx | 195 +++
.../completed/common/action-buttons.tsx | 86 +
.../detail/completed/common/add-another.tsx | 32 +
.../detail/completed/common/batch-action.tsx | 103 ++
.../detail/completed/common/chunk-content.tsx | 192 +++
.../documents/detail/completed/common/dot.tsx | 11 +
.../detail/completed/common/empty.tsx | 78 +
.../completed/common/full-screen-drawer.tsx | 35 +
.../detail/completed/common/keywords.tsx | 47 +
.../completed/common/regeneration-modal.tsx | 131 ++
.../completed/common/segment-index-tag.tsx | 40 +
.../documents/detail/completed/common/tag.tsx | 15 +
.../detail/completed/display-toggle.tsx | 40 +
.../documents/detail/completed/index.tsx | 873 ++++++----
.../detail/completed/new-child-segment.tsx | 175 ++
.../detail/completed/segment-card.tsx | 280 +++
.../detail/completed/segment-detail.tsx | 190 ++
.../detail/completed/segment-list.tsx | 116 ++
.../skeleton/full-doc-list-skeleton.tsx | 25 +
.../skeleton/general-list-skeleton.tsx | 74 +
.../skeleton/paragraph-list-skeleton.tsx | 76 +
.../skeleton/parent-chunk-card-skeleton.tsx | 45 +
.../detail/completed/status-item.tsx | 22 +
.../detail/completed/style.module.css | 13 +-
.../documents/detail/embedding/index.tsx | 302 ++--
.../detail/embedding/skeleton/index.tsx | 66 +
.../datasets/documents/detail/index.tsx | 219 ++-
.../documents/detail/metadata/index.tsx | 18 +-
.../detail/metadata/style.module.css | 13 +-
.../documents/detail/new-segment-modal.tsx | 156 --
.../datasets/documents/detail/new-segment.tsx | 208 +++
.../documents/detail/segment-add/index.tsx | 114 +-
.../documents/detail/settings/index.tsx | 40 +-
.../documents/detail/style.module.css | 10 +-
.../components/datasets/documents/index.tsx | 77 +-
.../components/datasets/documents/list.tsx | 440 +++--
.../datasets/documents/style.module.css | 17 +-
.../formatted-text/flavours/edit-slice.tsx | 115 ++
.../formatted-text/flavours/preview-slice.tsx | 56 +
.../formatted-text/flavours/shared.tsx | 60 +
.../datasets/formatted-text/flavours/type.ts | 5 +
.../datasets/formatted-text/formatted.tsx | 12 +
.../components/child-chunks-item.tsx | 30 +
.../components/chunk-detail-modal.tsx | 89 +
.../hit-testing/components/result-item.tsx | 121 ++
.../datasets/hit-testing/components/score.tsx | 25 +
.../datasets/hit-testing/hit-detail.tsx | 68 -
.../components/datasets/hit-testing/index.tsx | 144 +-
.../hit-testing/modify-retrieval-modal.tsx | 8 +-
.../datasets/hit-testing/style.module.css | 36 +-
.../datasets/hit-testing/textarea.tsx | 59 +-
.../utils/extension-to-file-type.ts | 31 +
web/app/components/datasets/loading.tsx | 0
.../components/datasets/preview/container.tsx | 29 +
.../components/datasets/preview/header.tsx | 23 +
web/app/components/datasets/preview/index.tsx | 0
.../datasets/settings/form/index.tsx | 130 +-
.../index-method-radio/index.module.css | 54 -
.../settings/index-method-radio/index.tsx | 105 +-
.../model-selector/model-trigger.tsx | 16 +-
web/app/components/header/indicator/index.tsx | 18 +-
web/context/dataset-detail.ts | 13 +-
web/hooks/use-metadata.ts | 4 +-
web/i18n/en-US/common.ts | 8 +-
web/i18n/en-US/dataset-creation.ts | 50 +-
web/i18n/en-US/dataset-documents.ts | 66 +-
web/i18n/en-US/dataset-hit-testing.ts | 14 +-
web/i18n/en-US/dataset-settings.ts | 9 +-
web/i18n/en-US/dataset.ts | 22 +-
web/i18n/zh-Hans/common.ts | 8 +-
web/i18n/zh-Hans/dataset-creation.ts | 29 +-
web/i18n/zh-Hans/dataset-documents.ts | 54 +-
web/i18n/zh-Hans/dataset-hit-testing.ts | 10 +-
web/i18n/zh-Hans/dataset-settings.ts | 9 +-
web/i18n/zh-Hans/dataset.ts | 22 +-
web/i18n/zh-Hant/dataset-creation.ts | 3 +-
web/models/datasets.ts | 111 +-
web/package.json | 2 +
web/public/screenshots/Light/Agent.png | Bin 0 -> 36209 bytes
web/public/screenshots/Light/Agent@2x.png | Bin 0 -> 103245 bytes
web/public/screenshots/Light/Agent@3x.png | Bin 0 -> 209674 bytes
web/public/screenshots/Light/ChatFlow.png | Bin 0 -> 28423 bytes
web/public/screenshots/Light/ChatFlow@2x.png | Bin 0 -> 81229 bytes
web/public/screenshots/Light/ChatFlow@3x.png | Bin 0 -> 160820 bytes
web/public/screenshots/Light/Chatbot.png | Bin 0 -> 31633 bytes
web/public/screenshots/Light/Chatbot@2x.png | Bin 0 -> 84515 bytes
web/public/screenshots/Light/Chatbot@3x.png | Bin 0 -> 142013 bytes
web/public/screenshots/Light/Chatflow.png | Bin 0 -> 28423 bytes
web/public/screenshots/Light/Chatflow@2x.png | Bin 0 -> 81229 bytes
web/public/screenshots/Light/Chatflow@3x.png | Bin 0 -> 160820 bytes
.../screenshots/Light/TextGenerator.png | Bin 0 -> 26627 bytes
.../screenshots/Light/TextGenerator@2x.png | Bin 0 -> 63818 bytes
.../screenshots/Light/TextGenerator@3x.png | Bin 0 -> 122391 bytes
web/public/screenshots/Light/Workflow.png | Bin 0 -> 22110 bytes
web/public/screenshots/Light/Workflow@2x.png | Bin 0 -> 62688 bytes
web/public/screenshots/Light/Workflow@3x.png | Bin 0 -> 147073 bytes
web/service/datasets.ts | 71 -
web/service/knowledge/use-create-dataset.ts | 223 +++
web/service/knowledge/use-dateset.ts | 0
web/service/knowledge/use-document.ts | 124 ++
web/service/knowledge/use-hit-testing.ts | 0
web/service/knowledge/use-import.ts | 0
web/service/knowledge/use-segment.ts | 169 ++
web/tailwind.config.js | 23 +-
web/themes/manual-dark.css | 8 +
web/themes/manual-light.css | 8 +
web/utils/time.ts | 12 +
web/yarn.lock | 10 +
216 files changed, 9038 insertions(+), 3088 deletions(-)
create mode 100644 web/app/components/app-sidebar/dataset-info.tsx
create mode 100644 web/app/components/base/app-icon/style.module.css
create mode 100644 web/app/components/base/checkbox/assets/mixed.svg
create mode 100644 web/app/components/base/checkbox/index.module.css
create mode 100644 web/app/components/base/divider/with-label.tsx
create mode 100644 web/app/components/base/icons/assets/public/knowledge/chunk.svg
create mode 100644 web/app/components/base/icons/assets/public/knowledge/collapse.svg
create mode 100644 web/app/components/base/icons/assets/public/knowledge/general-type.svg
create mode 100644 web/app/components/base/icons/assets/public/knowledge/layout-right-2-line-mod.svg
create mode 100644 web/app/components/base/icons/assets/public/knowledge/parent-child-type.svg
create mode 100644 web/app/components/base/icons/assets/public/knowledge/selection-mod.svg
create mode 100644 web/app/components/base/icons/src/public/knowledge/Chunk.json
create mode 100644 web/app/components/base/icons/src/public/knowledge/Chunk.tsx
create mode 100644 web/app/components/base/icons/src/public/knowledge/Collapse.json
create mode 100644 web/app/components/base/icons/src/public/knowledge/Collapse.tsx
create mode 100644 web/app/components/base/icons/src/public/knowledge/GeneralType.json
create mode 100644 web/app/components/base/icons/src/public/knowledge/GeneralType.tsx
create mode 100644 web/app/components/base/icons/src/public/knowledge/LayoutRight2LineMod.json
create mode 100644 web/app/components/base/icons/src/public/knowledge/LayoutRight2LineMod.tsx
create mode 100644 web/app/components/base/icons/src/public/knowledge/ParentChildType.json
create mode 100644 web/app/components/base/icons/src/public/knowledge/ParentChildType.tsx
create mode 100644 web/app/components/base/icons/src/public/knowledge/SelectionMod.json
create mode 100644 web/app/components/base/icons/src/public/knowledge/SelectionMod.tsx
create mode 100644 web/app/components/base/icons/src/public/knowledge/index.ts
create mode 100644 web/app/components/base/input-number/index.tsx
create mode 100644 web/app/components/base/linked-apps-panel/index.tsx
delete mode 100644 web/app/components/base/retry-button/index.tsx
delete mode 100644 web/app/components/base/retry-button/style.module.css
create mode 100644 web/app/components/datasets/chunk.tsx
create mode 100644 web/app/components/datasets/common/chunking-mode-label.tsx
create mode 100644 web/app/components/datasets/common/document-file-icon.tsx
create mode 100644 web/app/components/datasets/common/document-picker/document-list.tsx
create mode 100644 web/app/components/datasets/common/document-picker/index.tsx
create mode 100644 web/app/components/datasets/common/document-picker/preview-document-picker.tsx
create mode 100644 web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx
create mode 100644 web/app/components/datasets/common/document-status-with-action/index-failed.tsx
create mode 100644 web/app/components/datasets/common/document-status-with-action/status-with-action.tsx
create mode 100644 web/app/components/datasets/create/assets/family-mod.svg
create mode 100644 web/app/components/datasets/create/assets/file-list-3-fill.svg
create mode 100644 web/app/components/datasets/create/assets/gold.svg
create mode 100644 web/app/components/datasets/create/assets/note-mod.svg
create mode 100644 web/app/components/datasets/create/assets/option-card-effect-blue.svg
create mode 100644 web/app/components/datasets/create/assets/option-card-effect-orange.svg
create mode 100644 web/app/components/datasets/create/assets/option-card-effect-purple.svg
create mode 100644 web/app/components/datasets/create/assets/pattern-recognition-mod.svg
create mode 100644 web/app/components/datasets/create/assets/piggy-bank-mod.svg
create mode 100644 web/app/components/datasets/create/assets/progress-indicator.svg
create mode 100644 web/app/components/datasets/create/assets/rerank.svg
create mode 100644 web/app/components/datasets/create/assets/research-mod.svg
create mode 100644 web/app/components/datasets/create/assets/selection-mod.svg
create mode 100644 web/app/components/datasets/create/assets/setting-gear-mod.svg
create mode 100644 web/app/components/datasets/create/icons.ts
create mode 100644 web/app/components/datasets/create/step-two/inputs.tsx
create mode 100644 web/app/components/datasets/create/step-two/option-card.tsx
create mode 100644 web/app/components/datasets/create/stepper/index.tsx
create mode 100644 web/app/components/datasets/create/stepper/step.tsx
create mode 100644 web/app/components/datasets/create/top-bar/index.tsx
delete mode 100644 web/app/components/datasets/documents/detail/completed/InfiniteVirtualList.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/child-segment-detail.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/child-segment-list.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/common/add-another.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/common/batch-action.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/common/chunk-content.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/common/dot.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/common/empty.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/common/full-screen-drawer.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/common/keywords.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/common/regeneration-modal.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/common/segment-index-tag.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/common/tag.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/display-toggle.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/new-child-segment.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/segment-card.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/segment-detail.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/segment-list.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/skeleton/full-doc-list-skeleton.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/skeleton/general-list-skeleton.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/skeleton/paragraph-list-skeleton.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/skeleton/parent-chunk-card-skeleton.tsx
create mode 100644 web/app/components/datasets/documents/detail/completed/status-item.tsx
create mode 100644 web/app/components/datasets/documents/detail/embedding/skeleton/index.tsx
delete mode 100644 web/app/components/datasets/documents/detail/new-segment-modal.tsx
create mode 100644 web/app/components/datasets/documents/detail/new-segment.tsx
create mode 100644 web/app/components/datasets/formatted-text/flavours/edit-slice.tsx
create mode 100644 web/app/components/datasets/formatted-text/flavours/preview-slice.tsx
create mode 100644 web/app/components/datasets/formatted-text/flavours/shared.tsx
create mode 100644 web/app/components/datasets/formatted-text/flavours/type.ts
create mode 100644 web/app/components/datasets/formatted-text/formatted.tsx
create mode 100644 web/app/components/datasets/hit-testing/components/child-chunks-item.tsx
create mode 100644 web/app/components/datasets/hit-testing/components/chunk-detail-modal.tsx
create mode 100644 web/app/components/datasets/hit-testing/components/result-item.tsx
create mode 100644 web/app/components/datasets/hit-testing/components/score.tsx
delete mode 100644 web/app/components/datasets/hit-testing/hit-detail.tsx
create mode 100644 web/app/components/datasets/hit-testing/utils/extension-to-file-type.ts
create mode 100644 web/app/components/datasets/loading.tsx
create mode 100644 web/app/components/datasets/preview/container.tsx
create mode 100644 web/app/components/datasets/preview/header.tsx
create mode 100644 web/app/components/datasets/preview/index.tsx
delete mode 100644 web/app/components/datasets/settings/index-method-radio/index.module.css
create mode 100644 web/public/screenshots/Light/Agent.png
create mode 100644 web/public/screenshots/Light/Agent@2x.png
create mode 100644 web/public/screenshots/Light/Agent@3x.png
create mode 100644 web/public/screenshots/Light/ChatFlow.png
create mode 100644 web/public/screenshots/Light/ChatFlow@2x.png
create mode 100644 web/public/screenshots/Light/ChatFlow@3x.png
create mode 100644 web/public/screenshots/Light/Chatbot.png
create mode 100644 web/public/screenshots/Light/Chatbot@2x.png
create mode 100644 web/public/screenshots/Light/Chatbot@3x.png
create mode 100644 web/public/screenshots/Light/Chatflow.png
create mode 100644 web/public/screenshots/Light/Chatflow@2x.png
create mode 100644 web/public/screenshots/Light/Chatflow@3x.png
create mode 100644 web/public/screenshots/Light/TextGenerator.png
create mode 100644 web/public/screenshots/Light/TextGenerator@2x.png
create mode 100644 web/public/screenshots/Light/TextGenerator@3x.png
create mode 100644 web/public/screenshots/Light/Workflow.png
create mode 100644 web/public/screenshots/Light/Workflow@2x.png
create mode 100644 web/public/screenshots/Light/Workflow@3x.png
create mode 100644 web/service/knowledge/use-create-dataset.ts
create mode 100644 web/service/knowledge/use-dateset.ts
create mode 100644 web/service/knowledge/use-document.ts
create mode 100644 web/service/knowledge/use-hit-testing.ts
create mode 100644 web/service/knowledge/use-import.ts
create mode 100644 web/service/knowledge/use-segment.ts
create mode 100644 web/utils/time.ts
diff --git a/api/poetry.lock b/api/poetry.lock
index b42eb22dd40b8a..b2d22a887db2f6 100644
--- a/api/poetry.lock
+++ b/api/poetry.lock
@@ -1,4 +1,15 @@
-# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
+
+[[package]]
+name = "aiofiles"
+version = "24.1.0"
+description = "File support for asyncio."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"},
+ {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"},
+]
[[package]]
name = "aiofiles"
diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx
index b416659a6a1cfa..a6fb116fa8a50f 100644
--- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx
+++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx
@@ -7,85 +7,36 @@ import { useTranslation } from 'react-i18next'
import { useBoolean } from 'ahooks'
import {
Cog8ToothIcon,
- // CommandLineIcon,
- Squares2X2Icon,
- // eslint-disable-next-line sort-imports
- PuzzlePieceIcon,
DocumentTextIcon,
PaperClipIcon,
- QuestionMarkCircleIcon,
} from '@heroicons/react/24/outline'
import {
Cog8ToothIcon as Cog8ToothSolidIcon,
// CommandLineIcon as CommandLineSolidIcon,
DocumentTextIcon as DocumentTextSolidIcon,
} from '@heroicons/react/24/solid'
-import Link from 'next/link'
+import { RiApps2AddLine, RiInformation2Line } from '@remixicon/react'
import s from './style.module.css'
import classNames from '@/utils/classnames'
import { fetchDatasetDetail, fetchDatasetRelatedApps } from '@/service/datasets'
-import type { RelatedApp, RelatedAppResponse } from '@/models/datasets'
+import type { RelatedAppResponse } from '@/models/datasets'
import AppSideBar from '@/app/components/app-sidebar'
-import Divider from '@/app/components/base/divider'
-import AppIcon from '@/app/components/base/app-icon'
import Loading from '@/app/components/base/loading'
-import FloatPopoverContainer from '@/app/components/base/float-popover-container'
import DatasetDetailContext from '@/context/dataset-detail'
import { DataSourceType } from '@/models/datasets'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import { LanguagesSupported } from '@/i18n/language'
import { useStore } from '@/app/components/app/store'
-import { AiText, ChatBot, CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication'
-import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel'
import { getLocaleOnClient } from '@/i18n'
import { useAppContext } from '@/context/app-context'
+import Tooltip from '@/app/components/base/tooltip'
+import LinkedAppsPanel from '@/app/components/base/linked-apps-panel'
export type IAppDetailLayoutProps = {
children: React.ReactNode
params: { datasetId: string }
}
-type ILikedItemProps = {
- type?: 'plugin' | 'app'
- appStatus?: boolean
- detail: RelatedApp
- isMobile: boolean
-}
-
-const LikedItem = ({
- type = 'app',
- detail,
- isMobile,
-}: ILikedItemProps) => {
- return (
-
-
-
- {type === 'app' && (
-
- {detail.mode === 'advanced-chat' && (
-
- )}
- {detail.mode === 'agent-chat' && (
-
- )}
- {detail.mode === 'chat' && (
-
- )}
- {detail.mode === 'completion' && (
-
- )}
- {detail.mode === 'workflow' && (
-
- )}
-
- )}
-
- {!isMobile && {detail?.name || '--'}
}
-
- )
-}
-
const TargetIcon = ({ className }: SVGProps) => {
return