From 3d7989a77d18251ba0eac03e0f4699b99be84bd2 Mon Sep 17 00:00:00 2001 From: ShiHan Wan Date: Sun, 29 Sep 2024 20:37:02 -0400 Subject: [PATCH] feat: add ephemeral mode (#9) * add ephemeral mode * fix and add tests * format --- memonto/core/forget.py | 7 +++++ memonto/core/recall.py | 7 ++++- memonto/core/retain.py | 8 +++--- memonto/core/retrieve.py | 22 ++++++++++++++- memonto/memonto.py | 43 ++++++++++++++++------------- memonto/utils/decorators.py | 3 +++ tests/core/test_recall.py | 53 ++++++++++++++++++++++++++++++++++-- tests/core/test_retain.py | 3 +++ tests/core/test_retrieve.py | 54 +++++++++++++++++++++++++++++++++++++ 9 files changed, 175 insertions(+), 25 deletions(-) create mode 100644 tests/core/test_retrieve.py diff --git a/memonto/core/forget.py b/memonto/core/forget.py index db42319..e8144b1 100644 --- a/memonto/core/forget.py +++ b/memonto/core/forget.py @@ -1,13 +1,20 @@ +from rdflib import Graph + from memonto.stores.triple.base_store import TripleStoreModel from memonto.stores.vector.base_store import VectorStoreModel from memonto.utils.logger import logger def _forget( + data: Graph, id: str, triple_store: TripleStoreModel, vector_store: VectorStoreModel, + ephemeral: bool, ) -> None: + if ephemeral: + data.remove((None, None, None)) + try: if vector_store: vector_store.delete(id) diff --git a/memonto/core/recall.py b/memonto/core/recall.py index 3df6510..61cd604 100644 --- a/memonto/core/recall.py +++ b/memonto/core/recall.py @@ -1,4 +1,5 @@ import json +from rdflib import Graph from memonto.llms.base_llm import LLMModel from memonto.stores.triple.base_store import TripleStoreModel @@ -77,13 +78,17 @@ def _find_all(triple_store: TripleStoreModel) -> str: def _recall( + data: Graph, llm: LLMModel, vector_store: VectorStoreModel, triple_store: TripleStoreModel, message: str, id: str, + ephemeral: bool, ) -> str: - if message: + if ephemeral: + contextual_memory = data.serialize(format="turtle") + elif message: try: matched_triples = vector_store.search(message=message, id=id) triples = _hydrate_triples( diff --git a/memonto/core/retain.py b/memonto/core/retain.py index 71d23f8..9a72008 100644 --- a/memonto/core/retain.py +++ b/memonto/core/retain.py @@ -75,6 +75,7 @@ def _retain( message: str, id: str, auto_expand: bool, + ephemeral: bool, ) -> None: if auto_expand: ontology = expand_ontology( @@ -105,6 +106,7 @@ def _retain( logger.debug(f"Data Graph\n{data.serialize(format='turtle')}\n") - triple_store.save(ontology=ontology, data=data, id=id) - if vector_store: - vector_store.save(g=data, id=id) + if not ephemeral: + triple_store.save(ontology=ontology, data=data, id=id) + if vector_store: + vector_store.save(g=data, id=id) diff --git a/memonto/core/retrieve.py b/memonto/core/retrieve.py index 41455fe..b7d67ec 100644 --- a/memonto/core/retrieve.py +++ b/memonto/core/retrieve.py @@ -3,14 +3,34 @@ from memonto.stores.triple.base_store import TripleStoreModel +def get_triples_with_uri(g: Graph, uri: str) -> list[dict]: + uri_ref = URIRef(uri) + triples = [] + + for s, p, o in g.triples((uri_ref, None, None)): + triples.append({"s": s, "p": p, "o": o}) + + for s, p, o in g.triples((None, uri_ref, None)): + triples.append({"s": s, "p": p, "o": o}) + + for s, p, o in g.triples((None, None, uri_ref)): + triples.append({"s": s, "p": p, "o": o}) + + return triples + + def _retrieve( ontology: Graph, + data: Graph, triple_store: TripleStoreModel, id: str, uri: URIRef, query: str, + ephemeral: bool, ) -> list: - if query: + if ephemeral: + return get_triples_with_uri(g=data, uri=uri) + elif query: return triple_store.query(query=query) else: return triple_store.get(ontology=ontology, uri=uri, id=id) diff --git a/memonto/memonto.py b/memonto/memonto.py index fb4aef4..761b6a5 100644 --- a/memonto/memonto.py +++ b/memonto/memonto.py @@ -18,24 +18,17 @@ class Memonto(BaseModel): - id: Optional[str] = Field(None, description="Unique identifier for a memory group.") - ontology: Graph = Field(..., description="Schema describing the memory ontology.") - namespaces: dict[str, Namespace] = Field( - ..., description="Namespaces used in the memory ontology." - ) - data: Graph = Field( - default_factory=Graph, description="Data graph containing the actual memories." - ) - llm: Optional[LLMModel] = Field(None, description="LLM model instance.") - triple_store: Optional[TripleStoreModel] = Field(None, description="Store triples.") - vector_store: Optional[VectorStoreModel] = Field(None, description="Store vectors.") - debug: Optional[bool] = Field(False, description="Enable debug mode.") - auto_expand: Optional[bool] = Field( - False, description="Enable automatic expansion of the ontology." - ) - auto_forget: Optional[bool] = Field( - False, description="Enable automatic forgetting of memories." - ) + id: Optional[str] = None + ontology: Graph = ... + namespaces: dict[str, Namespace] = ... + data: Graph = Field(..., default_factory=Graph) + llm: Optional[LLMModel] = None + triple_store: Optional[TripleStoreModel] = None + vector_store: Optional[VectorStoreModel] = None + auto_expand: Optional[bool] = False + auto_forget: Optional[bool] = False + ephemeral: Optional[bool] = False + debug: Optional[bool] = False model_config = ConfigDict(arbitrary_types_allowed=True) @model_validator(mode="after") @@ -97,6 +90,7 @@ def retain(self, message: str) -> None: message=message, id=self.id, auto_expand=self.auto_expand, + ephemeral=self.ephemeral, ) @require_config("llm", "triple_store") @@ -112,6 +106,7 @@ async def aretain(self, message: str) -> None: message=message, id=self.id, auto_expand=self.auto_expand, + ephemeral=self.ephemeral, ) @require_config("llm", "triple_store", "vector_store") @@ -122,22 +117,26 @@ def recall(self, message: str = None) -> str: :return: A text summary of the entire current memory. """ return _recall( + data=self.data, llm=self.llm, triple_store=self.triple_store, vector_store=self.vector_store, message=message, id=self.id, + ephemeral=self.ephemeral, ) @require_config("llm", "triple_store", "vector_store") async def arecall(self, message: str = None) -> str: return await asyncio.to_thread( _recall, + data=self.data, llm=self.llm, triple_store=self.triple_store, vector_store=self.vector_store, message=message, id=self.id, + ephemeral=self.ephemeral, ) @require_config("triple_store") @@ -153,10 +152,12 @@ def retrieve(self, uri: URIRef = None, query: str = None) -> list: """ return _retrieve( ontology=self.ontology, + data=self.data, triple_store=self.triple_store, id=self.id, uri=uri, query=query, + ephemeral=self.ephemeral, ) @require_config("triple_store") @@ -164,10 +165,12 @@ async def aretrieve(self, uri: URIRef = None, query: str = None) -> list: return await asyncio.to_thread( _retrieve, ontology=self.ontology, + data=self.data, triple_store=self.triple_store, id=self.id, uri=uri, query=query, + ephemeral=self.ephemeral, ) def forget(self) -> None: @@ -175,17 +178,21 @@ def forget(self) -> None: Remove memories from the memory store. """ return _forget( + data=self.data, id=self.id, triple_store=self.triple_store, vector_store=self.vector_store, + ephemeral=self.ephemeral, ) async def aforget(self) -> None: await asyncio.to_thread( _forget, + data=self.data, id=self.id, triple_store=self.triple_store, vector_store=self.vector_store, + ephemeral=self.ephemeral, ) # TODO: no longer needed, can be deprecated or removed diff --git a/memonto/utils/decorators.py b/memonto/utils/decorators.py index d41ff39..52b91c2 100644 --- a/memonto/utils/decorators.py +++ b/memonto/utils/decorators.py @@ -8,6 +8,9 @@ def require_config(*config_names): def decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): + if getattr(self, "ephemeral", False): + return func(self, *args, **kwargs) + for config_name in config_names: store = getattr(self, config_name, None) diff --git a/tests/core/test_recall.py b/tests/core/test_recall.py index 570ed8b..050b60c 100644 --- a/tests/core/test_recall.py +++ b/tests/core/test_recall.py @@ -1,5 +1,5 @@ import pytest -from rdflib import Graph +from rdflib import Graph, Literal, URIRef from unittest.mock import ANY, MagicMock, patch from memonto.core.recall import _recall @@ -33,17 +33,48 @@ def mock_store(): return mock_store +@pytest.fixture +def data_graph(): + g = Graph() + + g.add( + ( + URIRef("http://example.org/test#subject1"), + URIRef("http://example.org/test#predicate1"), + Literal("object1"), + ) + ) + g.add( + ( + URIRef("http://example.org/test#subject2"), + URIRef("http://example.org/test#predicate2"), + Literal("object2"), + ) + ) + g.add( + ( + URIRef("http://example.org/test#subject3"), + URIRef("http://example.org/test#predicate3"), + Literal("object3"), + ) + ) + + return g + + @patch("memonto.core.recall._find_all") -def test_fetch_all_memory(mock_find_all, mock_llm, mock_store, id): +def test_fetch_all_memory(mock_find_all, mock_llm, mock_store, id, data_graph): all_memory = "all memory" mock_find_all.return_value = all_memory _recall( + data=data_graph, llm=mock_llm, vector_store=mock_store, triple_store=mock_store, message=None, id=id, + ephemeral=False, ) mock_llm.prompt.assert_called_once_with( @@ -61,20 +92,38 @@ def test_fetch_some_memory( mock_store, user_query, id, + data_graph, ): some_memory = "some memory" mock_find_adjacent_triples.return_value = some_memory mock_hydrate_triples.return_value = [] _recall( + data=data_graph, llm=mock_llm, vector_store=mock_store, triple_store=mock_store, message=user_query, id=id, + ephemeral=False, ) mock_llm.prompt.assert_called_once_with( prompt_name="summarize_memory", memory=some_memory, ) + + +def test_fetch_some_memory_ephemeral(mock_llm, data_graph): + + mem = _recall( + data=data_graph, + llm=mock_llm, + vector_store=None, + triple_store=None, + message=None, + id=None, + ephemeral=True, + ) + + assert mem == "some summary" diff --git a/tests/core/test_retain.py b/tests/core/test_retain.py index 543482b..dfa7426 100644 --- a/tests/core/test_retain.py +++ b/tests/core/test_retain.py @@ -56,6 +56,7 @@ def test_commit_memory( message=user_query, id=id, auto_expand=False, + ephemeral=False, ) ctm_prompt = call( @@ -90,6 +91,7 @@ def test_commit_memory_with_exception( message=user_query, id=id, auto_expand=False, + ephemeral=False, ) ctm_prompt = call( @@ -130,6 +132,7 @@ def test_commit_memory_auto_expand( message=user_query, id=id, auto_expand=True, + ephemeral=False, ) ctm_prompt = call( diff --git a/tests/core/test_retrieve.py b/tests/core/test_retrieve.py new file mode 100644 index 0000000..59042d9 --- /dev/null +++ b/tests/core/test_retrieve.py @@ -0,0 +1,54 @@ +import pytest +from rdflib import Graph, Literal, URIRef +from unittest.mock import ANY, MagicMock, Mock, patch + +from memonto.core.retrieve import _retrieve + + +@pytest.fixture +def data_graph(): + g = Graph() + + g.add( + ( + URIRef("http://example.org/test#subject1"), + URIRef("http://example.org/test#predicate1"), + Literal("object1"), + ) + ) + g.add( + ( + URIRef("http://example.org/test#subject2"), + URIRef("http://example.org/test#predicate2"), + Literal("object2"), + ) + ) + g.add( + ( + URIRef("http://example.org/test#subject3"), + URIRef("http://example.org/test#predicate3"), + Literal("object3"), + ) + ) + + return g + + +def test_retrive_memory_ephemeral(data_graph): + mem_data = _retrieve( + ontology=Mock(), + data=data_graph, + triple_store=None, + id=None, + uri=URIRef("http://example.org/test#subject1"), + query=None, + ephemeral=True, + ) + + assert mem_data == [ + { + "s": URIRef("http://example.org/test#subject1"), + "p": URIRef("http://example.org/test#predicate1"), + "o": Literal("object1"), + } + ]