From 4e7936f54a963de368cd739ee540f8b3730be5c9 Mon Sep 17 00:00:00 2001 From: ShiHan Wan Date: Tue, 15 Oct 2024 21:41:28 -0400 Subject: [PATCH] perf: improve kg storage and retrieval from vector (#23) * improve kg node value parsing * remove bnodes from recall * fix tests * formatting * remove debug code --- memonto/core/forget.py | 2 +- memonto/core/recall.py | 156 +--------------------------- memonto/core/retain.py | 4 +- memonto/stores/triple/base_store.py | 11 ++ memonto/stores/triple/jena.py | 151 ++++++++++++++++++++++++++- memonto/stores/vector/chroma.py | 16 ++- memonto/utils/rdf.py | 47 +++++++-- tests/core/test_recall.py | 30 +++--- 8 files changed, 226 insertions(+), 191 deletions(-) diff --git a/memonto/core/forget.py b/memonto/core/forget.py index e8144b1..f87dd4f 100644 --- a/memonto/core/forget.py +++ b/memonto/core/forget.py @@ -20,7 +20,7 @@ def _forget( vector_store.delete(id) if triple_store: - triple_store.delete(id) + triple_store.delete_all(id) except ValueError as e: logger.warning(e) except Exception as e: diff --git a/memonto/core/recall.py b/memonto/core/recall.py index 467f333..8834eb3 100644 --- a/memonto/core/recall.py +++ b/memonto/core/recall.py @@ -8,146 +8,6 @@ from memonto.utils.rdf import serialize_graph_without_ids -def _hydrate_triples( - matched: list, - triple_store: VectorStoreModel, - id: str = None, -) -> Graph: - matched_ids = matched.keys() - triple_ids = " ".join(f'("{id}")' for id in matched_ids) - - graph_id = f"data-{id}" if id else "data" - - query = f""" - PREFIX rdf: - - CONSTRUCT {{ - ?s ?p ?o . - }} - WHERE {{ - GRAPH <{graph_id}> {{ - VALUES (?uuid) {{ {triple_ids} }} - ?triple_node <{TRIPLE_PROP.uuid}> ?uuid . - ?triple_node rdf:subject ?s ; - rdf:predicate ?p ; - rdf:object ?o . - }} - }} - """ - - result = triple_store.query(query=query, format="turtle") - - g = Graph() - g.parse(data=result, format="turtle") - - return g - - -def _get_formatted_node(node: URIRef | Literal | BNode) -> str: - if isinstance(node, URIRef): - return f"<{str(node)}>" - elif isinstance(node, Literal): - return f'"{str(node)}"' - elif isinstance(node, BNode): - return f"_:{str(node)}" - else: - return f'"{str(node)}"' - - -def _find_adjacent_triples( - triples: Graph, - triple_store: VectorStoreModel, - id: str = None, - depth: int = 1, -) -> str: - nodes_set = set() - - for s, p, o in triples: - nodes_set.add(_get_formatted_node(s)) - nodes_set.add(_get_formatted_node(o)) - - explored_nodes = set(nodes_set) - new_nodes_set = nodes_set.copy() - - query = None - - for _ in range(depth): - if not new_nodes_set: - break - - node_list = ", ".join(new_nodes_set) - graph_id = f"data-{id}" if id else "data" - - query = f""" - CONSTRUCT {{ - ?s ?p ?o . - }} - WHERE {{ - GRAPH <{graph_id}> {{ - ?s ?p ?o . - FILTER (?s IN ({node_list}) || ?o IN ({node_list})) - }} - }} - """ - - logger.debug(f"Find adjacent triples SPARQL query\n{query}\n") - - try: - result_triples = triple_store.query(query=query, format="turtle") - except Exception as e: - raise ValueError(f"SPARQL Query Error: {e}") - - if result_triples is None: - raise ValueError("SPARQL query returned no results") - - graph = Graph() - graph.parse(data=result_triples, format="turtle") - - temp_new_nodes_set = set() - for s, p, o in graph: - formatted_subject = _get_formatted_node(s) - formatted_object = _get_formatted_node(o) - - if formatted_subject not in explored_nodes: - temp_new_nodes_set.add(formatted_subject) - explored_nodes.add(formatted_subject) - - if formatted_object not in explored_nodes: - temp_new_nodes_set.add(formatted_object) - explored_nodes.add(formatted_object) - - new_nodes_set = temp_new_nodes_set - - if query is None: - return "" - - return triple_store.query(query=query, format="turtle") - - -def _find_all(triple_store: TripleStoreModel, id: str) -> str: - result = triple_store.query( - query=f""" - CONSTRUCT {{ - ?s ?p ?o . - }} WHERE {{ - GRAPH {{ - ?s ?p ?o . - FILTER NOT EXISTS {{ ?s <{TRIPLE_PROP.uuid}> ?uuid }} - }} - }} - """, - format="turtle", - ) - - if isinstance(result, bytes): - result = result.decode("utf-8") - - if not result: - return "" - - return str(result) - - def get_contextual_memory( data: Graph, vector_store: VectorStoreModel, @@ -165,25 +25,15 @@ def get_contextual_memory( matched = vector_store.search(message=context, id=id) logger.debug(f"Matched Triples Raw\n{matched}\n") - matched_graph = _hydrate_triples( + memory = triple_store.get_context( matched=matched, - triple_store=triple_store, - id=id, - ) - matched_triples = matched_graph.serialize(format="turtle") - logger.debug(f"Matched Triples\n{matched_triples}\n") - - memory = _find_adjacent_triples( - triples=matched_graph, - triple_store=triple_store, - id=id, + graph_id=id, depth=1, ) - logger.debug(f"Adjacent Triples\n{memory}\n") except ValueError as e: logger.debug(f"Recall Exception\n{e}\n") else: - memory = _find_all(triple_store=triple_store, id=id) + memory = triple_store.get_all(graph_id=id) logger.debug(f"Contextual Memory\n{memory}\n") return memory diff --git a/memonto/core/retain.py b/memonto/core/retain.py index 10c7da3..e95e1ef 100644 --- a/memonto/core/retain.py +++ b/memonto/core/retain.py @@ -180,14 +180,14 @@ def save_memory( logger.debug(f"Data Graph\n{data.serialize(format='turtle')}\n") # debug - # _render(g=data, format="image") + # _render(g=data, ns=namespaces, format="image") if not ephemeral: hydrate_graph_with_ids(data) triple_store.save(ontology=ontology, data=data, id=id) if vector_store: - vector_store.save(g=data, id=id) + vector_store.save(g=data, ns=namespaces, id=id) data.remove((None, None, None)) diff --git a/memonto/stores/triple/base_store.py b/memonto/stores/triple/base_store.py index 4decde8..53a4c6b 100644 --- a/memonto/stores/triple/base_store.py +++ b/memonto/stores/triple/base_store.py @@ -27,6 +27,17 @@ def get(self): """ pass + @abstractmethod + def get_all(self, graph_id: str = None) -> str: + """ + Get all memory data from the datastore. + + :param graph_id: The id of the graph to get all memory data from. + + :return: A string representation of the memory data. + """ + pass + @abstractmethod def query(self): """ diff --git a/memonto/stores/triple/jena.py b/memonto/stores/triple/jena.py index 827fa0e..0d6cb09 100644 --- a/memonto/stores/triple/jena.py +++ b/memonto/stores/triple/jena.py @@ -7,6 +7,7 @@ from memonto.stores.triple.base_store import TripleStoreModel from memonto.utils.logger import logger from memonto.utils.namespaces import TRIPLE_PROP +from memonto.utils.rdf import format_node class ApacheJena(TripleStoreModel): @@ -53,6 +54,43 @@ def _get_prefixes(self, g: Graph) -> list[str]: gt = g.serialize(format="turtle") return [line for line in gt.splitlines() if line.startswith("@prefix")] + def _hydrate_triples( + self, + matched: list, + graph_id: str = None, + ) -> Graph: + matched_ids = matched.keys() + triple_ids = " ".join(f'("{id}")' for id in matched_ids) + g_id = f"data-{graph_id}" if graph_id else "data" + + query = f""" + PREFIX rdf: + + CONSTRUCT {{ + ?s ?p ?o . + }} + WHERE {{ + GRAPH <{g_id}> {{ + VALUES (?uuid) {{ {triple_ids} }} + ?triple_node <{TRIPLE_PROP.uuid}> ?uuid . + ?triple_node rdf:subject ?s ; + rdf:predicate ?p ; + rdf:object ?o . + }} + }} + """ + + result = self._query( + url=f"{self.connection_url}/sparql", + method=GET, + query=query, + ) + + g = Graph() + g.parse(data=result, format="turtle") + + return g + def _load( self, g: Graph, @@ -162,8 +200,117 @@ def get( return result["results"]["bindings"] - def delete(self, id: str = None) -> None: - query = f"""DROP GRAPH ; DROP GRAPH ;""" + def get_all(self, graph_id: str = None) -> str: + g_id = f"data-{graph_id}" if graph_id else "data" + + query = f""" + CONSTRUCT {{ + ?s ?p ?o . + }} WHERE {{ + GRAPH <{g_id}> {{ + ?s ?p ?o . + FILTER NOT EXISTS {{ ?s <{TRIPLE_PROP.uuid}> ?uuid }} + }} + }} + """ + + result = self._query( + url=f"{self.connection_url}/sparql", + method=GET, + query=query, + ) + + if isinstance(result, bytes): + result = result.decode("utf-8") + + if not result: + return "" + + return str(result) + + def get_context( + self, matched: dict[str, dict], graph_id: str, depth: int = 1 + ) -> str: + g_id = f"data-{graph_id}" if graph_id else "data" + nodes_set = set() + + matched_graph = self._hydrate_triples( + matched=matched, + graph_id=graph_id, + ) + logger.debug(f"Matched Triples\n{matched_graph.serialize(format='turtle')}\n") + + for s, p, o in matched_graph: + nodes_set.add(format_node(s)) + nodes_set.add(format_node(o)) + + explored_nodes = set(nodes_set) + new_nodes_set = nodes_set.copy() + + query = None + + for _ in range(depth): + if not new_nodes_set: + break + + node_list = ", ".join(new_nodes_set) + + query = f""" + CONSTRUCT {{ + ?s ?p ?o . + }} + WHERE {{ + GRAPH <{g_id}> {{ + ?s ?p ?o . + FILTER ( + (?s IN ({node_list}) || ?o IN ({node_list})) && + NOT EXISTS {{ ?s <{TRIPLE_PROP.uuid}> ?uuid }} + ) + }} + }} + """ + + logger.debug(f"Find adjacent triples SPARQL query\n{query}\n") + + try: + result_triples = self.query(query=query, format="turtle") + except Exception as e: + raise ValueError(f"SPARQL Query Error: {e}") + + if result_triples is None: + raise ValueError("SPARQL query returned no results") + + graph = Graph() + graph.parse(data=result_triples, format="turtle") + + temp_new_nodes_set = set() + for s, p, o in graph: + formatted_subject = format_node(s) + formatted_object = format_node(o) + + if formatted_subject not in explored_nodes: + temp_new_nodes_set.add(formatted_subject) + explored_nodes.add(formatted_subject) + + if formatted_object not in explored_nodes: + temp_new_nodes_set.add(formatted_object) + explored_nodes.add(formatted_object) + + new_nodes_set = temp_new_nodes_set + + if query is None: + return "" + + result = self.query(query=query, format="turtle") + logger.debug(f"Adjacent Triples\n{result}\n") + + return result + + def delete_all(self, graph_id: str = None) -> None: + d_id = f"data-{graph_id}" if graph_id else "data" + o_id = f"ontology-{graph_id}" if graph_id else "ontology" + + query = f"""DROP GRAPH <{o_id}> ; DROP GRAPH <{d_id}> ;""" self._query( url=f"{self.connection_url}/update", diff --git a/memonto/stores/vector/chroma.py b/memonto/stores/vector/chroma.py index d7bd76c..6f82e3e 100644 --- a/memonto/stores/vector/chroma.py +++ b/memonto/stores/vector/chroma.py @@ -2,12 +2,12 @@ import json from chromadb.config import Settings from pydantic import model_validator -from rdflib import Graph, RDF, BNode +from rdflib import Graph, RDF, Namespace from typing import Literal from memonto.stores.vector.base_store import VectorStoreModel from memonto.utils.logger import logger -from memonto.utils.rdf import is_rdf_schema, remove_namespace +from memonto.utils.rdf import is_bnode_uuid, is_rdf_schema, to_human_readable from memonto.utils.namespaces import TRIPLE_PROP @@ -50,7 +50,7 @@ def init(self) -> "Chroma": return self - def save(self, g: Graph, id: str = None) -> None: + def save(self, g: Graph, ns: dict[str, Namespace], id: str = None) -> None: collection = self.client.get_or_create_collection(id or "default") documents = [] @@ -58,14 +58,12 @@ def save(self, g: Graph, id: str = None) -> None: ids = [] for s, p, o in g: - if is_rdf_schema(p): - continue - if isinstance(s, BNode) and (s, TRIPLE_PROP.uuid, None) in g: + if is_rdf_schema(p) or is_bnode_uuid(s, g): continue - _s = remove_namespace(str(s)) - _p = remove_namespace(str(p)) - _o = remove_namespace(str(o)) + _s = to_human_readable(str(s), ns) + _p = to_human_readable(str(p), ns) + _o = to_human_readable(str(o), ns) id = "" for bnode in g.subjects(RDF.subject, s): diff --git a/memonto/utils/rdf.py b/memonto/utils/rdf.py index c193bd1..ef5ea2e 100644 --- a/memonto/utils/rdf.py +++ b/memonto/utils/rdf.py @@ -1,25 +1,45 @@ import datetime import graphviz import os +import re import uuid from collections import defaultdict -from rdflib import Graph, Literal, BNode +from rdflib import Graph, Literal, BNode, Namespace, URIRef from rdflib.namespace import RDF, RDFS, OWL from typing import Union from memonto.utils.namespaces import TRIPLE_PROP -def is_rdf_schema(p) -> Graph: +def is_rdf_schema(p: str) -> Graph: return p.startswith(RDFS) or p.startswith(OWL) or p.startswith(RDF) -def sanitize_label(label: str) -> str: - return label.replace("-", "_").replace(":", "_").replace(" ", "_") +def is_bnode_uuid(s: str, g: Graph) -> bool: + return isinstance(s, BNode) and (s, TRIPLE_PROP.uuid, None) in g -def remove_namespace(c: str) -> str: - return c.split("/")[-1].split("#")[-1].split(":")[-1] +def to_human_readable(c: str, ns: dict[str, Namespace]) -> str: + for n in ns.values(): + if c.startswith(n): + c.replace(n, "") + + c = c.split("/")[-1].split("#")[-1].split(":")[-1] + c = c.replace("_", " ") + c = re.sub(r"(?<=[a-z])(?=[A-Z])", " ", c).lower() + + return c + + +def format_node(node: URIRef | Literal | BNode) -> str: + if isinstance(node, URIRef): + return f"<{str(node)}>" + elif isinstance(node, Literal): + return f'"{str(node)}"' + elif isinstance(node, BNode): + return f"_:{str(node)}" + else: + return f'"{str(node)}"' def serialize_graph_without_ids(g: Graph, format: str = "turtle") -> Graph: @@ -72,7 +92,7 @@ def is_updated(nt: dict, o: list[dict]) -> bool: return [nt for nt in n if is_updated(nt, o)] -def generate_image(g: Graph, path: str = None) -> None: +def generate_image(g: Graph, ns: dict[str, Namespace], path: str = None) -> None: if not path: current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") current_directory = os.getcwd() @@ -89,9 +109,13 @@ def generate_image(g: Graph, path: str = None) -> None: if isinstance(o, BNode) and (o, TRIPLE_PROP.uuid, None) in g: continue - s_label = bnode_labels[s] if isinstance(s, BNode) else sanitize_label(str(s)) - o_label = bnode_labels[o] if isinstance(o, BNode) else sanitize_label(str(o)) - p_label = sanitize_label(str(p)) + s_label = ( + bnode_labels[s] if isinstance(s, BNode) else to_human_readable(str(s), ns) + ) + o_label = ( + bnode_labels[o] if isinstance(o, BNode) else to_human_readable(str(o), ns) + ) + p_label = to_human_readable(str(p), ns) dot.node(s_label, s_label) dot.node(o_label, o_label) @@ -116,6 +140,7 @@ def generate_text(g: Graph) -> str: def _render( g: Graph, + ns: dict[str, Namespace] = None, format: str = "turtle", path: str = None, ) -> Union[str, dict]: @@ -144,6 +169,6 @@ def _render( elif format == "text": return generate_text(g) elif format == "image": - return generate_image(g=g, path=path) + return generate_image(g=g, ns=ns, path=path) else: raise ValueError(f"Unsupported type '{type}'.") diff --git a/tests/core/test_recall.py b/tests/core/test_recall.py index 2538cf7..983f7d7 100644 --- a/tests/core/test_recall.py +++ b/tests/core/test_recall.py @@ -1,8 +1,15 @@ import pytest from rdflib import Graph, Literal, URIRef -from unittest.mock import ANY, MagicMock, patch +from unittest.mock import MagicMock, patch from memonto.core.recall import _recall +from memonto.memonto import Memonto +from memonto.stores.triple.jena import ApacheJena + + +@pytest.fixture +def jena(): + return ApacheJena(connection_url="http://localhost:8080/test-dataset") @pytest.fixture @@ -62,16 +69,16 @@ def data_graph(): return g -@patch("memonto.core.recall._find_all") -def test_fetch_all_memory(mock_find_all, mock_llm, mock_store, id, data_graph): +@patch("memonto.stores.triple.jena.ApacheJena.get_all") +def test_fetch_all_memory(mock_get_all, jena, mock_llm, mock_store, id, data_graph): all_memory = "all memory" - mock_find_all.return_value = all_memory + mock_get_all.return_value = all_memory _recall( data=data_graph, llm=mock_llm, vector_store=mock_store, - triple_store=mock_store, + triple_store=jena, context=None, id=id, ephemeral=False, @@ -84,11 +91,10 @@ def test_fetch_all_memory(mock_find_all, mock_llm, mock_store, id, data_graph): ) -@patch("memonto.core.recall._find_adjacent_triples") -@patch("memonto.core.recall._hydrate_triples") +@patch("memonto.stores.triple.jena.ApacheJena.get_context") def test_fetch_some_memory( - mock_hydrate_triples, - mock_find_adjacent_triples, + mock_get_context, + jena, mock_llm, mock_store, user_query, @@ -96,14 +102,13 @@ def test_fetch_some_memory( data_graph, ): some_memory = "some memory" - mock_find_adjacent_triples.return_value = some_memory - mock_hydrate_triples.return_value = Graph() + mock_get_context.return_value = some_memory _recall( data=data_graph, llm=mock_llm, vector_store=mock_store, - triple_store=mock_store, + triple_store=jena, context=user_query, id=id, ephemeral=False, @@ -117,7 +122,6 @@ def test_fetch_some_memory( def test_fetch_some_memory_ephemeral(mock_llm, data_graph): - mem = _recall( data=data_graph, llm=mock_llm,