Skip to content

Commit

Permalink
feat(repo): merge pull request #30 from Python-Code-Jam-2024-Royal-Re…
Browse files Browse the repository at this point in the history
…dshifts/hwittenborn/async-openai
  • Loading branch information
hwittenborn authored Jul 28, 2024
2 parents 415f1e1 + ad42e3d commit b168a42
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 23 deletions.
19 changes: 3 additions & 16 deletions citadel/commands/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_resp(self) -> tuple[str, discord.Interaction] | None:
'(e.g. "Shakespeare Texts")',
test_name='The name you want to assign the generated test (e.g. "The Renaissance")',
)
async def generate( # noqa: C901, PLR0915
async def generate( # noqa: C901
interaction: discord.Interaction,
msg_filter: str,
test_name: str,
Expand All @@ -133,21 +133,8 @@ async def generate( # noqa: C901, PLR0915
if not msg.content.startswith("/") and interaction.user != interaction.client.user
]

completion = globals.get_openai_client().chat.completions.create(
model=globals.get_openai_model(),
messages=[
{"role": "user", "content": NOTES_PROMPT.render(messages=messages, msg_filter=msg_filter)},
],
)

try:
content = completion.choices[0].message.content
if content is None:
raise RuntimeError("Unexpected empty output from OpenAI") # noqa: TRY003,EM101
output = json.loads(content)
except json.JSONDecodeError:
await interaction.followup.send("There was an error generating the notes. Please try again.")
raise
completion = await utils.get_openai_resp(NOTES_PROMPT.render(messages=messages, msg_filter=msg_filter))
output = json.loads(completion)

buttons = Buttons()
message = await interaction.followup.send(
Expand Down
6 changes: 3 additions & 3 deletions citadel/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import inflect
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from openai import OpenAI
from openai import AsyncOpenAI
from rich.logging import RichHandler
from sqlalchemy import Engine

Expand All @@ -22,12 +22,12 @@
# OpenAI client
#
# The API key and model name get set in `main::main`
OPENAI_CLIENT: OpenAI | None = None
OPENAI_CLIENT: AsyncOpenAI | None = None
OPENAI_MODEL: str | None = None
UNINITIALIZED_ERR = "Variable hasn't been initalized yet"


def get_openai_client() -> OpenAI:
def get_openai_client() -> AsyncOpenAI:
"""Get the global OpenAI client."""
if OPENAI_CLIENT is None:
raise NameError(UNINITIALIZED_ERR)
Expand Down
11 changes: 8 additions & 3 deletions citadel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typer
from discord import app_commands
from dotenv import load_dotenv
from openai import OpenAI
from openai import AsyncOpenAI
from sqlmodel import SQLModel, create_engine

from citadel import commands, globals
Expand Down Expand Up @@ -42,10 +42,11 @@ async def setup_hook(self) -> None:


@APP.command()
def main(
def main( # noqa: PLR0913
discord_token: Annotated[str, typer.Argument(envvar="CITADEL_DISCORD_TOKEN")],
openai_token: Annotated[str, typer.Argument(envvar="OPENAI_TOKEN")],
openai_model: Annotated[str, typer.Argument(envvar="OPENAI_MODEL")] = "gpt-4o",
openai_base: Annotated[str, typer.Argument(envvar="OPENAI_BASE")] = "https://api.openai.com/v1",
db_path: Annotated[str, typer.Argument(envvar="DB_PATH")] = "./citadel.db",
log_level: Annotated[LogLevel, typer.Argument(case_sensitive=False, envvar="LOG_LEVEL")] = LogLevel.INFO,
) -> None:
Expand All @@ -54,7 +55,7 @@ def main(
Environment variables can be set via the command-line, or in a file named `.env`.
"""
globals.LOGGER.setLevel(logging.getLevelName(log_level.value))
globals.OPENAI_CLIENT = OpenAI(api_key=openai_token)
globals.OPENAI_CLIENT = AsyncOpenAI(api_key=openai_token, base_url=openai_base)
globals.OPENAI_MODEL = openai_model

# Set up the database.
Expand All @@ -68,6 +69,10 @@ def main(
client.tree.add_command(commands.generate)
client.tree.add_command(commands.quiz)

@client.tree.error
async def on_app_command_error(interaction: discord.Interaction, error: app_commands.AppCommandError) -> None: # noqa: ARG001
await interaction.channel.send("An unknown error has occurred. Please check server logs for more information.")

client.run(discord_token)


Expand Down
2 changes: 1 addition & 1 deletion citadel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_responses(self) -> list[tuple[str, discord.Interaction]]:


async def get_openai_resp(msg: str) -> str:
completion = globals.get_openai_client().chat.completions.create(
completion = await globals.get_openai_client().chat.completions.create(
model=globals.get_openai_model(),
messages=[{"role": "user", "content": msg}],
)
Expand Down

0 comments on commit b168a42

Please sign in to comment.