-
Notifications
You must be signed in to change notification settings - Fork 895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add dialog entries summarization #496
Changes from 14 commits
b4f3ad7
bdee6c8
28523d0
04d2f6f
4905b67
b87aaf0
1a0f37a
b8f62df
b646b3e
2e9b2a3
9574110
dd517e3
979636b
9df467b
fa5b52d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,86 +1,83 @@ | ||
#!/usr/bin/env python3 | ||
|
||
|
||
import pandas as pd | ||
import asyncio | ||
from uuid import UUID, uuid4 | ||
|
||
from beartype import beartype | ||
from temporalio import activity | ||
|
||
# from agents_api.models.entry.entries_summarization import ( | ||
# entries_summarization_query, | ||
# get_toplevel_entries_query, | ||
# ) | ||
|
||
|
||
# TODO: Implement entry summarization queries | ||
# SCRUM-3 | ||
def entries_summarization_query(*args, **kwargs) -> pd.DataFrame: | ||
return pd.DataFrame() | ||
|
||
|
||
def get_toplevel_entries_query(*args, **kwargs) -> pd.DataFrame: | ||
return pd.DataFrame() | ||
|
||
|
||
# TODO: Implement entry summarization activities | ||
# SCRUM-4 | ||
from agents_api.autogen.openapi_model import Entry | ||
from agents_api.common.utils.datetime import utcnow | ||
from agents_api.env import summarization_model_name | ||
from agents_api.models.entry.entries_summarization import ( | ||
entries_summarization_query, | ||
get_toplevel_entries_query, | ||
) | ||
from agents_api.rec_sum.entities import get_entities | ||
from agents_api.rec_sum.summarize import summarize_messages | ||
from agents_api.rec_sum.trim import trim_messages | ||
|
||
|
||
@activity.defn | ||
@beartype | ||
async def summarization(session_id: str) -> None: | ||
raise NotImplementedError() | ||
|
||
# session_id = UUID(session_id) | ||
# entries = [] | ||
# entities_entry_ids = [] | ||
# for _, row in get_toplevel_entries_query(session_id=session_id).iterrows(): | ||
# if row["role"] == "system" and row.get("name") == "entities": | ||
# entities_entry_ids.append(UUID(row["entry_id"], version=4)) | ||
# else: | ||
# entries.append(row) | ||
|
||
# assert len(entries) > 0, "no need to summarize on empty entries list" | ||
|
||
# summarized, entities = await asyncio.gather( | ||
# summarize_messages(entries, model=summarization_model_name), | ||
# get_entities(entries, model=summarization_model_name), | ||
# ) | ||
# trimmed_messages = await trim_messages(summarized, model=summarization_model_name) | ||
# ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2 | ||
# new_entities_entry = Entry( | ||
# session_id=session_id, | ||
# source="summarizer", | ||
# role="system", | ||
# name="entities", | ||
# content=entities["content"], | ||
# timestamp=entries[0]["timestamp"] + ts_delta, | ||
# ) | ||
|
||
# entries_summarization_query( | ||
# session_id=session_id, | ||
# new_entry=new_entities_entry, | ||
# old_entry_ids=entities_entry_ids, | ||
# ) | ||
|
||
# trimmed_map = { | ||
# m["index"]: m["content"] for m in trimmed_messages if m.get("index") is not None | ||
# } | ||
|
||
# for idx, msg in enumerate(summarized): | ||
# new_entry = Entry( | ||
# session_id=session_id, | ||
# source="summarizer", | ||
# role="system", | ||
# name="information", | ||
# content=trimmed_map.get(idx, msg["content"]), | ||
# timestamp=entries[-1]["timestamp"] + 0.01, | ||
# ) | ||
|
||
# entries_summarization_query( | ||
# session_id=session_id, | ||
# new_entry=new_entry, | ||
# old_entry_ids=[ | ||
# UUID(entries[idx - 1]["entry_id"], version=4) | ||
# for idx in msg["summarizes"] | ||
# ], | ||
# ) | ||
session_id = UUID(session_id) | ||
entries = [] | ||
entities_entry_ids = [] | ||
for _, row in get_toplevel_entries_query(session_id=session_id).iterrows(): | ||
if row["role"] == "system" and row.get("name") == "entities": | ||
entities_entry_ids.append(UUID(row["entry_id"], version=4)) | ||
else: | ||
entries.append(row) | ||
|
||
assert len(entries) > 0, "no need to summarize on empty entries list" | ||
|
||
summarized, entities = await asyncio.gather( | ||
summarize_messages(entries, model=summarization_model_name), | ||
get_entities(entries, model=summarization_model_name), | ||
) | ||
trimmed_messages = await trim_messages(summarized, model=summarization_model_name) | ||
ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2 | ||
# TODO: set tokenizer, double check token_count calculation | ||
new_entities_entry = Entry( | ||
id=uuid4(), | ||
session_id=session_id, | ||
source="summarizer", | ||
role="system", | ||
name="entities", | ||
content=entities["content"], | ||
timestamp=entries[0]["timestamp"] + ts_delta, | ||
token_count=sum([len(c) // 3.5 for c in entities["content"]]), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
created_at=utcnow(), | ||
tokenizer="", | ||
) | ||
|
||
entries_summarization_query( | ||
session_id=session_id, | ||
new_entry=new_entities_entry, | ||
old_entry_ids=entities_entry_ids, | ||
) | ||
|
||
trimmed_map = { | ||
m["index"]: m["content"] for m in trimmed_messages if m.get("index") is not None | ||
} | ||
|
||
for idx, msg in enumerate(summarized): | ||
new_entry = Entry( | ||
session_id=session_id, | ||
source="summarizer", | ||
role="system", | ||
name="information", | ||
content=trimmed_map.get(idx, msg["content"]), | ||
timestamp=entries[-1]["timestamp"] + 0.01, | ||
) | ||
|
||
entries_summarization_query( | ||
session_id=session_id, | ||
new_entry=new_entry, | ||
old_entry_ids=[ | ||
UUID(entries[idx - 1]["entry_id"], version=4) | ||
for idx in msg["summarizes"] | ||
], | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,57 +4,57 @@ | |
from temporalio import activity | ||
|
||
from agents_api.autogen.openapi_model import Entry | ||
|
||
# from agents_api.models.entry.entries_summarization import get_toplevel_entries_query | ||
|
||
# TODO: Reimplement truncation queries | ||
# SCRUM-5 | ||
from agents_api.models.entry.delete_entries import delete_entries | ||
from agents_api.models.entry.entries_summarization import get_toplevel_entries_query | ||
|
||
|
||
def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list[UUID]: | ||
raise NotImplementedError() | ||
result: list[UUID] = [] | ||
|
||
if not len(messages): | ||
return messages | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function |
||
|
||
_token_cnt, _offset = 0, 0 | ||
# if messages[0].role == Role.system: | ||
# token_cnt, offset = messages[0].token_count, 1 | ||
token_cnt, offset = 0, 0 | ||
if messages[0].role == "system": | ||
token_cnt, offset = messages[0].token_count, 1 | ||
|
||
# for m in reversed(messages[offset:]): | ||
# token_cnt += m.token_count | ||
# if token_cnt < token_count_threshold: | ||
# continue | ||
# else: | ||
# result.append(m.id) | ||
for m in reversed(messages[offset:]): | ||
token_cnt += m.token_count | ||
if token_cnt >= token_count_threshold: | ||
result.append(m.id) | ||
|
||
# return result | ||
return result | ||
|
||
|
||
# TODO: Reimplement truncation activities | ||
# SCRUM-6 | ||
@activity.defn | ||
@beartype | ||
async def truncation(session_id: str, token_count_threshold: int) -> None: | ||
async def truncation( | ||
developer_id: str, session_id: str, token_count_threshold: int | ||
) -> None: | ||
session_id = UUID(session_id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding error handling for UUID conversion to handle invalid UUID strings gracefully. |
||
|
||
# delete_entries( | ||
# get_extra_entries( | ||
# [ | ||
# Entry( | ||
# entry_id=row["entry_id"], | ||
# session_id=session_id, | ||
# source=row["source"], | ||
# role=Role(row["role"]), | ||
# name=row["name"], | ||
# content=row["content"], | ||
# created_at=row["created_at"], | ||
# timestamp=row["timestamp"], | ||
# ) | ||
# for _, row in get_toplevel_entries_query( | ||
# session_id=session_id | ||
# ).iterrows() | ||
# ], | ||
# token_count_threshold, | ||
# ), | ||
# ) | ||
developer_id = UUID(developer_id) | ||
|
||
delete_entries( | ||
developer_id=developer_id, | ||
session_id=session_id, | ||
entry_ids=get_extra_entries( | ||
[ | ||
Entry( | ||
id=row["entry_id"], | ||
session_id=session_id, | ||
source=row["source"], | ||
role=row["role"], | ||
name=row["name"], | ||
content=row["content"], | ||
created_at=row["created_at"], | ||
timestamp=row["timestamp"], | ||
tokenizer=row["tokenizer"], | ||
token_count=row["token_count"], | ||
) | ||
for _, row in get_toplevel_entries_query( | ||
session_id=session_id | ||
).iterrows() | ||
], | ||
token_count_threshold, | ||
), | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,3 +70,31 @@ async def get_workflow_handle( | |
) | ||
|
||
return handle | ||
|
||
|
||
async def run_truncation_task( | ||
token_count_threshold: int, developer_id: UUID, session_id: UUID, job_id: UUID | ||
): | ||
from ..workflows.truncation import TruncationWorkflow | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider moving the import statement for |
||
|
||
client = await get_client() | ||
|
||
await client.start_workflow( | ||
TruncationWorkflow.run, | ||
args=[str(developer_id), str(session_id), token_count_threshold], | ||
task_queue="memory-task-queue", | ||
id=str(job_id), | ||
) | ||
|
||
|
||
async def run_summarization_task(session_id: UUID, job_id: UUID): | ||
from ..workflows.summarization import SummarizationWorkflow | ||
|
||
client = await get_client() | ||
|
||
await client.start_workflow( | ||
SummarizationWorkflow.run, | ||
args=[str(session_id)], | ||
task_queue="memory-task-queue", | ||
id=str(job_id), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure
entries
has at least two elements before calculatingts_delta
to avoid potentialIndexError
.