Skip to content

Commit

Permalink
feat: Summarize messages recursively
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed May 24, 2024
1 parent f3f67b7 commit 88aa9b2
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 182 deletions.
8 changes: 8 additions & 0 deletions agents-api/agents_api/activities/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import logging


logger = logging.getLogger(__name__)
h = logging.StreamHandler()
fmt = logging.Formatter("[%(asctime)s/%(levelname)s] - %(message)s")
h.setFormatter(fmt)
logger.addHandler(h)
102 changes: 82 additions & 20 deletions agents-api/agents_api/activities/summarization.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
#!/usr/bin/env python3

import asyncio
from pycozo.client import QueryException
from uuid import UUID
from typing import Callable
from textwrap import dedent
from temporalio import activity
from litellm import acompletion
from agents_api.models.entry.add_entries import add_entries_query
from agents_api.models.entry.delete_entries import delete_entries_by_ids_query
from agents_api.models.entry.entries_summarization import (
get_toplevel_entries_query,
entries_summarization_query,
)
from agents_api.common.protocol.entries import Entry
from ..model_registry import JULEP_MODELS
from ..env import summarization_model_name, model_inference_url, model_api_key
from ..env import model_inference_url, model_api_key
from agents_api.rec_sum.entities import get_entities
from agents_api.rec_sum.summarize import summarize_messages
from agents_api.rec_sum.trim import trim_messages
from agents_api.activities.logger import logger


example_previous_memory = """
Expand Down Expand Up @@ -157,31 +165,85 @@ async def run_prompt(
return parser(content.strip() if content is not None else "")


# @activity.defn
# async def summarization(session_id: str) -> None:
# session_id = UUID(session_id)
# entries = [
# Entry(**row)
# for _, row in get_toplevel_entries_query(session_id=session_id).iterrows()
# ]

# assert len(entries) > 0, "no need to summarize on empty entries list"

# response = await run_prompt(
# dialog=entries, previous_memories=[], model=summarization_model_name
# )

# new_entry = Entry(
# session_id=session_id,
# source="summarizer",
# role="system",
# name="information",
# content=response,
# timestamp=entries[-1].timestamp + 0.01,
# )

# entries_summarization_query(
# session_id=session_id,
# new_entry=new_entry,
# old_entry_ids=[e.id for e in entries],
# )


@activity.defn
async def summarization(session_id: str) -> None:
session_id = UUID(session_id)
entries = [
Entry(**row)
for _, row in get_toplevel_entries_query(session_id=session_id).iterrows()
]
entries = []
entities_entry_ids = []
for _, row in get_toplevel_entries_query(session_id=session_id).iterrows():
if row["role"] == "system" and row.get("name") == "entities":
entities_entry_ids.append(row["entry_id"])
else:
entries.append(row)

assert len(entries) > 0, "no need to summarize on empty entries list"

response = await run_prompt(
dialog=entries, previous_memories=[], model=f"openai/{summarization_model_name}"
trimmed_messages, entities = await asyncio.gather(
trim_messages(entries),
get_entities(entries),
)

new_entry = Entry(
session_id=session_id,
source="summarizer",
role="system",
name="information",
content=response,
timestamp=entries[-1].timestamp + 0.01,
summarized = await summarize_messages(trimmed_messages)

ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2

add_entries_query(
Entry(
session_id=session_id,
source="summarizer",
role="system",
name="entities",
content=entities["content"],
timestamp=entries[0]["timestamp"] + ts_delta,
)
)

entries_summarization_query(
session_id=session_id,
new_entry=new_entry,
old_entry_ids=[e.id for e in entries],
)
try:
delete_entries_by_ids_query(entry_ids=entities_entry_ids)
except QueryException as e:
logger.exception(e)

for msg in summarized:
new_entry = Entry(
session_id=session_id,
source="summarizer",
role="system",
name="information",
content=msg["content"],
timestamp=entries[-1]["timestamp"] + 0.01,
)

entries_summarization_query(
session_id=session_id,
new_entry=new_entry,
old_entry_ids=[entries[idx]["entry_id"] for idx in msg["summarizes"]],
)
47 changes: 47 additions & 0 deletions agents-api/agents_api/models/entry/delete_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,50 @@ def delete_entries_query(session_id: UUID) -> tuple[str, dict]:
}"""

return (query, {"session_id": str(session_id)})


@cozo_query
def delete_entries_by_ids_query(entry_ids: list[UUID]) -> tuple[str, dict]:
entry_ids = [f'to_uuid("{id}")' for id in entry_ids]

query = """
{
input[entry_id] <- $entry_ids
?[
session_id,
entry_id,
role,
name,
content,
source,
token_count,
created_at,
timestamp,
] := input[entry_id],
*entries{
session_id,
entry_id,
role,
name,
content,
source,
token_count,
created_at,
timestamp,
}
:delete entries {
session_id,
entry_id,
role,
name,
content,
source,
token_count,
created_at,
timestamp,
}
}"""

