Skip to content

Commit

Permalink
Always attempt to retrieve thread parent for message in '_is_watched_…
Browse files Browse the repository at this point in the history
…channel' in 'anti_crosspost' extension
  • Loading branch information
Mega-JC committed Jul 30, 2024
1 parent f19cfd3 commit 8dd447f
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions pcbot/exts/anti_crosspost.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ async def on_message(self, message: discord.Message):
"""
if (
message.author.bot
or not self._is_valid_channel(message.channel) # type: ignore
or not await self._is_watched_channel(message.channel) # type: ignore
or message.type != discord.MessageType.default
or (
message.content
Expand Down Expand Up @@ -258,7 +258,9 @@ async def on_message(self, message: discord.Message):
user_cache["message_to_alert"][
message.id
] = alert_message.id
logger.debug(f"Sent alert message for crosspost URL {message.jump_url}")
logger.debug(
f"Sent alert message for crosspost URL {message.jump_url}"
)
except discord.HTTPException as e:
logger.debug(f"Failed to send alert message: {e}")
break
Expand Down Expand Up @@ -294,7 +296,7 @@ async def on_message_delete(self, message: discord.Message):
Args:
message (discord.Message): The message object.
"""
if not self._is_valid_channel(message.channel): # type: ignore
if not await self._is_watched_channel(message.channel): # type: ignore
return

if message.author.id not in self.crossposting_cache:
Expand Down Expand Up @@ -335,15 +337,15 @@ async def on_message_delete(self, message: discord.Message):
f"Failed to delete alert message ID {alert_message_id}: {e}"
)

def _is_valid_channel(self, channel: discord.abc.GuildChannel) -> bool:
async def _is_watched_channel(self, channel: discord.abc.GuildChannel) -> bool:
"""
Check if a channel is valid based on the configured channel IDs.
Check if a channel is watched for crossposts based on the configured channel IDs.
Args:
channel (discord.abc.GuildChannel): The channel to check.
Returns:
bool: True if the channel is valid, otherwise False.
bool: True if the channel is watched, otherwise False.
"""
if isinstance(channel, discord.abc.GuildChannel):
# Check if the channel ID or category ID is in the monitored channel IDs
Expand All @@ -357,8 +359,18 @@ def _is_valid_channel(self, channel: discord.abc.GuildChannel) -> bool:
if isinstance(channel, discord.Thread):
if channel.parent_id in self.channel_ids:
return True
if channel.parent and channel.parent.category_id in self.channel_ids:
return True

try:
parent = (
channel.parent
or channel.guild.get_channel(channel.parent_id)
or await channel.guild.fetch_channel(channel.parent_id)
)
except discord.NotFound:
pass
else:
if parent and parent.category_id in self.channel_ids:
return True

return False

Expand Down

0 comments on commit 8dd447f

Please sign in to comment.