Skip to content

Commit

Permalink
perf: improve recall summarize prompt (#15)
Browse files Browse the repository at this point in the history
* optimize memory summary prompt

* fix tests

* change recall to be able to go n levels deep

* fix test

* format

* remove unused import
  • Loading branch information
shihanwan authored Oct 4, 2024
1 parent 549ee21 commit 165d71c
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 52 deletions.
133 changes: 91 additions & 42 deletions memonto/core/recall.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from rdflib import Graph
from rdflib import Graph, URIRef, Literal, BNode

from memonto.llms.base_llm import LLMModel
from memonto.stores.triple.base_store import TripleStoreModel
Expand All @@ -11,61 +10,105 @@ def _hydrate_triples(
triples: list,
triple_store: VectorStoreModel,
id: str = None,
) -> list:
) -> Graph:
triple_values = " ".join(
f"(<{triple['s']}> <{triple['p']}> \"{triple['o']}\")" for triple in triples
)

graph_id = f"data-{id}" if id else "data"

query = f"""
SELECT ?s ?p ?o
CONSTRUCT {{
?s ?p ?o .
}}
WHERE {{
GRAPH <{graph_id}> {{
VALUES (?s ?p ?o) {{ {triple_values} }}
?s ?p ?o .
}}
}}
"""

return triple_store.query(query=query)
result = triple_store.query(query=query, format="turtle")

g = Graph()
g.parse(data=result, format="turtle")

return g


def _get_formatted_node(node: URIRef | Literal | BNode) -> str:
if isinstance(node, URIRef):
return f"<{str(node)}>"
elif isinstance(node, Literal):
return f'"{str(node)}"'
elif isinstance(node, BNode):
return f"_:{str(node)}"
else:
return f'"{str(node)}"'


def _find_adjacent_triples(
triples: list,
triples: Graph,
triple_store: VectorStoreModel,
id: str = None,
depth: int = 1,
) -> str:
nodes_set = set()

for t in triples:
for key in ["s", "o"]:
node = t[key]
node_value = node["value"]
node_type = node["type"]
for s, p, o in triples:
nodes_set.add(_get_formatted_node(s))
nodes_set.add(_get_formatted_node(o))

if node_type == "uri":
formatted_node = f"<{node_value}>"
elif node_type == "literal":
formatted_node = f'"{node_value}"'
else:
formatted_node = f'"{node_value}"'
explored_nodes = set(nodes_set)
new_nodes_set = nodes_set.copy()

nodes_set.add(formatted_node)
for _ in range(depth):
if not new_nodes_set:
break

node_list = ", ".join(nodes_set)
graph_id = f"data-{id}" if id else "data"
node_list = ", ".join(new_nodes_set)
graph_id = f"data-{id}" if id else "data"

query = f"""
CONSTRUCT {{
?s ?p ?o .
}}
WHERE {{
GRAPH <{graph_id}> {{
query = f"""
CONSTRUCT {{
?s ?p ?o .
FILTER (?s IN ({node_list}) || ?o IN ({node_list}))
}}
}}
"""
WHERE {{
GRAPH <{graph_id}> {{
?s ?p ?o .
FILTER (?s IN ({node_list}) || ?o IN ({node_list}))
}}
}}
"""

logger.debug(f"Find adjacent triples SPARQL query\n{query}\n")

try:
result_triples = triple_store.query(query=query, format="turtle")
except Exception as e:
raise ValueError(f"SPARQL Query Error: {e}")

if result_triples is None:
raise ValueError("SPARQL query returned no results")

graph = Graph()
graph.parse(data=result_triples, format="turtle")

temp_new_nodes_set = set()
for s, p, o in graph:
formatted_subject = _get_formatted_node(s)
formatted_object = _get_formatted_node(o)

if formatted_subject not in explored_nodes:
temp_new_nodes_set.add(formatted_subject)
explored_nodes.add(formatted_subject)

if formatted_object not in explored_nodes:
temp_new_nodes_set.add(formatted_object)
explored_nodes.add(formatted_object)

new_nodes_set = temp_new_nodes_set

return triple_store.query(query=query, format="turtle")

