Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: improve kg storage and retrieval from vector #23

Merged
merged 5 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion memonto/core/forget.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _forget(
vector_store.delete(id)

if triple_store:
triple_store.delete(id)
triple_store.delete_all(id)
except ValueError as e:
logger.warning(e)
except Exception as e:
Expand Down
156 changes: 3 additions & 153 deletions memonto/core/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,146 +8,6 @@
from memonto.utils.rdf import serialize_graph_without_ids


def _hydrate_triples(
matched: list,
triple_store: VectorStoreModel,
id: str = None,
) -> Graph:
matched_ids = matched.keys()
triple_ids = " ".join(f'("{id}")' for id in matched_ids)

graph_id = f"data-{id}" if id else "data"

query = f"""
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>

CONSTRUCT {{
?s ?p ?o .
}}
WHERE {{
GRAPH <{graph_id}> {{
VALUES (?uuid) {{ {triple_ids} }}
?triple_node <{TRIPLE_PROP.uuid}> ?uuid .
?triple_node rdf:subject ?s ;
rdf:predicate ?p ;
rdf:object ?o .
}}
}}
"""

result = triple_store.query(query=query, format="turtle")

g = Graph()
g.parse(data=result, format="turtle")

return g


def _get_formatted_node(node: URIRef | Literal | BNode) -> str:
if isinstance(node, URIRef):
return f"<{str(node)}>"
elif isinstance(node, Literal):
return f'"{str(node)}"'
elif isinstance(node, BNode):
return f"_:{str(node)}"
else:
return f'"{str(node)}"'


def _find_adjacent_triples(
triples: Graph,
triple_store: VectorStoreModel,
id: str = None,
depth: int = 1,
) -> str:
nodes_set = set()

for s, p, o in triples:
nodes_set.add(_get_formatted_node(s))
nodes_set.add(_get_formatted_node(o))

explored_nodes = set(nodes_set)
new_nodes_set = nodes_set.copy()

query = None

for _ in range(depth):
if not new_nodes_set:
break

node_list = ", ".join(new_nodes_set)
graph_id = f"data-{id}" if id else "data"

query = f"""
CONSTRUCT {{
?s ?p ?o .
}}
WHERE {{
GRAPH <{graph_id}> {{
?s ?p ?o .
FILTER (?s IN ({node_list}) || ?o IN ({node_list}))
}}
}}
"""

logger.debug(f"Find adjacent triples SPARQL query\n{query}\n")

try:
result_triples = triple_store.query(query=query, format="turtle")
except Exception as e:
raise ValueError(f"SPARQL Query Error: {e}")

if result_triples is None:
raise ValueError("SPARQL query returned no results")

graph = Graph()
graph.parse(data=result_triples, format="turtle")

temp_new_nodes_set = set()
for s, p, o in graph:
formatted_subject = _get_formatted_node(s)
formatted_object = _get_formatted_node(o)

if formatted_subject not in explored_nodes:
temp_new_nodes_set.add(formatted_subject)
explored_nodes.add(formatted_subject)

if formatted_object not in explored_nodes:
temp_new_nodes_set.add(formatted_object)
explored_nodes.add(formatted_object)

new_nodes_set = temp_new_nodes_set

if query is None:
return ""

return triple_store.query(query=query, format="turtle")


def _find_all(triple_store: TripleStoreModel, id: str) -> str:
result = triple_store.query(
query=f"""
CONSTRUCT {{
?s ?p ?o .
}} WHERE {{
GRAPH <data-{id}> {{
?s ?p ?o .
FILTER NOT EXISTS {{ ?s <{TRIPLE_PROP.uuid}> ?uuid }}
}}
}}
""",
format="turtle",
)

if isinstance(result, bytes):
result = result.decode("utf-8")

if not result:
return ""

return str(result)


def get_contextual_memory(
data: Graph,
vector_store: VectorStoreModel,
Expand All @@ -165,25 +25,15 @@ def get_contextual_memory(
matched = vector_store.search(message=context, id=id)
logger.debug(f"Matched Triples Raw\n{matched}\n")

matched_graph = _hydrate_triples(
memory = triple_store.get_context(
matched=matched,
triple_store=triple_store,
id=id,
)
matched_triples = matched_graph.serialize(format="turtle")
logger.debug(f"Matched Triples\n{matched_triples}\n")

memory = _find_adjacent_triples(
triples=matched_graph,
triple_store=triple_store,
id=id,
graph_id=id,
depth=1,
)
logger.debug(f"Adjacent Triples\n{memory}\n")
except ValueError as e:
logger.debug(f"Recall Exception\n{e}\n")
else:
memory = _find_all(triple_store=triple_store, id=id)
memory = triple_store.get_all(graph_id=id)

logger.debug(f"Contextual Memory\n{memory}\n")
return memory
Expand Down
4 changes: 2 additions & 2 deletions memonto/core/retain.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,14 @@ def save_memory(
logger.debug(f"Data Graph\n{data.serialize(format='turtle')}\n")

# debug
# _render(g=data, format="image")
# _render(g=data, ns=namespaces, format="image")

if not ephemeral:
hydrate_graph_with_ids(data)
triple_store.save(ontology=ontology, data=data, id=id)

if vector_store:
vector_store.save(g=data, id=id)
vector_store.save(g=data, ns=namespaces, id=id)

data.remove((None, None, None))

Expand Down
11 changes: 11 additions & 0 deletions memonto/stores/triple/base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ def get(self):
"""
pass

@abstractmethod
def get_all(self, graph_id: str = None) -> str:
"""
Get all memory data from the datastore.

:param graph_id: The id of the graph to get all memory data from.

:return: A string representation of the memory data.
"""
pass

@abstractmethod
def query(self):
"""
Expand Down
151 changes: 149 additions & 2 deletions memonto/stores/triple/jena.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from memonto.stores.triple.base_store import TripleStoreModel
from memonto.utils.logger import logger
from memonto.utils.namespaces import TRIPLE_PROP
from memonto.utils.rdf import format_node


class ApacheJena(TripleStoreModel):
Expand Down Expand Up @@ -53,6 +54,43 @@ def _get_prefixes(self, g: Graph) -> list[str]:
gt = g.serialize(format="turtle")
return [line for line in gt.splitlines() if line.startswith("@prefix")]

def _hydrate_triples(
self,
matched: list,
graph_id: str = None,
) -> Graph:
matched_ids = matched.keys()
triple_ids = " ".join(f'("{id}")' for id in matched_ids)
g_id = f"data-{graph_id}" if graph_id else "data"

query = f"""
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>

CONSTRUCT {{
?s ?p ?o .
}}
WHERE {{
GRAPH <{g_id}> {{
VALUES (?uuid) {{ {triple_ids} }}
?triple_node <{TRIPLE_PROP.uuid}> ?uuid .
?triple_node rdf:subject ?s ;
rdf:predicate ?p ;
rdf:object ?o .
}}
}}
"""

result = self._query(
url=f"{self.connection_url}/sparql",
method=GET,
query=query,
)

g = Graph()
g.parse(data=result, format="turtle")

return g

def _load(
self,
g: Graph,
Expand Down Expand Up @@ -162,8 +200,117 @@ def get(

return result["results"]["bindings"]

def delete(self, id: str = None) -> None:
query = f"""DROP GRAPH <ontology-{id}> ; DROP GRAPH <data-{id}> ;"""
def get_all(self, graph_id: str = None) -> str:
g_id = f"data-{graph_id}" if graph_id else "data"

query = f"""
CONSTRUCT {{
?s ?p ?o .
}} WHERE {{
GRAPH <{g_id}> {{
?s ?p ?o .
FILTER NOT EXISTS {{ ?s <{TRIPLE_PROP.uuid}> ?uuid }}
}}
}}
"""

result = self._query(
url=f"{self.connection_url}/sparql",
method=GET,
query=query,
)

if isinstance(result, bytes):
result = result.decode("utf-8")

if not result:
return ""

return str(result)

def get_context(
self, matched: dict[str, dict], graph_id: str, depth: int = 1
) -> str:
g_id = f"data-{graph_id}" if graph_id else "data"
nodes_set = set()

matched_graph = self._hydrate_triples(
matched=matched,
graph_id=graph_id,
)
logger.debug(f"Matched Triples\n{matched_graph.serialize(format='turtle')}\n")

for s, p, o in matched_graph:
nodes_set.add(format_node(s))
nodes_set.add(format_node(o))

explored_nodes = set(nodes_set)
new_nodes_set = nodes_set.copy()

query = None

for _ in range(depth):
if not new_nodes_set:
break

node_list = ", ".join(new_nodes_set)

query = f"""
CONSTRUCT {{
?s ?p ?o .
}}
WHERE {{
GRAPH <{g_id}> {{
?s ?p ?o .
FILTER (
(?s IN ({node_list}) || ?o IN ({node_list})) &&
NOT EXISTS {{ ?s <{TRIPLE_PROP.uuid}> ?uuid }}
)
}}
}}
"""

logger.debug(f"Find adjacent triples SPARQL query\n{query}\n")

try:
result_triples = self.query(query=query, format="turtle")
except Exception as e:
raise ValueError(f"SPARQL Query Error: {e}")

if result_triples is None:
raise ValueError("SPARQL query returned no results")

graph = Graph()
graph.parse(data=result_triples, format="turtle")

temp_new_nodes_set = set()
for s, p, o in graph:
formatted_subject = format_node(s)
formatted_object = format_node(o)

if formatted_subject not in explored_nodes:
temp_new_nodes_set.add(formatted_subject)
explored_nodes.add(formatted_subject)

if formatted_object not in explored_nodes:
temp_new_nodes_set.add(formatted_object)
explored_nodes.add(formatted_object)

new_nodes_set = temp_new_nodes_set

if query is None:
return ""

result = self.query(query=query, format="turtle")
logger.debug(f"Adjacent Triples\n{result}\n")

return result

def delete_all(self, graph_id: str = None) -> None:
d_id = f"data-{graph_id}" if graph_id else "data"
o_id = f"ontology-{graph_id}" if graph_id else "ontology"

query = f"""DROP GRAPH <{o_id}> ; DROP GRAPH <{d_id}> ;"""

self._query(
url=f"{self.connection_url}/update",
Expand Down
Loading
Loading