Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(api): add services and dao&feat: add plugin module #3

Merged
merged 5 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions api/alembic/versions/f3069f28699a_update_plugin_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""update plugin table

Revision ID: f3069f28699a
Revises: ac21c38c5e56
Create Date: 2024-07-11 04:20:39.013033

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = 'f3069f28699a'
down_revision: Union[str, None] = 'ac21c38c5e56'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('plugin', sa.Column(
'name', sa.String(length=255), nullable=False))
op.add_column('plugin', sa.Column(
'avatar', sa.String(length=255), nullable=True))
op.add_column('plugin', sa.Column('description', sa.Text(), nullable=True))
op.add_column('plugin_api', sa.Column(
'name', sa.String(length=255), nullable=True))
op.add_column('plugin_config', sa.Column(
'plugin_openapi_desc', sa.Text(), nullable=True))
op.drop_column('plugin_config', 'avatar')
op.drop_column('plugin_config', 'description')
op.drop_column('plugin_config', 'name')
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('plugin_config', sa.Column('name', sa.VARCHAR(
length=255), autoincrement=False, nullable=False))
op.add_column('plugin_config', sa.Column(
'description', sa.TEXT(), autoincrement=False, nullable=True))
op.add_column('plugin_config', sa.Column('avatar', sa.VARCHAR(
length=255), autoincrement=False, nullable=True))
op.drop_column('plugin_config', 'plugin_openapi_desc')
op.drop_column('plugin_api', 'name')
op.drop_column('plugin', 'description')
op.drop_column('plugin', 'avatar')
op.drop_column('plugin', 'name')
op.alter_column('messages', 'message_type',
existing_type=sa.Enum(
'MARKDOWN', 'TEXT', name='messagetype'),
type_=sa.VARCHAR(length=50),
existing_nullable=False)
# ### end Alembic commands ###
Empty file added api/dao/__init__.py
Empty file.
File renamed without changes.
23 changes: 13 additions & 10 deletions api/routers/agent/crud.py → api/dao/agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from datetime import datetime
from typing import Any
from typing import Any, List
from sqlalchemy.orm import Session
from sqlalchemy import func
from models.agent_bot import AgentBotModel, AgentBotORM, AgentBotCreate, AgentBotUpdate
from models.agent_bot import AgentBotORM, AgentBotCreate, AgentBotUpdate
from models.agent_config import AgentConfigCreate, AgentConfigORM, AgentConfigUpdate
from fastapi_pagination.ext.sqlalchemy import paginate
from fastapi_pagination import Page


class AgentBotHelper:
Expand All @@ -19,12 +17,17 @@ def get(session: Session, bot_id: int) -> AgentBotORM | None:
return session.query(AgentBotORM).filter(AgentBotORM.id == bot_id).one_or_none()

@staticmethod
def get_all(session: Session, user_id: int) -> Page[AgentBotModel]:
return paginate(
session.query(AgentBotORM)
.filter(AgentBotORM.created_by == user_id)
.order_by(AgentBotORM.updated_at.desc())
)
def get_all(session: Session) -> List[AgentBotORM]:
return session.query(AgentBotORM).order_by(AgentBotORM.updated_at.desc()).all()

@staticmethod
def get_by_user(session: Session, user_id: int) -> List[AgentBotORM]:
return session.query(AgentBotORM).filter(AgentBotORM.created_by == user_id).order_by(AgentBotORM.updated_at.desc()).all()
# return paginate(
# session.query(AgentBotORM)
# .filter(AgentBotORM.created_by == user_id)
# .order_by(AgentBotORM.updated_at.desc())
# )

@staticmethod
def update(session: Session, operator: int, bot_model: AgentBotUpdate) -> int:
Expand Down
20 changes: 12 additions & 8 deletions api/routers/chat/crud.py → api/dao/chat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List
from datetime import datetime
from sqlalchemy.orm import Session
from models.agent_config import AgentConfigModel
from models.agent_config import AgentConfigModel, AgentConfigORM
from models.chat import ChatModel, ChatORM, MessageModel, MessageModelCreate, MessageORM
from routers.agent.crud import AgentConfigHelper
from dao.agent import AgentConfigHelper

from sqlalchemy import inspect

Expand All @@ -21,12 +21,16 @@ def getattr_from_column_name(instance, name, default=Ellipsis):

