diff --git a/api/alembic/versions/df31d751e3d7_update_plugin_api_table.py b/api/alembic/versions/df31d751e3d7_update_plugin_api_table.py new file mode 100644 index 0000000..9fd1532 --- /dev/null +++ b/api/alembic/versions/df31d751e3d7_update_plugin_api_table.py @@ -0,0 +1,51 @@ +"""update plugin api table + +Revision ID: df31d751e3d7 +Revises: a7a038514b7c +Create Date: 2024-07-17 09:25:05.071105 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'df31d751e3d7' +down_revision: Union[str, None] = 'a7a038514b7c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('plugin_api', sa.Column( + 'server_url', sa.String(length=255), nullable=True)) + op.add_column('plugin_api', sa.Column( + 'method', sa.String(length=255), nullable=True)) + op.add_column('plugin_api', sa.Column( + 'summary', sa.String(length=255), nullable=True)) + op.add_column('plugin_api', sa.Column( + 'parameters', sa.JSON(), nullable=True)) + op.add_column('plugin_api', sa.Column( + 'operation_id', sa.String(length=255), nullable=True)) + op.alter_column('plugin_api', 'openapi_desc', + existing_type=sa.TEXT(), + type_=sa.JSON(), + existing_nullable=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('plugin_api', 'openapi_desc', + existing_type=sa.JSON(), + type_=sa.TEXT(), + existing_nullable=True) + op.drop_column('plugin_api', 'operation_id') + op.drop_column('plugin_api', 'parameters') + op.drop_column('plugin_api', 'summary') + op.drop_column('plugin_api', 'method') + op.drop_column('plugin_api', 'server_url') + # ### end Alembic commands ### diff --git a/api/alembic/versions/f54e1fc0bd90_update_plugin_config_table.py b/api/alembic/versions/f54e1fc0bd90_update_plugin_config_table.py new file mode 100644 index 0000000..3e5584c --- /dev/null +++ b/api/alembic/versions/f54e1fc0bd90_update_plugin_config_table.py @@ -0,0 +1,34 @@ +"""update plugin config table + +Revision ID: f54e1fc0bd90 +Revises: df31d751e3d7 +Create Date: 2024-07-21 14:41:06.251765 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'f54e1fc0bd90' +down_revision: Union[str, None] = 'df31d751e3d7' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('plugin_config', sa.Column('apis', sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('plugin_config', 'apis') + op.alter_column('messages', 'message_type', + existing_type=sa.Enum('MARKDOWN', 'TEXT', name='messagetype'), + type_=sa.VARCHAR(length=50), + existing_nullable=False) + # ### end Alembic commands ### diff --git a/api/constants/constant.py b/api/constants/constant.py new file mode 100644 index 0000000..28f76bb --- /dev/null +++ b/api/constants/constant.py @@ -0,0 +1,115 @@ +defaultPluginOpenapiDesc: str = """ +info: + description: arxiv查找论文 + title: arxiv_app + version: v1 +openapi: 3.0.1 +paths: + /search: + get: + operationId: search + parameters: + - description: 输入搜索关键词,必须是英文 + in: query + name: search_query + required: true + schema: + type: string + - description: 搜索数量 + in: query + name: count + schema: + type: integer + - description: 可以是1:降序2:升序 + in: query + name: sort_order + schema: + type: integer + - description: 可以是1:相关性2.lastUpdateDate3.submittedDate + in: query + name: sort_by + schema: + type: integer + - description: 报告编号过滤器,用于获取报告编号中包含给定关键字的结果 + in: query + name: report_number + schema: + type: string + - description: 主题类别过滤器,用于获取主题类别中包含给定关键字的结果。 + in: query + name: subject_category + schema: + type: string + - description: 标题过滤器,用于获取标题包含给定关键字的结果 + in: query + name: title + schema: + type: string + - description: 摘要过滤器,用于获取摘要中包含给定关键字的结果 + in: query + name: abstract + schema: + type: string + - description: 作者过滤器,用于获取包含给定关键字的作者的结果 + in: query + name: author + schema: + type: string + - description: 评论过滤器,用于获取评论中包含给定关键字的结果 + in: query + name: comment + schema: + type: string + - description: 期刊参考过滤器,用于获取期刊参考中包含给定关键字的结果 + in: query + name: journal_reference + schema: + type: string + requestBody: + content: + application/json: + schema: + type: object + responses: + "200": + content: + application/json: + schema: + properties: + items: + items: + properties: + authors: + description: 作者列表 + items: + description: 作者 + type: string + type: array + entry_id: + description: https://arxiv.org/abs/{id} 链接 + type: string + pdf_url: + description: pdf 下载链接 + type: string + published: + description: 发表时间 + type: string + summary: + description: 摘要 + type: string + title: + description: 标题 + type: string + type: object + type: array + total: + description: 结果数目 + type: integer + type: object + description: new desc + default: + description: "" + summary: search +servers: + - url: http://env-00jxgx8y5a87.dev-hz.cloudbasefunction.cn +""" diff --git a/api/core/api_tool.py b/api/core/api_tool.py new file mode 100644 index 0000000..2f02945 --- /dev/null +++ b/api/core/api_tool.py @@ -0,0 +1,206 @@ +import json +from typing import Any +from urllib.parse import urlencode +import httpx +from models.plugin_api import PluginApiModel +import core.utils.ssrf_proxy as ssrf_proxy + + +class ApiTool: + plugin_api_model: PluginApiModel + + def __init__(self, plugin_api_model): + self.plugin_api_model = plugin_api_model + + @staticmethod + def get_parameter_value(parameter, parameters): + if parameter['name'] in parameters: + return parameters[parameter['name']] + elif parameter.get('required', False): + raise Exception(f"Missing required parameter {parameter['name']}") + else: + return (parameter.get('schema', {}) or {}).get('default', None) + + def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], + max_recursive=10) -> Any: + if max_recursive <= 0: + raise Exception("Max recursion depth reached") + for option in any_of or []: + try: + if 'type' in option: + # Attempt to convert the value based on the type. + if option['type'] == 'integer' or option['type'] == 'int': + return int(value) + elif option['type'] == 'number': + if '.' in str(value): + return float(value) + else: + return int(value) + elif option['type'] == 'string': + return str(value) + elif option['type'] == 'boolean': + if str(value).lower() in ['true', '1']: + return True + elif str(value).lower() in ['false', '0']: + return False + else: + continue # Not a boolean, try next option + elif option['type'] == 'null' and not value: + return None + else: + continue # Unsupported type, try next option + elif 'anyOf' in option and isinstance(option['anyOf'], list): + # Recursive call to handle nested anyOf + return self._convert_body_property_any_of(property, value, option['anyOf'], max_recursive - 1) + except ValueError: + continue # Conversion failed, try next option + # If no option succeeded, you might want to return the value as is or raise an error + # or raise ValueError(f"Cannot convert value '{value}' to any specified type in anyOf") + return value + + def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any: + try: + if 'type' in property: + if property['type'] == 'integer' or property['type'] == 'int': + return int(value) + elif property['type'] == 'number': + # check if it is a float + if '.' in value: + return float(value) + else: + return int(value) + elif property['type'] == 'string': + return str(value) + elif property['type'] == 'boolean': + return bool(value) + elif property['type'] == 'null': + if value is None: + return None + elif property['type'] == 'object': + if isinstance(value, str): + try: + return json.loads(value) + except ValueError: + return value + elif isinstance(value, dict): + return value + else: + return value + else: + raise ValueError( + f"Invalid type {property['type']} for property {property}") + elif 'anyOf' in property and isinstance(property['anyOf'], list): + return self._convert_body_property_any_of(property, value, property['anyOf']) + except ValueError as e: + return value + + def do_http_request(self, url: str, method: str, headers: dict[str, Any], + parameters: dict[str, Any]) -> httpx.Response: + """ + do http request depending on api bundle + """ + method = method.lower() + + params = {} + path_params = {} + body = {} + cookies = {} + + # check parameters + for parameter in self.plugin_api_model.openapi_desc.get('parameters', []): + value = self.get_parameter_value(parameter, parameters) + if value is not None: + if parameter['in'] == 'path': + path_params[parameter['name']] = value + + elif parameter['in'] == 'query': + params[parameter['name']] = value + + elif parameter['in'] == 'cookie': + cookies[parameter['name']] = value + + elif parameter['in'] == 'header': + headers[parameter['name']] = value + + # check if there is a request body and handle it + if 'requestBody' in self.plugin_api_model.openapi_desc and self.plugin_api_model.openapi_desc['requestBody'] is not None: + # handle json request body + if 'content' in self.plugin_api_model.openapi_desc['requestBody']: + for content_type in self.plugin_api_model.openapi_desc['requestBody']['content']: + headers['Content-Type'] = content_type + body_schema = self.plugin_api_model.openapi_desc[ + 'requestBody']['content'][content_type]['schema'] + required = body_schema.get('required', []) + properties = body_schema.get('properties', {}) + for name, property in properties.items(): + if name in parameters: + # convert type + body[name] = self._convert_body_property_type( + property, parameters[name]) + elif name in required: + raise Exception( + f"Missing required parameter {name} in operation {self.plugin_api_model.operation_id}" + ) + elif 'default' in property: + body[name] = property['default'] + else: + body[name] = None + break + + # replace path parameters + for name, value in path_params.items(): + url = url.replace(f'{{{name}}}', f'{value}') + + # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored + if 'Content-Type' in headers: + if headers['Content-Type'] == 'application/json': + body = json.dumps(body) + elif headers['Content-Type'] == 'application/x-www-form-urlencoded': + body = urlencode(body) + else: + body = body + + if method in ('get', 'head', 'post', 'put', 'delete', 'patch'): + response = getattr(ssrf_proxy, method)(url, params=params, headers=headers, data=body, + follow_redirects=True) + return response + else: + raise ValueError(f'Invalid http method') + + def validate_and_parse_response(self, response: httpx.Response) -> str: + """ + validate the response + """ + if isinstance(response, httpx.Response): + if response.status_code >= 400: + raise Exception( + f"Request failed with status code {response.status_code} and {response.text}") + if not response.content: + return 'Empty response from the tool, please check your parameters and try again.' + try: + response = response.json() + try: + return json.dumps(response, ensure_ascii=False) + except Exception as e: + return json.dumps(response) + except Exception as e: + return response.text + else: + raise ValueError(f'Invalid response type {type(response)}') + + def _invoke(self, tool_parameters: dict[str, Any]) -> str | list[str]: + """ + invoke http request + """ + # assemble request + headers = {} + + # do http request + response = self.do_http_request( + self.plugin_api_model.server_url, self.plugin_api_model.method, headers, tool_parameters) + + # validate response + response = self.validate_and_parse_response(response) + + # assemble invoke message + return response diff --git a/api/core/au_agent_provider.py b/api/core/au_agent_provider.py index 5f73494..c85eef3 100644 --- a/api/core/au_agent_provider.py +++ b/api/core/au_agent_provider.py @@ -171,6 +171,8 @@ def can_handle(self, config: ConfigMeta) -> int: return 20 if config.model.key == 'gpt-4o': return 20 + if config.model.key == 'gpt-3.5-turbo': + return 20 return -1 def provide(self, config: ConfigMeta, chat_id: int) -> Agent: diff --git a/api/core/base.py b/api/core/base.py index 77fa1bb..d080ca9 100644 --- a/api/core/base.py +++ b/api/core/base.py @@ -13,7 +13,7 @@ class ConfigMeta(BaseModel): '''agent meta model''' persona: str model: ModelMeta - tools: List[Any] = [] + plugins: List[Any] = [] class SSEType(Enum): diff --git a/api/core/chat.py b/api/core/chat.py index e2c5eed..c1e2403 100644 --- a/api/core/chat.py +++ b/api/core/chat.py @@ -29,9 +29,9 @@ def to_message(msg: MessageModel) -> BaseMessage: def chat(agent_config: AgentConfigModel, chat_id: int, history: List[MessageModel], message: MessageModel): config = get_config_meta(agent_config) - system_msg = SystemMessage(content=config.persona) - msgs = [to_message(m) for m in history] - msgs.insert(0, system_msg) + # system_msg = SystemMessage(content=config.persona) + # msgs = [to_message(m) for m in history] + # msgs.insert(0, system_msg) provider = agent_manager.get_provider(config) agent = provider.provide(config, chat_id) answer = agent.invoke(config, chat_id, history, message) diff --git a/api/core/langchain_agent_provider.py b/api/core/langchain_agent_provider.py index 28ded53..33ba6c3 100644 --- a/api/core/langchain_agent_provider.py +++ b/api/core/langchain_agent_provider.py @@ -8,7 +8,10 @@ from langchain_community.callbacks.manager import get_openai_callback from langchain_community.chat_models.openai import ChatOpenAI +from core.api_tool import ApiTool +from db import get_session from models.chat import MessageModel, MessageSenderType +from services.plugin import PluginAPIService from .agent_provider import AgentProvider from .agent import Agent @@ -32,7 +35,24 @@ def invoke( message: MessageModel, **kwargs, ) -> BaseMessage | None: - raise Exception('not implement') + """Chat with agent""" + llm = ChatOpenAI(model=config.model.key) + api_id = config.plugins[0]['api'] + with get_session() as session: + plugin_api_model = PluginAPIService.get_by_id(api_id, session) + api_tool = ApiTool(plugin_api_model=plugin_api_model) + print('api_tool_invoke', api_tool._invoke( + {"search_query": "deep learning"})) + tools = [] + system_msg = SystemMessage(content=config.persona) + msgs = [to_message(m) for m in history] + msgs.insert(0, system_msg) + try: + with get_openai_callback() as cb: + result = llm.invoke(msgs) + return result + except Exception as e: + return None def invoke_astream(self, config: ConfigMeta, @@ -44,6 +64,9 @@ def invoke_astream(self, *history_models, question_model = history msgs = [to_message(m) for m in history_models] llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) + for plugin in config.plugins: + print('plugin_agent', plugin) + # b = [{'id': x['key']} for plugin in config.plugins] tools = [] prompt = ChatPromptTemplate.from_messages([ system_msg, @@ -144,14 +167,16 @@ class LangchainAgentProvider(AgentProvider): def can_handle(self, config: ConfigMeta) -> int: '''Priority of provider for handling agent config''' if config.model.key == 'gpt-4': - return 10 + return 30 if config.model.key == 'gpt-4o': - return 10 + return 30 + if config.model.key == 'gpt-3.5-turbo': + return 30 return -1 def provide(self, config: ConfigMeta, chat_id: int) -> Agent: """get suitable agent""" - if len(config.tools) > 0: + if len(config.plugins) > 0: return ToolCallingAgent(meta=config) return LLMChat(meta=config) diff --git a/api/core/utils/__init__.py b/api/core/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/core/utils/ssrf_proxy.py b/api/core/utils/ssrf_proxy.py new file mode 100644 index 0000000..019b27f --- /dev/null +++ b/api/core/utils/ssrf_proxy.py @@ -0,0 +1,48 @@ +""" +Proxy requests to avoid SSRF +""" +import os + +import httpx + +SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '') +SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '') +SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '') + +proxies = { + 'http://': SSRF_PROXY_HTTP_URL, + 'https://': SSRF_PROXY_HTTPS_URL +} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None + + +def make_request(method, url, **kwargs): + if SSRF_PROXY_ALL_URL: + return httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs) + elif proxies: + return httpx.request(method=method, url=url, proxies=proxies, **kwargs) + else: + return httpx.request(method=method, url=url, **kwargs) + + +def get(url, **kwargs): + return make_request('GET', url, **kwargs) + + +def post(url, **kwargs): + return make_request('POST', url, **kwargs) + + +def put(url, **kwargs): + return make_request('PUT', url, **kwargs) + + +def patch(url, **kwargs): + return make_request('PATCH', url, **kwargs) + + +def delete(url, **kwargs): + return make_request('DELETE', url, **kwargs) + + +def head(url, **kwargs): + return make_request('HEAD', url, **kwargs) diff --git a/api/dao/plugin.py b/api/dao/plugin.py index 331f372..cd20932 100644 --- a/api/dao/plugin.py +++ b/api/dao/plugin.py @@ -1,10 +1,11 @@ from datetime import datetime from typing import Any, List from sqlalchemy.orm import Session -from sqlalchemy import func +from sqlalchemy import desc, func from models.plugin_config import PluginConfigCreate, PluginConfigORM, PluginConfigUpdate from models.plugin import PluginCreate, PluginORM, PluginUpdate from models.plugin_api import PluginApiCreate, PluginApiORM, PluginApiUpdate +from constants.constant import defaultPluginOpenapiDesc class PluginHelper: @@ -20,19 +21,10 @@ def get(session: Session, plugin_id: int) -> PluginORM | None: @staticmethod def get_user_plugin(session: Session, user_id: int) -> List[PluginORM]: return session.query(PluginORM).filter(PluginORM.created_by == user_id).order_by(PluginORM.updated_at.desc()).all() - # return paginate( - # session.query(PluginORM) - # .filter(PluginORM.created_by == user_id) - # .order_by(PluginORM.updated_at.desc()) - # ) @staticmethod def get_all_plugin(session: Session) -> List[PluginORM]: return session.query(PluginORM).order_by(PluginORM.updated_at.desc()).all() - # return paginate( - # session.query(PluginORM) - # .order_by(PluginORM.updated_at.desc()) - # ) @staticmethod def update(session: Session, operator: int, plugin_model: PluginUpdate) -> int: @@ -76,11 +68,18 @@ def get_plugin_draft(session: Session, plugin_id: int) -> PluginConfigORM | None PluginConfigORM.is_draft == True ).one_or_none() + @staticmethod + def get_latest_publish_config(session: Session, plugin_id: int) -> PluginConfigORM | None: + return session.query(PluginConfigORM).filter( + PluginConfigORM.plugin_id == plugin_id, + PluginConfigORM.is_draft == False + ).order_by(desc(PluginConfigORM.updated_at)).first() + @staticmethod def get_or_create_plugin_draft(session: Session, operator: int, plugin_id: int,) -> PluginConfigORM: exist = PluginConfigHelper.get_plugin_draft(session, plugin_id) if exist is None: - return PluginConfigHelper.create(session, operator, PluginConfigCreate(plugin_id=plugin_id, plugin_openapi_desc='', is_draft=True)) + return PluginConfigHelper.create(session, operator, PluginConfigCreate(plugin_id=plugin_id, plugin_openapi_desc=defaultPluginOpenapiDesc, is_draft=True)) else: return exist @@ -127,14 +126,13 @@ def count(session: Session) -> int: def get(session: Session, plugin_api_id: int) -> PluginApiORM | None: return session.query(PluginApiORM).filter(PluginApiORM.id == plugin_api_id).one_or_none() + @staticmethod + def get_by_config_id(session: Session, plugin_config_id: int) -> List[PluginApiORM]: + return session.query(PluginApiORM).filter(PluginApiORM.plugin_config_id == plugin_config_id).all() + @staticmethod def get_user_all(session: Session, user_id: int) -> List[PluginApiORM]: return session.query(PluginApiORM).filter(PluginApiORM.created_by == user_id).order_by(PluginApiORM.updated_at.desc()).all() - # paginate( - # session.query(PluginApiORM) - # .filter(PluginApiORM.created_by == user_id) - # .order_by(PluginApiORM.updated_at.desc()) - # ) @staticmethod def get_all(session: Session) -> List[PluginApiORM]: diff --git a/api/db.py b/api/db.py index cc66afe..469015e 100644 --- a/api/db.py +++ b/api/db.py @@ -1,6 +1,9 @@ +from contextlib import contextmanager +from typing import Generator from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker,Session + from config import settings SQLALCHEMY_DATABASE_URL = str(settings.SQLALCHEMY_DATABASE_URI) @@ -10,6 +13,14 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() +# 上下文管理器用于手动管理会话 +@contextmanager +def get_session() -> Generator[Session, None, None]: + session = SessionLocal() + try: + yield session + finally: + session.close() def get_db(): try: diff --git a/api/main.py b/api/main.py index 78b4ac8..7369066 100644 --- a/api/main.py +++ b/api/main.py @@ -1,9 +1,8 @@ from fastapi import FastAPI from fastapi_pagination import add_pagination -from routers.main import api_router - from dao.account import init_db +from routers.main import api_router from db import SessionLocal diff --git a/api/models/plugin_api.py b/api/models/plugin_api.py index 13afd83..670fb19 100644 --- a/api/models/plugin_api.py +++ b/api/models/plugin_api.py @@ -1,8 +1,9 @@ -from typing import Optional -from pydantic import BaseModel +from typing import List, Optional +from pydantic import BaseModel, Json from datetime import datetime from sqlalchemy import ( + JSON, Column, ForeignKey, Integer, @@ -22,10 +23,15 @@ class PluginApiORM(Base): __tablename__ = "plugin_api" id = Column(Integer, primary_key=True) plugin_config_id = Column(Integer, ForeignKey( - "plugin_config.id"), nullable=True) + "plugin_config.id")) description = Column(Text) # API 描述 name = Column(String(255)) # 插件名字 - openapi_desc = Column(Text) # openapi 先调研 + server_url = Column(String(255)) + method = Column(String(255)) + summary = Column(String(255)) + parameters = Column(JSON) + openapi_desc = Column(JSON) # openapi 先调研 + operation_id = Column(String(255)) # request_params = Column(JSON) # 入参 # response_params = Column(JSON) # 出参 @@ -50,9 +56,14 @@ class PluginApiCreate(BaseModel): plugin api create ''' plugin_config_id: int - openapi_desc: str + openapi_desc: dict name: Optional[str] description: Optional[str] + server_url: str + method: str + summary: Optional[str] + parameters: List[dict] + operation_id: str disabled: Optional[bool] created_by: int created_at: datetime @@ -68,6 +79,11 @@ class PluginApiUpdate(BaseModel): description: Optional[str] name: Optional[str] openapi_desc: Optional[str] + server_url: Optional[str] + method: Optional[str] + summary: Optional[str] + parameters: Optional[Json] + operation_id: Optional[str] disabled: Optional[bool] diff --git a/api/models/plugin_config.py b/api/models/plugin_config.py index 9bcc821..d74a5b9 100644 --- a/api/models/plugin_config.py +++ b/api/models/plugin_config.py @@ -1,7 +1,8 @@ -from typing import Optional +from typing import List, Optional from datetime import datetime from pydantic import BaseModel from sqlalchemy import ( + JSON, Boolean, Column, Integer, @@ -23,6 +24,7 @@ class PluginConfigORM(Base): plugin_id = Column(Integer, ForeignKey("plugin.id"), nullable=True) is_draft = Column(Boolean, nullable=False) # 是否为草稿 plugin_openapi_desc = Column(Text) # openapi 先调研 + apis = Column(JSON) # openapi_desc = Column(Text) # openapi 先调研 # plugin_desc = Column(Text) created_by = Column(Integer) @@ -48,7 +50,9 @@ class PluginConfigUpdate(BaseModel): plugin config update ''' id: int + plugin_id: Optional[int] plugin_openapi_desc: Optional[str] + apis: Optional[List[int]] is_draft: Optional[bool] @@ -61,6 +65,7 @@ class PluginConfigModel(PluginConfigCreate): created_at: datetime updated_by: int updated_at: datetime + apis: Optional[List[int]] = None class Config: from_attributes = True diff --git a/api/pyproject.toml b/api/pyproject.toml index 1409189..0763d95 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "langchain-community>=0.0.38", "sse-starlette>=2.1.2", "agentuniverse>=0.0.11", + "langchain_openai", ] readme = "README.md" requires-python = ">=3.8.1,<3.12" diff --git a/api/routers/plugin/router.py b/api/routers/plugin/router.py index 85eb603..3b5c1e2 100644 --- a/api/routers/plugin/router.py +++ b/api/routers/plugin/router.py @@ -26,7 +26,7 @@ def get_plugins(user_id: int, session: Session = Depends(get_db)): return data -@router.get("/", response_model=Page[PluginModel]) +@router.get("/all/", response_model=Page[PluginModel]) def get_all_plugins(session: Session = Depends(get_db)): data = paginate(PluginService.get_all(session)) return data diff --git a/api/services/plugin.py b/api/services/plugin.py index ffb2c1b..9a6a299 100644 --- a/api/services/plugin.py +++ b/api/services/plugin.py @@ -1,9 +1,11 @@ +from datetime import datetime from typing import List from sqlalchemy.orm import Session -from fastapi import Depends +from yaml import safe_load +import re +import uuid from dao.plugin import PluginAPIHelper, PluginConfigHelper, PluginHelper -from db import get_db from models.plugin import PluginCreate, PluginModel, PluginUpdate from models.plugin_api import PluginApiCreate, PluginApiModel, PluginApiUpdate from models.plugin_config import PluginConfigCreate, PluginConfigModel, PluginConfigUpdate @@ -37,6 +39,7 @@ def get_by_id(plugin_id: int, session: Session) -> PluginModel | None: @staticmethod def get_all(session: Session) -> List[PluginModel]: plugin_orms = PluginHelper.get_all_plugin(session) + print('plugin_orms', plugin_orms) return [PluginModel.model_validate(plugin_orm) for plugin_orm in plugin_orms] @staticmethod @@ -48,7 +51,7 @@ def get_user_plugin(user_id: int, session: Session) -> List[PluginModel]: class PluginConfigService: @staticmethod - def get_by_id(config_id: int, session: Session ) -> PluginConfigModel | None: + def get_by_id(config_id: int, session: Session) -> PluginConfigModel | None: plugin_config_orm = PluginConfigHelper.get(session, config_id) if plugin_config_orm is None: return None @@ -64,6 +67,15 @@ def get_plugin_draft(plugin_id: int, session: Session) -> PluginConfigModel | No else: return PluginConfigModel.model_validate(plugin_config_orm) + @staticmethod + def get_latest_publish_config(plugin_id: int, session: Session) -> PluginConfigModel | None: + plugin_config_orm = PluginConfigHelper.get_latest_publish_config( + session, plugin_id) + if plugin_config_orm is None: + return None + else: + return PluginConfigModel.model_validate(plugin_config_orm) + @staticmethod def get_or_create_plugin_draft(operator: int, plugin_id: int, session: Session) -> PluginConfigModel: exist = PluginConfigHelper.get_or_create_plugin_draft( @@ -78,6 +90,12 @@ def create(operator: int, config_model: PluginConfigCreate, session: Session) -> @staticmethod def update(operator: int, config_model: PluginConfigUpdate, session: Session) -> int: + if config_model.plugin_openapi_desc is not None: + plugin_api_bundles = PluginAPIService.parse_openapi_yaml_to_plugin_api_bundle( + config_model.id, config_model.plugin_openapi_desc, operator) + apis = [PluginAPIService.create( + operator, plugin_api, session).id for plugin_api in plugin_api_bundles] + config_model.apis = apis res = PluginConfigHelper.update( session, operator, config_model) return res @@ -85,12 +103,12 @@ def update(operator: int, config_model: PluginConfigUpdate, session: Session) -> class PluginAPIService: @staticmethod - async def count(session: Session ) -> int: + async def count(session: Session) -> int: cnt = PluginAPIHelper.count(session) return cnt @staticmethod - def get_by_id(plugin_api_id: int, session: Session ) -> PluginApiModel | None: + def get_by_id(plugin_api_id: int, session: Session) -> PluginApiModel | None: plugin_api_orm = PluginAPIHelper.get(session, plugin_api_id) if plugin_api_orm is None: return None @@ -98,18 +116,155 @@ def get_by_id(plugin_api_id: int, session: Session ) -> PluginApiModel | None: return PluginApiModel.model_validate(plugin_api_orm) @staticmethod - def get_all(session: Session ) -> List[PluginApiModel]: + def get_by_config_id(plugin_config_id: int, session: Session) -> List[PluginApiModel]: + plugin_api_orms = PluginAPIHelper.get_by_config_id( + session, plugin_config_id) + return [PluginApiModel.model_validate(plugin_api_orm) for plugin_api_orm in plugin_api_orms] + + @staticmethod + def get_all(session: Session) -> List[PluginApiModel]: plugin_api_orms = PluginAPIHelper.get_all(session) return [PluginApiModel.model_validate(plugin_api_orm) for plugin_api_orm in plugin_api_orms] @staticmethod - def update(operator: int, plugin_api_model: PluginApiUpdate, session: Session ) -> int: + def update(operator: int, plugin_api_model: PluginApiUpdate, session: Session) -> int: + res = PluginAPIHelper.update( session, operator, plugin_api_model) return res @staticmethod - def create(operator: int, plugin_api_model: PluginApiCreate, session: Session ) -> PluginApiModel: + def create(operator: int, plugin_api_model: PluginApiCreate, session: Session) -> PluginApiModel: plugin_api_orm = PluginAPIHelper.create( session, operator, plugin_api_model) return PluginApiModel.model_validate(plugin_api_orm) + + @staticmethod + def parse_openapi_yaml_to_plugin_api_bundle(plugin_config_id: int, yaml: str, operator: int, extra_info: dict | None = None) -> list[PluginApiCreate]: + """ + parse openapi yaml to tool bundle + + :param yaml: the yaml string + :return: the tool bundle + """ + extra_info = extra_info if extra_info is not None else {} + + openapi: dict = safe_load(yaml) + if openapi is None: + raise Exception('Invalid openapi yaml.') + + extra_info = extra_info if extra_info is not None else {} + now = datetime.now() + + # set description to extra_info + extra_info['description'] = openapi['info'].get('description', '') + + if len(openapi['servers']) == 0: + raise Exception('No server found in the openapi yaml.') + + server_url = str(openapi['servers'][0]['url']) + + # list all interfaces + interfaces = [] + for path, path_item in openapi['paths'].items(): + methods = ['get', 'post', 'put', 'delete', + 'patch', 'head', 'options', 'trace'] + for method in methods: + if method in path_item: + interfaces.append({ + 'path': path, + 'method': method, + 'operation': path_item[method], + }) + + # get all parameters + bundles = [] + for interface in interfaces: + # convert parameters + parameters = [] + if 'parameters' in interface['operation']: + for parameter in interface['operation']['parameters']: + parameters.append(parameter) + # create tool bundle + # check if there is a request body + if 'requestBody' in interface['operation']: + request_body = interface['operation']['requestBody'] + if 'content' in request_body: + for content_type, content in request_body['content'].items(): + # if there is a reference, get the reference and overwrite the content + if 'schema' not in content: + continue + + if '$ref' in content['schema']: + # get the reference + root = openapi + reference = content['schema']['$ref'].split( + '/')[1:] + for ref in reference: + root = root[ref] + # overwrite the content + interface['operation']['requestBody']['content'][content_type]['schema'] = root + + # parse body parameters + if 'schema' in interface['operation']['requestBody']['content'][content_type]: + body_schema = interface['operation']['requestBody']['content'][content_type]['schema'] + required = body_schema.get('required', []) + properties = body_schema.get('properties', {}) + print('properties', properties) + # for name, property in properties.items(): + # tool = ToolParameter( + # name=name, + # label=I18nObject( + # en_US=name, + # zh_Hans=name + # ), + # human_description=I18nObject( + # en_US=property.get('description', ''), + # zh_Hans=property.get('description', '') + # ), + # type=ToolParameter.ToolParameterType.STRING, + # required=name in required, + # form=ToolParameter.ToolParameterForm.LLM, + # llm_description=property.get( + # 'description', ''), + # default=property.get('default', None), + # ) + + # check if there is a type + # typ = ApiBasedToolSchemaParser._get_tool_parameter_type( + # property) + # if typ: + # tool.type = typ + + # parameters.append(tool) + + if 'operationId' not in interface['operation']: + # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ + path = interface['path'] + if interface['path'].startswith('/'): + path = interface['path'][1:] + # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ + path = re.sub(r'[^a-zA-Z0-9_-]', '', path) + if not path: + path = str(uuid.uuid4()) + + interface['operation']['operationId'] = f'{path}_{interface["method"]}' + bundles.append(PluginApiCreate( + plugin_config_id=plugin_config_id, + openapi_desc=interface['operation'], + name=openapi['info']['title'], + description=openapi['info']['description'], + created_at=now, + created_by=operator, + updated_at=now, + updated_by=operator, + server_url=server_url + interface['path'], + method=interface['method'], + summary=interface['operation']['description'] if 'description' in interface['operation'] else interface['operation'].get( + 'summary', None), + operation_id=interface['operation']['operationId'], + parameters=parameters, + disabled=False, + )) + + return bundles diff --git a/requirements-dev.lock b/requirements-dev.lock index e70fbee..6eed4f6 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -203,7 +203,10 @@ langchain-core==0.1.52 # via langchain # via langchain-anthropic # via langchain-community + # via langchain-openai # via langchain-text-splitters +langchain-openai==0.1.5 + # via magent-api langchain-text-splitters==0.0.2 # via langchain langsmith==0.1.81 @@ -261,6 +264,7 @@ onnxruntime==1.18.1 # via chromadb openai==1.13.3 # via agentuniverse + # via langchain-openai opentelemetry-api==1.25.0 # via chromadb # via opentelemetry-exporter-otlp-proto-grpc @@ -467,6 +471,7 @@ tenacity==8.4.1 # via qianfan tiktoken==0.5.2 # via agentuniverse + # via langchain-openai tokenizers==0.19.1 # via anthropic # via chromadb diff --git a/requirements.lock b/requirements.lock index e70fbee..6eed4f6 100644 --- a/requirements.lock +++ b/requirements.lock @@ -203,7 +203,10 @@ langchain-core==0.1.52 # via langchain # via langchain-anthropic # via langchain-community + # via langchain-openai # via langchain-text-splitters +langchain-openai==0.1.5 + # via magent-api langchain-text-splitters==0.0.2 # via langchain langsmith==0.1.81 @@ -261,6 +264,7 @@ onnxruntime==1.18.1 # via chromadb openai==1.13.3 # via agentuniverse + # via langchain-openai opentelemetry-api==1.25.0 # via chromadb # via opentelemetry-exporter-otlp-proto-grpc @@ -467,6 +471,7 @@ tenacity==8.4.1 # via qianfan tiktoken==0.5.2 # via agentuniverse + # via langchain-openai tokenizers==0.19.1 # via anthropic # via chromadb diff --git a/web/app/src/constant/default.ts b/web/app/src/constant/default.ts index 994ad78..3699f38 100644 --- a/web/app/src/constant/default.ts +++ b/web/app/src/constant/default.ts @@ -3,7 +3,6 @@ export const defaultAgentBotMeta = { name: 'default_bot', avatar: 'https://upload.wikimedia.org/wikipedia/commons/thumb/0/05/HONDA_ASIMO.jpg/440px-HONDA_ASIMO.jpg', - pluginId: 'xxx', }; export const defaultPluginMeta = { diff --git a/web/app/src/modules/chat/chat.ts b/web/app/src/modules/chat/chat.ts index 7b4f78d..d4d9901 100644 --- a/web/app/src/modules/chat/chat.ts +++ b/web/app/src/modules/chat/chat.ts @@ -149,7 +149,7 @@ export class Chat extends AsyncModel { }; protected doSendMessage = async (msg: ChatMessageCreate) => { - const url = `api/v1/chats/${this.id!}/messages`; + const url = `api/v1/chats/${this.id!}/messages/compress`; const res = await this.axios.post(url, msg); if (res.status === 200) { const model = res.data; @@ -174,7 +174,8 @@ export class Chat extends AsyncModel { }; protected doSendMessageStream = async (msg: ChatMessageCreate) => { - const url = `api/v1/chats/${this.id!}/messages`; + // const url = `api/v1/chats/${this.id!}/messages`; + const url = `api/v1/chats/${this.id!}/messages/compress`; const res = await this.axios.post>(url, msg, { headers: { Accept: 'text/event-stream', diff --git a/web/app/src/modules/chat/components/chat.tsx b/web/app/src/modules/chat/components/chat.tsx index 6007f78..86752dd 100644 --- a/web/app/src/modules/chat/components/chat.tsx +++ b/web/app/src/modules/chat/components/chat.tsx @@ -46,7 +46,8 @@ export function Chat(props: ChatProps) { icon={} onClick={() => chat.clear()} > - chat.sendMessageStream(v)} /> + {/* chat.sendMessageStream(v)} /> */} + chat.sendMessage(v)} />
内容由AI生成,无法确保真实准确,仅供参考。
diff --git a/web/app/src/modules/model/built-in-model-meta.ts b/web/app/src/modules/model/built-in-model-meta.ts index e499acf..fd96fc9 100644 --- a/web/app/src/modules/model/built-in-model-meta.ts +++ b/web/app/src/modules/model/built-in-model-meta.ts @@ -10,5 +10,10 @@ const gpt4o: ModelMeta = { name: 'gpt-4o', icon: 'https://static.intercomassets.com/avatars/7363200/square_128/AI-1715100858.png', }; +const gpt35: ModelMeta = { + key: 'gpt-3.5-turbo', + name: 'gpt-3.5-turbo', + icon: 'https://static.intercomassets.com/avatars/7363200/square_128/AI-1715100858.png', +}; -export const builtinModels = [gpt4, gpt4o]; +export const builtinModels = [gpt4, gpt4o, gpt35]; diff --git a/web/app/src/modules/plugin/plugin-config.ts b/web/app/src/modules/plugin/plugin-config.ts index d4df5f0..f6663df 100644 --- a/web/app/src/modules/plugin/plugin-config.ts +++ b/web/app/src/modules/plugin/plugin-config.ts @@ -14,6 +14,8 @@ export class PluginConfig extends AsyncModel { pluginId: number; + apis?: number[] = []; + isDraft = true; protected onChanagedEmitter = new Emitter(); @@ -58,6 +60,7 @@ export class PluginConfig extends AsyncModel { this.id = option.id; this.pluginId = option.plugin_id; this.isDraft = option.is_draft || true; + this.apis = option.apis; this._pluginOpenapiDesc = option.plugin_openapi_desc; super.fromMeta(option); } @@ -79,6 +82,7 @@ export class PluginConfig extends AsyncModel { plugin_id: this.pluginId, is_draft: this.isDraft, plugin_openapi_desc: this.pluginOpenapiDesc, + apis: this.apis, }; } diff --git a/web/app/src/modules/plugin/plugin.ts b/web/app/src/modules/plugin/plugin.ts index dbf4931..a972206 100644 --- a/web/app/src/modules/plugin/plugin.ts +++ b/web/app/src/modules/plugin/plugin.ts @@ -1,10 +1,12 @@ import { Deferred, inject, prop, transient } from '@difizen/mana-app'; + import { AsyncModel } from '../../common/async-model.js'; -import { PluginOption, PluginType } from './protocol.js'; -import type { PluginConfigOption } from './protocol.js'; import { AxiosClient } from '../axios-client/index.js'; -import type { PluginConfig } from './plugin-config.js'; + import { PluginConfigManager } from './plugin-config-manager.js'; +import type { PluginConfig } from './plugin-config.js'; +import { PluginOption, PluginType } from './protocol.js'; +import type { PluginConfigOption } from './protocol.js'; @transient() export class Plugin extends AsyncModel { diff --git a/web/app/src/modules/plugin/protocol.ts b/web/app/src/modules/plugin/protocol.ts index c35f405..95af1ba 100644 --- a/web/app/src/modules/plugin/protocol.ts +++ b/web/app/src/modules/plugin/protocol.ts @@ -17,6 +17,7 @@ export interface PluginConfigOption { plugin_id: number; plugin_openapi_desc: string; is_draft?: boolean; + apis?: number[]; } export type PluginConfigFactory = (options: PluginConfigOption) => PluginConfig; diff --git a/web/app/src/pages/bot/bot-provider.ts b/web/app/src/pages/bot/bot-provider.ts index 05287fc..22519b2 100644 --- a/web/app/src/pages/bot/bot-provider.ts +++ b/web/app/src/pages/bot/bot-provider.ts @@ -1,15 +1,16 @@ import { inject, prop, singleton } from '@difizen/mana-app'; import { DeferredModel } from '../../common/async-model.js'; -import { defaultAgentBotMeta } from '../../constant/default.js'; +import { defaultAgentBotMeta, defaultPluginMeta } from '../../constant/default.js'; import type { AgentBot } from '../../modules/agent-bot/index.js'; import { AgentBotManager } from '../../modules/agent-bot/index.js'; -import { PluginManager } from '../../modules/plugin/index.js'; +import { PluginConfigManager, PluginManager } from '../../modules/plugin/index.js'; @singleton() export class BotProvider extends DeferredModel { @inject(AgentBotManager) agentBotManager: AgentBotManager; @inject(PluginManager) pluginManager: PluginManager; + @inject(PluginConfigManager) pluginConfigManager: PluginConfigManager; @prop() _current?: AgentBot; @@ -34,12 +35,34 @@ export class BotProvider extends DeferredModel { if (!botId) { const page = await this.agentBotManager.getMyBots(); const meta = page.items[0]; - const plugins = await this.pluginManager.getPlugins(); + const pluginPage = await this.pluginManager.getPlugins(); + const pluginOption = pluginPage.items[0]; if (!meta) { this.current = await this.agentBotManager.createBot(defaultAgentBotMeta); } else { this.current = await this.agentBotManager.getBot(meta); } + // if (!pluginOption) { + // await this.pluginManager.createPlugin(defaultPluginMeta); + // } else { + // const plugin = await this.pluginManager.getPlugin(pluginOption); + // console.log('🚀 ~ BotProvider ~ init ~ plugin:', plugin); + // this.current.draftReady.then(() => { + // if (this.current.draft) { + // plugin.draftReady.then(() => { + // console.log('🚀 ~ plugin ~ init ~ draftReady:', plugin); + // // plugin.draft.save(); + // this.current.draft.plugins = [ + // { key: plugin.id.toString(), api: plugin.draft?.apis[0] }, + // ]; + // console.log( + // '🚀 ~ BotProvider ~ this.current.draftReady.then ~ this.current.draft.plugins:', + // this.current.draft.plugins, + // ); + // }); + // } + // }); + // } } else { this.current = await this.agentBotManager.getBot({ id: botId }); } diff --git a/web/app/src/pages/bot/bot-setting/view.tsx b/web/app/src/pages/bot/bot-setting/view.tsx index 7aaf614..1ce2764 100644 --- a/web/app/src/pages/bot/bot-setting/view.tsx +++ b/web/app/src/pages/bot/bot-setting/view.tsx @@ -1,14 +1,36 @@ import { BaseView, singleton, view } from '@difizen/mana-app'; import './index.less'; import type { CollapseProps } from 'antd'; -import { Collapse } from 'antd'; +import { Avatar, Collapse, List } from 'antd'; import { forwardRef } from 'react'; +const data = [ + { + title: 'Arxiv', + }, +]; + const items: CollapseProps['items'] = [ { key: '1', label: '插件', - children: 'test', + children: ( + ( + + + } + title={item.title} + description="Arxiv Plugin" + /> + + )} + /> + ), }, ];