diff --git a/app/main.py b/app/main.py index e93c55a..9c89205 100755 --- a/app/main.py +++ b/app/main.py @@ -76,10 +76,10 @@ async def _calculate_per_user_costs( defaultdict(lambda: defaultdict(int)) ) for message in messages: - if message["role"] != "user": + if message.role != "user": continue - thread_id = message["thread_id"] + thread_id = message.thread_id thread = threads_cache.get(thread_id) if thread is None: thread = await threads.fetch_one(thread_id) @@ -92,9 +92,9 @@ async def _calculate_per_user_costs( continue threads_cache[thread_id] = thread - user_id = message["discord_user_id"] + user_id = message.discord_user_id per_model_input_tokens = per_user_per_model_input_tokens[user_id] - per_model_input_tokens[thread["model"]] += message["tokens_used"] + per_model_input_tokens[thread.model] += message.tokens_used per_user_cost: dict[int, float] = defaultdict(float) for user_id, per_model_tokens in per_user_per_model_input_tokens.items(): @@ -204,7 +204,7 @@ async def model( ) return - await threads.partial_update(thread["thread_id"], model=model) + await threads.partial_update(thread.thread_id, model=model) await interaction.followup.send( content="\n".join( @@ -253,7 +253,7 @@ async def context( return await threads.partial_update( - thread["thread_id"], + thread.thread_id, context_length=context_length, ) @@ -457,7 +457,7 @@ async def transcript( ) transcript_content = "\n".join( - "[{created_at:%d/%m/%Y %I:%M:%S%p}] {content}".format(**msg) + f"[{msg.created_at:%d/%m/%Y %I:%M:%S%p}] {msg.content}" for msg in current_thread_messages ) with io.BytesIO(transcript_content.encode()) as f: diff --git a/app/repositories/thread_messages.py b/app/repositories/thread_messages.py index 6aa078b..e1c940a 100644 --- a/app/repositories/thread_messages.py +++ b/app/repositories/thread_messages.py @@ -2,7 +2,8 @@ from datetime import datetime from typing import Any from typing import Literal -from typing import TypedDict + +from pydantic import BaseModel from app import state @@ -17,7 +18,7 @@ """ -class ThreadMessage(TypedDict): +class ThreadMessage(BaseModel): thread_message_id: int thread_id: int content: str @@ -28,15 +29,15 @@ class ThreadMessage(TypedDict): def deserialize(rec: Mapping[str, Any]) -> ThreadMessage: - return { - "thread_message_id": rec["thread_message_id"], - "thread_id": rec["thread_id"], - "content": rec["content"], - "discord_user_id": rec["discord_user_id"], - "role": rec["role"], - "tokens_used": rec["tokens_used"], - "created_at": rec["created_at"], - } + return ThreadMessage( + thread_message_id=rec["thread_message_id"], + thread_id=rec["thread_id"], + content=rec["content"], + discord_user_id=rec["discord_user_id"], + role=rec["role"], + tokens_used=rec["tokens_used"], + created_at=rec["created_at"], + ) async def create( diff --git a/app/repositories/threads.py b/app/repositories/threads.py index 3252326..dc58411 100644 --- a/app/repositories/threads.py +++ b/app/repositories/threads.py @@ -1,7 +1,8 @@ from collections.abc import Mapping from datetime import datetime from typing import Any -from typing import TypedDict + +from pydantic import BaseModel from app import state from app.adapters.openai.gpt import OpenAIModel @@ -15,7 +16,7 @@ """ -class Thread(TypedDict): +class Thread(BaseModel): thread_id: int initiator_user_id: int model: OpenAIModel @@ -24,13 +25,13 @@ class Thread(TypedDict): def deserialize(rec: Mapping[str, Any]) -> Thread: - return { - "thread_id": rec["thread_id"], - "initiator_user_id": rec["initiator_user_id"], - "model": OpenAIModel(rec["model"]), - "context_length": rec["context_length"], - "created_at": rec["created_at"], - } + return Thread( + thread_id=rec["thread_id"], + initiator_user_id=rec["initiator_user_id"], + model=OpenAIModel(rec["model"]), + context_length=rec["context_length"], + created_at=rec["created_at"], + ) async def create( diff --git a/app/usecases/ai_conversations.py b/app/usecases/ai_conversations.py index 041bee0..e040dea 100644 --- a/app/usecases/ai_conversations.py +++ b/app/usecases/ai_conversations.py @@ -87,10 +87,10 @@ async def send_message_to_thread( message_history: list[gpt.Message] = [ { - "role": m["role"], - "content": [{"type": "text", "text": m["content"]}], + "role": m.role, + "content": [{"type": "text", "text": m.content}], } - for m in thread_history[-tracked_thread["context_length"] :] + for m in thread_history[-tracked_thread.context_length :] ] # Append this new message (along w/ any attachments) to the history @@ -122,7 +122,7 @@ async def send_message_to_thread( functions = openai_functions.get_full_openai_functions_schema() try: gpt_response = await gpt.send( - model=tracked_thread["model"], + model=tracked_thread.model, messages=message_history, functions=functions, ) @@ -178,7 +178,7 @@ async def send_message_to_thread( ) try: gpt_response = await gpt.send( - model=tracked_thread["model"], + model=tracked_thread.model, messages=message_history, ) except Exception as exc: