Skip to content

Commit

Permalink
allow setting guild commands to all available guilds
Browse files Browse the repository at this point in the history
  • Loading branch information
odrling committed Apr 1, 2024
1 parent 044032a commit 3465e2a
Showing 1 changed file with 27 additions and 31 deletions.
58 changes: 27 additions & 31 deletions discord/app_commands/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
Union,
overload,
)
from collections import Counter
from collections import ChainMap, Counter


from .namespace import Namespace, ResolveKey
Expand Down Expand Up @@ -81,6 +81,9 @@
_log = logging.getLogger(__name__)


ALL_GUILDS = -1


def _retrieve_guild_ids(
command: Any, guild: Optional[Snowflake] = MISSING, guilds: Sequence[Snowflake] = MISSING
) -> Optional[Set[int]]:
Expand Down Expand Up @@ -355,12 +358,13 @@ def _context_menu_add_helper(
# adding it into the mapping. This ensures atomicity.
for guild_id in guild_ids:
commands = self._guild_commands.get(guild_id, {})
found = name in commands
check_commands = ChainMap(commands, self._guild_commands.get(ALL_GUILDS, {}))
found = name in check_commands
if found and not override:
raise CommandAlreadyRegistered(name, guild_id)

to_add = not (override and found)
if len(commands) + to_add > 100:
if len(check_commands) + to_add > 100:
raise CommandLimitReached(guild_id=guild_id, limit=100)

# Actually add the command now that it has been verified to be okay.
Expand Down Expand Up @@ -445,12 +449,9 @@ def remove_command(
if guild is None:
return self._global_commands.pop(command, None)
else:
try:
commands = self._guild_commands[guild.id]
except KeyError:
return None
else:
return commands.pop(command, None)
commands = ChainMap(self._guild_commands.get(ALL_GUILDS, {}), self._guild_commands.get(guild.id, {}))

return commands.pop(command, None)
elif type in (AppCommandType.user, AppCommandType.message):
guild_id = None if guild is None else guild.id
key = (command, guild_id, type.value)
Expand All @@ -477,7 +478,7 @@ def clear_commands(self, *, guild: Optional[Snowflake], type: Optional[AppComman
self._global_commands.clear()
else:
try:
commands = self._guild_commands[guild.id]
commands = ChainMap(self._guild_commands.get(ALL_GUILDS, {}), self._guild_commands.get(guild.id, {}))
except KeyError:
pass
else:
Expand Down Expand Up @@ -563,7 +564,7 @@ def get_command(
return self._global_commands.get(command)
else:
try:
commands = self._guild_commands[guild.id]
commands = ChainMap(self._guild_commands.get(ALL_GUILDS, {}), self._guild_commands.get(guild.id, {}))
except KeyError:
return None
else:
Expand Down Expand Up @@ -643,7 +644,7 @@ def get_commands(
return list(self._global_commands.values())
else:
try:
commands = self._guild_commands[guild.id]
commands = ChainMap(self._guild_commands.get(ALL_GUILDS, {}), self._guild_commands.get(guild.id, {}))
except KeyError:
return []
else:
Expand Down Expand Up @@ -711,7 +712,7 @@ def walk_commands(
yield from cmd.walk_commands()
else:
try:
commands = self._guild_commands[guild.id]
commands = ChainMap(self._guild_commands.get(ALL_GUILDS, {}), self._guild_commands.get(guild.id, {}))
except KeyError:
return
else:
Expand All @@ -734,16 +735,12 @@ def _get_all_commands(
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g is None)
return base
else:
try:
commands = self._guild_commands[guild.id]
except KeyError:
guild_id = guild.id
return [cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id]
else:
base: List[Union[Command[Any, ..., Any], Group, ContextMenu]] = list(commands.values())
guild_id = guild.id
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id)
return base
commands = ChainMap(self._guild_commands.get(ALL_GUILDS, {}), self._guild_commands.get(guild.id, {}))

base: List[Union[Command[Any, ..., Any], Group, ContextMenu]] = list(commands.values())
guild_id = guild.id
base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id)
return base

def _remove_with_module(self, name: str) -> None:
remove: List[Any] = []
Expand Down Expand Up @@ -1109,14 +1106,13 @@ def _get_app_command_options(

command_guild_id = _get_as_snowflake(data, 'guild_id')
if command_guild_id:
try:
guild_commands = self._guild_commands[command_guild_id]
except KeyError:
command = None if not self.fallback_to_global else self._global_commands.get(name)
else:
command = guild_commands.get(name)
if command is None and self.fallback_to_global:
command = self._global_commands.get(name)
guild_commands = ChainMap(
self._guild_commands.get(ALL_GUILDS, {}), self._guild_commands.get(command_guild_id, {})
)

command = guild_commands.get(name)
if command is None and self.fallback_to_global:
command = self._global_commands.get(name)
else:
command = self._global_commands.get(name)

Expand Down

0 comments on commit 3465e2a

Please sign in to comment.