diff --git a/bot/cogs/config.py b/bot/cogs/config.py index d523bfc..1f8d0f7 100644 --- a/bot/cogs/config.py +++ b/bot/cogs/config.py @@ -1,14 +1,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Annotated, Optional, Union import asyncpg import discord import msgspec from async_lru import alru_cache +from discord import app_commands from discord.ext import commands from libs.utils import GuildContext from libs.utils.checks import bot_check_permissions, check_permissions +from libs.utils.embeds import Embed +from libs.utils.prefix import get_prefix if TYPE_CHECKING: from rodhaj import Rodhaj @@ -96,6 +99,16 @@ class SetupFlags(commands.FlagConverter): ) +class PrefixConverter(commands.Converter): + async def convert(self, ctx: commands.Context, argument: str): + user_id = ctx.bot.user.id + if argument.startswith((f"<@{user_id}>", f"<@!{user_id}>")): + raise commands.BadArgument("That is a reserved prefix already in use.") + if len(argument) > 100: + raise commands.BadArgument("That prefix is too long.") + return argument + + class Config(commands.Cog): """Config and setup commands for Rodhaj""" @@ -120,6 +133,12 @@ async def get_guild_config(self, guild_id: int) -> Optional[GuildConfig]: config = GuildConfig(bot=self.bot, **dict(rows)) return config + def clean_prefixes(self, prefixes: Union[str, list[str]]) -> str: + if isinstance(prefixes, str): + return f"`{prefixes}`" + + return ", ".join(f"`{prefix}`" for prefix in prefixes[2:]) + @check_permissions(manage_guild=True) @bot_check_permissions(manage_channels=True, manage_webhooks=True) @commands.guild_only() @@ -334,6 +353,99 @@ async def delete(self, ctx: GuildContext) -> None: else: await ctx.send("Cancelling.") + @check_permissions(manage_guild=True) + @commands.guild_only() + @config.group(name="prefix", fallback="info") + async def prefix(self, ctx: GuildContext) -> None: + """Shows and manages custom prefixes for the guild + + Passing in no subcommands will effectively show the currently set prefixes. + """ + prefixes = await get_prefix(self.bot, ctx.message) + embed = Embed() + embed.add_field( + name="Prefixes", value=self.clean_prefixes(prefixes), inline=False + ) + embed.add_field(name="Total", value=len(prefixes) - 2, inline=False) + embed.set_author(name=ctx.guild.name, icon_url=ctx.guild.icon.url) # type: ignore + await ctx.send(embed=embed) + + @prefix.command(name="add") + @app_commands.describe(prefix="The new prefix to add") + async def prefix_add( + self, ctx: GuildContext, prefix: Annotated[str, PrefixConverter] + ) -> None: + """Adds an custom prefix""" + prefixes = await get_prefix(self.bot, ctx.message) + + # 2 are the mention prefixes, which are always prepended on the list of prefixes + if isinstance(prefixes, list) and len(prefixes) > 12: + await ctx.send( + "You can not have more than 10 custom prefixes for your server" + ) + return + elif prefix in prefixes: + await ctx.send("The prefix you want to set already exists") + return + + query = """ + UPDATE guild_config + SET prefix = ARRAY_APPEND(prefix, $1) + WHERE id = $2; + """ + await self.pool.execute(query, prefix, ctx.guild.id) + get_prefix.cache_invalidate(self.bot, ctx.message) + await ctx.send(f"Added prefix: `{prefix}`") + + @prefix.command(name="edit") + @app_commands.describe( + old="The prefix to edit", new="A new prefix to replace the old" + ) + @app_commands.rename(old="old_prefix", new="new_prefix") + async def prefix_edit( + self, + ctx: GuildContext, + old: Annotated[str, PrefixConverter], + new: Annotated[str, PrefixConverter], + ) -> None: + """Edits and replaces a prefix""" + query = """ + UPDATE guild_config + SET prefix = ARRAY_REPLACE(prefix, $1, $2) + WHERE id = $3; + """ + prefixes = await get_prefix(self.bot, ctx.message) + + guild_id = ctx.guild.id + if old in prefixes: + await self.pool.execute(query, old, new, guild_id) + get_prefix.cache_invalidate(self.bot, ctx.message) + await ctx.send(f"Prefix updated to from `{old}` to `{new}`") + else: + await ctx.send("The prefix is not in the list of prefixes for your server") + + @prefix.command(name="delete") + @app_commands.describe(prefix="The prefix to delete") + async def prefix_delete( + self, ctx: GuildContext, prefix: Annotated[str, PrefixConverter] + ) -> None: + """Deletes a set prefix""" + query = """ + UPDATE guild_config + SET prefix = ARRAY_REMOVE(prefix, $1) + WHERE id=$2; + """ + msg = f"Do you want to delete the following prefix: {prefix}" + confirm = await ctx.prompt(msg, timeout=120.0, delete_after=True) + if confirm: + await self.pool.execute(query, prefix, ctx.guild.id) + get_prefix.cache_invalidate(self.bot, ctx.message) + await ctx.send(f"The prefix `{prefix}` has been successfully deleted") + elif confirm is None: + await ctx.send("Confirmation timed out. Cancelled deletion...") + else: + await ctx.send("Confirmation cancelled. Please try again") + async def setup(bot: Rodhaj) -> None: await bot.add_cog(Config(bot)) diff --git a/bot/libs/utils/prefix.py b/bot/libs/utils/prefix.py new file mode 100644 index 0000000..8bbe0f6 --- /dev/null +++ b/bot/libs/utils/prefix.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Union + +import discord +from async_lru import alru_cache + +if TYPE_CHECKING: + from bot.rodhaj import Rodhaj + + +@alru_cache(maxsize=1024) +async def get_prefix(bot: Rodhaj, message: discord.Message) -> Union[str, list[str]]: + """Obtains the prefix for the guild + + This coroutine is heavily cached in order to reduce database calls + and improved performance + + + Args: + bot (Rodhaj): An instance of `Rodhaj` + message (discord.Message): The message that is processed + + Returns: + Union[str, List[str]]: The default prefix or + a list of prefixes (including the default) + """ + user_id = bot.user.id # type: ignore + + # By putting the base with the mentions, we are effectively + # doing the exact same thing as commands.when_mentioned + base = [f"<@!{user_id}> ", f"<@{user_id}> ", bot.default_prefix] + if message.guild is None: + get_prefix.cache_invalidate(bot, message) + return base + + query = """ + SELECT prefix + FROM guild_config + WHERE id = $1; + """ + prefixes = await bot.pool.fetchval(query, message.guild.id) + if prefixes is None: + get_prefix.cache_invalidate(bot, message) + return base + base.extend(item for item in prefixes) + return base diff --git a/bot/rodhaj.py b/bot/rodhaj.py index e1428c2..f947f31 100644 --- a/bot/rodhaj.py +++ b/bot/rodhaj.py @@ -20,6 +20,7 @@ send_error_embed, ) from libs.utils.config import RodhajConfig +from libs.utils.prefix import get_prefix from libs.utils.reloader import Reloader if TYPE_CHECKING: @@ -45,13 +46,14 @@ def __init__( allowed_mentions=discord.AllowedMentions( everyone=False, replied_user=False ), - command_prefix=["r>", "?", "!"], + command_prefix=get_prefix, help_command=RodhajHelp(), intents=intents, tree_cls=RodhajCommandTree, *args, **kwargs, ) + self.default_prefix = "r>" self.logger = logging.getLogger("rodhaj") self.session = session self.partial_config: Optional[PartialConfig] = None