Skip to content

Commit

Permalink
switch to logger
Browse files Browse the repository at this point in the history
  • Loading branch information
shihanwan committed Sep 26, 2024
1 parent b69e816 commit 7e9bcba
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 52 deletions.
4 changes: 1 addition & 3 deletions memonto/core/query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from rdflib import URIRef, Graph, Namespace
from rdflib import URIRef, Graph

from memonto.stores.triple.base_store import TripleStoreModel

Expand All @@ -9,7 +9,6 @@ def query_memory_data(
id: str,
uri: URIRef,
query: str,
debug: bool,
) -> list:
if query:
return store.query(query=query)
Expand All @@ -18,5 +17,4 @@ def query_memory_data(
ontology=ontology,
id=id,
uri=uri,
debug=debug,
)
5 changes: 2 additions & 3 deletions memonto/core/remember.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

def load_memory(
namespaces: dict[str, Namespace],
store: TripleStoreModel,
triple_store: TripleStoreModel,
id: str,
debug: bool,
) -> Graph:
return store.load(namespaces=namespaces, id=id, debug=debug)
return triple_store.load(namespaces=namespaces, id=id)
26 changes: 10 additions & 16 deletions memonto/core/retain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from memonto.llms.base_llm import LLMModel
from memonto.stores.triple.base_store import TripleStoreModel
from memonto.stores.vector.base_store import VectorStoreModel
from memonto.utils.logger import logger


