Skip to content

Commit

Permalink
updatea and fix
Browse files Browse the repository at this point in the history
  • Loading branch information
DoroWolf committed Jan 3, 2025
1 parent 1378874 commit 3f6b983
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 27 deletions.
23 changes: 12 additions & 11 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ class RestartBot(Exception):


def init_bot():
import core.scripts.config_generate # noqa
from core.config import Config, CFGManager # noqa
from core.config import Config # noqa
from core.constants.default import base_superuser_default # noqa
from core.database import BotDBUtil, session, DBVersion # noqa
from core.logger import Logger # noqa
Expand Down Expand Up @@ -75,13 +74,6 @@ def init_bot():
"The base superuser is not found, please setup it in the config file."
)

disabled_bots.clear()
for t in CFGManager.values:
if t.startswith("bot_") and not t.endswith("_secret"):
if "enable" in CFGManager.values[t][t]:
if not CFGManager.values[t][t]["enable"]:
disabled_bots.append(t[4:])


def multiprocess_run_until_complete(func):
p = multiprocessing.Process(
Expand Down Expand Up @@ -117,7 +109,7 @@ def go(bot_name: str, subprocess: bool = False, binary_mode: bool = False):

def run_bot():
from core.constants.path import cache_path # noqa
from core.config import Config # noqa
from core.config import Config, CFGManager # noqa
from core.logger import Logger # noqa

def restart_process(bot_name: str):
Expand Down Expand Up @@ -155,6 +147,14 @@ def restart_process(bot_name: str):
envs["PYTHONPATH"] = os.path.abspath(".")
lst = bots_and_required_configs.keys()

for t in CFGManager.values:
if t.startswith("bot_") and not t.endswith("_secret"):
if "enable" in CFGManager.values[t][t]:
if not CFGManager.values[t][t]["enable"]:
disabled_bots.append(t[4:])
else:
Logger.warning(f"Bot {t} cannot found config \"enable\".")

for bl in lst:
if bl in disabled_bots:
continue
Expand All @@ -163,7 +163,7 @@ def restart_process(bot_name: str):
for c in bots_and_required_configs[bl]:
if not Config(c, _global=True):
Logger.error(
f"Bot {bl} requires config {c} but not found, abort to launch."
f"Bot {bl} requires config \"{c}\" but not found, abort to launch."
)
abort = True
break
Expand Down Expand Up @@ -200,6 +200,7 @@ def restart_process(bot_name: str):


if __name__ == "__main__":
import core.scripts.config_generate # noqa
try:
while True:
try:
Expand Down
45 changes: 32 additions & 13 deletions bots/api/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from slowapi import Limiter
from slowapi.util import get_remote_address

from bots.api.info import client_name
from core.constants import config_filename, config_path, logs_path
Expand Down Expand Up @@ -70,7 +72,7 @@ def save_csrf_tokens(tokens):


def verify_csrf_token(request: Request):
csrf_token = request.cookies.get("csrfToken")
csrf_token = request.cookies.get("XSRF-TOKEN")
if not csrf_token:
raise HTTPException(status_code=403, detail="Missing CSRF token")

Expand Down Expand Up @@ -106,6 +108,7 @@ def verify_jwt(request: Request):


app = FastAPI()
limiter = Limiter(key_func=get_remote_address)
ph = PasswordHasher()


Expand All @@ -130,18 +133,21 @@ async def startup_event():


@app.get("/favicon.ico", response_class=FileResponse)
async def favicon():
@limiter.limit("2/second")
async def favicon(request: Request):
favicon_path = os.path.join(assets_path, "favicon.ico")
return FileResponse(favicon_path)


@app.get("/api/verify-token")
@limiter.limit("2/second")
async def verify_token(request: Request):
verify_jwt(request)


@app.get("/api/get-csrf-token")
async def set_csrf_token(response: Response):
@limiter.limit("2/second")
async def set_csrf_token(request: Request, response: Response):
current_time = time.time()

token_entries = load_csrf_tokens()
Expand All @@ -160,7 +166,7 @@ async def set_csrf_token(response: Response):
save_csrf_tokens(token_entries)

response.set_cookie(
key="csrfToken",
key="XSRF-TOKEN",
value=csrf_token,
httponly=True,
secure=True,
Expand All @@ -172,6 +178,7 @@ async def set_csrf_token(response: Response):


@app.post("/api/auth")
@limiter.limit("10/minute")
async def auth(request: Request, response: Response):
try:
payload = {
Expand Down Expand Up @@ -232,6 +239,7 @@ async def auth(request: Request, response: Response):


@app.post("/api/change-password")
@limiter.limit("10/minute")
async def change_password(request: Request):
try:
verify_jwt(request)
Expand Down Expand Up @@ -271,6 +279,7 @@ async def change_password(request: Request):


@app.get("/api/server-info")
@limiter.limit("10/minute")
async def server_info(request: Request):
verify_jwt(request)
return {
Expand Down Expand Up @@ -306,6 +315,7 @@ async def server_info(request: Request):


@app.get("/api/config")
@limiter.limit("2/second")
async def get_config_list(request: Request):
verify_jwt(request)
try:
Expand All @@ -325,7 +335,8 @@ async def get_config_list(request: Request):


@app.get("/api/config/{cfg_filename}")
async def get_config_file(cfg_filename: str, request: Request):
@limiter.limit("2/second")
async def get_config_file(request: Request, cfg_filename: str):
verify_jwt(request)
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="Not found")
Expand All @@ -347,7 +358,8 @@ async def get_config_file(cfg_filename: str, request: Request):


@app.post("/api/config/{cfg_filename}")
async def edit_config_file(cfg_filename: str, request: Request):
@limiter.limit("10/minute")
async def edit_config_file(request: Request, cfg_filename: str):
verify_jwt(request)
verify_csrf_token(request)
if not os.path.exists(config_path):
Expand All @@ -370,7 +382,8 @@ async def edit_config_file(cfg_filename: str, request: Request):


@app.get("/api/target/{target_id}")
async def get_target(target_id: str):
@limiter.limit("2/second")
async def get_target(request: Request, target_id: str):
target = BotDBUtil.TargetInfo(target_id)
if not target.query:
return HTTPException(status_code=404, detail="Not found")
Expand Down Expand Up @@ -401,7 +414,8 @@ async def get_target(target_id: str):


@app.get("/api/sender/{sender_id}")
async def get_sender(sender_id: str):
@limiter.limit("2/second")
async def get_sender(request: Request, sender_id: str):
sender = BotDBUtil.SenderInfo(sender_id)
if not sender.query:
return HTTPException(status_code=404, detail="Not found")
Expand All @@ -418,12 +432,14 @@ async def get_sender(sender_id: str):


@app.get("/api/modules")
async def get_module_list():
@limiter.limit("2/second")
async def get_module_list(request: Request):
return {"modules": ModulesManager.return_modules_list()}


@app.get("/api/modules/{target_id}")
async def get_target_modules(target_id: str):
@limiter.limit("2/second")
async def get_target_modules(request: Request, target_id: str):
target_data = BotDBUtil.TargetInfo(target_id)
if not target_data.query:
return HTTPException(status_code=404, detail="Not found")
Expand All @@ -437,7 +453,8 @@ async def get_target_modules(target_id: str):


@app.post("/api/modules/{target_id}/enable")
async def enable_modules(target_id: str, request: Request):
@limiter.limit("10/minute")
async def enable_modules(request: Request, target_id: str):
try:
target_data = BotDBUtil.TargetInfo(target_id)
if not target_data.query:
Expand All @@ -462,7 +479,8 @@ async def enable_modules(target_id: str, request: Request):


@app.post("/api/modules/{target_id}/disable")
async def disable_modules(target_id: str, request: Request):
@limiter.limit("10/minute")
async def disable_modules(request: Request, target_id: str):
try:
target_data = BotDBUtil.TargetInfo(target_id)
if not target_data.query:
Expand All @@ -487,7 +505,8 @@ async def disable_modules(target_id: str, request: Request):


@app.get("/api/locale/{locale}/{string}")
async def get_locale(locale: str, string: str):
@limiter.limit("2/second")
async def get_locale(request: Request, locale: str, string: str):
try:
return {
"locale": locale,
Expand Down
1 change: 1 addition & 0 deletions console.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ async def send_command(msg):


if __name__ == "__main__":
import core.scripts.config_generate # noqa
init_bot()
Info.client_name = client_name
loop = asyncio.new_event_loop()
Expand Down
2 changes: 1 addition & 1 deletion core/scripts/config_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import re
import shutil
import sys
import traceback # noqa
from time import sleep

if __name__ == '__main__':
Expand Down Expand Up @@ -42,6 +41,7 @@ def generate_config(dir_path, language):
'config.comments.config_version',
fallback_failed_prompt=False)}\n')
f.write('initialized = false\n')
f.close()

from core.config import Config, CFGManager # noqa

Expand Down
Loading

0 comments on commit 3f6b983

Please sign in to comment.