Skip to content

Commit

Permalink
Migrate db models to pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
cmyui committed Oct 24, 2024
1 parent 9a64ddf commit db92572
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 32 deletions.
14 changes: 7 additions & 7 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ async def _calculate_per_user_costs(
defaultdict(lambda: defaultdict(int))
)
for message in messages:
if message["role"] != "user":
if message.role != "user":
continue

thread_id = message["thread_id"]
thread_id = message.thread_id
thread = threads_cache.get(thread_id)
if thread is None:
thread = await threads.fetch_one(thread_id)
Expand All @@ -92,9 +92,9 @@ async def _calculate_per_user_costs(
continue
threads_cache[thread_id] = thread

user_id = message["discord_user_id"]
user_id = message.discord_user_id
per_model_input_tokens = per_user_per_model_input_tokens[user_id]
per_model_input_tokens[thread["model"]] += message["tokens_used"]
per_model_input_tokens[thread.model] += message.tokens_used

per_user_cost: dict[int, float] = defaultdict(float)
for user_id, per_model_tokens in per_user_per_model_input_tokens.items():
Expand Down Expand Up @@ -204,7 +204,7 @@ async def model(
)
return

await threads.partial_update(thread["thread_id"], model=model)
await threads.partial_update(thread.thread_id, model=model)

await interaction.followup.send(
content="\n".join(
Expand Down Expand Up @@ -253,7 +253,7 @@ async def context(
return

await threads.partial_update(
thread["thread_id"],
thread.thread_id,
context_length=context_length,
)

Expand Down Expand Up @@ -457,7 +457,7 @@ async def transcript(
)

transcript_content = "\n".join(
"[{created_at:%d/%m/%Y %I:%M:%S%p}] {content}".format(**msg)
f"[{msg.created_at:%d/%m/%Y %I:%M:%S%p}] {msg.content}"
for msg in current_thread_messages
)
with io.BytesIO(transcript_content.encode()) as f:
Expand Down
23 changes: 12 additions & 11 deletions app/repositories/thread_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from datetime import datetime
from typing import Any
from typing import Literal
from typing import TypedDict

from pydantic import BaseModel

from app import state

Expand All @@ -17,7 +18,7 @@
"""


class ThreadMessage(TypedDict):
class ThreadMessage(BaseModel):
thread_message_id: int
thread_id: int
content: str
Expand All @@ -28,15 +29,15 @@ class ThreadMessage(TypedDict):


def deserialize(rec: Mapping[str, Any]) -> ThreadMessage:
return {
"thread_message_id": rec["thread_message_id"],
"thread_id": rec["thread_id"],
"content": rec["content"],
"discord_user_id": rec["discord_user_id"],
"role": rec["role"],
"tokens_used": rec["tokens_used"],
"created_at": rec["created_at"],
}
return ThreadMessage(
thread_message_id=rec["thread_message_id"],
thread_id=rec["thread_id"],
content=rec["content"],
discord_user_id=rec["discord_user_id"],
role=rec["role"],
tokens_used=rec["tokens_used"],
created_at=rec["created_at"],
)


async def create(
Expand Down
19 changes: 10 additions & 9 deletions app/repositories/threads.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from typing import TypedDict

from pydantic import BaseModel

from app import state
from app.adapters.openai.gpt import OpenAIModel
Expand All @@ -15,7 +16,7 @@
"""


class Thread(TypedDict):
class Thread(BaseModel):
thread_id: int
initiator_user_id: int
model: OpenAIModel
Expand All @@ -24,13 +25,13 @@ class Thread(TypedDict):


def deserialize(rec: Mapping[str, Any]) -> Thread:
return {
"thread_id": rec["thread_id"],
"initiator_user_id": rec["initiator_user_id"],
"model": OpenAIModel(rec["model"]),
"context_length": rec["context_length"],
"created_at": rec["created_at"],
}
return Thread(
thread_id=rec["thread_id"],
initiator_user_id=rec["initiator_user_id"],
model=OpenAIModel(rec["model"]),
context_length=rec["context_length"],
created_at=rec["created_at"],
)


async def create(
Expand Down
10 changes: 5 additions & 5 deletions app/usecases/ai_conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ async def send_message_to_thread(

message_history: list[gpt.Message] = [
{
"role": m["role"],
"content": [{"type": "text", "text": m["content"]}],
"role": m.role,
"content": [{"type": "text", "text": m.content}],
}
for m in thread_history[-tracked_thread["context_length"] :]
for m in thread_history[-tracked_thread.context_length :]
]

# Append this new message (along w/ any attachments) to the history
Expand Down Expand Up @@ -122,7 +122,7 @@ async def send_message_to_thread(
functions = openai_functions.get_full_openai_functions_schema()
try:
gpt_response = await gpt.send(
model=tracked_thread["model"],
model=tracked_thread.model,
messages=message_history,
functions=functions,
)
Expand Down Expand Up @@ -178,7 +178,7 @@ async def send_message_to_thread(
)
try:
gpt_response = await gpt.send(
model=tracked_thread["model"],
model=tracked_thread.model,
messages=message_history,
)
except Exception as exc:
Expand Down

0 comments on commit db92572

Please sign in to comment.