From 206d58b498f882a3eae0e79716b65669e90c194f Mon Sep 17 00:00:00 2001 From: ShiHan Wan Date: Sun, 13 Oct 2024 22:23:06 -0400 Subject: [PATCH] chore: add ids to triples for easier manipulation (#21) * add internal namespaces * add id to graph * add unit tests * format * save triple ids to vector * format * remove print --- memonto/core/recall.py | 29 ++++++++++++----- memonto/core/retain.py | 7 ++-- memonto/stores/vector/chroma.py | 14 ++++++-- memonto/utils/namespaces.py | 3 ++ memonto/utils/rdf.py | 49 ++++++++++++++++++++++++---- tests/utils/test_rdf.py | 58 +++++++++++++++++++++++++++++++++ 6 files changed, 140 insertions(+), 20 deletions(-) create mode 100644 memonto/utils/namespaces.py create mode 100644 tests/utils/test_rdf.py diff --git a/memonto/core/recall.py b/memonto/core/recall.py index 488dcb1..a6ff735 100644 --- a/memonto/core/recall.py +++ b/memonto/core/recall.py @@ -4,6 +4,8 @@ from memonto.stores.triple.base_store import TripleStoreModel from memonto.stores.vector.base_store import VectorStoreModel from memonto.utils.logger import logger +from memonto.utils.namespaces import TRIPLE_PROP +from memonto.utils.rdf import serialize_graph_without_ids def _hydrate_triples( @@ -11,21 +13,23 @@ def _hydrate_triples( triple_store: VectorStoreModel, id: str = None, ) -> Graph: - triple_values = " ".join( - f"(<{triple['s']}> <{triple['p']}> \"{triple['o']}\")" for triple in triples - ) + triple_ids = " ".join(f'("{triple_id}")' for triple_id in triples) graph_id = f"data-{id}" if id else "data" query = f""" + PREFIX rdf: + CONSTRUCT {{ ?s ?p ?o . }} WHERE {{ GRAPH <{graph_id}> {{ - VALUES (?s ?p ?o_string) {{ {triple_values} }} - ?s ?p ?o . - FILTER (STR(?s) = STR(?s) && STR(?p) = STR(?p) && STR(?o) = ?o_string) + VALUES (?uuid) {{ {triple_ids} }} + ?triple_node <{TRIPLE_PROP.uuid}> ?uuid . + ?triple_node rdf:subject ?s ; + rdf:predicate ?p ; + rdf:object ?o . }} }} """ @@ -121,7 +125,16 @@ def _find_adjacent_triples( def _find_all(triple_store: TripleStoreModel, id: str) -> str: result = triple_store.query( - query=f"CONSTRUCT {{?s ?p ?o .}} WHERE {{ GRAPH {{ ?s ?p ?o . }} }}", + query=f""" + CONSTRUCT {{ + ?s ?p ?o . + }} WHERE {{ + GRAPH {{ + ?s ?p ?o . + FILTER NOT EXISTS {{ ?s <{TRIPLE_PROP.uuid}> ?uuid }} + }} + }} + """, format="turtle", ) @@ -145,7 +158,7 @@ def get_contextual_memory( memory = "" if ephemeral: - memory = data.serialize(format="turtle") + memory = serialize_graph_without_ids(data) elif context: try: matched_triples = vector_store.search(message=context, id=id) diff --git a/memonto/core/retain.py b/memonto/core/retain.py index 2d5e8a6..96a25b3 100644 --- a/memonto/core/retain.py +++ b/memonto/core/retain.py @@ -4,7 +4,7 @@ from memonto.stores.triple.base_store import TripleStoreModel from memonto.stores.vector.base_store import VectorStoreModel from memonto.utils.logger import logger -from memonto.utils.rdf import _render +from memonto.utils.rdf import _render, hydrate_graph_with_ids def run_script( @@ -108,9 +108,12 @@ def _retain( logger.debug(f"Data Graph\n{data.serialize(format='turtle')}\n") if not ephemeral: + hydrate_graph_with_ids(data) + triple_store.save(ontology=ontology, data=data, id=id) + if vector_store: vector_store.save(g=data, id=id) - # _render(g=data, format="image") + # print(_render(g=data, format="image")) data.remove((None, None, None)) diff --git a/memonto/stores/vector/chroma.py b/memonto/stores/vector/chroma.py index c57f7d7..2331827 100644 --- a/memonto/stores/vector/chroma.py +++ b/memonto/stores/vector/chroma.py @@ -2,12 +2,13 @@ import json from chromadb.config import Settings from pydantic import model_validator -from rdflib import Graph +from rdflib import Graph, RDF, BNode from typing import Literal from memonto.stores.vector.base_store import VectorStoreModel from memonto.utils.logger import logger from memonto.utils.rdf import is_rdf_schema, remove_namespace +from memonto.utils.namespaces import TRIPLE_PROP class Chroma(VectorStoreModel): @@ -59,16 +60,23 @@ def save(self, g: Graph, id: str = None) -> None: for s, p, o in g: if is_rdf_schema(p): continue + if isinstance(s, BNode) and (s, TRIPLE_PROP.uuid, None) in g: + continue _s = remove_namespace(str(s)) _p = remove_namespace(str(p)) _o = remove_namespace(str(o)) + id = "" + for bnode in g.subjects(RDF.subject, s): + if (bnode, RDF.predicate, p) in g and (bnode, RDF.object, o) in g: + id = g.value(bnode, TRIPLE_PROP.uuid) + documents.append(f"{_s} {_p} {_o}") metadatas.append( {"triple": json.dumps({"s": str(s), "p": str(p), "o": str(o)})} ) - ids.append(f"{s}-{p}-{o}") + ids.append(f"{id}") if documents: try: @@ -87,7 +95,7 @@ def search(self, message: str, id: str = None, k: int = 3) -> list[dict]: except Exception as e: logger.error(f"Chroma Search\n{e}\n") - return [json.loads(t.get("triple", "{}")) for t in matched["metadatas"][0]] + return matched.get("ids", [])[0] def delete(self, id: str) -> None: try: diff --git a/memonto/utils/namespaces.py b/memonto/utils/namespaces.py new file mode 100644 index 0000000..fec042f --- /dev/null +++ b/memonto/utils/namespaces.py @@ -0,0 +1,3 @@ +from rdflib import Namespace + +TRIPLE_PROP = Namespace("triple:property:") diff --git a/memonto/utils/rdf.py b/memonto/utils/rdf.py index 8d51758..4d19c32 100644 --- a/memonto/utils/rdf.py +++ b/memonto/utils/rdf.py @@ -1,10 +1,14 @@ import datetime import graphviz import os -from rdflib import Graph +import uuid +from collections import defaultdict +from rdflib import Graph, Literal, BNode from rdflib.namespace import RDF, RDFS, OWL from typing import Union +from memonto.utils.namespaces import TRIPLE_PROP + def is_rdf_schema(p) -> Graph: return p.startswith(RDFS) or p.startswith(OWL) or p.startswith(RDF) @@ -18,6 +22,33 @@ def remove_namespace(c: str) -> str: return c.split("/")[-1].split("#")[-1].split(":")[-1] +def serialize_graph_without_ids(g: Graph, format: str = "turtle") -> Graph: + graph = Graph() + + for s, p, o in g: + if isinstance(s, BNode) and (s, TRIPLE_PROP.uuid, None) in g: + continue + + graph.add((s, p, o)) + + return graph.serialize(format=format) + + +def hydrate_graph_with_ids(g: Graph) -> Graph: + for s, p, o in g: + id = str(uuid.uuid4()) + + triple_node = BNode() + + g.add((triple_node, RDF.subject, s)) + g.add((triple_node, RDF.predicate, p)) + g.add((triple_node, RDF.object, o)) + + g.add((triple_node, TRIPLE_PROP.uuid, Literal(id))) + + return g + + def generate_image(g: Graph, path: str = None) -> None: if not path: current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") @@ -27,13 +58,17 @@ def generate_image(g: Graph, path: str = None) -> None: dot = graphviz.Digraph() + bnode_labels = defaultdict(lambda: f"BNode{len(bnode_labels) + 1}") + for s, p, o in g: - if is_rdf_schema(p): + if isinstance(s, BNode) and (s, TRIPLE_PROP.uuid, None) in g: + continue + if isinstance(o, BNode) and (o, TRIPLE_PROP.uuid, None) in g: continue - s_label = sanitize_label(str(s)) + s_label = bnode_labels[s] if isinstance(s, BNode) else sanitize_label(str(s)) + o_label = bnode_labels[o] if isinstance(o, BNode) else sanitize_label(str(o)) p_label = sanitize_label(str(p)) - o_label = sanitize_label(str(o)) dot.node(s_label, s_label) dot.node(o_label, o_label) @@ -78,11 +113,11 @@ def _render( - "image" format returns a string with the path to the png image. """ if format == "turtle": - return g.serialize(format="turtle") + return serialize_graph_without_ids(g=g, format="turtle") elif format == "json": - return g.serialize(format="json-ld") + return serialize_graph_without_ids(g=g, format="json-ld") elif format == "triples": - return g.serialize(format="nt") + return serialize_graph_without_ids(g=g, format="nt") elif format == "text": return generate_text(g) elif format == "image": diff --git a/tests/utils/test_rdf.py b/tests/utils/test_rdf.py new file mode 100644 index 0000000..77cd812 --- /dev/null +++ b/tests/utils/test_rdf.py @@ -0,0 +1,58 @@ +import pytest +from rdflib import Graph, URIRef, Literal, BNode + +from memonto.utils.namespaces import TRIPLE_PROP +from memonto.utils.rdf import serialize_graph_without_ids, hydrate_graph_with_ids + + +@pytest.fixture +def graph(): + g = Graph() + + g.add( + ( + URIRef("http://example.org/s1"), + URIRef("http://example.org/p1"), + URIRef("http://example.org/o1"), + ) + ) + g.add( + ( + URIRef("http://example.org/s2"), + URIRef("http://example.org/p2"), + URIRef("http://example.org/o2"), + ) + ) + + return g + + +@pytest.fixture +def bnode_graph(): + g = Graph() + + g.add( + ( + URIRef("http://example.org/s"), + URIRef("http://example.org/p"), + URIRef("http://example.org/o"), + ) + ) + g.add((BNode(), TRIPLE_PROP.uuid, Literal("12345"))) + + return g + + +def test_serialize_graph(bnode_graph): + g = serialize_graph_without_ids(bnode_graph) + + assert "12345" not in g + assert "s" in g + + +def test_hydrate_graph_with_ids(graph): + g = hydrate_graph_with_ids(graph) + + uuid_triples = [t for t in g if t[1] == TRIPLE_PROP.uuid] + assert len(uuid_triples) == 2 + assert all(isinstance(t[2], Literal) for t in uuid_triples)