From 6f04bd502805c3e10273b0ba0cfa0b3a3dc5cde6 Mon Sep 17 00:00:00 2001 From: Matej Sirovatka <54212263+S1ro1@users.noreply.github.com> Date: Thu, 26 Dec 2024 19:46:51 +0100 Subject: [PATCH] Refactor/autocomplete (#77) * Feat: support lb autocomplete and caching * Feat: add autocomplete * Add brief comment about runtime so people don't worry --------- Co-authored-by: alexzhang13 --- .../cogs/leaderboard_cog.py | 17 ++----- src/discord-cluster-manager/leaderboard_db.py | 26 +++++++++- src/discord-cluster-manager/utils.py | 49 ++++++++++++++++++- 3 files changed, 78 insertions(+), 14 deletions(-) diff --git a/src/discord-cluster-manager/cogs/leaderboard_cog.py b/src/discord-cluster-manager/cogs/leaderboard_cog.py index d7c1e17..665da77 100644 --- a/src/discord-cluster-manager/cogs/leaderboard_cog.py +++ b/src/discord-cluster-manager/cogs/leaderboard_cog.py @@ -6,23 +6,12 @@ from consts import GitHubGPU, ModalGPU from discord import Interaction, SelectOption, app_commands, ui from discord.ext import commands +from leaderboard_db import leaderboard_name_autocomplete from utils import extract_score, get_user_from_id, send_discord_message, setup_logging logger = setup_logging() -async def leaderboard_name_autocomplete( - interaction: discord.Interaction, - current: str, -) -> list[app_commands.Choice[str]]: - """Return leaderboard names that match the current typed name""" - bot = interaction.client - with bot.leaderboard_db as db: - leaderboards = db.get_leaderboards() - filtered = [lb["name"] for lb in leaderboards if current.lower() in lb["name"].lower()] - return [app_commands.Choice(name=name, value=name) for name in filtered[:25]] - - class LeaderboardSubmitCog(app_commands.Group): def __init__( self, @@ -37,6 +26,7 @@ def __init__( leaderboard_name="Name of the competition / kernel to optimize", script="The Python / CUDA script file to run", ) + @app_commands.autocomplete(leaderboard_name=leaderboard_name_autocomplete) # TODO: Modularize this so all the write functionality is in here. Haven't figured # a good way to do this yet. async def submit( @@ -51,6 +41,7 @@ async def submit( @app_commands.describe( gpu_type="Choose the GPU type for Modal", ) + @app_commands.autocomplete(leaderboard_name=leaderboard_name_autocomplete) @app_commands.choices( gpu_type=[app_commands.Choice(name=gpu.value, value=gpu.value) for gpu in ModalGPU] ) @@ -104,6 +95,7 @@ async def submit_modal( ### GITHUB SUBCOMMAND @app_commands.command(name="github", description="Submit leaderboard data for GitHub") + @app_commands.autocomplete(leaderboard_name=leaderboard_name_autocomplete) async def submit_github( self, interaction: discord.Interaction, @@ -423,6 +415,7 @@ async def leaderboard_create( ) @discord.app_commands.describe(leaderboard_name="Name of the leaderboard") + @app_commands.autocomplete(leaderboard_name=leaderboard_name_autocomplete) async def get_leaderboard_submissions( self, interaction: discord.Interaction, diff --git a/src/discord-cluster-manager/leaderboard_db.py b/src/discord-cluster-manager/leaderboard_db.py index e36b1ed..0da899e 100644 --- a/src/discord-cluster-manager/leaderboard_db.py +++ b/src/discord-cluster-manager/leaderboard_db.py @@ -1,5 +1,6 @@ from typing import Optional +import discord import psycopg2 from consts import ( DATABASE_URL, @@ -10,7 +11,28 @@ POSTGRES_USER, ) from psycopg2 import Error -from utils import LeaderboardItem, SubmissionItem +from utils import LeaderboardItem, LRUCache, SubmissionItem + +leaderboard_name_cache = LRUCache(max_size=512) + + +async def leaderboard_name_autocomplete( + interaction: discord.Interaction, + current: str, +) -> list[discord.app_commands.Choice[str]]: + """Return leaderboard names that match the current typed name""" + cached_value = leaderboard_name_cache[current] + if cached_value is not None: + return cached_value + + bot = interaction.client + with bot.leaderboard_db as db: + leaderboards = db.get_leaderboards() + filtered = [lb["name"] for lb in leaderboards if current.lower() in lb["name"].lower()] + leaderboard_name_cache[current] = [ + discord.app_commands.Choice(name=name, value=name) for name in filtered[:25] + ] + return leaderboard_name_cache[current] class LeaderboardDB: @@ -83,6 +105,7 @@ def create_leaderboard(self, leaderboard: LeaderboardItem) -> Optional[None]: ) self.connection.commit() + leaderboard_name_cache.invalidate() # Invalidate autocomplete cache except psycopg2.Error as e: self.connection.rollback() # Ensure rollback if error occurs return f"Error during leaderboard creation: {e}" @@ -98,6 +121,7 @@ def delete_leaderboard(self, leaderboard_name: str) -> Optional[str]: (leaderboard_name,), ) self.connection.commit() + leaderboard_name_cache.invalidate() # Invalidate autocomplete cache except psycopg2.Error as e: self.connection.rollback() return f"Error during leaderboard deletion: {e}" diff --git a/src/discord-cluster-manager/utils.py b/src/discord-cluster-manager/utils.py index c0bed57..6b506cd 100644 --- a/src/discord-cluster-manager/utils.py +++ b/src/discord-cluster-manager/utils.py @@ -2,7 +2,7 @@ import logging import re import subprocess -from typing import List, TypedDict +from typing import Any, List, TypedDict import discord @@ -82,6 +82,53 @@ def extract_score(score_str: str) -> float: return None +class LRUCache: + def __init__(self, max_size: int): + """LRU Cache implementation, as functools.lru doesn't work in async code + Note: Implementation uses list for convenience because cache is small, so + runtime complexity does not matter here. + Args: + max_size (int): Maximum size of the cache + """ + self._cache = {} + self._max_size = max_size + self._q = [] + + def __getitem__(self, key: Any, default: Any = None) -> Any | None: + if key not in self._cache: + return default + + self._q.remove(key) + self._q.append(key) + return self._cache[key] + + def __setitem__(self, key: Any, value: Any) -> None: + if key in self._cache: + self._q.remove(key) + self._q.append(key) + self._cache[key] = value + return + + if len(self._cache) >= self._max_size: + self._cache.pop(self._q.pop(0)) + + self._cache[key] = value + self._q.append(key) + + def __contains__(self, key: Any) -> bool: + return key in self._cache + + def __len__(self) -> int: + return len(self._cache) + + def invalidate(self): + """Invalidate the cache, clearing all entries, should be called when updating the underlying + data in db + """ + self._cache.clear() + self._q.clear() + + class LeaderboardItem(TypedDict): name: str deadline: datetime.datetime