Skip to content

Commit

Permalink
feat: add ephemeral mode (#9)
Browse files Browse the repository at this point in the history
* add ephemeral mode

* fix and add tests

* format
  • Loading branch information
shihanwan authored Sep 30, 2024
1 parent fd2fe0f commit 3d7989a
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 25 deletions.
7 changes: 7 additions & 0 deletions memonto/core/forget.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from rdflib import Graph

from memonto.stores.triple.base_store import TripleStoreModel
from memonto.stores.vector.base_store import VectorStoreModel
from memonto.utils.logger import logger


def _forget(
data: Graph,
id: str,
triple_store: TripleStoreModel,
vector_store: VectorStoreModel,
ephemeral: bool,
) -> None:
if ephemeral:
data.remove((None, None, None))

try:
if vector_store:
vector_store.delete(id)
Expand Down
7 changes: 6 additions & 1 deletion memonto/core/recall.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from rdflib import Graph

from memonto.llms.base_llm import LLMModel
from memonto.stores.triple.base_store import TripleStoreModel
Expand Down Expand Up @@ -77,13 +78,17 @@ def _find_all(triple_store: TripleStoreModel) -> str:


def _recall(
data: Graph,
llm: LLMModel,
vector_store: VectorStoreModel,
triple_store: TripleStoreModel,
message: str,
id: str,
ephemeral: bool,
) -> str:
if message:
if ephemeral:
contextual_memory = data.serialize(format="turtle")
elif message:
try:
matched_triples = vector_store.search(message=message, id=id)
triples = _hydrate_triples(
Expand Down
8 changes: 5 additions & 3 deletions memonto/core/retain.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def _retain(
message: str,
id: str,
auto_expand: bool,
ephemeral: bool,
) -> None:
if auto_expand:
ontology = expand_ontology(
Expand Down Expand Up @@ -105,6 +106,7 @@ def _retain(

logger.debug(f"Data Graph\n{data.serialize(format='turtle')}\n")

triple_store.save(ontology=ontology, data=data, id=id)
if vector_store:
vector_store.save(g=data, id=id)
if not ephemeral:
triple_store.save(ontology=ontology, data=data, id=id)
if vector_store:
vector_store.save(g=data, id=id)
22 changes: 21 additions & 1 deletion memonto/core/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,34 @@
from memonto.stores.triple.base_store import TripleStoreModel


def get_triples_with_uri(g: Graph, uri: str) -> list[dict]:
uri_ref = URIRef(uri)
triples = []

for s, p, o in g.triples((uri_ref, None, None)):
triples.append({"s": s, "p": p, "o": o})

for s, p, o in g.triples((None, uri_ref, None)):
triples.append({"s": s, "p": p, "o": o})

for s, p, o in g.triples((None, None, uri_ref)):
triples.append({"s": s, "p": p, "o": o})

return triples


def _retrieve(
ontology: Graph,
data: Graph,
triple_store: TripleStoreModel,
id: str,
uri: URIRef,
query: str,
ephemeral: bool,
) -> list:
if query:
if ephemeral:
return get_triples_with_uri(g=data, uri=uri)
elif query:
return triple_store.query(query=query)
else:
return triple_store.get(ontology=ontology, uri=uri, id=id)
43 changes: 25 additions & 18 deletions memonto/memonto.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,17 @@


class Memonto(BaseModel):
id: Optional[str] = Field(None, description="Unique identifier for a memory group.")
ontology: Graph = Field(..., description="Schema describing the memory ontology.")
namespaces: dict[str, Namespace] = Field(
..., description="Namespaces used in the memory ontology."
)
data: Graph = Field(
default_factory=Graph, description="Data graph containing the actual memories."
)
llm: Optional[LLMModel] = Field(None, description="LLM model instance.")
triple_store: Optional[TripleStoreModel] = Field(None, description="Store triples.")
vector_store: Optional[VectorStoreModel] = Field(None, description="Store vectors.")
debug: Optional[bool] = Field(False, description="Enable debug mode.")
auto_expand: Optional[bool] = Field(
False, description="Enable automatic expansion of the ontology."
)
auto_forget: Optional[bool] = Field(
False, description="Enable automatic forgetting of memories."
)
id: Optional[str] = None
ontology: Graph = ...
namespaces: dict[str, Namespace] = ...
data: Graph = Field(..., default_factory=Graph)
llm: Optional[LLMModel] = None
triple_store: Optional[TripleStoreModel] = None
vector_store: Optional[VectorStoreModel] = None
auto_expand: Optional[bool] = False
auto_forget: Optional[bool] = False
ephemeral: Optional[bool] = False
debug: Optional[bool] = False
model_config = ConfigDict(arbitrary_types_allowed=True)

@model_validator(mode="after")
Expand Down Expand Up @@ -97,6 +90,7 @@ def retain(self, message: str) -> None:
message=message,
id=self.id,
auto_expand=self.auto_expand,
ephemeral=self.ephemeral,
)

@require_config("llm", "triple_store")
Expand All @@ -112,6 +106,7 @@ async def aretain(self, message: str) -> None:
message=message,
id=self.id,
auto_expand=self.auto_expand,
ephemeral=self.ephemeral,
)

@require_config("llm", "triple_store", "vector_store")
Expand All @@ -122,22 +117,26 @@ def recall(self, message: str = None) -> str:
:return: A text summary of the entire current memory.
"""
return _recall(
data=self.data,
llm=self.llm,
triple_store=self.triple_store,
vector_store=self.vector_store,
message=message,
id=self.id,
ephemeral=self.ephemeral,
)

@require_config("llm", "triple_store", "vector_store")
async def arecall(self, message: str = None) -> str:
return await asyncio.to_thread(
_recall,
data=self.data,
llm=self.llm,
triple_store=self.triple_store,
vector_store=self.vector_store,
message=message,
id=self.id,
ephemeral=self.ephemeral,
)

@require_config("triple_store")
Expand All @@ -153,39 +152,47 @@ def retrieve(self, uri: URIRef = None, query: str = None) -> list:
"""
return _retrieve(
ontology=self.ontology,
data=self.data,
triple_store=self.triple_store,
id=self.id,
uri=uri,
query=query,
ephemeral=self.ephemeral,
)

@require_config("triple_store")
async def aretrieve(self, uri: URIRef = None, query: str = None) -> list:
return await asyncio.to_thread(
_retrieve,
ontology=self.ontology,
data=self.data,
triple_store=self.triple_store,
id=self.id,
uri=uri,
query=query,
ephemeral=self.ephemeral,
)

def forget(self) -> None:
"""
Remove memories from the memory store.
"""
return _forget(
data=self.data,
id=self.id,
triple_store=self.triple_store,
vector_store=self.vector_store,
ephemeral=self.ephemeral,
)

async def aforget(self) -> None:
await asyncio.to_thread(
_forget,
data=self.data,
id=self.id,
triple_store=self.triple_store,
vector_store=self.vector_store,
ephemeral=self.ephemeral,
)

# TODO: no longer needed, can be deprecated or removed
Expand Down
3 changes: 3 additions & 0 deletions memonto/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ def require_config(*config_names):
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if getattr(self, "ephemeral", False):
return func(self, *args, **kwargs)

for config_name in config_names:
store = getattr(self, config_name, None)

Expand Down
53 changes: 51 additions & 2 deletions tests/core/test_recall.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from rdflib import Graph
from rdflib import Graph, Literal, URIRef
from unittest.mock import ANY, MagicMock, patch

from memonto.core.recall import _recall
Expand Down Expand Up @@ -33,17 +33,48 @@ def mock_store():
return mock_store


@pytest.fixture
def data_graph():
g = Graph()

g.add(
(
URIRef("http://example.org/test#subject1"),
URIRef("http://example.org/test#predicate1"),
Literal("object1"),
)
)
g.add(
(
URIRef("http://example.org/test#subject2"),
URIRef("http://example.org/test#predicate2"),
Literal("object2"),
)
)
g.add(
(
URIRef("http://example.org/test#subject3"),
URIRef("http://example.org/test#predicate3"),
Literal("object3"),
)
)

return g


@patch("memonto.core.recall._find_all")
def test_fetch_all_memory(mock_find_all, mock_llm, mock_store, id):
def test_fetch_all_memory(mock_find_all, mock_llm, mock_store, id, data_graph):
all_memory = "all memory"
mock_find_all.return_value = all_memory

_recall(
data=data_graph,
llm=mock_llm,
vector_store=mock_store,
triple_store=mock_store,
message=None,
id=id,
ephemeral=False,
)

mock_llm.prompt.assert_called_once_with(
Expand All @@ -61,20 +92,38 @@ def test_fetch_some_memory(
mock_store,
user_query,
id,
data_graph,
):
some_memory = "some memory"
mock_find_adjacent_triples.return_value = some_memory
mock_hydrate_triples.return_value = []

_recall(
data=data_graph,
llm=mock_llm,
vector_store=mock_store,
triple_store=mock_store,
message=user_query,
id=id,
ephemeral=False,
)

mock_llm.prompt.assert_called_once_with(
prompt_name="summarize_memory",
memory=some_memory,
)


def test_fetch_some_memory_ephemeral(mock_llm, data_graph):

mem = _recall(
data=data_graph,
llm=mock_llm,
vector_store=None,
triple_store=None,
message=None,
id=None,
ephemeral=True,
)

assert mem == "some summary"
3 changes: 3 additions & 0 deletions tests/core/test_retain.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_commit_memory(
message=user_query,
id=id,
auto_expand=False,
ephemeral=False,
)

ctm_prompt = call(
Expand Down Expand Up @@ -90,6 +91,7 @@ def test_commit_memory_with_exception(
message=user_query,
id=id,
auto_expand=False,
ephemeral=False,
)

ctm_prompt = call(
Expand Down Expand Up @@ -130,6 +132,7 @@ def test_commit_memory_auto_expand(
message=user_query,
id=id,
auto_expand=True,
ephemeral=False,
)

ctm_prompt = call(
Expand Down
Loading

0 comments on commit 3d7989a

Please sign in to comment.