diff --git a/memonto/core/forget.py b/memonto/core/forget.py new file mode 100644 index 0000000..091dae2 --- /dev/null +++ b/memonto/core/forget.py @@ -0,0 +1,14 @@ +from memonto.stores.triple.base_store import TripleStoreModel +from memonto.stores.vector.base_store import VectorStoreModel + + +def forget_memory( + id: str, + triple_store: TripleStoreModel, + vector_store: VectorStoreModel, +) -> None: + if vector_store: + vector_store.delete(id) + + if triple_store: + triple_store.delete(id) diff --git a/memonto/core/recall.py b/memonto/core/recall.py index e0a186a..bf334d3 100644 --- a/memonto/core/recall.py +++ b/memonto/core/recall.py @@ -88,19 +88,23 @@ def recall_memory( id: str, ) -> str: if message: - matched_triples = vector_store.search(message=message, id=id) - triples = _hydrate_triples( - triples=matched_triples, - triple_store=triple_store, - id=id, - ) - contextual_memory = _find_adjacent_triples( - triples=triples, - triple_store=triple_store, - id=id, - ) - - logger.debug(f"Matched Triples\n{json.dumps(triples, indent=2)}\n") + try: + matched_triples = vector_store.search(message=message, id=id) + triples = _hydrate_triples( + triples=matched_triples, + triple_store=triple_store, + id=id, + ) + contextual_memory = _find_adjacent_triples( + triples=triples, + triple_store=triple_store, + id=id, + ) + + logger.debug(f"Matched Triples\n{json.dumps(triples, indent=2)}\n") + except ValueError as e: + logger.debug(f"Recall Exception\n{e}\n") + contextual_memory = "" else: contextual_memory = _find_all(triple_store=triple_store, id=id) diff --git a/memonto/memonto.py b/memonto/memonto.py index db61e5f..29c24b6 100644 --- a/memonto/memonto.py +++ b/memonto/memonto.py @@ -4,6 +4,7 @@ from memonto.core.configure import configure from memonto.core.init import init +from memonto.core.forget import forget_memory from memonto.core.query import query_memory_data from memonto.core.recall import recall_memory from memonto.core.remember import load_memory @@ -109,6 +110,7 @@ def recall(self, message: str = None) -> str: id=self.id, ) + # TODO: no longer needed, can be deprecated or removed def remember(self) -> None: """ Load existing memories from the memory store to a memonto instance. @@ -123,11 +125,15 @@ def remember(self) -> None: id=self.id, ) - def forget(self): + def forget(self) -> None: """ Remove memories from the memory store. """ - pass + forget_memory( + id=self.id, + triple_store=self.triple_store, + vector_store=self.vector_store, + ) def query(self, uri: URIRef = None, query: str = None) -> list: """ diff --git a/memonto/prompts/summarize_memory.prompt b/memonto/prompts/summarize_memory.prompt index e3a9bc5..509bfcc 100644 --- a/memonto/prompts/summarize_memory.prompt +++ b/memonto/prompts/summarize_memory.prompt @@ -9,4 +9,5 @@ Describe the RDF graph in one paragraph and make sure to follow these rules: - FOCUS on the telling a story about the who, what, where, how, etc. - LEAVE OUT anything not explicitly defined, do not make assumptions. - DO NOT describe the RDF graph schema and DO NOT mention the RDF graph at all. +- If the RDF graph is empty then just return that there are currently no stored memory. - Make sure to use plain and simple English. \ No newline at end of file diff --git a/memonto/stores/triple/jena.py b/memonto/stores/triple/jena.py index 1031ff9..3c08b4b 100644 --- a/memonto/stores/triple/jena.py +++ b/memonto/stores/triple/jena.py @@ -160,6 +160,15 @@ def get( return result["results"]["bindings"] + def delete(self, id: str = None) -> None: + query = f"""DROP GRAPH ; DROP GRAPH ;""" + + self._query( + url=f"{self.connection_url}/update", + method=POST, + query=query, + ) + def query(self, query: str, method: str = GET, format: str = JSON) -> list: result = self._query( url=f"{self.connection_url}/sparql", diff --git a/memonto/stores/vector/chroma.py b/memonto/stores/vector/chroma.py index 89dae6f..c906981 100644 --- a/memonto/stores/vector/chroma.py +++ b/memonto/stores/vector/chroma.py @@ -78,3 +78,6 @@ def search(self, message: str, id: str = None, k: int = 3) -> list[dict]: ) return [json.loads(t.get("triple", "{}")) for t in matched["metadatas"][0]] + + def delete(self, id: str) -> None: + self.client.delete_collection(id)