From f31865036d8b893aae1e9b6f2c7100287133ec7e Mon Sep 17 00:00:00 2001 From: Blue Date: Sun, 29 Oct 2023 22:07:35 +0000 Subject: [PATCH] Lint with `black` and check lint on github workflow --- .github/workflows/python-package.yml | 7 +- setup.py | 14 +- sql_helper/_connection.py | 9 +- sql_helper/connection.py | 14 +- sql_helper/emoji.py | 14 +- sql_helper/guild_settings.py | 14 +- sql_helper/metrics.py | 10 +- sql_helper/mixins/__init__.py | 2 +- sql_helper/mixins/aliases.py | 46 +++--- sql_helper/mixins/blocked_emojis.py | 58 ++++---- sql_helper/mixins/command_messages.py | 8 +- sql_helper/mixins/emoji_hashes.py | 6 +- sql_helper/mixins/emojis.py | 197 ++++++++++++++++---------- sql_helper/mixins/emojis_used.py | 20 +-- sql_helper/mixins/guild_members.py | 15 +- sql_helper/mixins/guild_messages.py | 123 ++++++++++------ sql_helper/mixins/guild_settings.py | 19 +-- sql_helper/mixins/guild_webhooks.py | 14 +- sql_helper/mixins/packs.py | 37 +++-- sql_helper/mixins/personas.py | 21 ++- sql_helper/mixins/premium.py | 44 +++--- sql_helper/premium_user.py | 32 +++-- sql_helper/webhook.py | 8 +- tests/base.py | 17 ++- tests/test_emotes.py | 156 ++++++++++---------- 25 files changed, 533 insertions(+), 372 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 2b1c469..5f6adca 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -23,12 +23,17 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest-asyncio pytest-mock mock pytest flake8 hypothesis + python -m pip install pytest-asyncio pytest-mock mock pytest flake8 hypothesis black python -m pip install . - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + - name: Lint with Black + uses: psf/black@stable + with: + options: "--check --verbose" + src: "." - name: Test with pytest run: | pytest diff --git a/setup.py b/setup.py index e9b9b36..617920c 100644 --- a/setup.py +++ b/setup.py @@ -2,11 +2,11 @@ setup( - name="sql_helper", - version="0.1.0", - description="A helper for PostgreSQL ", - author='Blue', - url="https://nqn.blue/", - packages=find_packages(), - install_requires=["aiopg", "discord.py", "python-dateutil"] + name="sql_helper", + version="0.1.0", + description="A helper for PostgreSQL ", + author="Blue", + url="https://nqn.blue/", + packages=find_packages(), + install_requires=["aiopg", "discord.py", "python-dateutil"], ) diff --git a/sql_helper/_connection.py b/sql_helper/_connection.py index 8a82235..a0d1b07 100644 --- a/sql_helper/_connection.py +++ b/sql_helper/_connection.py @@ -4,7 +4,13 @@ class _PostgresConnection: - def __init__(self, pool, get_guild: Callable[[int], Optional[Guild]], get_emoji: Callable[[SQLEmoji], Optional[Emoji]], profiler=None): + def __init__( + self, + pool, + get_guild: Callable[[int], Optional[Guild]], + get_emoji: Callable[[SQLEmoji], Optional[Emoji]], + profiler=None, + ): self.pool = pool self.pool_acq = None self.conn = None @@ -31,4 +37,3 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): self.cur_acq = None self.cur = None return rtn - diff --git a/sql_helper/connection.py b/sql_helper/connection.py index a311842..8790d96 100644 --- a/sql_helper/connection.py +++ b/sql_helper/connection.py @@ -25,11 +25,11 @@ class PostgresConnection( class SQLConnection: def __init__( - self, - pool, - get_guild: Optional[Callable[[int], Optional[Guild]]] = None, - get_emoji: Optional[Callable[[SQLEmoji], Optional[Emoji]]] = None, - profiler=None + self, + pool, + get_guild: Optional[Callable[[int], Optional[Guild]]] = None, + get_emoji: Optional[Callable[[SQLEmoji], Optional[Emoji]]] = None, + profiler=None, ): self.pool = pool self._get_guild = get_guild or (lambda id: None) @@ -37,4 +37,6 @@ def __init__( self.profiler = profiler def __call__(self) -> PostgresConnection: - return PostgresConnection(self.pool, self._get_guild, self._get_emoji, self.profiler) + return PostgresConnection( + self.pool, self._get_guild, self._get_emoji, self.profiler + ) diff --git a/sql_helper/emoji.py b/sql_helper/emoji.py index 8314d87..b18da77 100644 --- a/sql_helper/emoji.py +++ b/sql_helper/emoji.py @@ -1,5 +1,17 @@ from collections import namedtuple -SQLEmoji = namedtuple("SQLEmoji", ["emote_id", "emote_hash", "usable", "animated", "emote_sha", "guild_id", "name", "has_roles"]) +SQLEmoji = namedtuple( + "SQLEmoji", + [ + "emote_id", + "emote_hash", + "usable", + "animated", + "emote_sha", + "guild_id", + "name", + "has_roles", + ], +) EmojiCounts = namedtuple("EmojiCounts", ["static", "animated"]) diff --git a/sql_helper/guild_settings.py b/sql_helper/guild_settings.py index b2143c4..b2f87e7 100644 --- a/sql_helper/guild_settings.py +++ b/sql_helper/guild_settings.py @@ -16,13 +16,13 @@ class SettingsFlags(Flag): DEFAULTS = ( - SettingsFlags.stickers | - SettingsFlags.nitro | - SettingsFlags.replies | - SettingsFlags.pings | - SettingsFlags.user_content | - SettingsFlags.dashboard_posting | - SettingsFlags.phish_detection + SettingsFlags.stickers + | SettingsFlags.nitro + | SettingsFlags.replies + | SettingsFlags.pings + | SettingsFlags.user_content + | SettingsFlags.dashboard_posting + | SettingsFlags.phish_detection ) diff --git a/sql_helper/metrics.py b/sql_helper/metrics.py index fcae5ee..ee5870d 100644 --- a/sql_helper/metrics.py +++ b/sql_helper/metrics.py @@ -5,10 +5,7 @@ def sql_wrapper(namespace: str): sql_request = Histogram( - "sql_request_total", - "SQL Requests", - ["query"], - namespace=namespace + "sql_request_total", "SQL Requests", ["query"], namespace=namespace ) def middle(execute): @@ -26,9 +23,10 @@ async def execute_wrapper(query, *args, **kwargs): data={ "query": query, "params": kwargs.get("parameters"), - "total_time": total_time - } + "total_time": total_time, + }, ) return execute_wrapper + return middle diff --git a/sql_helper/mixins/__init__.py b/sql_helper/mixins/__init__.py index 39b7e43..5208a7c 100644 --- a/sql_helper/mixins/__init__.py +++ b/sql_helper/mixins/__init__.py @@ -10,4 +10,4 @@ from .emoji_hashes import EmojiHashesMixin from .aliases import AliasesMixin from .blocked_emojis import BlockedEmojisMixin -from .command_messages import CommandMessagesMixin \ No newline at end of file +from .command_messages import CommandMessagesMixin diff --git a/sql_helper/mixins/aliases.py b/sql_helper/mixins/aliases.py index 9493ae4..7c74566 100644 --- a/sql_helper/mixins/aliases.py +++ b/sql_helper/mixins/aliases.py @@ -8,34 +8,40 @@ class AliasesMixin(_PostgresConnection): async def get_user_aliases(self, user_id: int) -> List[PartialEmoji]: await self.cur.execute( - "SELECT animated, \"name\", emote_id FROM aliases WHERE user_id=%(user_id)s", - parameters={"user_id": user_id} + 'SELECT animated, "name", emote_id FROM aliases WHERE user_id=%(user_id)s', + parameters={"user_id": user_id}, ) return await _get_emotes(self.cur) - async def get_user_aliases_after(self, user_id: int, name: str, limit: int) -> List[PartialEmoji]: + async def get_user_aliases_after( + self, user_id: int, name: str, limit: int + ) -> List[PartialEmoji]: await self.cur.execute( - "SELECT animated, \"name\", emote_id FROM aliases WHERE user_id=%(user_id)s and name>%(name)s order by name limit %(limit)s", - parameters={"user_id": user_id, "name": name, "limit": limit} + 'SELECT animated, "name", emote_id FROM aliases WHERE user_id=%(user_id)s and name>%(name)s order by name limit %(limit)s', + parameters={"user_id": user_id, "name": name, "limit": limit}, ) return await _get_emotes(self.cur) - async def get_user_alias_name(self, user_id: int, name: str) -> Optional[PartialEmoji]: + async def get_user_alias_name( + self, user_id: int, name: str + ) -> Optional[PartialEmoji]: await self.cur.execute( - "SELECT animated, \"name\", emote_id FROM aliases WHERE user_id=%(user_id)s and lower(\"name\")=%(name)s", - parameters={"user_id": user_id, "name": name.lower()} + 'SELECT animated, "name", emote_id FROM aliases WHERE user_id=%(user_id)s and lower("name")=%(name)s', + parameters={"user_id": user_id, "name": name.lower()}, ) emotes = await _get_emotes(self.cur) if emotes: return next((emote for emote in emotes if emote.name == name), emotes[0]) - async def get_user_alias_name_with_guild(self, user_id: int, name: str, guild_id: int) -> Optional[PartialEmoji]: + async def get_user_alias_name_with_guild( + self, user_id: int, name: str, guild_id: int + ) -> Optional[PartialEmoji]: # Get an alias, but exclude where the current guild has has_roles set on the alias you're trying to use # Left join here, as if it's not in the main database we still want to be able to use it as it's obviously not blocked await self.cur.execute( - "SELECT aliases.animated, aliases.\"name\", aliases.emote_id from aliases left join emote_ids on aliases.emote_id=emote_ids.emote_id where user_id=%(user_id)s and lower(aliases.\"name\")=%(name)s and " + 'SELECT aliases.animated, aliases."name", aliases.emote_id from aliases left join emote_ids on aliases.emote_id=emote_ids.emote_id where user_id=%(user_id)s and lower(aliases."name")=%(name)s and ' "emote_ids.emote_hash not in (select emote_hash from emote_ids where guild_id=%(guild_id)s and has_roles=true)", - parameters={"user_id": user_id, "name": name.lower(), "guild_id": guild_id} + parameters={"user_id": user_id, "name": name.lower(), "guild_id": guild_id}, ) emotes = await _get_emotes(self.cur) if emotes: @@ -44,7 +50,7 @@ async def get_user_alias_name_with_guild(self, user_id: int, name: str, guild_id async def count_user_aliases(self, user_id: int) -> int: await self.cur.execute( "select count(*) from aliases where user_id=%(user_id)s", - parameters={"user_id": user_id} + parameters={"user_id": user_id}, ) results = await self.cur.fetchall() return results[0][0] @@ -52,7 +58,7 @@ async def count_user_aliases(self, user_id: int) -> int: async def user_has_aliases(self, user_id: int) -> bool: await self.cur.execute( "select 1 from aliases where user_id=%(user_id)s limit 1", - parameters={"user_id": user_id} + parameters={"user_id": user_id}, ) return bool(await self.cur.fetchall()) @@ -69,37 +75,37 @@ async def set_user_aliases(self, user_id: int, aliases: List[PartialEmoji]): await self.cur.execute( "INSERT INTO aliases (user_id, emote_id, animated, name) VALUES (%(user_id)s, unnest(%(emote_ids)s), unnest(%(animateds)s), unnest(%(names)s)) " - "ON CONFLICT ON CONSTRAINT aliases_pk DO UPDATE SET emote_id = excluded.emote_id, animated = excluded.animated, \"name\" = excluded.name", + 'ON CONFLICT ON CONSTRAINT aliases_pk DO UPDATE SET emote_id = excluded.emote_id, animated = excluded.animated, "name" = excluded.name', parameters={ "user_id": user_id, "names": names, "emote_ids": ids, "animateds": animateds, - } + }, ) async def set_user_alias(self, user_id: int, alias: PartialEmoji) -> NoReturn: await self.cur.execute( "INSERT INTO aliases (user_id, emote_id, animated, name) VALUES (%(user_id)s, %(emote_id)s, %(animated)s, %(name)s) " - "ON CONFLICT ON CONSTRAINT aliases_pk DO UPDATE SET emote_id = excluded.emote_id, animated = excluded.animated, \"name\" = excluded.name", + 'ON CONFLICT ON CONSTRAINT aliases_pk DO UPDATE SET emote_id = excluded.emote_id, animated = excluded.animated, "name" = excluded.name', parameters={ "user_id": user_id, "name": alias.name, "emote_id": alias.id, "animated": alias.animated, - } + }, ) async def delete_user_alias(self, user_id: int, name: str) -> NoReturn: await self.cur.execute( - "DELETE FROM aliases WHERE user_id=%(user_id)s and \"name\"=%(name)s", - parameters={"user_id": user_id, "name": name} + 'DELETE FROM aliases WHERE user_id=%(user_id)s and "name"=%(name)s', + parameters={"user_id": user_id, "name": name}, ) async def delete_all_user_aliases(self, user_id: int) -> NoReturn: await self.cur.execute( "DELETE FROM aliases WHERE user_id=%(user_id)s", - parameters={"user_id": user_id} + parameters={"user_id": user_id}, ) diff --git a/sql_helper/mixins/blocked_emojis.py b/sql_helper/mixins/blocked_emojis.py index 2ea4554..2243e79 100644 --- a/sql_helper/mixins/blocked_emojis.py +++ b/sql_helper/mixins/blocked_emojis.py @@ -16,8 +16,7 @@ class Ordering(Enum): class BlockedEmojisMixin(_PostgresConnection): async def get_guilds_with_blocked_emojis(self) -> List[int]: await self.cur.execute( - "SELECT distinct guild_id FROM blocked_emotes", - parameters={} + "SELECT distinct guild_id FROM blocked_emotes", parameters={} ) results = await self.cur.fetchall() return [guild_id for guild_id, in results] @@ -30,7 +29,7 @@ async def has_blocked_emotes(self, guild_id: int, emote_ids: List[int]) -> bool: or exists(select emote_id from blocked_emotes where guild_id=%(guild_id)s and emote_hash in (select emote_hash from emote_ids where emote_id=any(%(emote_ids)s))) """, - parameters={"guild_id": guild_id, "emote_ids": emote_ids} + parameters={"guild_id": guild_id, "emote_ids": emote_ids}, ) results = await self.cur.fetchall() return results[0][0] @@ -47,7 +46,7 @@ async def emotes_blocked(self, guild_id: int, emote_ids: List[int]) -> List[int] union (select emote_id from emote_ids where emote_id = any(%(emote_ids)s) and emote_hash in (select emote_hash from guild_blocks)) """, - parameters={"guild_id": guild_id, "emote_ids": emote_ids} + parameters={"guild_id": guild_id, "emote_ids": emote_ids}, ) results = await self.cur.fetchall() return [guild_id for guild_id, in results] @@ -79,53 +78,59 @@ async def block_emotes(self, guild_id: int, emotes: List[PartialEmoji]): "names": names, "emote_ids": ids, "animateds": animateds, - } + }, ) async def get_blocked_emotes(self, guild_id: int) -> List[PartialEmoji]: await self.cur.execute( - "SELECT animated, \"name\", emote_id FROM blocked_emotes where guild_id = %(guild_id)s", - parameters={"guild_id": guild_id} + 'SELECT animated, "name", emote_id FROM blocked_emotes where guild_id = %(guild_id)s', + parameters={"guild_id": guild_id}, ) return await _get_emotes(self.cur) - async def get_blocked_emotes_search(self, guild_id: int, name: str, ordering: Ordering) -> Tuple[int, List[PartialEmoji]]: + async def get_blocked_emotes_search( + self, guild_id: int, name: str, ordering: Ordering + ) -> Tuple[int, List[PartialEmoji]]: await self.cur.execute( "SELECT count(*) FROM blocked_emotes where guild_id = %(guild_id)s", - parameters={"guild_id": guild_id} + parameters={"guild_id": guild_id}, ) count = await self.cur.fetchall() if ordering is Ordering.before: await self.cur.execute( - f"SELECT animated, \"name\", emote_id FROM blocked_emotes where guild_id = %(guild_id)s and \"name\" < %(name)s order by \"name\" desc limit 10", - parameters={"guild_id": guild_id, "name": name} + f'SELECT animated, "name", emote_id FROM blocked_emotes where guild_id = %(guild_id)s and "name" < %(name)s order by "name" desc limit 10', + parameters={"guild_id": guild_id, "name": name}, ) elif ordering is Ordering.equal: await self.cur.execute( - f"SELECT animated, \"name\", emote_id FROM blocked_emotes where guild_id = %(guild_id)s and \"name\" >= %(name)s order by \"name\" limit 10", - parameters={"guild_id": guild_id, "name": name} + f'SELECT animated, "name", emote_id FROM blocked_emotes where guild_id = %(guild_id)s and "name" >= %(name)s order by "name" limit 10', + parameters={"guild_id": guild_id, "name": name}, ) elif ordering is Ordering.after: await self.cur.execute( - f"SELECT animated, \"name\", emote_id FROM blocked_emotes where guild_id = %(guild_id)s and \"name\" > %(name)s order by \"name\" limit 10", - parameters={"guild_id": guild_id, "name": name} + f'SELECT animated, "name", emote_id FROM blocked_emotes where guild_id = %(guild_id)s and "name" > %(name)s order by "name" limit 10', + parameters={"guild_id": guild_id, "name": name}, ) emotes = await _get_emotes(self.cur) if name and ordering is Ordering.before: emotes.reverse() return count[0][0], emotes - async def get_blocked_emotes_with_prefix(self, guild_id: int, prefix: str) -> List[PartialEmoji]: + async def get_blocked_emotes_with_prefix( + self, guild_id: int, prefix: str + ) -> List[PartialEmoji]: await self.cur.execute( - f"SELECT animated, \"name\", emote_id FROM blocked_emotes where guild_id = %(guild_id)s and starts_with(lower(\"name\"), %(prefix)s) order by \"name\" limit 25", - parameters={"guild_id": guild_id, "prefix": prefix.lower()} + f'SELECT animated, "name", emote_id FROM blocked_emotes where guild_id = %(guild_id)s and starts_with(lower("name"), %(prefix)s) order by "name" limit 25', + parameters={"guild_id": guild_id, "prefix": prefix.lower()}, ) return await _get_emotes(self.cur) - async def get_blocked_emote_by_name(self, guild_id: int, name: str) -> Optional[PartialEmoji]: + async def get_blocked_emote_by_name( + self, guild_id: int, name: str + ) -> Optional[PartialEmoji]: await self.cur.execute( - "SELECT animated, \"name\", emote_id FROM blocked_emotes where guild_id = %(guild_id)s and \"name\" = %(name)s limit 1", - parameters={"guild_id": guild_id, "name": name} + 'SELECT animated, "name", emote_id FROM blocked_emotes where guild_id = %(guild_id)s and "name" = %(name)s limit 1', + parameters={"guild_id": guild_id, "name": name}, ) emotes = await _get_emotes(self.cur) if emotes: @@ -135,20 +140,21 @@ async def unblock_emote(self, guild_id: int, emote_id: int) -> bool: # Returns if the guild still has blocked emotes await self.cur.execute( "delete from blocked_emotes where guild_id=%(guild_id)s and emote_id=%(emote_id)s", - parameters={"guild_id": guild_id, "emote_id": emote_id} + parameters={"guild_id": guild_id, "emote_id": emote_id}, ) await self.cur.execute( "select exists(select emote_id from blocked_emotes where guild_id = %(guild_id)s)", - parameters={"guild_id": guild_id} + parameters={"guild_id": guild_id}, ) results = await self.cur.fetchall() return results[0][0] - async def share_hashes_with_blocked_emote(self, guild_id: int, emote_id: int, targets: List[int]) -> List[int]: + async def share_hashes_with_blocked_emote( + self, guild_id: int, emote_id: int, targets: List[int] + ) -> List[int]: await self.cur.execute( "select emote_id from emote_ids where emote_hash=(select emote_hash from blocked_emotes where guild_id=%(guild_id)s and emote_id=%(emote_id)s) and emote_id=ANY(%(targets)s)", - parameters={"guild_id": guild_id, "emote_id": emote_id, "targets": targets} + parameters={"guild_id": guild_id, "emote_id": emote_id, "targets": targets}, ) results = await self.cur.fetchall() return [emote_id for emote_id, in results] - diff --git a/sql_helper/mixins/command_messages.py b/sql_helper/mixins/command_messages.py index 7e4ed9e..27f3d7d 100644 --- a/sql_helper/mixins/command_messages.py +++ b/sql_helper/mixins/command_messages.py @@ -6,10 +6,7 @@ class CommandMessagesMixin(_PostgresConnection): async def add_command_message(self, *, message_id: int, user_id: int): await self.cur.execute( "INSERT INTO command_messages (message_id, user_id) VALUES (%(message_id)s, %(user_id)s)", - parameters={ - "message_id": message_id, - "user_id": user_id - } + parameters={"message_id": message_id, "user_id": user_id}, ) async def get_command_message_author(self, message_id: int) -> Optional[int]: @@ -17,10 +14,9 @@ async def get_command_message_author(self, message_id: int) -> Optional[int]: "SELECT user_id FROM command_messages WHERE message_id=%(message_id)s LIMIT 1", parameters={ "message_id": message_id, - } + }, ) results = await self.cur.fetchall() if results: return results[0][0] return bool(results) - diff --git a/sql_helper/mixins/emoji_hashes.py b/sql_helper/mixins/emoji_hashes.py index 3833bf5..6cfda45 100644 --- a/sql_helper/mixins/emoji_hashes.py +++ b/sql_helper/mixins/emoji_hashes.py @@ -7,7 +7,7 @@ class EmojiHashesMixin(_PostgresConnection): async def is_emote_hash_filtered(self, emote_hash: str) -> bool: await self.cur.execute( "SELECT 1 FROM emote_hashes WHERE emote_hash=%(emote_hash)s and filtered=true LIMIT 1", - parameters={"emote_hash": emote_hash} + parameters={"emote_hash": emote_hash}, ) results = await self.cur.fetchall() return bool(results) @@ -15,7 +15,7 @@ async def is_emote_hash_filtered(self, emote_hash: str) -> bool: async def get_all_emote_hashes_filtered(self, emote_hashes: List[str]) -> Set[str]: await self.cur.execute( "SELECT emote_hash FROM emote_hashes WHERE emote_hash=ANY(%(emote_hashes)s) and filtered=true", - parameters={"emote_hashes": emote_hashes} + parameters={"emote_hashes": emote_hashes}, ) results = await self.cur.fetchall() return set(hash for hash, in results) @@ -23,5 +23,5 @@ async def get_all_emote_hashes_filtered(self, emote_hashes: List[str]) -> Set[st async def mark_emote_hash_filtered(self, emote_hash: str): await self.cur.execute( "INSERT INTO emote_hashes (emote_hash, filtered) VALUES (%(emote_hash)s, true)", - parameters={"emote_hash": emote_hash} + parameters={"emote_hash": emote_hash}, ) diff --git a/sql_helper/mixins/emojis.py b/sql_helper/mixins/emojis.py index 31ab522..101046e 100644 --- a/sql_helper/mixins/emojis.py +++ b/sql_helper/mixins/emojis.py @@ -12,7 +12,7 @@ async def get_emote_hashes(self, emote_ids: List[int]) -> Dict[str, int]: return {} await self.cur.execute( "SELECT emote_id, emote_hash FROM emote_ids WHERE emote_id IN (SELECT(UNNEST(%(emote_ids)s)))", - parameters={"emote_ids": emote_ids} + parameters={"emote_ids": emote_ids}, ) emote_hashes = await self.cur.fetchall() return {emote_hash: emote_id for emote_id, emote_hash in emote_hashes} @@ -20,22 +20,24 @@ async def get_emote_hashes(self, emote_ids: List[int]) -> Dict[str, int]: async def get_emote_hash(self, emote_id: int) -> Optional[str]: await self.cur.execute( "SELECT emote_hash FROM emote_ids WHERE emote_id=%(emote_id)s", - parameters={"emote_id": emote_id} + parameters={"emote_id": emote_id}, ) emote_hash = await self.cur.fetchall() if emote_hash and emote_hash[0][0]: return emote_hash[0][0] - async def get_synonyms_for_emote(self, emote_hash: str, limit: Optional[int] = 10) -> List[int]: + async def get_synonyms_for_emote( + self, emote_hash: str, limit: Optional[int] = 10 + ) -> List[int]: if limit is None: await self.cur.execute( "SELECT emote_id FROM emote_ids WHERE emote_hash=%(emote_hash)s and usable=true and has_roles=false and manual_block=false", - parameters={"emote_hash": emote_hash} + parameters={"emote_hash": emote_hash}, ) else: await self.cur.execute( "SELECT emote_id FROM emote_ids WHERE emote_hash=%(emote_hash)s and usable=true and has_roles=false and manual_block=false LIMIT %(limit)s", - parameters={"emote_hash": emote_hash, "limit": limit} + parameters={"emote_hash": emote_hash, "limit": limit}, ) results = await self.cur.fetchall() @@ -44,7 +46,7 @@ async def get_synonyms_for_emote(self, emote_hash: str, limit: Optional[int] = 1 async def alternative_emote_names(self, emote_id: int) -> List[str]: await self.cur.execute( """select trim("name") as name from (select "name", count(*) from emote_ids where emote_hash=(select emote_hash from emote_ids where emote_id=%(emote_id)s) group by "name" order by count desc) a where count > 10 and lower("name") not like 'emoji%%' limit 10""", - parameters={"emote_id": emote_id} + parameters={"emote_id": emote_id}, ) results = await self.cur.fetchall() return [name for name, in results] @@ -52,7 +54,7 @@ async def alternative_emote_names(self, emote_id: int) -> List[str]: async def is_emote_blocked(self, emote_id: int) -> bool: await self.cur.execute( "SELECT 1 FROM emote_ids WHERE emote_id=%(emote_id)s and (has_roles=true or manual_block=true) LIMIT 1", - parameters={"emote_id": emote_id} + parameters={"emote_id": emote_id}, ) results = await self.cur.fetchall() return bool(results) @@ -60,7 +62,7 @@ async def is_emote_blocked(self, emote_id: int) -> bool: async def is_emote_usable(self, emote_id: int) -> bool: await self.cur.execute( "SELECT 1 FROM emote_ids WHERE emote_id=%(emote_id)s and has_roles=false and manual_block=false and usable=true LIMIT 1", - parameters={"emote_id": emote_id} + parameters={"emote_id": emote_id}, ) results = await self.cur.fetchall() return bool(results) @@ -68,29 +70,31 @@ async def is_emote_usable(self, emote_id: int) -> bool: async def purge_emojis(self, emote_ids: List[int]) -> Set[int]: await self.cur.execute( "SELECT emote_id FROM emote_ids WHERE emote_id=ANY(%(emote_ids)s) and guild_id is NULL", - parameters={"emote_ids": emote_ids} + parameters={"emote_ids": emote_ids}, ) results = await self.cur.fetchall() if results: await self.cur.execute( "DELETE FROM emote_ids WHERE emote_id=ANY(%(emote_ids)s) and guild_id is NULL", - parameters={"emote_ids": emote_ids} + parameters={"emote_ids": emote_ids}, ) return set(emote_id for emote_id, in results) async def share_hashes(self, emote_id_1: int, emote_id_2: int) -> bool: await self.cur.execute( "select 1 from emote_ids where emote_hash=(select emote_hash from emote_ids where emote_id = %(emote_id_1)s) and emote_id = %(emote_id_2)s LIMIT 1", - parameters={"emote_id_1": emote_id_1, "emote_id_2": emote_id_2} + parameters={"emote_id_1": emote_id_1, "emote_id_2": emote_id_2}, ) results = await self.cur.fetchall() return bool(results) - async def get_emote_like(self, emote_id: int, *, require_guild: bool = True) -> Optional[Emoji]: + async def get_emote_like( + self, emote_id: int, *, require_guild: bool = True + ) -> Optional[Emoji]: # Get the emoji we're searching for await self.cur.execute( "SELECT emote_id, emote_hash, usable, animated, emote_sha, guild_id, trim(name), has_roles FROM emote_ids WHERE emote_id=%(emote_id)s LIMIT 1", - parameters={"emote_id": emote_id} + parameters={"emote_id": emote_id}, ) results = await self.cur.fetchall() if not results: @@ -106,23 +110,35 @@ async def get_emote_like(self, emote_id: int, *, require_guild: bool = True) -> else: return self._get_emoji(await self._get_emote_like(emote, require_guild)) - async def _get_emote_like(self, emote: SQLEmoji, require_guild: bool) -> Optional[SQLEmoji]: + async def _get_emote_like( + self, emote: SQLEmoji, require_guild: bool + ) -> Optional[SQLEmoji]: # At this point, we can't use the original emoji. Let's look for a similar one if require_guild: await self.cur.execute( "SELECT emote_id, emote_hash, usable, animated, emote_sha, guild_id, trim(name), has_roles FROM emote_ids WHERE emote_hash=%(emote_hash)s and guild_id is not null and usable=true LIMIT 1", - parameters={"emote_hash": emote.emote_hash} + parameters={"emote_hash": emote.emote_hash}, ) else: await self.cur.execute( "SELECT emote_id, emote_hash, usable, animated, emote_sha, guild_id, trim(name), has_roles FROM emote_ids WHERE emote_hash=%(emote_hash)s and usable=true LIMIT 1", - parameters={"emote_hash": emote.emote_hash} + parameters={"emote_hash": emote.emote_hash}, ) results = await self.cur.fetchall() if results: return SQLEmoji(*results[0]) - async def set_emote_perceptual_data(self, emote_id: int, guild_id: int, emote_hash: str, emote_sha: str, animated: bool, usable: bool, name: Optional[str], has_roles: bool): + async def set_emote_perceptual_data( + self, + emote_id: int, + guild_id: int, + emote_hash: str, + emote_sha: str, + animated: bool, + usable: bool, + name: Optional[str], + has_roles: bool, + ): await self.cur.execute( "INSERT INTO emote_ids (emote_id, emote_hash, usable, animated, emote_sha, guild_id, name, has_roles) VALUES (%(emote_id)s, %(emote_hash)s, %(usable)s, %(animated)s, %(emote_sha)s, %(guild_id)s, %(name)s, %(has_roles)s) ON CONFLICT (emote_id) DO UPDATE SET name=%(name)s, has_roles=%(has_roles)s, guild_id=coalesce(emote_ids.guild_id, %(guild_id)s)", parameters={ @@ -134,20 +150,38 @@ async def set_emote_perceptual_data(self, emote_id: int, guild_id: int, emote_ha "usable": usable, "name": name, "has_roles": has_roles, - } + }, ) - async def set_emote_guild(self, emote_id: int, guild_id: Optional[int], usable: Optional[bool], has_roles: bool, name: Optional[str]): + async def set_emote_guild( + self, + emote_id: int, + guild_id: Optional[int], + usable: Optional[bool], + has_roles: bool, + name: Optional[str], + ): if usable is None: # We don't know if the emote is available or not await self.cur.execute( "UPDATE emote_ids SET guild_id=%(guild_id)s, name=%(name)s, has_roles=%(has_roles)s where emote_id=%(emote_id)s", - parameters={"emote_id": emote_id, "guild_id": guild_id, "name": name, "has_roles": has_roles} + parameters={ + "emote_id": emote_id, + "guild_id": guild_id, + "name": name, + "has_roles": has_roles, + }, ) else: await self.cur.execute( "UPDATE emote_ids SET guild_id=%(guild_id)s, usable=%(usable)s, name=%(name)s, has_roles=%(has_roles)s where emote_id=%(emote_id)s", - parameters={"emote_id": emote_id, "guild_id": guild_id, "usable": usable, "name": name, "has_roles": has_roles} + parameters={ + "emote_id": emote_id, + "guild_id": guild_id, + "usable": usable, + "name": name, + "has_roles": has_roles, + }, ) async def increment_guild_emote_score(self, emoji_guild_ids: List[Tuple[int, int]]): @@ -157,32 +191,32 @@ async def increment_guild_emote_score(self, emoji_guild_ids: List[Tuple[int, int # select (-128)>>4+7 -> -1; await self.cur.execute( 'UPDATE emote_ids SET score=((score::int) + 1)::"char" ' - 'where (guild_id, emote_id) in (SELECT * FROM unnest(%(emoji_guild_ids)s) as f(guild_id bigint, emote_id bigint)) and ' - 'trunc(random() * (2<<(((score::int)>>5)+3))) = 0 and ' - 'score::int != 127', # POSITIVE 127 is max number here as signed - parameters={"emoji_guild_ids": emoji_guild_ids} + "where (guild_id, emote_id) in (SELECT * FROM unnest(%(emoji_guild_ids)s) as f(guild_id bigint, emote_id bigint)) and " + "trunc(random() * (2<<(((score::int)>>5)+3))) = 0 and " + "score::int != 127", # POSITIVE 127 is max number here as signed + parameters={"emoji_guild_ids": emoji_guild_ids}, ) async def decay_guild_emote_score(self): await self.cur.execute( 'update emote_ids set score=(score::int - 1)::"char" ' - 'where guild_id in (select guild_id from emote_ids where score::int > -28 and guild_id is not null) and ' - 'score::int != -128', # -128 is smallest number - parameters={} + "where guild_id in (select guild_id from emote_ids where score::int > -28 and guild_id is not null) and " + "score::int != -128", # -128 is smallest number + parameters={}, ) async def guild_emote_scores(self, guild_id: int) -> Counter: await self.cur.execute( - 'select emote_id, score::int + 128 from emote_ids where guild_id=%(guild_id)s order by score desc', - parameters={"guild_id": guild_id} + "select emote_id, score::int + 128 from emote_ids where guild_id=%(guild_id)s order by score desc", + parameters={"guild_id": guild_id}, ) emote_scores = await self.cur.fetchall() return Counter({k: v for k, v in emote_scores}) async def emote_score(self, emote_id: int) -> int: await self.cur.execute( - 'select score::int + 128 from emote_ids where emote_id=%(emote_id)s', - parameters={"emote_id": emote_id} + "select score::int + 128 from emote_ids where emote_id=%(emote_id)s", + parameters={"emote_id": emote_id}, ) emote_score = await self.cur.fetchall() if not emote_score: @@ -192,7 +226,7 @@ async def emote_score(self, emote_id: int) -> int: async def daily_scores_for_emotes(self, emote_id: int) -> int: await self.cur.execute( "select score::int + 128 as score from emote_ids where emote_hash=(select emote_hash from emote_ids where emote_id=%(emote_id)s) and score::int > -28", - parameters={"emote_id": emote_id} + parameters={"emote_id": emote_id}, ) emote_scores = await self.cur.fetchall() return [score for score, in emote_scores] @@ -200,43 +234,42 @@ async def daily_scores_for_emotes(self, emote_id: int) -> int: async def clear_guild_emotes(self, guild_id: int): await self.cur.execute( "UPDATE emote_ids SET guild_id=null where guild_id=%(guild_id)s", - parameters={"guild_id": guild_id} + parameters={"guild_id": guild_id}, ) async def set_emote_usability(self, emote_id: int, usable: Optional[bool]): await self.cur.execute( "UPDATE emote_ids SET usable=%(usable)s where emote_id=%(emote_id)s", - parameters={"emote_id": emote_id, "usable": usable} + parameters={"emote_id": emote_id, "usable": usable}, ) async def set_emotes_unusuable(self, emote_ids: List[int]): await self.cur.execute( "UPDATE emote_ids SET usable=false where emote_id=ANY(%(emote_ids)s)", - parameters={"emote_ids": emote_ids} + parameters={"emote_ids": emote_ids}, ) async def clear_guild_emojis(self, guild_id: int): await self.cur.execute( "UPDATE emote_ids SET guild_id=null where guild_id=%(guild_id)s", - parameters={"guild_id": guild_id} + parameters={"guild_id": guild_id}, ) async def clear_guilds_emojis(self, guild_ids: List[int]): await self.cur.execute( "UPDATE emote_ids SET guild_id=null where guild_id=ANY(%(guild_ids)s)", - parameters={"guild_ids": guild_ids} + parameters={"guild_ids": guild_ids}, ) async def clear_guilds_from_emojis(self, emoji_ids: List[int]): await self.cur.execute( "UPDATE emote_ids SET guild_id=null where emote_id=ANY(%(emoji_ids)s)", - parameters={"emoji_ids": emoji_ids} + parameters={"emoji_ids": emoji_ids}, ) async def get_guilds_with_emojis(self) -> Set[int]: await self.cur.execute( - "SELECT guild_id FROM emote_ids WHERE guild_id IS NOT NULL", - parameters={} + "SELECT guild_id FROM emote_ids WHERE guild_id IS NOT NULL", parameters={} ) results = await self.cur.fetchall() return {guild_id for guild_id, in results} @@ -245,63 +278,71 @@ async def get_pack_emote(self, pack_name: str, emote_name: str) -> Optional[Emoj return await self._case_insensitive_get_emote( "guild_id=(select guild_id from packs where pack_name=%(pack_name)s)", emote_name, - parameters={"pack_name": pack_name} + parameters={"pack_name": pack_name}, ) async def get_guild_emote(self, guild_id: int, emote_name: str) -> Optional[Emoji]: return await self._case_insensitive_get_emote( - "guild_id=%(guild_id)s", - emote_name, - parameters={"guild_id": guild_id} + "guild_id=%(guild_id)s", emote_name, parameters={"guild_id": guild_id} ) async def guild_emote_counts(self, guild_id: int) -> EmojiCounts: await self.cur.execute( "SELECT COUNT(*) filter(where not animated) as static, COUNT(*) filter(where animated) as animated FROM emote_ids where guild_id=%(guild_id)s", - parameters={"guild_id": guild_id} + parameters={"guild_id": guild_id}, ) counts = await self.cur.fetchall() return EmojiCounts(static=counts[0][0], animated=counts[0][1]) - async def get_mutual_guild_emote(self, user_id: int, emote_name: str) -> Optional[Emoji]: + async def get_mutual_guild_emote( + self, user_id: int, emote_name: str + ) -> Optional[Emoji]: return await self._case_insensitive_get_emote( "guild_id in (select guild_id from members where user_id=%(user_id)s)", emote_name, - parameters={"user_id": user_id} + parameters={"user_id": user_id}, ) - async def get_pack_guild_emote(self, user_id: int, emote_name: str) -> Optional[Emoji]: + async def get_pack_guild_emote( + self, user_id: int, emote_name: str + ) -> Optional[Emoji]: return await self._case_insensitive_get_emote( "guild_id in (select guild_id from user_packs where user_id=%(user_id)s)", emote_name, - parameters={"user_id": user_id} + parameters={"user_id": user_id}, ) - async def get_guild_emotes(self, guild_id: int, *, order_by_score: bool = False) -> List[Emoji]: + async def get_guild_emotes( + self, guild_id: int, *, order_by_score: bool = False + ) -> List[Emoji]: return await self._get_emojis( "guild_id=%(guild_id)s", "order by score desc" if order_by_score else "", - parameters={"guild_id": guild_id} + parameters={"guild_id": guild_id}, ) - async def get_mutual_guild_emotes(self, user_id: int, *, order_by_score: bool = False) -> List[Emoji]: + async def get_mutual_guild_emotes( + self, user_id: int, *, order_by_score: bool = False + ) -> List[Emoji]: return await self._get_emojis( "guild_id in (select guild_id from members where user_id=%(user_id)s)", "order by score desc" if order_by_score else "", - parameters={"user_id": user_id} + parameters={"user_id": user_id}, ) - async def get_guilds_emotes(self, guild_ids: List[int], *, order_by_score: bool = False) -> List[Emoji]: + async def get_guilds_emotes( + self, guild_ids: List[int], *, order_by_score: bool = False + ) -> List[Emoji]: return await self._get_emojis( "guild_id = ANY(%(guild_ids)s)", "order by score desc" if order_by_score else "", - parameters={"guild_ids": guild_ids} + parameters={"guild_ids": guild_ids}, ) async def get_guilds_emotes_raw(self, guild_ids: List[int]) -> List: await self.cur.execute( f"select emote_id, guild_id, trim(name), has_roles from emote_ids where guild_id = ANY(%(guild_ids)s)", - parameters={"guild_ids": guild_ids} + parameters={"guild_ids": guild_ids}, ) results = await self.cur.fetchall() return results @@ -309,7 +350,7 @@ async def get_guilds_emotes_raw(self, guild_ids: List[int]) -> List: async def get_guild_emotes_raw(self, guild_id: int) -> List: await self.cur.execute( f"select emote_id, emote_hash, usable, animated, emote_sha, guild_id, trim(name), has_roles from emote_ids where guild_id = %(guild_id)s", - parameters={"guild_id": guild_id} + parameters={"guild_id": guild_id}, ) results = await self.cur.fetchall() return [SQLEmoji(*i) for i in results] @@ -317,7 +358,7 @@ async def get_guild_emotes_raw(self, guild_id: int) -> List: async def get_emotes_raw(self, emote_ids: List[int]) -> List: await self.cur.execute( f"select emote_id, emote_hash, usable, animated, emote_sha, guild_id, trim(name), has_roles from emote_ids where emote_id = ANY(%(emote_ids)s)", - parameters={"emote_ids": emote_ids} + parameters={"emote_ids": emote_ids}, ) results = await self.cur.fetchall() return [SQLEmoji(*i) for i in results] @@ -325,37 +366,41 @@ async def get_emotes_raw(self, emote_ids: List[int]) -> List: async def get_emotes(self, emote_ids: List[int]) -> List[Emoji]: return await self._get_emojis( "emote_id = ANY(%(emote_ids)s) and guild_id is not NULL", - parameters={"emote_ids": emote_ids} + parameters={"emote_ids": emote_ids}, ) async def get_pack_guild_emotes(self, user_id: int) -> List[Emoji]: return await self._get_emojis( "guild_id in (select guild_id from user_packs where user_id=%(user_id)s)", - parameters={"user_id": user_id} + parameters={"user_id": user_id}, ) async def get_mega_emotes(self, guild_id: int, emote_name: str) -> List[Emoji]: await self.cur.execute( "select emote_id, emote_hash, usable, animated, emote_sha, guild_id, trim(name), has_roles from emote_ids where guild_id=%(guild_id)s and trim(name) ~ ('^' || %(emote_name)s || '_\d_\d$') and usable=true and has_roles=false and manual_block=false", - parameters={"guild_id": guild_id, "emote_name": emote_name} + parameters={"guild_id": guild_id, "emote_name": emote_name}, ) results = await self.cur.fetchall() return [self._get_emoji(SQLEmoji(*emote)) for emote in results] - async def _get_emojis(self, query_where, query_suffix: str = "", *, parameters) -> List[Emoji]: + async def _get_emojis( + self, query_where, query_suffix: str = "", *, parameters + ) -> List[Emoji]: await self.cur.execute( f"select emote_id, emote_hash, usable, animated, emote_sha, guild_id, trim(name), has_roles from emote_ids where {query_where} and usable=true and has_roles=false and manual_block=false {query_suffix}", - parameters=parameters + parameters=parameters, ) results = await self.cur.fetchall() emotes = (self._get_emoji(SQLEmoji(*i)) for i in results) return [emote for emote in emotes if emote] - async def _case_insensitive_get_emote(self, query_where, emote_name: str, parameters) -> Optional[Emoji]: + async def _case_insensitive_get_emote( + self, query_where, emote_name: str, parameters + ) -> Optional[Emoji]: case_insensitive = await self._get_emote_with_where( f"{query_where} and lower(trim(name))=lower(%(emote_name)s)", emote_name, - parameters + parameters, ) # If the least sensitive query didn't get us anything, we're done. if case_insensitive is None: @@ -367,16 +412,16 @@ async def _case_insensitive_get_emote(self, query_where, emote_name: str, parame case_sensitive = case_insensitive else: case_sensitive = await self._get_emote_with_where( - f"{query_where} and trim(name)=%(emote_name)s", - emote_name, - parameters + f"{query_where} and trim(name)=%(emote_name)s", emote_name, parameters ) if case_sensitive is not None: assert case_sensitive.name == emote_name if case_sensitive.usable: return self._get_emoji(case_sensitive) # See if we can find a copy that's usable! - case_sensitive = await self._get_emote_like(case_sensitive, require_guild=True) + case_sensitive = await self._get_emote_like( + case_sensitive, require_guild=True + ) if case_sensitive is not None: return self._get_emoji(case_sensitive) # At this point, give up on case sensitivity. Let's see if we can find something usable that's insensitive. @@ -384,18 +429,22 @@ async def _case_insensitive_get_emote(self, query_where, emote_name: str, parame usable_case_insensitive = await self._get_emote_with_where( f"{query_where} and lower(trim(name))=lower(%(emote_name)s) and usable=true", emote_name, - parameters + parameters, ) if usable_case_insensitive is None: # If not, well we found *something* insensitive earlier - usable_case_insensitive = await self._get_emote_like(case_insensitive, require_guild=True) + usable_case_insensitive = await self._get_emote_like( + case_insensitive, require_guild=True + ) if usable_case_insensitive is not None: return self._get_emoji(usable_case_insensitive) - async def _get_emote_with_where(self, query_where: str, emote_name: str, parameters) -> Optional[SQLEmoji]: + async def _get_emote_with_where( + self, query_where: str, emote_name: str, parameters + ) -> Optional[SQLEmoji]: await self.cur.execute( f"select emote_id, emote_hash, usable, animated, emote_sha, guild_id, trim(name), has_roles from emote_ids where {query_where} and has_roles=false and manual_block=false LIMIT 1", - parameters={**parameters, "emote_name": emote_name} + parameters={**parameters, "emote_name": emote_name}, ) results = await self.cur.fetchall() if results: diff --git a/sql_helper/mixins/emojis_used.py b/sql_helper/mixins/emojis_used.py index e1a2e21..2c8395f 100644 --- a/sql_helper/mixins/emojis_used.py +++ b/sql_helper/mixins/emojis_used.py @@ -23,29 +23,33 @@ async def add_used_emotes(self, to_cache: List[Tuple[int, PartialEmoji, datetime await self.cur.execute( f"INSERT INTO emotes_used (time, guild_id, name, emote_id, animated) VALUES (unnest(%(times)s), unnest(%(guild_ids)s), unnest(%(names)s), unnest(%(ids)s), unnest(%(animateds)s))", - parameters=params + parameters=params, ) - async def get_recently_used_emote(self, guild_id: int, name: str) -> Optional[PartialEmoji]: + async def get_recently_used_emote( + self, guild_id: int, name: str + ) -> Optional[PartialEmoji]: await self.cur.execute( - "SELECT first(animated, time), first(name, time), emote_id FROM emotes_used WHERE guild_id=%(guild_id)s and lower(\"name\")=%(name)s group by emote_id", - parameters={"guild_id": guild_id, "name": name.lower()} + 'SELECT first(animated, time), first(name, time), emote_id FROM emotes_used WHERE guild_id=%(guild_id)s and lower("name")=%(name)s group by emote_id', + parameters={"guild_id": guild_id, "name": name.lower()}, ) emotes = await _get_emotes(self.cur) if emotes: return next((emote for emote in emotes if emote.name == name), emotes[0]) - async def get_recently_used_emotes(self, guild_id: int, prefix: str, limit: int = 25) -> List[PartialEmoji]: + async def get_recently_used_emotes( + self, guild_id: int, prefix: str, limit: int = 25 + ) -> List[PartialEmoji]: if not prefix: await self.cur.execute( "select first(animated, time), first(name, time), emote_id from emotes_used where guild_id=%(guild_id)s group by emote_id limit %(limit)s", - parameters={"guild_id": guild_id, "limit": limit} + parameters={"guild_id": guild_id, "limit": limit}, ) else: # This one doesn't use an index for the whole thing. Should be OK though await self.cur.execute( - "select first(animated, time), first(name, time), emote_id from emotes_used where guild_id=%(guild_id)s and starts_with(lower(\"name\"), %(prefix)s) group by emote_id limit %(limit)s", - parameters={"guild_id": guild_id, "prefix": prefix, "limit": limit} + 'select first(animated, time), first(name, time), emote_id from emotes_used where guild_id=%(guild_id)s and starts_with(lower("name"), %(prefix)s) group by emote_id limit %(limit)s', + parameters={"guild_id": guild_id, "prefix": prefix, "limit": limit}, ) return await _get_emotes(self.cur) diff --git a/sql_helper/mixins/guild_members.py b/sql_helper/mixins/guild_members.py index 1a9248d..f3ad6cf 100644 --- a/sql_helper/mixins/guild_members.py +++ b/sql_helper/mixins/guild_members.py @@ -12,7 +12,7 @@ async def mutual_guild_ids(self, user_id: Union[User, int]) -> AsyncList: user_id = user_id.id await self.cur.execute( "SELECT guild_id FROM members WHERE user_id=%(user_id)s", - parameters={"user_id": user_id} + parameters={"user_id": user_id}, ) guilds = await self.cur.fetchall() return [g for g, in guilds] @@ -29,7 +29,7 @@ async def mutual_guilds(self, user_id: Union[User, int]) -> AsyncList: async def in_guild(self, *, user_id: int, guild_id: int) -> bool: await self.cur.execute( "SELECT 1 FROM members WHERE user_id=%(user_id)s AND guild_id=%(guild_id)s LIMIT 1", - parameters={"user_id": user_id, "guild_id": guild_id} + parameters={"user_id": user_id, "guild_id": guild_id}, ) return bool(await self.cur.fetchall()) @@ -37,7 +37,7 @@ async def random_member(self, guild_id: int) -> Optional[int]: for i in range(10): await self.cur.execute( "SELECT user_id FROM members WHERE guild_id=%(guild_id)s AND random() < (SELECT 2::decimal/count(*) FROM members WHERE guild_id=%(guild_id)s) LIMIT 1", - parameters={"guild_id": guild_id} + parameters={"guild_id": guild_id}, ) user_id = await self.cur.fetchall() if user_id: @@ -48,15 +48,10 @@ async def set_user_guilds(self, user_id: int, guild_ids: List[int]): async with self.cur.begin(): await self.cur.execute( "DELETE FROM members WHERE members.user_id = %(user_id)s", - parameters={ - "user_id": user_id - } + parameters={"user_id": user_id}, ) for guild_id in guild_ids: await self.cur.execute( "INSERT INTO members (guild_id, user_id) VALUES (%(guild_id)s, %(user_id)s)", - parameters={ - "user_id": user_id, - "guild_id": guild_id - } + parameters={"user_id": user_id, "guild_id": guild_id}, ) diff --git a/sql_helper/mixins/guild_messages.py b/sql_helper/mixins/guild_messages.py index c87ad30..96e86f8 100644 --- a/sql_helper/mixins/guild_messages.py +++ b/sql_helper/mixins/guild_messages.py @@ -6,15 +6,17 @@ class GuildMessagesMixin(_PostgresConnection): - async def add_guild_message(self, *, message_id: int, guild_id: int, channel_id: int, user_id: int): + async def add_guild_message( + self, *, message_id: int, guild_id: int, channel_id: int, user_id: int + ): await self.cur.execute( "INSERT INTO guild_messages (guild_id, channel_id, message_id, user_id) VALUES (%(guild_id)s, %(channel_id)s, %(message_id)s, %(user_id)s) ON CONFLICT ON CONSTRAINT guild_messages_pk DO NOTHING", parameters={ "guild_id": guild_id, "channel_id": channel_id, "message_id": message_id, - "user_id": user_id - } + "user_id": user_id, + }, ) async def add_guild_message_bulk(self, messages: Iterable[GuildMessage]): @@ -34,49 +36,60 @@ async def add_guild_message_bulk(self, messages: Iterable[GuildMessage]): "channel_id": channels, "message_id": message_ids, "user_id": users, - } + }, ) async def delete_guild_message(self, *, message_id: int): await self.cur.execute( "DELETE FROM guild_messages where message_id=%(message_id)s", - parameters={ - "message_id": message_id - } + parameters={"message_id": message_id}, ) @async_list async def get_guild_messages( - self, - *, - message_id: Optional[int] = None, - guild_id: Optional[int] = None, - channel_id: Optional[int] = None, - user_id: Optional[int] = None, - offset: Optional[int] = None, - no_results: int + self, + *, + message_id: Optional[int] = None, + guild_id: Optional[int] = None, + channel_id: Optional[int] = None, + user_id: Optional[int] = None, + offset: Optional[int] = None, + no_results: int, ) -> List[GuildMessage]: - await self._get_guild_message("guild_id, channel_id, message_id, user_id", "ORDER BY message_id DESC", message_id, guild_id, channel_id, user_id, offset, no_results) + await self._get_guild_message( + "guild_id, channel_id, message_id, user_id", + "ORDER BY message_id DESC", + message_id, + guild_id, + channel_id, + user_id, + offset, + no_results, + ) results = await self.cur.fetchall() return [GuildMessage(*i) for i in results] @async_list - async def get_guild_message_ids(self, guild_id: int, message_ids: List[int]) -> List[Tuple[int, int]]: + async def get_guild_message_ids( + self, guild_id: int, message_ids: List[int] + ) -> List[Tuple[int, int]]: await self.cur.execute( f"SELECT message_id, user_id FROM guild_messages WHERE guild_id=%(guild_id)s and message_id=ANY(%(message_ids)s)", - parameters={"guild_id": guild_id, "message_ids": message_ids} + parameters={"guild_id": guild_id, "message_ids": message_ids}, ) return await self.cur.fetchall() async def get_guild_message_count( - self, - *, - message_id: Optional[int] = None, - guild_id: Optional[int] = None, - channel_id: Optional[int] = None, - user_id: Optional[int] = None, + self, + *, + message_id: Optional[int] = None, + guild_id: Optional[int] = None, + channel_id: Optional[int] = None, + user_id: Optional[int] = None, ) -> int: - await self._get_guild_message("count(*)", "", message_id, guild_id, channel_id, user_id, 0, 1) + await self._get_guild_message( + "count(*)", "", message_id, guild_id, channel_id, user_id, 0, 1 + ) results = await self.cur.fetchall() return results[0][0] @@ -84,7 +97,7 @@ async def get_guild_message_count( async def get_users_posted_since(self, since: int): await self.cur.execute( "select distinct user_id from guild_messages where message_id > %(message_id)s", - parameters={"message_id": since} + parameters={"message_id": since}, ) results = await self.cur.fetchall() return (i[0] for i in results) @@ -92,51 +105,75 @@ async def get_users_posted_since(self, since: int): async def user_has_posted(self, user_id: int) -> bool: await self.cur.execute( "select 1 from guild_messages where user_id = %(user_id)s LIMIT 1", - parameters={"user_id": user_id} + parameters={"user_id": user_id}, ) results = await self.cur.fetchall() return bool(results) async def _get_guild_message( - self, - select: str, - order: str, - message_id: Optional[int] = None, - guild_id: Optional[int] = None, - channel_id: Optional[int] = None, - user_id: Optional[int] = None, - offset: Optional[int] = None, - no_results: int = 1 + self, + select: str, + order: str, + message_id: Optional[int] = None, + guild_id: Optional[int] = None, + channel_id: Optional[int] = None, + user_id: Optional[int] = None, + offset: Optional[int] = None, + no_results: int = 1, ): if message_id is not None: await self.cur.execute( f"SELECT {select} FROM guild_messages WHERE message_id=%(message_id)s LIMIT %(limit)s OFFSET %(offset)s", - parameters={"message_id": message_id, "limit": no_results, "offset": offset} + parameters={ + "message_id": message_id, + "limit": no_results, + "offset": offset, + }, ) elif channel_id and user_id: await self.cur.execute( f"SELECT {select} FROM guild_messages WHERE channel_id=%(channel_id)s AND user_id=%(user_id)s {order} LIMIT %(limit)s OFFSET %(offset)s", - parameters={"channel_id": channel_id, "user_id": user_id, "limit": no_results, "offset": offset} + parameters={ + "channel_id": channel_id, + "user_id": user_id, + "limit": no_results, + "offset": offset, + }, ) elif channel_id: await self.cur.execute( f"SELECT {select} FROM guild_messages WHERE channel_id=%(channel_id)s {order} LIMIT %(limit)s OFFSET %(offset)s", - parameters={"channel_id": channel_id, "limit": no_results, "offset": offset} + parameters={ + "channel_id": channel_id, + "limit": no_results, + "offset": offset, + }, ) elif guild_id and user_id: await self.cur.execute( f"SELECT {select} FROM guild_messages WHERE guild_id=%(guild_id)s AND user_id=%(user_id)s {order} LIMIT %(limit)s OFFSET %(offset)s", - parameters={"guild_id": guild_id, "user_id": user_id, "limit": no_results, "offset": offset} + parameters={ + "guild_id": guild_id, + "user_id": user_id, + "limit": no_results, + "offset": offset, + }, ) elif guild_id: await self.cur.execute( f"SELECT {select} FROM guild_messages WHERE guild_id=%(guild_id)s {order} LIMIT %(limit)s OFFSET %(offset)s", - parameters={"guild_id": guild_id, "limit": no_results, "offset": offset} + parameters={ + "guild_id": guild_id, + "limit": no_results, + "offset": offset, + }, ) elif user_id: await self.cur.execute( f"SELECT {select} FROM guild_messages WHERE user_id=%(user_id)s {order} LIMIT %(limit)s OFFSET %(offset)s", - parameters={"user_id": user_id, "limit": no_results, "offset": offset} + parameters={"user_id": user_id, "limit": no_results, "offset": offset}, ) else: - raise NotImplementedError(f"Invalid combination of requests: message_id={message_id} guild_id={guild_id} channel_id={channel_id} user_id={user_id}") + raise NotImplementedError( + f"Invalid combination of requests: message_id={message_id} guild_id={guild_id} channel_id={channel_id} user_id={user_id}" + ) diff --git a/sql_helper/mixins/guild_settings.py b/sql_helper/mixins/guild_settings.py index eec9244..71d1a15 100644 --- a/sql_helper/mixins/guild_settings.py +++ b/sql_helper/mixins/guild_settings.py @@ -23,9 +23,7 @@ async def personas_guilds(self) -> Set[str]: @async_list async def guild_prefixes(self) -> AsyncList: - await self.cur.execute( - "SELECT guild_id, prefix FROM guild_settings" - ) + await self.cur.execute("SELECT guild_id, prefix FROM guild_settings") return await self.cur.fetchall() @async_list @@ -36,10 +34,12 @@ async def guild_settings(self) -> AsyncList: results = await self.cur.fetchall() return [GuildSettings(*i) for i in results] - async def get_guild_settings(self, guild_id: Union[Guild, int]) -> Optional[GuildSettings]: + async def get_guild_settings( + self, guild_id: Union[Guild, int] + ) -> Optional[GuildSettings]: await self.cur.execute( "SELECT guild_id, prefix, locale, max_guildwide_emotes, nitro_role, boost_channel, boost_role, audit_channel, enable_stickers, enable_nitro, enable_replies, is_alias_server, enable_pings, enable_user_content, enable_personas, enable_dashboard_posting, enable_phish_detection FROM guild_settings WHERE guild_id=%(guild_id)s", - parameters={"guild_id": guild_id} + parameters={"guild_id": guild_id}, ) results = await self.cur.fetchall() if not results: @@ -49,7 +49,7 @@ async def get_guild_settings(self, guild_id: Union[Guild, int]) -> Optional[Guil async def get_guild_rank(self, max_emotes: int) -> int: await self.cur.execute( "SELECT count(*) FROM guild_settings WHERE max_guildwide_emotes > %(max_emotes)s", - parameters={"max_emotes": max_emotes} + parameters={"max_emotes": max_emotes}, ) results = await self.cur.fetchall() return results[0][0] @@ -66,11 +66,14 @@ async def set_guild_settings(self, guild_settings: GuildSettings): "(%(guild_id)s, %(prefix)s, %(nitro_role)s, %(boost_channel)s, %(boost_role)s, %(audit_channel)s, %(enable_stickers)s, %(enable_nitro)s, %(enable_replies)s, %(is_alias_server)s, %(locale)s, %(enable_pings)s, %(max_guildwide_emotes)s, %(enable_user_content)s, %(enable_personas)s, %(enable_dashboard_posting)s, %(enable_phish_detection)s)" 'ON CONFLICT (guild_id) DO UPDATE SET (prefix, nitro_role, boost_channel, boost_role, audit_channel, enable_stickers, enable_nitro, enable_replies, is_alias_server, "locale", enable_pings, max_guildwide_emotes, enable_user_content, enable_personas, enable_dashboard_posting, enable_phish_detection) = ' "(EXCLUDED.prefix, EXCLUDED.nitro_role, EXCLUDED.boost_channel, EXCLUDED.boost_role, EXCLUDED.audit_channel, EXCLUDED.enable_stickers, EXCLUDED.enable_nitro, EXCLUDED.enable_replies, EXCLUDED.is_alias_server, EXCLUDED.locale, EXCLUDED.enable_pings, EXCLUDED.max_guildwide_emotes, EXCLUDED.enable_user_content, EXCLUDED.enable_personas, EXCLUDED.enable_dashboard_posting, EXCLUDED.enable_phish_detection)", - parameters={field.name: getattr(guild_settings, field.name) for field in fields(guild_settings)} + parameters={ + field.name: getattr(guild_settings, field.name) + for field in fields(guild_settings) + }, ) async def delete_guild_settings(self, guild_id: int): await self.cur.execute( "DELETE FROM guild_settings WHERE guild_id=%(guild_id)s", - parameters={"guild_id": guild_id} + parameters={"guild_id": guild_id}, ) diff --git a/sql_helper/mixins/guild_webhooks.py b/sql_helper/mixins/guild_webhooks.py index a87ea85..83baaa7 100644 --- a/sql_helper/mixins/guild_webhooks.py +++ b/sql_helper/mixins/guild_webhooks.py @@ -11,7 +11,7 @@ class GuildWebhooksMixin(_PostgresConnection): async def get_channel_webhooks(self, channel_id: int) -> AsyncList: await self.cur.execute( "SELECT * FROM webhooks WHERE channel_id=%(channel_id)s", - parameters={"channel_id": channel_id} + parameters={"channel_id": channel_id}, ) results = await self.cur.fetchall() return [Webhook(*i) for i in results] @@ -19,18 +19,20 @@ async def get_channel_webhooks(self, channel_id: int) -> AsyncList: async def get_webhook_by_id(self, webhook_id: int) -> Optional[DiscordWebhook]: await self.cur.execute( "SELECT * FROM webhooks WHERE webhook_id=%(webhook_id)s", - parameters={"webhook_id": webhook_id} + parameters={"webhook_id": webhook_id}, ) results = await self.cur.fetchall() if results: return Webhook(*results[0]) - async def set_channel_webhooks(self, channel_id: int, webhooks: List[DiscordWebhook], *, delete: bool = True): + async def set_channel_webhooks( + self, channel_id: int, webhooks: List[DiscordWebhook], *, delete: bool = True + ): async with self.cur.begin(): if delete: await self.cur.execute( "DELETE FROM webhooks WHERE channel_id=%(channel_id)s", - parameters={"channel_id": channel_id} + parameters={"channel_id": channel_id}, ) for webhook in webhooks: try: @@ -46,6 +48,6 @@ async def set_channel_webhooks(self, channel_id: int, webhooks: List[DiscordWebh "channel_id": webhook.channel_id, "user_id": user_id, "token": webhook.token, - "name": webhook.name - } + "name": webhook.name, + }, ) diff --git a/sql_helper/mixins/packs.py b/sql_helper/mixins/packs.py index 0025477..a65847b 100644 --- a/sql_helper/mixins/packs.py +++ b/sql_helper/mixins/packs.py @@ -7,13 +7,11 @@ class PacksMixin(_PostgresConnection): - @async_list async def pack_ids(self): await self.cur.execute("SELECT guild_id, pack_name from packs") return await self.cur.fetchall() - @async_list async def all_packs(self): await self.cur.execute("SELECT * from packs") @@ -26,7 +24,7 @@ async def user_pack_ids(self, user_id: Union[User, int]) -> AsyncList: user_id = user_id.id await self.cur.execute( "SELECT guild_id FROM user_packs WHERE user_id=%(user_id)s", - parameters={"user_id": user_id} + parameters={"user_id": user_id}, ) guilds = await self.cur.fetchall() return [g for g, in guilds] @@ -35,7 +33,7 @@ async def user_pack_ids(self, user_id: Union[User, int]) -> AsyncList: async def get_top_packs_by_count(self) -> AsyncList: await self.cur.execute( "SELECT packs.* FROM (SELECT guild_id, COUNT(*) from user_packs GROUP BY guild_id ORDER BY count DESC LIMIT 25) a JOIN packs on a.guild_id=packs.guild_id", - parameters={} + parameters={}, ) packs = await self.cur.fetchall() return [Pack(*p) for p in packs] @@ -44,7 +42,7 @@ async def get_top_packs_by_count(self) -> AsyncList: async def get_packs_with_emoji(self, emote_id: int) -> AsyncList: await self.cur.execute( "select packs.pack_name from emote_ids join packs on emote_ids.guild_id=packs.guild_id where emote_ids.emote_hash=(SELECT emote_hash FROM emote_ids WHERE emote_id=%(emote_id)s) limit 100", - parameters={"emote_id": emote_id} + parameters={"emote_id": emote_id}, ) packs = await self.cur.fetchall() return [guild_id for guild_id, in packs] @@ -54,7 +52,7 @@ async def user_pack_count(self, user_id: Union[Object, int]) -> int: user_id = user_id.id await self.cur.execute( "SELECT count(*) FROM user_packs WHERE user_id=%(user_id)s", - parameters={"user_id": user_id} + parameters={"user_id": user_id}, ) results = await self.cur.fetchall() return results[0][0] @@ -65,7 +63,7 @@ async def user_packs(self, user_id: Union[User, int]) -> AsyncList: user_id = user_id.id await self.cur.execute( "SELECT packs.* FROM user_packs JOIN packs on packs.guild_id=user_packs.guild_id where user_id=%(user_id)s", - parameters={"user_id": user_id} + parameters={"user_id": user_id}, ) packs = await self.cur.fetchall() return [Pack(*p) for p in packs] @@ -76,7 +74,7 @@ async def pack_member_ids(self, guild_id: Union[Guild, int]) -> AsyncList: guild_id = guild_id.id await self.cur.execute( "SELECT user_id FROM user_packs WHERE guild_id=%(guild_id)s", - parameters={"guild_id": guild_id} + parameters={"guild_id": guild_id}, ) users = await self.cur.fetchall() return [g for g, in users] @@ -84,22 +82,21 @@ async def pack_member_ids(self, guild_id: Union[Guild, int]) -> AsyncList: @async_list async def all_pack_sizes(self) -> AsyncList: await self.cur.execute( - "SELECT guild_id, count(*) FROM user_packs GROUP BY guild_id", - parameters={} + "SELECT guild_id, count(*) FROM user_packs GROUP BY guild_id", parameters={} ) return await self.cur.fetchall() async def in_pack(self, *, user_id: int, guild_id: int) -> bool: await self.cur.execute( "SELECT 1 FROM user_packs WHERE user_id=%(user_id)s AND guild_id=%(guild_id)s LIMIT 1", - parameters={"user_id": user_id, "guild_id": guild_id} + parameters={"user_id": user_id, "guild_id": guild_id}, ) return bool(await self.cur.fetchall()) async def pack_size(self, guild_id: int) -> int: await self.cur.execute( "SELECT COUNT(*) FROM user_packs WHERE guild_id=%(guild_id)s", - parameters={"guild_id": guild_id} + parameters={"guild_id": guild_id}, ) users = await self.cur.fetchall() return users[0][0] @@ -107,7 +104,7 @@ async def pack_size(self, guild_id: int) -> int: async def delete_pack(self, guild_id: int): await self.cur.execute( "DELETE FROM packs WHERE guild_id=%(guild_id)s", - parameters={"guild_id": guild_id} + parameters={"guild_id": guild_id}, ) async def create_update_pack(self, guild_id: int, name: str, public: bool): @@ -118,13 +115,13 @@ async def create_update_pack(self, guild_id: int, name: str, public: bool): "guild_id": guild_id, "pack_name": name, "is_public": public, - } + }, ) async def get_pack_guild(self, guild_id: int) -> Optional[Pack]: await self.cur.execute( "SELECT * FROM packs WHERE guild_id=%(guild_id)s", - parameters={"guild_id": guild_id} + parameters={"guild_id": guild_id}, ) results = await self.cur.fetchall() if results: @@ -134,8 +131,7 @@ async def get_pack_guild(self, guild_id: int) -> Optional[Pack]: async def get_pack_name(self, name: str) -> Optional[Pack]: await self.cur.execute( - "SELECT * FROM packs WHERE pack_name=%(name)s", - parameters={"name": name} + "SELECT * FROM packs WHERE pack_name=%(name)s", parameters={"name": name} ) results = await self.cur.fetchall() if results: @@ -146,8 +142,7 @@ async def get_pack_name(self, name: str) -> Optional[Pack]: @async_list async def get_public_packs(self) -> AsyncList: await self.cur.execute( - "SELECT * FROM packs WHERE is_public=true", - parameters={} + "SELECT * FROM packs WHERE is_public=true", parameters={} ) results = await self.cur.fetchall() return [Pack(*i) for i in results] @@ -155,11 +150,11 @@ async def get_public_packs(self) -> AsyncList: async def join_pack(self, *, user_id: int, guild_id: int): await self.cur.execute( "INSERT INTO user_packs VALUES (%(user_id)s, %(guild_id)s)", - parameters={"guild_id": guild_id, "user_id": user_id} + parameters={"guild_id": guild_id, "user_id": user_id}, ) async def leave_pack(self, *, user_id: int, guild_id: int): await self.cur.execute( "DELETE FROM user_packs WHERE guild_id=%(guild_id)s AND user_id=%(user_id)s", - parameters={"guild_id": guild_id, "user_id": user_id} + parameters={"guild_id": guild_id, "user_id": user_id}, ) diff --git a/sql_helper/mixins/personas.py b/sql_helper/mixins/personas.py index 6498aef..d1daff2 100644 --- a/sql_helper/mixins/personas.py +++ b/sql_helper/mixins/personas.py @@ -10,7 +10,7 @@ class PersonasMixin(_PostgresConnection): async def personas(self, user_id: int) -> AsyncList: await self.cur.execute( "SELECT user_id, short_name, display_name, avatar_url FROM personas WHERE user_id=%(user_id)s", - parameters={"user_id": user_id} + parameters={"user_id": user_id}, ) personas = await self.cur.fetchall() return [Persona(*i) for i in personas] @@ -26,7 +26,7 @@ async def all_personas(self) -> AsyncList: async def get_persona(self, user_id: int, name: str) -> Optional[Persona]: await self.cur.execute( "SELECT user_id, short_name, display_name, avatar_url FROM personas WHERE user_id=%(user_id)s and (short_name=%(name)s or display_name=%(name)s)", - parameters={"user_id": user_id, "name": name} + parameters={"user_id": user_id, "name": name}, ) personas = await self.cur.fetchall() if personas: @@ -37,7 +37,7 @@ async def get_persona(self, user_id: int, name: str) -> Optional[Persona]: async def count_personas(self, user_id: int) -> int: await self.cur.execute( "select count(*) from personas where user_id=%(user_id)s", - parameters={"user_id": user_id} + parameters={"user_id": user_id}, ) results = await self.cur.fetchall() return results[0][0] @@ -45,11 +45,10 @@ async def count_personas(self, user_id: int) -> int: async def persona_exists(self, user_id: int, short_name: str) -> bool: await self.cur.execute( "SELECT 1 FROM personas WHERE user_id=%(user_id)s and short_name=%(name)s LIMIT 1", - parameters={"user_id": user_id, "name": short_name} + parameters={"user_id": user_id, "name": short_name}, ) return bool(await self.cur.fetchall()) - async def create_persona(self, persona: Persona): await self.cur.execute( "INSERT INTO personas (user_id, short_name, display_name, avatar_url) VALUES (%(user_id)s, %(short_name)s, %(display_name)s, %(avatar_url)s)", @@ -57,8 +56,8 @@ async def create_persona(self, persona: Persona): "user_id": persona.user_id, "short_name": persona.short_name, "display_name": persona.display_name, - "avatar_url": persona.avatar_url - } + "avatar_url": persona.avatar_url, + }, ) async def set_persona(self, original_name: str, persona: Persona): @@ -69,18 +68,18 @@ async def set_persona(self, original_name: str, persona: Persona): "original_name": original_name, "short_name": persona.short_name, "display_name": persona.display_name, - "avatar_url": persona.avatar_url - } + "avatar_url": persona.avatar_url, + }, ) async def delete_persona(self, user_id: int, short_name: str): await self.cur.execute( "DELETE FROM personas WHERE user_id=%(user_id)s and short_name=%(short_name)s", - parameters={"user_id": user_id, "short_name": short_name} + parameters={"user_id": user_id, "short_name": short_name}, ) async def delete_all_personas(self, user_id: int): await self.cur.execute( "DELETE FROM personas WHERE user_id=%(user_id)s", - parameters={"user_id": user_id} + parameters={"user_id": user_id}, ) diff --git a/sql_helper/mixins/premium.py b/sql_helper/mixins/premium.py index 24a7690..363216f 100644 --- a/sql_helper/mixins/premium.py +++ b/sql_helper/mixins/premium.py @@ -11,45 +11,55 @@ class PremiumMixin(_PostgresConnection): async def guild_features(self, *, guild_id: int) -> List[GuildFeature]: await self.cur.execute( "SELECT feature FROM guild_features WHERE guild_id=%(guild_id)s", - parameters={"guild_id": guild_id} + parameters={"guild_id": guild_id}, ) return [GuildFeature(feature) for feature, in await self.cur.fetchall()] - async def add_guild_feature(self, *, guild_ids: List[int], features: List[GuildFeature]): + async def add_guild_feature( + self, *, guild_ids: List[int], features: List[GuildFeature] + ): if guild_ids and features: await self.cur.execute( "INSERT INTO guild_features (SELECT * from unnest(%(guild_ids)s) as guild_id CROSS JOIN unnest(%(feature)s::guild_features_enum[]) as feature) " "ON CONFLICT (guild_id, feature) DO NOTHING", - parameters={"guild_ids": guild_ids, "feature": [feature.value for feature in features]} + parameters={ + "guild_ids": guild_ids, + "feature": [feature.value for feature in features], + }, ) - async def remove_guild_feature(self, *, guild_ids: List[int], features: List[GuildFeature]): + async def remove_guild_feature( + self, *, guild_ids: List[int], features: List[GuildFeature] + ): if guild_ids and features: await self.cur.execute( "DELETE FROM guild_features WHERE guild_id=ANY(%(guild_ids)s) AND feature=ANY(%(feature)s::guild_features_enum[])", - parameters={"guild_ids": guild_ids, "feature": [feature.value for feature in features]} + parameters={ + "guild_ids": guild_ids, + "feature": [feature.value for feature in features], + }, ) if GuildFeature.USER_CONTENT in features: await self.cur.execute( "UPDATE guild_settings SET enable_user_content=true WHERE guild_id=ANY(%(guild_ids)s)", - parameters={"guild_ids": guild_ids} + parameters={"guild_ids": guild_ids}, ) if GuildFeature.EMOTE_ROLES in features: await self.cur.execute( "UPDATE guild_settings SET nitro_role=null WHERE guild_id=ANY(%(guild_ids)s)", - parameters={"guild_ids": guild_ids} + parameters={"guild_ids": guild_ids}, ) async def set_premium_user_discord_override(self, member_id: str, discord_id: int): await self.cur.execute( "UPDATE premium_users SET discord_override_id=%(discord_id)s WHERE pledge_id=%(member_id)s", - parameters={"member_id": member_id, "discord_id": discord_id} + parameters={"member_id": member_id, "discord_id": discord_id}, ) async def get_premium_user_discord_ids(self) -> List[int]: await self.cur.execute( "SELECT COALESCE(discord_override_id, discord_patreon_id) FROM premium_users where discord_patreon_id is not null or discord_override_id is not null", - parameters={} + parameters={}, ) results = await self.cur.fetchall() return [user_id for user_id, in results] @@ -57,7 +67,7 @@ async def get_premium_user_discord_ids(self) -> List[int]: async def get_premium_guild_ids(self) -> List[int]: await self.cur.execute( "SELECT guild_id FROM guild_features where feature = 'premium'", - parameters={} + parameters={}, ) results = await self.cur.fetchall() return [guild_id for guild_id, in results] @@ -76,12 +86,14 @@ async def update_premium_users(self, users: List[PremiumUser]): discord_ids.append(user.discord_id) lifetime_support_cents.append(user.lifetime_support_cents) last_charge_dates.append(user.last_charge_date) - last_charge_status.append(user.last_charge_status and user.last_charge_status.value) + last_charge_status.append( + user.last_charge_status and user.last_charge_status.value + ) pledge_ids.append(user.pledge_id) await self.cur.execute( "INSERT INTO premium_users (patreon_id, discord_patreon_id, lifetime_support_cents, last_charge_date, last_charge_status, tokens, tokens_spent, pledge_id) VALUES " "(unnest(%(patreon_ids)s), unnest(%(discord_ids)s::bigint[]), unnest(%(lifetime_support_cents)s), unnest(%(last_charge_dates)s), unnest(%(last_charge_status)s)::premium_last_charge_status_enum, 0, 0, unnest(%(pledge_ids)s))" - 'ON CONFLICT (patreon_id) DO UPDATE SET (discord_patreon_id, lifetime_support_cents, last_charge_date, last_charge_status, tokens, tokens_spent, pledge_id) = ' + "ON CONFLICT (patreon_id) DO UPDATE SET (discord_patreon_id, lifetime_support_cents, last_charge_date, last_charge_status, tokens, tokens_spent, pledge_id) = " "(EXCLUDED.discord_patreon_id, EXCLUDED.lifetime_support_cents, EXCLUDED.last_charge_date, EXCLUDED.last_charge_status, premium_users.tokens, premium_users.tokens_spent, EXCLUDED.pledge_id)", parameters={ "patreon_ids": patreon_ids, @@ -90,13 +102,13 @@ async def update_premium_users(self, users: List[PremiumUser]): "last_charge_dates": last_charge_dates, "last_charge_status": last_charge_status, "pledge_ids": pledge_ids, - } + }, ) async def get_premium_user_patreon(self, member_id: str) -> Optional[PremiumUser]: await self.cur.execute( "SELECT * FROM premium_users WHERE pledge_id=%(pledge_id)s", - parameters={"pledge_id": member_id} + parameters={"pledge_id": member_id}, ) results = await self.cur.fetchall() if not results: @@ -106,7 +118,7 @@ async def get_premium_user_patreon(self, member_id: str) -> Optional[PremiumUser async def get_premium_user_discord(self, discord_id: int) -> Optional[PremiumUser]: await self.cur.execute( "SELECT * FROM premium_users WHERE coalesce(discord_override_id, discord_patreon_id)=%(discord_id)s", - parameters={"discord_id": discord_id} + parameters={"discord_id": discord_id}, ) results = await self.cur.fetchall() @@ -117,7 +129,7 @@ async def get_premium_user_discord(self, discord_id: int) -> Optional[PremiumUse async def get_premium_users(self) -> List[PremiumUser]: await self.cur.execute( "SELECT * FROM premium_users WHERE last_charge_status is not null", - parameters={} + parameters={}, ) results = await self.cur.fetchall() return [PremiumUser(*i) for i in results] diff --git a/sql_helper/premium_user.py b/sql_helper/premium_user.py index 0d679f2..f0ba6f4 100644 --- a/sql_helper/premium_user.py +++ b/sql_helper/premium_user.py @@ -8,20 +8,34 @@ ALLOWED_LAST_CHARGE_STATUSES = ( PremiumLastChargeStatus.PAID, PremiumLastChargeStatus.PENDING, - PremiumLastChargeStatus.DELETED + PremiumLastChargeStatus.DELETED, ) -class PremiumUser(namedtuple( - "PremiumUser", - ["patreon_id", "discord_patreon_id", "lifetime_support_cents", "last_charge_date", "last_charge_status", "tokens", "tokens_spent", "pledge_id", "discord_override_id"] -)): +class PremiumUser( + namedtuple( + "PremiumUser", + [ + "patreon_id", + "discord_patreon_id", + "lifetime_support_cents", + "last_charge_date", + "last_charge_status", + "tokens", + "tokens_spent", + "pledge_id", + "discord_override_id", + ], + ) +): def __new__(cls, *args, **kwargs): if args: args = list(args) args[4] = PremiumLastChargeStatus(args[4]) if kwargs: - kwargs["last_charge_date"] = kwargs.get("last_charge_date") and kwargs["last_charge_date"].date() + kwargs["last_charge_date"] = ( + kwargs.get("last_charge_date") and kwargs["last_charge_date"].date() + ) return super(PremiumUser, cls).__new__(cls, *args, **kwargs) @property @@ -30,7 +44,7 @@ def discord_id(self): def should_have_premium(self, discord_requirement: bool = True) -> bool: return ( - (self.discord_id or not discord_requirement) and - self.last_charge_status in ALLOWED_LAST_CHARGE_STATUSES and - self.last_charge_date + relativedelta(months=1, days=5) >= date.today() + (self.discord_id or not discord_requirement) + and self.last_charge_status in ALLOWED_LAST_CHARGE_STATUSES + and self.last_charge_date + relativedelta(months=1, days=5) >= date.today() ) diff --git a/sql_helper/webhook.py b/sql_helper/webhook.py index 78c5523..be84536 100644 --- a/sql_helper/webhook.py +++ b/sql_helper/webhook.py @@ -2,9 +2,11 @@ from collections import namedtuple -class Webhook(namedtuple("Webhook", [ - "webhook_id", "guild_id", "channel_id", "token", "name", "user_id" -])): +class Webhook( + namedtuple( + "Webhook", ["webhook_id", "guild_id", "channel_id", "token", "name", "user_id"] + ) +): index = "webhook" webhook_id: str diff --git a/tests/base.py b/tests/base.py index 8b0ac7f..e87a954 100644 --- a/tests/base.py +++ b/tests/base.py @@ -19,7 +19,7 @@ emote_sha=text(), guild_id=integers(), name=name, - has_roles=booleans() + has_roles=booleans(), ) @@ -45,5 +45,16 @@ def postgres(get_guild, get_emoji) -> PostgresConnection: return postgres -def get_sql_emoji(emote_id: int, emote_hash: str = "abc", usable: bool = True, animated: bool = False, emote_sha: str = "cde", guild_id: Optional[int] = 123, name: str = "foo", has_roles: bool= False) -> SQLEmoji: - return SQLEmoji(emote_id, emote_hash, usable, animated, emote_sha, guild_id, name, has_roles) +def get_sql_emoji( + emote_id: int, + emote_hash: str = "abc", + usable: bool = True, + animated: bool = False, + emote_sha: str = "cde", + guild_id: Optional[int] = 123, + name: str = "foo", + has_roles: bool = False, +) -> SQLEmoji: + return SQLEmoji( + emote_id, emote_hash, usable, animated, emote_sha, guild_id, name, has_roles + ) diff --git a/tests/test_emotes.py b/tests/test_emotes.py index 27c134c..d768cd6 100644 --- a/tests/test_emotes.py +++ b/tests/test_emotes.py @@ -10,9 +10,7 @@ @pytest.mark.asyncio async def test_get_emote_like_no_emote(postgres): - postgres.cur.fetchall.side_effect = [ - [] - ] + postgres.cur.fetchall.side_effect = [[]] rtn = await postgres.get_emote_like(emote_id=123) assert rtn is None assert postgres.cur.fetchall.call_count == 1 @@ -20,9 +18,7 @@ async def test_get_emote_like_no_emote(postgres): @pytest.mark.asyncio async def test_get_emote_like_default(postgres): - postgres.cur.fetchall.side_effect = [ - [get_sql_emoji(123)] - ] + postgres.cur.fetchall.side_effect = [[get_sql_emoji(123)]] rtn = await postgres.get_emote_like(emote_id=123) assert rtn == get_sql_emoji(123) assert postgres.cur.fetchall.call_count == 1 @@ -30,10 +26,7 @@ async def test_get_emote_like_default(postgres): @pytest.mark.asyncio async def test_get_emote_like_unusable_no_secondary(postgres): - postgres.cur.fetchall.side_effect = [ - [get_sql_emoji(123, usable=False)], - [] - ] + postgres.cur.fetchall.side_effect = [[get_sql_emoji(123, usable=False)], []] rtn = await postgres.get_emote_like(emote_id=123) assert rtn is None assert postgres.cur.fetchall.call_count == 2 @@ -43,7 +36,7 @@ async def test_get_emote_like_unusable_no_secondary(postgres): async def test_get_emote_like_unusable_with_secondary(postgres): postgres.cur.fetchall.side_effect = [ [get_sql_emoji(123, usable=False)], - [get_sql_emoji(234)] + [get_sql_emoji(234)], ] rtn = await postgres.get_emote_like(emote_id=123) assert rtn == get_sql_emoji(234) @@ -54,7 +47,7 @@ async def test_get_emote_like_unusable_with_secondary(postgres): async def test_get_emote_like_no_guild_no_secondary(postgres): postgres.cur.fetchall.side_effect = [ [get_sql_emoji(123, guild_id=None)], - [get_sql_emoji(234)] + [get_sql_emoji(234)], ] rtn = await postgres.get_emote_like(emote_id=123) assert rtn == get_sql_emoji(234) @@ -72,64 +65,71 @@ async def test_get_emote_like_no_require_guild(postgres): @pytest.mark.asyncio -@pytest.mark.parametrize("side_effect, expected", [ - ([None], None), - ( - [get_sql_emoji(1, name="test", usable=True)], - get_sql_emoji(1, name="test", usable=True) - ), - ( - [ - get_sql_emoji(2, name="Test", usable=True), +@pytest.mark.parametrize( + "side_effect, expected", + [ + ([None], None), + ( + [get_sql_emoji(1, name="test", usable=True)], get_sql_emoji(1, name="test", usable=True), - ], - get_sql_emoji(1, name="test", usable=True) - ), - ( - [ - get_sql_emoji(2, name="Test", usable=True), - get_sql_emoji(3, name="test", usable=False), + ), + ( + [ + get_sql_emoji(2, name="Test", usable=True), + get_sql_emoji(1, name="test", usable=True), + ], get_sql_emoji(1, name="test", usable=True), - ], - get_sql_emoji(1, name="test", usable=True) - ), - ( - [ + ), + ( + [ + get_sql_emoji(2, name="Test", usable=True), + get_sql_emoji(3, name="test", usable=False), + get_sql_emoji(1, name="test", usable=True), + ], + get_sql_emoji(1, name="test", usable=True), + ), + ( + [ + get_sql_emoji(2, name="Test", usable=True), + get_sql_emoji(3, name="test", usable=False), + None, + ], get_sql_emoji(2, name="Test", usable=True), - get_sql_emoji(3, name="test", usable=False), - None - ], - get_sql_emoji(2, name="Test", usable=True) - ), - ( - [ - get_sql_emoji(2, name="Test", usable=False), + ), + ( + [ + get_sql_emoji(2, name="Test", usable=False), + get_sql_emoji(3, name="test", usable=True), + ], get_sql_emoji(3, name="test", usable=True), - ], - get_sql_emoji(3, name="test", usable=True) - ), - ( - [ - get_sql_emoji(2, name="test", usable=False), - None, + ), + ( + [ + get_sql_emoji(2, name="test", usable=False), + None, + get_sql_emoji(4, name="Test", usable=True), + ], get_sql_emoji(4, name="Test", usable=True), - ], - get_sql_emoji(4, name="Test", usable=True) - ), - ( - [ - get_sql_emoji(2, name="test", usable=False), - None, - None, + ), + ( + [ + get_sql_emoji(2, name="test", usable=False), + None, + None, + get_sql_emoji(4, name="test", usable=True), + ], get_sql_emoji(4, name="test", usable=True), - ], - get_sql_emoji(4, name="test", usable=True) - ), -]) -async def test_case_insensitive_get_emote(postgres, side_effect: List[Optional[SQLEmoji]], expected: Optional[SQLEmoji]): + ), + ], +) +async def test_case_insensitive_get_emote( + postgres, side_effect: List[Optional[SQLEmoji]], expected: Optional[SQLEmoji] +): side_effect = [None if i is None else [i] for i in side_effect] postgres.cur.fetchall.side_effect = side_effect - rtn = await postgres._case_insensitive_get_emote(query_where="", emote_name="test", parameters={}) + rtn = await postgres._case_insensitive_get_emote( + query_where="", emote_name="test", parameters={} + ) assert rtn == expected assert postgres.cur.fetchall.call_count == len(side_effect) @@ -137,10 +137,12 @@ async def test_case_insensitive_get_emote(postgres, side_effect: List[Optional[S @pytest.mark.asyncio @given( lists(sqlemoji_strategy(name=sampled_from(["test", "Test", "TEST"]))), - sampled_from(["test", "Test", "TEST"]) + sampled_from(["test", "Test", "TEST"]), ) @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) -async def test_case_insensitive_get_emote_never_unusable(postgres, emote_pool: List[SQLEmoji], emote_name: str): +async def test_case_insensitive_get_emote_never_unusable( + postgres, emote_pool: List[SQLEmoji], emote_name: str +): sql_query = None def _execute_side_effect(query: str, parameters): @@ -161,7 +163,9 @@ def _fetchall_side_effect(): postgres.cur.execute.side_effect = _execute_side_effect postgres.cur.fetchall.side_effect = _fetchall_side_effect - rtn: Optional[SQLEmoji] = await postgres._case_insensitive_get_emote(query_where="", emote_name=emote_name, parameters={}) + rtn: Optional[SQLEmoji] = await postgres._case_insensitive_get_emote( + query_where="", emote_name=emote_name, parameters={} + ) if rtn is None: assert all(emote.usable is False for emote in emote_pool) @@ -169,21 +173,21 @@ def _fetchall_side_effect(): assert rtn in emote_pool assert rtn.usable if rtn.name != emote_name: - assert not any(emote.name == emote_name and emote.usable for emote in emote_pool) + assert not any( + emote.name == emote_name and emote.usable for emote in emote_pool + ) @pytest.mark.asyncio @given( - sqlemoji_strategy( - emote_hash=sampled_from(["1", "2", "3", "4"]) - ), - lists(sqlemoji_strategy( - emote_hash=sampled_from(["1", "2", "3", "4"]) - )), + sqlemoji_strategy(emote_hash=sampled_from(["1", "2", "3", "4"])), + lists(sqlemoji_strategy(emote_hash=sampled_from(["1", "2", "3", "4"]))), booleans(), ) @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) -async def test_get_emote_like(postgres, to_find: SQLEmoji, emote_pool: List[SQLEmoji], add_to_pool: bool): +async def test_get_emote_like( + postgres, to_find: SQLEmoji, emote_pool: List[SQLEmoji], add_to_pool: bool +): sql_query = None parameters_ = {} @@ -218,7 +222,11 @@ def _fetchall_side_effect(): # If the emote id we're looking for isn't in the pool, we never find anything sharing its hash assert rtn is None elif rtn is None: - assert not any(emote.usable for emote in emote_pool if emote.emote_hash == to_find.emote_hash) + assert not any( + emote.usable + for emote in emote_pool + if emote.emote_hash == to_find.emote_hash + ) else: assert rtn.emote_hash == to_find.emote_hash assert rtn.usable