diff --git a/api/alembic/versions/f3069f28699a_update_plugin_table.py b/api/alembic/versions/f3069f28699a_update_plugin_table.py new file mode 100644 index 0000000..a6bc9a9 --- /dev/null +++ b/api/alembic/versions/f3069f28699a_update_plugin_table.py @@ -0,0 +1,56 @@ +"""update plugin table + +Revision ID: f3069f28699a +Revises: ac21c38c5e56 +Create Date: 2024-07-11 04:20:39.013033 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'f3069f28699a' +down_revision: Union[str, None] = 'ac21c38c5e56' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('plugin', sa.Column( + 'name', sa.String(length=255), nullable=False)) + op.add_column('plugin', sa.Column( + 'avatar', sa.String(length=255), nullable=True)) + op.add_column('plugin', sa.Column('description', sa.Text(), nullable=True)) + op.add_column('plugin_api', sa.Column( + 'name', sa.String(length=255), nullable=True)) + op.add_column('plugin_config', sa.Column( + 'plugin_openapi_desc', sa.Text(), nullable=True)) + op.drop_column('plugin_config', 'avatar') + op.drop_column('plugin_config', 'description') + op.drop_column('plugin_config', 'name') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('plugin_config', sa.Column('name', sa.VARCHAR( + length=255), autoincrement=False, nullable=False)) + op.add_column('plugin_config', sa.Column( + 'description', sa.TEXT(), autoincrement=False, nullable=True)) + op.add_column('plugin_config', sa.Column('avatar', sa.VARCHAR( + length=255), autoincrement=False, nullable=True)) + op.drop_column('plugin_config', 'plugin_openapi_desc') + op.drop_column('plugin_api', 'name') + op.drop_column('plugin', 'description') + op.drop_column('plugin', 'avatar') + op.drop_column('plugin', 'name') + op.alter_column('messages', 'message_type', + existing_type=sa.Enum( + 'MARKDOWN', 'TEXT', name='messagetype'), + type_=sa.VARCHAR(length=50), + existing_nullable=False) + # ### end Alembic commands ### diff --git a/api/dao/__init__.py b/api/dao/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/routers/account/crud.py b/api/dao/account.py similarity index 100% rename from api/routers/account/crud.py rename to api/dao/account.py diff --git a/api/routers/agent/crud.py b/api/dao/agent.py similarity index 83% rename from api/routers/agent/crud.py rename to api/dao/agent.py index c815b2f..7eccff2 100644 --- a/api/routers/agent/crud.py +++ b/api/dao/agent.py @@ -1,11 +1,9 @@ from datetime import datetime -from typing import Any +from typing import Any, List from sqlalchemy.orm import Session from sqlalchemy import func -from models.agent_bot import AgentBotModel, AgentBotORM, AgentBotCreate, AgentBotUpdate +from models.agent_bot import AgentBotORM, AgentBotCreate, AgentBotUpdate from models.agent_config import AgentConfigCreate, AgentConfigORM, AgentConfigUpdate -from fastapi_pagination.ext.sqlalchemy import paginate -from fastapi_pagination import Page class AgentBotHelper: @@ -19,12 +17,17 @@ def get(session: Session, bot_id: int) -> AgentBotORM | None: return session.query(AgentBotORM).filter(AgentBotORM.id == bot_id).one_or_none() @staticmethod - def get_all(session: Session, user_id: int) -> Page[AgentBotModel]: - return paginate( - session.query(AgentBotORM) - .filter(AgentBotORM.created_by == user_id) - .order_by(AgentBotORM.updated_at.desc()) - ) + def get_all(session: Session) -> List[AgentBotORM]: + return session.query(AgentBotORM).order_by(AgentBotORM.updated_at.desc()).all() + + @staticmethod + def get_by_user(session: Session, user_id: int) -> List[AgentBotORM]: + return session.query(AgentBotORM).filter(AgentBotORM.created_by == user_id).order_by(AgentBotORM.updated_at.desc()).all() + # return paginate( + # session.query(AgentBotORM) + # .filter(AgentBotORM.created_by == user_id) + # .order_by(AgentBotORM.updated_at.desc()) + # ) @staticmethod def update(session: Session, operator: int, bot_model: AgentBotUpdate) -> int: diff --git a/api/routers/chat/crud.py b/api/dao/chat.py similarity index 85% rename from api/routers/chat/crud.py rename to api/dao/chat.py index 705d10b..8e954f1 100644 --- a/api/routers/chat/crud.py +++ b/api/dao/chat.py @@ -1,9 +1,9 @@ from typing import List from datetime import datetime from sqlalchemy.orm import Session -from models.agent_config import AgentConfigModel +from models.agent_config import AgentConfigModel, AgentConfigORM from models.chat import ChatModel, ChatORM, MessageModel, MessageModelCreate, MessageORM -from routers.agent.crud import AgentConfigHelper +from dao.agent import AgentConfigHelper from sqlalchemy import inspect @@ -21,12 +21,16 @@ def getattr_from_column_name(instance, name, default=Ellipsis): class ChatHelper: @staticmethod - def get_chat_bot_config(session: Session, operator: int, coversation_id: int) -> AgentConfigModel: - chat_orm = session.query(ChatORM).filter(ChatORM.id == - coversation_id).one_or_none() - chat_model = ChatModel.model_validate(chat_orm) - config_id = chat_model.bot_config_id - return AgentConfigHelper.get(session, config_id) + def get_chat(session: Session, chat_id: int) -> ChatORM: + chat_orm = session.query(ChatORM).filter( + ChatORM.id == chat_id).one_or_none() + return chat_orm + + @staticmethod + def get_chat_by_agent_config(session: Session, agent_config_id: int) -> ChatORM: + chat_orm = session.query(ChatORM).filter( + ChatORM.bot_config_id == agent_config_id).one_or_none() + return chat_orm @staticmethod def get_or_create_bot_chat(session: Session, operator: int, agent_config_id: int) -> ChatORM: diff --git a/api/routers/plugin/crud.py b/api/dao/plugin.py similarity index 76% rename from api/routers/plugin/crud.py rename to api/dao/plugin.py index 7413ce1..331f372 100644 --- a/api/routers/plugin/crud.py +++ b/api/dao/plugin.py @@ -1,12 +1,10 @@ from datetime import datetime -from typing import Any +from typing import Any, List from sqlalchemy.orm import Session from sqlalchemy import func from models.plugin_config import PluginConfigCreate, PluginConfigORM, PluginConfigUpdate -from models.plugin import PluginCreate, PluginORM, PluginUpdate, PluginModel -from models.plugin_api import PluginApiCreate, PluginApiORM, PluginApiUpdate, PluginApiModel -from fastapi_pagination.ext.sqlalchemy import paginate -from fastapi_pagination import Page +from models.plugin import PluginCreate, PluginORM, PluginUpdate +from models.plugin_api import PluginApiCreate, PluginApiORM, PluginApiUpdate class PluginHelper: @@ -20,12 +18,21 @@ def get(session: Session, plugin_id: int) -> PluginORM | None: return session.query(PluginORM).filter(PluginORM.id == plugin_id).one_or_none() @staticmethod - def get_all(session: Session, user_id: int) -> Page[PluginModel]: - return paginate( - session.query(PluginORM) - .filter(PluginORM.created_by == user_id) - .order_by(PluginORM.updated_at.desc()) - ) + def get_user_plugin(session: Session, user_id: int) -> List[PluginORM]: + return session.query(PluginORM).filter(PluginORM.created_by == user_id).order_by(PluginORM.updated_at.desc()).all() + # return paginate( + # session.query(PluginORM) + # .filter(PluginORM.created_by == user_id) + # .order_by(PluginORM.updated_at.desc()) + # ) + + @staticmethod + def get_all_plugin(session: Session) -> List[PluginORM]: + return session.query(PluginORM).order_by(PluginORM.updated_at.desc()).all() + # return paginate( + # session.query(PluginORM) + # .order_by(PluginORM.updated_at.desc()) + # ) @staticmethod def update(session: Session, operator: int, plugin_model: PluginUpdate) -> int: @@ -73,7 +80,7 @@ def get_plugin_draft(session: Session, plugin_id: int) -> PluginConfigORM | None def get_or_create_plugin_draft(session: Session, operator: int, plugin_id: int,) -> PluginConfigORM: exist = PluginConfigHelper.get_plugin_draft(session, plugin_id) if exist is None: - return PluginConfigHelper.create(session, operator, PluginConfigCreate(plugin_id=plugin_id, name='', avatar='', description='', is_draft=True)) + return PluginConfigHelper.create(session, operator, PluginConfigCreate(plugin_id=plugin_id, plugin_openapi_desc='', is_draft=True)) else: return exist @@ -121,12 +128,17 @@ def get(session: Session, plugin_api_id: int) -> PluginApiORM | None: return session.query(PluginApiORM).filter(PluginApiORM.id == plugin_api_id).one_or_none() @staticmethod - def get_all(session: Session, user_id: int) -> Page[PluginApiModel]: - return paginate( - session.query(PluginApiORM) - .filter(PluginApiORM.created_by == user_id) - .order_by(PluginApiORM.updated_at.desc()) - ) + def get_user_all(session: Session, user_id: int) -> List[PluginApiORM]: + return session.query(PluginApiORM).filter(PluginApiORM.created_by == user_id).order_by(PluginApiORM.updated_at.desc()).all() + # paginate( + # session.query(PluginApiORM) + # .filter(PluginApiORM.created_by == user_id) + # .order_by(PluginApiORM.updated_at.desc()) + # ) + + @staticmethod + def get_all(session: Session) -> List[PluginApiORM]: + return session.query(PluginApiORM).order_by(PluginApiORM.updated_at.desc()).all() @staticmethod def update(session: Session, operator: int, plugin_api_model: PluginApiUpdate) -> int: diff --git a/api/models/plugin.py b/api/models/plugin.py index 41be034..7ce72ef 100644 --- a/api/models/plugin.py +++ b/api/models/plugin.py @@ -6,6 +6,8 @@ Column, Integer, DateTime, + String, + Text, ) from db import Base @@ -18,6 +20,9 @@ class PluginORM(Base): ''' __tablename__ = "plugin" id = Column(Integer, primary_key=True) + name = Column(String(255), nullable=False) # 插件名字 + avatar = Column(String(255)) # 插件图标 + description = Column(Text) # 插件描述 plugin_type = Column(Integer, nullable=False) # 插件创建类型 created_by = Column(Integer) created_at = Column( @@ -34,6 +39,9 @@ class PluginCreate(BaseModel): plugin create ''' plugin_type: int + name: str + avatar: Optional[str] + description: Optional[str] class PluginUpdate(BaseModel): @@ -41,7 +49,10 @@ class PluginUpdate(BaseModel): plugin update ''' id: int - plugin_type: int + plugin_type: Optional[int] + name: Optional[str] + avatar: Optional[str] + description: Optional[str] class PluginModel(PluginCreate): @@ -49,7 +60,6 @@ class PluginModel(PluginCreate): plugin ''' id: int - plugin_type: int created_by: int created_at: datetime updated_by: int diff --git a/api/models/plugin_api.py b/api/models/plugin_api.py index 223579f..13afd83 100644 --- a/api/models/plugin_api.py +++ b/api/models/plugin_api.py @@ -7,6 +7,7 @@ ForeignKey, Integer, DateTime, + String, Text, Boolean ) @@ -23,7 +24,9 @@ class PluginApiORM(Base): plugin_config_id = Column(Integer, ForeignKey( "plugin_config.id"), nullable=True) description = Column(Text) # API 描述 + name = Column(String(255)) # 插件名字 openapi_desc = Column(Text) # openapi 先调研 + # request_params = Column(JSON) # 入参 # response_params = Column(JSON) # 出参 # debug_example = Column(JSON) # 调试示例 @@ -47,9 +50,10 @@ class PluginApiCreate(BaseModel): plugin api create ''' plugin_config_id: int - description: str openapi_desc: str - disabled: bool + name: Optional[str] + description: Optional[str] + disabled: Optional[bool] created_by: int created_at: datetime updated_by: int @@ -61,9 +65,10 @@ class PluginApiUpdate(BaseModel): plugin api update ''' id: int - description: str - openapi_desc: str - disabled: bool + description: Optional[str] + name: Optional[str] + openapi_desc: Optional[str] + disabled: Optional[bool] class PluginApiModel(PluginApiCreate): @@ -71,10 +76,6 @@ class PluginApiModel(PluginApiCreate): plugin api ''' id: int - plugin_config_id: int - description: str - openapi_desc: str - disabled: bool created_by: int created_at: datetime updated_by: int diff --git a/api/models/plugin_config.py b/api/models/plugin_config.py index b035a0f..9bcc821 100644 --- a/api/models/plugin_config.py +++ b/api/models/plugin_config.py @@ -7,7 +7,6 @@ Integer, DateTime, ForeignKey, - String, Text, ) @@ -20,12 +19,10 @@ class PluginConfigORM(Base): ''' __tablename__ = "plugin_config" id = Column(Integer, primary_key=True) - name = Column(String(255), nullable=False) # 插件名字 - avatar = Column(String(255)) # 插件图标 - description = Column(Text) # 插件描述 # meta_info = Column(JSON) # 插件元信息,包含 api url,service_token,oauth_info 等 plugin_id = Column(Integer, ForeignKey("plugin.id"), nullable=True) is_draft = Column(Boolean, nullable=False) # 是否为草稿 + plugin_openapi_desc = Column(Text) # openapi 先调研 # openapi_desc = Column(Text) # openapi 先调研 # plugin_desc = Column(Text) created_by = Column(Integer) @@ -42,10 +39,7 @@ class PluginConfigCreate(BaseModel): plugin config create ''' plugin_id: int - name: str - avatar: str - description: str - + plugin_openapi_desc: str is_draft: Optional[bool] @@ -54,10 +48,8 @@ class PluginConfigUpdate(BaseModel): plugin config update ''' id: int + plugin_openapi_desc: Optional[str] is_draft: Optional[bool] - name: str - avatar: str - description: str class PluginConfigModel(PluginConfigCreate): @@ -65,11 +57,6 @@ class PluginConfigModel(PluginConfigCreate): plugin config ''' id: int - name: str - avatar: str - description: str - - plugin_id: int created_by: int created_at: datetime updated_by: int diff --git a/api/routers/account/router.py b/api/routers/account/router.py index fe4eb44..c4aac43 100644 --- a/api/routers/account/router.py +++ b/api/routers/account/router.py @@ -1,32 +1,29 @@ -from fastapi import APIRouter, HTTPException, Depends -from sqlalchemy.orm import Session +from fastapi import APIRouter, HTTPException from models.account import AccountModel, AccountCreate -from db import get_db - -from .crud import AccountHelper +from services.account import AccountService router = APIRouter() account_router = router @router.post("/", response_model=AccountModel) -async def create_account(model: AccountCreate, session: Session = Depends(get_db), ): - model = await AccountHelper.create(session, model) - return AccountModel.model_validate(model) +async def create_account(model: AccountCreate) -> AccountModel: + account_model = await AccountService.create(model) + return account_model @router.get("/{user_id}", response_model=AccountModel) -async def get_account_by_id(user_id, db: Session = Depends(get_db)): - model = AccountHelper.get_by_id(db, user_id) +async def get_account_by_id(user_id): + model = AccountService.get_by_id(user_id) if model is None: raise HTTPException(404) - return AccountModel.model_validate(model) + return model @router.get("/email/{email}", response_model=AccountModel) -async def get_account_by_email(email, db: Session = Depends(get_db)): - model = AccountHelper.get_by_email(db, email) +async def get_account_by_email(email): + model = AccountService.get_by_email(email) if model is None: raise HTTPException(404) return AccountModel.model_validate(model) diff --git a/api/routers/agent/router.py b/api/routers/agent/router.py index cb1ed7c..b0169b2 100644 --- a/api/routers/agent/router.py +++ b/api/routers/agent/router.py @@ -1,15 +1,11 @@ -from fastapi import APIRouter, HTTPException, Depends -from sqlalchemy.orm import Session +from fastapi import APIRouter, HTTPException -from fastapi_pagination import Page +from fastapi_pagination import Page, paginate from models.agent_bot import AgentBotModel, AgentBotCreate, AgentBotUpdate from models.agent_config import AgentConfigModel, AgentConfigUpdate, AgentConfigCreate -from db import get_db - -from .crud import AgentBotHelper, AgentConfigHelper - +from services.agent import AgentConfigService, AgentService router = APIRouter() @@ -17,59 +13,58 @@ @router.post("/bots", response_model=AgentBotModel) -def create_agent_bot(user_id: int, bot: AgentBotCreate, session: Session = Depends(get_db)): - model = AgentBotHelper.create(session, user_id, bot) - return AgentBotModel.model_validate(model) +def create_agent_bot(user_id: int, bot: AgentBotCreate): + model = AgentService.create(user_id, bot) + return model @router.get("/bots", response_model=Page[AgentBotModel]) -def get_agent_bots(user_id: int, session: Session = Depends(get_db)): - data = AgentBotHelper.get_all(session, user_id) - return data +def get_agent_bots(user_id: int): + data = AgentService.get_by_user(user_id) + res = paginate(data) + return res @router.get("/bots/{bot_id}", response_model=AgentBotModel) -async def get_agent_bot(bot_id, user_id: int, with_draft=False, session: Session = Depends(get_db)): - model = AgentBotHelper.get(session, bot_id) +async def get_agent_bot(bot_id, user_id: int, with_draft=False): + model = AgentService.get_by_id(bot_id) if model is None: raise HTTPException(404) - bot = AgentBotModel.model_validate(model) if with_draft: - draft = AgentConfigHelper.get_or_create_bot_draft( - session, user_id, bot.id) - bot.draft = draft - return bot + draft = AgentConfigService.get_or_create_bot_draft(user_id, bot_id) + model.draft = draft + return model @router.get("/bots/{bot_id}/draft", response_model=AgentConfigModel) -async def get_or_create_agent_bot_draft_config(user_id: int, bot_id, session: Session = Depends(get_db)): - model = AgentConfigHelper.get_or_create_bot_draft(session, user_id, bot_id) +async def get_or_create_agent_bot_draft_config(user_id: int, bot_id): + model = AgentConfigService.get_or_create_bot_draft(user_id, bot_id) if model is None: raise HTTPException(404) - return AgentConfigModel.model_validate(model) + return model @router.put("/bots/{bot_id}") -async def update_agent_bot(user_id: int, bot: AgentBotUpdate, db: Session = Depends(get_db)): - success = AgentBotHelper.update(db, user_id, bot) +async def update_agent_bot(user_id: int, bot: AgentBotUpdate): + success = AgentService.update(user_id, bot) return success @router.get("/configs/{config_id}", response_model=AgentConfigModel) -async def get_agent_config(config_id, session: Session = Depends(get_db)): - model = AgentConfigHelper.get(session, config_id) +async def get_agent_config(config_id): + model = AgentConfigService.get_by_id(config_id) if model is None: raise HTTPException(404) - return AgentConfigModel.model_validate(model) + return model @router.put("/configs/{bot_id}") -async def update_agent_config(user_id: int, config: AgentConfigUpdate, db: Session = Depends(get_db)): - success = AgentConfigHelper.update(db, user_id, config) +async def update_agent_config(user_id: int, config: AgentConfigUpdate): + success = AgentConfigService.update(user_id, config) return success @router.post("/configs", response_model=AgentConfigModel) -async def create_agent_config(user_id: int, config: AgentConfigCreate, session: Session = Depends(get_db)): - model = AgentConfigHelper.create(session, user_id, config) - return AgentConfigModel.model_validate(model) +async def create_agent_config(user_id: int, config: AgentConfigCreate): + model = AgentConfigService.create(user_id, config) + return model diff --git a/api/routers/chat/router.py b/api/routers/chat/router.py index d947f26..6aa2294 100644 --- a/api/routers/chat/router.py +++ b/api/routers/chat/router.py @@ -9,8 +9,9 @@ from typing import List, AsyncIterable, Optional from sse_starlette import ServerSentEvent from sse_starlette.sse import EventSourceResponse +from services.agent import AgentConfigService +from services.chat import ChatService -from routers.agent.crud import AgentConfigHelper from models.chat import ( ChatModel, MessageModel, MessageModelCreate, MessageSenderType ) @@ -19,8 +20,6 @@ from core.langchain_utils import get_message_str, message_content_to_str from core.chat import chat, chat_stream -from .crud import ChatHelper - logger = logging.getLogger('uvicorn.error') logger.setLevel(logging.DEBUG) @@ -41,13 +40,11 @@ def chat_with_bot_in_debug_mode(bot_id, user_id: int, session: Session = Depends ''' get or create chat ''' - config_orm = AgentConfigHelper.get_bot_draft(session, bot_id) - if config_orm is None: + config = AgentConfigService.get_bot_draft(bot_id, session) + if config is None: raise HTTPException(404) - config = AgentConfigModel.model_validate(config_orm) - model = ChatHelper.get_or_create_bot_chat( - session, user_id, config.id) - return ChatModel.model_validate(model) + model = ChatService.get_or_create_bot_chat(user_id, config.id, session) + return model @router.delete("/{chat_id}/messages", response_model=int) @@ -57,14 +54,14 @@ async def clear_messages_in_chat(chat_id: int, chat_id = chat_id sender_id = user_id - config_orm = ChatHelper.get_chat_bot_config( - session, sender_id, chat_id) - config = AgentConfigModel.model_validate(config_orm) + config = ChatService.get_chat_bot_config( + sender_id, chat_id, session) # TODO: delete or soft delete - if config.is_draft: - return ChatHelper.clear_messages(session, chat_id) - else: - return ChatHelper.clear_messages(session, chat_id) + if config is not None: + if config.is_draft: + return ChatService.clear_messages(chat_id, session) + else: + return ChatService.clear_messages(chat_id, session) class ChatReply(BaseModel): @@ -82,21 +79,10 @@ async def send_message_in_chat(chat_id: int, ''' msg.sender_id = user_id msg.chat_id = chat_id - msg_orm = ChatHelper.insert_message( - session, msg) - msg_model = MessageModel.model_validate(msg_orm) - + msg_model = ChatService.insert_message(msg, session) reply_msg = ChatReply(send=msg_model) - - config_orm = ChatHelper.get_chat_bot_config( - session, user_id, chat_id) - config = AgentConfigModel.model_validate(config_orm) - - history_orm_list = ChatHelper.get_messages( - session, chat_id=chat_id) - - history = [MessageModel.model_validate( - orm_obj) for orm_obj in history_orm_list] + config = ChatService.get_chat_bot_config(user_id, chat_id, session) + history = ChatService.get_messages(chat_id, session) answer_msg = chat(config, chat_id, history, msg_model) if answer_msg is not None: new_msg = MessageModelCreate(sender_id=0, @@ -104,9 +90,7 @@ async def send_message_in_chat(chat_id: int, chat_turn_id=msg_model.chat_turn_id, chat_id=chat_id, content=get_message_str(answer_msg)) - new_msg_orm = ChatHelper.insert_message( - session, new_msg) - new_msg_model = MessageModel.model_validate(new_msg_orm) + new_msg_model = ChatService.insert_message(new_msg, session) reply_msg.reply = [new_msg_model] return reply_msg @@ -130,13 +114,8 @@ async def send_message(session: Session, message: MessageModel) -> AsyncIterable # user send new message yield ServerSentEvent(event=SSEType.MESSAGE.value, id=str(message.chat_turn_id), data=message.model_dump_json()) - config_orm = ChatHelper.get_chat_bot_config( - session, sender_id, chat_id) - config = AgentConfigModel.model_validate(config_orm) - history_orm_list = ChatHelper.get_messages( - session, chat_id=chat_id) - history = [MessageModel.model_validate( - orm_obj) for orm_obj in history_orm_list] + config = ChatService.get_chat_bot_config(sender_id, chat_id, session) + history = ChatService.get_messages(chat_id, session) answer_msg = None msg_iterator = chat_stream(config, chat_id, history, message) @@ -152,9 +131,7 @@ async def send_message(session: Session, message: MessageModel) -> AsyncIterable chat_turn_id=chat_turn_id, chat_id=chat_id, content=chunk_content) - new_msg_orm = ChatHelper.insert_message( - session, new_msg) - answer_msg = MessageModel.model_validate(new_msg_orm) + answer_msg = ChatService.insert_message(new_msg, session) # ai send new message answer_msg.complete = False yield ServerSentEvent(event=SSEType.MESSAGE.value, id=str(answer_msg.chat_turn_id), data=answer_msg.model_dump_json()) @@ -167,7 +144,7 @@ async def send_message(session: Session, message: MessageModel) -> AsyncIterable ) yield ServerSentEvent(event=SSEType.CHUNK.value, id=str(answer_msg.chat_turn_id), data=new_chunk.model_dump_json()) if answer_msg is not None: - ChatHelper.update_message_content(session, answer_msg, content) + ChatService.update_message_content(answer_msg, content, session) @router.post("/{chat_id}/messages") @@ -181,8 +158,6 @@ async def send_message_stream_in_chat(chat_id: int, msg.sender_id = user_id msg.chat_id = chat_id - msg_orm = ChatHelper.insert_message( - session, msg) - msg_model = MessageModel.model_validate(msg_orm) + msg_model = ChatService.insert_message(msg, session) return EventSourceResponse(send_message(session, msg_model), media_type="text/event-stream") diff --git a/api/routers/plugin/router.py b/api/routers/plugin/router.py index 556e924..23eed79 100644 --- a/api/routers/plugin/router.py +++ b/api/routers/plugin/router.py @@ -1,13 +1,10 @@ -from fastapi import APIRouter, HTTPException, Depends -from sqlalchemy.orm import Session +from fastapi import APIRouter, HTTPException -from fastapi_pagination import Page +from fastapi_pagination import Page, paginate from models.plugin import PluginModel, PluginCreate, PluginUpdate from models.plugin_api import PluginApiCreate, PluginApiModel, PluginApiUpdate from models.plugin_config import PluginConfigModel, PluginConfigCreate, PluginConfigUpdate - -from db import get_db -from .crud import PluginHelper, PluginConfigHelper, PluginAPIHelper +from services.plugin import PluginAPIService, PluginConfigService, PluginService router = APIRouter() @@ -15,80 +12,85 @@ @router.post("/", response_model=PluginModel) -def create_plugin(user_id: int, plugin: PluginCreate, session: Session = Depends(get_db)): - model = PluginHelper.create(session, user_id, plugin) - return PluginModel.model_validate(model) +def create_plugin(user_id: int, plugin: PluginCreate): + plugin_model = PluginService.create(user_id, plugin) + return plugin_model @router.get("/", response_model=Page[PluginModel]) -def get_plugins(user_id: int, session: Session = Depends(get_db)): - data = PluginHelper.get_all(session, user_id) +def get_plugins(user_id: int): + data = paginate(PluginService.get_user_plugin(user_id)) + return data + + +@router.get("/", response_model=Page[PluginModel]) +def get_all_plugins(): + data = paginate(PluginService.get_all()) return data @router.get("/{plugin_id}", response_model=PluginModel) -async def get_plugin(plugin_id, user_id: int, with_draft=False, session: Session = Depends(get_db)): - model = PluginHelper.get(session, plugin_id) - if model is None: +async def get_plugin(plugin_id, user_id: int, with_draft=False): + plugin_model = PluginService.get_by_id(plugin_id) + if plugin_model is None: raise HTTPException(404) - plugin_model = PluginModel.model_validate(model) if with_draft: - draft = PluginConfigHelper.get_or_create_plugin_draft( - session, user_id, plugin_model.id) + draft = PluginConfigService.get_or_create_plugin_draft( + user_id, plugin_model.id) plugin_model.draft = draft return plugin_model @router.get("/{plugin_id}/draft", response_model=PluginConfigModel) -async def get_or_create_plugin_draft_config(user_id: int, plugin_id, session: Session = Depends(get_db)): - model = PluginConfigHelper.get_or_create_plugin_draft( - session, user_id, plugin_id) +async def get_or_create_plugin_draft_config(user_id: int, plugin_id): + model = PluginConfigService.get_or_create_plugin_draft( + user_id, plugin_id) if model is None: raise HTTPException(404) - return PluginConfigModel.model_validate(model) + return model @router.put("/{plugin_id}") -async def update_plugin(user_id: int, plugin: PluginUpdate, db: Session = Depends(get_db)): - success = PluginHelper.update(db, user_id, plugin) +async def update_plugin(user_id: int, plugin: PluginUpdate): + success = PluginService.update(user_id, plugin) return success @router.get("/configs/{config_id}", response_model=PluginConfigModel) -async def get_plugin_config(config_id, session: Session = Depends(get_db)): - model = PluginConfigHelper.get(session, config_id) +async def get_plugin_config(config_id): + model = PluginConfigService.get_by_id(config_id) if model is None: raise HTTPException(404) - return PluginConfigModel.model_validate(model) + return model @router.put("/configs/{config_id}") -async def update_plugin_config(user_id: int, config: PluginConfigUpdate, db: Session = Depends(get_db)): - success = PluginConfigHelper.update(db, user_id, config) +async def update_plugin_config(user_id: int, config: PluginConfigUpdate): + success = PluginConfigService.update(user_id, config) return success @router.post("/configs", response_model=PluginConfigModel) -async def create_plugin_config(user_id: int, config: PluginConfigCreate, session: Session = Depends(get_db)): - model = PluginConfigHelper.create(session, user_id, config) - return PluginConfigModel.model_validate(model) +async def create_plugin_config(user_id: int, config: PluginConfigCreate): + model = PluginConfigService.create(user_id, config) + return model @router.get("/api/{api_id}", response_model=PluginApiModel) -async def get_plugin_api(api_id, session: Session = Depends(get_db)): - model = PluginAPIHelper.get(session, api_id) +async def get_plugin_api(api_id): + model = PluginAPIService.get_by_id(api_id) if model is None: raise HTTPException(404) - return PluginApiModel.model_validate(model) + return model @router.put("/api/{api_id}") -async def update_plugin_api(user_id: int, api: PluginApiUpdate, db: Session = Depends(get_db)): - success = PluginAPIHelper.update(db, user_id, api) +async def update_plugin_api(user_id: int, api: PluginApiUpdate): + success = PluginAPIService.update(user_id, api) return success @router.post("/api", response_model=PluginApiModel) -async def create_plugin_api(user_id: int, api: PluginApiCreate, session: Session = Depends(get_db)): - model = PluginAPIHelper.create(session, user_id, api) - return PluginApiModel.model_validate(model) +async def create_plugin_api(user_id: int, api: PluginApiCreate): + model = PluginAPIService.create(user_id, api) + return model diff --git a/api/services/__init__.py b/api/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/services/account.py b/api/services/account.py new file mode 100644 index 0000000..67d48cf --- /dev/null +++ b/api/services/account.py @@ -0,0 +1,30 @@ +from sqlalchemy.orm import Session +from fastapi import Depends +from dao.account import AccountHelper + +from db import get_db +from models.account import AccountCreate, AccountModel + + +class AccountService: + + @staticmethod + async def create(account: AccountCreate, session: Session = Depends(get_db)) -> AccountModel: + account_orm = await AccountHelper.create(session, account) + return AccountModel.model_validate(account_orm) + + @staticmethod + def get_by_id(user_id: int, session: Session = Depends(get_db)) -> AccountModel | None: + account_orm = AccountHelper.get_by_id(session, user_id) + if account_orm is None: + return None + else: + return AccountModel.model_validate(account_orm) + + @staticmethod + def get_by_email(email: str, session: Session = Depends(get_db)) -> AccountModel | None: + account_orm = AccountHelper.get_by_email(session, email) + if account_orm is None: + return None + else: + return AccountModel.model_validate(account_orm) diff --git a/api/services/agent.py b/api/services/agent.py new file mode 100644 index 0000000..4565ae6 --- /dev/null +++ b/api/services/agent.py @@ -0,0 +1,83 @@ +from typing import List +from sqlalchemy.orm import Session +from fastapi import Depends + +from dao.agent import AgentBotHelper, AgentConfigHelper +from db import get_db +from models.agent_bot import AgentBotCreate, AgentBotModel, AgentBotUpdate +from models.agent_config import AgentConfigCreate, AgentConfigModel, AgentConfigUpdate + + +class AgentService: + + @staticmethod + async def count_agent(session: Session = Depends(get_db)) -> int: + cnt = AgentBotHelper.count(session) + return cnt + + @staticmethod + def create(operator: int, bot_model: AgentBotCreate, session: Session = Depends(get_db)) -> AgentBotModel: + agent_orm = AgentBotHelper.create(session, operator, bot_model) + return AgentBotModel.model_validate(agent_orm) + + @staticmethod + def update(operator: int, bot_model: AgentBotUpdate, session: Session = Depends(get_db)) -> int: + res = AgentBotHelper.update(session, operator, bot_model) + return res + + @staticmethod + def get_by_id(bot_id: int, session: Session = Depends(get_db)) -> AgentBotModel | None: + agent_orm = AgentBotHelper.get(session, bot_id) + if agent_orm is None: + return None + else: + return AgentBotModel.model_validate(agent_orm) + + @staticmethod + def get_all(session: Session = Depends(get_db)) -> List[AgentBotModel]: + agent_orms = AgentBotHelper.get_all(session) + return [AgentBotModel.model_validate(agent_orm) for agent_orm in agent_orms] + + @staticmethod + def get_by_user(user_id: int, session: Session = Depends(get_db)) -> List[AgentBotModel]: + agent_orms = AgentBotHelper.get_by_user(session, user_id) + return [AgentBotModel.model_validate(agent_orm) for agent_orm in agent_orms] + + +class AgentConfigService: + + @staticmethod + def get_by_id(config_id: int, session: Session = Depends(get_db)) -> AgentConfigModel | None: + agent_config_orm = AgentConfigHelper.get(session, config_id) + if agent_config_orm is None: + return None + else: + return AgentConfigModel.model_validate(agent_config_orm) + + @staticmethod + def get_bot_draft(bot_id: int, session: Session = Depends(get_db)) -> AgentConfigModel | None: + agent_config_orm = AgentConfigHelper.get_bot_draft(session, bot_id) + if agent_config_orm is None: + return None + else: + return AgentConfigModel.model_validate(agent_config_orm) + + @staticmethod + def get_or_create_bot_draft(operator: int, bot_id: int, session: Session = Depends(get_db)) -> AgentConfigModel: + exist = AgentConfigHelper.get_bot_draft(session, bot_id) + if exist is None: + return AgentConfigHelper.create(session, operator, AgentConfigCreate(bot_id=bot_id, is_draft=True, config=None)) + else: + return exist + + @staticmethod + def create(operator: int, config_model: AgentConfigCreate, session: Session = Depends(get_db)) -> AgentConfigModel: + agent_config_orm = AgentConfigHelper.create( + session, operator, config_model) + return AgentConfigModel.model_validate(agent_config_orm) + + @staticmethod + def update(operator: int, bot_model: AgentConfigUpdate, session: Session = Depends(get_db)) -> int: + res = AgentConfigHelper.update( + session, operator, bot_model) + return res diff --git a/api/services/chat.py b/api/services/chat.py new file mode 100644 index 0000000..ff68f30 --- /dev/null +++ b/api/services/chat.py @@ -0,0 +1,71 @@ +from typing import List +from sqlalchemy.orm import Session +from fastapi import Depends + +from dao.agent import AgentConfigHelper +from dao.chat import ChatHelper +from db import get_db +from models.agent_config import AgentConfigModel +from models.chat import ChatModel, MessageModel, MessageModelCreate + + +class ChatService: + + @staticmethod + def get_chat_bot_config(operator: int, chat_id: int, session: Session = Depends(get_db)) -> AgentConfigModel: + chat_orm = ChatHelper.get_chat(session, chat_id) + chat_model = ChatModel.model_validate(chat_orm) + config_id = chat_model.bot_config_id + agent_config_orm = AgentConfigHelper.get(session, config_id) + return AgentConfigModel.model_validate(agent_config_orm) + + # @staticmethod + # def get_bot_draft(bot_id: int, session: Session = Depends(get_db)) -> AgentConfigModel | None: + # agent_config_orm = AgentConfigHelper.get_bot_draft(session, bot_id) + # if agent_config_orm is None: + # return None + # else: + # return AgentConfigModel.model_validate(agent_config_orm) + + @staticmethod + def get_or_create_bot_chat(operator: int, agent_config_id: int, session: Session = Depends(get_db)) -> ChatModel: + chat_orm = ChatHelper.get_or_create_bot_chat( + session, operator, agent_config_id) + chat_model = ChatModel.model_validate(chat_orm) + return chat_model + + @staticmethod + def create(operator: int, agent_bot_id: int, agent_config_id: int, session: Session = Depends(get_db)) -> ChatModel: + chat_orm = ChatHelper.create( + session, operator, agent_bot_id, agent_config_id) + return ChatModel.model_validate(chat_orm) + + @staticmethod + def insert_message(message: MessageModelCreate, session: Session = Depends(get_db)) -> MessageModel: + msg_orm = ChatHelper.insert_message(session, message) + return MessageModel.model_validate(msg_orm) + + @staticmethod + def update_message_content(message: MessageModel, content: str, session: Session = Depends(get_db)) -> int: + flag = ChatHelper.update_message_content(session, message, content) + return flag + + @staticmethod + def delete_message(message_id: int, session: Session = Depends(get_db)) -> int: + flag = ChatHelper.delete_message(session, message_id) + return flag + + @staticmethod + def get_message(message_id: int, session: Session = Depends(get_db)) -> MessageModel: + msg_orm = ChatHelper.get_message(session, message_id) + return MessageModel.model_validate(msg_orm) + + @staticmethod + def get_messages(chat_id: int, session: Session = Depends(get_db)) -> List[MessageModel]: + msg_orm_list = ChatHelper.get_messages(session, chat_id) + return [MessageModel.model_validate(msg_orm) for msg_orm in msg_orm_list] + + @staticmethod + def clear_messages(chat_id: int, session: Session = Depends(get_db)) -> int: + flag = ChatHelper.clear_messages(session, chat_id) + return flag diff --git a/api/services/plugin.py b/api/services/plugin.py new file mode 100644 index 0000000..4eaf6f3 --- /dev/null +++ b/api/services/plugin.py @@ -0,0 +1,115 @@ +from typing import List +from sqlalchemy.orm import Session +from fastapi import Depends + +from dao.plugin import PluginAPIHelper, PluginConfigHelper, PluginHelper +from db import get_db +from models.plugin import PluginCreate, PluginModel, PluginUpdate +from models.plugin_api import PluginApiCreate, PluginApiModel, PluginApiUpdate +from models.plugin_config import PluginConfigCreate, PluginConfigModel, PluginConfigUpdate + + +class PluginService: + + @staticmethod + async def count(session: Session = Depends(get_db)) -> int: + cnt = PluginHelper.count(session) + return cnt + + @staticmethod + def create(operator: int, plugin_model: PluginCreate, session: Session = Depends(get_db)) -> PluginModel: + plugin_orm = PluginHelper.create(session, operator, plugin_model) + return PluginModel.model_validate(plugin_orm) + + @staticmethod + def update(operator: int, plugin_model: PluginUpdate, session: Session = Depends(get_db)) -> int: + res = PluginHelper.update(session, operator, plugin_model) + return res + + @staticmethod + def get_by_id(plugin_id: int, session: Session = Depends(get_db)) -> PluginModel | None: + plugin_orm = PluginHelper.get(session, plugin_id) + if plugin_orm is None: + return None + else: + return PluginModel.model_validate(plugin_orm) + + @staticmethod + def get_all(session: Session = Depends(get_db)) -> List[PluginModel]: + plugin_orms = PluginHelper.get_all_plugin(session) + return [PluginModel.model_validate(plugin_orm) for plugin_orm in plugin_orms] + + @staticmethod + def get_user_plugin(user_id: int, session: Session = Depends(get_db)) -> List[PluginModel]: + plugin_orms = PluginHelper.get_user_plugin(session, user_id) + return [PluginModel.model_validate(plugin_orm) for plugin_orm in plugin_orms] + + +class PluginConfigService: + + @staticmethod + def get_by_id(config_id: int, session: Session = Depends(get_db)) -> PluginConfigModel | None: + plugin_config_orm = PluginConfigHelper.get(session, config_id) + if plugin_config_orm is None: + return None + else: + return PluginConfigModel.model_validate(plugin_config_orm) + + @staticmethod + def get_plugin_draft(plugin_id: int, session: Session = Depends(get_db)) -> PluginConfigModel | None: + plugin_config_orm = PluginConfigHelper.get_plugin_draft( + session, plugin_id) + if plugin_config_orm is None: + return None + else: + return PluginConfigModel.model_validate(plugin_config_orm) + + @staticmethod + def get_or_create_plugin_draft(operator: int, plugin_id: int, session: Session = Depends(get_db)) -> PluginConfigModel: + exist = PluginConfigHelper.get_or_create_plugin_draft( + session, operator, plugin_id) + return exist + + @staticmethod + def create(operator: int, config_model: PluginConfigCreate, session: Session = Depends(get_db)) -> PluginConfigModel: + plugin_config_orm = PluginConfigHelper.create( + session, operator, config_model) + return PluginConfigModel.model_validate(plugin_config_orm) + + @staticmethod + def update(operator: int, config_model: PluginConfigUpdate, session: Session = Depends(get_db)) -> int: + res = PluginConfigHelper.update( + session, operator, config_model) + return res + + +class PluginAPIService: + @staticmethod + async def count(session: Session = Depends(get_db)) -> int: + cnt = PluginAPIHelper.count(session) + return cnt + + @staticmethod + def get_by_id(plugin_api_id: int, session: Session = Depends(get_db)) -> PluginApiModel | None: + plugin_api_orm = PluginAPIHelper.get(session, plugin_api_id) + if plugin_api_orm is None: + return None + else: + return PluginApiModel.model_validate(plugin_api_orm) + + @staticmethod + def get_all(session: Session = Depends(get_db)) -> List[PluginApiModel]: + plugin_api_orms = PluginAPIHelper.get_all(session) + return [PluginApiModel.model_validate(plugin_api_orm) for plugin_api_orm in plugin_api_orms] + + @staticmethod + def update(operator: int, plugin_api_model: PluginApiUpdate, session: Session = Depends(get_db)) -> int: + res = PluginAPIHelper.update( + session, operator, plugin_api_model) + return res + + @staticmethod + def create(operator: int, plugin_api_model: PluginApiCreate, session: Session = Depends(get_db)) -> PluginApiModel: + plugin_api_orm = PluginAPIHelper.create( + session, operator, plugin_api_model) + return PluginApiModel.model_validate(plugin_api_orm) diff --git a/web/app/src/constant/default.ts b/web/app/src/constant/default.ts index 4929dc9..994ad78 100644 --- a/web/app/src/constant/default.ts +++ b/web/app/src/constant/default.ts @@ -3,6 +3,15 @@ export const defaultAgentBotMeta = { name: 'default_bot', avatar: 'https://upload.wikimedia.org/wikipedia/commons/thumb/0/05/HONDA_ASIMO.jpg/440px-HONDA_ASIMO.jpg', + pluginId: 'xxx', +}; + +export const defaultPluginMeta = { + plugin_type: 0, + name: '查找论文', + avatar: + 'https://mdn.alipayobjects.com/huamei_zabatk/afts/img/A*IR_ISbcKmdcAAAAAAAAAAAAADvyTAQ/original', + description: '查找论文', }; export const defaultSuperUserEmal = 'default@magent.com'; diff --git a/web/app/src/modules/agent-bot/agent-config.ts b/web/app/src/modules/agent-bot/agent-config.ts index 7395f74..79ad425 100644 --- a/web/app/src/modules/agent-bot/agent-config.ts +++ b/web/app/src/modules/agent-bot/agent-config.ts @@ -4,7 +4,11 @@ import debounce from 'lodash.debounce'; import { AsyncModel } from '../../common/async-model.js'; import { AxiosClient } from '../axios-client/index.js'; -import type { AgentConfigInfo, AgentConfigModelMeta } from './protocol.js'; +import type { + AgentConfigInfo, + AgentConfigModelMeta, + AgentConfigPluginMeta, +} from './protocol.js'; import { AgentConfigOption, AgentConfigType } from './protocol.js'; @transient() @@ -43,6 +47,16 @@ export class AgentConfig extends AsyncModel { this.changed(); } + @prop() + protected _plugins?: AgentConfigPluginMeta[]; + get plugins() { + return this._plugins; + } + set plugins(v: AgentConfigPluginMeta[] | undefined) { + this._plugins = v; + this.changed(); + } + @prop() config?: AgentConfigInfo; @@ -70,6 +84,7 @@ export class AgentConfig extends AsyncModel { this.is_draft = option.is_draft || true; this._persona = option.config?.persona; this._model = option.config?.model; + this._plugins = option.config?.plugins; // TODO: tools & datasets super.fromMeta(option); } @@ -94,6 +109,7 @@ export class AgentConfig extends AsyncModel { ...(this.config || {}), persona: this.persona, model: this.model, + plugins: this.plugins, }, }; } diff --git a/web/app/src/modules/agent-bot/protocol.ts b/web/app/src/modules/agent-bot/protocol.ts index 40eb7b3..5e36735 100644 --- a/web/app/src/modules/agent-bot/protocol.ts +++ b/web/app/src/modules/agent-bot/protocol.ts @@ -7,6 +7,10 @@ export interface AgentConfigToolMeta { key: string; [key: string]: any; } +export interface AgentConfigPluginMeta { + key: string; + [key: string]: any; +} export interface AgentConfigDatasetMeta { key: string; [key: string]: any; @@ -18,6 +22,7 @@ export interface AgentConfigModelMeta { export interface AgentConfigInfo { persona: string; tools: AgentConfigToolMeta[]; + plugins: AgentConfigPluginMeta[]; datasets: AgentConfigDatasetMeta[]; model: AgentConfigModelMeta; [key: string]: any; diff --git a/web/app/src/modules/agent-module.ts b/web/app/src/modules/agent-module.ts index 9e53c9f..3131855 100644 --- a/web/app/src/modules/agent-module.ts +++ b/web/app/src/modules/agent-module.ts @@ -4,10 +4,12 @@ import { AgentBotModule } from './agent-bot/module.js'; import { AxiosClientModule } from './axios-client/module.js'; import { ChatModule } from './chat/module.js'; import { ModelModule } from './model/module.js'; +import { PluginModule } from './plugin/module.js'; export const AgentModule = ManaModule.create().dependOn( ModelModule, AgentBotModule, AxiosClientModule, ChatModule, + PluginModule, ); diff --git a/web/app/src/modules/plugin/index.ts b/web/app/src/modules/plugin/index.ts new file mode 100644 index 0000000..cf3d721 --- /dev/null +++ b/web/app/src/modules/plugin/index.ts @@ -0,0 +1,6 @@ +export * from './protocol.js'; +export * from './plugin.js'; +export * from './plugin-manager.js'; +export * from './plugin-config.js'; +export * from './plugin-config-manager.js'; +export * from './module.js'; diff --git a/web/app/src/modules/plugin/module.ts b/web/app/src/modules/plugin/module.ts new file mode 100644 index 0000000..864ed51 --- /dev/null +++ b/web/app/src/modules/plugin/module.ts @@ -0,0 +1,39 @@ +import { ManaModule } from '@difizen/mana-app'; + +import { Plugin } from './plugin.js'; +import { + PluginConfigFactory, + PluginConfigOption, + PluginFactory, + PluginOption, +} from './protocol.js'; +import { PluginManager } from './plugin-manager.js'; +import { PluginConfig } from './plugin-config.js'; +import { PluginConfigManager } from './plugin-config-manager.js'; + +export const PluginModule = ManaModule.create().register( + PluginManager, + Plugin, + PluginConfig, + PluginConfigManager, + { + token: PluginConfigFactory, + useFactory: (ctx) => { + return (option: PluginConfigOption) => { + const child = ctx.container.createChild(); + child.register({ token: PluginConfigOption, useValue: option }); + return child.get(PluginConfig); + }; + }, + }, + { + token: PluginFactory, + useFactory: (ctx) => { + return (option: any) => { + const child = ctx.container.createChild(); + child.register({ token: PluginOption, useValue: option }); + return child.get(Plugin); + }; + }, + }, +); diff --git a/web/app/src/modules/plugin/plugin-config-manager.ts b/web/app/src/modules/plugin/plugin-config-manager.ts new file mode 100644 index 0000000..310c237 --- /dev/null +++ b/web/app/src/modules/plugin/plugin-config-manager.ts @@ -0,0 +1,20 @@ +import { inject, singleton } from '@difizen/mana-app'; + +import { AxiosClient } from '../axios-client/index.js'; +import { UserManager } from '../user/index.js'; +import { PluginConfigFactory, type PluginConfigOption } from './protocol.js'; +import type { PluginConfig } from './plugin-config.js'; + +@singleton() +export class PluginConfigManager { + @inject(UserManager) userManager: UserManager; + @inject(AxiosClient) axios: AxiosClient; + @inject(PluginConfigFactory) configFactory: PluginConfigFactory; + getDraft = async (option: PluginConfigOption): Promise => { + return this.configFactory(option); + }; + + create = (option: PluginConfigOption): PluginConfig => { + return this.configFactory(option); + }; +} diff --git a/web/app/src/modules/plugin/plugin-config.ts b/web/app/src/modules/plugin/plugin-config.ts new file mode 100644 index 0000000..d4df5f0 --- /dev/null +++ b/web/app/src/modules/plugin/plugin-config.ts @@ -0,0 +1,100 @@ +import { Emitter, inject, prop, transient } from '@difizen/mana-app'; +import debounce from 'lodash.debounce'; + +import { AsyncModel } from '../../common/async-model.js'; +import { AxiosClient } from '../axios-client/index.js'; + +import { PluginConfigOption, PluginConfigType } from './protocol.js'; +import {} from './protocol.js'; + +@transient() +export class PluginConfig extends AsyncModel { + protected axios: AxiosClient; + id: number; + + pluginId: number; + + isDraft = true; + + protected onChanagedEmitter = new Emitter(); + + get onChanged() { + return this.onChanagedEmitter.event; + } + + @prop() + protected _pluginOpenapiDesc: string; + + get pluginOpenapiDesc() { + return this._pluginOpenapiDesc; + } + set pluginOpenapiDesc(v: string) { + this._pluginOpenapiDesc = v; + this.changed(); + } + + // @prop() + // config?: PluginConfigInfo; + + option?: PluginConfigOption; + + constructor( + @inject(PluginConfigOption) option: PluginConfigOption, + @inject(AxiosClient) axios: AxiosClient, + ) { + super(); + this.option = option; + this.id = option.id; + this.axios = axios; + + this.initialize(option); + } + + override shouldInitFromMeta(option: PluginConfigOption): boolean { + return PluginConfigType.isFullOption(option); + } + + protected override fromMeta(option: PluginConfigOption): void { + this.id = option.id; + this.pluginId = option.plugin_id; + this.isDraft = option.is_draft || true; + this._pluginOpenapiDesc = option.plugin_openapi_desc; + super.fromMeta(option); + } + + override async fetchInfo(option: PluginConfigOption): Promise { + const res = await this.axios.get( + `api/v1/plugins/configs/${option.id}`, + ); + if (res.status === 200) { + if (this.shouldInitFromMeta(res.data)) { + this.fromMeta(res.data); + } + } + } + + toMeta(): PluginConfigOption { + return { + id: this.id, + plugin_id: this.pluginId, + is_draft: this.isDraft, + plugin_openapi_desc: this.pluginOpenapiDesc, + }; + } + + protected deferSave: (meta?: PluginConfigOption) => Promise | undefined = + debounce(this.save.bind(this), 500); + + protected changed = () => { + this.onChanagedEmitter.fire(this); + this.deferSave(); + }; + + async save(meta = this.toMeta()): Promise { + const res = await this.axios.put(`api/v1/plugins/configs/${this.id}`, meta); + if (res.status === 200) { + return res.data === 1; + } + return false; + } +} diff --git a/web/app/src/modules/plugin/plugin-manager.ts b/web/app/src/modules/plugin/plugin-manager.ts new file mode 100644 index 0000000..44e4ca3 --- /dev/null +++ b/web/app/src/modules/plugin/plugin-manager.ts @@ -0,0 +1,76 @@ +import { inject, singleton } from '@difizen/mana-app'; +import qs from 'query-string'; + +import { AxiosClient } from '../axios-client/index.js'; +import { UserManager } from '../user/index.js'; + +import { + PluginFactory, + type Plugin, + type PluginOption, + type PluginMeta, +} from './protocol.js'; + +@singleton() +export class PluginManager { + protected cache: Map = new Map(); + @inject(PluginFactory) pluginFactory: PluginFactory; + @inject(UserManager) userManager: UserManager; + @inject(AxiosClient) axios: AxiosClient; + + getMyPlugins = async (): Promise> => { + const user = await this.userManager.currentReady; + const defaultValue = { items: [], total: 0, page: 0, size: 0, pages: 0 }; + if (!user) { + return defaultValue; + } + const query = qs.stringify({ + page: 1, + size: 10, + user_id: user.id, + }); + const res = await this.axios.get(`api/v1/plugins?${query}`); + if (res.status === 200) { + return res.data; + } + return defaultValue; + }; + + getPlugins = async (): Promise> => { + const defaultValue = { items: [], total: 0, page: 0, size: 0, pages: 0 }; + const query = qs.stringify({ + page: 1, + size: 10, + }); + const res = await this.axios.get(`api/v1/plugins?${query}`); + if (res.status === 200) { + return res.data; + } + return defaultValue; + }; + + createPlugin = async (meta: PluginMeta): Promise => { + const user = await this.userManager.currentReady; + if (!user) { + throw new Error('cannot get user info'); + } + const query = qs.stringify({ + user_id: user.id, + }); + const res = await this.axios.post(`api/v1/plugins?${query}`, meta); + if (res.status !== 200) { + throw new Error('failed to create agent bot'); + } + return this.pluginFactory(res.data); + }; + + getPlugin = (option: PluginOption): Plugin => { + const exist = this.cache.get(option.id); + if (exist) { + return exist; + } + const plugin = this.pluginFactory(option); + this.cache.set(plugin.id, plugin); + return plugin; + }; +} diff --git a/web/app/src/modules/plugin/plugin.ts b/web/app/src/modules/plugin/plugin.ts new file mode 100644 index 0000000..dbf4931 --- /dev/null +++ b/web/app/src/modules/plugin/plugin.ts @@ -0,0 +1,121 @@ +import { Deferred, inject, prop, transient } from '@difizen/mana-app'; +import { AsyncModel } from '../../common/async-model.js'; +import { PluginOption, PluginType } from './protocol.js'; +import type { PluginConfigOption } from './protocol.js'; +import { AxiosClient } from '../axios-client/index.js'; +import type { PluginConfig } from './plugin-config.js'; +import { PluginConfigManager } from './plugin-config-manager.js'; + +@transient() +export class Plugin extends AsyncModel { + axios: AxiosClient; + configManager: PluginConfigManager; + + id: number; + + @prop() + name: string; + @prop() + pluginType: number; + @prop() + avatar?: string; + @prop() + description?: string; + + @prop() + draft?: PluginConfig; + + protected draftDeferred = new Deferred(); + + get draftReady() { + return this.draftDeferred.promise; + } + + option: PluginOption; + + constructor( + @inject(PluginOption) option: PluginOption, + @inject(PluginConfigManager) configManager: PluginConfigManager, + @inject(AxiosClient) axios: AxiosClient, + ) { + super(); + this.option = option; + this.configManager = configManager; + this.axios = axios; + + this.id = option.id; + this.initialize(option); + this.ensureDraft(option); + } + + shouldInitFromMeta(option: PluginOption = this.option): boolean { + return PluginType.isFullOption(option); + } + + async ensureDraft( + option: PluginOption = this.option, + ): Promise { + await this.ready; + if (this.draft) { + return this.draft; + } + let draftConfig = option.draft; + if (!draftConfig) { + draftConfig = await this.fetchDraftInfo(option); + } + if (draftConfig) { + this.draft = this.configManager.create(draftConfig); + this.draftDeferred.resolve(this.draft); + } + return this.draft; + } + + protected override fromMeta(option: PluginOption = this.option) { + this.id = option.id; + this.pluginType = option.plugin_type!; + this.name = option.name!; + this.avatar = option.avatar; + this.description = option.description; + super.fromMeta(option); + } + + async fetchInfo(option: PluginOption = this.option) { + const res = await this.axios.get(`api/v1/plugins/${option.id}`); + if (res.status === 200) { + if (this.shouldInitFromMeta(res.data)) { + this.fromMeta(res.data); + } + } + } + + async fetchDraftInfo(option: PluginOption = this.option) { + const res = await this.axios.get( + `api/v1/plugins/${option.id}/draft`, + ); + if (res.status === 200) { + return res.data; + } + return undefined; + } + + toMeta(): PluginOption { + return { + id: this.id, + plugin_type: this.pluginType, + description: this.description, + name: this.name, + avatar: this.avatar, + }; + } + + async save(): Promise { + const res = await this.axios.put( + `api/v1/plugins/${this.id}`, + this.toMeta(), + ); + if (res.status === 200) { + return res.data === 1; + } + return false; + } +} diff --git a/web/app/src/modules/plugin/protocol.ts b/web/app/src/modules/plugin/protocol.ts new file mode 100644 index 0000000..c35f405 --- /dev/null +++ b/web/app/src/modules/plugin/protocol.ts @@ -0,0 +1,74 @@ +import { Syringe } from '@difizen/mana-app'; + +import type { PluginConfig } from './plugin-config.js'; +import type { Plugin } from './plugin.js'; + +// export interface PluginConfigInfo { +// pluginOpenapiDesc: string; +// [key: string]: any; +// } + +export const PluginConfigOption = Syringe.defineToken('PluginConfigOption', { + multiple: false, +}); + +export interface PluginConfigOption { + id: number; + plugin_id: number; + plugin_openapi_desc: string; + is_draft?: boolean; +} + +export type PluginConfigFactory = (options: PluginConfigOption) => PluginConfig; +export const PluginConfigFactory = Syringe.defineToken('PluginConfigFactory', { + multiple: false, +}); + +export const PluginConfigType = { + isOption(data?: Record): data is PluginConfigOption { + return !!(data && 'id' in data); + }, + isFullOption(data?: Record): boolean { + return ( + PluginConfigType.isOption(data) && + 'plugin_id' in data && + 'plugin_openapi_desc' in data + ); + }, +}; + +export type { Plugin } from './plugin.js'; + +export interface PluginMeta { + name: string; + plugin_type: number; + avatar?: string; + description?: string; +} + +export interface PluginOption extends Partial { + id: number; + draft?: PluginConfigOption | null; +} + +export const PluginOption = Syringe.defineToken('PluginOption', { + multiple: false, +}); + +export type PluginFactory = (options: PluginOption) => Plugin; +export const PluginFactory = Syringe.defineToken('PluginFactory', { + multiple: false, +}); + +export const PluginType = { + isOption(data?: Record): data is PluginOption { + return !!(data && 'id' in data); + }, + isFullOption(data?: Record): boolean { + return PluginType.isOption(data) && 'name' in data; + }, +}; + +export const PluginInstance = Syringe.defineToken('PluginInstance', { + multiple: false, +}); diff --git a/web/app/src/pages/bot/bot-provider.ts b/web/app/src/pages/bot/bot-provider.ts index ba27a3f..05287fc 100644 --- a/web/app/src/pages/bot/bot-provider.ts +++ b/web/app/src/pages/bot/bot-provider.ts @@ -4,10 +4,12 @@ import { DeferredModel } from '../../common/async-model.js'; import { defaultAgentBotMeta } from '../../constant/default.js'; import type { AgentBot } from '../../modules/agent-bot/index.js'; import { AgentBotManager } from '../../modules/agent-bot/index.js'; +import { PluginManager } from '../../modules/plugin/index.js'; @singleton() export class BotProvider extends DeferredModel { @inject(AgentBotManager) agentBotManager: AgentBotManager; + @inject(PluginManager) pluginManager: PluginManager; @prop() _current?: AgentBot; @@ -32,6 +34,7 @@ export class BotProvider extends DeferredModel { if (!botId) { const page = await this.agentBotManager.getMyBots(); const meta = page.items[0]; + const plugins = await this.pluginManager.getPlugins(); if (!meta) { this.current = await this.agentBotManager.createBot(defaultAgentBotMeta); } else { diff --git a/web/app/src/pages/bot/bot-setting/index.less b/web/app/src/pages/bot/bot-setting/index.less new file mode 100644 index 0000000..b1399f5 --- /dev/null +++ b/web/app/src/pages/bot/bot-setting/index.less @@ -0,0 +1,3 @@ +.bot-skill { + display: block; +} diff --git a/web/app/src/pages/bot/bot-setting/index.ts b/web/app/src/pages/bot/bot-setting/index.ts new file mode 100644 index 0000000..8c69392 --- /dev/null +++ b/web/app/src/pages/bot/bot-setting/index.ts @@ -0,0 +1 @@ +export * from './view.js'; diff --git a/web/app/src/pages/bot/bot-setting/view.tsx b/web/app/src/pages/bot/bot-setting/view.tsx new file mode 100644 index 0000000..7aaf614 --- /dev/null +++ b/web/app/src/pages/bot/bot-setting/view.tsx @@ -0,0 +1,30 @@ +import { BaseView, singleton, view } from '@difizen/mana-app'; +import './index.less'; +import type { CollapseProps } from 'antd'; +import { Collapse } from 'antd'; +import { forwardRef } from 'react'; + +const items: CollapseProps['items'] = [ + { + key: '1', + label: '插件', + children: 'test', + }, +]; + +const BotSettingComponent = forwardRef( + function BotSkillComponent(props, ref) { + return ( +
+
技能
+ +
+ ); + }, +); + +@singleton() +@view('bot-setting') +export class BotSettingView extends BaseView { + override view = BotSettingComponent; +} diff --git a/web/app/src/pages/bot/module.ts b/web/app/src/pages/bot/module.ts index 763cba9..8606912 100644 --- a/web/app/src/pages/bot/module.ts +++ b/web/app/src/pages/bot/module.ts @@ -15,6 +15,7 @@ import { BotLayoutView, BotLayoutSlots } from './bot-layout/layout.js'; import { BotPersonaView } from './bot-persona/view.js'; import { BotPreviewerView } from './bot-previewer/index.js'; import { BotProvider } from './bot-provider.js'; +import { BotSettingView } from './bot-setting/view.js'; import { MagentBrandView } from './brand/view.js'; import { ModelSelectorView } from './model-selector/view.js'; @@ -29,6 +30,7 @@ export const BotModule = ManaModule.create().register( UserAvatarView, ModelSelectorView, BotPersonaView, + BotSettingView, createSlotPreference({ slot: RootSlotId, view: MagentBaseLayoutView, @@ -69,4 +71,8 @@ export const BotModule = ManaModule.create().register( slot: AgentBotConfigSlots.contentLeft, view: BotPersonaView, }), + createSlotPreference({ + slot: AgentBotConfigSlots.contentRight, + view: BotSettingView, + }), );