Skip to content

Commit

Permalink
Merge pull request #333 from nanglo123/add_dkg_viz
Browse files Browse the repository at this point in the history
Add visualization for DKG relations
  • Loading branch information
bgyori authored Jun 3, 2024
2 parents 4003ed8 + 0f400ca commit bb2f2c8
Show file tree
Hide file tree
Showing 4 changed files with 317 additions and 3 deletions.
31 changes: 31 additions & 0 deletions mira/dkg/viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import networkx as nx


def draw_relations(records, fname, is_full=False):
"""Draw a graph of some DKG records queried from the /relations endpoint."""

graph = nx.DiGraph()
graph.graph["rankdir"] = "LR"

for relation in records:
if is_full:
subject_curie = relation['subject']['id']
subject_name = relation['subject']['name']
object_curie = relation['object']['id']
object_name = relation['object']['name']
predicate_name = relation['predicate'].get('type')
predicate_curie = relation['predicate']['pred']

subject_node = f"{subject_name} ({subject_curie})"
predicate_edge = f"{predicate_name} ({predicate_curie})" \
if predicate_name else predicate_curie
object_node = f"{object_name} ({object_curie})"

graph.add_edge(subject_node, object_node, label=predicate_edge,
color="red", weight=2)
else:
graph.add_edge(relation['subject'], relation['object'],
label=relation['predicate'],
color="red", weight=2)
agraph = nx.nx_agraph.to_agraph(graph)
agraph.draw(path=fname, prog="dot", format="png")
2 changes: 1 addition & 1 deletion mira/modeling/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, model: Model, initialized: bool):
def get_interpretable_kinetics(self):
# Return kinetics but with y and p substituted
# based on vmap and pmap
subs = {self.y[v]: sympy.Symbol(k) for k, v in self.vmap.items()}
subs = {self.y[v]: sympy.Symbol(k) if isinstance(k, str) else k[0] for k, v in self.vmap.items()}
subs.update({self.p[p]: sympy.Symbol(k) for k, p in self.pmap.items()})
return sympy.Matrix([
k.subs(subs) for k in self.kinetics
Expand Down
272 changes: 272 additions & 0 deletions notebooks/DKG_Viz_Demo.ipynb

Large diffs are not rendered by default.

15 changes: 13 additions & 2 deletions tests/test_dkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from mira.dkg.api import get_relations
from mira.dkg.client import AskemEntity, Entity, METAREGISTRY_BASE
from mira.dkg.utils import MiraState
from mira.dkg.viz import draw_relations

MIRA_NEO4J_URL = pystow.get_config("mira", "neo4j_url") or os.getenv("MIRA_NEO4J_URL")


@unittest.skipIf(not MIRA_NEO4J_URL, reason="Missing neo4j connection configuration")
class TestDKG(unittest.TestCase):
"""Test the DKG."""
Expand Down Expand Up @@ -83,12 +83,23 @@ def test_get_relations(self):
spec = inspect.signature(get_relations)
relation_query_default = spec.parameters["relation_query"].default
self.assertIsInstance(relation_query_default, fastapi.params.Body)

for key, data in relation_query_default.examples.items():
with self.subTest(key=key):
response = self.client.post("/api/relations", json=data["value"])
self.assertEqual(200, response.status_code, msg=response.content)

def test_get_relations_graph(self):
"Test getting graph output of relations."""
spec = inspect.signature(get_relations)
relation_query_default = spec.parameters["relation_query"].default
self.assertIsInstance(relation_query_default, fastapi.params.Body)
for key, data in relation_query_default.examples.items():
with self.subTest(key=key):
response = self.client.post("/api/relations", json=data["value"])
is_full = data['value'].get('full', False)
draw_relations(response.json(), f"test_{key}.png",
is_full=is_full)

def test_search(self):
"""Test search functionality."""
res1 = self.client.get("/api/search", params={
Expand Down

0 comments on commit bb2f2c8

Please sign in to comment.