diff --git a/bot/cogs/dev_tools.py b/bot/cogs/dev_tools.py index 0ca7c6e..f109cd4 100644 --- a/bot/cogs/dev_tools.py +++ b/bot/cogs/dev_tools.py @@ -4,13 +4,13 @@ from cogs import EXTENSIONS from discord.ext import commands from discord.ext.commands import Context, Greedy -from libs.utils import RoboView +from libs.utils import RoboContext, RoboView from rodhaj import Rodhaj class MaybeView(RoboView): - def __init__(self, ctx: commands.Context) -> None: + def __init__(self, ctx: RoboContext) -> None: super().__init__(ctx) @discord.ui.button(label="eg") @@ -26,7 +26,7 @@ class DevTools(commands.Cog, command_attrs=dict(hidden=True)): def __init__(self, bot: Rodhaj): self.bot = bot - async def cog_check(self, ctx: commands.Context) -> bool: + async def cog_check(self, ctx: RoboContext) -> bool: return await self.bot.is_owner(ctx.author) # Umbra's sync command @@ -79,7 +79,7 @@ async def sync( @commands.guild_only() @commands.command(name="reload-all") - async def reload_all(self, ctx: commands.Context) -> None: + async def reload_all(self, ctx: RoboContext) -> None: """Reloads all cogs. Used in production to not produce any downtime""" if not hasattr(self.bot, "uptime"): await ctx.send("Bot + exts must be up and loaded before doing this") @@ -90,7 +90,7 @@ async def reload_all(self, ctx: commands.Context) -> None: await ctx.send("Successfully reloaded all extensions live") @commands.command(name="view-test", hidden=True) - async def view_test(self, ctx: commands.Context) -> None: + async def view_test(self, ctx: RoboContext) -> None: view = MaybeView(ctx) view.message = await ctx.send("yeo", view=view) diff --git a/bot/cogs/utilities.py b/bot/cogs/utilities.py index 5a412d4..8b0bba8 100644 --- a/bot/cogs/utilities.py +++ b/bot/cogs/utilities.py @@ -7,7 +7,7 @@ import pygit2 from discord.ext import commands from discord.utils import format_dt -from libs.utils import Embed, human_timedelta +from libs.utils import Embed, RoboContext, human_timedelta from rodhaj import Rodhaj @@ -49,7 +49,7 @@ def get_last_commits(self, count: int = 5): return "\n".join(self.format_commit(c) for c in commits) @commands.hybrid_command(name="about") - async def about(self, ctx: commands.Context) -> None: + async def about(self, ctx: RoboContext) -> None: """Shows some stats for Rodhaj""" total_members = 0 total_unique = len(self.bot.users) @@ -83,13 +83,13 @@ async def about(self, ctx: commands.Context) -> None: await ctx.send(embed=embed) @commands.hybrid_command(name="uptime") - async def uptime(self, ctx: commands.Context) -> None: + async def uptime(self, ctx: RoboContext) -> None: """Displays the bot's uptime""" uptime_message = f"Uptime: {self.get_bot_uptime()}" await ctx.send(uptime_message) @commands.hybrid_command(name="version") - async def version(self, ctx: commands.Context) -> None: + async def version(self, ctx: RoboContext) -> None: """Displays the current build version""" version_message = f"Version: {self.bot.version}" await ctx.send(version_message) diff --git a/bot/libs/utils/context.py b/bot/libs/utils/context.py new file mode 100644 index 0000000..fa572d5 --- /dev/null +++ b/bot/libs/utils/context.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import discord +from discord.ext import commands + +from .views import RoboView + +if TYPE_CHECKING: + from bot.rodhaj import Rodhaj + + +class ConfirmationView(RoboView): + def __init__(self, ctx, timeout: float, delete_after: bool): + super().__init__(ctx, timeout) + self.value: Optional[bool] = None + self.delete_after = delete_after + self.message: Optional[discord.Message] = None + + async def on_timeout(self) -> None: + if self.delete_after and self.message: + await self.message.delete() + elif self.message: + await self.message.edit(view=None) + + async def delete_response(self, interaction: discord.Interaction): + await interaction.response.defer() + if self.delete_after: + await interaction.delete_original_response() + + self.stop() + + @discord.ui.button( + label="Confirm", + style=discord.ButtonStyle.green, + emoji="<:greenTick:596576670815879169>", + ) + async def confirm( + self, interaction: discord.Interaction, button: discord.ui.Button + ) -> None: + self.value = True + await self.delete_response(interaction) + + @discord.ui.button( + label="Cancel", + style=discord.ButtonStyle.red, + emoji="<:redTick:596576672149667840>", + ) + async def cancel( + self, interaction: discord.Interaction, button: discord.ui.Button + ) -> None: + self.value = False + await interaction.response.defer() + await interaction.delete_original_response() + self.stop() + + +class RoboContext(commands.Context): + bot: Rodhaj + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def prompt( + self, message: str, *, timeout: float = 60.0, delete_after: bool = False + ) -> Optional[bool]: + view = ConfirmationView(ctx=self, timeout=timeout, delete_after=delete_after) + view.message = await self.send(message, view=view, ephemeral=delete_after) + await view.wait() + return view.value diff --git a/bot/libs/utils/modals.py b/bot/libs/utils/modals.py index d0b6fc9..b426e1d 100644 --- a/bot/libs/utils/modals.py +++ b/bot/libs/utils/modals.py @@ -3,12 +3,12 @@ from typing import TYPE_CHECKING import discord -from discord.ext import commands +from .context import RoboContext from .errors import produce_error_embed if TYPE_CHECKING: - from bot.rodhaj import Rodhaj + pass NO_CONTROL_MSG = "This modal cannot be controlled by you, sorry!" @@ -16,7 +16,7 @@ class RoboModal(discord.ui.Modal): """Subclassed `discord.ui.Modal` that includes sane default configs""" - def __init__(self, ctx: commands.Context[Rodhaj], *args, **kwargs): + def __init__(self, ctx: RoboContext, *args, **kwargs): super().__init__(*args, **kwargs) self.ctx = ctx diff --git a/bot/libs/utils/views.py b/bot/libs/utils/views.py index 7421665..b89ff7d 100644 --- a/bot/libs/utils/views.py +++ b/bot/libs/utils/views.py @@ -3,12 +3,11 @@ from typing import TYPE_CHECKING, Any, Optional import discord -from discord.ext import commands from .errors import produce_error_embed if TYPE_CHECKING: - from bot.rodhaj import Rodhaj + from .context import RoboContext NO_CONTROL_MSG = "This view cannot be controlled by you, sorry!" @@ -16,8 +15,8 @@ class RoboView(discord.ui.View): """Subclassed `discord.ui.View` that includes sane default configs""" - def __init__(self, ctx: commands.Context[Rodhaj]): - super().__init__(timeout=5.0) + def __init__(self, ctx: RoboContext, *args, **kwargs): + super().__init__(*args, **kwargs) self.ctx = ctx self.message: Optional[discord.Message] diff --git a/bot/rodhaj.py b/bot/rodhaj.py index bde8ad2..f4a2fe4 100644 --- a/bot/rodhaj.py +++ b/bot/rodhaj.py @@ -1,13 +1,14 @@ import logging import signal from pathlib import Path +from typing import Union import asyncpg import discord from aiohttp import ClientSession from cogs import EXTENSIONS, VERSION from discord.ext import commands -from libs.utils import RodhajCommandTree +from libs.utils import RoboContext, RodhajCommandTree, send_error_embed _fsw = True try: @@ -54,6 +55,16 @@ async def fs_watcher(self) -> None: self.logger.info(f"Reloading extension: {reload_file.name[:-3]}") await self.reload_extension(f"cogs.{reload_file.name[:-3]}") + async def get_context( + self, origin: Union[discord.Interaction, discord.Message], /, *, cls=RoboContext + ) -> RoboContext: + return await super().get_context(origin, cls=cls) + + async def on_command_error( + self, ctx: commands.Context, error: commands.CommandError + ) -> None: + await send_error_embed(ctx, error) + async def setup_hook(self) -> None: def stop(): self.loop.create_task(self.close())