From 027ed5ed9c8aa7f58c5bf0d937285324733612c1 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 24 Jul 2024 18:40:08 +0800 Subject: [PATCH 1/9] update: refactor codes --- Dockerfile | 2 +- addons/plugins/helloworld/main.py | 2 - astrbot/bootstrap.py | 97 ++++ astrbot/core.py | 415 ----------------- .../message}/baidu_aip_judge.py | 0 astrbot/message/handler.py | 198 ++++++++ {util => astrbot/message}/unfit_words.py | 0 .../session.py => astrbot/persist/helper.py | 53 +-- astrbot/persist/initialization.sql | 18 + dashboard/__init__.py | 11 + .../dashboard => dashboard}/dist/_redirects | 0 .../dist/assets/BaseBreadcrumb-4d676ba5.css | 0 ...ue_vue_type_style_index_0_lang-1875d383.js | 0 .../dist/assets/BlankLayout-471b2640.js | 0 .../dist/assets/ColorPage-beafd674.js | 0 .../dist/assets/ConfigDetailCard-756c045d.js | 0 .../dist/assets/ConfigPage-56ea019d.js | 0 .../dist/assets/ConsolePage-bd0aea4c.js | 0 .../dist/assets/ConsolePage-ff373be6.css | 0 .../dist/assets/DefaultDashboard-ef4a941d.js | 0 .../dist/assets/Error404Page-11cf087a.css | 0 .../dist/assets/Error404Page-bcb6eec8.js | 0 .../dist/assets/ExtensionPage-701bf929.js | 0 .../dist/assets/FullLayout-35e69863.js | 0 .../dist/assets/LoginPage-5c692a20.js | 0 .../dist/assets/LoginPage-74e85ca7.css | 0 ...ue_type_script_setup_true_lang-d555e5be.js | 0 .../dist/assets/MaterialIcons-a926bc0c.js | 0 .../dist/assets/RegisterPage-799ed804.css | 0 .../dist/assets/RegisterPage-a274fd0f.js | 0 .../dist/assets/ShadowPage-4758709f.js | 0 .../dist/assets/TablerIcons-da1fd166.js | 0 .../dist/assets/TypographyPage-ee445c8b.js | 0 ...ue_type_script_setup_true_lang-b40a2daa.js | 0 .../_plugin-vue_export-helper-c27b6911.js | 0 .../dist/assets/img-error-bg-41f65efa.svg | 0 .../dist/assets/img-error-blue-f50c8e77.svg | 0 .../dist/assets/img-error-purple-b97a483b.svg | 0 .../dist/assets/img-error-text-630dc36d.svg | 0 .../dist/assets/index-0f1523f3.css | 0 .../dist/assets/index-5ac7c267.js | 0 .../materialdesignicons-webfont-67d24abe.eot | Bin .../materialdesignicons-webfont-80bb28b3.woff | Bin .../materialdesignicons-webfont-a58ecb54.ttf | Bin ...materialdesignicons-webfont-c1c004a9.woff2 | Bin .../dist/assets/md5-086248bf.js | 0 .../dist/assets/social-google-9b2fa67a.svg | 0 .../dashboard => dashboard}/dist/favicon.svg | 0 .../dashboard => dashboard}/dist/index.html | 0 {addons/dashboard => dashboard}/helper.py | 88 ++-- {addons/dashboard => dashboard}/server.py | 163 +++---- main.py | 92 +--- model/command/command.py | 10 +- model/command/internal_handler.py | 233 ++++++++++ model/command/manager.py | 99 ++++ model/command/openai_official.py | 255 ----------- model/command/openai_official_handler.py | 152 +++++++ model/command/parser.py | 19 + model/platform/__init__.py | 70 +++ model/platform/_message_parse.py | 114 ----- model/platform/_message_result.py | 9 - model/platform/_platfrom.py | 73 --- model/platform/manager.py | 87 ++++ model/platform/qq_aiocqhttp.py | 146 ++++++ model/platform/qq_gocq.py | 314 ------------- model/platform/qq_nakuru.py | 255 +++++++++++ model/platform/qq_official.py | 327 ++++++++------ model/plugin/command.py | 25 ++ model/plugin/manager.py | 256 +++++++++++ model/provider/openai_official.py | 19 +- requirements.txt | 3 +- type/{message.py => astrbot_message.py} | 33 +- type/command.py | 7 +- type/message_event.py | 55 +++ type/plugin.py | 26 ++ type/register.py | 28 +- type/types.py | 79 ++-- util/agent/web_searcher.py | 10 +- util/cmd_config.py | 62 --- util/config_utils.py | 109 +++++ util/general_utils.py | 425 +----------------- util/image_render/helper.py | 43 -- util/image_uploader.py | 21 + util/io.py | 128 ++++++ util/plugin_dev/api/v1/bot.py | 4 +- util/plugin_dev/api/v1/message.py | 13 +- util/plugin_dev/api/v1/platform.py | 4 +- util/plugin_dev/api/v1/register.py | 12 +- util/plugin_util.py | 361 --------------- util/search_engine_scraper/test.py | 22 - util/t2i/__init__.py | 1 + util/t2i/context.py | 11 + util/t2i/renderer.py | 25 ++ util/t2i/strategies/__init__.py | 3 + util/t2i/strategies/base_strategy.py | 7 + util/t2i/strategies/local_strategy.py | 278 ++++++++++++ util/t2i/strategies/network_strategy.py | 42 ++ .../strategies}/template/base.html | 0 util/updator.py | 209 --------- util/updator/astrbot_updator.py | 109 +++++ util/updator/plugin_updator.py | 48 ++ util/updator/zip_updator.py | 155 +++++++ .../bing.py | 4 +- .../config.py | 0 .../engine.py | 2 +- .../google.py | 4 +- .../sogo.py | 4 +- 107 files changed, 3142 insertions(+), 2807 deletions(-) create mode 100644 astrbot/bootstrap.py delete mode 100644 astrbot/core.py rename {addons => astrbot/message}/baidu_aip_judge.py (100%) create mode 100644 astrbot/message/handler.py rename {util => astrbot/message}/unfit_words.py (100%) rename persist/session.py => astrbot/persist/helper.py (87%) create mode 100644 astrbot/persist/initialization.sql create mode 100644 dashboard/__init__.py rename {addons/dashboard => dashboard}/dist/_redirects (100%) rename {addons/dashboard => dashboard}/dist/assets/BaseBreadcrumb-4d676ba5.css (100%) rename {addons/dashboard => dashboard}/dist/assets/BaseBreadcrumb.vue_vue_type_style_index_0_lang-1875d383.js (100%) rename {addons/dashboard => dashboard}/dist/assets/BlankLayout-471b2640.js (100%) rename {addons/dashboard => dashboard}/dist/assets/ColorPage-beafd674.js (100%) rename {addons/dashboard => dashboard}/dist/assets/ConfigDetailCard-756c045d.js (100%) rename {addons/dashboard => dashboard}/dist/assets/ConfigPage-56ea019d.js (100%) rename {addons/dashboard => dashboard}/dist/assets/ConsolePage-bd0aea4c.js (100%) rename {addons/dashboard => dashboard}/dist/assets/ConsolePage-ff373be6.css (100%) rename {addons/dashboard => dashboard}/dist/assets/DefaultDashboard-ef4a941d.js (100%) rename {addons/dashboard => dashboard}/dist/assets/Error404Page-11cf087a.css (100%) rename {addons/dashboard => dashboard}/dist/assets/Error404Page-bcb6eec8.js (100%) rename {addons/dashboard => dashboard}/dist/assets/ExtensionPage-701bf929.js (100%) rename {addons/dashboard => dashboard}/dist/assets/FullLayout-35e69863.js (100%) rename {addons/dashboard => dashboard}/dist/assets/LoginPage-5c692a20.js (100%) rename {addons/dashboard => dashboard}/dist/assets/LoginPage-74e85ca7.css (100%) rename {addons/dashboard => dashboard}/dist/assets/LogoDark.vue_vue_type_script_setup_true_lang-d555e5be.js (100%) rename {addons/dashboard => dashboard}/dist/assets/MaterialIcons-a926bc0c.js (100%) rename {addons/dashboard => dashboard}/dist/assets/RegisterPage-799ed804.css (100%) rename {addons/dashboard => dashboard}/dist/assets/RegisterPage-a274fd0f.js (100%) rename {addons/dashboard => dashboard}/dist/assets/ShadowPage-4758709f.js (100%) rename {addons/dashboard => dashboard}/dist/assets/TablerIcons-da1fd166.js (100%) rename {addons/dashboard => dashboard}/dist/assets/TypographyPage-ee445c8b.js (100%) rename {addons/dashboard => dashboard}/dist/assets/UiParentCard.vue_vue_type_script_setup_true_lang-b40a2daa.js (100%) rename {addons/dashboard => dashboard}/dist/assets/_plugin-vue_export-helper-c27b6911.js (100%) rename {addons/dashboard => dashboard}/dist/assets/img-error-bg-41f65efa.svg (100%) rename {addons/dashboard => dashboard}/dist/assets/img-error-blue-f50c8e77.svg (100%) rename {addons/dashboard => dashboard}/dist/assets/img-error-purple-b97a483b.svg (100%) rename {addons/dashboard => dashboard}/dist/assets/img-error-text-630dc36d.svg (100%) rename {addons/dashboard => dashboard}/dist/assets/index-0f1523f3.css (100%) rename {addons/dashboard => dashboard}/dist/assets/index-5ac7c267.js (100%) rename {addons/dashboard => dashboard}/dist/assets/materialdesignicons-webfont-67d24abe.eot (100%) rename {addons/dashboard => dashboard}/dist/assets/materialdesignicons-webfont-80bb28b3.woff (100%) rename {addons/dashboard => dashboard}/dist/assets/materialdesignicons-webfont-a58ecb54.ttf (100%) rename {addons/dashboard => dashboard}/dist/assets/materialdesignicons-webfont-c1c004a9.woff2 (100%) rename {addons/dashboard => dashboard}/dist/assets/md5-086248bf.js (100%) rename {addons/dashboard => dashboard}/dist/assets/social-google-9b2fa67a.svg (100%) rename {addons/dashboard => dashboard}/dist/favicon.svg (100%) rename {addons/dashboard => dashboard}/dist/index.html (100%) rename {addons/dashboard => dashboard}/helper.py (92%) rename {addons/dashboard => dashboard}/server.py (78%) create mode 100644 model/command/internal_handler.py create mode 100644 model/command/manager.py delete mode 100644 model/command/openai_official.py create mode 100644 model/command/openai_official_handler.py create mode 100644 model/command/parser.py create mode 100644 model/platform/__init__.py delete mode 100644 model/platform/_message_parse.py delete mode 100644 model/platform/_message_result.py delete mode 100644 model/platform/_platfrom.py create mode 100644 model/platform/manager.py create mode 100644 model/platform/qq_aiocqhttp.py delete mode 100644 model/platform/qq_gocq.py create mode 100644 model/platform/qq_nakuru.py create mode 100644 model/plugin/command.py create mode 100644 model/plugin/manager.py rename type/{message.py => astrbot_message.py} (53%) create mode 100644 type/message_event.py create mode 100644 util/config_utils.py delete mode 100644 util/image_render/helper.py create mode 100644 util/image_uploader.py create mode 100644 util/io.py delete mode 100644 util/plugin_util.py delete mode 100644 util/search_engine_scraper/test.py create mode 100644 util/t2i/__init__.py create mode 100644 util/t2i/context.py create mode 100644 util/t2i/renderer.py create mode 100644 util/t2i/strategies/__init__.py create mode 100644 util/t2i/strategies/base_strategy.py create mode 100644 util/t2i/strategies/local_strategy.py create mode 100644 util/t2i/strategies/network_strategy.py rename util/{image_render => t2i/strategies}/template/base.html (100%) delete mode 100644 util/updator.py create mode 100644 util/updator/astrbot_updator.py create mode 100644 util/updator/plugin_updator.py create mode 100644 util/updator/zip_updator.py rename util/{search_engine_scraper => websearch}/bing.py (89%) rename util/{search_engine_scraper => websearch}/config.py (100%) rename util/{search_engine_scraper => websearch}/engine.py (97%) rename util/{search_engine_scraper => websearch}/google.py (84%) rename util/{search_engine_scraper => websearch}/sogo.py (91%) diff --git a/Dockerfile b/Dockerfile index 98cff53..93c0a29 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.10.13-bullseye +FROM python:3.10-slim WORKDIR /AstrBot COPY . /AstrBot/ diff --git a/addons/plugins/helloworld/main.py b/addons/plugins/helloworld/main.py index a954000..a088cda 100644 --- a/addons/plugins/helloworld/main.py +++ b/addons/plugins/helloworld/main.py @@ -1,5 +1,3 @@ -import os -import shutil from nakuru.entities.components import * flag_not_support = False try: diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py new file mode 100644 index 0000000..fdfc922 --- /dev/null +++ b/astrbot/bootstrap.py @@ -0,0 +1,97 @@ +import asyncio +import threading +from astrbot.message.handler import MessageHandler +from astrbot.persist.helper import dbConn +from dashboard.server import AstrBotDashBoard +from model.provider.provider import Provider +from model.command.manager import CommandManager +from model.command.internal_handler import InternalCommandHandler +from model.plugin.manager import PluginManager +from model.platform.manager import PlatformManager +from typing import Dict, List, Union +from type.types import Context +from SparkleLogging.utils.core import LogManager +from logging import Logger +from util.cmd_config import CmdConfig +from util.general_utils import upload_metrics +from util.config_utils import * +from util.updator.astrbot_updator import AstrBotUpdator + +logger: Logger = LogManager.GetLogger(log_name='astrbot') + + +class AstrBotBootstrap(): + def __init__(self) -> None: + self.context = Context() + self.config_helper: CmdConfig = CmdConfig() + + # load configs and ensure the backward compatibility + init_configs() + try_migrate_config() + self.configs = inject_to_context(self.context) + logger.info("AstrBot v" + self.context.version) + self.context.config_helper = self.config_helper + + # apply proxy settings + http_proxy = self.context.base_config.get("http_proxy") + https_proxy = self.context.base_config.get("https_proxy") + if http_proxy: + os.environ['HTTP_PROXY'] = http_proxy + if https_proxy: + os.environ['HTTPS_PROXY'] = https_proxy + os.environ['NO_PROXY'] = 'https://api.sgroup.qq.com' + + if http_proxy and https_proxy: + logger.info(f"使用代理: {http_proxy}, {https_proxy}") + else: + logger.info("未使用代理。") + + async def run(self): + + self.command_manager = CommandManager() + self.plugin_manager = PluginManager(self.context) + self.updator = AstrBotUpdator() + self.cmd_handler = InternalCommandHandler(self.command_manager, self.plugin_manager) + self.db_conn_helper = dbConn() + + # load llm provider + self.llm_instance: Provider = None + self.load_llm() + + self.message_handler = MessageHandler(self.context, self.command_manager, self.db_conn_helper, self.llm_instance) + self.platfrom_manager = PlatformManager(self.context, self.message_handler) + self.dashboard = AstrBotDashBoard(self.context, plugin_manager=self.plugin_manager, astrbot_updator=self.updator) + + self.context.updator = self.updator + self.context.plugin_updator = self.plugin_manager.updator + + # load plugins, plugins' commands. + self.load_plugins() + self.command_manager.register_from_pcb(self.context.plugin_command_bridge) + + # load platforms + platform_tasks = self.load_platform() + # load metrics uploader + metrics_upload_task = upload_metrics(self.context) + # load dashboard + self.dashboard.run_http_server() + dashboard_task = self.dashboard.ws_server() + + await asyncio.gather(metrics_upload_task, dashboard_task, *platform_tasks) + + + def load_llm(self): + if 'openai' in self.configs and \ + len(self.configs['openai']['key']) and \ + self.configs['openai']['key'][0] is not None: + from model.provider.openai_official import ProviderOpenAIOfficial + from model.command.openai_official_handler import OpenAIOfficialCommandHandler + self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager) + self.llm_instance = ProviderOpenAIOfficial(self.context) + logger.info("已启用 OpenAI API 支持。") + + def load_plugins(self): + self.plugin_manager.plugin_reload() + + def load_platform(self): + return self.platfrom_manager.load_platforms() \ No newline at end of file diff --git a/astrbot/core.py b/astrbot/core.py deleted file mode 100644 index e5a64ad..0000000 --- a/astrbot/core.py +++ /dev/null @@ -1,415 +0,0 @@ -import re -import threading -import asyncio -import time -import util.unfit_words as uw -import os -import sys -import traceback - -import util.agent.web_searcher as web_searcher -import util.plugin_util as putil - -from nakuru.entities.components import Plain, At, Image - -from addons.baidu_aip_judge import BaiduJudge -from model.provider.provider import Provider -from model.command.command import Command -from util import general_utils as gu -from util.general_utils import upload, run_monitor -from util.cmd_config import CmdConfig as cc -from util.cmd_config import init_astrbot_config_items -from type.types import GlobalObject -from type.register import * -from type.message import AstrBotMessage -from type.config import * -from addons.dashboard.helper import DashBoardHelper -from addons.dashboard.server import DashBoardData -from persist.session import dbConn -from model.platform._message_result import MessageResult -from SparkleLogging.utils.core import LogManager -from logging import Logger - -logger: Logger = LogManager.GetLogger(log_name='astrbot-core') - -# 用户发言频率 -user_frequency = {} -# 时间默认值 -frequency_time = 60 -# 计数默认值 -frequency_count = 10 - -# 语言模型 -OPENAI_OFFICIAL = 'openai_official' -NONE_LLM = 'none_llm' -chosen_provider = None -# 语言模型对象 -llm_instance: dict[str, Provider] = {} -llm_command_instance: dict[str, Command] = {} -llm_wake_prefix = "" - -# 百度内容审核实例 -baidu_judge = None - -# 全局对象 -_global_object: GlobalObject = None - - -def privider_chooser(cfg): - l = [] - if 'openai' in cfg and len(cfg['openai']['key']) and cfg['openai']['key'][0]: - l.append('openai_official') - return l - -def init(): - ''' - 初始化机器人 - ''' - global llm_instance, llm_command_instance - global baidu_judge, chosen_provider - global frequency_count, frequency_time - global _global_object - - init_astrbot_config_items() - cfg = cc.get_all() - - _event_loop = asyncio.new_event_loop() - asyncio.set_event_loop(_event_loop) - - # 初始化 global_object - _global_object = GlobalObject() - _global_object.version = VERSION - _global_object.base_config = cfg - _global_object.logger = logger - logger.info("AstrBot v" + VERSION) - - if 'reply_prefix' in cfg: - # 适配旧版配置 - if isinstance(cfg['reply_prefix'], dict): - _global_object.reply_prefix = "" - cfg['reply_prefix'] = "" - cc.put("reply_prefix", "") - else: - _global_object.reply_prefix = cfg['reply_prefix'] - - default_personality_str = cc.get("default_personality_str", "") - if default_personality_str == "": - _global_object.default_personality = None - else: - _global_object.default_personality = { - "name": "default", - "prompt": default_personality_str, - } - - # 语言模型提供商 - logger.info("正在载入语言模型...") - prov = privider_chooser(cfg) - if OPENAI_OFFICIAL in prov: - logger.info("初始化:OpenAI官方") - if cfg['openai']['key'] is not None and cfg['openai']['key'] != [None]: - from model.provider.openai_official import ProviderOpenAIOfficial - from model.command.openai_official import CommandOpenAIOfficial - llm_instance[OPENAI_OFFICIAL] = ProviderOpenAIOfficial( - cfg['openai']) - llm_command_instance[OPENAI_OFFICIAL] = CommandOpenAIOfficial( - llm_instance[OPENAI_OFFICIAL], _global_object) - _global_object.llms.append(RegisteredLLM( - llm_name=OPENAI_OFFICIAL, llm_instance=llm_instance[OPENAI_OFFICIAL], origin="internal")) - chosen_provider = OPENAI_OFFICIAL - - instance = llm_instance[OPENAI_OFFICIAL] - assert isinstance(instance, ProviderOpenAIOfficial) - instance.DEFAULT_PERSONALITY = _global_object.default_personality - instance.curr_personality = instance.DEFAULT_PERSONALITY - - # 检查provider设置偏好 - p = cc.get("chosen_provider", None) - if p is not None and p in llm_instance: - chosen_provider = p - - # 百度内容审核 - if 'baidu_aip' in cfg and 'enable' in cfg['baidu_aip'] and cfg['baidu_aip']['enable']: - try: - baidu_judge = BaiduJudge(cfg['baidu_aip']) - logger.info("百度内容审核初始化成功") - except BaseException as e: - logger.info("百度内容审核初始化失败") - - threading.Thread(target=upload, args=( - _global_object, ), daemon=True).start() - - # 得到发言频率配置 - if 'limit' in cfg: - if 'count' in cfg['limit']: - frequency_count = cfg['limit']['count'] - if 'time' in cfg['limit']: - frequency_time = cfg['limit']['time'] - - try: - if 'uniqueSessionMode' in cfg and cfg['uniqueSessionMode']: - _global_object.unique_session = True - else: - _global_object.unique_session = False - except BaseException as e: - logger.info("独立会话配置错误: "+str(e)) - - nick_qq = cc.get("nick_qq", None) - if not nick_qq: - nick_qq = ("ai", "!", "!") - if isinstance(nick_qq, str): - nick_qq = (nick_qq,) - if isinstance(nick_qq, list): - nick_qq = tuple(nick_qq) - _global_object.nick = nick_qq - - # 语言模型唤醒词 - global llm_wake_prefix - llm_wake_prefix = cc.get("llm_wake_prefix", "") - - logger.info("正在载入插件...") - # 加载插件 - _command = Command(None, _global_object) - ok, err = putil.plugin_reload(_global_object) - if ok: - logger.info( - f"成功载入 {len(_global_object.cached_plugins)} 个插件") - else: - logger.error(err) - - if chosen_provider is None: - llm_command_instance[NONE_LLM] = _command - chosen_provider = NONE_LLM - - logger.info("正在载入机器人消息平台") - # GOCQ - if 'gocqbot' in cfg and cfg['gocqbot']['enable']: - logger.info("启用 QQ_GOCQ 机器人消息平台") - threading.Thread(target=run_gocq_bot, args=( - cfg, _global_object), daemon=True).start() - - # QQ频道 - if 'qqbot' in cfg and cfg['qqbot']['enable'] and cfg['qqbot']['appid'] != None: - logger.info("启用 QQ_OFFICIAL 机器人消息平台") - threading.Thread(target=run_qqchan_bot, args=( - cfg, _global_object), daemon=True).start() - - # 初始化dashboard - _global_object.dashboard_data = DashBoardData( - stats={}, - configs={}, - logs={}, - plugins=_global_object.cached_plugins, - ) - dashboard_helper = DashBoardHelper(_global_object, config=cc.get_all()) - dashboard_thread = threading.Thread( - target=dashboard_helper.run, daemon=True) - dashboard_thread.start() - - # 运行 monitor - threading.Thread(target=run_monitor, args=( - _global_object,), daemon=True).start() - - logger.info( - "如果有任何问题, 请在 https://github.com/Soulter/AstrBot 上提交 issue 或加群 322154837。") - logger.info("请给 https://github.com/Soulter/AstrBot 点个 star。") - logger.info(f"🎉 项目启动完成") - - dashboard_thread.join() - - -def run_qqchan_bot(cfg: dict, global_object: GlobalObject): - ''' - 运行 QQ_OFFICIAL 机器人 - ''' - try: - from model.platform.qq_official import QQOfficial - qqchannel_bot = QQOfficial( - cfg=cfg, message_handler=oper_msg, global_object=global_object) - global_object.platforms.append(RegisteredPlatform( - platform_name="qqchan", platform_instance=qqchannel_bot, origin="internal")) - qqchannel_bot.run() - except BaseException as e: - logger.error("启动 QQ 频道机器人时出现错误, 原因如下: " + str(e)) - logger.error(r"如果您是初次启动,请前往可视化面板填写配置。详情请看:https://astrbot.soulter.top/center/。") - - -def run_gocq_bot(cfg: dict, _global_object: GlobalObject): - ''' - 运行 QQ_GOCQ 机器人 - ''' - from model.platform.qq_gocq import QQGOCQ - noticed = False - host = cc.get("gocq_host", "127.0.0.1") - port = cc.get("gocq_websocket_port", 6700) - http_port = cc.get("gocq_http_port", 5700) - logger.info( - f"正在检查连接...host: {host}, ws port: {port}, http port: {http_port}") - while True: - if not gu.port_checker(port=port, host=host) or not gu.port_checker(port=http_port, host=host): - if not noticed: - noticed = True - logger.warning( - f"连接到{host}:{port}(或{http_port})失败。程序会每隔 5s 自动重试。") - time.sleep(5) - else: - logger.info("已连接到 gocq。") - break - try: - qq_gocq = QQGOCQ(cfg=cfg, message_handler=oper_msg, - global_object=_global_object) - _global_object.platforms.append(RegisteredPlatform( - platform_name="gocq", platform_instance=qq_gocq, origin="internal")) - qq_gocq.run() - except BaseException as e: - input("启动QQ机器人出现错误"+str(e)) - - -def check_frequency(id) -> bool: - ''' - 检查发言频率 - ''' - ts = int(time.time()) - if id in user_frequency: - if ts-user_frequency[id]['time'] > frequency_time: - user_frequency[id]['time'] = ts - user_frequency[id]['count'] = 1 - return True - else: - if user_frequency[id]['count'] >= frequency_count: - return False - else: - user_frequency[id]['count'] += 1 - return True - else: - t = {'time': ts, 'count': 1} - user_frequency[id] = t - return True - - -async def record_message(platform: str, session_id: str): - # TODO: 这里会非常吃资源。然而 sqlite3 不支持多线程,所以暂时这样写。 - curr_ts = int(time.time()) - db_inst = dbConn() - db_inst.increment_stat_session(platform, session_id, 1) - db_inst.increment_stat_message(curr_ts, 1) - db_inst.increment_stat_platform(curr_ts, platform, 1) - - -async def oper_msg(message: AstrBotMessage, - session_id: str, - role: str = 'member', - platform: str = None, - ) -> MessageResult: - """ - 处理消息。 - message: 消息对象 - session_id: 该消息源的唯一识别号 - role: member | admin - platform: str 所注册的平台的名称。如果没有注册,将抛出一个异常。 - """ - global chosen_provider, _global_object - message_str = message.message_str - hit = False # 是否命中指令 - command_result = () # 调用指令返回的结果 - llm_result_str = "" - - # 获取平台实例 - reg_platform: RegisteredPlatform = None - for p in _global_object.platforms: - if p.platform_name == platform: - reg_platform = p - break - if not reg_platform: - raise Exception(f"未找到平台 {platform} 的实例。") - - # 统计数据,如频道消息量 - await record_message(platform, session_id) - - if not message_str: - return MessageResult("Hi~") - - # 检查发言频率 - if not check_frequency(message.sender.user_id): - return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。') - - # check commands and plugins - message_str_no_wake_prefix = message_str - for wake_prefix in _global_object.nick: # nick: tuple - if message_str.startswith(wake_prefix): - message_str_no_wake_prefix = message_str.removeprefix(wake_prefix) - break - hit, command_result = await llm_command_instance[chosen_provider].check_command( - message_str_no_wake_prefix, - session_id, - role, - reg_platform, - message, - ) - - # 没触发指令 - if not hit: - # 关键词拦截 - for i in uw.unfit_words_q: - matches = re.match(i, message_str.strip(), re.I | re.M) - if matches: - return MessageResult(f"你的提问得到的回复未通过【默认关键词拦截】服务, 不予回复。") - if baidu_judge != None: - check, msg = await asyncio.to_thread(baidu_judge.judge, message_str) - if not check: - return MessageResult(f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}") - if chosen_provider == NONE_LLM: - logger.info("一条消息由于 Bot 未启动任何语言模型并且未触发指令而将被忽略。") - return - try: - if llm_wake_prefix and not message_str.startswith(llm_wake_prefix): - return - # check image url - image_url = None - for comp in message.message: - if isinstance(comp, Image): - if comp.url is None: - image_url = comp.file - break - else: - image_url = comp.url - break - # web search keyword - web_sch_flag = False - if message_str.startswith("ws ") and message_str != "ws ": - message_str = message_str[3:] - web_sch_flag = True - else: - message_str += "\n" + cc.get("llm_env_prompt", "") - if chosen_provider == OPENAI_OFFICIAL: - if _global_object.web_search or web_sch_flag: - official_fc = chosen_provider == OPENAI_OFFICIAL - llm_result_str = await web_searcher.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc) - else: - llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url) - - llm_result_str = _global_object.reply_prefix + llm_result_str - except BaseException as e: - logger.error(f"调用异常:{traceback.format_exc()}") - return MessageResult(f"调用异常。详细原因:{str(e)}") - - if hit: - # 有指令或者插件触发 - # command_result 是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型) - if not command_result: - return - if not command_result[0]: - return MessageResult(f"指令调用错误: \n{str(command_result[1])}") - if isinstance(command_result[1], (list, str)): - return MessageResult(command_result[1]) - - # 敏感过滤 - # 过滤不合适的词 - for i in uw.unfit_words: - llm_result_str = re.sub(i, "***", llm_result_str) - # 百度内容审核服务二次审核 - if baidu_judge != None: - check, msg = await asyncio.to_thread(baidu_judge.judge, llm_result_str) - if not check: - return MessageResult(f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}") - # 发送信息 - return MessageResult(llm_result_str) diff --git a/addons/baidu_aip_judge.py b/astrbot/message/baidu_aip_judge.py similarity index 100% rename from addons/baidu_aip_judge.py rename to astrbot/message/baidu_aip_judge.py diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py new file mode 100644 index 0000000..0ad9ac3 --- /dev/null +++ b/astrbot/message/handler.py @@ -0,0 +1,198 @@ +import time +import re +import asyncio +import traceback +import astrbot.message.unfit_words as uw + +from typing import Dict +from astrbot.persist.helper import dbConn +from model.provider.provider import Provider +from model.command.manager import CommandManager +from type.message_event import AstrMessageEvent, MessageResult +from type.types import Context +from type.command import CommandResult +from SparkleLogging.utils.core import LogManager +from logging import Logger +from nakuru.entities.components import Image +import util.agent.web_searcher as web_searcher + +logger: Logger = LogManager.GetLogger(log_name='astrbot') + + +class RateLimitHelper(): + def __init__(self, context: Context) -> None: + self.user_rate_limit: Dict[int, int] = {} + self.rate_limit_time: int = 60 + self.rate_limit_count: int = 10 + self.user_frequency = {} + + if 'limit' in context.base_config: + if 'count' in context.base_config['limit']: + self.rate_limit_count = context.base_config['limit']['count'] + if 'time' in context.base_config['limit']: + self.rate_limit_time = context.base_config['limit']['time'] + + def check_frequency(self, session_id: str) -> bool: + ''' + 检查发言频率 + ''' + ts = int(time.time()) + if session_id in self.user_frequency: + if ts-self.user_frequency[session_id]['time'] > self.rate_limit_time: + self.user_frequency[session_id]['time'] = ts + self.user_frequency[session_id]['count'] = 1 + return True + else: + if self.user_frequency[session_id]['count'] >= self.rate_limit_count: + return False + else: + self.user_frequency[session_id]['count'] += 1 + return True + else: + t = {'time': ts, 'count': 1} + self.user_frequency[session_id] = t + return True + +class ContentSafetyHelper(): + def __init__(self, context: Context) -> None: + self.baidu_judge = None + if 'baidu_api' in context.base_config and \ + 'enable' in context.base_config['baidu_aip'] and \ + context.base_config['baidu_aip']['enable']: + try: + from astrbot.message.baidu_aip_judge import BaiduJudge + self.baidu_judge = BaiduJudge(context.base_config['baidu_aip']) + logger.info("已启用百度 AI 内容审核。") + except BaseException as e: + logger.error("百度 AI 内容审核初始化失败。") + logger.error(e) + + async def check_content(self, content: str) -> bool: + ''' + 检查文本内容是否合法 + ''' + for i in uw.unfit_words_q: + matches = re.match(i, content.strip(), re.I | re.M) + if matches: + return False + if self.baidu_judge != None: + check, msg = await asyncio.to_thread(self.baidu_judge.judge, content) + if not check: + logger.info(f"百度 AI 内容审核发现以下违规:{msg}") + return False + return True + + def filter_content(self, content: str) -> str: + ''' + 过滤文本内容 + ''' + for i in uw.unfit_words_q: + content = re.sub(i, "*", content, flags=re.I) + return content + + def baidu_check(self, content: str) -> bool: + ''' + 使用百度 AI 内容审核检查文本内容是否合法 + ''' + if self.baidu_judge != None: + check, msg = self.baidu_judge.judge(content) + if not check: + logger.info(f"百度 AI 内容审核发现以下违规:{msg}") + return False + return True + +class MessageHandler(): + def __init__(self, context: Context, + command_manager: CommandManager, + persist_manager: dbConn, + provider: Provider) -> None: + self.context = context + self.command_manager = command_manager + self.persist_manager = persist_manager + self.rate_limit_helper = RateLimitHelper(context) + self.content_safety_helper = ContentSafetyHelper(context) + self.llm_wake_prefix = self.context.base_config['llm_wake_prefix'] + self.nicks = self.context.nick + self.provider = provider + self.reply_prefix = self.context.reply_prefix + + async def handle(self, message: AstrMessageEvent, llm_provider: Provider = None) -> MessageResult: + ''' + Handle the message event, including commands, plugins, etc. + + `llm_provider`: the provider to use for LLM. If None, use the default provider + ''' + msg_plain = message.message_str.strip() + provider = llm_provider if llm_provider else self.provider + inner_provider = False if llm_provider else True + + self.persist_manager.record_message(message.platform.platform_name, message.session_id) + + # TODO: this should be configurable + if not message.message_str: + return MessageResult("Hi~") + + # check the rate limit + if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id): + return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置 {self.rate_limit_helper.rate_limit_time} 秒内只能提问{self.rate_limit_helper.rate_limit_count} 次。') + + # remove the nick prefix + for nick in self.nicks: + if msg_plain.startswith(nick): + msg_plain = msg_plain.removeprefix(nick) + break + + # scan candidate commands + cmd_res = await self.command_manager.scan_command(message, self.context) + if cmd_res: + assert(isinstance(cmd_res, CommandResult)) + return MessageResult( + cmd_res.message_chain, + is_command_call=True + ) + + # check if the message is a llm-wake-up command + if not msg_plain.startswith(self.llm_wake_prefix): + return + + if not provider: + return + + # check the content safety + if not await self.content_safety_helper.check_content(msg_plain): + return MessageResult("信息包含违规内容,由于机器人管理者开启内容安全审核,你的此条消息已被停止继续处理。") + + image_url = None + for comp in message.message_obj.message: + if isinstance(comp, Image): + image_url = comp.url if comp.url else comp.file + break + web_search = self.context.web_search + if not web_search and msg_plain.startswith("ws"): + # leverage web search feature + web_search = True + msg_plain = msg_plain.removeprefix("ws").strip() + + try: + if web_search: + llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, inner_provider) + else: + llm_result = await provider.text_chat( + msg_plain, message.session_id, image_url + ) + except BaseException as e: + logger.error(traceback.format_exc()) + logger.error(f"LLM 调用失败。") + return MessageResult("AstrBot 请求 LLM 资源失败:" + str(e)) + + # concatenate the reply prefix + if self.reply_prefix: + llm_result = self.reply_prefix + llm_result + + # mask the unsafe content + llm_result = self.content_safety_helper.filter_content(llm_result) + check = self.content_safety_helper.baidu_check(llm_result) + if not check: + return MessageResult("LLM 输出的信息包含违规内容,由于机器人管理者开启了内容安全审核,该条消息已拦截。") + + return MessageResult(llm_result) \ No newline at end of file diff --git a/util/unfit_words.py b/astrbot/message/unfit_words.py similarity index 100% rename from util/unfit_words.py rename to astrbot/message/unfit_words.py diff --git a/persist/session.py b/astrbot/persist/helper.py similarity index 87% rename from persist/session.py rename to astrbot/persist/helper.py index a0a4511..89d5a4d 100644 --- a/persist/session.py +++ b/astrbot/persist/helper.py @@ -4,52 +4,25 @@ import time from typing import Tuple - class dbConn(): def __init__(self): db_path = "data/data.db" if os.path.exists("data.db"): shutil.copy("data.db", db_path) - conn = sqlite3.connect(db_path) - conn.text_factory = str - self.conn = conn - c = conn.cursor() - c.execute( - ''' - CREATE TABLE IF NOT EXISTS tb_session( - qq_id VARCHAR(32) PRIMARY KEY, - history TEXT - ); - ''' - ) - c.execute( - ''' - CREATE TABLE IF NOT EXISTS tb_stat_session( - platform VARCHAR(32), - session_id VARCHAR(32), - cnt INTEGER - ); - ''' - ) - c.execute( - ''' - CREATE TABLE IF NOT EXISTS tb_stat_message( - ts INTEGER, - cnt INTEGER - ); - ''' - ) - c.execute( - ''' - CREATE TABLE IF NOT EXISTS tb_stat_platform( - ts INTEGER, - platform VARCHAR(32), - cnt INTEGER - ); - ''' - ) + with open(os.path.dirname(__file__) + "/initialization.sql", "r") as f: + sql = f.read() - conn.commit() + self.conn = sqlite3.connect(db_path) + self.conn.text_factory = str + c = self.conn.cursor() + c.executescript(sql) + self.conn.commit() + + def record_message(self, platform, session_id): + curr_ts = int(time.time()) + self.increment_stat_session(platform, session_id, 1) + self.increment_stat_message(curr_ts, 1) + self.increment_stat_platform(curr_ts, platform, 1) def insert_session(self, qq_id, history): conn = self.conn diff --git a/astrbot/persist/initialization.sql b/astrbot/persist/initialization.sql new file mode 100644 index 0000000..b047c26 --- /dev/null +++ b/astrbot/persist/initialization.sql @@ -0,0 +1,18 @@ +CREATE TABLE IF NOT EXISTS tb_session( + qq_id VARCHAR(32) PRIMARY KEY, + history TEXT +); +CREATE TABLE IF NOT EXISTS tb_stat_session( + platform VARCHAR(32), + session_id VARCHAR(32), + cnt INTEGER +); +CREATE TABLE IF NOT EXISTS tb_stat_message( + ts INTEGER, + cnt INTEGER +); +CREATE TABLE IF NOT EXISTS tb_stat_platform( + ts INTEGER, + platform VARCHAR(32), + cnt INTEGER +); \ No newline at end of file diff --git a/dashboard/__init__.py b/dashboard/__init__.py new file mode 100644 index 0000000..8e4ce69 --- /dev/null +++ b/dashboard/__init__.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + +class DashBoardData(): + stats: dict = {} + configs: dict = {} + +@dataclass +class Response(): + status: str + message: str + data: dict diff --git a/addons/dashboard/dist/_redirects b/dashboard/dist/_redirects similarity index 100% rename from addons/dashboard/dist/_redirects rename to dashboard/dist/_redirects diff --git a/addons/dashboard/dist/assets/BaseBreadcrumb-4d676ba5.css b/dashboard/dist/assets/BaseBreadcrumb-4d676ba5.css similarity index 100% rename from addons/dashboard/dist/assets/BaseBreadcrumb-4d676ba5.css rename to dashboard/dist/assets/BaseBreadcrumb-4d676ba5.css diff --git a/addons/dashboard/dist/assets/BaseBreadcrumb.vue_vue_type_style_index_0_lang-1875d383.js b/dashboard/dist/assets/BaseBreadcrumb.vue_vue_type_style_index_0_lang-1875d383.js similarity index 100% rename from addons/dashboard/dist/assets/BaseBreadcrumb.vue_vue_type_style_index_0_lang-1875d383.js rename to dashboard/dist/assets/BaseBreadcrumb.vue_vue_type_style_index_0_lang-1875d383.js diff --git a/addons/dashboard/dist/assets/BlankLayout-471b2640.js b/dashboard/dist/assets/BlankLayout-471b2640.js similarity index 100% rename from addons/dashboard/dist/assets/BlankLayout-471b2640.js rename to dashboard/dist/assets/BlankLayout-471b2640.js diff --git a/addons/dashboard/dist/assets/ColorPage-beafd674.js b/dashboard/dist/assets/ColorPage-beafd674.js similarity index 100% rename from addons/dashboard/dist/assets/ColorPage-beafd674.js rename to dashboard/dist/assets/ColorPage-beafd674.js diff --git a/addons/dashboard/dist/assets/ConfigDetailCard-756c045d.js b/dashboard/dist/assets/ConfigDetailCard-756c045d.js similarity index 100% rename from addons/dashboard/dist/assets/ConfigDetailCard-756c045d.js rename to dashboard/dist/assets/ConfigDetailCard-756c045d.js diff --git a/addons/dashboard/dist/assets/ConfigPage-56ea019d.js b/dashboard/dist/assets/ConfigPage-56ea019d.js similarity index 100% rename from addons/dashboard/dist/assets/ConfigPage-56ea019d.js rename to dashboard/dist/assets/ConfigPage-56ea019d.js diff --git a/addons/dashboard/dist/assets/ConsolePage-bd0aea4c.js b/dashboard/dist/assets/ConsolePage-bd0aea4c.js similarity index 100% rename from addons/dashboard/dist/assets/ConsolePage-bd0aea4c.js rename to dashboard/dist/assets/ConsolePage-bd0aea4c.js diff --git a/addons/dashboard/dist/assets/ConsolePage-ff373be6.css b/dashboard/dist/assets/ConsolePage-ff373be6.css similarity index 100% rename from addons/dashboard/dist/assets/ConsolePage-ff373be6.css rename to dashboard/dist/assets/ConsolePage-ff373be6.css diff --git a/addons/dashboard/dist/assets/DefaultDashboard-ef4a941d.js b/dashboard/dist/assets/DefaultDashboard-ef4a941d.js similarity index 100% rename from addons/dashboard/dist/assets/DefaultDashboard-ef4a941d.js rename to dashboard/dist/assets/DefaultDashboard-ef4a941d.js diff --git a/addons/dashboard/dist/assets/Error404Page-11cf087a.css b/dashboard/dist/assets/Error404Page-11cf087a.css similarity index 100% rename from addons/dashboard/dist/assets/Error404Page-11cf087a.css rename to dashboard/dist/assets/Error404Page-11cf087a.css diff --git a/addons/dashboard/dist/assets/Error404Page-bcb6eec8.js b/dashboard/dist/assets/Error404Page-bcb6eec8.js similarity index 100% rename from addons/dashboard/dist/assets/Error404Page-bcb6eec8.js rename to dashboard/dist/assets/Error404Page-bcb6eec8.js diff --git a/addons/dashboard/dist/assets/ExtensionPage-701bf929.js b/dashboard/dist/assets/ExtensionPage-701bf929.js similarity index 100% rename from addons/dashboard/dist/assets/ExtensionPage-701bf929.js rename to dashboard/dist/assets/ExtensionPage-701bf929.js diff --git a/addons/dashboard/dist/assets/FullLayout-35e69863.js b/dashboard/dist/assets/FullLayout-35e69863.js similarity index 100% rename from addons/dashboard/dist/assets/FullLayout-35e69863.js rename to dashboard/dist/assets/FullLayout-35e69863.js diff --git a/addons/dashboard/dist/assets/LoginPage-5c692a20.js b/dashboard/dist/assets/LoginPage-5c692a20.js similarity index 100% rename from addons/dashboard/dist/assets/LoginPage-5c692a20.js rename to dashboard/dist/assets/LoginPage-5c692a20.js diff --git a/addons/dashboard/dist/assets/LoginPage-74e85ca7.css b/dashboard/dist/assets/LoginPage-74e85ca7.css similarity index 100% rename from addons/dashboard/dist/assets/LoginPage-74e85ca7.css rename to dashboard/dist/assets/LoginPage-74e85ca7.css diff --git a/addons/dashboard/dist/assets/LogoDark.vue_vue_type_script_setup_true_lang-d555e5be.js b/dashboard/dist/assets/LogoDark.vue_vue_type_script_setup_true_lang-d555e5be.js similarity index 100% rename from addons/dashboard/dist/assets/LogoDark.vue_vue_type_script_setup_true_lang-d555e5be.js rename to dashboard/dist/assets/LogoDark.vue_vue_type_script_setup_true_lang-d555e5be.js diff --git a/addons/dashboard/dist/assets/MaterialIcons-a926bc0c.js b/dashboard/dist/assets/MaterialIcons-a926bc0c.js similarity index 100% rename from addons/dashboard/dist/assets/MaterialIcons-a926bc0c.js rename to dashboard/dist/assets/MaterialIcons-a926bc0c.js diff --git a/addons/dashboard/dist/assets/RegisterPage-799ed804.css b/dashboard/dist/assets/RegisterPage-799ed804.css similarity index 100% rename from addons/dashboard/dist/assets/RegisterPage-799ed804.css rename to dashboard/dist/assets/RegisterPage-799ed804.css diff --git a/addons/dashboard/dist/assets/RegisterPage-a274fd0f.js b/dashboard/dist/assets/RegisterPage-a274fd0f.js similarity index 100% rename from addons/dashboard/dist/assets/RegisterPage-a274fd0f.js rename to dashboard/dist/assets/RegisterPage-a274fd0f.js diff --git a/addons/dashboard/dist/assets/ShadowPage-4758709f.js b/dashboard/dist/assets/ShadowPage-4758709f.js similarity index 100% rename from addons/dashboard/dist/assets/ShadowPage-4758709f.js rename to dashboard/dist/assets/ShadowPage-4758709f.js diff --git a/addons/dashboard/dist/assets/TablerIcons-da1fd166.js b/dashboard/dist/assets/TablerIcons-da1fd166.js similarity index 100% rename from addons/dashboard/dist/assets/TablerIcons-da1fd166.js rename to dashboard/dist/assets/TablerIcons-da1fd166.js diff --git a/addons/dashboard/dist/assets/TypographyPage-ee445c8b.js b/dashboard/dist/assets/TypographyPage-ee445c8b.js similarity index 100% rename from addons/dashboard/dist/assets/TypographyPage-ee445c8b.js rename to dashboard/dist/assets/TypographyPage-ee445c8b.js diff --git a/addons/dashboard/dist/assets/UiParentCard.vue_vue_type_script_setup_true_lang-b40a2daa.js b/dashboard/dist/assets/UiParentCard.vue_vue_type_script_setup_true_lang-b40a2daa.js similarity index 100% rename from addons/dashboard/dist/assets/UiParentCard.vue_vue_type_script_setup_true_lang-b40a2daa.js rename to dashboard/dist/assets/UiParentCard.vue_vue_type_script_setup_true_lang-b40a2daa.js diff --git a/addons/dashboard/dist/assets/_plugin-vue_export-helper-c27b6911.js b/dashboard/dist/assets/_plugin-vue_export-helper-c27b6911.js similarity index 100% rename from addons/dashboard/dist/assets/_plugin-vue_export-helper-c27b6911.js rename to dashboard/dist/assets/_plugin-vue_export-helper-c27b6911.js diff --git a/addons/dashboard/dist/assets/img-error-bg-41f65efa.svg b/dashboard/dist/assets/img-error-bg-41f65efa.svg similarity index 100% rename from addons/dashboard/dist/assets/img-error-bg-41f65efa.svg rename to dashboard/dist/assets/img-error-bg-41f65efa.svg diff --git a/addons/dashboard/dist/assets/img-error-blue-f50c8e77.svg b/dashboard/dist/assets/img-error-blue-f50c8e77.svg similarity index 100% rename from addons/dashboard/dist/assets/img-error-blue-f50c8e77.svg rename to dashboard/dist/assets/img-error-blue-f50c8e77.svg diff --git a/addons/dashboard/dist/assets/img-error-purple-b97a483b.svg b/dashboard/dist/assets/img-error-purple-b97a483b.svg similarity index 100% rename from addons/dashboard/dist/assets/img-error-purple-b97a483b.svg rename to dashboard/dist/assets/img-error-purple-b97a483b.svg diff --git a/addons/dashboard/dist/assets/img-error-text-630dc36d.svg b/dashboard/dist/assets/img-error-text-630dc36d.svg similarity index 100% rename from addons/dashboard/dist/assets/img-error-text-630dc36d.svg rename to dashboard/dist/assets/img-error-text-630dc36d.svg diff --git a/addons/dashboard/dist/assets/index-0f1523f3.css b/dashboard/dist/assets/index-0f1523f3.css similarity index 100% rename from addons/dashboard/dist/assets/index-0f1523f3.css rename to dashboard/dist/assets/index-0f1523f3.css diff --git a/addons/dashboard/dist/assets/index-5ac7c267.js b/dashboard/dist/assets/index-5ac7c267.js similarity index 100% rename from addons/dashboard/dist/assets/index-5ac7c267.js rename to dashboard/dist/assets/index-5ac7c267.js diff --git a/addons/dashboard/dist/assets/materialdesignicons-webfont-67d24abe.eot b/dashboard/dist/assets/materialdesignicons-webfont-67d24abe.eot similarity index 100% rename from addons/dashboard/dist/assets/materialdesignicons-webfont-67d24abe.eot rename to dashboard/dist/assets/materialdesignicons-webfont-67d24abe.eot diff --git a/addons/dashboard/dist/assets/materialdesignicons-webfont-80bb28b3.woff b/dashboard/dist/assets/materialdesignicons-webfont-80bb28b3.woff similarity index 100% rename from addons/dashboard/dist/assets/materialdesignicons-webfont-80bb28b3.woff rename to dashboard/dist/assets/materialdesignicons-webfont-80bb28b3.woff diff --git a/addons/dashboard/dist/assets/materialdesignicons-webfont-a58ecb54.ttf b/dashboard/dist/assets/materialdesignicons-webfont-a58ecb54.ttf similarity index 100% rename from addons/dashboard/dist/assets/materialdesignicons-webfont-a58ecb54.ttf rename to dashboard/dist/assets/materialdesignicons-webfont-a58ecb54.ttf diff --git a/addons/dashboard/dist/assets/materialdesignicons-webfont-c1c004a9.woff2 b/dashboard/dist/assets/materialdesignicons-webfont-c1c004a9.woff2 similarity index 100% rename from addons/dashboard/dist/assets/materialdesignicons-webfont-c1c004a9.woff2 rename to dashboard/dist/assets/materialdesignicons-webfont-c1c004a9.woff2 diff --git a/addons/dashboard/dist/assets/md5-086248bf.js b/dashboard/dist/assets/md5-086248bf.js similarity index 100% rename from addons/dashboard/dist/assets/md5-086248bf.js rename to dashboard/dist/assets/md5-086248bf.js diff --git a/addons/dashboard/dist/assets/social-google-9b2fa67a.svg b/dashboard/dist/assets/social-google-9b2fa67a.svg similarity index 100% rename from addons/dashboard/dist/assets/social-google-9b2fa67a.svg rename to dashboard/dist/assets/social-google-9b2fa67a.svg diff --git a/addons/dashboard/dist/favicon.svg b/dashboard/dist/favicon.svg similarity index 100% rename from addons/dashboard/dist/favicon.svg rename to dashboard/dist/favicon.svg diff --git a/addons/dashboard/dist/index.html b/dashboard/dist/index.html similarity index 100% rename from addons/dashboard/dist/index.html rename to dashboard/dist/index.html diff --git a/addons/dashboard/helper.py b/dashboard/helper.py similarity index 92% rename from addons/dashboard/helper.py rename to dashboard/helper.py index a8a4614..4a9a704 100644 --- a/addons/dashboard/helper.py +++ b/dashboard/helper.py @@ -1,20 +1,16 @@ -from addons.dashboard.server import AstrBotDashBoard, DashBoardData -from pydantic import BaseModel +import threading +import asyncio + +from . import DashBoardData from typing import Union, Optional -import uuid -from util import general_utils as gu from util.cmd_config import CmdConfig from dataclasses import dataclass -import sys -import os -import threading -import time -import asyncio from util.plugin_dev.api.v1.config import update_config from SparkleLogging.utils.core import LogManager from logging import Logger +from type.types import Context -logger: Logger = LogManager.GetLogger(log_name='astrbot-core') +logger: Logger = LogManager.GetLogger(log_name='astrbot') @dataclass @@ -29,34 +25,12 @@ class DashBoardConfig(): class DashBoardHelper(): - def __init__(self, global_object, config: dict): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - dashboard_data = global_object.dashboard_data + def __init__(self, context: Context, dashboard_data: DashBoardData): dashboard_data.configs = { "data": [] } - self.parse_default_config(dashboard_data, config) - self.dashboard_data: DashBoardData = dashboard_data - self.dashboard = AstrBotDashBoard(global_object) - self.key_map = {} # key: uuid, value: config key name - self.cc = CmdConfig() - - @self.dashboard.register("post_configs") - def on_post_configs(post_configs: dict): - try: - if 'base_config' in post_configs: - self.save_config( - post_configs['base_config'], namespace='') # 基础配置 - self.save_config( - post_configs['config'], namespace=post_configs['namespace']) # 选定配置 - self.parse_default_config( - self.dashboard_data, self.cc.get_all()) - # 重启 - threading.Thread(target=self.dashboard.shutdown_bot, - args=(2,), daemon=True).start() - except Exception as e: - raise e + self.context = context + self.parse_default_config(dashboard_data, context.base_config) # 将 config.yaml、 中的配置解析到 dashboard_data.configs 中 def parse_default_config(self, dashboard_data: DashBoardData, config: dict): @@ -64,7 +38,7 @@ def parse_default_config(self, dashboard_data: DashBoardData, config: dict): try: qq_official_platform_group = DashBoardConfig( config_type="group", - name="QQ_OFFICIAL 平台配置", + name="QQ(官方)", description="", body=[ DashBoardConfig( @@ -119,7 +93,7 @@ def parse_default_config(self, dashboard_data: DashBoardData, config: dict): ) qq_gocq_platform_group = DashBoardConfig( config_type="group", - name="go-cqhttp", + name="QQ(nakuru)", description="", body=[ DashBoardConfig( @@ -197,6 +171,38 @@ def parse_default_config(self, dashboard_data: DashBoardData, config: dict): ] ) + qq_aiocqhttp_platform_group = DashBoardConfig( + config_type="group", + name="QQ(aiocqhttp)", + description="", + body=[ + DashBoardConfig( + config_type="item", + val_type="bool", + name="启用", + description="", + value=config['aiocqhttp']['enable'], + path="aiocqhttp.enable", + ), + DashBoardConfig( + config_type="item", + val_type="str", + name="WebSocket 反向连接 host", + description="", + value=config['aiocqhttp']['ws_reverse_host'], + path="aiocqhttp.ws_reverse_host", + ), + DashBoardConfig( + config_type="item", + val_type="int", + name="WebSocket 反向连接 port", + description="", + value=config['aiocqhttp']['ws_reverse_port'], + path="aiocqhttp.ws_reverse_port", + ), + ] + ) + general_platform_detail_group = DashBoardConfig( config_type="group", name="通用平台配置", @@ -467,7 +473,8 @@ def parse_default_config(self, dashboard_data: DashBoardData, config: dict): general_platform_detail_group, openai_official_llm_group, other_group, - baidu_aip_group + baidu_aip_group, + qq_aiocqhttp_platform_group ] except Exception as e: @@ -525,9 +532,6 @@ def save_config(self, post_config: list, namespace: str): def _write_config(self, namespace: str, key: str, value): if namespace == "" or namespace.startswith("internal_"): # 机器人自带配置,存到 config.yaml - self.cc.put_by_dot_str(key, value) + self.context.config_helper.put_by_dot_str(key, value) else: update_config(namespace, key, value) - - def run(self): - self.dashboard.run() diff --git a/addons/dashboard/server.py b/dashboard/server.py similarity index 78% rename from addons/dashboard/server.py rename to dashboard/server.py index ff9a4d1..8da717f 100644 --- a/addons/dashboard/server.py +++ b/dashboard/server.py @@ -1,55 +1,45 @@ -import util.plugin_util as putil import websockets import json import threading import asyncio import os import uuid -import time +import logging import traceback +from . import DashBoardData, Response from flask import Flask, request -from flask.logging import default_handler from werkzeug.serving import make_server -from util import general_utils as gu -from dataclasses import dataclass -from persist.session import dbConn -from type.register import RegisteredPlugin +from astrbot.persist.helper import dbConn +from type.types import Context from typing import List -from util.cmd_config import CmdConfig -from util.updator import check_update, update_project, request_release_info, _reboot from SparkleLogging.utils.core import LogManager from logging import Logger -logger: Logger = LogManager.GetLogger(log_name='astrbot-core') +from dashboard.helper import DashBoardHelper +from util.io import get_local_ip_addresses +from model.plugin.manager import PluginManager +from util.updator.astrbot_updator import AstrBotUpdator -@dataclass -class DashBoardData(): - stats: dict - configs: dict - logs: dict - plugins: List[RegisteredPlugin] - -@dataclass -class Response(): - status: str - message: str - data: dict +logger: Logger = LogManager.GetLogger(log_name='astrbot') class AstrBotDashBoard(): - def __init__(self, global_object: 'gu.GlobalObject'): - self.global_object = global_object - self.loop = asyncio.get_event_loop() - asyncio.set_event_loop(self.loop) - self.dashboard_data: DashBoardData = global_object.dashboard_data - self.dashboard_be = Flask( - __name__, static_folder="dist", static_url_path="/") - self.funcs = {} - self.cc = CmdConfig() + def __init__(self, context: Context, plugin_manager: PluginManager, astrbot_updator: AstrBotUpdator): + self.context = context + self.plugin_manager = plugin_manager + self.astrbot_updator = astrbot_updator + self.dashboard_data = DashBoardData() + self.dashboard_helper = DashBoardHelper(self.context, self.dashboard_data) + + self.dashboard_be = Flask(__name__, static_folder="dist", static_url_path="/") + logging.getLogger('werkzeug').setLevel(logging.ERROR) + self.dashboard_be.logger.setLevel(logging.ERROR) + self.ws_clients = {} # remote_ip: ws - # 启动 websocket 服务器 - self.ws_server = websockets.serve(self.__handle_msg, "0.0.0.0", 6186) + self.loop = asyncio.get_event_loop() + + self.http_server_thread: threading.Thread = None @self.dashboard_be.get("/") def index(): @@ -74,8 +64,8 @@ def rt_dashboard(): @self.dashboard_be.post("/api/authenticate") def authenticate(): - username = self.cc.get("dashboard_username", "") - password = self.cc.get("dashboard_password", "") + username = self.context.base_config.get("dashboard_username", "") + password = self.context.base_config.get("dashboard_password", "") # 获得请求体 post_data = request.json if post_data["username"] == username and post_data["password"] == password: @@ -96,11 +86,12 @@ def authenticate(): @self.dashboard_be.post("/api/change_password") def change_password(): - password = self.cc.get("dashboard_password", "") + password = self.context.base_config("dashboard_password", "") # 获得请求体 post_data = request.json if post_data["password"] == password: - self.cc.put("dashboard_password", post_data["new_password"]) + self.context.config_helper.put("dashboard_password", post_data["new_password"]) + self.context.base_config['dashboard_password'] = post_data["new_password"] return Response( status="success", message="修改成功。", @@ -159,7 +150,7 @@ def get_config_outline(): def post_configs(): post_configs = request.json try: - self.funcs["post_configs"](post_configs) + self.on_post_configs(post_configs) return Response( status="success", message="保存成功~ 机器人将在 2 秒内重启以应用新的配置。", @@ -175,7 +166,7 @@ def post_configs(): @self.dashboard_be.get("/api/extensions") def get_plugins(): _plugin_resp = [] - for plugin in self.dashboard_data.plugins: + for plugin in self.context.cached_plugins: _p = plugin.metadata _t = { "name": _p.plugin_name, @@ -197,7 +188,7 @@ def install_plugin(): repo_url = post_data["url"] try: logger.info(f"正在安装插件 {repo_url}") - putil.install_plugin(repo_url, global_object) + self.plugin_manager.install_plugin(repo_url) logger.info(f"安装插件 {repo_url} 成功") return Response( status="success", @@ -221,7 +212,7 @@ def upload_install_plugin(): # save file to temp/ file_path = f"temp/{uuid.uuid4()}.zip" file.save(file_path) - putil.install_plugin_from_file(file_path, global_object) + self.plugin_manager.install_plugin_from_file(file_path) logger.info(f"安装插件 {file.filename} 成功") return Response( status="success", @@ -242,8 +233,7 @@ def uninstall_plugin(): plugin_name = post_data["name"] try: logger.info(f"正在卸载插件 {plugin_name}") - putil.uninstall_plugin( - plugin_name, global_object) + self.plugin_manager.uninstall_plugin(plugin_name) logger.info(f"卸载插件 {plugin_name} 成功") return Response( status="success", @@ -264,7 +254,7 @@ def update_plugin(): plugin_name = post_data["name"] try: logger.info(f"正在更新插件 {plugin_name}") - putil.update_plugin(plugin_name, global_object) + self.plugin_manager.update_plugin(plugin_name) logger.info(f"更新插件 {plugin_name} 成功") return Response( status="success", @@ -284,7 +274,7 @@ def log(): for item in self.ws_clients: try: asyncio.run_coroutine_threadsafe( - self.ws_clients[item].send(request.data.decode()), self.loop) + self.ws_clients[item].send(request.data.decode()), self.loop).result() except Exception as e: pass return 'ok' @@ -292,12 +282,12 @@ def log(): @self.dashboard_be.get("/api/check_update") def get_update_info(): try: - ret = check_update() + ret = self.astrbot_updator.check_update(None, None) return Response( status="success", - message=ret, + message=str(ret), data={ - "has_new_version": ret != "当前已经是最新版本。" # 先这样吧,累了=.= + "has_new_version": ret is not None } ).__dict__ except Exception as e: @@ -317,8 +307,8 @@ def update_project_api(): else: latest = False try: - update_project(latest=latest, version=version) - threading.Thread(target=self.shutdown_bot, args=(3,)).start() + self.astrbot_updator.update(latest=latest, version=version) + threading.Thread(target=self.astrbot_updator._reboot, args=(3, )).start() return Response( status="success", message="更新成功,机器人将在 3 秒内重启。", @@ -335,7 +325,7 @@ def update_project_api(): @self.dashboard_be.get("/api/llm/list") def llm_list(): ret = [] - for llm in self.global_object.llms: + for llm in self.context.llms: ret.append(llm.llm_name) return Response( status="success", @@ -347,10 +337,9 @@ def llm_list(): def llm(): text = request.args["text"] llm = request.args["llm"] - for llm_ in self.global_object.llms: + for llm_ in self.context.llms: if llm_.llm_name == llm: try: - # ret = await llm_.llm_instance.text_chat(text) ret = asyncio.run_coroutine_threadsafe( llm_.llm_instance.text_chat(text), self.loop).result() return Response( @@ -370,10 +359,21 @@ def llm(): message="LLM not found.", data=None ).__dict__ - - def shutdown_bot(self, delay_s: int): - time.sleep(delay_s) - _reboot() + + def on_post_configs(self, post_configs: dict): + try: + if 'base_config' in post_configs: + self.dashboard_helper.save_config( + post_configs['base_config'], namespace='') # 基础配置 + self.dashboard_helper.save_config( + post_configs['config'], namespace=post_configs['namespace']) # 选定配置 + self.dashboard_helper.parse_default_config( + self.dashboard_data, self.context.config_helper.get_all()) + # 重启 + threading.Thread(target=self.astrbot_updator._reboot, + args=(2, ), daemon=True).start() + except Exception as e: + raise e def _get_configs(self, namespace: str): if namespace == "": @@ -387,6 +387,8 @@ def _get_configs(self, namespace: str): ret = [self.dashboard_data.configs['data'][2],] elif namespace == "internal_llm_openai_official": ret = [self.dashboard_data.configs['data'][3],] + elif namespace == "internal_platform_qq_aiocqhttp": + ret = [self.dashboard_data.configs['data'][6],] else: path = f"data/config/{namespace}.json" if not os.path.exists(path): @@ -417,16 +419,22 @@ def _generate_outline(self): "tag": "" }, { - "title": "QQ_OFFICIAL", - "desc": "QQ官方API。支持频道、群(需获得群权限)", + "title": "QQ(官方机器人 API)", + "desc": "QQ官方API。支持频道、群、私聊(需获得群权限)", "namespace": "internal_platform_qq_official", "tag": "" }, { - "title": "go-cqhttp", - "desc": "第三方 QQ 协议实现。支持频道、群", + "title": "QQ(nakuru 适配器)", + "desc": "适用于 go-cqhttp", "namespace": "internal_platform_qq_gocq", "tag": "" + }, + { + "title": "QQ(aiocqhttp 适配器)", + "desc": "适用于 Lagrange, LLBot, Shamrock 等支持反向WS的协议实现。", + "namespace": "internal_platform_qq_aiocqhttp", + "tag": "" } ] }, @@ -443,7 +451,7 @@ def _generate_outline(self): ] } ] - for plugin in self.global_object.cached_plugins: + for plugin in self.context.cached_plugins: for item in outline: if item['type'] == plugin.metadata.plugin_type: item['body'].append({ @@ -454,15 +462,9 @@ def _generate_outline(self): }) return outline - def register(self, name: str): - def decorator(func): - self.funcs[name] = func - return func - return decorator - async def get_log_history(self): try: - with open("logs/astrbot-core/astrbot-core.log", "r", encoding="utf-8") as f: + with open("logs/astrbot/astrbot.log", "r", encoding="utf-8") as f: return f.readlines()[-100:] except Exception as e: logger.warning(f"读取日志历史失败: {e.__str__()}") @@ -486,19 +488,18 @@ async def __handle_msg(self, websocket, path): del self.ws_clients[address] break - def run_ws_server(self, loop): - asyncio.set_event_loop(loop) - loop.run_until_complete(self.ws_server) - loop.run_forever() - - def run(self): - threading.Thread(target=self.run_ws_server, args=(self.loop,)).start() - logger.info("已启动 websocket 服务器") - ip_address = gu.get_local_ip_addresses() - ip_str = f"http://{ip_address}:6185\n\thttp://localhost:6185" - logger.info( - f"\n==================\n您可访问:\n\n\t{ip_str}\n\n来登录可视化面板,默认账号密码为空。\n注意: 所有配置项现已全量迁移至 cmd_config.json 文件下,可登录可视化面板在线修改配置。\n==================\n") + async def ws_server(self): + ws_server = websockets.serve(self.__handle_msg, "0.0.0.0", 6186) + logger.info("WebSocket 服务器已启动。") + await ws_server + def http_server(self): http_server = make_server( '0.0.0.0', 6185, self.dashboard_be, threaded=True) http_server.serve_forever() + + def run_http_server(self): + self.http_server_thread = threading.Thread(target=self.http_server, daemon=True).start() + ip_address = get_local_ip_addresses() + ip_str = f"http://{ip_address}:6185" + logger.info(f"HTTP 服务器已启动,可访问: {ip_str} 等来登录可视化面板。") diff --git a/main.py b/main.py index 5b67508..59b00d9 100644 --- a/main.py +++ b/main.py @@ -1,15 +1,14 @@ import os +import asyncio import sys import warnings import traceback -import threading -from logging import Formatter, Logger -from util.cmd_config import CmdConfig, try_migrate_config +from astrbot.bootstrap import AstrBotBootstrap +from SparkleLogging.utils.core import LogManager +from logging import Formatter warnings.filterwarnings("ignore") - -logger: Logger = None logo_tmpl = """ ___ _______.___________..______ .______ ______ .___________. / \ / | || _ \ | _ \ / __ \ | | @@ -19,89 +18,38 @@ /__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__| """ - -def make_necessary_dirs(): - ''' - 创建必要的目录。 - ''' - os.makedirs("data/config", exist_ok=True) - os.makedirs("temp", exist_ok=True) - -def update_dept(): - ''' - 更新依赖库。 - ''' - # 获取 Python 可执行文件路径 - py = sys.executable - requirements_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "requirements.txt") - print(requirements_path) - # 更新依赖库 - mirror = "https://mirrors.aliyun.com/pypi/simple/" - os.system(f"{py} -m pip install -r {requirements_path} -i {mirror}") def main(): + global logger try: import botpy, logging - import astrbot.core as bot_core # delete qqbotpy's logger for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - except ImportError as import_error: - logger.error(import_error) - logger.error("检测到一些依赖库没有安装。由于兼容性问题,AstrBot 此版本将不会自动为您安装依赖库。请您先自行安装,然后重试。") - logger.info("如何安装?如果:") - logger.info("- Windows 启动器部署且使用启动器下载了 Python的:在 launcher.exe 所在目录下的地址框输入 powershell,然后执行 .\python\python.exe -m pip install .\AstrBot\requirements.txt") - logger.info("- Windows 启动器部署且使用自己之前下载的 Python的:在 launcher.exe 所在目录下的地址框输入 powershell,然后执行 python -m pip install .\AstrBot\requirements.txt") - logger.info("- 自行 clone 源码部署的:python -m pip install -r requirements.txt") - logger.info("- 如果还不会,加群 322154837 ") - input("按任意键退出。") - exit() - except FileNotFoundError as file_not_found: - logger.error(file_not_found) - input("配置文件不存在,请检查是否已经下载配置文件。") - exit() - except BaseException as e: - logger.error(traceback.format_exc()) - input("未知错误。") - exit() + logging.root.removeHandler(handler) - # 启动主程序(cores/qqbot/core.py) - bot_core.init() + bootstrap = AstrBotBootstrap() + asyncio.run(bootstrap.run()) + except KeyboardInterrupt: + logger.info("AstrBot 已退出。") + except BaseException as e: + logger.error(traceback.format_exc()) def check_env(): if not (sys.version_info.major == 3 and sys.version_info.minor >= 9): - logger.error("请使用 Python3.9+ 运行本项目。按任意键退出。") - input("") + logger.error("请使用 Python3.9+ 运行本项目。") exit() - + + os.makedirs("data/config", exist_ok=True) + os.makedirs("temp", exist_ok=True) + if __name__ == "__main__": - update_dept() - make_necessary_dirs() - try_migrate_config() - cc = CmdConfig() - http_proxy = cc.get("http_proxy") - https_proxy = cc.get("https_proxy") - if http_proxy: - os.environ['HTTP_PROXY'] = http_proxy - if https_proxy: - os.environ['HTTPS_PROXY'] = https_proxy - os.environ['NO_PROXY'] = 'https://api.sgroup.qq.com' + check_env() - from SparkleLogging.utils.core import LogManager logger = LogManager.GetLogger( - log_name='astrbot-core', + log_name='astrbot', out_to_console=True, custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S") ) logger.info(logo_tmpl) - logger.info(f"使用代理: {http_proxy}, {https_proxy}") - - check_env() - t = threading.Thread(target=main, daemon=True) - t.start() - try: - t.join() - except KeyboardInterrupt as e: - logger.info("退出 AstrBot。") - exit() + main() diff --git a/model/command/command.py b/model/command/command.py index ce92b28..5637026 100644 --- a/model/command/command.py +++ b/model/command/command.py @@ -4,18 +4,16 @@ import asyncio import json -import util.plugin_util as putil import util.updator from nakuru.entities.components import ( Image ) from util import general_utils as gu -from util.image_render.helper import text_to_image_base from model.provider.provider import Provider from util.cmd_config import CmdConfig as cc from type.message import * -from type.types import GlobalObject +from type.types import Context from type.command import * from type.plugin import * from type.register import * @@ -24,7 +22,7 @@ from SparkleLogging.utils.core import LogManager from logging import Logger -logger: Logger = LogManager.GetLogger(log_name='astrbot-core') +logger: Logger = LogManager.GetLogger(log_name='astrbot') PLATFORM_QQCHAN = 'qqchan' PLATFORM_GOCQ = 'gocq' @@ -33,7 +31,7 @@ class Command: - def __init__(self, provider: Provider, global_object: GlobalObject = None): + def __init__(self, provider: Provider, global_object: Context = None): self.provider = provider self.global_object = global_object @@ -138,7 +136,7 @@ def get_my_id(self, message_obj, platform): except BaseException as e: return False, f"在{platform}上获取你的ID失败,原因: {str(e)}", "plugin" - async def plugin_oper(self, message: str, role: str, ctx: GlobalObject, platform: str): + async def plugin_oper(self, message: str, role: str, ctx: Context, platform: str): l = message.split(" ") if len(l) < 2: p = await text_to_image_base("# 插件指令面板 \n- 安装插件: `plugin i 插件Github地址`\n- 卸载插件: `plugin d 插件名`\n- 重载插件: `plugin reload`\n- 查看插件列表:`plugin l`\n - 更新插件: `plugin u 插件名`\n") diff --git a/model/command/internal_handler.py b/model/command/internal_handler.py new file mode 100644 index 0000000..9a4e6a4 --- /dev/null +++ b/model/command/internal_handler.py @@ -0,0 +1,233 @@ +import aiohttp + +from model.command.manager import CommandManager +from model.plugin.manager import PluginManager +from type.message_event import AstrMessageEvent +from type.command import CommandResult +from type.types import Context +from type.config import VERSION +from SparkleLogging.utils.core import LogManager +from logging import Logger +from nakuru.entities.components import Image + +logger: Logger = LogManager.GetLogger(log_name='astrbot') + + +class InternalCommandHandler: + def __init__(self, manager: CommandManager, plugin_manager: PluginManager) -> None: + self.manager = manager + self.plugin_manager = plugin_manager + + self.manager.register("help", "查看帮助", 10, self.help) + self.manager.register("nick", "设置机器人昵称", 10, self.set_nick) + self.manager.register("update", "更新 AstrBot", 10, self.update) + self.manager.register("plugin", "插件管理", 10, self.plugin) + self.manager.register("reboot", "重启 AstrBot", 10, self.reboot) + self.manager.register("web_search", "网页搜索开关", 10, self.web_search) + self.manager.register("t2i", "文本转图片开关", 10, self.t2i_toggle) + self.manager.register("myid", "获取你在此平台上的ID", 10, self.myid) + + def set_nick(self, message: AstrMessageEvent, context: Context): + message_str = message.message_str + if message.role != "admin": + return CommandResult().message("你没有权限使用该指令。") + l = message_str.split(" ") + if len(l) == 1: + return CommandResult().message("【设置机器人昵称】示例:\n支持多昵称\nnick 昵称1 昵称2 昵称3") + nick = l[1:] + context.config_helper.put("nick_qq", nick) + context.nick = tuple(nick) + return CommandResult( + hit=True, + success=True, + message_chain=f"已经成功将昵称设定为 {nick}", + ) + + def update(self, message: AstrMessageEvent, context: Context): + tokens = self.manager.command_parser.parse(message.message_str) + if message.role != "admin": + return CommandResult( + hit=True, + success=False, + message_chain="你没有权限使用该指令", + ) + if tokens.len == 1: + update_info = context.updator.check_update(None, None) + ret = "" + if not update_info: + ret = f"当前已经是最新版本 v{VERSION}。" + else: + ret = f"发现新版本 v{update_info['version']}。\n- 使用 /update latest 更新到最新版本。\n- 使用 /update vX.X.X 更新到指定版本。" + return CommandResult( + hit=True, + success=False, + message_chain=ret, + ) + else: + if tokens.get(1) == "latest": + try: + context.updator.update() + return CommandResult().message(f"已经成功更新到最新版本 v{update_info['version']}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启") + except BaseException as e: + return CommandResult().message(f"更新失败。原因:{str(e)}") + elif tokens.get(1).startswith("v"): + try: + context.updator.update(version=tokens.get(1)) + return CommandResult().message(f"已经成功更新到版本 v{tokens.get(1)}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启") + except BaseException as e: + return CommandResult().message(f"更新失败。原因:{str(e)}") + else: + return CommandResult().message("update: 参数错误。") + + def reboot(self, message: AstrMessageEvent, context: Context): + if message.role != "admin": + return CommandResult( + hit=True, + success=False, + message_chain="你没有权限使用该指令", + ) + context.updator._reboot(5) + return CommandResult( + hit=True, + success=True, + message_chain="AstrBot 将在 5s 后重启。", + ) + + def plugin(self, message: AstrMessageEvent, context: Context): + tokens = self.manager.command_parser.parse(message.message_str) + if tokens.len == 1: + ret = "# 插件指令面板 \n- 安装插件: `plugin i 插件Github地址`\n- 卸载插件: `plugin d 插件名`\n- 查看插件列表:`plugin l`\n - 更新插件: `plugin u 插件名`\n" + return CommandResult().message(ret) + + if tokens.get(1) == "l": + plugin_list_info = "" + for plugin in context.cached_plugins: + plugin_list_info += f"- `{plugin.metadata.plugin_name}` By {plugin.metadata.author}: {plugin.metadata.desc}\n" + if plugin_list_info.strip() == "": + return CommandResult().message("plugin v: 没有找到插件。") + return CommandResult().message(plugin_list_info) + + elif tokens.get(1) == "d": + if message.role != "admin": + return CommandResult().message("plugin d: 你没有权限使用该指令。") + if tokens.len == 2: + return CommandResult().message("plugin d: 请指定要卸载的插件名。") + plugin_name = tokens.get(2) + try: + self.plugin_manager.uninstall_plugin(plugin_name) + except BaseException as e: + return CommandResult().message(f"plugin d: 卸载插件失败。原因:{str(e)}") + return CommandResult().message(f"plugin d: 已经成功卸载插件 {plugin_name}。") + + elif tokens.get(1) == "i": + if message.role != "admin": + return CommandResult().message("plugin i: 你没有权限使用该指令。") + if tokens.len == 2: + return CommandResult().message("plugin i: 请指定要安装的插件的 Github 地址,或者前往可视化面板安装。") + plugin_url = tokens.get(2) + try: + self.plugin_manager.install_plugin(plugin_url) + except BaseException as e: + return CommandResult().message(f"plugin i: 安装插件失败。原因:{str(e)}") + return CommandResult().message("plugin i: 已经成功安装插件。") + + elif tokens.get(1) == "u": + if message.role != "admin": + return CommandResult().message("plugin u: 你没有权限使用该指令。") + if tokens.len == 2: + return CommandResult().message("plugin u: 请指定要更新的插件名。") + plugin_name = tokens.get(2) + try: + self.plugin_manager.update_plugin(plugin_name) + except BaseException as e: + return CommandResult().message(f"plugin u: 更新插件失败。原因:{str(e)}") + return CommandResult().message(f"plugin u: 已经成功更新插件 {plugin_name}。") + + return CommandResult().message("plugin: 参数错误。") + + async def help(self, message: AstrMessageEvent, context: Context): + notice = "" + try: + async with aiohttp.ClientSession() as session: + async with session.get("https://soulter.top/channelbot/notice.json") as resp: + notice = (await resp.json())["notice"] + except BaseException as e: + logger.warn("An error occurred while fetching astrbot notice. Never mind, it's not important.") + + msg = "# Help Center\n## 指令列表\n" + for key, value in self.manager.commands_handler.items(): + if value.plugin_metadata: + msg += f"`{key}` ({value.plugin_metadata.plugin_name}) - {value.description}\n" + else: msg += f"`{key}` - {value.description}\n" + # plugins + if context.cached_plugins != None: + plugin_list_info = "" + for plugin in context.cached_plugins: + plugin_list_info += f"`{plugin.metadata.plugin_name}` {plugin.metadata.desc}\n" + if plugin_list_info.strip() != "": + msg += "\n## 插件列表\n> 使用plugin v 插件名 查看插件帮助\n" + msg += plugin_list_info + msg += notice + + return CommandResult().message(msg) + + def web_search(self, message: AstrMessageEvent, context: Context): + l = message.message_str.split(' ') + if len(l) == 1: + return CommandResult( + hit=True, + success=True, + message_chain=f"网页搜索功能当前状态: {context.web_search}", + ) + elif l[1] == 'on': + context.web_search = True + return CommandResult( + hit=True, + success=True, + message_chain="已开启网页搜索", + ) + elif l[1] == 'off': + context.web_search = False + return CommandResult( + hit=True, + success=True, + message_chain="已关闭网页搜索", + ) + else: + return CommandResult( + hit=True, + success=False, + message_chain="参数错误", + ) + + def t2i_toggle(self, message: AstrMessageEvent, context: Context): + p = context.config_helper.get("qq_pic_mode", True) + if p: + context.config_helper.put("qq_pic_mode", False) + return CommandResult( + hit=True, + success=True, + message_chain="已关闭文本转图片模式。", + ) + context.config_helper.put("qq_pic_mode", True) + + return CommandResult( + hit=True, + success=True, + message_chain="已开启文本转图片模式。", + ) + + def myid(self, message: AstrMessageEvent, context: Context): + try: + user_id = str(message.message_obj.sender.user_id) + return CommandResult( + hit=True, + success=True, + message_chain=f"你在此平台上的ID:{user_id}", + ) + except BaseException as e: + return CommandResult( + hit=True, + success=False, + message_chain=f"在 {message.platform} 上获取你的ID失败,原因: {str(e)}", + ) diff --git a/model/command/manager.py b/model/command/manager.py new file mode 100644 index 0000000..34b3cc6 --- /dev/null +++ b/model/command/manager.py @@ -0,0 +1,99 @@ +import heapq +import inspect +import traceback +from typing import Dict +from type.types import Context +from type.plugin import PluginMetadata +from type.message_event import AstrMessageEvent +from type.command import CommandResult +from type.register import RegisteredPlugins +from model.command.parser import CommandParser +from model.plugin.command import PluginCommandBridge +from SparkleLogging.utils.core import LogManager +from logging import Logger +from dataclasses import dataclass + +logger: Logger = LogManager.GetLogger(log_name='astrbot') + +@dataclass +class CommandMetadata(): + inner_command: bool + plugin_metadata: PluginMetadata + handler: callable + description: str + +class CommandManager(): + def __init__(self): + self.commands = [] + self.commands_handler: Dict[str, CommandMetadata] = {} + self.command_parser = CommandParser() + + def register(self, + command: str, + description: str, + priority: int, + handler: callable, + plugin_metadata: PluginMetadata = None, + ): + ''' + 优先级越高,越先被处理。 + ''' + if command in self.commands_handler: + raise ValueError(f"Command {command} already exists.") + if not handler: + raise ValueError(f"Handler of {command} is None.") + + heapq.heappush(self.commands, (-priority, command)) + self.commands_handler[command] = CommandMetadata( + inner_command=plugin_metadata == None, + plugin_metadata=plugin_metadata, + handler=handler, + description=description + ) + if plugin_metadata: + logger.info(f"已注册 {plugin_metadata.author}/{plugin_metadata.plugin_name} 的指令 {command}。") + else: + logger.info(f"已注册指令 {command}。") + + def register_from_pcb(self, pcb: PluginCommandBridge): + for request in pcb.plugin_commands_waitlist: + plugin = None + for registered_plugin in pcb.cached_plugins: + if registered_plugin.metadata.plugin_name == request.plugin_name: + plugin = registered_plugin + break + if not plugin: + logger.warning(f"插件 {request.plugin_name} 未找到,无法注册指令 {request.command_name}。") + self.register(request.command_name, request.description, request.priority, request.handler, plugin.metadata) + self.plugin_commands_waitlist = [] + + async def scan_command(self, message_event: AstrMessageEvent, context: Context) -> CommandResult: + message_str = message_event.message_str + for _, command in self.commands: + if message_str.startswith(command): + logger.info(f"触发 {command} 指令。") + command_result = await self.execute_handler(command, message_event, context) + return command_result + + async def execute_handler(self, + command: str, + message_event: AstrMessageEvent, + context: Context) -> CommandResult: + command_metadata = self.commands_handler[command] + handler = command_metadata.handler + # call handler + try: + if inspect.iscoroutinefunction(handler): + command_result = await handler(message_event, context) + else: + command_result = handler(message_event, context) + + if not isinstance(command_result, CommandResult): + raise ValueError(f"Command {command} handler should return CommandResult.") + return command_result + except BaseException as e: + logger.error(traceback.format_exc()) + if not command_metadata.inner_command: + logger.error(f"当执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时,发生了异常。") + else: + logger.error(f"当执行 {command} 指令时,发生了异常。") \ No newline at end of file diff --git a/model/command/openai_official.py b/model/command/openai_official.py deleted file mode 100644 index 34f3f18..0000000 --- a/model/command/openai_official.py +++ /dev/null @@ -1,255 +0,0 @@ -from model.command.command import Command -from model.provider.openai_official import ProviderOpenAIOfficial, MODELS -from util.personality import personalities -from util.general_utils import download_image_by_url -from type.types import GlobalObject -from type.command import CommandItem -from SparkleLogging.utils.core import LogManager -from logging import Logger -from openai._exceptions import NotFoundError -from nakuru.entities.components import Image - -logger: Logger = LogManager.GetLogger(log_name='astrbot-core') - -class CommandOpenAIOfficial(Command): - def __init__(self, provider: ProviderOpenAIOfficial, global_object: GlobalObject): - self.provider = provider - self.global_object = global_object - self.personality_str = "" - self.commands = [ - CommandItem("reset", self.reset, "重置 LLM 会话。", "内置"), - CommandItem("his", self.his, "查看与 LLM 的历史记录。", "内置"), - CommandItem("status", self.status, "查看 GPT 配置信息和用量状态。", "内置"), - ] - super().__init__(provider, global_object) - - async def check_command(self, - message: str, - session_id: str, - role: str, - platform: str, - message_obj): - self.platform = platform - - # 检查基础指令 - hit, res = await super().check_command( - message, - session_id, - role, - platform, - message_obj - ) - - logger.debug(f"基础指令hit: {hit}, res: {res}") - - # 这里是这个 LLM 的专属指令 - if hit: - return True, res - if self.command_start_with(message, "reset", "重置"): - return True, await self.reset(session_id, message) - elif self.command_start_with(message, "his", "历史"): - return True, self.his(message, session_id) - elif self.command_start_with(message, "status"): - return True, self.status(session_id) - elif self.command_start_with(message, "help", "帮助"): - return True, await self.help() - elif self.command_start_with(message, "unset"): - return True, self.unset(session_id) - elif self.command_start_with(message, "set"): - return True, self.set(message, session_id) - elif self.command_start_with(message, "update"): - return True, self.update(message, role) - elif self.command_start_with(message, "画", "draw"): - return True, await self.draw(message) - elif self.command_start_with(message, "switch"): - return True, await self.switch(message) - elif self.command_start_with(message, "models"): - return True, await self.print_models() - elif self.command_start_with(message, "model"): - return True, await self.set_model(message) - return False, None - - async def get_models(self): - try: - models = await self.provider.client.models.list() - except NotFoundError as e: - bu = str(self.provider.client.base_url) - self.provider.client.base_url = bu + "/v1" - models = await self.provider.client.models.list() - finally: - return filter(lambda x: x.id.startswith("gpt"), models.data) - - async def print_models(self): - models = await self.get_models() - i = 1 - ret = "OpenAI GPT 类可用模型" - for model in models: - ret += f"\n{i}. {model.id}" - i += 1 - ret += "\nTips: 使用 /model 模型名/编号,即可实时更换模型。如目标模型不存在于上表,请输入模型名。" - logger.debug(ret) - return True, ret, "models" - - - async def set_model(self, message: str): - l = message.split(" ") - if len(l) == 1: - return True, "请输入 /model 模型名/编号", "model" - model = str(l[1]) - if model.isdigit(): - models = await self.get_models() - models = list(models) - if int(model) <= len(models) and int(model) >= 1: - model = models[int(model)-1] - self.provider.set_model(model.id) - return True, f"模型已设置为 {model.id}", "model" - else: - self.provider.set_model(model) - return True, f"模型已设置为 {model} (自定义)", "model" - - - async def help(self): - commands = super().general_commands() - commands['画'] = '调用 OpenAI DallE 模型生成图片' - commands['/set'] = '人格设置面板' - commands['/status'] = '查看 Api Key 状态和配置信息' - commands['/token'] = '查看本轮会话 token' - commands['/reset'] = '重置当前与 LLM 的会话,但保留人格(system prompt)' - commands['/reset p'] = '重置当前与 LLM 的会话,并清除人格。' - commands['/models'] = '获取当前可用的模型' - commands['/model'] = '更换模型' - - return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help" - - async def reset(self, session_id: str, message: str = "reset"): - if self.provider is None: - return False, "未启用 OpenAI 官方 API", "reset" - l = message.split(" ") - if len(l) == 1: - await self.provider.forget(session_id, keep_system_prompt=True) - return True, "重置成功", "reset" - if len(l) == 2 and l[1] == "p": - await self.provider.forget(session_id) - - def his(self, message: str, session_id: str): - if self.provider is None: - return False, "未启用 OpenAI 官方 API", "his" - size_per_page = 3 - page = 1 - l = message.split(" ") - if len(l) == 2: - try: - page = int(l[1]) - except BaseException as e: - return True, "页码不合法", "his" - contexts, total_num = self.provider.dump_contexts_page(session_id, size_per_page, page=page) - t_pages = total_num // size_per_page + 1 - return True, f"历史记录如下:\n{contexts}\n第 {page} 页 | 共 {t_pages} 页\n*输入 /his 2 跳转到第 2 页", "his" - - def status(self, session_id: str): - if self.provider is None: - return False, "未启用 OpenAI 官方 API", "status" - keys_data = self.provider.get_keys_data() - ret = "OpenAI Key" - for k in keys_data: - status = "🟢" if keys_data[k] else "🔴" - ret += "\n|- " + k[:8] + " " + status - - conf = self.provider.get_configs() - ret += "\n当前模型:" + conf['model'] - if conf['model'] in MODELS: - ret += "\n最大上下文窗口:" + str(MODELS[conf['model']]) + " tokens" - - if session_id in self.provider.session_memory and len(self.provider.session_memory[session_id]): - ret += "\n你的会话上下文:" + str(self.provider.session_memory[session_id][-1]['usage_tokens']) + " tokens" - - return True, ret, "status" - - async def switch(self, message: str): - ''' - 切换账号 - ''' - l = message.split(" ") - if len(l) == 1: - _, ret, _ = self.status() - curr_ = self.provider.get_curr_key() - if curr_ is None: - ret += "当前您未选择账号。输入/switch <账号序号>切换账号。" - else: - ret += f"当前您选择的账号为:{curr_[-8:]}。输入/switch <账号序号>切换账号。" - return True, ret, "switch" - elif len(l) == 2: - try: - key_stat = self.provider.get_keys_data() - index = int(l[1]) - if index > len(key_stat) or index < 1: - return True, "账号序号不合法。", "switch" - else: - try: - new_key = list(key_stat.keys())[index-1] - self.provider.set_key(new_key) - except BaseException as e: - return True, "账号切换失败,原因: " + str(e), "switch" - return True, f"账号切换成功。", "switch" - except BaseException as e: - return True, "未知错误: "+str(e), "switch" - else: - return True, "参数过多。", "switch" - - def unset(self, session_id: str): - if self.provider is None: - return False, "未启用 OpenAI 官方 API", "unset" - self.provider.curr_personality = {} - self.provider.forget(session_id) - return True, "已清除人格并重置历史记录。", "unset" - - def set(self, message: str, session_id: str): - if self.provider is None: - return False, "未启用 OpenAI 官方 API", "set" - l = message.split(" ") - if len(l) == 1: - return True, f"【人格文本由PlexPt开源项目awesome-chatgpt-pr \ - ompts-zh提供】\n设置人格: \n/set 人格名。例如/set 编剧\n人格列表: /set list\n人格详细信息: \ - /set view 人格名\n自定义人格: /set 人格文本\n重置会话(清除人格): /reset\n重置会话(保留人格): /reset p\n【当前人格】: {str(self.provider.curr_personality)}", "set" - elif l[1] == "list": - msg = "人格列表:\n" - for key in personalities.keys(): - msg += f" |-{key}\n" - msg += '\n\n*输入/set view 人格名查看人格详细信息' - msg += '\n*不定时更新人格库,请及时更新本项目。' - return True, msg, "set" - elif l[1] == "view": - if len(l) == 2: - return True, "请输入/set view 人格名", "set" - ps = l[2].strip() - if ps in personalities: - msg = f"人格{ps}的详细信息:\n" - msg += f"{personalities[ps]}\n" - else: - msg = f"人格{ps}不存在" - return True, msg, "set" - else: - ps = l[1].strip() - if ps in personalities: - self.provider.curr_personality = { - 'name': ps, - 'prompt': personalities[ps] - } - self.provider.personality_set(ps, session_id) - return True, f"人格{ps}已设置。", "set" - else: - self.provider.curr_personality = { - 'name': '自定义人格', - 'prompt': ps - } - self.provider.personality_set(ps, session_id) - return True, f"自定义人格已设置。 \n人格信息: {ps}", "set" - - async def draw(self, message: str): - if self.provider is None: - return False, "未启用 OpenAI 官方 API", "draw" - message = message.removeprefix("/").removeprefix("画") - img_url = await self.provider.image_generate(message) - p = await download_image_by_url(url=img_url) - with open(p, 'rb') as f: - return True, [Image.fromBytes(f.read())], "draw" \ No newline at end of file diff --git a/model/command/openai_official_handler.py b/model/command/openai_official_handler.py new file mode 100644 index 0000000..759149b --- /dev/null +++ b/model/command/openai_official_handler.py @@ -0,0 +1,152 @@ +from model.command.manager import CommandManager +from type.message_event import AstrMessageEvent +from type.command import CommandResult +from type.types import Context +from SparkleLogging.utils.core import LogManager +from logging import Logger +from nakuru.entities.components import Image +from model.provider.openai_official import ProviderOpenAIOfficial, MODELS +from util.personality import personalities +from util.io import download_image_by_url + +logger: Logger = LogManager.GetLogger(log_name='astrbot') + + +class OpenAIOfficialCommandHandler(): + def __init__(self, manager: CommandManager) -> None: + self.manager = manager + + self.provider = None + + self.manager.register("reset", "重置会话", 10, self.reset) + self.manager.register("his", "查看历史记录", 10, self.his) + self.manager.register("status", "查看当前状态", 10, self.status) + self.manager.register("switch", "切换账号", 10, self.switch) + self.manager.register("unset", "清除个性化人格设置", 10, self.unset) + self.manager.register("set", "设置个性化人格", 10, self.set) + self.manager.register("draw", "调用 DallE 模型画图", 10, self.draw) + self.manager.register("画", "调用 DallE 模型画图", 10, self.draw) + + def set_provider(self, provider): + self.provider = provider + + async def reset(self, message: AstrMessageEvent, context: Context): + tokens = self.manager.command_parser.parse(message.message_str) + if tokens.len == 1: + await self.provider.forget(message.session_id, keep_system_prompt=True) + return CommandResult().message("重置成功") + elif tokens.get(1) == 'p': + await self.provider.forget(message.session_id) + + def his(self, message: AstrMessageEvent, context: Context): + tokens = self.manager.command_parser.parse(message.message_str) + size_per_page = 3 + page = 1 + if tokens.len == 2: + try: + page = int(tokens.get(1)) + except BaseException as e: + return CommandResult().message("页码格式错误") + contexts, total_num = self.provider.dump_contexts_page(message.session_id, size_per_page, page=page) + t_pages = total_num // size_per_page + 1 + return CommandResult().message(f"历史记录如下:\n{contexts}\n第 {page} 页 | 共 {t_pages} 页\n*输入 /his 2 跳转到第 2 页") + + def status(self, message: AstrMessageEvent, context: Context): + keys_data = self.provider.get_keys_data() + ret = "OpenAI Key" + for k in keys_data: + status = "🟢" if keys_data[k] else "🔴" + ret += "\n|- " + k[:8] + " " + status + + conf = self.provider.get_configs() + ret += "\n当前模型: " + conf['model'] + if conf['model'] in MODELS: + ret += "\n最大上下文窗口: " + str(MODELS[conf['model']]) + " tokens" + + if message.session_id in self.provider.session_memory and len(self.provider.session_memory[message.session_id]): + ret += "\n你的会话上下文: " + str(self.provider.session_memory[message.session_id][-1]['usage_tokens']) + " tokens" + + return CommandResult().message(ret) + + async def switch(self, message: AstrMessageEvent, context: Context): + ''' + 切换账号 + ''' + tokens = self.manager.command_parser.parse(message.message_str) + if tokens.len == 1: + _, ret, _ = self.status() + curr_ = self.provider.get_curr_key() + if curr_ is None: + ret += "当前您未选择账号。输入/switch <账号序号>切换账号。" + else: + ret += f"当前您选择的账号为:{curr_[-8:]}。输入/switch <账号序号>切换账号。" + return CommandResult().message(ret) + elif tokens.len == 2: + try: + key_stat = self.provider.get_keys_data() + index = int(tokens.get(1)) + if index > len(key_stat) or index < 1: + return CommandResult().message("账号序号错误。") + else: + try: + new_key = list(key_stat.keys())[index-1] + self.provider.set_key(new_key) + except BaseException as e: + return CommandResult().message("切换账号未知错误: "+str(e)) + return CommandResult().message("切换账号成功。") + except BaseException as e: + return CommandResult().message("切换账号错误。") + else: + return CommandResult().message("参数过多。") + + def unset(self, message: AstrMessageEvent, context: Context): + self.provider.curr_personality = {} + self.provider.forget(message.session_id) + return CommandResult().message("已清除个性化设置。") + + + def set(self, message: AstrMessageEvent, context: Context): + l = message.message_str.split(" ") + if len(l) == 1: + return CommandResult().message("【人格文本由PlexPt开源项目awesome-chatgpt-prompts-zh提供】\n设置人格: \n/set 人格名。例如/set 编剧\n人格列表: /set list\n人格详细信息: /set view 人格名\n自定义人格: /set 人格文本\n重置会话(清除人格): /reset\n重置会话(保留人格): /reset p\n【当前人格】: " + str(self.provider.curr_personality)) + elif l[1] == "list": + msg = "人格列表:\n" + for key in personalities.keys(): + msg += f" |-{key}\n" + msg += '\n\n*输入/set view 人格名查看人格详细信息' + return CommandResult().message(msg) + elif l[1] == "view": + if len(l) == 2: + return CommandResult().message("请输入人格名") + ps = l[2].strip() + if ps in personalities: + msg = f"人格{ps}的详细信息:\n" + msg += f"{personalities[ps]}\n" + else: + msg = f"人格{ps}不存在" + return CommandResult().message(msg) + else: + ps = l[1].strip() + if ps in personalities: + self.provider.curr_personality = { + 'name': ps, + 'prompt': personalities[ps] + } + self.provider.personality_set(ps, message.session_id) + return CommandResult().message(f"人格已设置。 \n人格信息: {ps}") + else: + self.provider.curr_personality = { + 'name': '自定义人格', + 'prompt': ps + } + self.provider.personality_set(ps, message.session_id) + return CommandResult().message(f"人格已设置。 \n人格信息: {ps}") + + async def draw(self, message: AstrMessageEvent, context: Context): + message = message.message_str.removeprefix("画") + img_url = await self.provider.image_generate(message) + p = await download_image_by_url(url=img_url) + with open(p, 'rb') as f: + return CommandResult( + message_chain=[Image.fromBytes(f.read())], + ) \ No newline at end of file diff --git a/model/command/parser.py b/model/command/parser.py new file mode 100644 index 0000000..db289a1 --- /dev/null +++ b/model/command/parser.py @@ -0,0 +1,19 @@ +class CommandTokens(): + def __init__(self) -> None: + self.tokens = [] + self.len = 0 + + def get(self, idx: int): + if idx >= self.len: + return None + return self.tokens[idx].strip() + +class CommandParser(): + def __init__(self): + pass + + def parse(self, message: str): + cmd_tokens = CommandTokens() + cmd_tokens.tokens = message.split(" ") + cmd_tokens.len = len(cmd_tokens.tokens) + return cmd_tokens \ No newline at end of file diff --git a/model/platform/__init__.py b/model/platform/__init__.py new file mode 100644 index 0000000..2df17e3 --- /dev/null +++ b/model/platform/__init__.py @@ -0,0 +1,70 @@ +import abc +from typing import Union, Any, List +from nakuru.entities.components import Plain, At, Image, BaseMessageComponent +from type.astrbot_message import AstrBotMessage + + +class Platform(): + def __init__(self) -> None: + pass + + @abc.abstractmethod + async def handle_msg(self, message: AstrBotMessage): + ''' + 处理到来的消息 + ''' + pass + + @abc.abstractmethod + async def reply_msg(self, message: AstrBotMessage, + result_message: List[BaseMessageComponent]): + ''' + 回复用户唤醒机器人的消息。(被动回复) + ''' + pass + + @abc.abstractmethod + async def send_msg(self, target: Any, result_message: Union[List[BaseMessageComponent], str]): + ''' + 发送消息(主动) + ''' + pass + + def parse_message_outline(self, message: AstrBotMessage) -> str: + ''' + 将消息解析成大纲消息形式,如: xxxxx[图片]xxxxx。用于输出日志等。 + ''' + if isinstance(message, str): + return message + ret = '' + parsed = message if isinstance(message, list) else message.message + try: + for node in parsed: + if isinstance(node, Plain): + ret += node.text.replace('\n', ' ') + elif isinstance(node, At): + ret += f'[At: {node.name}/{node.qq}]' + elif isinstance(node, Image): + ret += '[图片]' + except Exception as e: + pass + return ret[:100] if len(ret) > 100 else ret + + def check_nick(self, message_str: str) -> bool: + if self.context.nick: + for nick in self.context.nick: + if nick and message_str.strip().startswith(nick): + return True + return False + + async def convert_to_t2i_chain(self, message_result: list) -> list: + plain_str = "" + rendered_images = [] + for i in message_result: + if isinstance(i, Plain): + plain_str += i.text + if plain_str and len(plain_str) > 50: + p = await self.context.image_renderer.render(plain_str) + rendered_images.append(Image.fromFileSystem(p)) + return rendered_images + \ No newline at end of file diff --git a/model/platform/_message_parse.py b/model/platform/_message_parse.py deleted file mode 100644 index bf2e2ee..0000000 --- a/model/platform/_message_parse.py +++ /dev/null @@ -1,114 +0,0 @@ -from nakuru.entities.components import Plain, At, Image, BaseMessageComponent -from nakuru import ( - GuildMessage, - GroupMessage, - FriendMessage -) -import botpy.message -from type.message import * -from typing import List, Union -from util.general_utils import save_temp_img -import time, base64 - -# QQ官方消息类型转换 - - -def qq_official_message_parse(message: List[BaseMessageComponent]): - plain_text = "" - image_path = None # only one img supported - for i in message: - if isinstance(i, Plain): - plain_text += i.text - elif isinstance(i, Image) and not image_path: - if i.path: - image_path = i.path - elif i.file and i.file.startswith("base64://"): - img_data = base64.b64decode(i.file[9:]) - image_path = save_temp_img(img_data) - else: - image_path = save_temp_img(i.file) - return plain_text, image_path - -# QQ官方消息类型 2 AstrBotMessage - - -def qq_official_message_parse_rev(message: Union[botpy.message.Message, botpy.message.GroupMessage], - message_type: MessageType) -> AstrBotMessage: - abm = AstrBotMessage() - abm.type = message_type - abm.timestamp = int(time.time()) - abm.raw_message = message - abm.message_id = message.id - abm.tag = "qqchan" - msg: List[BaseMessageComponent] = [] - - if message_type == MessageType.GROUP_MESSAGE: - abm.sender = MessageMember( - message.author.member_openid, - "" - ) - abm.message_str = message.content.strip() - abm.self_id = "unknown_selfid" - - msg.append(Plain(abm.message_str)) - if message.attachments: - for i in message.attachments: - if i.content_type.startswith("image"): - url = i.url - if not url.startswith("http"): - url = "https://"+url - img = Image.fromURL(url) - msg.append(img) - abm.message = msg - - elif message_type == MessageType.GUILD_MESSAGE or message_type == MessageType.FRIEND_MESSAGE: - # 目前对于 FRIEND_MESSAGE 只处理频道私聊 - try: - abm.self_id = str(message.mentions[0].id) - except: - abm.self_id = "" - - plain_content = message.content.replace( - "<@!"+str(abm.self_id)+">", "").strip() - msg.append(Plain(plain_content)) - if message.attachments: - for i in message.attachments: - if i.content_type.startswith("image"): - url = i.url - if not url.startswith("http"): - url = "https://"+url - img = Image.fromURL(url) - msg.append(img) - abm.message = msg - abm.message_str = plain_content - abm.sender = MessageMember( - str(message.author.id), - str(message.author.username) - ) - else: - raise ValueError(f"Unknown message type: {message_type}") - return abm - - -def nakuru_message_parse_rev(message: Union[GuildMessage, GroupMessage, FriendMessage]) -> AstrBotMessage: - abm = AstrBotMessage() - abm.type = MessageType(message.type) - abm.timestamp = int(time.time()) - abm.raw_message = message - abm.message_id = message.message_id - - plain_content = "" - for i in message.message: - if isinstance(i, Plain): - plain_content += i.text - abm.message_str = plain_content - - abm.self_id = str(message.self_id) - abm.sender = MessageMember( - str(message.sender.user_id), - str(message.sender.nickname) - ) - abm.tag = "gocq" - abm.message = message.message - - return abm diff --git a/model/platform/_message_result.py b/model/platform/_message_result.py deleted file mode 100644 index c4421e1..0000000 --- a/model/platform/_message_result.py +++ /dev/null @@ -1,9 +0,0 @@ -from dataclasses import dataclass -from typing import Union, Optional - - -@dataclass -class MessageResult(): - result_message: Union[str, list] - is_command_call: Optional[bool] = False - callback: Optional[callable] = None diff --git a/model/platform/_platfrom.py b/model/platform/_platfrom.py deleted file mode 100644 index 3efbb98..0000000 --- a/model/platform/_platfrom.py +++ /dev/null @@ -1,73 +0,0 @@ -import abc -from typing import Union -from nakuru import ( - GuildMessage, - GroupMessage, - FriendMessage, -) -from nakuru.entities.components import Plain, At, Image - - -class Platform(): - def __init__(self, message_handler: callable) -> None: - ''' - 初始化平台的各种接口 - ''' - self.message_handler = message_handler - self.cnt_receive = 0 - self.cnt_reply = 0 - pass - - @abc.abstractmethod - async def handle_msg(self): - ''' - 处理到来的消息 - ''' - self.cnt_receive += 1 - pass - - @abc.abstractmethod - async def reply_msg(self): - ''' - 回复消息(被动发送) - ''' - self.cnt_reply += 1 - pass - - @abc.abstractmethod - async def send_msg(self, target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]): - ''' - 发送消息(主动发送) - ''' - self.cnt_reply += 1 - pass - - @abc.abstractmethod - async def send(self, target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]): - ''' - 发送消息(主动发送)同 send_msg() - ''' - self.cnt_reply += 1 - pass - - def parse_message_outline(self, message: Union[GuildMessage, GroupMessage, FriendMessage, str, list]) -> str: - ''' - 将消息解析成大纲消息形式。 - 如: xxxxx[图片]xxxxx - ''' - if isinstance(message, str): - return message - ret = '' - ls_to_parse = message if isinstance(message, list) else message.message - try: - for node in ls_to_parse: - if isinstance(node, Plain): - ret += node.text - elif isinstance(node, At): - ret += f'[At: {node.name}/{node.qq}]' - elif isinstance(node, Image): - ret += '[图片]' - except Exception as e: - pass - ret.replace('\n', '') - return ret diff --git a/model/platform/manager.py b/model/platform/manager.py new file mode 100644 index 0000000..bec7df7 --- /dev/null +++ b/model/platform/manager.py @@ -0,0 +1,87 @@ +import time, threading + +from util.io import port_checker +from type.register import RegisteredPlatform +from type.types import Context +from SparkleLogging.utils.core import LogManager +from logging import Logger +from astrbot.message.handler import MessageHandler + +logger: Logger = LogManager.GetLogger(log_name='astrbot') + + +class PlatformManager(): + def __init__(self, context: Context, message_handler: MessageHandler) -> None: + self.context = context + self.config = context.base_config + self.msg_handler = message_handler + + def load_platforms(self): + tasks = [] + + if 'gocqbot' in self.config and self.config['gocqbot']['enable']: + logger.info("启用 QQ(nakuru 适配器)") + tasks.append(self.gocq_bot()) + + if 'aiocqhttp' in self.config and self.config['aiocqhttp']['enable']: + logger.info("启用 QQ(aiocqhttp 适配器)") + tasks.append(self.aiocq_bot()) + + # QQ频道 + if 'qqbot' in self.config and self.config['qqbot']['enable'] and self.config['qqbot']['appid'] != None: + logger.info("启用 QQ(官方 API) 机器人消息平台") + tasks.append(self.qqchan_bot()) + + return tasks + + def gocq_bot(self): + ''' + 运行 QQ(nakuru 适配器) + ''' + from model.platform.qq_nakuru import QQGOCQ + noticed = False + host = self.config.get("gocq_host", "127.0.0.1") + port = self.config.get("gocq_websocket_port", 6700) + http_port = self.config.get("gocq_http_port", 5700) + logger.info( + f"正在检查连接...host: {host}, ws port: {port}, http port: {http_port}") + while True: + if not port_checker(port=port, host=host) or not port_checker(port=http_port, host=host): + if not noticed: + noticed = True + logger.warning( + f"连接到{host}:{port}(或{http_port})失败。程序会每隔 5s 自动重试。") + time.sleep(5) + else: + logger.info("nakuru 适配器已连接。") + break + try: + qq_gocq = QQGOCQ(self.context, self.msg_handler) + self.context.platforms.append(RegisteredPlatform( + platform_name="gocq", platform_instance=qq_gocq, origin="internal")) + return qq_gocq.run() + except BaseException as e: + logger.error("启动 nakuru 适配器时出现错误: " + str(e)) + + def aiocq_bot(self): + ''' + 运行 QQ(aiocqhttp 适配器) + ''' + from model.platform.qq_aiocqhttp import AIOCQHTTP + qq_aiocqhttp = AIOCQHTTP(self.context, self.msg_handler) + self.context.platforms.append(RegisteredPlatform( + platform_name="aiocqhttp", platform_instance=qq_aiocqhttp, origin="internal")) + return qq_aiocqhttp.run_aiocqhttp() + + def qqchan_bot(self): + ''' + 运行 QQ 官方机器人适配器 + ''' + try: + from model.platform.qq_official import QQOfficial + qqchannel_bot = QQOfficial(self.context, self.msg_handler) + self.context.platforms.append(RegisteredPlatform( + platform_name="qqchan", platform_instance=qqchannel_bot, origin="internal")) + return qqchannel_bot.run() + except BaseException as e: + logger.error("启动 QQ官方机器人适配器时出现错误: " + str(e)) diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py new file mode 100644 index 0000000..777f4cc --- /dev/null +++ b/model/platform/qq_aiocqhttp.py @@ -0,0 +1,146 @@ +import time +import traceback +import logging +from aiocqhttp import CQHttp, Event +from . import Platform +from type.astrbot_message import * +from type.message_event import * +from typing import Union, List, Dict +from nakuru.entities.components import * +from SparkleLogging.utils.core import LogManager +from logging import Logger +from astrbot.message.handler import MessageHandler + +logger: Logger = LogManager.GetLogger(log_name='astrbot') + +class AIOCQHTTP(Platform): + def __init__(self, context: Context, message_handler: MessageHandler) -> None: + self.message_handler = message_handler + self.waiting = {} + self.context = context + self.unique_session = self.context.unique_session + self.announcement = self.context.base_config.get("announcement", "欢迎新人!") + self.host = self.context.base_config['aiocqhttp']['ws_reverse_host'] + self.port = self.context.base_config['aiocqhttp']['ws_reverse_port'] + + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + def compat_onebot2astrbotmsg(self, event: Event) -> AstrBotMessage: + + abm = AstrBotMessage() + abm.self_id = str(event.self_id) + abm.tag = "aiocqhttp" + + abm.sender = MessageMember(str(event.sender['user_id']), event.sender['nickname']) + + if event['message_type'] == 'group': + abm.type = MessageType.GROUP_MESSAGE + elif event['message_type'] == 'private': + abm.type = MessageType.FRIEND_MESSAGE + + if self.unique_session: + abm.session_id = abm.sender.user_id + else: + abm.session_id = str(event.group_id) if abm.type == MessageType.GROUP_MESSAGE else abm.sender.user_id + + abm.message_id = str(event.message_id) + abm.message = [] + + message_str = "" + for m in event.message: + t = m['type'] + a = None + if t == 'at': + a = At(**m['data']) + abm.message.append(a) + if t == 'text': + a = Plain(text=m['data']['text']) + message_str += m['data']['text'].strip() + abm.message.append(a) + if t == 'image': + a = Image(file=m['data']['file']) + abm.message.append(a) + abm.timestamp = int(time.time()) + abm.message_str = message_str + abm.raw_message = event + return abm + + def run_aiocqhttp(self): + if not self.host or not self.port: + return + self.bot = CQHttp(use_ws_reverse=True) + + @self.bot.on_message('group') + async def group(event: Event): + abm = self.compat_onebot2astrbotmsg(event) + if abm: + await self.handle_msg(event, abm) + return {'reply': event.message} + + @self.bot.on_message('private') + async def private(event: Event): + abm = self.compat_onebot2astrbotmsg(event) + if abm: + await self.handle_msg(event, abm) + return {'reply': event.message} + + return self.bot.run_task(host=self.host, port=int(self.port)) + + async def handle_msg(self, message: AstrBotMessage): + logger.info( + f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}") + + # 解析 role + sender_id = str(message.sender.user_id) + if sender_id == self.context.config_helper.get('admin_qq', '') or \ + sender_id in self.context.config_helper.get('other_admins', []): + role = 'admin' + else: + role = 'member' + + # construct astrbot message event + ame = AstrMessageEvent().from_astrbot_message(message, self.context, "gocq", message.session_id, role) + + # transfer control to message handler + message_result = await self.message_handler.handle(ame) + if not message_result: return + + await self.reply_msg(message, message_result.result_message) + if message_result.callback: + message_result.callback() + + # 如果是等待回复的消息 + if message.session_id in self.waiting and self.waiting[message.session_id] == '': + self.waiting[message.session_id] = message + + + async def reply_msg(self, + message: AstrBotMessage, + result_message: list): + await super().reply_msg() + """ + 回复用户唤醒机器人的消息。(被动回复) + """ + logger.info( + f"{message.sender.user_id} <- {self.parse_message_outline(message)}") + + if isinstance(res, str): + res = [Plain(text=res), ] + + # if image mode, put all Plain texts into a new picture. + if self.context.config_helper.get("qq_pic_mode", False) and isinstance(res, list): + rendered_images = await self.convert_to_t2i_chain(res) + if rendered_images: + try: + await self._reply(message, rendered_images) + return + except BaseException as e: + logger.warn(traceback.format_exc()) + logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。") + + async def _reply(self, message: AstrBotMessage, message_chain: List[BaseMessageComponent]): + if isinstance(message_chain, str): + message_chain = [Plain(text=message_chain), ] + + await self.bot.send(message.raw_message, message_chain) \ No newline at end of file diff --git a/model/platform/qq_gocq.py b/model/platform/qq_gocq.py deleted file mode 100644 index 593225e..0000000 --- a/model/platform/qq_gocq.py +++ /dev/null @@ -1,314 +0,0 @@ -from nakuru.entities.components import Plain, At, Image, Node -from util import general_utils as gu -from util.image_render.helper import text_to_image_base -from util.cmd_config import CmdConfig -import asyncio -from nakuru import ( - CQHTTP, - GuildMessage, - GroupMessage, - FriendMessage, - GroupMemberIncrease, - Notify -) -from typing import Union -from type.types import GlobalObject -import time - -from ._platfrom import Platform -from ._message_parse import nakuru_message_parse_rev -from type.message import * -from SparkleLogging.utils.core import LogManager -from logging import Logger - -logger: Logger = LogManager.GetLogger(log_name='astrbot-core') - - -class FakeSource: - def __init__(self, type, group_id): - self.type = type - self.group_id = group_id - - -class QQGOCQ(Platform): - def __init__(self, cfg: dict, message_handler: callable, global_object: GlobalObject) -> None: - super().__init__(message_handler) - - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - self.waiting = {} - self.cc = CmdConfig() - self.cfg = cfg - - self.context = global_object - - self.unique_session = cfg['uniqueSessionMode'] - - self.client = CQHTTP( - host=self.cc.get("gocq_host", "127.0.0.1"), - port=self.cc.get("gocq_websocket_port", 6700), - http_port=self.cc.get("gocq_http_port", 5700), - ) - gocq_app = self.client - - self.announcement = self.cc.get("announcement", "欢迎新人!") - - @gocq_app.receiver("GroupMessage") - async def _(app: CQHTTP, source: GroupMessage): - if self.cc.get("gocq_react_group", True): - abm = nakuru_message_parse_rev(source) - if isinstance(source.message[0], Plain): - await self.handle_msg(abm) - elif isinstance(source.message[0], At): - if source.message[0].qq == source.self_id: - await self.handle_msg(abm) - else: - return - - @gocq_app.receiver("FriendMessage") - async def _(app: CQHTTP, source: FriendMessage): - if self.cc.get("gocq_react_friend", True): - abm = nakuru_message_parse_rev(source) - if isinstance(source.message[0], Plain): - await self.handle_msg(abm) - else: - return - - @gocq_app.receiver("GroupMemberIncrease") - async def _(app: CQHTTP, source: GroupMemberIncrease): - if self.cc.get("gocq_react_group_increase", True): - await app.sendGroupMessage(source.group_id, [ - Plain(text=self.announcement) - ]) - - # @gocq_app.receiver("Notify") - # async def _(app: CQHTTP, source: Notify): - # print(source) - # if source.sub_type == "poke" and source.target_id == source.self_id: - # await self.handle_msg(source) - - @gocq_app.receiver("GuildMessage") - async def _(app: CQHTTP, source: GuildMessage): - if self.cc.get("gocq_react_guild", True): - abm = nakuru_message_parse_rev(source) - if isinstance(source.message[0], Plain): - await self.handle_msg(abm) - elif isinstance(source.message[0], At): - if source.message[0].qq == source.self_tiny_id: - await self.handle_msg(abm) - else: - return - - def run(self): - self.client.run() - - async def handle_msg(self, message: AstrBotMessage): - await super().handle_msg() - logger.info( - f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}") - - assert isinstance(message.raw_message, - (GroupMessage, FriendMessage, GuildMessage)) - is_group = message.type != MessageType.FRIEND_MESSAGE - - # 判断是否响应消息 - resp = False - if not is_group: - resp = True - else: - for i in message.message: - if isinstance(i, At): - if message.type.value == "GuildMessage": - if str(i.qq) == str(message.raw_message.user_id) or str(i.qq) == str(message.raw_message.self_tiny_id): - resp = True - if message.type.value == "FriendMessage": - if str(i.qq) == str(message.self_id): - resp = True - if message.type.value == "GroupMessage": - if str(i.qq) == str(message.self_id): - resp = True - elif isinstance(i, Plain) and self.context.nick: - for nick in self.context.nick: - if nick != '' and i.text.strip().startswith(nick): - resp = True - break - - if not resp: - return - - # 解析 session_id - if self.unique_session or not is_group: - session_id = message.raw_message.user_id - elif message.type == MessageType.GROUP_MESSAGE: - session_id = message.raw_message.group_id - elif message.type == MessageType.GUILD_MESSAGE: - session_id = message.raw_message.channel_id - else: - session_id = message.raw_message.user_id - - message.session_id = session_id - - # 解析 role - sender_id = str(message.raw_message.user_id) - if sender_id == self.cc.get('admin_qq', '') or \ - sender_id in self.cc.get('other_admins', []): - role = 'admin' - else: - role = 'member' - - message_result = await self.message_handler( - message=message, - session_id=session_id, - role=role, - platform='gocq' - ) - - if message_result is None: - return - await self.reply_msg(message, message_result.result_message) - if message_result.callback is not None: - message_result.callback() - - # 如果是等待回复的消息 - if session_id in self.waiting and self.waiting[session_id] == '': - self.waiting[session_id] = message - - async def reply_msg(self, - message: Union[AstrBotMessage, GuildMessage, GroupMessage, FriendMessage], - result_message: list): - await super().reply_msg() - """ - 插件开发者请使用send方法, 可以不用直接调用这个方法。 - """ - if isinstance(message, AstrBotMessage): - source = message.raw_message - else: - source = message - - res = result_message - - logger.info( - f"{source.user_id} <- {self.parse_message_outline(res)}") - - if isinstance(source, int): - source = FakeSource("GroupMessage", source) - - # str convert to CQ Message Chain - if isinstance(res, str): - res_str = res - res = [] - if source.type == "GroupMessage" and not isinstance(source, FakeSource): - res.append(At(qq=source.user_id)) - res.append(Plain(text=res_str)) - - # if image mode, put all Plain texts into a new picture. - if self.cc.get("qq_pic_mode", False) and isinstance(res, list): - plains = [] - news = [] - for i in res: - if isinstance(i, Plain): - plains.append(i.text) - else: - news.append(i) - plains_str = "".join(plains).strip() - if plains_str != "" and len(plains_str) > 50: - # p = gu.create_markdown_image("".join(plains)) - p = await text_to_image_base(plains_str) - news.append(Image.fromFileSystem(p)) - res = news - - # 回复消息链 - if isinstance(res, list) and len(res) > 0: - if source.type == "GuildMessage": - await self.client.sendGuildChannelMessage(source.guild_id, source.channel_id, res) - return - elif source.type == "FriendMessage": - await self.client.sendFriendMessage(source.user_id, res) - return - elif source.type == "GroupMessage": - # 过长时forward发送 - plain_text_len = 0 - image_num = 0 - for i in res: - if isinstance(i, Plain): - plain_text_len += len(i.text) - elif isinstance(i, Image): - image_num += 1 - if plain_text_len > self.cc.get('qq_forward_threshold', 200): - # 删除At - for i in res: - if isinstance(i, At): - res.remove(i) - node = Node(res) - # node.content = res - node.uin = 123456 - node.name = f"bot" - node.time = int(time.time()) - # print(node) - nodes = [node] - await self.client.sendGroupForwardMessage(source.group_id, nodes) - return - await self.client.sendGroupMessage(source.group_id, res) - return - - async def send_msg(self, message: Union[GroupMessage, FriendMessage, GuildMessage, AstrBotMessage], result_message: list): - ''' - 提供给插件的发送QQ消息接口。 - 参数说明:第一个参数可以是消息对象,也可以是QQ群号。第二个参数是消息内容(消息内容可以是消息链列表,也可以是纯文字信息)。 - ''' - await super().reply_msg() - try: - await self.reply_msg(message, result_message) - except BaseException as e: - raise e - - async def send(self, - to, - res): - ''' - 同 send_msg() - ''' - await super().reply_msg() - await self.reply_msg(to, res) - - async def create_text_image(text: str): - ''' - 文本转图片。 - text: 文本内容 - 返回:文件路径 - ''' - try: - return await text_to_image_base(text) - except Exception as e: - raise e - - def wait_for_message(self, group_id) -> Union[GroupMessage, FriendMessage, GuildMessage]: - ''' - 等待下一条消息,超时 300s 后抛出异常 - ''' - self.waiting[group_id] = '' - cnt = 0 - while True: - if group_id in self.waiting and self.waiting[group_id] != '': - # 去掉 - ret = self.waiting[group_id] - del self.waiting[group_id] - return ret - cnt += 1 - if cnt > 300: - raise Exception("等待消息超时。") - time.sleep(1) - - def get_client(self): - return self.client - - async def nakuru_method_invoker(self, func, *args, **kwargs): - """ - 返回一个方法调用器,可以用来立即调用nakuru的方法。 - """ - try: - ret = func(*args, **kwargs) - return ret - except BaseException as e: - raise e diff --git a/model/platform/qq_nakuru.py b/model/platform/qq_nakuru.py new file mode 100644 index 0000000..a9cc625 --- /dev/null +++ b/model/platform/qq_nakuru.py @@ -0,0 +1,255 @@ +import time, asyncio, traceback + +from nakuru.entities.components import Plain, At, Image, Node, BaseMessageComponent +from nakuru import ( + CQHTTP, + GuildMessage, + GroupMessage, + FriendMessage, + GroupMemberIncrease, + MessageItemType +) +from typing import Union, List, Dict +from type.types import Context +from . import Platform +from type.astrbot_message import * +from type.message_event import * +from SparkleLogging.utils.core import LogManager +from logging import Logger +from astrbot.message.handler import MessageHandler + +logger: Logger = LogManager.GetLogger(log_name='astrbot') + + +class FakeSource: + def __init__(self, type, group_id): + self.type = type + self.group_id = group_id + + +class QQGOCQ(Platform): + def __init__(self, context: Context, message_handler: MessageHandler) -> None: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + self.message_handler = message_handler + self.waiting = {} + self.context = context + self.unique_session = self.context.unique_session + self.announcement = self.context.base_config.get("announcement", "欢迎新人!") + + self.client = CQHTTP( + host=self.context.base_config.get("gocq_host", "127.0.0.1"), + port=self.context.base_config.get("gocq_websocket_port", 6700), + http_port=self.context.base_config.get("gocq_http_port", 5700), + ) + gocq_app = self.client + + @gocq_app.receiver("GroupMessage") + async def _(app: CQHTTP, source: GroupMessage): + if self.context.base_config.get("gocq_react_group", True): + abm = self.convert_message(source) + await self.handle_msg(abm) + + @gocq_app.receiver("FriendMessage") + async def _(app: CQHTTP, source: FriendMessage): + if self.context.base_config.get("gocq_react_friend", True): + abm = self.convert_message(source) + await self.handle_msg(abm) + + @gocq_app.receiver("GroupMemberIncrease") + async def _(app: CQHTTP, source: GroupMemberIncrease): + if self.context.base_config.get("gocq_react_group_increase", True): + await app.sendGroupMessage(source.group_id, [ + Plain(text=self.announcement) + ]) + + @gocq_app.receiver("GuildMessage") + async def _(app: CQHTTP, source: GuildMessage): + if self.cc.get("gocq_react_guild", True): + abm = self.convert_message(source) + await self.handle_msg(abm) + + def pre_check(self, message: AstrBotMessage) -> bool: + # if message chain contains Plain components or At components which points to self_id, return True + if message.type == MessageType.FRIEND_MESSAGE: + return True + for comp in message.message: + if isinstance(comp, Plain): + return True + if isinstance(comp, At) and str(comp.qq) == message.self_id: + return True + # check nicks + if self.check_nick(message.message_str): + return True + return False + + def run(self): + return self.client._run() + + async def handle_msg(self, message: AstrBotMessage): + logger.info( + f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}") + + assert isinstance(message.raw_message, + (GroupMessage, FriendMessage, GuildMessage)) + + # 判断是否响应消息 + if not self.pre_check(message): + return + + # 解析 session_id + if self.unique_session or message.type == MessageType.FRIEND_MESSAGE: + session_id = message.raw_message.user_id + elif message.type == MessageType.GROUP_MESSAGE: + session_id = message.raw_message.group_id + elif message.type == MessageType.GUILD_MESSAGE: + session_id = message.raw_message.channel_id + else: + session_id = message.raw_message.user_id + + message.session_id = session_id + + # 解析 role + sender_id = str(message.raw_message.user_id) + if sender_id == self.context.config_helper.get('admin_qq', '') or \ + sender_id in self.context.config_helper.get('other_admins', []): + role = 'admin' + else: + role = 'member' + + # construct astrbot message event + ame = AstrMessageEvent().from_astrbot_message(message, self.context, "gocq", session_id, role) + + # transfer control to message handler + message_result = await self.message_handler.handle(ame) + if not message_result: return + + await self.reply_msg(message, message_result.result_message) + if message_result.callback: + message_result.callback() + + # 如果是等待回复的消息 + if session_id in self.waiting and self.waiting[session_id] == '': + self.waiting[session_id] = message + + async def reply_msg(self, + message: AstrBotMessage, + result_message: List[BaseMessageComponent]): + """ + 回复用户唤醒机器人的消息。(被动回复) + """ + source = message.raw_message + res = result_message + + assert isinstance(source, + (GroupMessage, FriendMessage, GuildMessage)) + + logger.info( + f"{source.user_id} <- {self.parse_message_outline(res)}") + + if isinstance(res, str): + res = [Plain(text=res), ] + + # if image mode, put all Plain texts into a new picture. + if self.context.config_helper.get("qq_pic_mode", False) and isinstance(res, list): + rendered_images = await self.convert_to_t2i_chain(res) + if rendered_images: + try: + await self._reply(source, rendered_images) + return + except BaseException as e: + logger.warn(traceback.format_exc()) + logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。") + + await self._reply(source, res) + + async def _reply(self, source, message_chain: List[BaseMessageComponent]): + if isinstance(message_chain, str): + message_chain = [Plain(text=message_chain), ] + + is_dict = isinstance(source, dict) + if source.type == "GuildMessage": + guild_id = source['guild_id'] if is_dict else source.guild_id + chan_id = source['channel_id'] if is_dict else source.channel_id + await self.client.sendGuildChannelMessage(guild_id, chan_id, message_chain) + elif source.type == "FriendMessage": + user_id = source['user_id'] if is_dict else source.user_id + await self.client.sendFriendMessage(user_id, message_chain) + elif source.type == "GroupMessage": + group_id = source['group_id'] if is_dict else source.group_id + # 过长时forward发送 + plain_text_len = 0 + image_num = 0 + for i in message_chain: + if isinstance(i, Plain): + plain_text_len += len(i.text) + elif isinstance(i, Image): + image_num += 1 + if plain_text_len > self.context.config_helper.get('qq_forward_threshold', 200): + # 删除At + for i in message_chain: + if isinstance(i, At): + message_chain.remove(i) + node = Node(message_chain) + node.uin = 123456 + node.name = f"bot" + node.time = int(time.time()) + nodes = [node] + await self.client.sendGroupForwardMessage(group_id, nodes) + return + await self.client.sendGroupMessage(group_id, message_chain) + + async def send_msg(self, target: Dict[str, int], result_message: Union[List[BaseMessageComponent], str]): + ''' + 以主动的方式给用户、群或者频道发送一条消息。 + + `target` 接收一个 dict 类型的值引用。 + + - 要发给 QQ 下的某个用户,请添加 key `user_id`,值为 int 类型的 qq 号; + - 要发给某个群聊,请添加 key `group_id`,值为 int 类型的 qq 群号; + - 要发给某个频道,请添加 key `guild_id`, `channel_id`。均为 int 类型。 + + guild_id 不是频道号。 + ''' + await self._reply(target, result_message) + + def convert_message(self, message: Union[GroupMessage, FriendMessage, GuildMessage]) -> AstrBotMessage: + abm = AstrBotMessage() + abm.type = MessageType(message.type) + abm.raw_message = message + abm.message_id = message.message_id + + plain_content = "" + for i in message.message: + if isinstance(i, Plain): + plain_content += i.text + abm.message_str = plain_content.strip() + if message.type == MessageItemType.GuildMessage: + abm.self_id = str(message.self_tiny_id) + else: + abm.self_id = str(message.self_id) + abm.sender = MessageMember( + str(message.sender.user_id), + str(message.sender.nickname) + ) + abm.tag = "gocq" + abm.message = message.message + return abm + + def wait_for_message(self, group_id) -> Union[GroupMessage, FriendMessage, GuildMessage]: + ''' + 等待下一条消息,超时 300s 后抛出异常 + ''' + self.waiting[group_id] = '' + cnt = 0 + while True: + if group_id in self.waiting and self.waiting[group_id] != '': + # 去掉 + ret = self.waiting[group_id] + del self.waiting[group_id] + return ret + cnt += 1 + if cnt > 300: + raise Exception("等待消息超时。") + time.sleep(1) \ No newline at end of file diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index aa3594e..060ea5d 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -1,83 +1,77 @@ -import io import botpy -from PIL import Image as PILImage -import botpy.message import re +import time +import traceback import asyncio -import aiohttp +import botpy.message import botpy.types import botpy.types.message -from util import general_utils as gu -from botpy.types.message import Reference +from botpy.types.message import Reference, Media from botpy import Client -import time -from ._platfrom import Platform -from ._message_parse import ( - qq_official_message_parse_rev, - qq_official_message_parse -) -from type.message import * -from typing import Union, List +from util.io import save_temp_img +from . import Platform +from type.astrbot_message import * +from type.message_event import * +from typing import Union, List, Dict from nakuru.entities.components import * -from util.image_render.helper import text_to_image_base -from util.cmd_config import CmdConfig from SparkleLogging.utils.core import LogManager from logging import Logger +from astrbot.message.handler import MessageHandler -logger: Logger = LogManager.GetLogger(log_name='astrbot-core') +logger: Logger = LogManager.GetLogger(log_name='astrbot') # QQ 机器人官方框架 class botClient(Client): def set_platform(self, platform: 'QQOfficial'): self.platform = platform - + + # 收到群消息 async def on_group_at_message_create(self, message: botpy.message.GroupMessage): - abm = qq_official_message_parse_rev(message, MessageType.GROUP_MESSAGE) + abm = self.platform._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE) await self.platform.handle_msg(abm) # 收到频道消息 async def on_at_message_create(self, message: botpy.message.Message): # 转换层 - abm = qq_official_message_parse_rev(message, MessageType.GUILD_MESSAGE) + abm = self.platform._parse_from_qqofficial(message, MessageType.GUILD_MESSAGE) await self.platform.handle_msg(abm) # 收到私聊消息 async def on_direct_message_create(self, message: botpy.message.DirectMessage): # 转换层 - abm = qq_official_message_parse_rev( - message, MessageType.FRIEND_MESSAGE) + abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE) await self.platform.handle_msg(abm) class QQOfficial(Platform): - def __init__(self, cfg: dict, message_handler: callable, global_object) -> None: - super().__init__(message_handler) - + def __init__(self, context: Context, message_handler: MessageHandler) -> None: + super().__init__() self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - + + self.message_handler = message_handler self.waiting: dict = {} + self.context = context + + self.appid = context.base_config['qqbot']['appid'] + self.token = context.base_config['qqbot']['token'] + self.secret = context.base_config['qqbot_secret'] + self.unique_session = context.unique_session + qq_group = context.base_config['qqofficial_enable_group_message'] - self.cc = CmdConfig() - self.cfg = cfg - self.appid = cfg['qqbot']['appid'] - self.token = cfg['qqbot']['token'] - self.secret = cfg['qqbot_secret'] - self.unique_session = cfg['uniqueSessionMode'] - qq_group = cfg['qqofficial_enable_group_message'] if qq_group: self.intents = botpy.Intents( public_messages=True, public_guild_messages=True, - direct_message=cfg['direct_message_mode'] + direct_message=context.base_config['direct_message_mode'] ) else: self.intents = botpy.Intents( public_guild_messages=True, - direct_message=cfg['direct_message_mode'] + direct_message=context.base_config['direct_message_mode'] ) self.client = botClient( intents=self.intents, @@ -85,27 +79,100 @@ def __init__(self, cfg: dict, message_handler: callable, global_object) -> None: ) self.client.set_platform(self) + + def _parse_to_qqofficial(self, message: List[BaseMessageComponent]): + plain_text = "" + image_path = None # only one img supported + for i in message: + if isinstance(i, Plain): + plain_text += i.text + elif isinstance(i, Image) and not image_path: + if i.path: + image_path = i.path + elif i.file and i.file.startswith("base64://"): + img_data = base64.b64decode(i.file[9:]) + image_path = save_temp_img(img_data) + else: + image_path = save_temp_img(i.file) + return plain_text, image_path + + def _parse_from_qqofficial(self, message: Union[botpy.message.Message, botpy.message.GroupMessage], + message_type: MessageType): + abm = AstrBotMessage() + abm.type = message_type + abm.timestamp = int(time.time()) + abm.raw_message = message + abm.message_id = message.id + abm.tag = "qqchan" + msg: List[BaseMessageComponent] = [] + + if message_type == MessageType.GROUP_MESSAGE: + abm.sender = MessageMember( + message.author.member_openid, + "" + ) + abm.message_str = message.content.strip() + abm.self_id = "unknown_selfid" + + msg.append(Plain(abm.message_str)) + if message.attachments: + for i in message.attachments: + if i.content_type.startswith("image"): + url = i.url + if not url.startswith("http"): + url = "https://"+url + img = Image.fromURL(url) + msg.append(img) + abm.message = msg + + elif message_type == MessageType.GUILD_MESSAGE or message_type == MessageType.FRIEND_MESSAGE: + # 目前对于 FRIEND_MESSAGE 只处理频道私聊 + try: + abm.self_id = str(message.mentions[0].id) + except: + abm.self_id = "" + + plain_content = message.content.replace( + "<@!"+str(abm.self_id)+">", "").strip() + msg.append(Plain(plain_content)) + if message.attachments: + for i in message.attachments: + if i.content_type.startswith("image"): + url = i.url + if not url.startswith("http"): + url = "https://"+url + img = Image.fromURL(url) + msg.append(img) + abm.message = msg + abm.message_str = plain_content + abm.sender = MessageMember( + str(message.author.id), + str(message.author.username) + ) + else: + raise ValueError(f"Unknown message type: {message_type}") + return abm def run(self): try: - self.loop.run_until_complete(self.client.run( + return self.client.start( appid=self.appid, secret=self.secret - )) + ) except BaseException as e: - print(e) + # 早期的 qq-botpy 版本使用 token 登录。 + logger.error(traceback.format_exc()) self.client = botClient( intents=self.intents, bot_log=False ) self.client.set_platform(self) - self.client.run( + return self.client.start( appid=self.appid, token=self.token ) async def handle_msg(self, message: AstrBotMessage): - await super().handle_msg() assert isinstance(message.raw_message, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage)) is_group = message.type != MessageType.FRIEND_MESSAGE @@ -128,24 +195,21 @@ async def handle_msg(self, message: AstrBotMessage): # 解析出 role sender_id = message.sender.user_id - if sender_id == self.cfg['admin_qqchan'] or \ - sender_id in self.cfg['other_admins']: + if sender_id == self.context.config_helper.get('admin_qqchan', None) or \ + sender_id in self.context.config_helper.get('other_admins', None): role = 'admin' else: role = 'member' - - message_result = await self.message_handler( - message=message, - session_id=session_id, - role=role, - platform='qqchan' - ) - - if message_result is None: + + # construct astrbot message event + ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqchan", session_id, role) + + message_result = await self.message_handler.handle(ame) + if not message_result: return await self.reply_msg(message, message_result.result_message) - if message_result.callback is not None: + if message_result.callback: message_result.callback() # 如果是等待回复的消息 @@ -153,73 +217,29 @@ async def handle_msg(self, message: AstrBotMessage): self.waiting[session_id] = message async def reply_msg(self, - message: Union[botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, AstrBotMessage], - res: Union[str, list]): + message: AstrBotMessage, + result_message: List[BaseMessageComponent]): ''' 回复频道消息 ''' - await super().reply_msg() - if isinstance(message, AstrBotMessage): - source = message.raw_message - else: - source = message + source = message.raw_message assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage)) logger.info( - f"{message.sender.nickname}({message.sender.user_id}) <- {self.parse_message_outline(res)}") + f"{message.sender.nickname}({message.sender.user_id}) <- {self.parse_message_outline(result_message)}") plain_text = '' image_path = '' msg_ref = None - - # if isinstance(res, list): - # plain_text, image_path = qq_official_message_parse(res) - # elif isinstance(res, str): - # plain_text = res - - # if self.cfg['qq_pic_mode']: - # # 文本转图片,并且加上原来的图片 - # if plain_text != '' or image_path != '': - # if image_path is not None and image_path != '': - # if image_path.startswith("http"): - # plain_text += "\n\n" + "![](" + image_path + ")" - # else: - # plain_text += "\n\n" + \ - # "![](file:///" + image_path + ")" - # # image_path = gu.create_markdown_image("".join(plain_text)) - # image_path = await text_to_image_base("".join(plain_text)) - # plain_text = "" - - # else: - # if image_path is not None and image_path != '': - # msg_ref = None - # if image_path.startswith("http"): - # async with aiohttp.ClientSession() as session: - # async with session.get(image_path) as response: - # if response.status == 200: - # image = PILImage.open(io.BytesIO(await response.read())) - # image_path = gu.save_temp_img(image) - if self.cc.get("qq_pic_mode", False): - plains = [] - news = [] - if isinstance(res, str): - res = [Plain(text=res, convert=False),] - for i in res: - if isinstance(i, Plain): - plains.append(i.text) - else: - news.append(i) - plains_str = "".join(plains).strip() - if plains_str and len(plains_str) > 50: - p = await text_to_image_base(plains_str, return_url=False) - with open(p, "rb") as f: - news.append(Image.fromBytes(f.read())) - res = news - - if isinstance(res, list): - plain_text, image_path = qq_official_message_parse(res) + rendered_images = [] + + if self.context.config_helper.get("qq_pic_mode", False) and isinstance(result_message, list): + rendered_images = await self.convert_to_t2i_chain(result_message) + + if isinstance(result_message, list): + plain_text, image_path = self._parse_to_qqofficial(result_message) else: - plain_text = res + plain_text = result_message if source and not image_path: # file_image与message_reference不能同时传入 msg_ref = Reference(message_id=source.id, @@ -231,22 +251,33 @@ async def reply_msg(self, 'msg_id': message.message_id, 'message_reference': msg_ref } + if message.type == MessageType.GROUP_MESSAGE: data['group_openid'] = str(source.group_openid) elif message.type == MessageType.GUILD_MESSAGE: data['channel_id'] = source.channel_id elif message.type == MessageType.FRIEND_MESSAGE: - # 目前只处理频道私聊 data['guild_id'] = source.guild_id - else: - raise ValueError(f"未知的消息类型: {message.type}") if image_path: data['file_image'] = image_path + if rendered_images: + # 文转图 + _data = data.copy() + _data['content'] = '' + _data['file_image'] = rendered_images[0].file + _data['message_reference'] = None + + try: + await self._reply(**_data) + return + except BaseException as e: + logger.warn(traceback.format_exc()) + logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。") try: - await self._send_wrapper(**data) + await self._reply(**data) except BaseException as e: - print(e) + logger.error(traceback.format_exc()) # 分割过长的消息 if "msg over length" in str(e): split_res = [] @@ -254,68 +285,70 @@ async def reply_msg(self, split_res.append(plain_text[len(plain_text)//2:]) for i in split_res: data['content'] = i - await self._send_wrapper(**data) + await self._reply(**data) else: - # 发送qq信息 try: # 防止被qq频道过滤消息 plain_text = plain_text.replace(".", " . ") - await self._send_wrapper(**data) - + await self._reply(**data) except BaseException as e: try: data['content'] = str.join(" ", plain_text) - await self._send_wrapper(**data) + await self._reply(**data) except BaseException as e: plain_text = re.sub( r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE) plain_text = plain_text.replace(".", "·") data['content'] = plain_text - await self._send_wrapper(**data) + await self._reply(**data) - async def _send_wrapper(self, **kwargs): + async def _reply(self, **kwargs): if 'group_openid' in kwargs: # QQ群组消息 - media = None # qq群组消息需要自行上传,暂时不处理 - # if 'file_image' in kwargs: - # file_image_path = kwargs['file_image'] - # if file_image_path != "": - # media = await self.upload_img(file_image_path, kwargs['group_openid']) - # del kwargs['file_image'] - # if media is not None: - # kwargs['msg_type'] = 7 # 富媒体 - await self.client.api.post_group_message(media=media, **kwargs) + if 'file_image' in kwargs: + file_image_path = kwargs['file_image'].replace("file:///", "") + if file_image_path: + logger.debug(f"上传图片: {file_image_path}") + image_url = await self.context.image_uploader.upload_image(file_image_path) + logger.debug(f"上传成功: {image_url}") + media = await self.client.api.post_group_file(kwargs['group_openid'], 1, image_url) + del kwargs['file_image'] + kwargs['media'] = media + kwargs['msg_type'] = 7 # 富媒体 + await self.client.api.post_group_message(**kwargs) elif 'channel_id' in kwargs: # 频道消息 if 'file_image' in kwargs: - kwargs['file_image'] = kwargs['file_image'].replace( - "file://", "") + kwargs['file_image'] = kwargs['file_image'].replace("file:///", "") await self.client.api.post_message(**kwargs) else: # 频道私聊消息 if 'file_image' in kwargs: - kwargs['file_image'] = kwargs['file_image'].replace( - "file://", "") + kwargs['file_image'] = kwargs['file_image'].replace("file:///", "") await self.client.api.post_dms(**kwargs) - async def send_msg(self, - message_obj: Union[botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, AstrBotMessage], - message_chain: List[BaseMessageComponent], - ): + async def send_msg(self, target: Dict[str, str], result_message: Union[List[BaseMessageComponent], str]): ''' - 发送消息。目前只支持被动回复消息(即拥有一个 botpy Message 类型的 message_obj 传入) + 以主动的方式给用户、群或者频道发送一条消息。 + + `target` 接收一个 dict 类型的值引用。 + + - 如果目标是 QQ 群,请添加 key `group_openid`。 + - 如果目标是 频道消息,请添加 key `channel_id`。 + - 如果目标是 频道私聊,请添加 key `guild_id`。 ''' - await self.reply_msg(message_obj, message_chain) + if isinstance(result_message, list): + plain_text, image_path = self._parse_to_qqofficial(result_message) + else: + plain_text = result_message - async def send(self, - message_obj: Union[botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, AstrBotMessage], - message_chain: List[BaseMessageComponent], - ): - ''' - 发送消息。目前只支持被动回复消息(即拥有一个 botpy Message 类型的 message_obj 传入) - ''' - await self.reply_msg(message_obj, message_chain) + payload = { + 'content': plain_text, + 'file_image': image_path, + **target + } + await self._reply(**payload) def wait_for_message(self, channel_id: int) -> AstrBotMessage: ''' diff --git a/model/plugin/command.py b/model/plugin/command.py new file mode 100644 index 0000000..bd6c8e4 --- /dev/null +++ b/model/plugin/command.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass +from type.register import RegisteredPlugins +from typing import List, Union, Callable +from SparkleLogging.utils.core import LogManager +from logging import Logger + +logger: Logger = LogManager.GetLogger(log_name='astrbot') + + +@dataclass +class CommandRegisterRequest(): + command_name: str + description: str + priority: int + handler: Callable + plugin_name: str = None + +class PluginCommandBridge(): + def __init__(self, cached_plugins: RegisteredPlugins): + self.plugin_commands_waitlist: List[CommandRegisterRequest] = [] + self.cached_plugins = cached_plugins + + def register_command(self, plugin_name, command_name, description, priority, handler): + self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, plugin_name)) + \ No newline at end of file diff --git a/model/plugin/manager.py b/model/plugin/manager.py new file mode 100644 index 0000000..11d33a0 --- /dev/null +++ b/model/plugin/manager.py @@ -0,0 +1,256 @@ +import inspect +import os +import sys +import traceback +import uuid +import shutil +import yaml + +from util.updator.plugin_updator import PluginUpdator +from util.io import remove_dir, download_file +from types import ModuleType +from type.types import Context +from type.plugin import * +from type.register import * +from SparkleLogging.utils.core import LogManager +from logging import Logger + +logger: Logger = LogManager.GetLogger(log_name='astrbot') + +class PluginManager(): + def __init__(self, context: Context): + self.updator = PluginUpdator() + self.plugin_store_path = self.updator.get_plugin_store_path() + self.context = context + + def get_classes(self, arg: ModuleType): + classes = [] + clsmembers = inspect.getmembers(arg, inspect.isclass) + for (name, _) in clsmembers: + if name.lower().endswith("plugin") or name.lower() == "main": + classes.append(name) + break + return classes + + def get_modules(self, path): + modules = [] + + dirs = os.listdir(path) + # 遍历文件夹,找到 main.py 或者和文件夹同名的文件 + for d in dirs: + if os.path.isdir(os.path.join(path, d)): + if os.path.exists(os.path.join(path, d, "main.py")): + module_str = 'main' + elif os.path.exists(os.path.join(path, d, d + ".py")): + module_str = d + else: + print(f"插件 {d} 未找到 main.py 或者 {d}.py,跳过。") + continue + if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists(os.path.join(path, d, d + ".py")): + modules.append({ + "pname": d, + "module": module_str, + "module_path": os.path.join(path, d, module_str) + }) + return modules + + def get_plugin_modules(self): + plugins = [] + try: + plugin_dir = self.plugin_store_path + if os.path.exists(plugin_dir): + plugins = self.get_modules(plugin_dir) + return plugins + except BaseException as e: + raise e + + def check_plugin_dept_update(self, target_plugin: str = None): + plugin_dir = self.plugin_store_path + if not os.path.exists(plugin_dir): + return False + to_update = [] + if target_plugin: + to_update.append(target_plugin) + else: + for p in self.context.cached_plugins: + to_update.append(p.root_dir_name) + for p in to_update: + plugin_path = os.path.join(plugin_dir, p) + if os.path.exists(os.path.join(plugin_path, "requirements.txt")): + pth = os.path.join(plugin_path, "requirements.txt") + logger.info(f"正在检查更新插件 {p} 的依赖: {pth}") + self.update_plugin_dept(os.path.join(plugin_path, "requirements.txt")) + + def update_plugin_dept(self, path): + mirror = "https://mirrors.aliyun.com/pypi/simple/" + py = sys.executable + os.system(f"{py} -m pip install -r {path} -i {mirror} --quiet") + + def install_plugin(self, repo_url: str): + ppath = self.plugin_store_path + + # we no longer use Git anymore :) + # Repo.clone_from(repo_url, to_path=plugin_path, branch='master') + + plugin_path = self.updator.update(repo_url) + with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f: + f.write(repo_url) + + ok, err = self.plugin_reload() + if not ok: + raise Exception(err) + + def download_from_repo_url(self, target_path: str, repo_url: str): + repo_namespace = repo_url.split("/")[-2:] + author = repo_namespace[0] + repo = repo_namespace[1] + + logger.info(f"正在下载插件 {repo} ...") + release_url = f"https://api.github.com/repos/{author}/{repo}/releases" + releases = self.updator.fetch_release_info(url=release_url) + if not releases: + # download from the default branch directly. + logger.warn(f"未在插件 {author}/{repo} 中找到任何发布版本,将从默认分支下载。") + release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip" + else: + release_url = releases[0]['zipball_url'] + + download_file(release_url, target_path + ".zip") + + def get_registered_plugin(self, plugin_name: str) -> RegisteredPlugin: + for p in self.context.cached_plugins: + if p.metadata.plugin_name == plugin_name: + return p + + def uninstall_plugin(self, plugin_name: str): + plugin = self.get_registered_plugin(plugin_name, self.context.cached_plugins) + if not plugin: + raise Exception("插件不存在。") + root_dir_name = plugin.root_dir_name + ppath = self.plugin_store_path + self.context.cached_plugins.remove(plugin) + if not remove_dir(os.path.join(ppath, root_dir_name)): + raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。") + + def update_plugin(self, plugin_name: str): + plugin = self.get_registered_plugin(plugin_name, self.context.cached_plugins) + if not plugin: + raise Exception("插件不存在。") + + self.updator.update(plugin) + + def plugin_reload(self): + cached_plugins = self.context.cached_plugins + plugins = self.get_plugin_modules() + if plugins is None: + return False, "未找到任何插件模块" + fail_rec = "" + + registered_map = {} + for p in cached_plugins: + registered_map[p.module_path] = None + + for plugin in plugins: + try: + p = plugin['module'] + module_path = plugin['module_path'] + root_dir_name = plugin['pname'] + + # self.check_plugin_dept_update(cached_plugins, root_dir_name) + + module = __import__("addons.plugins." + + root_dir_name + "." + p, fromlist=[p]) + + cls = self.get_classes(module) + + try: + # 尝试传入 ctx + obj = getattr(module, cls[0])(ctx=self.context) + except: + obj = getattr(module, cls[0])() + + metadata = None + + plugin_path = os.path.join(self.plugin_store_path, root_dir_name) + metadata = self.load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj) + + logger.info(f"插件 {metadata.plugin_name}({metadata.author}) 加载成功。") + + if module_path not in registered_map: + cached_plugins.append(RegisteredPlugin( + metadata=metadata, + plugin_instance=obj, + module=module, + module_path=module_path, + root_dir_name=root_dir_name + )) + except BaseException as e: + traceback.print_exc() + fail_rec += f"加载{p}插件出现问题,原因 {str(e)}\n" + + if not fail_rec: + return True, None + else: + return False, fail_rec + + def install_plugin_from_file(self, zip_file_path: str): + # try to unzip + temp_dir = os.path.join(os.path.dirname(zip_file_path), str(uuid.uuid4())) + self.updator.unzip_file(zip_file_path, temp_dir) + # check if the plugin has metadata.yaml + if not os.path.exists(os.path.join(temp_dir, "metadata.yaml")): + remove_dir(temp_dir) + raise Exception("插件缺少 metadata.yaml 文件。") + + metadata = self.load_plugin_metadata(temp_dir) + plugin_name = metadata.plugin_name + if not plugin_name: + remove_dir(temp_dir) + raise Exception("插件 metadata.yaml 文件中 name 字段为空。") + plugin_name = self.updator.format_name(plugin_name) + + ppath = self.plugin_store_path + plugin_path = os.path.join(ppath, plugin_name) + if os.path.exists(plugin_path): + remove_dir(plugin_path) + + # move to the target path + shutil.move(temp_dir, plugin_path) + + if metadata.repo: + with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f: + f.write(metadata.repo) + + # remove the temp dir + remove_dir(temp_dir) + + ok, err = self.plugin_reload() + if not ok: + raise Exception(err) + + def load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> PluginMetadata: + metadata = None + + if not os.path.exists(plugin_path): + raise Exception("插件不存在。") + + if os.path.exists(os.path.join(plugin_path, "metadata.yaml")): + with open(os.path.join(plugin_path, "metadata.yaml"), "r", encoding='utf-8') as f: + metadata = yaml.safe_load(f) + elif plugin_obj: + # 使用 info() 函数 + metadata = plugin_obj.info() + + if isinstance(metadata, dict): + if 'name' not in metadata or 'desc' not in metadata or 'version' not in metadata or 'author' not in metadata: + raise Exception("插件元数据信息不完整。") + metadata = PluginMetadata( + plugin_name=metadata['name'], + plugin_type=PluginType.COMMON if 'plugin_type' not in metadata else PluginType(metadata['plugin_type']), + author=metadata['author'], + desc=metadata['desc'], + version=metadata['version'], + repo=metadata['repo'] if 'repo' in metadata else None + ) + + return metadata \ No newline at end of file diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index 7be1326..ab7516f 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -12,7 +12,7 @@ from openai.types.chat.chat_completion import ChatCompletion from openai._exceptions import * -from persist.session import dbConn +from astrbot.persist.helper import dbConn from model.provider.provider import Provider from util import general_utils as gu from util.cmd_config import CmdConfig @@ -20,7 +20,9 @@ from logging import Logger from typing import List, Dict -logger: Logger = LogManager.GetLogger(log_name='astrbot-core') +from type.types import Context + +logger: Logger = LogManager.GetLogger(log_name='astrbot') MODELS = { "gpt-4o": 128000, @@ -46,7 +48,7 @@ } class ProviderOpenAIOfficial(Provider): - def __init__(self, cfg) -> None: + def __init__(self, context: Context) -> None: super().__init__() os.makedirs("data/openai", exist_ok=True) @@ -57,6 +59,8 @@ def __init__(self, cfg) -> None: self.chosen_api_key = None self.base_url = None self.keys_data = {} # 记录超额 + + cfg = context.base_config['openai'] if cfg['key']: self.api_keys = cfg['key'] if cfg['api_base']: self.base_url = cfg['api_base'] @@ -78,11 +82,12 @@ def __init__(self, cfg) -> None: self.session_memory: Dict[str, List] = {} # 会话记忆 self.session_memory_lock = threading.Lock() self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小 + + logger.info("正在载入分词器 cl100k_base...") self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器 - self.DEFAULT_PERSONALITY = { - "name": "default", - "prompt": "你是一个很有帮助的 AI 助手。" - } + logger.info("分词器载入完成。") + + self.DEFAULT_PERSONALITY = context.default_personality self.curr_personality = self.DEFAULT_PERSONALITY self.session_personality = {} # 记录了某个session是否已设置人格。 # 从 SQLite DB 读取历史记录 diff --git a/requirements.txt b/requirements.txt index c2dab44..270244f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ pydantic~=1.10.4 aiohttp requests -openai~=1.2.3 +openai qq-botpy chardet~=5.1.0 Pillow @@ -16,3 +16,4 @@ flask psutil lxml_html_clean SparkleLogging +aiocqhttp diff --git a/type/message.py b/type/astrbot_message.py similarity index 53% rename from type/message.py rename to type/astrbot_message.py index ca49d51..61719fb 100644 --- a/type/message.py +++ b/type/astrbot_message.py @@ -1,17 +1,14 @@ +import time from enum import Enum from typing import List from dataclasses import dataclass from nakuru.entities.components import BaseMessageComponent -from type.register import RegisteredPlatform -from type.types import GlobalObject - class MessageType(Enum): GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息 FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息 GUILD_MESSAGE = 'GuildMessage' # 频道消息 - @dataclass class MessageMember(): user_id: str # 发送者id @@ -32,31 +29,9 @@ class AstrBotMessage(): message_str: str # 最直观的纯文本消息字符串 raw_message: object timestamp: int # 消息时间戳 + + def __init__(self) -> None: + self.timestamp = int(time.time()) def __str__(self) -> str: return str(self.__dict__) - -class AstrMessageEvent(): - ''' - 消息事件。 - ''' - context: GlobalObject # 一些公用数据 - message_str: str # 纯消息字符串 - message_obj: AstrBotMessage # 消息对象 - platform: RegisteredPlatform # 来源平台 - role: str # 基本身份。`admin` 或 `member` - session_id: int # 会话 id - - def __init__(self, - message_str: str, - message_obj: AstrBotMessage, - platform: RegisteredPlatform, - role: str, - context: GlobalObject, - session_id: str = None): - self.context = context - self.message_str = message_str - self.message_obj = message_obj - self.platform = platform - self.role = role - self.session_id = session_id diff --git a/type/command.py b/type/command.py index 73ed3ef..f09a0a0 100644 --- a/type/command.py +++ b/type/command.py @@ -1,5 +1,6 @@ from typing import Union, List, Callable from dataclasses import dataclass +from nakuru.entities.components import Plain @dataclass @@ -18,11 +19,15 @@ class CommandResult(): 用于在Command中返回多个值 ''' - def __init__(self, hit: bool, success: bool = False, message_chain: list = [], command_name: str = "unknown_command") -> None: + def __init__(self, hit: bool = True, success: bool = True, message_chain: list = [], command_name: str = "unknown_command") -> None: self.hit = hit self.success = success self.message_chain = message_chain self.command_name = command_name + + def message(self, message: str): + self.message_chain = [Plain(message), ] + return self def _result_tuple(self): return (self.success, self.message_chain, self.command_name) diff --git a/type/message_event.py b/type/message_event.py new file mode 100644 index 0000000..12d3931 --- /dev/null +++ b/type/message_event.py @@ -0,0 +1,55 @@ +from typing import List, Union, Optional +from dataclasses import dataclass +from type.register import RegisteredPlatform +from type.types import Context +from type.astrbot_message import AstrBotMessage + +class AstrMessageEvent(): + + def __init__(self, + message_str: str, + message_obj: AstrBotMessage, + platform: RegisteredPlatform, + role: str, + context: Context, + session_id: str = None): + ''' + AstrBot 消息事件。 + + `message_str`: 纯消息字符串 + `message_obj`: AstrBotMessage 对象 + `platform`: 平台对象 + `role`: 角色,`admin` or `member` + `context`: 全局对象 + `session_id`: 会话id + ''' + self.context = context + self.message_str = message_str + self.message_obj = message_obj + self.platform = platform + self.role = role + self.session_id = session_id + + def from_astrbot_message(message: AstrBotMessage, + context: Context, + platform_name: str, + session_id: str, + role: str = "member"): + + ame = AstrMessageEvent(message.message_str, + message, + context.find_platform(platform_name), + role, + context, + session_id) + return ame + + + + + +@dataclass +class MessageResult(): + result_message: Union[str, list] + is_command_call: Optional[bool] = False + callback: Optional[callable] = None diff --git a/type/plugin.py b/type/plugin.py index c0aff62..23dc976 100644 --- a/type/plugin.py +++ b/type/plugin.py @@ -1,4 +1,6 @@ from enum import Enum +from types import ModuleType +from typing import List from dataclasses import dataclass class PluginType(Enum): @@ -25,3 +27,27 @@ class PluginMetadata: def __str__(self) -> str: return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})" + +@dataclass +class RegisteredPlugin: + ''' + 注册在 AstrBot 中的插件。 + ''' + metadata: PluginMetadata + plugin_instance: object + module_path: str + module: ModuleType + root_dir_name: str + trig_cnt: int = 0 + + def reset_trig_cnt(self): + self.trig_cnt = 0 + + def trig(self): + self.trig_cnt += 1 + + def __str__(self) -> str: + return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})" + + +RegisteredPlugins = List[RegisteredPlugin] diff --git a/type/register.py b/type/register.py index d26fbfd..1cf679a 100644 --- a/type/register.py +++ b/type/register.py @@ -1,35 +1,9 @@ from model.provider.provider import Provider as LLMProvider -from model.platform._platfrom import Platform +from model.platform import Platform from type.plugin import * from typing import List -from types import ModuleType from dataclasses import dataclass -@dataclass -class RegisteredPlugin: - ''' - 注册在 AstrBot 中的插件。 - ''' - metadata: PluginMetadata - plugin_instance: object - module_path: str - module: ModuleType - root_dir_name: str - trig_cnt: int = 0 - - def reset_trig_cnt(self): - self.trig_cnt = 0 - - def trig(self): - self.trig_cnt += 1 - - def __str__(self) -> str: - return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})" - - -RegisteredPlugins = List[RegisteredPlugin] - - @dataclass class RegisteredPlatform: ''' diff --git a/type/types.py b/type/types.py index 3e4ff82..6dd569d 100644 --- a/type/types.py +++ b/type/types.py @@ -1,36 +1,63 @@ - from type.register import * from typing import List from logging import Logger +from util.cmd_config import CmdConfig +from util.t2i.renderer import TextToImageRenderer +from util.updator.astrbot_updator import AstrBotUpdator +from util.image_uploader import ImageUploader +from util.updator.plugin_updator import PluginUpdator +from model.plugin.command import PluginCommandBridge + -class GlobalObject: +class Context: ''' 存放一些公用的数据,用于在不同模块(如core与command)之间传递 ''' - version: str # 机器人版本 - nick: tuple # 用户定义的机器人的别名 - base_config: dict # config.json 中导出的配置 - cached_plugins: List[RegisteredPlugin] # 加载的插件 - platforms: List[RegisteredPlatform] - llms: List[RegisteredLLM] - - web_search: bool # 是否开启了网页搜索 - reply_prefix: str # 回复前缀 - unique_session: bool # 是否开启了独立会话 - default_personality: dict - dashboard_data = None - - logger: Logger = None def __init__(self): - self.nick = None # gocq 的昵称 - self.base_config = None # config.yaml - self.cached_plugins = [] # 缓存的插件 - self.web_search = False # 是否开启了网页搜索 - self.reply_prefix = None - self.unique_session = False - self.platforms = [] - self.llms = [] - self.default_personality = None - self.dashboard_data = None + self.logger: Logger = None + self.base_config: dict = None # 配置(期望启动机器人后是不变的) + self.config_helper: CmdConfig = None + self.cached_plugins: List[RegisteredPlugin] = [] # 缓存的插件 + self.platforms: List[RegisteredPlatform] = [] + self.llms: List[RegisteredLLM] = [] + self.default_personality: dict = None + + self.unique_session = False # 独立会话 + self.version: str = None # 机器人版本 + self.nick = None # gocq 的唤醒词 self.stat = {} + self.t2i_mode = False + self.web_search = False # 是否开启了网页搜索 + self.reply_prefix = "" + self.updator: AstrBotUpdator = None + self.plugin_updator: PluginUpdator = None + + self.plugin_command_bridge = PluginCommandBridge(self.cached_plugins) + self.image_renderer = TextToImageRenderer() + self.image_uploader = ImageUploader() + + def register_commands(self, + plugin_name: str, + command_name: str, + description: str, + priority: int, + handler: callable): + ''' + 注册插件指令。 + + `plugin_name`: 插件名,注意需要和你的 metadata 中的一致。 + `command_name`: 指令名,如 "help"。不需要带前缀。 + `description`: 指令描述。 + `priority`: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。 + `handler`: 指令处理函数。 + ''' + self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler) + + + def find_platform(self, platform_name: str) -> RegisteredPlatform: + for platform in self.platforms: + if platform_name == platform.platform_name: + return platform + + raise ValueError("couldn't find the platform you specified") diff --git a/util/agent/web_searcher.py b/util/agent/web_searcher.py index e07f808..70a6f11 100644 --- a/util/agent/web_searcher.py +++ b/util/agent/web_searcher.py @@ -9,15 +9,15 @@ from bs4 import BeautifulSoup from openai.types.chat.chat_completion_message_tool_call import Function from util.agent.func_call import FuncCall -from util.search_engine_scraper.config import HEADERS, USER_AGENTS -from util.search_engine_scraper.bing import Bing -from util.search_engine_scraper.sogo import Sogo -from util.search_engine_scraper.google import Google +from util.websearch.config import HEADERS, USER_AGENTS +from util.websearch.bing import Bing +from util.websearch.sogo import Sogo +from util.websearch.google import Google from model.provider.provider import Provider from SparkleLogging.utils.core import LogManager from logging import Logger -logger: Logger = LogManager.GetLogger(log_name='astrbot-core') +logger: Logger = LogManager.GetLogger(log_name='astrbot') bing_search = Bing() diff --git a/util/cmd_config.py b/util/cmd_config.py index 279f191..337534b 100644 --- a/util/cmd_config.py +++ b/util/cmd_config.py @@ -1,6 +1,5 @@ import os import json -import yaml from typing import Union cpath = "data/cmd_config.json" @@ -11,7 +10,6 @@ def check_exist(): json.dump({}, f, indent=4, ensure_ascii=False) f.flush() - class CmdConfig(): @staticmethod @@ -83,63 +81,3 @@ def init_attributes(key: Union[str, list], init_val=""): with open(cpath, "w", encoding="utf-8-sig") as f: json.dump(d, f, indent=4, ensure_ascii=False) f.flush() - - -def init_astrbot_config_items(): - # 加载默认配置 - cc = CmdConfig() - cc.init_attributes("qq_forward_threshold", 200) - cc.init_attributes("qq_welcome", "") - cc.init_attributes("qq_pic_mode", False) - cc.init_attributes("gocq_host", "127.0.0.1") - cc.init_attributes("gocq_http_port", 5700) - cc.init_attributes("gocq_websocket_port", 6700) - cc.init_attributes("gocq_react_group", True) - cc.init_attributes("gocq_react_guild", True) - cc.init_attributes("gocq_react_friend", True) - cc.init_attributes("gocq_react_group_increase", True) - cc.init_attributes("other_admins", []) - cc.init_attributes("CHATGPT_BASE_URL", "") - cc.init_attributes("qqbot_secret", "") - cc.init_attributes("qqofficial_enable_group_message", False) - cc.init_attributes("admin_qq", "") - cc.init_attributes("nick_qq", ["!", "!", "ai"]) - cc.init_attributes("admin_qqchan", "") - cc.init_attributes("llm_env_prompt", "") - cc.init_attributes("llm_wake_prefix", "") - cc.init_attributes("default_personality_str", "") - cc.init_attributes("openai_image_generate", { - "model": "dall-e-3", - "size": "1024x1024", - "style": "vivid", - "quality": "standard", - }) - cc.init_attributes("http_proxy", "") - cc.init_attributes("https_proxy", "") - cc.init_attributes("dashboard_username", "") - cc.init_attributes("dashboard_password", "") - - - -def try_migrate_config(): - ''' - 将 cmd_config.json 迁移至 data/cmd_config.json - ''' - print("try migrate configs") - if os.path.exists("cmd_config.json"): - with open("cmd_config.json", "r", encoding="utf-8-sig") as f: - data = json.load(f) - with open("data/cmd_config.json", "w", encoding="utf-8-sig") as f: - json.dump(data, f, indent=2, ensure_ascii=False) - try: - os.remove("cmd_config.json") - except Exception as e: - pass - if not os.path.exists("cmd_config.json") and not os.path.exists("data/cmd_config.json"): - # 从 configs/config.yaml 上拿数据 - configs_pth = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../configs/config.yaml")) - with open(configs_pth, encoding='utf-8') as f: - data = yaml.load(f, Loader=yaml.Loader) - print(data) - with open("data/cmd_config.json", "w", encoding="utf-8-sig") as f: - json.dump(data, f, indent=2, ensure_ascii=False) diff --git a/util/config_utils.py b/util/config_utils.py new file mode 100644 index 0000000..688bd74 --- /dev/null +++ b/util/config_utils.py @@ -0,0 +1,109 @@ +import json, os +from util.cmd_config import CmdConfig +from type.config import VERSION +from type.types import Context + +def init_configs(): + ''' + 初始化必需的配置项 + ''' + cc = CmdConfig() + + cc.init_attributes("qq_forward_threshold", 200) + cc.init_attributes("qq_welcome", "") + cc.init_attributes("qq_pic_mode", False) + cc.init_attributes("gocq_host", "127.0.0.1") + cc.init_attributes("gocq_http_port", 5700) + cc.init_attributes("gocq_websocket_port", 6700) + cc.init_attributes("gocq_react_group", True) + cc.init_attributes("gocq_react_guild", True) + cc.init_attributes("gocq_react_friend", True) + cc.init_attributes("gocq_react_group_increase", True) + cc.init_attributes("other_admins", []) + cc.init_attributes("CHATGPT_BASE_URL", "") + cc.init_attributes("qqbot_secret", "") + cc.init_attributes("qqofficial_enable_group_message", False) + cc.init_attributes("admin_qq", "") + cc.init_attributes("nick_qq", ["!", "!", "ai"]) + cc.init_attributes("admin_qqchan", "") + cc.init_attributes("llm_env_prompt", "") + cc.init_attributes("llm_wake_prefix", "") + cc.init_attributes("default_personality_str", "") + cc.init_attributes("openai_image_generate", { + "model": "dall-e-3", + "size": "1024x1024", + "style": "vivid", + "quality": "standard", + }) + cc.init_attributes("http_proxy", "") + cc.init_attributes("https_proxy", "") + cc.init_attributes("dashboard_username", "") + cc.init_attributes("dashboard_password", "") + + # aiocqhttp 适配器 + cc.init_attributes("aiocqhttp", { + "enable": False, + "ws_reverse_host": "", + "ws_reverse_port": 0, + }) + +def try_migrate_config(): + ''' + 将 cmd_config.json 迁移至 data/cmd_config.json (如果存在的话) + ''' + if os.path.exists("cmd_config.json"): + with open("cmd_config.json", "r", encoding="utf-8-sig") as f: + data = json.load(f) + with open("data/cmd_config.json", "w", encoding="utf-8-sig") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + try: + os.remove("cmd_config.json") + except Exception as e: + pass + +def inject_to_context(context: Context): + ''' + 将配置注入到 Context 中。 + this method returns all the configs + ''' + cc = CmdConfig() + + context.version = VERSION + context.base_config = cc.get_all() + + cfg = context.base_config + + if 'reply_prefix' in cfg: + # 适配旧版配置 + if isinstance(cfg['reply_prefix'], dict): + context.reply_prefix = "" + cfg['reply_prefix'] = "" + cc.put("reply_prefix", "") + else: + context.reply_prefix = cfg['reply_prefix'] + + default_personality_str = cc.get("default_personality_str", "") + if default_personality_str == "": + context.default_personality = None + else: + context.default_personality = { + "name": "default", + "prompt": default_personality_str, + } + + if 'uniqueSessionMode' in cfg and cfg['uniqueSessionMode']: + context.unique_session = True + else: + context.unique_session = False + + nick_qq = cc.get("nick_qq", None) + if nick_qq == None: + nick_qq = ("/", ) + if isinstance(nick_qq, str): + nick_qq = (nick_qq, ) + if isinstance(nick_qq, list): + nick_qq = tuple(nick_qq) + context.nick = nick_qq + context.t2i_mode = cc.get("qq_pic_mode", False) + + return cfg \ No newline at end of file diff --git a/util/general_utils.py b/util/general_utils.py index effcaf0..28bdabd 100644 --- a/util/general_utils.py +++ b/util/general_utils.py @@ -1,415 +1,29 @@ import time -import socket -import os -import re +import asyncio import requests -import aiohttp -import socket import json import sys import psutil -import ssl -import zipfile -import shutil -import stat -from PIL import Image, ImageDraw, ImageFont -from type.types import GlobalObject +from type.types import Context from SparkleLogging.utils.core import LogManager from logging import Logger -logger: Logger = LogManager.GetLogger(log_name='astrbot-core') +logger: Logger = LogManager.GetLogger(log_name='astrbot') - -def port_checker(port: int, host: str = "localhost"): - sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sk.settimeout(1) - try: - sk.connect((host, port)) - sk.close() - return True - except Exception: - sk.close() - return False - -def get_font(size: int) -> ImageFont.FreeTypeFont: - # get yahei first - # common and default fonts on Windows, macOS and Linux - fonts = ["msyh.ttc", "NotoSansCJK-Regular.ttc", "msyhbd.ttc", "PingFang.ttc", "Heiti.ttc"] - for font in fonts: - try: - font = ImageFont.truetype(font, size) - return font - except Exception as e: - pass - - -def render_markdown(markdown_text, image_width=800, image_height=600, font_size=26, font_color=(0, 0, 0), bg_color=(255, 255, 255)): - - HEADER_MARGIN = 20 - HEADER_FONT_STANDARD_SIZE = 42 - - QUOTE_LEFT_LINE_MARGIN = 10 - QUOTE_FONT_LINE_MARGIN = 6 # 引用文字距离左边线的距离和上下的距离 - QUOTE_LEFT_LINE_HEIGHT = font_size + QUOTE_FONT_LINE_MARGIN * 2 - QUOTE_LEFT_LINE_WIDTH = 5 - QUOTE_LEFT_LINE_COLOR = (180, 180, 180) - QUOTE_FONT_SIZE = font_size - QUOTE_FONT_COLOR = (180, 180, 180) - # QUOTE_BG_COLOR = (255, 255, 255) - - CODE_BLOCK_MARGIN = 10 - CODE_BLOCK_FONT_SIZE = font_size - CODE_BLOCK_FONT_COLOR = (255, 255, 255) - CODE_BLOCK_BG_COLOR = (240, 240, 240) - CODE_BLOCK_CODES_MARGIN_VERTICAL = 5 # 代码块和代码之间的距离 - CODE_BLOCK_CODES_MARGIN_HORIZONTAL = 5 # 代码块和代码之间的距离 - CODE_BLOCK_TEXT_MARGIN = 4 # 代码和代码之间的距离 - - INLINE_CODE_MARGIN = 8 - INLINE_CODE_FONT_SIZE = font_size - INLINE_CODE_FONT_COLOR = font_color - INLINE_CODE_FONT_MARGIN = 4 - INLINE_CODE_BG_COLOR = (230, 230, 230) - INLINE_CODE_BG_HEIGHT = INLINE_CODE_FONT_SIZE + INLINE_CODE_FONT_MARGIN * 2 - - LIST_MARGIN = 8 - LIST_FONT_SIZE = font_size - LIST_FONT_COLOR = font_color - - TEXT_LINE_MARGIN = 8 - - IMAGE_MARGIN = 15 - # 用于匹配图片的正则表达式 - IMAGE_REGEX = r"!\s*\[.*?\]\s*\((.*?)\)" - - # 加载字体 - font = get_font(font_size) - - images: Image = {} - - # pre_process, get height of each line - pre_lines = markdown_text.split('\n') - height = 0 - pre_in_code = False - i = -1 - _pre_lines = [] - for line in pre_lines: - i += 1 - # 处理图片 - if re.search(IMAGE_REGEX, line): - try: - image_url = re.findall(IMAGE_REGEX, line)[0] - print(image_url) - image_res = Image.open(requests.get( - image_url, stream=True, timeout=5).raw) - images[i] = image_res - # 最大不得超过image_width的50% - img_height = image_res.size[1] - - if image_res.size[0] > image_width*0.5: - image_res = image_res.resize( - (int(image_width*0.5), int(image_res.size[1]*image_width*0.5/image_res.size[0]))) - img_height = image_res.size[1] - - height += img_height + IMAGE_MARGIN*2 - - line = re.sub(IMAGE_REGEX, "", line) - except Exception as e: - print(e) - line = re.sub(IMAGE_REGEX, "\n[加载失败的图片]\n", line) - continue - - line.replace("\t", " ") - if font.getsize(line)[0] > image_width: - cp = line - _width = 0 - _word_cnt = 0 - for ii in range(len(line)): - # 检测是否是中文 - _width += font.getsize(line[ii])[0] - _word_cnt += 1 - if _width > image_width: - _pre_lines.append(cp[:_word_cnt]) - cp = cp[_word_cnt:] - _word_cnt = 0 - _width = 0 - _pre_lines.append(cp) - else: - _pre_lines.append(line) - pre_lines = _pre_lines - - i = -1 - for line in pre_lines: - if line == "": - height += TEXT_LINE_MARGIN - continue - i += 1 - line = line.strip() - if pre_in_code and not line.startswith("```"): - height += font_size + CODE_BLOCK_TEXT_MARGIN - # pre_codes.append(line) - continue - if line.startswith("#"): - header_level = line.count("#") - height += HEADER_FONT_STANDARD_SIZE + HEADER_MARGIN*2 - header_level * 4 - elif line.startswith("-"): - height += font_size+LIST_MARGIN*2 - elif line.startswith(">"): - height += font_size+QUOTE_LEFT_LINE_MARGIN*2 - elif line.startswith("```"): - if pre_in_code: - pre_in_code = False - # pre_codes = [] - height += CODE_BLOCK_MARGIN - else: - pre_in_code = True - height += CODE_BLOCK_MARGIN - elif re.search(r"`(.*?)`", line): - height += font_size+INLINE_CODE_FONT_MARGIN*2+INLINE_CODE_MARGIN*2 - else: - height += font_size + TEXT_LINE_MARGIN*2 - - markdown_text = '\n'.join(pre_lines) - image_height = height - if image_height < 100: - image_height = 100 - image_width += 20 - - # 创建空白图像 - image = Image.new('RGB', (image_width, image_height), bg_color) - draw = ImageDraw.Draw(image) - - # 设置初始位置 - x, y = 10, 10 - - # 解析Markdown文本 - lines = markdown_text.split("\n") - # lines = pre_lines - - in_code_block = False - code_block_start_y = 0 - code_block_codes = [] - - index = -1 - for line in lines: - index += 1 - if in_code_block and not line.startswith("```"): - code_block_codes.append(line) - y += font_size + CODE_BLOCK_TEXT_MARGIN - continue - line = line.strip() - - if line.startswith("#"): - # 处理标题 - header_level = line.count("#") - line = line.strip("#").strip() - font_size_header = HEADER_FONT_STANDARD_SIZE - header_level * 4 - font = get_font(font_size_header) - y += HEADER_MARGIN # 上边距 - # 字间距 - draw.text((x, y), line, font=font, fill=font_color) - draw.line((x, y + font_size_header + 8, image_width - 10, - y + font_size_header + 8), fill=(230, 230, 230), width=3) - y += font_size_header + HEADER_MARGIN - - elif line.startswith(">"): - # 处理引用 - quote_text = line.strip(">") - y += QUOTE_LEFT_LINE_MARGIN - draw.line((x, y, x, y + QUOTE_LEFT_LINE_HEIGHT), - fill=QUOTE_LEFT_LINE_COLOR, width=QUOTE_LEFT_LINE_WIDTH) - font = get_font(QUOTE_FONT_SIZE) - draw.text((x + QUOTE_FONT_LINE_MARGIN, y + QUOTE_FONT_LINE_MARGIN), - quote_text, font=font, fill=QUOTE_FONT_COLOR) - y += font_size + QUOTE_LEFT_LINE_HEIGHT + QUOTE_LEFT_LINE_MARGIN - - elif line.startswith("-"): - # 处理列表 - list_text = line.strip("-").strip() - font = get_font(LIST_FONT_SIZE) - y += LIST_MARGIN - draw.text((x, y), " · " + list_text, - font=font, fill=LIST_FONT_COLOR) - y += font_size + LIST_MARGIN - - elif line.startswith("```"): - if not in_code_block: - code_block_start_y = y+CODE_BLOCK_MARGIN - in_code_block = True - else: - # print(code_block_codes) - in_code_block = False - codes = "\n".join(code_block_codes) - code_block_codes = [] - draw.rounded_rectangle((x, code_block_start_y, image_width - 10, y+CODE_BLOCK_CODES_MARGIN_VERTICAL + - CODE_BLOCK_TEXT_MARGIN), radius=5, fill=CODE_BLOCK_BG_COLOR, width=2) - font = get_font(CODE_BLOCK_FONT_SIZE) - draw.text((x + CODE_BLOCK_CODES_MARGIN_HORIZONTAL, code_block_start_y + - CODE_BLOCK_CODES_MARGIN_VERTICAL), codes, font=font, fill=font_color) - y += CODE_BLOCK_CODES_MARGIN_VERTICAL + CODE_BLOCK_MARGIN - # y += font_size+10 - elif re.search(r"`(.*?)`", line): - y += INLINE_CODE_MARGIN # 上边距 - # 处理行内代码 - code_regex = r"`(.*?)`" - parts_inline = re.findall(code_regex, line) - # print(parts_inline) - parts = re.split(code_regex, line) - # print(parts) - for part in parts: - # the judge has a tiny bug. - # when line is like "hi`hi`". all the parts will be in parts_inline. - if part in parts_inline: - font = get_font(INLINE_CODE_FONT_SIZE) - code_text = part.strip("`") - code_width = font.getsize( - code_text)[0] + INLINE_CODE_FONT_MARGIN*2 - x += INLINE_CODE_MARGIN - code_box = (x, y, x + code_width, - y + INLINE_CODE_BG_HEIGHT) - draw.rounded_rectangle( - code_box, radius=5, fill=INLINE_CODE_BG_COLOR, width=2) # 使用灰色填充矩形框作为引用背景 - draw.text((x+INLINE_CODE_FONT_MARGIN, y), - code_text, font=font, fill=font_color) - x += code_width+INLINE_CODE_MARGIN-INLINE_CODE_FONT_MARGIN - else: - font = get_font(font_size) - draw.text((x, y), part, font=font, fill=font_color) - x += font.getsize(part)[0] - y += font_size + INLINE_CODE_MARGIN - x = 10 - - else: - # 处理普通文本 - if line == "": - y += TEXT_LINE_MARGIN - else: - font = get_font(font_size) - - draw.text((x, y), line, font=font, fill=font_color) - y += font_size + TEXT_LINE_MARGIN*2 - - # 图片特殊处理 - if index in images: - image_res = images[index] - # 最大不得超过image_width的50% - if image_res.size[0] > image_width*0.5: - image_res = image_res.resize( - (int(image_width*0.5), int(image_res.size[1]*image_width*0.5/image_res.size[0]))) - image.paste(image_res, (IMAGE_MARGIN, y)) - y += image_res.size[1] + IMAGE_MARGIN*2 - return image - - -def save_temp_img(img: Image) -> str: - if not os.path.exists("temp"): - os.makedirs("temp") - - # 获得文件创建时间,清除超过1小时的 - try: - for f in os.listdir("temp"): - path = os.path.join("temp", f) - if os.path.isfile(path): - ctime = os.path.getctime(path) - if time.time() - ctime > 3600: - os.remove(path) - except Exception as e: - print(f"清除临时文件失败: {e}") - - # 获得时间戳 - timestamp = int(time.time()) - p = f"temp/{timestamp}.jpg" - - if isinstance(img, Image.Image): - img.save(p) - else: - with open(p, "wb") as f: - f.write(img) - logger.info(f"保存临时图片: {p}") - return p - -async def download_image_by_url(url: str, post: bool = False, post_data: dict = None) -> str: - ''' - 下载图片 - ''' - try: - logger.info(f"下载图片: {url}") - async with aiohttp.ClientSession() as session: - if post: - async with session.post(url, json=post_data) as resp: - return save_temp_img(await resp.read()) - else: - async with session.get(url) as resp: - return save_temp_img(await resp.read()) - except aiohttp.client_exceptions.ClientConnectorSSLError as e: - # 关闭SSL验证 - ssl_context = ssl.create_default_context() - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE - async with aiohttp.ClientSession(trust_env=False) as session: - if post: - async with session.get(url, ssl=ssl_context) as resp: - return save_temp_img(await resp.read()) - else: - async with session.get(url, ssl=ssl_context) as resp: - return save_temp_img(await resp.read()) - except Exception as e: - raise e - -def download_file(url: str, path: str): - ''' - 从指定 url 下载文件到指定路径 path - ''' - try: - logger.info(f"下载文件: {url}") - with requests.get(url, stream=True) as r: - with open(path, 'wb') as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - except Exception as e: - raise e - - -def create_markdown_image(text: str): - ''' - markdown文本转图片。 - 返回:文件路径 - ''' - try: - img = render_markdown(text) - p = save_temp_img(img) - return p - except Exception as e: - raise e - - -def get_local_ip_addresses(): - ip = '' - try: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - s.connect(('8.8.8.8', 80)) - ip = s.getsockname()[0] - except BaseException as e: - pass - finally: - s.close() - return ip - - -def upload(_global_object: GlobalObject): +async def upload_metrics(_global_object: Context): ''' 上传相关非敏感统计数据 ''' - time.sleep(10) + await asyncio.sleep(10) while True: platform_stats = {} llm_stats = {} plugin_stats = {} for platform in _global_object.platforms: platform_stats[platform.platform_name] = { - "cnt_receive": platform.platform_instance.cnt_receive, - "cnt_reply": platform.platform_instance.cnt_reply + "cnt_receive": 0, + "cnt_reply": 0 } for llm in _global_object.llms: @@ -442,7 +56,7 @@ def upload(_global_object: GlobalObject): _global_object.cnt_total = 0 except BaseException as e: pass - time.sleep(30*60) + await asyncio.sleep(30*60) def retry(n: int = 3): ''' @@ -459,7 +73,7 @@ def wrapper(*args, **kwargs): return wrapper return decorator -def run_monitor(global_object: GlobalObject): +def run_monitor(global_object: Context): ''' 监测机器性能 - Bot 内存使用量 @@ -476,24 +90,3 @@ def run_monitor(global_object: GlobalObject): } stat['sys_start_time'] = start_time time.sleep(30) - -def remove_dir(file_path) -> bool: - if not os.path.exists(file_path): return True - try: - shutil.rmtree(file_path, onerror=on_error) - return True - except BaseException as e: - logger.error(f"删除文件/文件夹 {file_path} 失败: {str(e)}") - return False - -def on_error(func, path, exc_info): - ''' - a callback of the rmtree function. - ''' - print(f"remove {path} failed.") - import stat - if not os.access(path, os.W_OK): - os.chmod(path, stat.S_IWUSR) - func(path) - else: - raise \ No newline at end of file diff --git a/util/image_render/helper.py b/util/image_render/helper.py deleted file mode 100644 index 7c80e38..0000000 --- a/util/image_render/helper.py +++ /dev/null @@ -1,43 +0,0 @@ -import aiohttp, os -from util.general_utils import download_image_by_url, create_markdown_image -from type.config import VERSION - -BASE_RENDER_URL = "https://t2i.soulter.top/text2img" -TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template") - -async def text_to_image_base(text: str, return_url: bool = False) -> str: - ''' - 返回图像的文件路径 - ''' - with open(os.path.join(TEMPLATE_PATH, "base.html"), "r", encoding='utf-8') as f: - tmpl_str = f.read() - - assert(tmpl_str) - - text = text.replace("`", "\`") - - post_data = { - "tmpl": tmpl_str, - "json": return_url, - "tmpldata": { - "text": text, - "version": f"v{VERSION}", - }, - "options": { - "full_page": True - } - } - - if return_url: - async with aiohttp.ClientSession() as session: - async with session.post(f"{BASE_RENDER_URL}/generate", json=post_data) as resp: - ret = await resp.json() - return f"{BASE_RENDER_URL}/{ret['data']['id']}" - else: - image_path = "" - try: - image_path = await download_image_by_url(f"{BASE_RENDER_URL}/generate", post=True, post_data=post_data) - except Exception as e: - print(f"调用 markdown 渲染 API 失败,错误信息:{e},将使用本地渲染方式。") - image_path = create_markdown_image(text) - return image_path \ No newline at end of file diff --git a/util/image_uploader.py b/util/image_uploader.py new file mode 100644 index 0000000..a9a76d0 --- /dev/null +++ b/util/image_uploader.py @@ -0,0 +1,21 @@ +import aiohttp +import uuid + +class ImageUploader(): + def __init__(self) -> None: + self.S3_URL = "https://s3.neko.soulter.top/astrbot-s3" + + async def upload_image(self, image_path: str) -> str: + ''' + 上传图像文件到S3 + ''' + with open(image_path, "rb") as f: + image = f.read() + + image_url = f"{self.S3_URL}/{uuid.uuid4().hex}.jpg" + + async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session: + async with session.put(image_url, data=image) as resp: + if resp.status != 200: + raise Exception(f"Failed to upload image: {resp.status}") + return image_url diff --git a/util/io.py b/util/io.py new file mode 100644 index 0000000..b6032d3 --- /dev/null +++ b/util/io.py @@ -0,0 +1,128 @@ +import os +import ssl +import shutil +import socket +import time +import aiohttp +import requests + +from PIL import Image +from SparkleLogging.utils.core import LogManager +from logging import Logger + +logger: Logger = LogManager.GetLogger(log_name='astrbot') + + +def on_error(func, path, exc_info): + ''' + a callback of the rmtree function. + ''' + print(f"remove {path} failed.") + import stat + if not os.access(path, os.W_OK): + os.chmod(path, stat.S_IWUSR) + func(path) + else: + raise + +def remove_dir(file_path) -> bool: + if not os.path.exists(file_path): return True + try: + shutil.rmtree(file_path, onerror=on_error) + return True + except BaseException as e: + logger.error(f"删除文件/文件夹 {file_path} 失败: {str(e)}") + return False + +def port_checker(port: int, host: str = "localhost"): + sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sk.settimeout(1) + try: + sk.connect((host, port)) + sk.close() + return True + except Exception: + sk.close() + return False + + +def save_temp_img(img: Image) -> str: + if not os.path.exists("temp"): + os.makedirs("temp") + + # 获得文件创建时间,清除超过1小时的 + try: + for f in os.listdir("temp"): + path = os.path.join("temp", f) + if os.path.isfile(path): + ctime = os.path.getctime(path) + if time.time() - ctime > 3600: + os.remove(path) + except Exception as e: + print(f"清除临时文件失败: {e}") + + # 获得时间戳 + timestamp = int(time.time()) + p = f"temp/{timestamp}.jpg" + + if isinstance(img, Image.Image): + img.save(p) + else: + with open(p, "wb") as f: + f.write(img) + logger.info(f"保存临时图片: {p}") + return p + +async def download_image_by_url(url: str, post: bool = False, post_data: dict = None) -> str: + ''' + 下载图片, 返回 path + ''' + try: + logger.info(f"下载图片: {url}") + async with aiohttp.ClientSession() as session: + if post: + async with session.post(url, json=post_data) as resp: + return save_temp_img(await resp.read()) + else: + async with session.get(url) as resp: + return save_temp_img(await resp.read()) + except aiohttp.client_exceptions.ClientConnectorSSLError as e: + # 关闭SSL验证 + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + async with aiohttp.ClientSession(trust_env=False) as session: + if post: + async with session.get(url, ssl=ssl_context) as resp: + return save_temp_img(await resp.read()) + else: + async with session.get(url, ssl=ssl_context) as resp: + return save_temp_img(await resp.read()) + except Exception as e: + raise e + +def download_file(url: str, path: str): + ''' + 从指定 url 下载文件到指定路径 path + ''' + try: + logger.info(f"下载文件: {url}") + with requests.get(url, stream=True) as r: + with open(path, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + except Exception as e: + raise e + + +def get_local_ip_addresses(): + ip = '' + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(('8.8.8.8', 80)) + ip = s.getsockname()[0] + except BaseException as e: + pass + finally: + s.close() + return ip diff --git a/util/plugin_dev/api/v1/bot.py b/util/plugin_dev/api/v1/bot.py index 46a3b7f..c334872 100644 --- a/util/plugin_dev/api/v1/bot.py +++ b/util/plugin_dev/api/v1/bot.py @@ -1,5 +1,5 @@ from type.plugin import PluginMetadata, PluginType from type.register import RegisteredLLM, RegisteredPlatform, RegisteredPlugin, RegisteredPlugins -from type.types import GlobalObject -from type.message import AstrMessageEvent +from type.types import Context +from type.message_event import AstrMessageEvent from type.command import CommandResult \ No newline at end of file diff --git a/util/plugin_dev/api/v1/message.py b/util/plugin_dev/api/v1/message.py index c1b3243..182e911 100644 --- a/util/plugin_dev/api/v1/message.py +++ b/util/plugin_dev/api/v1/message.py @@ -1,10 +1,3 @@ -from astrbot.core import oper_msg -from type.message import * -from type.command import CommandResult -from model.platform._message_result import MessageResult - -''' -消息处理。在消息平台接收到消息后,调用此函数进行处理。 -集成了指令检测、指令处理、LLM 调用等功能。 -''' -message_handler = oper_msg +from type.message_event import * +from type.astrbot_message import * +from type.command import CommandResult \ No newline at end of file diff --git a/util/plugin_dev/api/v1/platform.py b/util/plugin_dev/api/v1/platform.py index bf93986..47bd617 100644 --- a/util/plugin_dev/api/v1/platform.py +++ b/util/plugin_dev/api/v1/platform.py @@ -5,7 +5,7 @@ 消息平台的具体实现类需要继承Platform类,并实现其中的抽象方法。 ''' -from model.platform._platfrom import Platform +from model.platform import Platform -from model.platform.qq_gocq import QQGOCQ +from model.platform.qq_nakuru import QQGOCQ from model.platform.qq_official import QQOfficial \ No newline at end of file diff --git a/util/plugin_dev/api/v1/register.py b/util/plugin_dev/api/v1/register.py index 83ccd54..6ebef89 100644 --- a/util/plugin_dev/api/v1/register.py +++ b/util/plugin_dev/api/v1/register.py @@ -4,11 +4,11 @@ 必须分别实现 Platform 和 LLMProvider 中涉及的接口 ''' from model.provider.provider import Provider as LLMProvider -from model.platform._platfrom import Platform -from type.types import GlobalObject +from model.platform import Platform +from type.types import Context from type.register import RegisteredPlatform, RegisteredLLM -def register_platform(platform_name: str, platform_instance: Platform, context: GlobalObject) -> None: +def register_platform(platform_name: str, platform_instance: Platform, context: Context) -> None: ''' 注册一个消息平台。 @@ -24,7 +24,7 @@ def register_platform(platform_name: str, platform_instance: Platform, context: context.platforms.append(RegisteredPlatform(platform_name, platform_instance)) -def register_llm(llm_name: str, llm_instance: LLMProvider, context: GlobalObject) -> None: +def register_llm(llm_name: str, llm_instance: LLMProvider, context: Context) -> None: ''' 注册一个大语言模型。 @@ -39,7 +39,7 @@ def register_llm(llm_name: str, llm_instance: LLMProvider, context: GlobalObject context.llms.append(RegisteredLLM(llm_name, llm_instance)) -def unregister_platform(platform_name: str, context: GlobalObject) -> None: +def unregister_platform(platform_name: str, context: Context) -> None: ''' 注销一个消息平台。 @@ -51,7 +51,7 @@ def unregister_platform(platform_name: str, context: GlobalObject) -> None: context.platforms.pop(i) return -def unregister_llm(llm_name: str, context: GlobalObject) -> None: +def unregister_llm(llm_name: str, context: Context) -> None: ''' 注销一个大语言模型。 diff --git a/util/plugin_util.py b/util/plugin_util.py deleted file mode 100644 index da9cf9c..0000000 --- a/util/plugin_util.py +++ /dev/null @@ -1,361 +0,0 @@ -''' -插件工具函数 -''' -import os, sys, zipfile, shutil, yaml -import inspect -import traceback -import uuid - -from types import ModuleType -from type.plugin import * -from type.register import * -from SparkleLogging.utils.core import LogManager -from logging import Logger -from type.types import GlobalObject -from util.general_utils import download_file, remove_dir -from util.updator import request_release_info - -logger: Logger = LogManager.GetLogger(log_name='astrbot-core') - - - -# 找出模块里所有的类名 -def get_classes(p_name, arg: ModuleType): - classes = [] - clsmembers = inspect.getmembers(arg, inspect.isclass) - for (name, _) in clsmembers: - if name.lower().endswith("plugin") or name.lower() == "main": - classes.append(name) - break - # if p_name.lower() == name.lower()[:-6] or name.lower() == "main": - return classes - -# 获取一个文件夹下所有的模块, 文件名和文件夹名相同 - - -def get_modules(path): - modules = [] - - # 得到其下的所有文件夹 - dirs = os.listdir(path) - # 遍历文件夹,找到 main.py 或者和文件夹同名的文件 - for d in dirs: - if os.path.isdir(os.path.join(path, d)): - if os.path.exists(os.path.join(path, d, "main.py")): - module_str = 'main' - elif os.path.exists(os.path.join(path, d, d + ".py")): - module_str = d - else: - print(f"插件 {d} 未找到 main.py 或者 {d}.py,跳过。") - continue - if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists(os.path.join(path, d, d + ".py")): - modules.append({ - "pname": d, - "module": module_str, - "module_path": os.path.join(path, d, module_str) - }) - return modules - - -def get_plugin_store_path(): - plugin_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../addons/plugins")) - return plugin_dir - -def get_plugin_modules(): - plugins = [] - try: - plugin_dir = get_plugin_store_path() - if os.path.exists(plugin_dir): - plugins = get_modules(plugin_dir) - return plugins - except BaseException as e: - raise e - -def check_plugin_dept_update(cached_plugins: RegisteredPlugins, target_plugin: str = None): - plugin_dir = get_plugin_store_path() - if not os.path.exists(plugin_dir): - return False - to_update = [] - if target_plugin: - to_update.append(target_plugin) - else: - for p in cached_plugins: - to_update.append(p.root_dir_name) - for p in to_update: - plugin_path = os.path.join(plugin_dir, p) - if os.path.exists(os.path.join(plugin_path, "requirements.txt")): - pth = os.path.join(plugin_path, "requirements.txt") - logger.info(f"正在检查更新插件 {p} 的依赖: {pth}") - update_plugin_dept(os.path.join(plugin_path, "requirements.txt")) - - -def has_init_param(cls, param_name): - try: - # 获取 __init__ 方法的签名 - init_signature = inspect.signature(cls.__init__) - - # 检查参数名是否在签名中 - return param_name in init_signature.parameters - except (AttributeError, ValueError): - # 如果类没有 __init__ 方法或者无法获取签名 - return False - -def plugin_reload(ctx: GlobalObject): - cached_plugins = ctx.cached_plugins - plugins = get_plugin_modules() - if plugins is None: - return False, "未找到任何插件模块" - fail_rec = "" - - registered_map = {} - for p in cached_plugins: - registered_map[p.module_path] = None - - for plugin in plugins: - try: - p = plugin['module'] - module_path = plugin['module_path'] - root_dir_name = plugin['pname'] - - check_plugin_dept_update(cached_plugins, root_dir_name) - - module = __import__("addons.plugins." + - root_dir_name + "." + p, fromlist=[p]) - - cls = get_classes(p, module) - - try: - # 尝试传入 ctx - obj = getattr(module, cls[0])(ctx=ctx) - except: - obj = getattr(module, cls[0])() - - metadata = None - try: - info = obj.info() - if isinstance(info, dict): - if 'name' not in info or 'desc' not in info or 'version' not in info or 'author' not in info: - fail_rec += f"注册插件 {module_path} 失败,原因: 插件信息不完整\n" - continue - else: - metadata = PluginMetadata( - plugin_name=info['name'], - plugin_type=PluginType.COMMON if 'plugin_type' not in info else PluginType(info['plugin_type']), - author=info['author'], - desc=info['desc'], - version=info['version'], - repo=info['repo'] if 'repo' in info else None - ) - elif isinstance(info, PluginMetadata): - metadata = info - else: - fail_rec += f"注册插件 {module_path} 失败,原因: info 函数返回值类型错误\n" - continue - except BaseException as e: - fail_rec += f"注册插件 {module_path} 失败, 原因: {str(e)}\n" - continue - - if module_path not in registered_map: - cached_plugins.append(RegisteredPlugin( - metadata=metadata, - plugin_instance=obj, - module=module, - module_path=module_path, - root_dir_name=root_dir_name - )) - except BaseException as e: - traceback.print_exc() - fail_rec += f"加载{p}插件出现问题,原因 {str(e)}\n" - if fail_rec == "": - return True, None - else: - return False, fail_rec - -def update_plugin_dept(path): - mirror = "https://mirrors.aliyun.com/pypi/simple/" - py = sys.executable - os.system(f"{py} -m pip install -r {path} -i {mirror} --quiet") - - -def install_plugin(repo_url: str, ctx: GlobalObject): - ppath = get_plugin_store_path() - if repo_url.endswith("/"): - repo_url = repo_url[:-1] - - repo_namespace = repo_url.split("/")[-2:] - repo = repo_namespace[1] - - plugin_path = os.path.join(ppath, repo.replace("-", "_").lower()) - if os.path.exists(plugin_path): remove_dir(plugin_path) - - # we no longer use Git anymore :) - # Repo.clone_from(repo_url, to_path=plugin_path, branch='master') - - download_from_repo_url(plugin_path, repo_url) - unzip_file(plugin_path + ".zip", plugin_path) - - with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f: - f.write(repo_url) - - ok, err = plugin_reload(ctx) - if not ok: - raise Exception(err) - -def install_plugin_from_file(zip_file_path: str, ctx: GlobalObject): - # try to unzip - temp_dir = os.path.join(os.path.dirname(zip_file_path), str(uuid.uuid4())) - unzip_file(zip_file_path, temp_dir) - # check if the plugin has metadata.yaml - if not os.path.exists(os.path.join(temp_dir, "metadata.yaml")): - remove_dir(temp_dir) - raise Exception("插件缺少 metadata.yaml 文件。") - - metadata = load_plugin_metadata(temp_dir) - plugin_name = metadata.plugin_name - if not plugin_name: - remove_dir(temp_dir) - raise Exception("插件 metadata.yaml 文件中 name 字段为空。") - plugin_name = plugin_name.replace("-", "_").lower() - - ppath = get_plugin_store_path() - plugin_path = os.path.join(ppath, plugin_name.replace("-", "_").lower()) - if os.path.exists(plugin_path): remove_dir(plugin_path) - - # move to the target path - shutil.move(temp_dir, plugin_path) - - with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f: - if metadata.repo: f.write(metadata.repo) - - # remove the temp dir - remove_dir(temp_dir) - - ok, err = plugin_reload(ctx) - if not ok: - raise Exception(err) - -def load_plugin_metadata(plugin_path: str) -> PluginMetadata: - if not os.path.exists(plugin_path): - raise Exception("插件不存在。") - if not os.path.exists(os.path.join(plugin_path, "metadata.yaml")): - raise Exception("插件缺少 metadata.yaml 文件。") - metadata = None - with open(os.path.join(plugin_path, "metadata.yaml"), "r", encoding='utf-8') as f: - metadata = yaml.safe_load(f) - if 'name' not in metadata or 'desc' not in metadata or 'version' not in metadata or 'author' not in metadata: - raise Exception("插件 metadata.yaml 信息不完整。") - return PluginMetadata( - plugin_name=metadata['name'], - plugin_type=PluginType.COMMON if 'plugin_type' not in metadata else PluginType(metadata['plugin_type']), - author=metadata['author'], - desc=metadata['desc'], - version=metadata['version'], - repo=metadata['repo'] if 'repo' in metadata else None - ) - - -def download_from_repo_url(target_path: str, repo_url: str): - repo_namespace = repo_url.split("/")[-2:] - author = repo_namespace[0] - repo = repo_namespace[1] - - logger.info(f"正在下载插件 {repo} ...") - release_url = f"https://api.github.com/repos/{author}/{repo}/releases" - releases = request_release_info(latest=True, url=release_url, mirror_url=release_url) - if not releases: - # download from the default branch directly. - logger.warn(f"未在插件 {author}/{repo} 中找到任何发布版本,将从默认分支下载。") - release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip" - else: - release_url = releases[0]['zipball_url'] - - download_file(release_url, target_path + ".zip") - - -def get_registered_plugin(plugin_name: str, cached_plugins: RegisteredPlugins) -> RegisteredPlugin: - ret = None - for p in cached_plugins: - if p.metadata.plugin_name == plugin_name: - ret = p - break - return ret - - -def uninstall_plugin(plugin_name: str, ctx: GlobalObject): - plugin = get_registered_plugin(plugin_name, ctx.cached_plugins) - if not plugin: - raise Exception("插件不存在。") - root_dir_name = plugin.root_dir_name - ppath = get_plugin_store_path() - ctx.cached_plugins.remove(plugin) - if not remove_dir(os.path.join(ppath, root_dir_name)): - raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。") - - -def update_plugin(plugin_name: str, ctx: GlobalObject): - plugin = get_registered_plugin(plugin_name, ctx.cached_plugins) - if not plugin: - raise Exception("插件不存在。") - ppath = get_plugin_store_path() - root_dir_name = plugin.root_dir_name - plugin_path = os.path.join(ppath, root_dir_name) - - if not os.path.exists(os.path.join(plugin_path, "REPO")): - raise Exception("插件更新信息文件 `REPO` 不存在,请手动升级,或者先卸载然后重新安装该插件。") - - repo_url = None - with open(os.path.join(plugin_path, "REPO"), "r", encoding='utf-8') as f: - repo_url = f.read() - - download_from_repo_url(plugin_path, repo_url) - try: - remove_dir(plugin_path) - except BaseException as e: - logger.error(f"删除旧版本插件 {plugin_name} 文件夹失败: {str(e)},使用覆盖安装。") - unzip_file(plugin_path + ".zip", plugin_path) - - ok, err = plugin_reload(ctx) - if not ok: - raise Exception(err) - -def unzip_file(zip_path: str, target_dir: str): - ''' - 解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir - ''' - os.makedirs(target_dir, exist_ok=True) - update_dir = "" - logger.info(f"解压文件: {zip_path}") - with zipfile.ZipFile(zip_path, 'r') as z: - update_dir = z.namelist()[0] - z.extractall(target_dir) - - files = os.listdir(os.path.join(target_dir, update_dir)) - for f in files: - logger.info(f"移动更新文件/目录: {f}") - if os.path.isdir(os.path.join(target_dir, update_dir, f)): - if os.path.exists(os.path.join(target_dir, f)): - shutil.rmtree(os.path.join(target_dir, f), onerror=on_error) - else: - if os.path.exists(os.path.join(target_dir, f)): - os.remove(os.path.join(target_dir, f)) - shutil.move(os.path.join(target_dir, update_dir, f), target_dir) - - try: - logger.info(f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}") - shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error) - os.remove(zip_path) - except: - logger.warn(f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}") - - -def on_error(func, path, exc_info): - ''' - a callback of the rmtree function. - ''' - print(f"remove {path} failed.") - import stat - if not os.access(path, os.W_OK): - os.chmod(path, stat.S_IWUSR) - func(path) - else: - raise \ No newline at end of file diff --git a/util/search_engine_scraper/test.py b/util/search_engine_scraper/test.py deleted file mode 100644 index a1be25e..0000000 --- a/util/search_engine_scraper/test.py +++ /dev/null @@ -1,22 +0,0 @@ -from sogo import Sogo -from bing import Bing - -sogo_search = Sogo() -bing_search = Bing() -async def search(keyword: str) -> str: - results = await sogo_search.search(keyword, 5) - # results = await bing_search.search(keyword, 5) - ret = "" - if len(results) == 0: - return "没有搜索到结果" - - idx = 1 - for i in results: - ret += f"{idx}. {i.title}({i.url})\n{i.snippet}\n\n" - idx += 1 - - return ret - -import asyncio -ret = asyncio.run(search("gpt4orelease")) -print(ret) \ No newline at end of file diff --git a/util/t2i/__init__.py b/util/t2i/__init__.py new file mode 100644 index 0000000..047b584 --- /dev/null +++ b/util/t2i/__init__.py @@ -0,0 +1 @@ +from .renderer import TextToImageRenderer \ No newline at end of file diff --git a/util/t2i/context.py b/util/t2i/context.py new file mode 100644 index 0000000..255fb89 --- /dev/null +++ b/util/t2i/context.py @@ -0,0 +1,11 @@ +from util.t2i.strategies.base_strategy import RenderStrategy + +class RenderContext: + def __init__(self, strategy: RenderStrategy): + self._strategy = strategy + + def set_strategy(self, strategy: RenderStrategy): + self._strategy = strategy + + async def render(self, text: str) -> str: + return await self._strategy.render(text) diff --git a/util/t2i/renderer.py b/util/t2i/renderer.py new file mode 100644 index 0000000..28df04b --- /dev/null +++ b/util/t2i/renderer.py @@ -0,0 +1,25 @@ +from util.t2i.strategies.local_strategy import LocalRenderStrategy +from util.t2i.strategies.network_strategy import NetworkRenderStrategy +from util.t2i.context import RenderContext +from SparkleLogging.utils.core import LogManager +from logging import Logger + +logger: Logger = LogManager.GetLogger(log_name='astrbot') + +class TextToImageRenderer: + def __init__(self): + self.network_strategy = NetworkRenderStrategy() + self.local_strategy = LocalRenderStrategy() + self.context = RenderContext(self.network_strategy) + + async def render(self, text: str, use_network: bool = True, return_url: bool = False): + if use_network: + try: + return await self.context.render(text) + except BaseException as e: + logger.error(f"Failed to render image via AstrBot API: {e}. Falling back to local rendering.") + self.context.set_strategy(self.local_strategy) + return await self.context.render(text) + else: + self.context.set_strategy(self.local_strategy) + return await self.context.render(text) diff --git a/util/t2i/strategies/__init__.py b/util/t2i/strategies/__init__.py new file mode 100644 index 0000000..1fd9d92 --- /dev/null +++ b/util/t2i/strategies/__init__.py @@ -0,0 +1,3 @@ +from .base_strategy import RenderStrategy +from .local_strategy import LocalRenderStrategy +from .network_strategy import NetworkRenderStrategy diff --git a/util/t2i/strategies/base_strategy.py b/util/t2i/strategies/base_strategy.py new file mode 100644 index 0000000..35fbb1b --- /dev/null +++ b/util/t2i/strategies/base_strategy.py @@ -0,0 +1,7 @@ +from abc import ABC, abstractmethod + +class RenderStrategy(ABC): + @abstractmethod + def render(self, text: str, return_url: bool) -> str: + pass + \ No newline at end of file diff --git a/util/t2i/strategies/local_strategy.py b/util/t2i/strategies/local_strategy.py new file mode 100644 index 0000000..56210bc --- /dev/null +++ b/util/t2i/strategies/local_strategy.py @@ -0,0 +1,278 @@ +import re +import requests + +from .base_strategy import RenderStrategy +from PIL import ImageFont, Image, ImageDraw +from util.io import save_temp_img + +class LocalRenderStrategy(RenderStrategy): + + def get_font(self, size: int) -> ImageFont.FreeTypeFont: + # common and default fonts on Windows, macOS and Linux + fonts = ["msyh.ttc", "NotoSansCJK-Regular.ttc", "msyhbd.ttc", "PingFang.ttc", "Heiti.ttc"] + for font in fonts: + try: + font = ImageFont.truetype(font, size) + return font + except Exception as e: + pass + + async def render(self, text: str): + font_size = 26 + image_width = 800 + image_height = 600 + font_color = (0, 0, 0) + bg_color = (255, 255, 255) + + HEADER_MARGIN = 20 + HEADER_FONT_STANDARD_SIZE = 42 + + QUOTE_LEFT_LINE_MARGIN = 10 + QUOTE_FONT_LINE_MARGIN = 6 # 引用文字距离左边线的距离和上下的距离 + QUOTE_LEFT_LINE_HEIGHT = font_size + QUOTE_FONT_LINE_MARGIN * 2 + QUOTE_LEFT_LINE_WIDTH = 5 + QUOTE_LEFT_LINE_COLOR = (180, 180, 180) + QUOTE_FONT_SIZE = font_size + QUOTE_FONT_COLOR = (180, 180, 180) + # QUOTE_BG_COLOR = (255, 255, 255) + + CODE_BLOCK_MARGIN = 10 + CODE_BLOCK_FONT_SIZE = font_size + CODE_BLOCK_FONT_COLOR = (255, 255, 255) + CODE_BLOCK_BG_COLOR = (240, 240, 240) + CODE_BLOCK_CODES_MARGIN_VERTICAL = 5 # 代码块和代码之间的距离 + CODE_BLOCK_CODES_MARGIN_HORIZONTAL = 5 # 代码块和代码之间的距离 + CODE_BLOCK_TEXT_MARGIN = 4 # 代码和代码之间的距离 + + INLINE_CODE_MARGIN = 8 + INLINE_CODE_FONT_SIZE = font_size + INLINE_CODE_FONT_COLOR = font_color + INLINE_CODE_FONT_MARGIN = 4 + INLINE_CODE_BG_COLOR = (230, 230, 230) + INLINE_CODE_BG_HEIGHT = INLINE_CODE_FONT_SIZE + INLINE_CODE_FONT_MARGIN * 2 + + LIST_MARGIN = 8 + LIST_FONT_SIZE = font_size + LIST_FONT_COLOR = font_color + + TEXT_LINE_MARGIN = 8 + + IMAGE_MARGIN = 15 + # 用于匹配图片的正则表达式 + IMAGE_REGEX = r"!\s*\[.*?\]\s*\((.*?)\)" + + # 加载字体 + font = self.get_font(font_size) + + images: Image = {} + + # pre_process, get height of each line + pre_lines = text.split('\n') + height = 0 + pre_in_code = False + i = -1 + _pre_lines = [] + for line in pre_lines: + i += 1 + # 处理图片 + if re.search(IMAGE_REGEX, line): + try: + image_url = re.findall(IMAGE_REGEX, line)[0] + print(image_url) + image_res = Image.open(requests.get( + image_url, stream=True, timeout=5).raw) + images[i] = image_res + # 最大不得超过image_width的50% + img_height = image_res.size[1] + + if image_res.size[0] > image_width*0.5: + image_res = image_res.resize( + (int(image_width*0.5), int(image_res.size[1]*image_width*0.5/image_res.size[0]))) + img_height = image_res.size[1] + + height += img_height + IMAGE_MARGIN*2 + + line = re.sub(IMAGE_REGEX, "", line) + except Exception as e: + print(e) + line = re.sub(IMAGE_REGEX, "\n[加载失败的图片]\n", line) + continue + + line.replace("\t", " ") + if font.getsize(line)[0] > image_width: + cp = line + _width = 0 + _word_cnt = 0 + for ii in range(len(line)): + # 检测是否是中文 + _width += font.getsize(line[ii])[0] + _word_cnt += 1 + if _width > image_width: + _pre_lines.append(cp[:_word_cnt]) + cp = cp[_word_cnt:] + _word_cnt = 0 + _width = 0 + _pre_lines.append(cp) + else: + _pre_lines.append(line) + pre_lines = _pre_lines + + i = -1 + for line in pre_lines: + if line == "": + height += TEXT_LINE_MARGIN + continue + i += 1 + line = line.strip() + if pre_in_code and not line.startswith("```"): + height += font_size + CODE_BLOCK_TEXT_MARGIN + # pre_codes.append(line) + continue + if line.startswith("#"): + header_level = line.count("#") + height += HEADER_FONT_STANDARD_SIZE + HEADER_MARGIN*2 - header_level * 4 + elif line.startswith("-"): + height += font_size+LIST_MARGIN*2 + elif line.startswith(">"): + height += font_size+QUOTE_LEFT_LINE_MARGIN*2 + elif line.startswith("```"): + if pre_in_code: + pre_in_code = False + # pre_codes = [] + height += CODE_BLOCK_MARGIN + else: + pre_in_code = True + height += CODE_BLOCK_MARGIN + elif re.search(r"`(.*?)`", line): + height += font_size+INLINE_CODE_FONT_MARGIN*2+INLINE_CODE_MARGIN*2 + else: + height += font_size + TEXT_LINE_MARGIN*2 + + text = '\n'.join(pre_lines) + image_height = height + if image_height < 100: + image_height = 100 + image_width += 20 + + # 创建空白图像 + image = Image.new('RGB', (image_width, image_height), bg_color) + draw = ImageDraw.Draw(image) + + # 设置初始位置 + x, y = 10, 10 + + # 解析Markdown文本 + lines = text.split("\n") + # lines = pre_lines + + in_code_block = False + code_block_start_y = 0 + code_block_codes = [] + + index = -1 + for line in lines: + index += 1 + if in_code_block and not line.startswith("```"): + code_block_codes.append(line) + y += font_size + CODE_BLOCK_TEXT_MARGIN + continue + line = line.strip() + + if line.startswith("#"): + # 处理标题 + header_level = line.count("#") + line = line.strip("#").strip() + font_size_header = HEADER_FONT_STANDARD_SIZE - header_level * 4 + font = self.get_font(font_size_header) + y += HEADER_MARGIN # 上边距 + # 字间距 + draw.text((x, y), line, font=font, fill=font_color) + draw.line((x, y + font_size_header + 8, image_width - 10, + y + font_size_header + 8), fill=(230, 230, 230), width=3) + y += font_size_header + HEADER_MARGIN + + elif line.startswith(">"): + # 处理引用 + quote_text = line.strip(">") + y += QUOTE_LEFT_LINE_MARGIN + draw.line((x, y, x, y + QUOTE_LEFT_LINE_HEIGHT), + fill=QUOTE_LEFT_LINE_COLOR, width=QUOTE_LEFT_LINE_WIDTH) + font = self.get_font(QUOTE_FONT_SIZE) + draw.text((x + QUOTE_FONT_LINE_MARGIN, y + QUOTE_FONT_LINE_MARGIN), + quote_text, font=font, fill=QUOTE_FONT_COLOR) + y += font_size + QUOTE_LEFT_LINE_HEIGHT + QUOTE_LEFT_LINE_MARGIN + + elif line.startswith("-"): + # 处理列表 + list_text = line.strip("-").strip() + font = self.get_font(LIST_FONT_SIZE) + y += LIST_MARGIN + draw.text((x, y), " · " + list_text, + font=font, fill=LIST_FONT_COLOR) + y += font_size + LIST_MARGIN + + elif line.startswith("```"): + if not in_code_block: + code_block_start_y = y+CODE_BLOCK_MARGIN + in_code_block = True + else: + # print(code_block_codes) + in_code_block = False + codes = "\n".join(code_block_codes) + code_block_codes = [] + draw.rounded_rectangle((x, code_block_start_y, image_width - 10, y+CODE_BLOCK_CODES_MARGIN_VERTICAL + + CODE_BLOCK_TEXT_MARGIN), radius=5, fill=CODE_BLOCK_BG_COLOR, width=2) + font = self.get_font(CODE_BLOCK_FONT_SIZE) + draw.text((x + CODE_BLOCK_CODES_MARGIN_HORIZONTAL, code_block_start_y + + CODE_BLOCK_CODES_MARGIN_VERTICAL), codes, font=font, fill=font_color) + y += CODE_BLOCK_CODES_MARGIN_VERTICAL + CODE_BLOCK_MARGIN + # y += font_size+10 + elif re.search(r"`(.*?)`", line): + y += INLINE_CODE_MARGIN # 上边距 + # 处理行内代码 + code_regex = r"`(.*?)`" + parts_inline = re.findall(code_regex, line) + parts = re.split(code_regex, line) + for part in parts: + # the judge has a tiny bug. + # when line is like "hi`hi`". all the parts will be in parts_inline. + if part in parts_inline: + font = self.get_font(INLINE_CODE_FONT_SIZE) + code_text = part.strip("`") + code_width = font.getsize( + code_text)[0] + INLINE_CODE_FONT_MARGIN*2 + x += INLINE_CODE_MARGIN + code_box = (x, y, x + code_width, + y + INLINE_CODE_BG_HEIGHT) + draw.rounded_rectangle( + code_box, radius=5, fill=INLINE_CODE_BG_COLOR, width=2) # 使用灰色填充矩形框作为引用背景 + draw.text((x+INLINE_CODE_FONT_MARGIN, y), + code_text, font=font, fill=font_color) + x += code_width+INLINE_CODE_MARGIN-INLINE_CODE_FONT_MARGIN + else: + font = self.get_font(font_size) + draw.text((x, y), part, font=font, fill=font_color) + x += font.getsize(part)[0] + y += font_size + INLINE_CODE_MARGIN + x = 10 + + else: + # 处理普通文本 + if line == "": + y += TEXT_LINE_MARGIN + else: + font = self.get_font(font_size) + + draw.text((x, y), line, font=font, fill=font_color) + y += font_size + TEXT_LINE_MARGIN*2 + + # 图片特殊处理 + if index in images: + image_res = images[index] + # 最大不得超过image_width的50% + if image_res.size[0] > image_width*0.5: + image_res = image_res.resize( + (int(image_width*0.5), int(image_res.size[1]*image_width*0.5/image_res.size[0]))) + image.paste(image_res, (IMAGE_MARGIN, y)) + y += image_res.size[1] + IMAGE_MARGIN*2 + return save_temp_img(image) diff --git a/util/t2i/strategies/network_strategy.py b/util/t2i/strategies/network_strategy.py new file mode 100644 index 0000000..cd59ce3 --- /dev/null +++ b/util/t2i/strategies/network_strategy.py @@ -0,0 +1,42 @@ +import aiohttp +import os + +from .base_strategy import RenderStrategy +from type.config import VERSION +from util.io import download_image_by_url + +class NetworkRenderStrategy(RenderStrategy): + def __init__(self) -> None: + super().__init__() + self.BASE_RENDER_URL = "https://t2i.soulter.top/text2img" + self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template") + + async def render(self, text: str, return_url: bool=False) -> str: + ''' + 返回图像的文件路径 + ''' + with open(os.path.join(self.TEMPLATE_PATH, "base.html"), "r", encoding='utf-8') as f: + tmpl_str = f.read() + + assert(tmpl_str) + + text = text.replace("`", "\`") + + post_data = { + "tmpl": tmpl_str, + "json": return_url, + "tmpldata": { + "text": text, + "version": f"v{VERSION}", + }, + "options": { + "full_page": True + } + } + + if return_url: + async with aiohttp.ClientSession() as session: + async with session.post(f"{self.BASE_RENDER_URL}/generate", json=post_data) as resp: + ret = await resp.json() + return f"{self.BASE_RENDER_URL}/{ret['data']['id']}" + return await download_image_by_url(f"{self.BASE_RENDER_URL}/generate", post=True, post_data=post_data) diff --git a/util/image_render/template/base.html b/util/t2i/strategies/template/base.html similarity index 100% rename from util/image_render/template/base.html rename to util/t2i/strategies/template/base.html diff --git a/util/updator.py b/util/updator.py deleted file mode 100644 index 3ca2f44..0000000 --- a/util/updator.py +++ /dev/null @@ -1,209 +0,0 @@ -import sys, os, zipfile, shutil -import requests -import psutil -from type.config import VERSION -from SparkleLogging.utils.core import LogManager -from logging import Logger - -from util.general_utils import download_file - -logger: Logger = LogManager.GetLogger(log_name='astrbot-core') - -ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases" -MIRROR_ASTRBOT_RELEASE_API = "https://api.soulter.top/releases" # 0-10 分钟的缓存时间 - -def get_main_path(): - ret = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) - return ret - -def terminate_child_processes(): - try: - parent = psutil.Process(os.getpid()) - children = parent.children(recursive=True) - logger.info(f"正在终止 {len(children)} 个子进程。") - for child in children: - logger.info(f"正在终止子进程 {child.pid}") - child.terminate() - try: - child.wait(timeout=3) - except psutil.NoSuchProcess: - continue - except psutil.TimeoutExpired: - logger.info(f"子进程 {child.pid} 没有被正常终止, 正在强行杀死。") - child.kill() - except psutil.NoSuchProcess: - pass - -def _reboot(): - py = sys.executable - terminate_child_processes() - os.execl(py, py, *sys.argv) - -def request_release_info(latest: bool = True, url: str = ASTRBOT_RELEASE_API, mirror_url: str = MIRROR_ASTRBOT_RELEASE_API) -> list: - ''' - 请求版本信息。 - 返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。 - ''' - try: - result = requests.get(mirror_url).json() - except BaseException as e: - result = requests.get(url).json() - try: - if not result: return [] - if latest: - ret = github_api_release_parser([result[0]]) - else: - ret = github_api_release_parser(result) - except BaseException as e: - logger.error(f"解析版本信息失败: {result}") - raise Exception(f"解析版本信息失败: {result}") - return ret - -def github_api_release_parser(releases: list) -> list: - ''' - 解析 GitHub API 返回的 releases 信息。 - 返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。 - ''' - ret = [] - for release in releases: - version = release['name'] - commit_hash = '' - # 规范是: v3.0.7.xxxxxx,其中xxxxxx为 commit hash - _t = version.split(".") - if len(_t) == 4: - commit_hash = _t[3] - ret.append({ - "version": release['name'], - "published_at": release['published_at'], - "body": release['body'], - "commit_hash": commit_hash, - "tag_name": release['tag_name'], - "zipball_url": release['zipball_url'] - }) - return ret - -def compare_version(v1: str, v2: str) -> int: - ''' - 比较两个版本号的大小。 - 返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2。 - ''' - v1 = v1.replace('v', '') - v2 = v2.replace('v', '') - v1 = v1.split('.') - v2 = v2.split('.') - - for i in range(3): - if int(v1[i]) > int(v2[i]): - return 1 - elif int(v1[i]) < int(v2[i]): - return -1 - return 0 - -def check_update() -> str: - update_data = request_release_info() - tag_name = update_data[0]['tag_name'] - logger.debug(f"当前版本: v{VERSION}") - logger.debug(f"最新版本: {tag_name}") - - if compare_version(VERSION, tag_name) >= 0: - return "当前已经是最新版本。" - - update_info = f"""# 当前版本 -v{VERSION} - -# 最新版本 -{update_data[0]['version']} - -# 发布时间 -{update_data[0]['published_at']} - -# 更新内容 ---- -{update_data[0]['body']} ----""" - return update_info - -def update_project(reboot: bool = False, - latest: bool = True, - version: str = ''): - update_data = request_release_info(latest) - if latest: - latest_version = update_data[0]['tag_name'] - if compare_version(VERSION, latest_version) >= 0: - raise Exception("当前已经是最新版本。") - else: - try: - download_file(update_data[0]['zipball_url'], "temp.zip") - unzip_file("temp.zip", get_main_path()) - if reboot: _reboot() - except BaseException as e: - raise e - else: - # 更新到指定版本 - flag = False - print(f"请求更新到指定版本: {version}") - for data in update_data: - if data['tag_name'] == version: - try: - download_file(data['zipball_url'], "temp.zip") - unzip_file("temp.zip", get_main_path()) - flag = True - if reboot: _reboot() - except BaseException as e: - raise e - if not flag: - raise Exception("未找到指定版本。") - -def unzip_file(zip_path: str, target_dir: str): - ''' - 解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir - ''' - os.makedirs(target_dir, exist_ok=True) - update_dir = "" - logger.info(f"解压文件: {zip_path}") - with zipfile.ZipFile(zip_path, 'r') as z: - update_dir = z.namelist()[0] - z.extractall(target_dir) - - avoid_dirs = ["logs", "data", "configs", "temp_plugins", update_dir] - # copy addons/plugins to the target_dir temporarily - if os.path.exists(os.path.join(target_dir, "addons/plugins")): - logger.info("备份插件目录:从 addons/plugins 到 temp_plugins") - shutil.copytree(os.path.join(target_dir, "addons/plugins"), "temp_plugins") - - files = os.listdir(os.path.join(target_dir, update_dir)) - for f in files: - logger.info(f"移动更新文件/目录: {f}") - if os.path.isdir(os.path.join(target_dir, update_dir, f)): - if f in avoid_dirs: continue - if os.path.exists(os.path.join(target_dir, f)): - shutil.rmtree(os.path.join(target_dir, f), onerror=on_error) - else: - if os.path.exists(os.path.join(target_dir, f)): - os.remove(os.path.join(target_dir, f)) - shutil.move(os.path.join(target_dir, update_dir, f), target_dir) - - # move back - if os.path.exists("temp_plugins"): - logger.info("恢复插件目录:从 temp_plugins 到 addons/plugins") - shutil.rmtree(os.path.join(target_dir, "addons/plugins"), onerror=on_error) - shutil.move("temp_plugins", os.path.join(target_dir, "addons/plugins")) - - try: - logger.info(f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}") - shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error) - os.remove(zip_path) - except: - logger.warn(f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}") - -def on_error(func, path, exc_info): - ''' - a callback of the rmtree function. - ''' - print(f"remove {path} failed.") - import stat - if not os.access(path, os.W_OK): - os.chmod(path, stat.S_IWUSR) - func(path) - else: - raise \ No newline at end of file diff --git a/util/updator/astrbot_updator.py b/util/updator/astrbot_updator.py new file mode 100644 index 0000000..2ce31a8 --- /dev/null +++ b/util/updator/astrbot_updator.py @@ -0,0 +1,109 @@ +import os, psutil, sys, zipfile, shutil, time +from util.updator.zip_updator import ReleaseInfo, RepoZipUpdator +from SparkleLogging.utils.core import LogManager +from logging import Logger +from type.config import VERSION +from util.io import on_error + +logger: Logger = LogManager.GetLogger(log_name='astrbot') + +class AstrBotUpdator(RepoZipUpdator): + def __init__(self): + self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + self.ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases" + + def terminate_child_processes(self): + try: + parent = psutil.Process(os.getpid()) + children = parent.children(recursive=True) + logger.info(f"正在终止 {len(children)} 个子进程。") + for child in children: + logger.info(f"正在终止子进程 {child.pid}") + child.terminate() + try: + child.wait(timeout=3) + except psutil.NoSuchProcess: + continue + except psutil.TimeoutExpired: + logger.info(f"子进程 {child.pid} 没有被正常终止, 正在强行杀死。") + child.kill() + except psutil.NoSuchProcess: + pass + + def _reboot(self, delay: int = None): + if delay: time.sleep(delay) + py = sys.executable + self.terminate_child_processes() + os.execl(py, py, *sys.argv) + + def check_update(self, url: str, current_version: str) -> ReleaseInfo: + return super().check_update(self.ASTRBOT_RELEASE_API, VERSION) + + def update(self, reboot = False, latest = True, version = None): + update_data = self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest) + file_url = None + + if latest: + latest_version = update_data[0]['tag_name'] + if self.compare_version(VERSION, latest_version) >= 0: + raise Exception("当前已经是最新版本。") + file_url = update_data[0]['zipball_url'] + else: + # 更新到指定版本 + print(f"请求更新到指定版本: {version}") + for data in update_data: + if data['tag_name'] == version: + file_url = data['zipball_url'] + if not file_url: + raise Exception(f"未找到版本号为 {version} 的更新文件。") + + try: + self.download_from_repo_url("temp", data['zipball_url']) + self.unzip_file("temp.zip", self.MAIN_PATH) + except BaseException as e: + raise e + + if reboot: + self._reboot() + + def unzip_file(self, zip_path: str, target_dir: str): + ''' + 解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir + ''' + os.makedirs(target_dir, exist_ok=True) + update_dir = "" + logger.info(f"解压文件: {zip_path}") + with zipfile.ZipFile(zip_path, 'r') as z: + update_dir = z.namelist()[0] + z.extractall(target_dir) + + avoid_dirs = ["logs", "data", "configs", "temp_plugins", update_dir] + # copy addons/plugins to the target_dir temporarily + if os.path.exists(os.path.join(target_dir, "addons/plugins")): + logger.info("备份插件目录:从 addons/plugins 到 temp_plugins") + shutil.copytree(os.path.join(target_dir, "addons/plugins"), "temp_plugins") + + files = os.listdir(os.path.join(target_dir, update_dir)) + for f in files: + logger.info(f"移动更新文件/目录: {f}") + if os.path.isdir(os.path.join(target_dir, update_dir, f)): + if f in avoid_dirs: continue + if os.path.exists(os.path.join(target_dir, f)): + shutil.rmtree(os.path.join(target_dir, f), onerror=on_error) + else: + if os.path.exists(os.path.join(target_dir, f)): + os.remove(os.path.join(target_dir, f)) + shutil.move(os.path.join(target_dir, update_dir, f), target_dir) + + # move back + if os.path.exists("temp_plugins"): + logger.info("恢复插件目录:从 temp_plugins 到 addons/plugins") + shutil.rmtree(os.path.join(target_dir, "addons/plugins"), onerror=on_error) + shutil.move("temp_plugins", os.path.join(target_dir, "addons/plugins")) + + try: + logger.info(f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}") + shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error) + os.remove(zip_path) + except: + logger.warn(f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}") diff --git a/util/updator/plugin_updator.py b/util/updator/plugin_updator.py new file mode 100644 index 0000000..6303f2b --- /dev/null +++ b/util/updator/plugin_updator.py @@ -0,0 +1,48 @@ +import os + +from util.updator.zip_updator import RepoZipUpdator +from util.io import remove_dir +from type.plugin import PluginMetadata +from type.register import RegisteredPlugin +from typing import Union +from SparkleLogging.utils.core import LogManager +from logging import Logger + +logger: Logger = LogManager.GetLogger(log_name='astrbot') + + +class PluginUpdator(RepoZipUpdator): + def __init__(self) -> None: + self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../addons/plugins")) + + def get_plugin_store_path(self) -> str: + return self.plugin_store_path + + def update(self, plugin: Union[RegisteredPlugin, str]) -> str: + repo_url = None + + if not isinstance(plugin, str): + plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name) + if not os.path.exists(os.path.join(plugin_path, "REPO")): + raise Exception("插件更新信息文件 `REPO` 不存在,请手动升级,或者先卸载然后重新安装该插件。") + + with open(os.path.join(plugin_path, "REPO"), "r", encoding='utf-8') as f: + repo_url = f.read() + else: + repo_url = plugin + plugin_path = os.path.join(self.plugin_store_path, self.format_repo_name(repo_url)) + + logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}") + self.download_from_repo_url(plugin_path, repo_url) + + try: + remove_dir(plugin_path) + except BaseException as e: + logger.error(f"删除旧版本插件 {plugin.metadata.plugin_name} 文件夹失败: {str(e)},使用覆盖安装。") + + self.unzip_file(plugin_path + ".zip", plugin_path) + + return plugin_path + + + \ No newline at end of file diff --git a/util/updator/zip_updator.py b/util/updator/zip_updator.py new file mode 100644 index 0000000..a49feee --- /dev/null +++ b/util/updator/zip_updator.py @@ -0,0 +1,155 @@ +import requests, os, zipfile, shutil +from SparkleLogging.utils.core import LogManager +from logging import Logger +from util.io import on_error, download_file + +logger: Logger = LogManager.GetLogger(log_name='astrbot') + +class ReleaseInfo(): + version: str + published_at: str + body: str + + def __str__(self) -> str: + return f"新版本: {self.version}, 发布于: {self.published_at}, 详细内容: {self.body}" + +class RepoZipUpdator(): + def __init__(self, path): + self.path = path + self.rm_on_error = on_error + + def fetch_release_info(self, url: str, latest: bool = True) -> list: + ''' + 请求版本信息。 + 返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。 + ''' + result = requests.get(url).json() + + try: + if not result: return [] + if latest: + ret = self.github_api_release_parser([result[0]]) + else: + ret = self.github_api_release_parser(result) + except BaseException as e: + logger.error(f"解析版本信息失败: {result}") + raise Exception(f"解析版本信息失败: {result}") + return ret + + def github_api_release_parser(self, releases: list) -> list: + ''' + 解析 GitHub API 返回的 releases 信息。 + 返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。 + ''' + ret = [] + for release in releases: + version = release['name'] + commit_hash = '' + # 规范是: v3.0.7.xxxxxx,其中xxxxxx为 commit hash + _t = version.split(".") + if len(_t) == 4: commit_hash = _t[3] + ret.append({ + "version": release['name'], + "published_at": release['published_at'], + "body": release['body'], + "commit_hash": commit_hash, + "tag_name": release['tag_name'], + "zipball_url": release['zipball_url'] + }) + return ret + + def unzip(self): + raise NotImplementedError() + + def update(self): + raise NotImplementedError() + + def compare_version(self, v1: str, v2: str) -> int: + ''' + 比较两个版本号的大小。 + 返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2。 + ''' + v1 = v1.replace('v', '') + v2 = v2.replace('v', '') + v1 = v1.split('.') + v2 = v2.split('.') + + for i in range(3): + if int(v1[i]) > int(v2[i]): + return 1 + elif int(v1[i]) < int(v2[i]): + return -1 + return 0 + + def check_update(self, url: str, current_version: str) -> ReleaseInfo: + update_data = self.fetch_release_info(url) + tag_name = update_data[0]['tag_name'] + + if self.compare_version(current_version, tag_name) >= 0: + return None + return ReleaseInfo( + version=tag_name, + published_at=update_data[0]['published_at'], + body=update_data[0]['body'] + ) + + def download_from_repo_url(self, target_path: str, repo_url: str): + repo_namespace = repo_url.split("/")[-2:] + author = repo_namespace[0] + repo = repo_namespace[1] + + logger.info(f"正在下载更新 {repo} ...") + release_url = f"https://api.github.com/repos/{author}/{repo}/releases" + releases = self.fetch_release_info(url=release_url) + if not releases: + # download from the default branch directly. + logger.warn(f"未在仓库 {author}/{repo} 中找到任何发布版本,将从默认分支下载。") + release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip" + else: + release_url = releases[0]['zipball_url'] + + download_file(release_url, target_path + ".zip") + + + def unzip_file(self, zip_path: str, target_dir: str): + ''' + 解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir + ''' + os.makedirs(target_dir, exist_ok=True) + update_dir = "" + logger.info(f"解压文件: {zip_path}") + with zipfile.ZipFile(zip_path, 'r') as z: + update_dir = z.namelist()[0] + z.extractall(target_dir) + + files = os.listdir(os.path.join(target_dir, update_dir)) + for f in files: + logger.info(f"移动更新文件/目录: {f}") + if os.path.isdir(os.path.join(target_dir, update_dir, f)): + if os.path.exists(os.path.join(target_dir, f)): + shutil.rmtree(os.path.join(target_dir, f), onerror=on_error) + else: + if os.path.exists(os.path.join(target_dir, f)): + os.remove(os.path.join(target_dir, f)) + shutil.move(os.path.join(target_dir, update_dir, f), target_dir) + + try: + logger.info(f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}") + shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error) + os.remove(zip_path) + except: + logger.warn(f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}") + + def format_repo_name(self, repo_url: str) -> str: + if repo_url.endswith("/"): + repo_url = repo_url[:-1] + + repo_namespace = repo_url.split("/")[-2:] + repo = repo_namespace[1] + + repo = self.format_name(repo) + + return repo + + def format_name(self, name: str) -> str: + return name.replace("-", "_").lower() \ No newline at end of file diff --git a/util/search_engine_scraper/bing.py b/util/websearch/bing.py similarity index 89% rename from util/search_engine_scraper/bing.py rename to util/websearch/bing.py index f51a2a0..c0801dc 100644 --- a/util/search_engine_scraper/bing.py +++ b/util/websearch/bing.py @@ -1,8 +1,8 @@ from typing import List try: - from util.search_engine_scraper.engine import SearchEngine, SearchResult - from util.search_engine_scraper.config import HEADERS, USER_AGENT_BING + from util.websearch.engine import SearchEngine, SearchResult + from util.websearch.config import HEADERS, USER_AGENT_BING except ImportError: from engine import SearchEngine, SearchResult from config import HEADERS, USER_AGENT_BING diff --git a/util/search_engine_scraper/config.py b/util/websearch/config.py similarity index 100% rename from util/search_engine_scraper/config.py rename to util/websearch/config.py diff --git a/util/search_engine_scraper/engine.py b/util/websearch/engine.py similarity index 97% rename from util/search_engine_scraper/engine.py rename to util/websearch/engine.py index d78f5bf..57bf2f7 100644 --- a/util/search_engine_scraper/engine.py +++ b/util/websearch/engine.py @@ -1,6 +1,6 @@ import random try: - from util.search_engine_scraper.config import HEADERS, USER_AGENTS + from util.websearch.config import HEADERS, USER_AGENTS except ImportError: from config import HEADERS, USER_AGENTS diff --git a/util/search_engine_scraper/google.py b/util/websearch/google.py similarity index 84% rename from util/search_engine_scraper/google.py rename to util/websearch/google.py index 9eff698..f1d7847 100644 --- a/util/search_engine_scraper/google.py +++ b/util/websearch/google.py @@ -2,8 +2,8 @@ from googlesearch import search try: - from util.search_engine_scraper.engine import SearchEngine, SearchResult - from util.search_engine_scraper.config import HEADERS, USER_AGENTS + from util.websearch.engine import SearchEngine, SearchResult + from util.websearch.config import HEADERS, USER_AGENTS except ImportError: from engine import SearchEngine, SearchResult from config import HEADERS, USER_AGENTS diff --git a/util/search_engine_scraper/sogo.py b/util/websearch/sogo.py similarity index 91% rename from util/search_engine_scraper/sogo.py rename to util/websearch/sogo.py index 4ddf88d..07a7a5b 100644 --- a/util/search_engine_scraper/sogo.py +++ b/util/websearch/sogo.py @@ -2,8 +2,8 @@ from bs4 import BeautifulSoup try: - from util.search_engine_scraper.engine import SearchEngine, SearchResult - from util.search_engine_scraper.config import HEADERS, USER_AGENTS + from util.websearch.engine import SearchEngine, SearchResult + from util.websearch.config import HEADERS, USER_AGENTS except ImportError: from engine import SearchEngine, SearchResult from config import HEADERS, USER_AGENTS From d2bca2c2183de07d36a1795245f7439344fc9027 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 24 Jul 2024 09:19:43 -0400 Subject: [PATCH 2/9] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E4=BA=86=E4=B8=80?= =?UTF-8?q?=E4=BA=9Bbug=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/bootstrap.py | 21 +++++--- dashboard/server.py | 6 +-- model/command/internal_handler.py | 2 +- model/command/manager.py | 8 ++- model/command/openai_official_handler.py | 8 ++- model/platform/__init__.py | 7 ++- model/platform/manager.py | 14 +++--- model/platform/qq_aiocqhttp.py | 64 +++++++++++++++++------- model/platform/qq_nakuru.py | 7 ++- model/platform/qq_official.py | 10 ++-- util/t2i/context.py | 4 +- util/t2i/renderer.py | 2 +- util/t2i/strategies/local_strategy.py | 2 +- 13 files changed, 101 insertions(+), 54 deletions(-) diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py index fdfc922..bcaa765 100644 --- a/astrbot/bootstrap.py +++ b/astrbot/bootstrap.py @@ -1,5 +1,5 @@ import asyncio -import threading +import traceback from astrbot.message.handler import MessageHandler from astrbot.persist.helper import dbConn from dashboard.server import AstrBotDashBoard @@ -72,13 +72,21 @@ async def run(self): # load platforms platform_tasks = self.load_platform() # load metrics uploader - metrics_upload_task = upload_metrics(self.context) + metrics_upload_task = asyncio.create_task(upload_metrics(self.context)) # load dashboard self.dashboard.run_http_server() - dashboard_task = self.dashboard.ws_server() - - await asyncio.gather(metrics_upload_task, dashboard_task, *platform_tasks) - + dashboard_task = asyncio.create_task(self.dashboard.ws_server()) + tasks = [metrics_upload_task, dashboard_task, *platform_tasks] + tasks = [self.handle_task(task) for task in tasks] + await asyncio.gather(*tasks) + + async def handle_task(self, task: Union[asyncio.Task, asyncio.Future]): + try: + result = await task + return result + except Exception as e: + logger.error(traceback.format_exc()) + return None def load_llm(self): if 'openai' in self.configs and \ @@ -88,6 +96,7 @@ def load_llm(self): from model.command.openai_official_handler import OpenAIOfficialCommandHandler self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager) self.llm_instance = ProviderOpenAIOfficial(self.context) + self.openai_command_handler.set_provider(self.llm_instance) logger.info("已启用 OpenAI API 支持。") def load_plugins(self): diff --git a/dashboard/server.py b/dashboard/server.py index 8da717f..bbcd35d 100644 --- a/dashboard/server.py +++ b/dashboard/server.py @@ -419,19 +419,19 @@ def _generate_outline(self): "tag": "" }, { - "title": "QQ(官方机器人 API)", + "title": "QQ(官方)", "desc": "QQ官方API。支持频道、群、私聊(需获得群权限)", "namespace": "internal_platform_qq_official", "tag": "" }, { - "title": "QQ(nakuru 适配器)", + "title": "QQ(nakuru)", "desc": "适用于 go-cqhttp", "namespace": "internal_platform_qq_gocq", "tag": "" }, { - "title": "QQ(aiocqhttp 适配器)", + "title": "QQ(aiocqhttp)", "desc": "适用于 Lagrange, LLBot, Shamrock 等支持反向WS的协议实现。", "namespace": "internal_platform_qq_aiocqhttp", "tag": "" diff --git a/model/command/internal_handler.py b/model/command/internal_handler.py index 9a4e6a4..5f0c58c 100644 --- a/model/command/internal_handler.py +++ b/model/command/internal_handler.py @@ -23,7 +23,7 @@ def __init__(self, manager: CommandManager, plugin_manager: PluginManager) -> No self.manager.register("update", "更新 AstrBot", 10, self.update) self.manager.register("plugin", "插件管理", 10, self.plugin) self.manager.register("reboot", "重启 AstrBot", 10, self.reboot) - self.manager.register("web_search", "网页搜索开关", 10, self.web_search) + self.manager.register("websearch", "网页搜索开关", 10, self.web_search) self.manager.register("t2i", "文本转图片开关", 10, self.t2i_toggle) self.manager.register("myid", "获取你在此平台上的ID", 10, self.myid) diff --git a/model/command/manager.py b/model/command/manager.py index 34b3cc6..bb4ba7c 100644 --- a/model/command/manager.py +++ b/model/command/manager.py @@ -93,7 +93,11 @@ async def execute_handler(self, return command_result except BaseException as e: logger.error(traceback.format_exc()) + if not command_metadata.inner_command: - logger.error(f"当执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时,发生了异常。") + text = f"执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时发生了异常。{e}" + logger.error(text) else: - logger.error(f"当执行 {command} 指令时,发生了异常。") \ No newline at end of file + text = f"执行 {command} 指令时发生了异常。{e}" + logger.error(text) + return CommandResult().message(text) \ No newline at end of file diff --git a/model/command/openai_official_handler.py b/model/command/openai_official_handler.py index 759149b..49b645a 100644 --- a/model/command/openai_official_handler.py +++ b/model/command/openai_official_handler.py @@ -145,8 +145,6 @@ def set(self, message: AstrMessageEvent, context: Context): async def draw(self, message: AstrMessageEvent, context: Context): message = message.message_str.removeprefix("画") img_url = await self.provider.image_generate(message) - p = await download_image_by_url(url=img_url) - with open(p, 'rb') as f: - return CommandResult( - message_chain=[Image.fromBytes(f.read())], - ) \ No newline at end of file + return CommandResult( + message_chain=[Image.fromURL(img_url)], + ) \ No newline at end of file diff --git a/model/platform/__init__.py b/model/platform/__init__.py index 2df17e3..8073625 100644 --- a/model/platform/__init__.py +++ b/model/platform/__init__.py @@ -64,7 +64,10 @@ async def convert_to_t2i_chain(self, message_result: list) -> list: if isinstance(i, Plain): plain_str += i.text if plain_str and len(plain_str) > 50: - p = await self.context.image_renderer.render(plain_str) - rendered_images.append(Image.fromFileSystem(p)) + p = await self.context.image_renderer.render(plain_str, return_url=True) + if p.startswith('http'): + rendered_images.append(Image.fromURL(p)) + else: + rendered_images.append(Image.fromFileSystem(p)) return rendered_images \ No newline at end of file diff --git a/model/platform/manager.py b/model/platform/manager.py index bec7df7..1f287f4 100644 --- a/model/platform/manager.py +++ b/model/platform/manager.py @@ -1,4 +1,4 @@ -import time, threading +import asyncio from util.io import port_checker from type.register import RegisteredPlatform @@ -21,20 +21,20 @@ def load_platforms(self): if 'gocqbot' in self.config and self.config['gocqbot']['enable']: logger.info("启用 QQ(nakuru 适配器)") - tasks.append(self.gocq_bot()) + tasks.append(asyncio.create_task(self.gocq_bot())) if 'aiocqhttp' in self.config and self.config['aiocqhttp']['enable']: logger.info("启用 QQ(aiocqhttp 适配器)") - tasks.append(self.aiocq_bot()) + tasks.append(asyncio.create_task(self.aiocq_bot())) # QQ频道 if 'qqbot' in self.config and self.config['qqbot']['enable'] and self.config['qqbot']['appid'] != None: logger.info("启用 QQ(官方 API) 机器人消息平台") - tasks.append(self.qqchan_bot()) + tasks.append(asyncio.create_task(self.qqchan_bot())) return tasks - def gocq_bot(self): + async def gocq_bot(self): ''' 运行 QQ(nakuru 适配器) ''' @@ -51,7 +51,7 @@ def gocq_bot(self): noticed = True logger.warning( f"连接到{host}:{port}(或{http_port})失败。程序会每隔 5s 自动重试。") - time.sleep(5) + await asyncio.sleep(5) else: logger.info("nakuru 适配器已连接。") break @@ -59,7 +59,7 @@ def gocq_bot(self): qq_gocq = QQGOCQ(self.context, self.msg_handler) self.context.platforms.append(RegisteredPlatform( platform_name="gocq", platform_instance=qq_gocq, origin="internal")) - return qq_gocq.run() + await qq_gocq.run() except BaseException as e: logger.error("启动 nakuru 适配器时出现错误: " + str(e)) diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py index 777f4cc..84ffcf4 100644 --- a/model/platform/qq_aiocqhttp.py +++ b/model/platform/qq_aiocqhttp.py @@ -22,11 +22,8 @@ def __init__(self, context: Context, message_handler: MessageHandler) -> None: self.announcement = self.context.base_config.get("announcement", "欢迎新人!") self.host = self.context.base_config['aiocqhttp']['ws_reverse_host'] self.port = self.context.base_config['aiocqhttp']['ws_reverse_port'] - - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - def compat_onebot2astrbotmsg(self, event: Event) -> AstrBotMessage: + def convert_message(self, event: Event) -> AstrBotMessage: abm = AstrBotMessage() abm.self_id = str(event.self_id) @@ -69,28 +66,48 @@ def compat_onebot2astrbotmsg(self, event: Event) -> AstrBotMessage: def run_aiocqhttp(self): if not self.host or not self.port: return - self.bot = CQHttp(use_ws_reverse=True) - + self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp') @self.bot.on_message('group') async def group(event: Event): - abm = self.compat_onebot2astrbotmsg(event) + abm = self.convert_message(event) if abm: - await self.handle_msg(event, abm) - return {'reply': event.message} + await self.handle_msg(abm) + # return {'reply': event.message} @self.bot.on_message('private') async def private(event: Event): - abm = self.compat_onebot2astrbotmsg(event) + abm = self.convert_message(event) if abm: - await self.handle_msg(event, abm) - return {'reply': event.message} + await self.handle_msg(abm) + # return {'reply': event.message} + + bot = self.bot.run_task(host=self.host, port=int(self.port)) - return self.bot.run_task(host=self.host, port=int(self.port)) + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + logging.getLogger('aiocqhttp').setLevel(logging.ERROR) + + return bot + + def pre_check(self, message: AstrBotMessage) -> bool: + # if message chain contains Plain components or At components which points to self_id, return True + if message.type == MessageType.FRIEND_MESSAGE: + return True + for comp in message.message: + if isinstance(comp, At) and str(comp.qq) == message.self_id: + return True + # check nicks + if self.check_nick(message.message_str): + return True + return False async def handle_msg(self, message: AstrBotMessage): logger.info( f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}") + if not self.pre_check(message): + return + # 解析 role sender_id = str(message.sender.user_id) if sender_id == self.context.config_helper.get('admin_qq', '') or \ @@ -100,7 +117,7 @@ async def handle_msg(self, message: AstrBotMessage): role = 'member' # construct astrbot message event - ame = AstrMessageEvent().from_astrbot_message(message, self.context, "gocq", message.session_id, role) + ame = AstrMessageEvent.from_astrbot_message(message, self.context, "aiocqhttp", message.session_id, role) # transfer control to message handler message_result = await self.message_handler.handle(ame) @@ -118,13 +135,14 @@ async def handle_msg(self, message: AstrBotMessage): async def reply_msg(self, message: AstrBotMessage, result_message: list): - await super().reply_msg() """ 回复用户唤醒机器人的消息。(被动回复) """ logger.info( f"{message.sender.user_id} <- {self.parse_message_outline(message)}") + res = result_message + if isinstance(res, str): res = [Plain(text=res), ] @@ -138,9 +156,21 @@ async def reply_msg(self, except BaseException as e: logger.warn(traceback.format_exc()) logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。") + + await self._reply(message, res) async def _reply(self, message: AstrBotMessage, message_chain: List[BaseMessageComponent]): if isinstance(message_chain, str): message_chain = [Plain(text=message_chain), ] - - await self.bot.send(message.raw_message, message_chain) \ No newline at end of file + + ret = [] + for segment in message_chain: + d = segment.toDict() + if isinstance(segment, Plain): + d['type'] = 'text' + if isinstance(segment, Image): + # d['data']['file'] = + pass + ret.append(d) + + await self.bot.send(message.raw_message, ret) \ No newline at end of file diff --git a/model/platform/qq_nakuru.py b/model/platform/qq_nakuru.py index a9cc625..d332b0b 100644 --- a/model/platform/qq_nakuru.py +++ b/model/platform/qq_nakuru.py @@ -75,8 +75,6 @@ def pre_check(self, message: AstrBotMessage) -> bool: if message.type == MessageType.FRIEND_MESSAGE: return True for comp in message.message: - if isinstance(comp, Plain): - return True if isinstance(comp, At) and str(comp.qq) == message.self_id: return True # check nicks @@ -85,7 +83,8 @@ def pre_check(self, message: AstrBotMessage) -> bool: return False def run(self): - return self.client._run() + coro = self.client._run() + return coro async def handle_msg(self, message: AstrBotMessage): logger.info( @@ -119,7 +118,7 @@ async def handle_msg(self, message: AstrBotMessage): role = 'member' # construct astrbot message event - ame = AstrMessageEvent().from_astrbot_message(message, self.context, "gocq", session_id, role) + ame = AstrMessageEvent.from_astrbot_message(message, self.context, "gocq", session_id, role) # transfer control to message handler message_result = await self.message_handler.handle(ame) diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index 060ea5d..f2d60ea 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -309,9 +309,13 @@ async def _reply(self, **kwargs): if 'file_image' in kwargs: file_image_path = kwargs['file_image'].replace("file:///", "") if file_image_path: - logger.debug(f"上传图片: {file_image_path}") - image_url = await self.context.image_uploader.upload_image(file_image_path) - logger.debug(f"上传成功: {image_url}") + + if file_image_path.startswith("http"): + image_url = file_image_path + else: + logger.debug(f"上传图片: {file_image_path}") + image_url = await self.context.image_uploader.upload_image(file_image_path) + logger.debug(f"上传成功: {image_url}") media = await self.client.api.post_group_file(kwargs['group_openid'], 1, image_url) del kwargs['file_image'] kwargs['media'] = media diff --git a/util/t2i/context.py b/util/t2i/context.py index 255fb89..ac9f837 100644 --- a/util/t2i/context.py +++ b/util/t2i/context.py @@ -7,5 +7,5 @@ def __init__(self, strategy: RenderStrategy): def set_strategy(self, strategy: RenderStrategy): self._strategy = strategy - async def render(self, text: str) -> str: - return await self._strategy.render(text) + async def render(self, text: str, return_url: bool = False): + return await self._strategy.render(text, return_url) diff --git a/util/t2i/renderer.py b/util/t2i/renderer.py index 28df04b..bf141bc 100644 --- a/util/t2i/renderer.py +++ b/util/t2i/renderer.py @@ -15,7 +15,7 @@ def __init__(self): async def render(self, text: str, use_network: bool = True, return_url: bool = False): if use_network: try: - return await self.context.render(text) + return await self.context.render(text, return_url=return_url) except BaseException as e: logger.error(f"Failed to render image via AstrBot API: {e}. Falling back to local rendering.") self.context.set_strategy(self.local_strategy) diff --git a/util/t2i/strategies/local_strategy.py b/util/t2i/strategies/local_strategy.py index 56210bc..5711a48 100644 --- a/util/t2i/strategies/local_strategy.py +++ b/util/t2i/strategies/local_strategy.py @@ -17,7 +17,7 @@ def get_font(self, size: int) -> ImageFont.FreeTypeFont: except Exception as e: pass - async def render(self, text: str): + async def render(self, text: str, **kwargs): font_size = 26 image_width = 800 image_height = 600 From 8d6f1240d652321bcb45415357c13ac83938c597 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 24 Jul 2024 09:48:25 -0400 Subject: [PATCH 3/9] update: metrics refactoring --- astrbot/bootstrap.py | 7 +- model/command/command.py | 317 --------------------------------------- model/command/manager.py | 3 + type/types.py | 1 + util/general_utils.py | 62 -------- util/metrics.py | 66 ++++++++ 6 files changed, 74 insertions(+), 382 deletions(-) delete mode 100644 model/command/command.py create mode 100644 util/metrics.py diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py index bcaa765..bba3544 100644 --- a/astrbot/bootstrap.py +++ b/astrbot/bootstrap.py @@ -13,7 +13,7 @@ from SparkleLogging.utils.core import LogManager from logging import Logger from util.cmd_config import CmdConfig -from util.general_utils import upload_metrics +from util.metrics import MetricUploader from util.config_utils import * from util.updator.astrbot_updator import AstrBotUpdator @@ -47,7 +47,6 @@ def __init__(self) -> None: logger.info("未使用代理。") async def run(self): - self.command_manager = CommandManager() self.plugin_manager = PluginManager(self.context) self.updator = AstrBotUpdator() @@ -61,7 +60,9 @@ async def run(self): self.message_handler = MessageHandler(self.context, self.command_manager, self.db_conn_helper, self.llm_instance) self.platfrom_manager = PlatformManager(self.context, self.message_handler) self.dashboard = AstrBotDashBoard(self.context, plugin_manager=self.plugin_manager, astrbot_updator=self.updator) + self.metrics_uploader = MetricUploader(self.context) + self.context.metrics_uploader = self.metrics_uploader self.context.updator = self.updator self.context.plugin_updator = self.plugin_manager.updator @@ -72,7 +73,7 @@ async def run(self): # load platforms platform_tasks = self.load_platform() # load metrics uploader - metrics_upload_task = asyncio.create_task(upload_metrics(self.context)) + metrics_upload_task = asyncio.create_task(self.metrics_uploader.upload_metrics()) # load dashboard self.dashboard.run_http_server() dashboard_task = asyncio.create_task(self.dashboard.ws_server()) diff --git a/model/command/command.py b/model/command/command.py deleted file mode 100644 index 5637026..0000000 --- a/model/command/command.py +++ /dev/null @@ -1,317 +0,0 @@ -import json -import inspect -import aiohttp -import asyncio -import json - -import util.updator - -from nakuru.entities.components import ( - Image -) -from util import general_utils as gu -from model.provider.provider import Provider -from util.cmd_config import CmdConfig as cc -from type.message import * -from type.types import Context -from type.command import * -from type.plugin import * -from type.register import * - -from typing import List -from SparkleLogging.utils.core import LogManager -from logging import Logger - -logger: Logger = LogManager.GetLogger(log_name='astrbot') - -PLATFORM_QQCHAN = 'qqchan' -PLATFORM_GOCQ = 'gocq' - -# 指令功能的基类,通用的(不区分语言模型)的指令就在这实现 - - -class Command: - def __init__(self, provider: Provider, global_object: Context = None): - self.provider = provider - self.global_object = global_object - - async def check_command(self, - message, - session_id: str, - role: str, - platform: RegisteredPlatform, - message_obj): - self.platform = platform - # 插件 - cached_plugins = self.global_object.cached_plugins - # 将消息封装成 AstrMessageEvent 对象 - ame = AstrMessageEvent( - message_str=message, - message_obj=message_obj, - platform=platform, - role=role, - context=self.global_object, - session_id=session_id - ) - # 从已启动的插件中查找是否有匹配的指令 - for plugin in cached_plugins: - # 过滤掉平台类插件 - if plugin.metadata.plugin_type == PluginType.PLATFORM: - continue - try: - if inspect.iscoroutinefunction(plugin.plugin_instance.run): - result = await plugin.plugin_instance.run(ame) - else: - result = await asyncio.to_thread(plugin.plugin_instance.run, ame) - if not result: - continue - if isinstance(result, CommandResult): - hit = result.hit - res = result._result_tuple() - elif isinstance(result, tuple): - hit = result[0] - res = result[1] - else: - raise TypeError("插件返回值格式错误。") - if hit: - plugin.trig() - logger.debug("hit plugin: " + plugin.metadata.plugin_name) - return True, res - except TypeError as e: - # 参数不匹配,尝试使用旧的参数方案 - try: - if inspect.iscoroutinefunction(plugin.plugin_instance.run): - hit, res = await plugin.plugin_instance.run(message, role, platform, message_obj, self.global_object.platform_qq) - else: - hit, res = await asyncio.to_thread(plugin.plugin_instance.run, message, role, platform, message_obj, self.global_object.platform_qq) - if hit: - return True, res - except BaseException as e: - logger.error( - f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。") - except BaseException as e: - logger.error( - f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。") - - if self.command_start_with(message, "nick"): - return True, self.set_nick(message, platform, role) - if self.command_start_with(message, "plugin"): - return True, await self.plugin_oper(message, role, self.global_object, platform) - if self.command_start_with(message, "myid") or self.command_start_with(message, "!myid"): - return True, self.get_my_id(message_obj, platform) - if self.command_start_with(message, "web"): # 网页搜索 - return True, self.web_search(message) - if self.command_start_with(message, "update"): - return True, self.update(message, role) - if message == "t2i": - return True, self.t2i_toggle(message, role) - if not self.provider and message == "help": - return True, await self.help() - - return False, None - - def web_search(self, message): - l = message.split(' ') - if len(l) == 1: - return True, f"网页搜索功能当前状态: {self.global_object.web_search}", "web" - elif l[1] == 'on': - self.global_object.web_search = True - return True, "已开启网页搜索", "web" - elif l[1] == 'off': - self.global_object.web_search = False - return True, "已关闭网页搜索", "web" - - def t2i_toggle(self, message, role): - p = cc.get("qq_pic_mode", True) - if p: - cc.put("qq_pic_mode", False) - return True, "已关闭文本转图片模式。", "t2i" - cc.put("qq_pic_mode", True) - return True, "已开启文本转图片模式。", "t2i" - - def get_my_id(self, message_obj, platform): - try: - user_id = str(message_obj.sender.user_id) - return True, f"你在此平台上的ID:{user_id}", "plugin" - except BaseException as e: - return False, f"在{platform}上获取你的ID失败,原因: {str(e)}", "plugin" - - async def plugin_oper(self, message: str, role: str, ctx: Context, platform: str): - l = message.split(" ") - if len(l) < 2: - p = await text_to_image_base("# 插件指令面板 \n- 安装插件: `plugin i 插件Github地址`\n- 卸载插件: `plugin d 插件名`\n- 重载插件: `plugin reload`\n- 查看插件列表:`plugin l`\n - 更新插件: `plugin u 插件名`\n") - with open(p, 'rb') as f: - return True, [Image.fromBytes(f.read())], "plugin" - else: - if l[1] == "i": - if role != "admin": - return False, f"你的身份组{role}没有权限安装插件", "plugin" - try: - putil.install_plugin(l[2], ) - return True, "插件拉取并载入成功~", "plugin" - except BaseException as e: - return False, f"拉取插件失败,原因: {str(e)}", "plugin" - elif l[1] == "d": - if role != "admin": - return False, f"你的身份组{role}没有权限删除插件", "plugin" - try: - putil.uninstall_plugin(l[2], ctx) - return True, "插件卸载成功~", "plugin" - except BaseException as e: - return False, f"卸载插件失败,原因: {str(e)}", "plugin" - elif l[1] == "u": - try: - putil.update_plugin(l[2], ctx) - return True, "\n更新插件成功!!", "plugin" - except BaseException as e: - return False, f"更新插件失败,原因: {str(e)}。\n建议: 使用 plugin i 指令进行覆盖安装(插件数据可能会丢失)", "plugin" - elif l[1] == "l": - try: - plugin_list_info = "" - for plugin in ctx.cached_plugins: - plugin_list_info += f"### {plugin.metadata.plugin_name} \n- 名称: {plugin.metadata.plugin_name}\n- 简介: {plugin.metadata.desc}\n- 版本: {plugin.metadata.version}\n- 作者: {plugin.metadata.author}\n" - p = await text_to_image_base(f"# 已激活的插件\n{plugin_list_info}\n> 使用plugin v 插件名 查看插件帮助\n") - with open(p, 'rb') as f: - return True, [Image.fromBytes(f.read())], "plugin" - except BaseException as e: - return False, f"获取插件列表失败,原因: {str(e)}", "plugin" - elif l[1] == "v": - try: - info = None - for i in ctx.cached_plugins: - if i.metadata.plugin_name == l[2]: - info = i.metadata - break - if info: - p = await text_to_image_base(f"# `{info.plugin_name}` 插件信息\n- 类型: {info.plugin_type}\n- 简介{info.desc}\n- 版本: {info.version}\n- 作者: {info.author}") - with open(p, 'rb') as f: - return True, [Image.fromBytes(f.read())], "plugin" - else: - return False, "未找到该插件", "plugin" - except BaseException as e: - return False, f"获取插件信息失败,原因: {str(e)}", "plugin" - - ''' - nick: 存储机器人的昵称 - ''' - - def set_nick(self, message: str, platform: RegisteredPlatform, role: str = "member"): - if role != "admin": - return True, "你无权使用该指令 :P", "nick" - if str(platform) == PLATFORM_GOCQ: - l = message.split(" ") - if len(l) == 1: - return True, "【设置机器人昵称】示例:\n支持多昵称\nnick 昵称1 昵称2 昵称3", "nick" - nick = l[1:] - cc.put("nick_qq", nick) - self.global_object.nick = tuple(nick) - return True, f"设置成功!现在你可以叫我这些昵称来提问我啦~", "nick" - elif str(platform) == PLATFORM_QQCHAN: - nick = message.split(" ")[2] - return False, "QQ频道平台不支持为机器人设置昵称。", "nick" - - def general_commands(self): - return { - "help": "帮助", - "keyword": "设置关键词/关键指令回复", - "update": "更新项目", - "nick": "设置机器人唤醒词", - "plugin": "插件安装、卸载和重载", - "web on/off": "LLM 网页搜索能力", - "t2i": "启用/关闭文本转图片模式" - } - - async def help_messager(self, commands: dict, platform: str, cached_plugins: List[RegisteredPlugin] = None): - try: - async with aiohttp.ClientSession() as session: - async with session.get("https://soulter.top/channelbot/notice.json") as resp: - notice = (await resp.json())["notice"] - except BaseException as e: - notice = "" - msg = "## 指令列表\n" - for key, value in commands.items(): - msg += f"- `{key}`: {value}\n" - # plugins - if cached_plugins: - plugin_list_info = "" - for plugin in cached_plugins: - plugin_list_info += f"- `{plugin.metadata.plugin_name}`: {plugin.metadata.desc}\n" - if plugin_list_info.strip(): - msg += "\n## 插件列表\n> 使用 plugin v 插件名 查看插件帮助\n" - msg += plugin_list_info - msg += notice - - try: - p = await text_to_image_base(msg) - with open(p, 'rb') as f: - return [Image.fromBytes(f.read()),] - except BaseException as e: - logger.error(str(e)) - return msg - - def command_start_with(self, message: str, *args): - ''' - 当消息以指定的指令开头时返回True - ''' - for arg in args: - if message.startswith(arg) or message.startswith('/'+arg): - return True - return False - - def update(self, message: str, role: str): - if role != "admin": - return True, "你没有权限使用该指令", "update" - l = message.split(" ") - if len(l) == 1: - try: - update_info = util.updator.check_update() - update_info += "\n> Tips: 输入「update latest」更新到最新版本,输入「update <版本号如v3.1.3>」切换到指定版本,输入「update r」重启机器人\n" - return True, update_info, "update" - except BaseException as e: - return False, "检查更新失败: "+str(e), "update" - else: - if l[1] == "latest": - try: - util.updator.update_project() - return True, "更新成功,重启生效。可输入「update r」重启", "update" - except BaseException as e: - return False, "更新失败: "+str(e), "update" - elif l[1] == "r": - util.updator._reboot() - else: - if l[1].lower().startswith('v'): - try: - util.updator.update_project(latest=False, version=l[1]) - return True, "更新成功,重启生效。可输入「update r」重启", "update" - except BaseException as e: - return False, "更新失败: "+str(e), "update" - else: - return False, "版本号格式错误", "update" - - def reset(self): - return False - - def set(self): - return False - - def unset(self): - return False - - def key(self): - return False - - async def help(self): - ret = await self.help_messager(self.general_commands(), self.platform, self.global_object.cached_plugins) - return True, ret, "help" - - def status(self): - return False - - def token(self): - return False - - def his(self): - return False - - def draw(self): - return False diff --git a/model/command/manager.py b/model/command/manager.py index bb4ba7c..4dcfcbb 100644 --- a/model/command/manager.py +++ b/model/command/manager.py @@ -90,6 +90,9 @@ async def execute_handler(self, if not isinstance(command_result, CommandResult): raise ValueError(f"Command {command} handler should return CommandResult.") + + context.metrics_uploader.command_stats[command] += 1 + return command_result except BaseException as e: logger.error(traceback.format_exc()) diff --git a/type/types.py b/type/types.py index 6dd569d..b3b63e4 100644 --- a/type/types.py +++ b/type/types.py @@ -32,6 +32,7 @@ def __init__(self): self.reply_prefix = "" self.updator: AstrBotUpdator = None self.plugin_updator: PluginUpdator = None + self.metrics_uploader = None self.plugin_command_bridge = PluginCommandBridge(self.cached_plugins) self.image_renderer = TextToImageRenderer() diff --git a/util/general_utils.py b/util/general_utils.py index 28bdabd..270faf4 100644 --- a/util/general_utils.py +++ b/util/general_utils.py @@ -11,68 +11,6 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot') -async def upload_metrics(_global_object: Context): - ''' - 上传相关非敏感统计数据 - ''' - await asyncio.sleep(10) - while True: - platform_stats = {} - llm_stats = {} - plugin_stats = {} - for platform in _global_object.platforms: - platform_stats[platform.platform_name] = { - "cnt_receive": 0, - "cnt_reply": 0 - } - - for llm in _global_object.llms: - stat = llm.llm_instance.model_stat - for k in stat: - llm_stats[llm.llm_name + "#" + k] = stat[k] - llm.llm_instance.reset_model_stat() - - for plugin in _global_object.cached_plugins: - plugin_stats[plugin.metadata.plugin_name] = { - "metadata": plugin.metadata, - "trig_cnt": plugin.trig_cnt - } - plugin.reset_trig_cnt() - - try: - res = { - "stat_version": "moon", - "version": _global_object.version, # 版本号 - "platform_stats": platform_stats, # 过去 30 分钟各消息平台交互消息数 - "llm_stats": llm_stats, - "plugin_stats": plugin_stats, - "sys": sys.platform, # 系统版本 - } - resp = requests.post( - 'https://api.soulter.top/upload', data=json.dumps(res), timeout=5) - if resp.status_code == 200: - ok = resp.json() - if ok['status'] == 'ok': - _global_object.cnt_total = 0 - except BaseException as e: - pass - await asyncio.sleep(30*60) - -def retry(n: int = 3): - ''' - 重试装饰器 - ''' - def decorator(func): - def wrapper(*args, **kwargs): - for i in range(n): - try: - return func(*args, **kwargs) - except Exception as e: - if i == n-1: raise e - logger.warning(f"函数 {func.__name__} 第 {i+1} 次重试... {e}") - return wrapper - return decorator - def run_monitor(global_object: Context): ''' 监测机器性能 diff --git a/util/metrics.py b/util/metrics.py new file mode 100644 index 0000000..3134150 --- /dev/null +++ b/util/metrics.py @@ -0,0 +1,66 @@ +import asyncio +import requests +import json +import sys + +from type.types import Context +from collections import defaultdict + +class MetricUploader(): + def __init__(self, context: Context) -> None: + self.platform_stats = {} + self.llm_stats = {} + self.plugin_stats = {} + self.command_stats = defaultdict(int) + self.context = context + + async def upload_metrics(self): + ''' + 上传相关非敏感的指标以更好地了解 AstrBot 的使用情况。上传的指标不会包含任何有关消息文本、用户信息等敏感信息。 + + 这些数据包含: + - AstrBot 版本 + - OS 版本 + - 平台消息上下行数量 + - LLM 模型名称、调用次数 + - 加载的插件的元数据 + ''' + await asyncio.sleep(10) + context = self.context + while True: + for llm in context.llms: + stat = llm.llm_instance.model_stat + for k in stat: + self.llm_stats[llm.llm_name + "#" + k] = stat[k] + llm.llm_instance.reset_model_stat() + + for plugin in context.cached_plugins: + self.plugin_stats[plugin.metadata.plugin_name] = { + "metadata": plugin.metadata + } + + try: + res = { + "stat_version": "moon", + "version": context.version, # 版本号 + "platform_stats": self.platform_stats, # 过去 30 分钟各消息平台交互消息数 + "llm_stats": self.llm_stats, + "plugin_stats": self.plugin_stats, + "command_stats": self.command_stats, + "sys": sys.platform, # 系统版本 + } + resp = requests.post( + 'https://api.soulter.top/upload', data=json.dumps(res), timeout=5) + if resp.status_code == 200: + ok = resp.json() + if ok['status'] == 'ok': + self.clear() + except BaseException as e: + pass + await asyncio.sleep(30*60) + + def clear(self): + self.platform_stats.clear() + self.llm_stats.clear() + self.plugin_stats.clear() + self.command_stats.clear() \ No newline at end of file From c0ea6737648dde6b281ee9c8b56d1e37a19037bc Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 25 Jul 2024 10:44:17 -0400 Subject: [PATCH 4/9] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E4=B8=80=E4=BA=9B?= =?UTF-8?q?=E6=8C=87=E4=BB=A4=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/command/internal_handler.py | 18 ++++----- model/command/openai_official_handler.py | 48 +++++++++++++++++++++--- model/platform/qq_official.py | 19 +++++++--- model/provider/openai_official.py | 17 +++++---- util/config_utils.py | 4 +- util/t2i/strategies/local_strategy.py | 2 +- util/t2i/strategies/network_strategy.py | 4 +- util/updator/zip_updator.py | 5 +++ 8 files changed, 85 insertions(+), 32 deletions(-) diff --git a/model/command/internal_handler.py b/model/command/internal_handler.py index 5f0c58c..50f0552 100644 --- a/model/command/internal_handler.py +++ b/model/command/internal_handler.py @@ -19,7 +19,7 @@ def __init__(self, manager: CommandManager, plugin_manager: PluginManager) -> No self.plugin_manager = plugin_manager self.manager.register("help", "查看帮助", 10, self.help) - self.manager.register("nick", "设置机器人昵称", 10, self.set_nick) + self.manager.register("wake", "设置机器人唤醒词", 10, self.set_nick) self.manager.register("update", "更新 AstrBot", 10, self.update) self.manager.register("plugin", "插件管理", 10, self.plugin) self.manager.register("reboot", "重启 AstrBot", 10, self.reboot) @@ -33,14 +33,14 @@ def set_nick(self, message: AstrMessageEvent, context: Context): return CommandResult().message("你没有权限使用该指令。") l = message_str.split(" ") if len(l) == 1: - return CommandResult().message("【设置机器人昵称】示例:\n支持多昵称\nnick 昵称1 昵称2 昵称3") + return CommandResult().message("设置机器人唤醒词,支持多唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称1 昵称2 昵称3") nick = l[1:] context.config_helper.put("nick_qq", nick) context.nick = tuple(nick) return CommandResult( hit=True, success=True, - message_chain=f"已经成功将昵称设定为 {nick}", + message_chain=f"已经成功将唤醒词设定为 {nick}", ) def update(self, message: AstrMessageEvent, context: Context): @@ -51,13 +51,13 @@ def update(self, message: AstrMessageEvent, context: Context): success=False, message_chain="你没有权限使用该指令", ) + update_info = context.updator.check_update(None, None) if tokens.len == 1: - update_info = context.updator.check_update(None, None) ret = "" if not update_info: ret = f"当前已经是最新版本 v{VERSION}。" else: - ret = f"发现新版本 v{update_info['version']}。\n- 使用 /update latest 更新到最新版本。\n- 使用 /update vX.X.X 更新到指定版本。" + ret = f"发现新版本 {update_info.version},更新内容如下:\n---\n{update_info.body}\n---\n- 使用 /update latest 更新到最新版本。\n- 使用 /update vX.X.X 更新到指定版本。" return CommandResult( hit=True, success=False, @@ -67,7 +67,7 @@ def update(self, message: AstrMessageEvent, context: Context): if tokens.get(1) == "latest": try: context.updator.update() - return CommandResult().message(f"已经成功更新到最新版本 v{update_info['version']}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启") + return CommandResult().message(f"已经成功更新到最新版本 v{update_info.version}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启") except BaseException as e: return CommandResult().message(f"更新失败。原因:{str(e)}") elif tokens.get(1).startswith("v"): @@ -157,13 +157,13 @@ async def help(self, message: AstrMessageEvent, context: Context): msg = "# Help Center\n## 指令列表\n" for key, value in self.manager.commands_handler.items(): if value.plugin_metadata: - msg += f"`{key}` ({value.plugin_metadata.plugin_name}) - {value.description}\n" - else: msg += f"`{key}` - {value.description}\n" + msg += f"- `{key}` ({value.plugin_metadata.plugin_name}): {value.description}\n" + else: msg += f"- `{key}`: {value.description}\n" # plugins if context.cached_plugins != None: plugin_list_info = "" for plugin in context.cached_plugins: - plugin_list_info += f"`{plugin.metadata.plugin_name}` {plugin.metadata.desc}\n" + plugin_list_info += f"- `{plugin.metadata.plugin_name}` {plugin.metadata.desc}\n" if plugin_list_info.strip() != "": msg += "\n## 插件列表\n> 使用plugin v 插件名 查看插件帮助\n" msg += plugin_list_info diff --git a/model/command/openai_official_handler.py b/model/command/openai_official_handler.py index 49b645a..1f78361 100644 --- a/model/command/openai_official_handler.py +++ b/model/command/openai_official_handler.py @@ -25,6 +25,7 @@ def __init__(self, manager: CommandManager) -> None: self.manager.register("unset", "清除个性化人格设置", 10, self.unset) self.manager.register("set", "设置个性化人格", 10, self.set) self.manager.register("draw", "调用 DallE 模型画图", 10, self.draw) + self.manager.register("model", "切换模型", 10, self.model) self.manager.register("画", "调用 DallE 模型画图", 10, self.draw) def set_provider(self, provider): @@ -37,6 +38,41 @@ async def reset(self, message: AstrMessageEvent, context: Context): return CommandResult().message("重置成功") elif tokens.get(1) == 'p': await self.provider.forget(message.session_id) + + async def model(self, message: AstrMessageEvent, context: Context): + tokens = self.manager.command_parser.parse(message.message_str) + if tokens.len == 1: + ret = await self._print_models() + return CommandResult().message(ret) + model = tokens.get(1) + if model.isdigit(): + try: + models = await self.provider.get_models() + except BaseException as e: + logger.error(f"获取模型列表失败: {str(e)}") + return CommandResult().message("获取模型列表失败,无法使用编号切换模型。可以尝试直接输入模型名来切换,如 gpt-4o。") + models = list(models) + if int(model) <= len(models) and int(model) >= 1: + model = models[int(model)-1] + self.provider.set_model(model.id) + return CommandResult().message(f"模型已设置为 {model.id}") + else: + self.provider.set_model(model) + return CommandResult().message(f"模型已设置为 {model} (自定义)") + + async def _print_models(self): + try: + models = await self.provider.get_models() + except BaseException as e: + return "获取模型列表失败: " + str(e) + i = 1 + ret = "OpenAI GPT 类可用模型" + for model in models: + ret += f"\n{i}. {model.id}" + i += 1 + ret += "\nTips: 使用 /model 模型名/编号,即可实时更换模型。如目标模型不存在于上表,请输入模型名。" + logger.debug(ret) + return ret def his(self, message: AstrMessageEvent, context: Context): tokens = self.manager.command_parser.parse(message.message_str) @@ -108,12 +144,12 @@ def unset(self, message: AstrMessageEvent, context: Context): def set(self, message: AstrMessageEvent, context: Context): l = message.message_str.split(" ") if len(l) == 1: - return CommandResult().message("【人格文本由PlexPt开源项目awesome-chatgpt-prompts-zh提供】\n设置人格: \n/set 人格名。例如/set 编剧\n人格列表: /set list\n人格详细信息: /set view 人格名\n自定义人格: /set 人格文本\n重置会话(清除人格): /reset\n重置会话(保留人格): /reset p\n【当前人格】: " + str(self.provider.curr_personality)) + return CommandResult().message("- 设置人格: \nset 人格名。例如 set 编剧\n- 人格列表: set list\n- 人格详细信息: set view 人格名\n- 自定义人格: set 人格文本\n- 重置会话(清除人格): reset\n- 重置会话(保留人格): reset p\n\n【当前人格】: " + str(self.provider.curr_personality['prompt'])) elif l[1] == "list": msg = "人格列表:\n" for key in personalities.keys(): - msg += f" |-{key}\n" - msg += '\n\n*输入/set view 人格名查看人格详细信息' + msg += f"- {key}\n" + msg += '\n\n*输入 set view 人格名 查看人格详细信息' return CommandResult().message(msg) elif l[1] == "view": if len(l) == 2: @@ -126,20 +162,20 @@ def set(self, message: AstrMessageEvent, context: Context): msg = f"人格{ps}不存在" return CommandResult().message(msg) else: - ps = l[1].strip() + ps = "".join(l[1:]).strip() if ps in personalities: self.provider.curr_personality = { 'name': ps, 'prompt': personalities[ps] } - self.provider.personality_set(ps, message.session_id) + self.provider.personality_set(self.provider.curr_personality, message.session_id) return CommandResult().message(f"人格已设置。 \n人格信息: {ps}") else: self.provider.curr_personality = { 'name': '自定义人格', 'prompt': ps } - self.provider.personality_set(ps, message.session_id) + self.provider.personality_set(self.provider.curr_personality, message.session_id) return CommandResult().message(f"人格已设置。 \n人格信息: {ps}") async def draw(self, message: AstrMessageEvent, context: Context): diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index f2d60ea..8c7fc2f 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -9,7 +9,7 @@ from botpy.types.message import Reference, Media from botpy import Client -from util.io import save_temp_img +from util.io import save_temp_img, download_image_by_url from . import Platform from type.astrbot_message import * from type.message_event import * @@ -80,7 +80,7 @@ def __init__(self, context: Context, message_handler: MessageHandler) -> None: self.client.set_platform(self) - def _parse_to_qqofficial(self, message: List[BaseMessageComponent]): + async def _parse_to_qqofficial(self, message: List[BaseMessageComponent], is_group: bool = False): plain_text = "" image_path = None # only one img supported for i in message: @@ -92,8 +92,9 @@ def _parse_to_qqofficial(self, message: List[BaseMessageComponent]): elif i.file and i.file.startswith("base64://"): img_data = base64.b64decode(i.file[9:]) image_path = save_temp_img(img_data) - else: - image_path = save_temp_img(i.file) + elif i.file and i.file.startswith("http"): + # 如果是群消息,不需要下载 + image_path = await download_image_by_url(i.file) if not is_group else i.file return plain_text, image_path def _parse_from_qqofficial(self, message: Union[botpy.message.Message, botpy.message.GroupMessage], @@ -237,7 +238,7 @@ async def reply_msg(self, rendered_images = await self.convert_to_t2i_chain(result_message) if isinstance(result_message, list): - plain_text, image_path = self._parse_to_qqofficial(result_message) + plain_text, image_path = await self._parse_to_qqofficial(result_message, message.type == MessageType.GROUP_MESSAGE) else: plain_text = result_message @@ -319,17 +320,23 @@ async def _reply(self, **kwargs): media = await self.client.api.post_group_file(kwargs['group_openid'], 1, image_url) del kwargs['file_image'] kwargs['media'] = media + logger.debug(f"发送群图片: {media}") kwargs['msg_type'] = 7 # 富媒体 await self.client.api.post_group_message(**kwargs) elif 'channel_id' in kwargs: # 频道消息 if 'file_image' in kwargs: kwargs['file_image'] = kwargs['file_image'].replace("file:///", "") + # 频道消息发图只支持本地 + if kwargs['file_image'].startswith("http"): + kwargs['file_image'] = await download_image_by_url(kwargs['file_image']) await self.client.api.post_message(**kwargs) else: # 频道私聊消息 if 'file_image' in kwargs: kwargs['file_image'] = kwargs['file_image'].replace("file:///", "") + if kwargs['file_image'].startswith("http"): + kwargs['file_image'] = await download_image_by_url(kwargs['file_image']) await self.client.api.post_dms(**kwargs) async def send_msg(self, target: Dict[str, str], result_message: Union[List[BaseMessageComponent], str]): @@ -343,7 +350,7 @@ async def send_msg(self, target: Dict[str, str], result_message: Union[List[Base - 如果目标是 频道私聊,请添加 key `guild_id`。 ''' if isinstance(result_message, list): - plain_text, image_path = self._parse_to_qqofficial(result_message) + plain_text, image_path = await self._parse_to_qqofficial(result_message) else: plain_text = result_message diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index ab7516f..73e5baa 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -188,12 +188,14 @@ async def is_lvm(self): return self.model_configs['model'].startswith("gpt-4") async def get_models(self): - ''' - 获取所有模型 - ''' - models = await self.client.models.list() - logger.info(f"OpenAI 模型列表:{models}") - return models + try: + models = await self.client.models.list() + except NotFoundError as e: + bu = str(self.client.base_url) + self.client.base_url = bu + "/v1" + models = await self.client.models.list() + finally: + return filter(lambda x: x.id.startswith("gpt"), models.data) async def assemble_context(self, session_id: str, prompt: str, image_url: str = None): ''' @@ -376,7 +378,8 @@ async def text_chat(self, if "maximum context length" in str(e): logger.warn(f"OpenAI 请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。") self.pop_record(session_id) - + + logger.warning(traceback.format_exc()) logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。") time.sleep(1) diff --git a/util/config_utils.py b/util/config_utils.py index 688bd74..935793c 100644 --- a/util/config_utils.py +++ b/util/config_utils.py @@ -11,7 +11,7 @@ def init_configs(): cc.init_attributes("qq_forward_threshold", 200) cc.init_attributes("qq_welcome", "") - cc.init_attributes("qq_pic_mode", False) + cc.init_attributes("qq_pic_mode", True) cc.init_attributes("gocq_host", "127.0.0.1") cc.init_attributes("gocq_http_port", 5700) cc.init_attributes("gocq_websocket_port", 6700) @@ -104,6 +104,6 @@ def inject_to_context(context: Context): if isinstance(nick_qq, list): nick_qq = tuple(nick_qq) context.nick = nick_qq - context.t2i_mode = cc.get("qq_pic_mode", False) + context.t2i_mode = cc.get("qq_pic_mode", True) return cfg \ No newline at end of file diff --git a/util/t2i/strategies/local_strategy.py b/util/t2i/strategies/local_strategy.py index 5711a48..0a27601 100644 --- a/util/t2i/strategies/local_strategy.py +++ b/util/t2i/strategies/local_strategy.py @@ -17,7 +17,7 @@ def get_font(self, size: int) -> ImageFont.FreeTypeFont: except Exception as e: pass - async def render(self, text: str, **kwargs): + async def render(self, text: str, return_url: bool=False) -> str: font_size = 26 image_width = 800 image_height = 600 diff --git a/util/t2i/strategies/network_strategy.py b/util/t2i/strategies/network_strategy.py index cd59ce3..67f406d 100644 --- a/util/t2i/strategies/network_strategy.py +++ b/util/t2i/strategies/network_strategy.py @@ -30,7 +30,9 @@ async def render(self, text: str, return_url: bool=False) -> str: "version": f"v{VERSION}", }, "options": { - "full_page": True + "full_page": True, + "type": "jpeg", + "quality": 25, } } diff --git a/util/updator/zip_updator.py b/util/updator/zip_updator.py index a49feee..102c9ee 100644 --- a/util/updator/zip_updator.py +++ b/util/updator/zip_updator.py @@ -10,6 +10,11 @@ class ReleaseInfo(): published_at: str body: str + def __init__(self, version: str = '', published_at: str = '', body: str = '') -> None: + self.version = version + self.published_at = published_at + self.body = body + def __str__(self) -> str: return f"新版本: {self.version}, 发布于: {self.published_at}, 详细内容: {self.body}" From 475257c6962b8bb5b11f2e8284ade485e750aba6 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 25 Jul 2024 12:33:31 -0400 Subject: [PATCH 5/9] =?UTF-8?q?perf:=20=E5=A2=9E=E5=BC=BA=20aiocqhttp=20?= =?UTF-8?q?=E5=8F=91=E5=9B=BE=E7=9A=84=E7=A8=B3=E5=AE=9A=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/platform/qq_aiocqhttp.py | 27 ++++++++++++++++++++----- model/plugin/manager.py | 2 +- type/command.py | 35 ++++++++++++++++++++++++++++++++- type/types.py | 2 +- util/plugin_dev/api/v1/types.py | 6 ++++-- 5 files changed, 62 insertions(+), 10 deletions(-) diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py index 84ffcf4..7bb1c49 100644 --- a/model/platform/qq_aiocqhttp.py +++ b/model/platform/qq_aiocqhttp.py @@ -2,6 +2,7 @@ import traceback import logging from aiocqhttp import CQHttp, Event +from aiocqhttp.exceptions import ActionFailed from . import Platform from type.astrbot_message import * from type.message_event import * @@ -164,13 +165,29 @@ async def _reply(self, message: AstrBotMessage, message_chain: List[BaseMessageC message_chain = [Plain(text=message_chain), ] ret = [] - for segment in message_chain: + image_idx = [] + for idx, segment in enumerate(message_chain): d = segment.toDict() if isinstance(segment, Plain): d['type'] = 'text' if isinstance(segment, Image): - # d['data']['file'] = - pass + image_idx.append(idx) ret.append(d) - - await self.bot.send(message.raw_message, ret) \ No newline at end of file + try: + await self.bot.send(message.raw_message, ret) + except ActionFailed as e: + logger.error(traceback.format_exc()) + logger.error(f"回复消息失败: {e}") + if e.retcode == 1200: + # ENOENT + if not image_idx: + raise e + logger.info("检测到失败原因为文件未找到,猜测用户的协议端与 AstrBot 位于不同的文件系统上。尝试采用上传图片的方式发图。") + for idx in image_idx: + if ret[idx]['data']['file'].startswith('file://'): + logger.info(f"正在上传图片: {ret[idx]['data']['path']}") + image_url = await self.context.image_uploader.upload_image(ret[idx]['data']['path']) + logger.info(f"上传成功。") + ret[idx]['data']['file'] = image_url + ret[idx]['data']['path'] = image_url + await self.bot.send(message.raw_message, ret) \ No newline at end of file diff --git a/model/plugin/manager.py b/model/plugin/manager.py index 11d33a0..7284105 100644 --- a/model/plugin/manager.py +++ b/model/plugin/manager.py @@ -165,7 +165,7 @@ def plugin_reload(self): try: # 尝试传入 ctx - obj = getattr(module, cls[0])(ctx=self.context) + obj = getattr(module, cls[0])(context=self.context) except: obj = getattr(module, cls[0])() diff --git a/type/command.py b/type/command.py index f09a0a0..5411c29 100644 --- a/type/command.py +++ b/type/command.py @@ -1,6 +1,6 @@ from typing import Union, List, Callable from dataclasses import dataclass -from nakuru.entities.components import Plain +from nakuru.entities.components import Plain, Image @dataclass @@ -26,8 +26,41 @@ def __init__(self, hit: bool = True, success: bool = True, message_chain: list = self.command_name = command_name def message(self, message: str): + ''' + 快捷回复消息。 + + CommandResult().message("Hello, world!") + ''' + self.message_chain = [Plain(message), ] + return self + + def error(self, message: str): + ''' + 快捷回复消息。 + + CommandResult().error("Hello, world!") + ''' + self.success = False self.message_chain = [Plain(message), ] return self + + def url_image(self, url: str): + ''' + 快捷回复图片(网络url的格式)。 + + CommandResult().image("https://example.com/image.jpg") + ''' + self.message_chain = [Image.fromURL(url), ] + return self + + def file_image(self, path: str): + ''' + 快捷回复图片(本地文件路径的格式)。 + + CommandResult().image("image.jpg") + ''' + self.message_chain = [Image.fromFileSystem(path), ] + return self def _result_tuple(self): return (self.success, self.message_chain, self.command_name) diff --git a/type/types.py b/type/types.py index b3b63e4..8773bfb 100644 --- a/type/types.py +++ b/type/types.py @@ -51,7 +51,7 @@ def register_commands(self, `command_name`: 指令名,如 "help"。不需要带前缀。 `description`: 指令描述。 `priority`: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。 - `handler`: 指令处理函数。 + `handler`: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context ''' self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler) diff --git a/util/plugin_dev/api/v1/types.py b/util/plugin_dev/api/v1/types.py index 96d07a7..eaaef76 100644 --- a/util/plugin_dev/api/v1/types.py +++ b/util/plugin_dev/api/v1/types.py @@ -1,5 +1,7 @@ ''' -插件类型 +插件类型、消息组件类型 ''' -from type.plugin import PluginType \ No newline at end of file +from type.plugin import PluginType + +from nakuru.entities.components import Image, Plain, At, Node, BaseMessageComponent \ No newline at end of file From b59ac8ab5d70661fa45f4ecb28f8164511b36d5f Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 25 Jul 2024 12:58:45 -0400 Subject: [PATCH 6/9] =?UTF-8?q?perf:=20=E5=9C=A8=20context=20=E4=B8=AD?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0message=5Fhandler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/bootstrap.py | 1 + type/types.py | 1 + 2 files changed, 2 insertions(+) diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py index bba3544..200d259 100644 --- a/astrbot/bootstrap.py +++ b/astrbot/bootstrap.py @@ -65,6 +65,7 @@ async def run(self): self.context.metrics_uploader = self.metrics_uploader self.context.updator = self.updator self.context.plugin_updator = self.plugin_manager.updator + self.context.message_handler = self.message_handler # load plugins, plugins' commands. self.load_plugins() diff --git a/type/types.py b/type/types.py index 8773bfb..78b0589 100644 --- a/type/types.py +++ b/type/types.py @@ -37,6 +37,7 @@ def __init__(self): self.plugin_command_bridge = PluginCommandBridge(self.cached_plugins) self.image_renderer = TextToImageRenderer() self.image_uploader = ImageUploader() + self.message_handler = None # see astrbot/message/handler.py def register_commands(self, plugin_name: str, From a93746fdcd4211df630c02dd7688eef78602cf18 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 26 Jul 2024 05:02:29 -0400 Subject: [PATCH 7/9] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20aiocqhttp=20?= =?UTF-8?q?=E8=BF=90=E8=A1=8C=E5=AF=BC=E8=87=B4=20ctrl+c=20=E6=97=A0?= =?UTF-8?q?=E6=B3=95=E9=80=80=E5=87=BA=20bot=20=E7=9A=84=E9=97=AE=E9=A2=98?= =?UTF-8?q?=20perf:=20=E6=94=AF=E6=8C=81=E9=80=9A=E8=BF=87context=E6=B3=A8?= =?UTF-8?q?=E5=86=8Ctask?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/bootstrap.py | 20 +++++++++++--------- model/platform/manager.py | 6 +++--- model/platform/qq_aiocqhttp.py | 7 ++++++- type/types.py | 14 +++++++++++++- util/plugin_dev/api/v1/llm.py | 5 ----- util/plugin_dev/api/v1/message.py | 3 ++- util/plugin_dev/api/v1/register.py | 8 ++++++-- 7 files changed, 41 insertions(+), 22 deletions(-) diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py index 200d259..ef548dd 100644 --- a/astrbot/bootstrap.py +++ b/astrbot/bootstrap.py @@ -74,21 +74,23 @@ async def run(self): # load platforms platform_tasks = self.load_platform() # load metrics uploader - metrics_upload_task = asyncio.create_task(self.metrics_uploader.upload_metrics()) + metrics_upload_task = asyncio.create_task(self.metrics_uploader.upload_metrics(), name="metrics-uploader") # load dashboard self.dashboard.run_http_server() - dashboard_task = asyncio.create_task(self.dashboard.ws_server()) - tasks = [metrics_upload_task, dashboard_task, *platform_tasks] + dashboard_task = asyncio.create_task(self.dashboard.ws_server(), name="dashboard") + tasks = [metrics_upload_task, dashboard_task, *platform_tasks, *self.context.ext_tasks] tasks = [self.handle_task(task) for task in tasks] await asyncio.gather(*tasks) async def handle_task(self, task: Union[asyncio.Task, asyncio.Future]): - try: - result = await task - return result - except Exception as e: - logger.error(traceback.format_exc()) - return None + while True: + try: + result = await task + return result + except Exception as e: + logger.error(traceback.format_exc()) + logger.error(f"{task.get_name()} 任务发生错误,将在 5 秒后重试。") + await asyncio.sleep(5) def load_llm(self): if 'openai' in self.configs and \ diff --git a/model/platform/manager.py b/model/platform/manager.py index 1f287f4..5ca2173 100644 --- a/model/platform/manager.py +++ b/model/platform/manager.py @@ -21,16 +21,16 @@ def load_platforms(self): if 'gocqbot' in self.config and self.config['gocqbot']['enable']: logger.info("启用 QQ(nakuru 适配器)") - tasks.append(asyncio.create_task(self.gocq_bot())) + tasks.append(asyncio.create_task(self.gocq_bot(), name="nakuru-adapter")) if 'aiocqhttp' in self.config and self.config['aiocqhttp']['enable']: logger.info("启用 QQ(aiocqhttp 适配器)") - tasks.append(asyncio.create_task(self.aiocq_bot())) + tasks.append(asyncio.create_task(self.aiocq_bot(), name="aiocqhttp-adapter")) # QQ频道 if 'qqbot' in self.config and self.config['qqbot']['enable'] and self.config['qqbot']['appid'] != None: logger.info("启用 QQ(官方 API) 机器人消息平台") - tasks.append(asyncio.create_task(self.qqchan_bot())) + tasks.append(asyncio.create_task(self.qqchan_bot(), name="qqofficial-adapter")) return tasks diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py index 7bb1c49..e0d8105 100644 --- a/model/platform/qq_aiocqhttp.py +++ b/model/platform/qq_aiocqhttp.py @@ -1,4 +1,5 @@ import time +import asyncio import traceback import logging from aiocqhttp import CQHttp, Event @@ -82,7 +83,7 @@ async def private(event: Event): await self.handle_msg(abm) # return {'reply': event.message} - bot = self.bot.run_task(host=self.host, port=int(self.port)) + bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder) for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) @@ -90,6 +91,10 @@ async def private(event: Event): return bot + async def shutdown_trigger_placeholder(self): + while True: + await asyncio.sleep(1) + def pre_check(self, message: AstrBotMessage) -> bool: # if message chain contains Plain components or At components which points to self_id, return True if message.type == MessageType.FRIEND_MESSAGE: diff --git a/type/types.py b/type/types.py index 78b0589..4acf0fc 100644 --- a/type/types.py +++ b/type/types.py @@ -1,5 +1,7 @@ +import asyncio +from asyncio import Task from type.register import * -from typing import List +from typing import List, Awaitable from logging import Logger from util.cmd_config import CmdConfig from util.t2i.renderer import TextToImageRenderer @@ -38,6 +40,7 @@ def __init__(self): self.image_renderer = TextToImageRenderer() self.image_uploader = ImageUploader() self.message_handler = None # see astrbot/message/handler.py + self.ext_tasks: List[Task] = [] def register_commands(self, plugin_name: str, @@ -56,6 +59,15 @@ def register_commands(self, ''' self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler) + def register_task(self, coro: Awaitable, task_name: str): + ''' + 注册任务。适用于需要长时间运行的插件。 + + `coro`: 协程对象 + `task_name`: 任务名,用于标识任务。自定义即可。 + ''' + task = asyncio.create_task(coro, name=task_name) + self.ext_tasks.append(task) def find_platform(self, platform_name: str) -> RegisteredPlatform: for platform in self.platforms: diff --git a/util/plugin_dev/api/v1/llm.py b/util/plugin_dev/api/v1/llm.py index e585689..11d2be5 100644 --- a/util/plugin_dev/api/v1/llm.py +++ b/util/plugin_dev/api/v1/llm.py @@ -1,6 +1 @@ -''' -大语言模型. - -插件开发者可以继承这个类来做实现。 -''' from model.provider.provider import Provider as LLMProvider \ No newline at end of file diff --git a/util/plugin_dev/api/v1/message.py b/util/plugin_dev/api/v1/message.py index 182e911..218498c 100644 --- a/util/plugin_dev/api/v1/message.py +++ b/util/plugin_dev/api/v1/message.py @@ -1,3 +1,4 @@ from type.message_event import * from type.astrbot_message import * -from type.command import CommandResult \ No newline at end of file +from type.command import CommandResult +from astrbot.message.handler import MessageHandler \ No newline at end of file diff --git a/util/plugin_dev/api/v1/register.py b/util/plugin_dev/api/v1/register.py index 6ebef89..da58819 100644 --- a/util/plugin_dev/api/v1/register.py +++ b/util/plugin_dev/api/v1/register.py @@ -8,13 +8,17 @@ from type.types import Context from type.register import RegisteredPlatform, RegisteredLLM -def register_platform(platform_name: str, platform_instance: Platform, context: Context) -> None: +def register_platform(platform_name: str, context: Context, platform_instance: Platform = None) -> None: ''' 注册一个消息平台。 Args: platform_name: 平台名称。 - platform_instance: 平台实例。 + platform_instance: 平台实例,可为空。 + context: 上下文对象。 + + Note: + 当插件类被加载时,AstrBot 会传给插件 context 对象。插件可以通过 context 对象注册指令、长任务等。 ''' # check 是否已经注册 From f6276caf5060960559909d0a2183ad7c98b0a3b9 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 26 Jul 2024 05:15:41 -0400 Subject: [PATCH 8/9] chore: update default plugin --- addons/plugins/helloworld/main.py | 64 +++++++------------------ addons/plugins/helloworld/metadata.yaml | 4 +- 2 files changed, 19 insertions(+), 49 deletions(-) diff --git a/addons/plugins/helloworld/main.py b/addons/plugins/helloworld/main.py index a088cda..56eb991 100644 --- a/addons/plugins/helloworld/main.py +++ b/addons/plugins/helloworld/main.py @@ -1,62 +1,32 @@ -from nakuru.entities.components import * flag_not_support = False try: + from util.plugin_dev.api.v1.bot import Context, AstrMessageEvent, CommandResult from util.plugin_dev.api.v1.config import * - from util.plugin_dev.api.v1.bot import ( - AstrMessageEvent, - CommandResult, - ) except ImportError: flag_not_support = True print("导入接口失败。请升级到 AstrBot 最新版本。") - ''' -注意改插件名噢!格式:XXXPlugin 或 Main -小提示:把此模板仓库 fork 之后 clone 到机器人文件夹下的 addons/plugins/ 目录下,然后用 Pycharm/VSC 等工具打开可获更棒的编程体验(自动补全等) +注意以格式 XXXPlugin 或 Main 来修改插件名。 +提示:把此模板仓库 fork 之后 clone 到机器人文件夹下的 addons/plugins/ 目录下,然后用 Pycharm/VSC 等工具打开可获更棒的编程体验(自动补全等) ''' class HelloWorldPlugin: """ - 初始化函数, 可以选择直接pass + AstrBot 会传递 context 给插件。 + + - context.register_commands: 注册指令 + - context.register_task: 注册任务 + - context.message_handler: 消息处理器(平台类插件用) """ - def __init__(self) -> None: - pass + def __init__(self, context: Context) -> None: + self.context = context + self.context.register_commands("helloworld", "helloworld", "内置测试指令。", 1, self.helloworld) """ - 机器人程序会调用此函数。 - """ - def run(self, ame: AstrMessageEvent): - if ame.message_str.startswith("helloworld"): # 如果消息文本以"helloworld"开头 - return CommandResult( - hit=True, # 代表插件会响应此消息 - success=True, # 插件响应类型为成功响应 - message_chain=[Plain("Hello World!!")], # 消息链 - command_name="helloworld" # 指令名 - ) - return CommandResult( - hit=False, # 插件不会响应此消息 - success=False, - message_chain=None - ) - """ - 插件元信息。 - 当用户输入 plugin v 插件名称 时,会调用此函数,返回帮助信息。 - 返回参数要求(必填):dict{ - "name": str, # 插件名称 - "desc": str, # 插件简短描述 - "help": str, # 插件帮助信息 - "version": str, # 插件版本 - "author": str, # 插件作者 - "repo": str, # 插件仓库地址 [ 可选 ] - "homepage": str, # 插件主页 [ 可选 ] - } + 指令处理函数。 + + - 需要接收两个参数:message: AstrMessageEvent, context: Context + - 返回 CommandResult 对象 """ - def info(self): - return { - "name": "helloworld", - "desc": "这是 AstrBot 的默认插件,支持关键词回复。", - "help": "输入 /keyword 查看关键词回复帮助。", - "version": "v1.3", - "author": "Soulter", - "repo": "https://github.com/Soulter/helloworld" - } + def helloworld(self, message: AstrMessageEvent, context: Context): + return CommandResult().message("Hello, World!") \ No newline at end of file diff --git a/addons/plugins/helloworld/metadata.yaml b/addons/plugins/helloworld/metadata.yaml index eefb68c..41e34ef 100644 --- a/addons/plugins/helloworld/metadata.yaml +++ b/addons/plugins/helloworld/metadata.yaml @@ -1,6 +1,6 @@ name: helloworld # 这是你的插件的唯一识别名。 -desc: 这是 AstrBot 的默认插件,支持关键词回复。 # 插件简短描述 -help: 输入 /keyword 查看关键词回复帮助。 # 插件的帮助信息 +desc: 这是 AstrBot 的默认插件。 +help: version: v1.3 # 插件版本号。格式:v1.1.1 或者 v1.1 author: Soulter # 作者 repo: https://github.com/Soulter/helloworld # 插件的仓库地址 From 6f5f3659902c735fd45d14e26b15fe207a40814c Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 26 Jul 2024 06:24:11 -0400 Subject: [PATCH 9/9] update version --- type/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/type/config.py b/type/config.py index 261a1a1..a040388 100644 --- a/type/config.py +++ b/type/config.py @@ -1 +1 @@ -VERSION = '3.2.4' +VERSION = '3.3.0' \ No newline at end of file