Skip to content

Commit

Permalink
perf: improve kg storage and retrieval from vector (#23)
Browse files Browse the repository at this point in the history
* improve kg node value parsing

* remove bnodes from recall

* fix tests

* formatting

* remove debug code
  • Loading branch information
shihanwan authored Oct 16, 2024
1 parent 9831c1f commit 4e7936f
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 191 deletions.
2 changes: 1 addition & 1 deletion memonto/core/forget.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _forget(
vector_store.delete(id)

if triple_store:
triple_store.delete(id)
triple_store.delete_all(id)
except ValueError as e:
logger.warning(e)
except Exception as e:
Expand Down
156 changes: 3 additions & 153 deletions memonto/core/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,146 +8,6 @@
from memonto.utils.rdf import serialize_graph_without_ids


def _hydrate_triples(
matched: list,
triple_store: VectorStoreModel,
id: str = None,
) -> Graph:
matched_ids = matched.keys()
triple_ids = " ".join(f'("{id}")' for id in matched_ids)

graph_id = f"data-{id}" if id else "data"

query = f"""
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
CONSTRUCT {{
?s ?p ?o .
}}
WHERE {{
GRAPH <{graph_id}> {{
VALUES (?uuid) {{ {triple_ids} }}
?triple_node <{TRIPLE_PROP.uuid}> ?uuid .
?triple_node rdf:subject ?s ;
rdf:predicate ?p ;
rdf:object ?o .
}}
}}
"""

result = triple_store.query(query=query, format="turtle")

g = Graph()
g.parse(data=result, format="turtle")

return g


def _get_formatted_node(node: URIRef | Literal | BNode) -> str:
if isinstance(node, URIRef):
return f"<{str(node)}>"
elif isinstance(node, Literal):
return f'"{str(node)}"'
elif isinstance(node, BNode):
return f"_:{str(node)}"
else:
return f'"{str(node)}"'


def _find_adjacent_triples(
triples: Graph,
triple_store: VectorStoreModel,
id: str = None,
depth: int = 1,
) -> str:
nodes_set = set()

for s, p, o in triples:
nodes_set.add(_get_formatted_node(s))
nodes_set.add(_get_formatted_node(o))

explored_nodes = set(nodes_set)
new_nodes_set = nodes_set.copy()

query = None

for _ in range(depth):
if not new_nodes_set:
break

node_list = ", ".join(new_nodes_set)
graph_id = f"data-{id}" if id else "data"

query = f"""
CONSTRUCT {{
?s ?p ?o .
}}
WHERE {{
GRAPH <{graph_id}> {{
?s ?p ?o .
FILTER (?s IN ({node_list}) || ?o IN ({node_list}))
}}
}}
"""

logger.debug(f"Find adjacent triples SPARQL query\n{query}\n")

try:
result_triples = triple_store.query(query=query, format="turtle")
except Exception as e:
raise ValueError(f"SPARQL Query Error: {e}")

if result_triples is None:
raise ValueError("SPARQL query returned no results")

graph = Graph()
graph.parse(data=result_triples, format="turtle")

temp_new_nodes_set = set()
for s, p, o in graph:
formatted_subject = _get_formatted_node(s)
formatted_object = _get_formatted_node(o)

if formatted_subject not in explored_nodes:
temp_new_nodes_set.add(formatted_subject)
explored_nodes.add(formatted_subject)

if formatted_object not in explored_nodes:
temp_new_nodes_set.add(formatted_object)
explored_nodes.add(formatted_object)

new_nodes_set = temp_new_nodes_set

if query is None:
return ""

return triple_store.query(query=query, format="turtle")


def _find_all(triple_store: TripleStoreModel, id: str) -> str:
result = triple_store.query(
query=f"""
CONSTRUCT {{
?s ?p ?o .
}} WHERE {{
GRAPH <data-{id}> {{
?s ?p ?o .
FILTER NOT EXISTS {{ ?s <{TRIPLE_PROP.uuid}> ?uuid }}
}}
}}
""",
format="turtle",
)

if isinstance(result, bytes):
result = result.decode("utf-8")

if not result:
return ""

return str(result)


def get_contextual_memory(
data: Graph,
vector_store: VectorStoreModel,
Expand All @@ -165,25 +25,15 @@ def get_contextual_memory(
matched = vector_store.search(message=context, id=id)
logger.debug(f"Matched Triples Raw\n{matched}\n")

matched_graph = _hydrate_triples(
memory = triple_store.get_context(
matched=matched,
triple_store=triple_store,
id=id,
)
matched_triples = matched_graph.serialize(format="turtle")
logger.debug(f"Matched Triples\n{matched_triples}\n")

memory = _find_adjacent_triples(
triples=matched_graph,
triple_store=triple_store,
id=id,
graph_id=id,
depth=1,
)
logger.debug(f"Adjacent Triples\n{memory}\n")
except ValueError as e:
logger.debug(f"Recall Exception\n{e}\n")
else:
memory = _find_all(triple_store=triple_store, id=id)
memory = triple_store.get_all(graph_id=id)

logger.debug(f"Contextual Memory\n{memory}\n")
return memory
Expand Down
4 changes: 2 additions & 2 deletions memonto/core/retain.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,14 @@ def save_memory(
logger.debug(f"Data Graph\n{data.serialize(format='turtle')}\n")

# debug
# _render(g=data, format="image")
# _render(g=data, ns=namespaces, format="image")

if not ephemeral:
hydrate_graph_with_ids(data)
triple_store.save(ontology=ontology, data=data, id=id)

if vector_store:
vector_store.save(g=data, id=id)
vector_store.save(g=data, ns=namespaces, id=id)

data.remove((None, None, None))

Expand Down
11 changes: 11 additions & 0 deletions memonto/stores/triple/base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ def get(self):
"""
pass

@abstractmethod
def get_all(self, graph_id: str = None) -> str:
"""
Get all memory data from the datastore.
:param graph_id: The id of the graph to get all memory data from.
:return: A string representation of the memory data.
"""
pass

@abstractmethod
def query(self):
"""
Expand Down
151 changes: 149 additions & 2 deletions memonto/stores/triple/jena.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from memonto.stores.triple.base_store import TripleStoreModel
from memonto.utils.logger import logger
from memonto.utils.namespaces import TRIPLE_PROP
from memonto.utils.rdf import format_node


class ApacheJena(TripleStoreModel):
Expand Down Expand Up @@ -53,6 +54,43 @@ def _get_prefixes(self, g: Graph) -> list[str]:
gt = g.serialize(format="turtle")
return [line for line in gt.splitlines() if line.startswith("@prefix")]

def _hydrate_triples(
self,
matched: list,
graph_id: str = None,
) -> Graph:
matched_ids = matched.keys()
triple_ids = " ".join(f'("{id}")' for id in matched_ids)
g_id = f"data-{graph_id}" if graph_id else "data"

query = f"""
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
CONSTRUCT {{
?s ?p ?o .
}}
WHERE {{
GRAPH <{g_id}> {{
VALUES (?uuid) {{ {triple_ids} }}
?triple_node <{TRIPLE_PROP.uuid}> ?uuid .
?triple_node rdf:subject ?s ;
rdf:predicate ?p ;
rdf:object ?o .
}}
}}
"""

result = self._query(
url=f"{self.connection_url}/sparql",
method=GET,
query=query,
)

g = Graph()
g.parse(data=result, format="turtle")

return g

def _load(
self,
g: Graph,
Expand Down Expand Up @@ -162,8 +200,117 @@ def get(

return result["results"]["bindings"]

def delete(self, id: str = None) -> None:
query = f"""DROP GRAPH <ontology-{id}> ; DROP GRAPH <data-{id}> ;"""
def get_all(self, graph_id: str = None) -> str:
g_id = f"data-{graph_id}" if graph_id else "data"

query = f"""
CONSTRUCT {{
?s ?p ?o .
}} WHERE {{
GRAPH <{g_id}> {{
?s ?p ?o .
FILTER NOT EXISTS {{ ?s <{TRIPLE_PROP.uuid}> ?uuid }}
}}
}}
"""

result = self._query(
url=f"{self.connection_url}/sparql",
method=GET,
query=query,
)

if isinstance(result, bytes):
result = result.decode("utf-8")

if not result:
return ""

return str(result)

def get_context(
self, matched: dict[str, dict], graph_id: str, depth: int = 1
) -> str:
g_id = f"data-{graph_id}" if graph_id else "data"
nodes_set = set()

matched_graph = self._hydrate_triples(
matched=matched,
graph_id=graph_id,
)
logger.debug(f"Matched Triples\n{matched_graph.serialize(format='turtle')}\n")

for s, p, o in matched_graph:
nodes_set.add(format_node(s))
nodes_set.add(format_node(o))

explored_nodes = set(nodes_set)
new_nodes_set = nodes_set.copy()

query = None

for _ in range(depth):
if not new_nodes_set:
break

node_list = ", ".join(new_nodes_set)

query = f"""
CONSTRUCT {{
?s ?p ?o .
}}
WHERE {{
GRAPH <{g_id}> {{
?s ?p ?o .
FILTER (
(?s IN ({node_list}) || ?o IN ({node_list})) &&
NOT EXISTS {{ ?s <{TRIPLE_PROP.uuid}> ?uuid }}
)
}}
}}
"""

logger.debug(f"Find adjacent triples SPARQL query\n{query}\n")

try:
result_triples = self.query(query=query, format="turtle")
except Exception as e:
raise ValueError(f"SPARQL Query Error: {e}")

if result_triples is None:
raise ValueError("SPARQL query returned no results")

graph = Graph()
graph.parse(data=result_triples, format="turtle")

temp_new_nodes_set = set()
for s, p, o in graph:
formatted_subject = format_node(s)
formatted_object = format_node(o)

if formatted_subject not in explored_nodes:
temp_new_nodes_set.add(formatted_subject)
explored_nodes.add(formatted_subject)

if formatted_object not in explored_nodes:
temp_new_nodes_set.add(formatted_object)
explored_nodes.add(formatted_object)

new_nodes_set = temp_new_nodes_set

if query is None:
return ""

result = self.query(query=query, format="turtle")
logger.debug(f"Adjacent Triples\n{result}\n")

return result

def delete_all(self, graph_id: str = None) -> None:
d_id = f"data-{graph_id}" if graph_id else "data"
o_id = f"ontology-{graph_id}" if graph_id else "ontology"

query = f"""DROP GRAPH <{o_id}> ; DROP GRAPH <{d_id}> ;"""

self._query(
url=f"{self.connection_url}/update",
Expand Down
Loading

0 comments on commit 4e7936f

Please sign in to comment.