Skip to content

Commit

Permalink
chore: add ids to triples for easier manipulation (#21)
Browse files Browse the repository at this point in the history
* add internal namespaces

* add id to graph

* add unit tests

* format

* save triple ids to vector

* format

* remove print
  • Loading branch information
shihanwan authored Oct 14, 2024
1 parent a0531bb commit 206d58b
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 20 deletions.
29 changes: 21 additions & 8 deletions memonto/core/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,32 @@
from memonto.stores.triple.base_store import TripleStoreModel
from memonto.stores.vector.base_store import VectorStoreModel
from memonto.utils.logger import logger
from memonto.utils.namespaces import TRIPLE_PROP
from memonto.utils.rdf import serialize_graph_without_ids


def _hydrate_triples(
triples: list,
triple_store: VectorStoreModel,
id: str = None,
) -> Graph:
triple_values = " ".join(
f"(<{triple['s']}> <{triple['p']}> \"{triple['o']}\")" for triple in triples
)
triple_ids = " ".join(f'("{triple_id}")' for triple_id in triples)

graph_id = f"data-{id}" if id else "data"

query = f"""
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
CONSTRUCT {{
?s ?p ?o .
}}
WHERE {{
GRAPH <{graph_id}> {{
VALUES (?s ?p ?o_string) {{ {triple_values} }}
?s ?p ?o .
FILTER (STR(?s) = STR(?s) && STR(?p) = STR(?p) && STR(?o) = ?o_string)
VALUES (?uuid) {{ {triple_ids} }}
?triple_node <{TRIPLE_PROP.uuid}> ?uuid .
?triple_node rdf:subject ?s ;
rdf:predicate ?p ;
rdf:object ?o .
}}
}}
"""
Expand Down Expand Up @@ -121,7 +125,16 @@ def _find_adjacent_triples(

def _find_all(triple_store: TripleStoreModel, id: str) -> str:
result = triple_store.query(
query=f"CONSTRUCT {{?s ?p ?o .}} WHERE {{ GRAPH <data-{id}> {{ ?s ?p ?o . }} }}",
query=f"""
CONSTRUCT {{
?s ?p ?o .
}} WHERE {{
GRAPH <data-{id}> {{
?s ?p ?o .
FILTER NOT EXISTS {{ ?s <{TRIPLE_PROP.uuid}> ?uuid }}
}}
}}
""",
format="turtle",
)

Expand All @@ -145,7 +158,7 @@ def get_contextual_memory(
memory = ""

if ephemeral:
memory = data.serialize(format="turtle")
memory = serialize_graph_without_ids(data)
elif context:
try:
matched_triples = vector_store.search(message=context, id=id)
Expand Down
7 changes: 5 additions & 2 deletions memonto/core/retain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from memonto.stores.triple.base_store import TripleStoreModel
from memonto.stores.vector.base_store import VectorStoreModel
from memonto.utils.logger import logger
from memonto.utils.rdf import _render
from memonto.utils.rdf import _render, hydrate_graph_with_ids


def run_script(
Expand Down Expand Up @@ -108,9 +108,12 @@ def _retain(
logger.debug(f"Data Graph\n{data.serialize(format='turtle')}\n")

if not ephemeral:
hydrate_graph_with_ids(data)

triple_store.save(ontology=ontology, data=data, id=id)

if vector_store:
vector_store.save(g=data, id=id)

# _render(g=data, format="image")
# print(_render(g=data, format="image"))
data.remove((None, None, None))
14 changes: 11 additions & 3 deletions memonto/stores/vector/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import json
from chromadb.config import Settings
from pydantic import model_validator
from rdflib import Graph
from rdflib import Graph, RDF, BNode
from typing import Literal

from memonto.stores.vector.base_store import VectorStoreModel
from memonto.utils.logger import logger
from memonto.utils.rdf import is_rdf_schema, remove_namespace
from memonto.utils.namespaces import TRIPLE_PROP


class Chroma(VectorStoreModel):
Expand Down Expand Up @@ -59,16 +60,23 @@ def save(self, g: Graph, id: str = None) -> None:
for s, p, o in g:
if is_rdf_schema(p):
continue
if isinstance(s, BNode) and (s, TRIPLE_PROP.uuid, None) in g:
continue

_s = remove_namespace(str(s))
_p = remove_namespace(str(p))
_o = remove_namespace(str(o))

id = ""
for bnode in g.subjects(RDF.subject, s):
if (bnode, RDF.predicate, p) in g and (bnode, RDF.object, o) in g:
id = g.value(bnode, TRIPLE_PROP.uuid)

documents.append(f"{_s} {_p} {_o}")
metadatas.append(
{"triple": json.dumps({"s": str(s), "p": str(p), "o": str(o)})}
)
ids.append(f"{s}-{p}-{o}")
ids.append(f"{id}")

if documents:
try:
Expand All @@ -87,7 +95,7 @@ def search(self, message: str, id: str = None, k: int = 3) -> list[dict]:
except Exception as e:
logger.error(f"Chroma Search\n{e}\n")

return [json.loads(t.get("triple", "{}")) for t in matched["metadatas"][0]]
return matched.get("ids", [])[0]

def delete(self, id: str) -> None:
try:
Expand Down
3 changes: 3 additions & 0 deletions memonto/utils/namespaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from rdflib import Namespace

TRIPLE_PROP = Namespace("triple:property:")
49 changes: 42 additions & 7 deletions memonto/utils/rdf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import datetime
import graphviz
import os
from rdflib import Graph
import uuid
from collections import defaultdict
from rdflib import Graph, Literal, BNode
from rdflib.namespace import RDF, RDFS, OWL
from typing import Union

from memonto.utils.namespaces import TRIPLE_PROP


def is_rdf_schema(p) -> Graph:
return p.startswith(RDFS) or p.startswith(OWL) or p.startswith(RDF)
Expand All @@ -18,6 +22,33 @@ def remove_namespace(c: str) -> str:
return c.split("/")[-1].split("#")[-1].split(":")[-1]


def serialize_graph_without_ids(g: Graph, format: str = "turtle") -> Graph:
graph = Graph()

for s, p, o in g:
if isinstance(s, BNode) and (s, TRIPLE_PROP.uuid, None) in g:
continue

graph.add((s, p, o))

return graph.serialize(format=format)


def hydrate_graph_with_ids(g: Graph) -> Graph:
for s, p, o in g:
id = str(uuid.uuid4())

triple_node = BNode()

g.add((triple_node, RDF.subject, s))
g.add((triple_node, RDF.predicate, p))
g.add((triple_node, RDF.object, o))

g.add((triple_node, TRIPLE_PROP.uuid, Literal(id)))

return g


def generate_image(g: Graph, path: str = None) -> None:
if not path:
current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
Expand All @@ -27,13 +58,17 @@ def generate_image(g: Graph, path: str = None) -> None:

dot = graphviz.Digraph()

bnode_labels = defaultdict(lambda: f"BNode{len(bnode_labels) + 1}")

for s, p, o in g:
if is_rdf_schema(p):
if isinstance(s, BNode) and (s, TRIPLE_PROP.uuid, None) in g:
continue
if isinstance(o, BNode) and (o, TRIPLE_PROP.uuid, None) in g:
continue

s_label = sanitize_label(str(s))
s_label = bnode_labels[s] if isinstance(s, BNode) else sanitize_label(str(s))
o_label = bnode_labels[o] if isinstance(o, BNode) else sanitize_label(str(o))
p_label = sanitize_label(str(p))
o_label = sanitize_label(str(o))

dot.node(s_label, s_label)
dot.node(o_label, o_label)
Expand Down Expand Up @@ -78,11 +113,11 @@ def _render(
- "image" format returns a string with the path to the png image.
"""
if format == "turtle":
return g.serialize(format="turtle")
return serialize_graph_without_ids(g=g, format="turtle")
elif format == "json":
return g.serialize(format="json-ld")
return serialize_graph_without_ids(g=g, format="json-ld")
elif format == "triples":
return g.serialize(format="nt")
return serialize_graph_without_ids(g=g, format="nt")
elif format == "text":
return generate_text(g)
elif format == "image":
Expand Down
58 changes: 58 additions & 0 deletions tests/utils/test_rdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
from rdflib import Graph, URIRef, Literal, BNode

from memonto.utils.namespaces import TRIPLE_PROP
from memonto.utils.rdf import serialize_graph_without_ids, hydrate_graph_with_ids


@pytest.fixture
def graph():
g = Graph()

g.add(
(
URIRef("http://example.org/s1"),
URIRef("http://example.org/p1"),
URIRef("http://example.org/o1"),
)
)
g.add(
(
URIRef("http://example.org/s2"),
URIRef("http://example.org/p2"),
URIRef("http://example.org/o2"),
)
)

return g


@pytest.fixture
def bnode_graph():
g = Graph()

g.add(
(
URIRef("http://example.org/s"),
URIRef("http://example.org/p"),
URIRef("http://example.org/o"),
)
)
g.add((BNode(), TRIPLE_PROP.uuid, Literal("12345")))

return g


def test_serialize_graph(bnode_graph):
g = serialize_graph_without_ids(bnode_graph)

assert "12345" not in g
assert "s" in g


def test_hydrate_graph_with_ids(graph):
g = hydrate_graph_with_ids(graph)

uuid_triples = [t for t in g if t[1] == TRIPLE_PROP.uuid]
assert len(uuid_triples) == 2
assert all(isinstance(t[2], Literal) for t in uuid_triples)

0 comments on commit 206d58b

Please sign in to comment.