From c3b6f4d0847a7cd002bee098b562639b9d42312b Mon Sep 17 00:00:00 2001 From: Hunter Wittenborn Date: Sun, 28 Jul 2024 10:53:00 -0500 Subject: [PATCH] feat(openai): use asynchronous OpenAI endpoints Also log any errors to the Discord channel, if they happen. --- citadel/commands/generate.py | 19 +++---------------- citadel/globals.py | 6 +++--- citadel/main.py | 11 ++++++++--- citadel/utils.py | 2 +- 4 files changed, 15 insertions(+), 23 deletions(-) diff --git a/citadel/commands/generate.py b/citadel/commands/generate.py index aeecdd7..f988e54 100644 --- a/citadel/commands/generate.py +++ b/citadel/commands/generate.py @@ -117,7 +117,7 @@ def get_resp(self) -> tuple[str, discord.Interaction] | None: '(e.g. "Shakespeare Texts")', test_name='The name you want to assign the generated test (e.g. "The Renaissance")', ) -async def generate( # noqa: C901, PLR0915 +async def generate( # noqa: C901 interaction: discord.Interaction, msg_filter: str, test_name: str, @@ -133,21 +133,8 @@ async def generate( # noqa: C901, PLR0915 if not msg.content.startswith("/") and interaction.user != interaction.client.user ] - completion = globals.get_openai_client().chat.completions.create( - model=globals.get_openai_model(), - messages=[ - {"role": "user", "content": NOTES_PROMPT.render(messages=messages, msg_filter=msg_filter)}, - ], - ) - - try: - content = completion.choices[0].message.content - if content is None: - raise RuntimeError("Unexpected empty output from OpenAI") # noqa: TRY003,EM101 - output = json.loads(content) - except json.JSONDecodeError: - await interaction.followup.send("There was an error generating the notes. Please try again.") - raise + completion = await utils.get_openai_resp(NOTES_PROMPT.render(messages=messages, msg_filter=msg_filter)) + output = json.loads(completion) buttons = Buttons() message = await interaction.followup.send( diff --git a/citadel/globals.py b/citadel/globals.py index 57ecc35..f6e6281 100644 --- a/citadel/globals.py +++ b/citadel/globals.py @@ -5,7 +5,7 @@ import inflect from jinja2 import Environment, FileSystemLoader, StrictUndefined -from openai import OpenAI +from openai import AsyncOpenAI from rich.logging import RichHandler from sqlalchemy import Engine @@ -22,12 +22,12 @@ # OpenAI client # # The API key and model name get set in `main::main` -OPENAI_CLIENT: OpenAI | None = None +OPENAI_CLIENT: AsyncOpenAI | None = None OPENAI_MODEL: str | None = None UNINITIALIZED_ERR = "Variable hasn't been initalized yet" -def get_openai_client() -> OpenAI: +def get_openai_client() -> AsyncOpenAI: """Get the global OpenAI client.""" if OPENAI_CLIENT is None: raise NameError(UNINITIALIZED_ERR) diff --git a/citadel/main.py b/citadel/main.py index ebe9d15..23ecf71 100644 --- a/citadel/main.py +++ b/citadel/main.py @@ -6,7 +6,7 @@ import typer from discord import app_commands from dotenv import load_dotenv -from openai import OpenAI +from openai import AsyncOpenAI from sqlmodel import SQLModel, create_engine from citadel import commands, globals @@ -42,10 +42,11 @@ async def setup_hook(self) -> None: @APP.command() -def main( +def main( # noqa: PLR0913 discord_token: Annotated[str, typer.Argument(envvar="CITADEL_DISCORD_TOKEN")], openai_token: Annotated[str, typer.Argument(envvar="OPENAI_TOKEN")], openai_model: Annotated[str, typer.Argument(envvar="OPENAI_MODEL")] = "gpt-4o", + openai_base: Annotated[str, typer.Argument(envvar="OPENAI_BASE")] = "https://api.openai.com/v1", db_path: Annotated[str, typer.Argument(envvar="DB_PATH")] = "./citadel.db", log_level: Annotated[LogLevel, typer.Argument(case_sensitive=False, envvar="LOG_LEVEL")] = LogLevel.INFO, ) -> None: @@ -54,7 +55,7 @@ def main( Environment variables can be set via the command-line, or in a file named `.env`. """ globals.LOGGER.setLevel(logging.getLevelName(log_level.value)) - globals.OPENAI_CLIENT = OpenAI(api_key=openai_token) + globals.OPENAI_CLIENT = AsyncOpenAI(api_key=openai_token, base_url=openai_base) globals.OPENAI_MODEL = openai_model # Set up the database. @@ -68,6 +69,10 @@ def main( client.tree.add_command(commands.generate) client.tree.add_command(commands.quiz) + @client.tree.error + async def on_app_command_error(interaction: discord.Interaction, error: app_commands.AppCommandError) -> None: # noqa: ARG001 + await interaction.channel.send("An unknown error has occurred. Please check server logs for more information.") + client.run(discord_token) diff --git a/citadel/utils.py b/citadel/utils.py index e9041f5..8527bff 100644 --- a/citadel/utils.py +++ b/citadel/utils.py @@ -42,7 +42,7 @@ def get_responses(self) -> list[tuple[str, discord.Interaction]]: async def get_openai_resp(msg: str) -> str: - completion = globals.get_openai_client().chat.completions.create( + completion = await globals.get_openai_client().chat.completions.create( model=globals.get_openai_model(), messages=[{"role": "user", "content": msg}], )