Skip to content

Commit

Permalink
Implement better extension/library watchdog (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
No767 authored Jan 26, 2024
1 parent b9ae30b commit 90e78bb
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 54 deletions.
1 change: 1 addition & 0 deletions bot/libs/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __enter__(self) -> None:
max_bytes = 32 * 1024 * 1024 # 32 MiB
self.log.setLevel(logging.INFO)
logging.getLogger("discord").setLevel(logging.INFO)
logging.getLogger("watchfiles").setLevel(logging.WARNING)
handler = RotatingFileHandler(
filename="rodhaj.log",
encoding="utf-8",
Expand Down
88 changes: 88 additions & 0 deletions bot/libs/utils/reloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

import asyncio
import importlib
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Optional

from discord.ext import commands
from watchfiles import Change, awatch

if TYPE_CHECKING:
from rodhaj import Rodhaj


class Reloader:
"""An watchdog for reloading extensions and library files
This reloads/unloads extensions, and also reloads library modules.
This does not implement a deep reload, as there is no way to do so
that way.
"""

def __init__(self, bot: Rodhaj, path: Path):
self.bot = bot
self.path = path

self.loop = asyncio.get_running_loop()
self.logger = bot.logger
self._cogs_path = self.path / "cogs"
self._libs_path = self.path / "libs"

### Finding modules from the path directly

def find_modules_from_path(self, path: str) -> Optional[str]:
root, ext = os.path.splitext(path)
sys_path_index = len(sys.path[0].split("/"))
if ext != ".py":
return

local_path = root.split("/")[sys_path_index:]
return ".".join(item for item in local_path)

### Loading/reloading extensions and library modules

async def reload_or_load_extension(self, module: str) -> None:
try:
await self.bot.reload_extension(module)
self.logger.info("Reloaded extension: %s", module)
except commands.ExtensionNotLoaded:
await self.bot.load_extension(module)
self.logger.info("Loaded extension: %s", module)

async def reload_library(self, module: str) -> None:
try:
actual_module = sys.modules[module]
importlib.reload(actual_module)
self.logger.info("Reloaded lib module: %s", module)
except KeyError:
self.logger.warning("Failed to reload module %s. Does it exist?", module)

async def reload_extension_or_library(self, module: str) -> None:
if module.startswith("libs"):
await self.reload_library(module)
elif module.startswith("cogs"):
await self.reload_or_load_extension(module)

### Internal coroutine to start the watch

async def _start(self) -> None:
async for changes in awatch(self._cogs_path, self._libs_path):
for ctype, cpath in changes:
module = self.find_modules_from_path(cpath)
if module is None:
continue

if ctype == Change.modified or ctype == Change.added:
await self.reload_extension_or_library(module)
elif ctype == Change.deleted:
await self.bot.unload_extension(module)

### Public method to start the reloader

def start(self) -> None:
"""Starts the deep reloader"""
self.loop.create_task(self._start())
self.bot.dispatch("deepreloader_ready")
25 changes: 5 additions & 20 deletions bot/rodhaj.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,11 @@
RodhajHelp,
send_error_embed,
)
from libs.utils.reloader import Reloader

if TYPE_CHECKING:
from cogs.tickets import Tickets

_fsw = True
try:
from watchfiles import awatch
except ImportError:
_fsw = False

TRANSPROGRAMMER_GUILD_ID = 1183302385020436480


Expand Down Expand Up @@ -65,6 +60,7 @@ def __init__(
self.transprogrammer_guild_id = TRANSPROGRAMMER_GUILD_ID
self.version = str(VERSION)
self._dev_mode = dev_mode
self._reloader = Reloader(self, Path(__file__).parent)

### Ticket related utils
async def fetch_partial_config(self) -> Optional[PartialConfig]:
Expand Down Expand Up @@ -196,17 +192,6 @@ async def on_message(self, message: discord.Message) -> None:
return
await self.process_commands(message, ctx)

### Dev related utils

async def fs_watcher(self) -> None:
cogs_path = Path(__file__).parent.joinpath("cogs")
async for changes in awatch(cogs_path):
changes_list = list(changes)[0]
if changes_list[0].modified == 2:
reload_file = Path(changes_list[1])
self.logger.info(f"Reloading extension: {reload_file.name[:-3]}")
await self.reload_extension(f"cogs.{reload_file.name[:-3]}")

### Internal core overrides

async def setup_hook(self) -> None:
Expand All @@ -219,9 +204,9 @@ async def setup_hook(self) -> None:

self.partial_config = await self.fetch_partial_config()

if self._dev_mode is True and _fsw is True:
self.logger.info("Dev mode is enabled. Loading FSWatcher")
self.loop.create_task(self.fs_watcher())
if self._dev_mode:
self.logger.info("Dev mode is enabled. Loading Reloader")
self._reloader.start()

async def on_ready(self):
if not hasattr(self, "uptime"):
Expand Down
44 changes: 12 additions & 32 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ environs = "^10.3.0"
async-lru = "^2.0.4"
msgspec = "^0.18.6"
jishaku = "^2.5.2"
watchfiles = "^0.21.0"

[tool.poetry.group.dev.dependencies]
# These are pinned by major version
Expand All @@ -30,7 +31,6 @@ jishaku = "^2.5.2"
pre-commit = "^3"
pyright = "^1.1"
ruff = "^0.1"
watchfiles = "^0"

[tool.poetry.group.docs.dependencies]
sphinx = "^7.2.6"
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ typing-extensions==4.9.0
environs==10.3.0
async-lru==2.0.4
msgspec==0.18.6
jishaku==2.5.2
jishaku==2.5.2
watchfiles>=0.21.0,<1

0 comments on commit 90e78bb

Please sign in to comment.