From 26b65eae915542b1336a85237bdc269129ae7f10 Mon Sep 17 00:00:00 2001 From: ShiHan Wan Date: Wed, 16 Oct 2024 17:48:11 -0400 Subject: [PATCH] feat: add relevant memories to retain so it's smarter (#24) * fix forgetting vector or triple stores in ephemeral mode * add relevant memories to retain * fix tests * remove debug --- memonto/core/forget.py | 20 +++++++++--------- memonto/core/retain.py | 27 +++++++++++++++++++++++++ memonto/prompts/commit_to_memory.prompt | 12 ++++++++--- tests/core/test_retain.py | 3 +++ 4 files changed, 49 insertions(+), 13 deletions(-) diff --git a/memonto/core/forget.py b/memonto/core/forget.py index f87dd4f..e5043b9 100644 --- a/memonto/core/forget.py +++ b/memonto/core/forget.py @@ -14,14 +14,14 @@ def _forget( ) -> None: if ephemeral: data.remove((None, None, None)) + else: + try: + if vector_store: + vector_store.delete(id) - try: - if vector_store: - vector_store.delete(id) - - if triple_store: - triple_store.delete_all(id) - except ValueError as e: - logger.warning(e) - except Exception as e: - logger.error(e) + if triple_store: + triple_store.delete_all(id) + except ValueError as e: + logger.warning(e) + except Exception as e: + logger.error(e) diff --git a/memonto/core/retain.py b/memonto/core/retain.py index e95e1ef..f170ce9 100644 --- a/memonto/core/retain.py +++ b/memonto/core/retain.py @@ -145,6 +145,24 @@ def update_memory( return str(updated_memory) +def find_relevant_memories( + data: Graph, + vector_store: VectorStoreModel, + message: str, + id: str, + ephemeral: bool, +) -> str: + relevant_memory = "" + + if ephemeral: + relevant_memory = str(data.serialize(format="turtle")) + else: + relevant_memory = str(vector_store.search(message=message, id=id, k=3)) + + logger.debug(f"relevant_memory\n{relevant_memory}\n") + return relevant_memory + + def save_memory( ontology: Graph, namespaces: dict[str, Namespace], @@ -158,12 +176,21 @@ def save_memory( str_ontology: str, updated_memory: str, ) -> None: + relevant_memory = find_relevant_memories( + data=data, + vector_store=vector_store, + message=message, + id=id, + ephemeral=ephemeral, + ) + script = llm.prompt( prompt_name="commit_to_memory", temperature=0.2, ontology=str_ontology, user_message=message, updated_memory=updated_memory, + relevant_memory=relevant_memory, ) logger.debug(f"Retain Script\n{script}\n") diff --git a/memonto/prompts/commit_to_memory.prompt b/memonto/prompts/commit_to_memory.prompt index c136bb7..7980f1f 100644 --- a/memonto/prompts/commit_to_memory.prompt +++ b/memonto/prompts/commit_to_memory.prompt @@ -10,14 +10,20 @@ And this user message: ${user_message} ``` -And these removed triples: +And these removed memories: ``` ${updated_memory} ``` +And these relevant memories: +``` +${relevant_memory} +``` + Analyze the user message to find AS MUCH new information AS POSSIBLE that could fit onto the above ontology while adhering to these rules: -- First find all the new information in the user message that maps onto BOTH the above ontology and ESPECIALLY the removed triples. -- Second apply the existing namespaces to the extracted information. +- First find all the new information in the user message that maps onto BOTH the above ontology and ESPECIALLY the removed memories. +- Second check if the relevant memories can help extract even more information from the user message that maps onto the ontology or removed memories. +- Third apply only the existing namespaces to the extracted information. - Finally create the script that will add the extracted information to an rdflib graph called `data`. - NEVER generate code that initializes new graphs, namespaces, classes, properties, etc. - GENERATE Python code to add the triples with the relevant information assuming rdflib Graph `data` and the newly added namespaces already exists. diff --git a/tests/core/test_retain.py b/tests/core/test_retain.py index d3507db..32e0a95 100644 --- a/tests/core/test_retain.py +++ b/tests/core/test_retain.py @@ -66,6 +66,7 @@ def test_commit_memory( ontology=ANY, user_message=user_query, updated_memory="", + relevant_memory=ANY, ) assert mock_llm.prompt.call_count == 1 @@ -103,6 +104,7 @@ def test_commit_memory_with_exception( ontology=ANY, user_message=user_query, updated_memory="", + relevant_memory=ANY, ) ctmeh_prompt = call( @@ -146,6 +148,7 @@ def test_commit_memory_auto_expand( ontology=ANY, user_message=user_query, updated_memory="", + relevant_memory=ANY, ) eo_prompt = call(