Skip to content

Commit

Permalink
Refactor/autocomplete (#77)
Browse files Browse the repository at this point in the history
* Feat: support lb autocomplete and caching

* Feat: add autocomplete

* Add brief comment about runtime so people don't worry

---------

Co-authored-by: alexzhang13 <[email protected]>
  • Loading branch information
S1ro1 and alexzhang13 authored Dec 26, 2024
1 parent 30c6d7d commit 6f04bd5
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 14 deletions.
17 changes: 5 additions & 12 deletions src/discord-cluster-manager/cogs/leaderboard_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,12 @@
from consts import GitHubGPU, ModalGPU
from discord import Interaction, SelectOption, app_commands, ui
from discord.ext import commands
from leaderboard_db import leaderboard_name_autocomplete
from utils import extract_score, get_user_from_id, send_discord_message, setup_logging

logger = setup_logging()


async def leaderboard_name_autocomplete(
interaction: discord.Interaction,
current: str,
) -> list[app_commands.Choice[str]]:
"""Return leaderboard names that match the current typed name"""
bot = interaction.client
with bot.leaderboard_db as db:
leaderboards = db.get_leaderboards()
filtered = [lb["name"] for lb in leaderboards if current.lower() in lb["name"].lower()]
return [app_commands.Choice(name=name, value=name) for name in filtered[:25]]


class LeaderboardSubmitCog(app_commands.Group):
def __init__(
self,
Expand All @@ -37,6 +26,7 @@ def __init__(
leaderboard_name="Name of the competition / kernel to optimize",
script="The Python / CUDA script file to run",
)
@app_commands.autocomplete(leaderboard_name=leaderboard_name_autocomplete)
# TODO: Modularize this so all the write functionality is in here. Haven't figured
# a good way to do this yet.
async def submit(
Expand All @@ -51,6 +41,7 @@ async def submit(
@app_commands.describe(
gpu_type="Choose the GPU type for Modal",
)
@app_commands.autocomplete(leaderboard_name=leaderboard_name_autocomplete)
@app_commands.choices(
gpu_type=[app_commands.Choice(name=gpu.value, value=gpu.value) for gpu in ModalGPU]
)
Expand Down Expand Up @@ -104,6 +95,7 @@ async def submit_modal(

### GITHUB SUBCOMMAND
@app_commands.command(name="github", description="Submit leaderboard data for GitHub")
@app_commands.autocomplete(leaderboard_name=leaderboard_name_autocomplete)
async def submit_github(
self,
interaction: discord.Interaction,
Expand Down Expand Up @@ -423,6 +415,7 @@ async def leaderboard_create(
)

@discord.app_commands.describe(leaderboard_name="Name of the leaderboard")
@app_commands.autocomplete(leaderboard_name=leaderboard_name_autocomplete)
async def get_leaderboard_submissions(
self,
interaction: discord.Interaction,
Expand Down
26 changes: 25 additions & 1 deletion src/discord-cluster-manager/leaderboard_db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

import discord
import psycopg2
from consts import (
DATABASE_URL,
Expand All @@ -10,7 +11,28 @@
POSTGRES_USER,
)
from psycopg2 import Error
from utils import LeaderboardItem, SubmissionItem
from utils import LeaderboardItem, LRUCache, SubmissionItem

leaderboard_name_cache = LRUCache(max_size=512)


async def leaderboard_name_autocomplete(
interaction: discord.Interaction,
current: str,
) -> list[discord.app_commands.Choice[str]]:
"""Return leaderboard names that match the current typed name"""
cached_value = leaderboard_name_cache[current]
if cached_value is not None:
return cached_value

bot = interaction.client
with bot.leaderboard_db as db:
leaderboards = db.get_leaderboards()
filtered = [lb["name"] for lb in leaderboards if current.lower() in lb["name"].lower()]
leaderboard_name_cache[current] = [
discord.app_commands.Choice(name=name, value=name) for name in filtered[:25]
]
return leaderboard_name_cache[current]


class LeaderboardDB:
Expand Down Expand Up @@ -83,6 +105,7 @@ def create_leaderboard(self, leaderboard: LeaderboardItem) -> Optional[None]:
)

self.connection.commit()
leaderboard_name_cache.invalidate() # Invalidate autocomplete cache
except psycopg2.Error as e:
self.connection.rollback() # Ensure rollback if error occurs
return f"Error during leaderboard creation: {e}"
Expand All @@ -98,6 +121,7 @@ def delete_leaderboard(self, leaderboard_name: str) -> Optional[str]:
(leaderboard_name,),
)
self.connection.commit()
leaderboard_name_cache.invalidate() # Invalidate autocomplete cache
except psycopg2.Error as e:
self.connection.rollback()
return f"Error during leaderboard deletion: {e}"
Expand Down
49 changes: 48 additions & 1 deletion src/discord-cluster-manager/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import re
import subprocess
from typing import List, TypedDict
from typing import Any, List, TypedDict

import discord

Expand Down Expand Up @@ -82,6 +82,53 @@ def extract_score(score_str: str) -> float:
return None


class LRUCache:
def __init__(self, max_size: int):
"""LRU Cache implementation, as functools.lru doesn't work in async code
Note: Implementation uses list for convenience because cache is small, so
runtime complexity does not matter here.
Args:
max_size (int): Maximum size of the cache
"""
self._cache = {}
self._max_size = max_size
self._q = []

def __getitem__(self, key: Any, default: Any = None) -> Any | None:
if key not in self._cache:
return default

self._q.remove(key)
self._q.append(key)
return self._cache[key]

def __setitem__(self, key: Any, value: Any) -> None:
if key in self._cache:
self._q.remove(key)
self._q.append(key)
self._cache[key] = value
return

if len(self._cache) >= self._max_size:
self._cache.pop(self._q.pop(0))

self._cache[key] = value
self._q.append(key)

def __contains__(self, key: Any) -> bool:
return key in self._cache

def __len__(self) -> int:
return len(self._cache)

def invalidate(self):
"""Invalidate the cache, clearing all entries, should be called when updating the underlying
data in db
"""
self._cache.clear()
self._q.clear()


class LeaderboardItem(TypedDict):
name: str
deadline: datetime.datetime
Expand Down

0 comments on commit 6f04bd5

Please sign in to comment.