Expand All @@ -82,38 +125,44 @@ def _recall(
llm: LLMModel,
vector_store: VectorStoreModel,
triple_store: TripleStoreModel,
message: str,
context: str,
id: str,
ephemeral: bool,
) -> str:
if ephemeral:
contextual_memory = data.serialize(format="turtle")
elif message:
memory = data.serialize(format="turtle")
elif context:
try:
matched_triples = vector_store.search(message=message, id=id)
triples = _hydrate_triples(
matched_triples = vector_store.search(message=context, id=id)

matched_graph = _hydrate_triples(
triples=matched_triples,
triple_store=triple_store,
id=id,
)
contextual_memory = _find_adjacent_triples(
triples=triples,

matched_triples = matched_graph.serialize(format="turtle")
logger.debug(f"Matched Triples\n{matched_triples}\n")

memory = _find_adjacent_triples(
triples=matched_graph,
triple_store=triple_store,
id=id,
depth=1,
)

logger.debug(f"Matched Triples\n{json.dumps(triples, indent=2)}\n")
logger.debug(f"Adjacent Triples\n{memory}\n")
except ValueError as e:
logger.debug(f"Recall Exception\n{e}\n")
contextual_memory = ""
memory = ""
else:
contextual_memory = _find_all(triple_store=triple_store, id=id)
memory = _find_all(triple_store=triple_store, id=id)

logger.debug(f"Contextual Triples\n{contextual_memory}\n")
logger.debug(f"Contextual Triples\n{memory}\n")

summarized_memory = llm.prompt(
prompt_name="summarize_memory",
memory=contextual_memory,
context=context,
memory=memory,
)

logger.debug(f"Summarized Memory\n{summarized_memory}\n")
Expand Down
6 changes: 3 additions & 3 deletions memonto/memonto.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,20 @@ def recall(self, context: str = None) -> str:
llm=self.llm,
triple_store=self.triple_store,
vector_store=self.vector_store,
message=context,
context=context,
id=self.id,
ephemeral=self.ephemeral,
)

@require_config("llm", "triple_store", "vector_store")
async def arecall(self, message: str = None) -> str:
async def arecall(self, context: str = None) -> str:
return await asyncio.to_thread(
_recall,
data=self.data,
llm=self.llm,
triple_store=self.triple_store,
vector_store=self.vector_store,
message=message,
context=context,
id=self.id,
ephemeral=self.ephemeral,
)
Expand Down
11 changes: 8 additions & 3 deletions memonto/prompts/summarize_memory.prompt
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
You are trying to describe the following RDF graph in plain English.

Here is the the RDF graph:
Here is the user message which serves as context:
```
${context}
```

And here is the RDF graph that contains information relvant to the context:
```
${memory}
```

Describe the RDF graph in one paragraph and make sure to follow these rules:
Summarize the user message and the RDF graph in one paragraph while following these rules:
- FOCUS on the telling a story about the who, what, where, how, etc.
- LEAVE OUT anything not explicitly defined, do not make assumptions.
- DO NOT describe the RDF graph schema and DO NOT mention the RDF graph at all.
- DO NOT mention the RDF graph schema and DO NOT mention the RDF graph at all.
- If the RDF graph is empty then just return that there are currently no stored memory.
- Make sure to use plain and simple English.
10 changes: 6 additions & 4 deletions tests/core/test_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,14 @@ def test_fetch_all_memory(mock_find_all, mock_llm, mock_store, id, data_graph):
llm=mock_llm,
vector_store=mock_store,
triple_store=mock_store,
message=None,
context=None,
id=id,
ephemeral=False,
)

mock_llm.prompt.assert_called_once_with(
prompt_name="summarize_memory",
context=None,
memory=all_memory,
)

Expand All @@ -96,20 +97,21 @@ def test_fetch_some_memory(
):
some_memory = "some memory"
mock_find_adjacent_triples.return_value = some_memory
mock_hydrate_triples.return_value = []
mock_hydrate_triples.return_value = Graph()

_recall(
data=data_graph,
llm=mock_llm,
vector_store=mock_store,
triple_store=mock_store,
message=user_query,
context=user_query,
id=id,
ephemeral=False,
)

mock_llm.prompt.assert_called_once_with(
prompt_name="summarize_memory",
context=user_query,
memory=some_memory,
)

Expand All @@ -121,7 +123,7 @@ def test_fetch_some_memory_ephemeral(mock_llm, data_graph):
llm=mock_llm,
vector_store=None,
triple_store=None,
message=None,
context=None,
id=None,
ephemeral=True,
)
Expand Down

0 comments on commit 165d71c

Please sign in to comment.