From 92fe22b61458fff6d5cbf4a94e5faef3562133a0 Mon Sep 17 00:00:00 2001 From: shihanwan Date: Sat, 12 Oct 2024 15:11:51 -0400 Subject: [PATCH] fix a few issues with recall --- memonto/core/recall.py | 43 +++++++++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/memonto/core/recall.py b/memonto/core/recall.py index 5863d9a..488dcb1 100644 --- a/memonto/core/recall.py +++ b/memonto/core/recall.py @@ -23,8 +23,9 @@ def _hydrate_triples( }} WHERE {{ GRAPH <{graph_id}> {{ - VALUES (?s ?p ?o) {{ {triple_values} }} + VALUES (?s ?p ?o_string) {{ {triple_values} }} ?s ?p ?o . + FILTER (STR(?s) = STR(?s) && STR(?p) = STR(?p) && STR(?o) = ?o_string) }} }} """ @@ -63,6 +64,8 @@ def _find_adjacent_triples( explored_nodes = set(nodes_set) new_nodes_set = nodes_set.copy() + query = None + for _ in range(depth): if not new_nodes_set: break @@ -110,6 +113,9 @@ def _find_adjacent_triples( new_nodes_set = temp_new_nodes_set + if query is None: + return "" + return triple_store.query(query=query, format="turtle") @@ -128,27 +134,28 @@ def _find_all(triple_store: TripleStoreModel, id: str) -> str: return str(result) -def _recall( +def get_contextual_memory( data: Graph, - llm: LLMModel, vector_store: VectorStoreModel, triple_store: TripleStoreModel, context: str, id: str, ephemeral: bool, -) -> str: +): + memory = "" + if ephemeral: memory = data.serialize(format="turtle") elif context: try: matched_triples = vector_store.search(message=context, id=id) + logger.debug(f"Matched Triples Raw\n{matched_triples}\n") matched_graph = _hydrate_triples( triples=matched_triples, triple_store=triple_store, id=id, ) - matched_triples = matched_graph.serialize(format="turtle") logger.debug(f"Matched Triples\n{matched_triples}\n") @@ -161,16 +168,34 @@ def _recall( logger.debug(f"Adjacent Triples\n{memory}\n") except ValueError as e: logger.debug(f"Recall Exception\n{e}\n") - memory = "" else: memory = _find_all(triple_store=triple_store, id=id) - context = "" - logger.debug(f"Contextual Triples\n{memory}\n") + logger.debug(f"Contextual Memory\n{memory}\n") + return memory + + +def _recall( + data: Graph, + llm: LLMModel, + vector_store: VectorStoreModel, + triple_store: TripleStoreModel, + context: str, + id: str, + ephemeral: bool, +) -> str: + memory = get_contextual_memory( + data=data, + vector_store=vector_store, + triple_store=triple_store, + context=context, + id=id, + ephemeral=ephemeral, + ) summarized_memory = llm.prompt( prompt_name="summarize_memory", - context=context, + context=context or "", memory=memory, )