def run_script(
Expand All @@ -14,18 +15,14 @@ def run_script(
llm: LLMModel,
max_retries: int = 1,
initial_temperature: float = 0.2,
debug: bool = False,
) -> Graph:
attempt = 0

while attempt < max_retries:
try:
exec(script, exec_ctx)
if debug:
print(f"Script executed successfully on attempt {attempt + 1}")
except Exception as e:
if debug:
print(f"Attempt {attempt + 1} to commit memory failed with error: {e}")
logger.debug(f"Run Script (Attempt {attempt + 1}) Failed\n{e}\n")

temperature = initial_temperature * (2**attempt)
temperature = min(temperature, 1.0)
Expand All @@ -39,8 +36,7 @@ def run_script(
user_message=message,
)

if debug:
print(f"Generated script on attempt {attempt + 1}:\n{script}")
logger.debug(f"Fixed Script (Attempt {attempt + 1})\n{script}\n")

attempt += 1

Expand All @@ -51,7 +47,6 @@ def expand_ontology(
ontology: Graph,
llm: LLMModel,
message: str,
debug: bool,
) -> Graph:
script = llm.prompt(
prompt_name="expand_ontology",
Expand All @@ -60,9 +55,13 @@ def expand_ontology(
user_message=message,
)

logger.debug(f"Expand Script\n{script}\n")

# TODO: handle exceptions just like in run_script
exec(script, {"ontology": ontology})

logger.debug(f"Ontology Graph\n{ontology.serialize(format='turtle')}\n")

return ontology


Expand All @@ -76,14 +75,12 @@ def retain_memory(
message: str,
id: str,
auto_expand: bool,
debug: bool,
):
) -> None:
if auto_expand:
ontology = expand_ontology(
ontology=ontology,
llm=llm,
message=message,
debug=debug,
)

str_ontology = ontology.serialize(format="turtle")
Expand All @@ -95,8 +92,7 @@ def retain_memory(
user_message=message,
)

if debug:
print(f"script:\n{script}\n")
logger.debug(f"Retain Script\n{script}\n")

data = run_script(
script=script,
Expand All @@ -105,11 +101,9 @@ def retain_memory(
ontology=str_ontology,
data=data,
llm=llm,
debug=debug,
)

if debug:
print(f"data:\n{data.serialize(format='turtle')}\n")
logger.debug(f"Data Graph\n{data.serialize(format='turtle')}\n")

triple_store.save(ontology=ontology, data=data, id=id)
vector_store.save(g=data, id=id)
6 changes: 4 additions & 2 deletions memonto/core/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _find_adjacent_triples(
return triple_store.query(query=query, format="turtle")


def _find_all(triple_store: TripleStoreModel, id: str = None) -> str:
def _find_all(triple_store: TripleStoreModel) -> str:
return triple_store.query(
query="CONSTRUCT {?s ?p ?o .} WHERE { GRAPH ?g { ?s ?p ?o . }}",
format="turtle",
Expand All @@ -82,9 +82,9 @@ def recall_memory(
triple_store: TripleStoreModel,
message: str,
id: str,
debug: bool,
) -> str:
if vector_store is None:
logger.error("Vector store is not configured.")
raise Exception("Vector store is not configured.")

if message:
Expand All @@ -111,4 +111,6 @@ def recall_memory(
memory=contextual_memory,
)

logger.debug(f"Summarized Memory\n{summarized_memory}\n")

return summarized_memory
6 changes: 1 addition & 5 deletions memonto/memonto.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def retain(self, message: str) -> None:
vector_store=self.vector_store,
message=message,
id=self.id,
debug=self.debug,
auto_expand=self.auto_expand,
)

Expand All @@ -108,7 +107,6 @@ def recall(self, message: str = None) -> str:
vector_store=self.vector_store,
message=message,
id=self.id,
debug=self.debug,
)

def remember(self) -> None:
Expand All @@ -121,9 +119,8 @@ def remember(self) -> None:
"""
self.ontology, self.data = load_memory(
namespaces=self.namespaces,
store=self.triple_store,
triple_store=self.triple_store,
id=self.id,
debug=self.debug,
)

def forget(self):
Expand All @@ -148,7 +145,6 @@ def query(self, uri: URIRef = None, query: str = None) -> list:
id=self.id,
uri=uri,
query=query,
debug=self.debug,
)

def render(self, format: str = "turtle") -> Union[str, dict]:
Expand Down
35 changes: 13 additions & 22 deletions memonto/stores/triple/jena.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from rdflib import Graph, Literal, Namespace, URIRef, RDF, RDFS, OWL
from rdflib import Graph, Literal, Namespace, URIRef
from SPARQLWrapper import SPARQLWrapper, GET, POST, TURTLE, JSON
from SPARQLWrapper.SPARQLExceptions import SPARQLWrapperException
from typing import Tuple

from memonto.stores.triple.base_store import TripleStoreModel
from memonto.utils.logger import logger


class ApacheJena(TripleStoreModel):
Expand All @@ -18,7 +19,6 @@ def _query(
method: Literal,
query: str,
format: str = TURTLE,
debug: bool = False,
) -> SPARQLWrapper:
sparql = SPARQLWrapper(url)
sparql.setQuery(query)
Expand All @@ -28,23 +28,24 @@ def _query(
if self.username and self.password:
sparql.setCredentials(self.username, self.password)

if debug:
print(f"Query:\n{query}\n")
logger.debug(f"SPARQL Query\n{query}\n")

try:
response = sparql.query()
content_type = response.info()["Content-Type"]

if "html" in content_type:
return response.response.read().decode("utf-8")
res = response.response.read().decode("utf-8")
logger.debug(f"SPARQL Query Result\n{res}\n")
return res
else:
return response.convert()
res = response.convert()
logger.debug(f"SPARQL Query Result\n{res}\n")
return res
except SPARQLWrapperException as e:
if 1:
print(f"SPARQL query error:\n{e}\n")
logger.error(f"SPARQL Query Error\n{e}\n")
except Exception as e:
if debug:
print(f"Generic query error:\n{e}\n")
logger.error(f"Generic Query Error\n{e}\n")

def _get_prefixes(self, g: Graph) -> list[str]:
gt = g.serialize(format="turtle")
Expand All @@ -55,15 +56,13 @@ def _load(
g: Graph,
namespaces: dict[str, Namespace],
id: str,
debug: bool,
) -> Graph:
query = f"CONSTRUCT {{ ?s ?p ?o }} WHERE {{ GRAPH <{id}> {{ ?s ?p ?o }} }}"

response = self._query(
url=f"{self.connection_url}/sparql",
method=POST,
query=query,
debug=debug,
)

g.parse(data=response, format="turtle")
Expand All @@ -78,7 +77,6 @@ def save(
ontology: Graph,
data: Graph,
id: str = None,
debug: bool = False,
) -> None:
o_triples = ontology.serialize(format="nt")
d_triples = data.serialize(format="nt")
Expand All @@ -104,14 +102,12 @@ def save(
url=f"{self.connection_url}/update",
method=POST,
query=query,
debug=debug,
)

def load(
self,
namespaces: dict[str, Namespace],
id: str = None,
debug: bool = False,
) -> Tuple[Graph, Graph]:
ontology_id = f"ontology-{id}" if id else "ontology"
data_id = f"data-{id}" if id else "data"
Expand All @@ -123,18 +119,15 @@ def load(
g=ontology,
namespaces=namespaces,
id=ontology_id,
debug=debug,
)
data = self._load(
g=data,
namespaces=namespaces,
id=data_id,
debug=debug,
)

if debug:
print(f"Loaded ontology:\n{ontology.serialize(format='turtle')}\n")
print(f"Loaded data:\n{data.serialize(format='turtle')}\n")
logger.debug(f"Loaded Ontology Graph\n{ontology.serialize(format='turtle')}\n")
logger.debug(f"Loaded Data Graph\n{data.serialize(format='turtle')}\n")

return ontology, data

Expand All @@ -143,7 +136,6 @@ def get(
ontology: Graph,
id: str,
uri: URIRef,
debug: bool = False,
) -> list:
prefixes = self._get_prefixes(ontology)
prefix_block = (
Expand All @@ -164,7 +156,6 @@ def get(
method=GET,
query=query,
format=JSON,
debug=debug,
)

return result["results"]["bindings"]
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_remember.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def mock_store():

def test_load_memory(mock_store, id):
load_memory(
store=mock_store,
triple_store=mock_store,
id=id,
debug=False,
)
Expand Down

0 comments on commit 7e9bcba

Please sign in to comment.