class ChatHelper:
@staticmethod
def get_chat_bot_config(session: Session, operator: int, coversation_id: int) -> AgentConfigModel:
chat_orm = session.query(ChatORM).filter(ChatORM.id ==
coversation_id).one_or_none()
chat_model = ChatModel.model_validate(chat_orm)
config_id = chat_model.bot_config_id
return AgentConfigHelper.get(session, config_id)
def get_chat(session: Session, chat_id: int) -> ChatORM:
chat_orm = session.query(ChatORM).filter(
ChatORM.id == chat_id).one_or_none()
return chat_orm

@staticmethod
def get_chat_by_agent_config(session: Session, agent_config_id: int) -> ChatORM:
chat_orm = session.query(ChatORM).filter(
ChatORM.bot_config_id == agent_config_id).one_or_none()
return chat_orm

@staticmethod
def get_or_create_bot_chat(session: Session, operator: int, agent_config_id: int) -> ChatORM:
Expand Down
48 changes: 30 additions & 18 deletions api/routers/plugin/crud.py → api/dao/plugin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from datetime import datetime
from typing import Any
from typing import Any, List
from sqlalchemy.orm import Session
from sqlalchemy import func
from models.plugin_config import PluginConfigCreate, PluginConfigORM, PluginConfigUpdate
from models.plugin import PluginCreate, PluginORM, PluginUpdate, PluginModel
from models.plugin_api import PluginApiCreate, PluginApiORM, PluginApiUpdate, PluginApiModel
from fastapi_pagination.ext.sqlalchemy import paginate
from fastapi_pagination import Page
from models.plugin import PluginCreate, PluginORM, PluginUpdate
from models.plugin_api import PluginApiCreate, PluginApiORM, PluginApiUpdate


class PluginHelper:
Expand All @@ -20,12 +18,21 @@ def get(session: Session, plugin_id: int) -> PluginORM | None:
return session.query(PluginORM).filter(PluginORM.id == plugin_id).one_or_none()

@staticmethod
def get_all(session: Session, user_id: int) -> Page[PluginModel]:
return paginate(
session.query(PluginORM)
.filter(PluginORM.created_by == user_id)
.order_by(PluginORM.updated_at.desc())
)
def get_user_plugin(session: Session, user_id: int) -> List[PluginORM]:
return session.query(PluginORM).filter(PluginORM.created_by == user_id).order_by(PluginORM.updated_at.desc()).all()
# return paginate(
# session.query(PluginORM)
# .filter(PluginORM.created_by == user_id)
# .order_by(PluginORM.updated_at.desc())
# )

@staticmethod
def get_all_plugin(session: Session) -> List[PluginORM]:
return session.query(PluginORM).order_by(PluginORM.updated_at.desc()).all()
# return paginate(
# session.query(PluginORM)
# .order_by(PluginORM.updated_at.desc())
# )

@staticmethod
def update(session: Session, operator: int, plugin_model: PluginUpdate) -> int:
Expand Down Expand Up @@ -73,7 +80,7 @@ def get_plugin_draft(session: Session, plugin_id: int) -> PluginConfigORM | None
def get_or_create_plugin_draft(session: Session, operator: int, plugin_id: int,) -> PluginConfigORM:
exist = PluginConfigHelper.get_plugin_draft(session, plugin_id)
if exist is None:
return PluginConfigHelper.create(session, operator, PluginConfigCreate(plugin_id=plugin_id, name='', avatar='', description='', is_draft=True))
return PluginConfigHelper.create(session, operator, PluginConfigCreate(plugin_id=plugin_id, plugin_openapi_desc='', is_draft=True))
else:
return exist

Expand Down Expand Up @@ -121,12 +128,17 @@ def get(session: Session, plugin_api_id: int) -> PluginApiORM | None:
return session.query(PluginApiORM).filter(PluginApiORM.id == plugin_api_id).one_or_none()

@staticmethod
def get_all(session: Session, user_id: int) -> Page[PluginApiModel]:
return paginate(
session.query(PluginApiORM)
.filter(PluginApiORM.created_by == user_id)
.order_by(PluginApiORM.updated_at.desc())
)
def get_user_all(session: Session, user_id: int) -> List[PluginApiORM]:
return session.query(PluginApiORM).filter(PluginApiORM.created_by == user_id).order_by(PluginApiORM.updated_at.desc()).all()
# paginate(
# session.query(PluginApiORM)
# .filter(PluginApiORM.created_by == user_id)
# .order_by(PluginApiORM.updated_at.desc())
# )