return (query, {"entry_ids": entry_ids})
3 changes: 0 additions & 3 deletions agents-api/agents_api/rec_sum/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .entities import get_entities
from .summarize import summarize_messages
from .trim import trim_messages
17 changes: 5 additions & 12 deletions agents-api/agents_api/rec_sum/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,21 @@
module_directory = Path(__file__).parent



with open(f"{module_directory}/entities_example_chat.json", 'r') as _f:
with open(f"{module_directory}/entities_example_chat.json", "r") as _f:
entities_example_chat = json.load(_f)



with open(f"{module_directory}/trim_example_chat.json", 'r') as _f:
with open(f"{module_directory}/trim_example_chat.json", "r") as _f:
trim_example_chat = json.load(_f)



with open(f"{module_directory}/trim_example_result.json", 'r') as _f:
with open(f"{module_directory}/trim_example_result.json", "r") as _f:
trim_example_result = json.load(_f)



with open(f"{module_directory}/summarize_example_chat.json", 'r') as _f:
with open(f"{module_directory}/summarize_example_chat.json", "r") as _f:
summarize_example_chat = json.load(_f)



with open(f"{module_directory}/summarize_example_result.json", 'r') as _f:
with open(f"{module_directory}/summarize_example_result.json", "r") as _f:
summarize_example_result = json.load(_f)


45 changes: 9 additions & 36 deletions agents-api/agents_api/rec_sum/entities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json

from tenacity import retry, stop_after_attempt, wait_fixed
from tenacity import retry, stop_after_attempt

from .data import entities_example_chat
from .generate import generate
Expand Down Expand Up @@ -41,50 +41,23 @@
- See the example to get a better idea of the task."""


make_entities_prompt = lambda session, user="a user", assistant="gpt-4-turbo", **_: [f"""\
You are given a session history of a chat between {user or "a user"} and {assistant or "gpt-4-turbo"}. The session is formatted in the ChatML JSON format (from OpenAI).
{entities_instructions}
<ct:example-session>
{json.dumps(entities_example_chat, indent=2)}
</ct:example-session>
<ct:example-plan>
{entities_example_plan}
</ct:example-plan>
<ct:example-entities>
{entities_example_result}
</ct:example-entities>""",

f"""\
Begin! Write the entities as a Markdown formatted list. First write your plan inside <ct:plan></ct:plan> and then the extracted entities between <ct:entities></ct:entities>.
<ct:session>
{json.dumps(session, indent=2)}
</ct:session>"""]

def make_entities_prompt(session, user="a user", assistant="gpt-4-turbo", **_):
return [
f"You are given a session history of a chat between {user or 'a user'} and {assistant or 'gpt-4-turbo'}. The session is formatted in the ChatML JSON format (from OpenAI).\n\n{entities_instructions}\n\n<ct:example-session>\n{json.dumps(entities_example_chat, indent=2)}\n</ct:example-session>\n\n<ct:example-plan>\n{entities_example_plan}\n</ct:example-plan>\n\n<ct:example-entities>\n{entities_example_result}\n</ct:example-entities>",
f"Begin! Write the entities as a Markdown formatted list. First write your plan inside <ct:plan></ct:plan> and then the extracted entities between <ct:entities></ct:entities>.\n\n<ct:session>\n{json.dumps(session, indent=2)}\n\n</ct:session>",
]


@retry(stop=stop_after_attempt(2))
async def get_entities(
chat_session,
model="gpt-4-turbo",
stop=["</ct:entities"],
model="gpt-4-turbo",
stop=["</ct:entities"],
temperature=0.7,
**kwargs,
):
assert len(chat_session) > 2, "Session is too short"

# Remove the system prompt if present
if (
chat_session[0]["role"] == "system"
and chat_session[0].get("name") != "entities"
):
chat_session = chat_session[1:]

names = get_names_from_session(chat_session)
system_prompt, user_message = make_entities_prompt(chat_session, **names)
messages = [chatml.system(system_prompt), chatml.user(user_message)]
Expand All @@ -100,5 +73,5 @@ async def get_entities(
result["content"] = result["content"].split("<ct:entities>")[-1].strip()
result["role"] = "system"
result["name"] = "entities"

return chatml.make(**result)
9 changes: 4 additions & 5 deletions agents-api/agents_api/rec_sum/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@

@retry(wait=wait_fixed(2), stop=stop_after_attempt(5))
async def generate(
messages: list[dict],
client: AsyncClient=client,
model: str="gpt-4-turbo",
messages: list[dict],
client: AsyncClient = client,
model: str = "gpt-4-turbo",
**kwargs
) -> dict:
result = await client.chat.completions.create(
model=model, messages=messages, **kwargs
)

result = result.choices[0].message.__dict__

return result

Loading

0 comments on commit 88aa9b2

Please sign in to comment.