Skip to content

Commit

Permalink
Implement subclassed commands.Context (RoboContext)" -m "This all…
Browse files Browse the repository at this point in the history
…ows us to inject our own code that we may want within the context
  • Loading branch information
No767 committed Nov 23, 2023
1 parent fdacfc8 commit 10cb33e
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 17 deletions.
10 changes: 5 additions & 5 deletions bot/cogs/dev_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from cogs import EXTENSIONS
from discord.ext import commands
from discord.ext.commands import Context, Greedy
from libs.utils import RoboView
from libs.utils import RoboContext, RoboView

from rodhaj import Rodhaj


class MaybeView(RoboView):
def __init__(self, ctx: commands.Context) -> None:
def __init__(self, ctx: RoboContext) -> None:
super().__init__(ctx)

@discord.ui.button(label="eg")
Expand All @@ -26,7 +26,7 @@ class DevTools(commands.Cog, command_attrs=dict(hidden=True)):
def __init__(self, bot: Rodhaj):
self.bot = bot

async def cog_check(self, ctx: commands.Context) -> bool:
async def cog_check(self, ctx: RoboContext) -> bool:
return await self.bot.is_owner(ctx.author)

# Umbra's sync command
Expand Down Expand Up @@ -79,7 +79,7 @@ async def sync(

@commands.guild_only()
@commands.command(name="reload-all")
async def reload_all(self, ctx: commands.Context) -> None:
async def reload_all(self, ctx: RoboContext) -> None:
"""Reloads all cogs. Used in production to not produce any downtime"""
if not hasattr(self.bot, "uptime"):
await ctx.send("Bot + exts must be up and loaded before doing this")
Expand All @@ -90,7 +90,7 @@ async def reload_all(self, ctx: commands.Context) -> None:
await ctx.send("Successfully reloaded all extensions live")

@commands.command(name="view-test", hidden=True)
async def view_test(self, ctx: commands.Context) -> None:
async def view_test(self, ctx: RoboContext) -> None:
view = MaybeView(ctx)
view.message = await ctx.send("yeo", view=view)

Expand Down
8 changes: 4 additions & 4 deletions bot/cogs/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pygit2
from discord.ext import commands
from discord.utils import format_dt
from libs.utils import Embed, human_timedelta
from libs.utils import Embed, RoboContext, human_timedelta

from rodhaj import Rodhaj

Expand Down Expand Up @@ -49,7 +49,7 @@ def get_last_commits(self, count: int = 5):
return "\n".join(self.format_commit(c) for c in commits)

@commands.hybrid_command(name="about")
async def about(self, ctx: commands.Context) -> None:
async def about(self, ctx: RoboContext) -> None:
"""Shows some stats for Rodhaj"""
total_members = 0
total_unique = len(self.bot.users)
Expand Down Expand Up @@ -83,13 +83,13 @@ async def about(self, ctx: commands.Context) -> None:
await ctx.send(embed=embed)

@commands.hybrid_command(name="uptime")
async def uptime(self, ctx: commands.Context) -> None:
async def uptime(self, ctx: RoboContext) -> None:
"""Displays the bot's uptime"""
uptime_message = f"Uptime: {self.get_bot_uptime()}"
await ctx.send(uptime_message)

@commands.hybrid_command(name="version")
async def version(self, ctx: commands.Context) -> None:
async def version(self, ctx: RoboContext) -> None:
"""Displays the current build version"""
version_message = f"Version: {self.bot.version}"
await ctx.send(version_message)
Expand Down
71 changes: 71 additions & 0 deletions bot/libs/utils/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

import discord
from discord.ext import commands

from .views import RoboView

if TYPE_CHECKING:
from bot.rodhaj import Rodhaj


class ConfirmationView(RoboView):
def __init__(self, ctx, timeout: float, delete_after: bool):
super().__init__(ctx, timeout)
self.value: Optional[bool] = None
self.delete_after = delete_after
self.message: Optional[discord.Message] = None

async def on_timeout(self) -> None:
if self.delete_after and self.message:
await self.message.delete()
elif self.message:
await self.message.edit(view=None)

async def delete_response(self, interaction: discord.Interaction):
await interaction.response.defer()
if self.delete_after:
await interaction.delete_original_response()

self.stop()

@discord.ui.button(
label="Confirm",
style=discord.ButtonStyle.green,
emoji="<:greenTick:596576670815879169>",
)
async def confirm(
self, interaction: discord.Interaction, button: discord.ui.Button
) -> None:
self.value = True
await self.delete_response(interaction)

@discord.ui.button(
label="Cancel",
style=discord.ButtonStyle.red,
emoji="<:redTick:596576672149667840>",
)
async def cancel(
self, interaction: discord.Interaction, button: discord.ui.Button
) -> None:
self.value = False
await interaction.response.defer()
await interaction.delete_original_response()
self.stop()


class RoboContext(commands.Context):
bot: Rodhaj

def __init__(self, **kwargs):
super().__init__(**kwargs)

async def prompt(
self, message: str, *, timeout: float = 60.0, delete_after: bool = False
) -> Optional[bool]:
view = ConfirmationView(ctx=self, timeout=timeout, delete_after=delete_after)
view.message = await self.send(message, view=view, ephemeral=delete_after)
await view.wait()
return view.value
6 changes: 3 additions & 3 deletions bot/libs/utils/modals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
from typing import TYPE_CHECKING

import discord
from discord.ext import commands

from .context import RoboContext
from .errors import produce_error_embed

if TYPE_CHECKING:
from bot.rodhaj import Rodhaj
pass

NO_CONTROL_MSG = "This modal cannot be controlled by you, sorry!"


class RoboModal(discord.ui.Modal):
"""Subclassed `discord.ui.Modal` that includes sane default configs"""

def __init__(self, ctx: commands.Context[Rodhaj], *args, **kwargs):
def __init__(self, ctx: RoboContext, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ctx = ctx

Expand Down
7 changes: 3 additions & 4 deletions bot/libs/utils/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@
from typing import TYPE_CHECKING, Any, Optional

import discord
from discord.ext import commands

from .errors import produce_error_embed

if TYPE_CHECKING:
from bot.rodhaj import Rodhaj
from .context import RoboContext

NO_CONTROL_MSG = "This view cannot be controlled by you, sorry!"


class RoboView(discord.ui.View):
"""Subclassed `discord.ui.View` that includes sane default configs"""

def __init__(self, ctx: commands.Context[Rodhaj]):
super().__init__(timeout=5.0)
def __init__(self, ctx: RoboContext, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ctx = ctx
self.message: Optional[discord.Message]

Expand Down
13 changes: 12 additions & 1 deletion bot/rodhaj.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
import signal
from pathlib import Path
from typing import Union

import asyncpg
import discord
from aiohttp import ClientSession
from cogs import EXTENSIONS, VERSION
from discord.ext import commands
from libs.utils import RodhajCommandTree
from libs.utils import RoboContext, RodhajCommandTree, send_error_embed

_fsw = True
try:
Expand Down Expand Up @@ -54,6 +55,16 @@ async def fs_watcher(self) -> None:
self.logger.info(f"Reloading extension: {reload_file.name[:-3]}")
await self.reload_extension(f"cogs.{reload_file.name[:-3]}")

async def get_context(
self, origin: Union[discord.Interaction, discord.Message], /, *, cls=RoboContext
) -> RoboContext:
return await super().get_context(origin, cls=cls)

async def on_command_error(
self, ctx: commands.Context, error: commands.CommandError
) -> None:
await send_error_embed(ctx, error)

async def setup_hook(self) -> None:
def stop():
self.loop.create_task(self.close())
Expand Down

0 comments on commit 10cb33e

Please sign in to comment.