diff --git a/memonto/core/query.py b/memonto/core/query.py index 6214a2c..2722f48 100644 --- a/memonto/core/query.py +++ b/memonto/core/query.py @@ -1,4 +1,4 @@ -from rdflib import URIRef, Graph, Namespace +from rdflib import URIRef, Graph from memonto.stores.triple.base_store import TripleStoreModel @@ -9,7 +9,6 @@ def query_memory_data( id: str, uri: URIRef, query: str, - debug: bool, ) -> list: if query: return store.query(query=query) @@ -18,5 +17,4 @@ def query_memory_data( ontology=ontology, id=id, uri=uri, - debug=debug, ) diff --git a/memonto/core/remember.py b/memonto/core/remember.py index 790457d..3aed450 100644 --- a/memonto/core/remember.py +++ b/memonto/core/remember.py @@ -5,8 +5,7 @@ def load_memory( namespaces: dict[str, Namespace], - store: TripleStoreModel, + triple_store: TripleStoreModel, id: str, - debug: bool, ) -> Graph: - return store.load(namespaces=namespaces, id=id, debug=debug) + return triple_store.load(namespaces=namespaces, id=id) diff --git a/memonto/core/retain.py b/memonto/core/retain.py index fcceee2..4f14f47 100644 --- a/memonto/core/retain.py +++ b/memonto/core/retain.py @@ -3,6 +3,7 @@ from memonto.llms.base_llm import LLMModel from memonto.stores.triple.base_store import TripleStoreModel from memonto.stores.vector.base_store import VectorStoreModel +from memonto.utils.logger import logger def run_script( @@ -14,18 +15,14 @@ def run_script( llm: LLMModel, max_retries: int = 1, initial_temperature: float = 0.2, - debug: bool = False, ) -> Graph: attempt = 0 while attempt < max_retries: try: exec(script, exec_ctx) - if debug: - print(f"Script executed successfully on attempt {attempt + 1}") except Exception as e: - if debug: - print(f"Attempt {attempt + 1} to commit memory failed with error: {e}") + logger.debug(f"Run Script (Attempt {attempt + 1}) Failed\n{e}\n") temperature = initial_temperature * (2**attempt) temperature = min(temperature, 1.0) @@ -39,8 +36,7 @@ def run_script( user_message=message, ) - if debug: - print(f"Generated script on attempt {attempt + 1}:\n{script}") + logger.debug(f"Fixed Script (Attempt {attempt + 1})\n{script}\n") attempt += 1 @@ -51,7 +47,6 @@ def expand_ontology( ontology: Graph, llm: LLMModel, message: str, - debug: bool, ) -> Graph: script = llm.prompt( prompt_name="expand_ontology", @@ -60,9 +55,13 @@ def expand_ontology( user_message=message, ) + logger.debug(f"Expand Script\n{script}\n") + # TODO: handle exceptions just like in run_script exec(script, {"ontology": ontology}) + logger.debug(f"Ontology Graph\n{ontology.serialize(format='turtle')}\n") + return ontology @@ -76,14 +75,12 @@ def retain_memory( message: str, id: str, auto_expand: bool, - debug: bool, -): +) -> None: if auto_expand: ontology = expand_ontology( ontology=ontology, llm=llm, message=message, - debug=debug, ) str_ontology = ontology.serialize(format="turtle") @@ -95,8 +92,7 @@ def retain_memory( user_message=message, ) - if debug: - print(f"script:\n{script}\n") + logger.debug(f"Retain Script\n{script}\n") data = run_script( script=script, @@ -105,11 +101,9 @@ def retain_memory( ontology=str_ontology, data=data, llm=llm, - debug=debug, ) - if debug: - print(f"data:\n{data.serialize(format='turtle')}\n") + logger.debug(f"Data Graph\n{data.serialize(format='turtle')}\n") triple_store.save(ontology=ontology, data=data, id=id) vector_store.save(g=data, id=id) diff --git a/memonto/core/retrieve.py b/memonto/core/retrieve.py index 423475b..b541c46 100644 --- a/memonto/core/retrieve.py +++ b/memonto/core/retrieve.py @@ -69,7 +69,7 @@ def _find_adjacent_triples( return triple_store.query(query=query, format="turtle") -def _find_all(triple_store: TripleStoreModel, id: str = None) -> str: +def _find_all(triple_store: TripleStoreModel) -> str: return triple_store.query( query="CONSTRUCT {?s ?p ?o .} WHERE { GRAPH ?g { ?s ?p ?o . }}", format="turtle", @@ -82,9 +82,9 @@ def recall_memory( triple_store: TripleStoreModel, message: str, id: str, - debug: bool, ) -> str: if vector_store is None: + logger.error("Vector store is not configured.") raise Exception("Vector store is not configured.") if message: @@ -111,4 +111,6 @@ def recall_memory( memory=contextual_memory, ) + logger.debug(f"Summarized Memory\n{summarized_memory}\n") + return summarized_memory diff --git a/memonto/memonto.py b/memonto/memonto.py index 8990ace..b5f8808 100644 --- a/memonto/memonto.py +++ b/memonto/memonto.py @@ -92,7 +92,6 @@ def retain(self, message: str) -> None: vector_store=self.vector_store, message=message, id=self.id, - debug=self.debug, auto_expand=self.auto_expand, ) @@ -108,7 +107,6 @@ def recall(self, message: str = None) -> str: vector_store=self.vector_store, message=message, id=self.id, - debug=self.debug, ) def remember(self) -> None: @@ -121,9 +119,8 @@ def remember(self) -> None: """ self.ontology, self.data = load_memory( namespaces=self.namespaces, - store=self.triple_store, + triple_store=self.triple_store, id=self.id, - debug=self.debug, ) def forget(self): @@ -148,7 +145,6 @@ def query(self, uri: URIRef = None, query: str = None) -> list: id=self.id, uri=uri, query=query, - debug=self.debug, ) def render(self, format: str = "turtle") -> Union[str, dict]: diff --git a/memonto/stores/triple/jena.py b/memonto/stores/triple/jena.py index d54b840..7cd433a 100644 --- a/memonto/stores/triple/jena.py +++ b/memonto/stores/triple/jena.py @@ -1,9 +1,10 @@ -from rdflib import Graph, Literal, Namespace, URIRef, RDF, RDFS, OWL +from rdflib import Graph, Literal, Namespace, URIRef from SPARQLWrapper import SPARQLWrapper, GET, POST, TURTLE, JSON from SPARQLWrapper.SPARQLExceptions import SPARQLWrapperException from typing import Tuple from memonto.stores.triple.base_store import TripleStoreModel +from memonto.utils.logger import logger class ApacheJena(TripleStoreModel): @@ -18,7 +19,6 @@ def _query( method: Literal, query: str, format: str = TURTLE, - debug: bool = False, ) -> SPARQLWrapper: sparql = SPARQLWrapper(url) sparql.setQuery(query) @@ -28,23 +28,24 @@ def _query( if self.username and self.password: sparql.setCredentials(self.username, self.password) - if debug: - print(f"Query:\n{query}\n") + logger.debug(f"SPARQL Query\n{query}\n") try: response = sparql.query() content_type = response.info()["Content-Type"] if "html" in content_type: - return response.response.read().decode("utf-8") + res = response.response.read().decode("utf-8") + logger.debug(f"SPARQL Query Result\n{res}\n") + return res else: - return response.convert() + res = response.convert() + logger.debug(f"SPARQL Query Result\n{res}\n") + return res except SPARQLWrapperException as e: - if 1: - print(f"SPARQL query error:\n{e}\n") + logger.error(f"SPARQL Query Error\n{e}\n") except Exception as e: - if debug: - print(f"Generic query error:\n{e}\n") + logger.error(f"Generic Query Error\n{e}\n") def _get_prefixes(self, g: Graph) -> list[str]: gt = g.serialize(format="turtle") @@ -55,7 +56,6 @@ def _load( g: Graph, namespaces: dict[str, Namespace], id: str, - debug: bool, ) -> Graph: query = f"CONSTRUCT {{ ?s ?p ?o }} WHERE {{ GRAPH <{id}> {{ ?s ?p ?o }} }}" @@ -63,7 +63,6 @@ def _load( url=f"{self.connection_url}/sparql", method=POST, query=query, - debug=debug, ) g.parse(data=response, format="turtle") @@ -78,7 +77,6 @@ def save( ontology: Graph, data: Graph, id: str = None, - debug: bool = False, ) -> None: o_triples = ontology.serialize(format="nt") d_triples = data.serialize(format="nt") @@ -104,14 +102,12 @@ def save( url=f"{self.connection_url}/update", method=POST, query=query, - debug=debug, ) def load( self, namespaces: dict[str, Namespace], id: str = None, - debug: bool = False, ) -> Tuple[Graph, Graph]: ontology_id = f"ontology-{id}" if id else "ontology" data_id = f"data-{id}" if id else "data" @@ -123,18 +119,15 @@ def load( g=ontology, namespaces=namespaces, id=ontology_id, - debug=debug, ) data = self._load( g=data, namespaces=namespaces, id=data_id, - debug=debug, ) - if debug: - print(f"Loaded ontology:\n{ontology.serialize(format='turtle')}\n") - print(f"Loaded data:\n{data.serialize(format='turtle')}\n") + logger.debug(f"Loaded Ontology Graph\n{ontology.serialize(format='turtle')}\n") + logger.debug(f"Loaded Data Graph\n{data.serialize(format='turtle')}\n") return ontology, data @@ -143,7 +136,6 @@ def get( ontology: Graph, id: str, uri: URIRef, - debug: bool = False, ) -> list: prefixes = self._get_prefixes(ontology) prefix_block = ( @@ -164,7 +156,6 @@ def get( method=GET, query=query, format=JSON, - debug=debug, ) return result["results"]["bindings"] diff --git a/tests/core/test_remember.py b/tests/core/test_remember.py index fe556ee..602c0e9 100644 --- a/tests/core/test_remember.py +++ b/tests/core/test_remember.py @@ -17,7 +17,7 @@ def mock_store(): def test_load_memory(mock_store, id): load_memory( - store=mock_store, + triple_store=mock_store, id=id, debug=False, )