@staticmethod
def get_all(session: Session) -> List[PluginApiORM]:
return session.query(PluginApiORM).order_by(PluginApiORM.updated_at.desc()).all()

@staticmethod
def update(session: Session, operator: int, plugin_api_model: PluginApiUpdate) -> int:
Expand Down
14 changes: 12 additions & 2 deletions api/models/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Column,
Integer,
DateTime,
String,
Text,
)

from db import Base
Expand All @@ -18,6 +20,9 @@ class PluginORM(Base):
'''
__tablename__ = "plugin"
id = Column(Integer, primary_key=True)
name = Column(String(255), nullable=False) # 插件名字
avatar = Column(String(255)) # 插件图标
description = Column(Text) # 插件描述
plugin_type = Column(Integer, nullable=False) # 插件创建类型
created_by = Column(Integer)
created_at = Column(
Expand All @@ -34,22 +39,27 @@ class PluginCreate(BaseModel):
plugin create
'''
plugin_type: int
name: str
avatar: Optional[str]
description: Optional[str]


class PluginUpdate(BaseModel):
'''
plugin update
'''
id: int
plugin_type: int
plugin_type: Optional[int]
name: Optional[str]
avatar: Optional[str]
description: Optional[str]


class PluginModel(PluginCreate):
'''
plugin
'''
id: int
plugin_type: int
created_by: int
created_at: datetime
updated_by: int
Expand Down
19 changes: 10 additions & 9 deletions api/models/plugin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ForeignKey,
Integer,
DateTime,
String,
Text,
Boolean
)
Expand All @@ -23,7 +24,9 @@ class PluginApiORM(Base):
plugin_config_id = Column(Integer, ForeignKey(
"plugin_config.id"), nullable=True)
description = Column(Text) # API 描述
name = Column(String(255)) # 插件名字
openapi_desc = Column(Text) # openapi 先调研

# request_params = Column(JSON) # 入参
# response_params = Column(JSON) # 出参
# debug_example = Column(JSON) # 调试示例
Expand All @@ -47,9 +50,10 @@ class PluginApiCreate(BaseModel):
plugin api create
'''
plugin_config_id: int
description: str
openapi_desc: str
disabled: bool
name: Optional[str]
description: Optional[str]
disabled: Optional[bool]
created_by: int
created_at: datetime
updated_by: int
Expand All @@ -61,20 +65,17 @@ class PluginApiUpdate(BaseModel):
plugin api update
'''
id: int
description: str
openapi_desc: str
disabled: bool
description: Optional[str]
name: Optional[str]
openapi_desc: Optional[str]
disabled: Optional[bool]


class PluginApiModel(PluginApiCreate):
'''
plugin api
'''
id: int
plugin_config_id: int
description: str
openapi_desc: str
disabled: bool
created_by: int
created_at: datetime
updated_by: int
Expand Down
19 changes: 3 additions & 16 deletions api/models/plugin_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Integer,
DateTime,
ForeignKey,
String,
Text,
)

Expand All @@ -20,12 +19,10 @@ class PluginConfigORM(Base):
'''
__tablename__ = "plugin_config"
id = Column(Integer, primary_key=True)
name = Column(String(255), nullable=False) # 插件名字
avatar = Column(String(255)) # 插件图标
description = Column(Text) # 插件描述
# meta_info = Column(JSON) # 插件元信息,包含 api url,service_token,oauth_info 等
plugin_id = Column(Integer, ForeignKey("plugin.id"), nullable=True)
is_draft = Column(Boolean, nullable=False) # 是否为草稿
plugin_openapi_desc = Column(Text) # openapi 先调研
# openapi_desc = Column(Text) # openapi 先调研
# plugin_desc = Column(Text)
created_by = Column(Integer)
Expand All @@ -42,10 +39,7 @@ class PluginConfigCreate(BaseModel):
plugin config create
'''
plugin_id: int
name: str
avatar: str
description: str

plugin_openapi_desc: str
is_draft: Optional[bool]


Expand All @@ -54,22 +48,15 @@ class PluginConfigUpdate(BaseModel):
plugin config update
'''
id: int
plugin_openapi_desc: Optional[str]
is_draft: Optional[bool]
name: str
avatar: str
description: str


class PluginConfigModel(PluginConfigCreate):
'''
plugin config
'''
id: int
name: str
avatar: str
description: str

plugin_id: int
created_by: int
created_at: datetime
updated_by: int
Expand Down
Loading
Loading