Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add visualization of binary relations #203

Merged
merged 3 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions py/hirm_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

Cluster = collections.namedtuple(
"Cluster", ["cluster_id", "relations", "domain_clusters"])
Domain_Cluster = collections.namedtuple(
"Domain_Cluster", ["cluster_id", "domain", "entities"])
DomainCluster = collections.namedtuple(
"DomainCluster", ["cluster_id", "domain", "entities"])

def load_observations(path):
"""Load a dataset from path, and return it as an array of Observations."""
Expand All @@ -24,7 +24,7 @@ def load_observations(path):
return data

def load_clusters(path):
"""Load the hirm clusters output from path as an HIRM_Clusters."""
"""Load the hirm clusters output from path as a list of Cluster's."""
id_to_relations = {}
id_to_clusters = {}

Expand Down Expand Up @@ -55,9 +55,9 @@ def load_clusters(path):
continue

domain_clusters.append(
Domain_Cluster(cluster_id=fields[1],
domain=fields[0],
entities=fields[2:]))
DomainCluster(cluster_id=fields[1],
domain=fields[0],
entities=fields[2:]))

id_to_clusters[cluster_id] = domain_clusters

Expand Down
12 changes: 10 additions & 2 deletions py/visualize_unary.py → py/make_plots.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
#!/usr/bin/python3

# Generates an html file containing plots that visualize the clusters
# produced by an HIRM. Currently it only supports datasets with a single
# domain, and only visualizes the unary and binary relations over that
# domain.

# Example usage:
# ./make_plots.py --observations=../cxx/assets/animals.unary.obs --clusters=../cxx/assets/animals.unary.hirm --output=/tmp/vis.html

import argparse
import hirm_io
import visualize

def main():
parser = argparse.ArgumentParser(
description="Generate matrix plots for a collection of unary relations")
description="Generate matrix plots for an HIRM's output")
parser.add_argument(
"--observations", required=True, type=str,
help="Path to file containing observations")
Expand All @@ -25,7 +33,7 @@ def main():
clusters = hirm_io.load_clusters(args.clusters)

print(f"Writing html to {args.output} ...")
visualize.plot_unary_matrices(clusters, obs, args.output)
visualize.make_plots(clusters, obs, args.output)


if __name__ == "__main__":
Expand Down
134 changes: 105 additions & 29 deletions py/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Apache License, Version 2.0, refer to LICENSE.txt

import base64
import collections
import copy
import io
import sys
Expand All @@ -19,9 +20,9 @@ def figure_to_html(fig) -> str:
return html


def make_numpy_matrix(obs, entities, relations):
"""Return a numpy array containing the observations."""
m = np.empty(shape=(len(entities), len(relations)))
def make_unary_matrix(obs, entities, relations):
"""Return a numpy array containing the observations for unary relations."""
m = np.full((len(entities), len(relations)), np.nan)
for ob in obs:
val = ob.value
ent_index = entities.index(ob.items[0])
Expand All @@ -31,19 +32,19 @@ def make_numpy_matrix(obs, entities, relations):
return m


def normalize_matrix(m):
"""Linearly map matrix values to be in [0, 1]."""
return (m - np.min(m)) / np.ptp(m)


def unary_matrix(cluster, obs, clusters):
"""Plot a matrix visualization for the cluster of unary relations."""
fontsize = 12
def make_binary_matrix(obs, entities):
"""Return array containing the observations for a single binary relation."""
m = np.full((len(entities), len(entities)), np.nan)
for ob in obs:
val = ob.value
index1 = entities.index(ob.items[0])
index2 = entities.index(ob.items[1])
m[index1][index2] = val
return m

relations = sorted(cluster.relations)
num_columns = len(relations)
longest_relation_length = max(len(r) for r in relations)

def get_all_entities(cluster):
"""Return the entities in the order the cluster prefers."""
all_entities = []
domain = ""
for dc in cluster.domain_clusters:
Expand All @@ -55,41 +56,116 @@ def unary_matrix(cluster, obs, clusters):

all_entities.extend(sorted(dc.entities))

return all_entities


def add_redlines(ax, cluster, horizontal=True, vertical=False):
"""Add redlines between entity clusters."""
n = 0
for dc in cluster.domain_clusters:
if n > 0:
if horizontal:
ax.axhline(n - 0.5, color='r', linewidth=2)
if vertical:
ax.axvline(n - 0.5, color='r', linewidth=2)

n += len(dc.entities)


def unary_matrix(cluster, obs, clusters):
"""Plot a matrix visualization for the cluster of unary relations."""
fontsize = 12

relations = sorted(cluster.relations)
num_columns = len(relations)
longest_relation_length = max(len(r) for r in relations)

all_entities = get_all_entities(cluster)
num_rows = len(all_entities)
longest_entity_length = max(len(e) for e in all_entities)

width = (num_columns + longest_entity_length) * fontsize / 72.0 + 0.2 # Fontsize is in points, 1 pt = 1/72in
height = (num_rows + longest_relation_length) * fontsize / 72.0 + 0.2
fig, ax = plt.subplots(figsize=(width,height))
fig, ax = plt.subplots(figsize=(width, height))

ax.xaxis.tick_top()
ax.set_xticks(np.arange(num_columns), labels=relations,
rotation=90, fontsize=fontsize)
ax.set_yticks(np.arange(num_rows), labels=all_entities, fontsize=fontsize)

# Matrix of data
m = make_numpy_matrix(obs, all_entities, relations)
cmap = copy.copy(plt.get_cmap('Greys'))
cmap.set_bad(color='gray')
ax.imshow(normalize_matrix(m))
m = make_unary_matrix(obs, all_entities, relations)
cmap = copy.copy(plt.get_cmap())
cmap.set_bad(color='white')
ax.imshow(m, cmap=cmap)

# Red lines between rows (entities)
n = 0
for dc in cluster.domain_clusters:
if n > 0:
ax.axhline(n - 0.5, color='r', linewidth=2)
n += len(dc.entities)
add_redlines(ax, cluster)

fig.tight_layout()
return fig


def plot_unary_matrices(clusters, obs, output):
def binary_matrix(cluster, obs, clusters):
"""Plot a matrix visualization for a single binary relation."""
fontsize = 12
all_entities = get_all_entities(cluster)
n = len(all_entities)
longest_entity_length = max(len(e) for e in all_entities)
width = (n + longest_entity_length) * fontsize / 72.0 + 0.2
fig, ax = plt.subplots(figsize=(width, width))

ax.xaxis.tick_top()
ax.set_xticks(np.arange(n), labels=all_entities, rotation=90, fontsize=fontsize)
ax.set_yticks(np.arange(n), labels=all_entities, fontsize=fontsize)

m = make_binary_matrix(obs, all_entities)
cmap = copy.copy(plt.get_cmap())
cmap.set_bad(color='white')
ax.imshow(m, cmap=cmap)

add_redlines(ax, cluster, vertical=True)

fig.tight_layout()
return fig


def collate_observations(obs):
"""Separate observations into unary and binary."""
unary_obs = []
binary_obs = collections.defaultdict(list)
for ob in obs:
if len(ob.items) == 1:
unary_obs.append(ob)

if len(ob.items) == 2:
binary_obs[ob.relation].append(ob)

return unary_obs, binary_obs


def html_for_cluster(cluster, obs, clusters):
"""Return the html for visualizing a single cluster."""
html = ''
unary_obs, binary_obs = collate_observations(obs)
if unary_obs:
print(f"For irm #{cluster.cluster_id}, building unary matrix based on {len(unary_obs)} observations")
fig = unary_matrix(cluster, unary_obs, clusters)
html += figure_to_html(fig) + "\n"

if binary_obs:
for rel, rel_obs in binary_obs.items():
print(f"For irm #{cluster.cluster_id}, building binary matrix for {rel} based on {len(rel_obs)} observations")
html += f"<h2>{rel}</h2>\n"
fig = binary_matrix(cluster, rel_obs, clusters)
html += figure_to_html(fig) + "\n"
return html


def make_plots(clusters, obs, output):
"""Write a matrix visualization for each cluster in clusters."""
with open(output, 'w') as f:
f.write("<html><body>\n\n")
for cluster in clusters:
f.write("<h1>IRM #" + cluster.cluster_id + "\n")
fig = unary_matrix(cluster, obs, clusters)
f.write(figure_to_html(fig) + "\n\n")
f.write(f"<h1>IRM #{cluster.cluster_id}</h1>\n")
f.write(html_for_cluster(cluster, obs, clusters) + "\n")
f.write("</body></html>\n")
33 changes: 25 additions & 8 deletions py/visualize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,38 @@ def test_figure_to_html(self):
html = visualize.figure_to_html(fig)
self.assertEqual(html[:9], '<img src=')

def test_make_numpy_matrix(self):
def test_make_unary_matrix(self):
observations = []
observations.append(hirm_io.Observation(relation="R1", value="0", items=["A"]))
observations.append(hirm_io.Observation(relation="R1", value="1", items=["B"]))
observations.append(hirm_io.Observation(relation="R2", value="0", items=["A"]))
observations.append(hirm_io.Observation(relation="R2", value="1", items=["B"]))
m = visualize.make_numpy_matrix(observations, ["A", "B"], ["R1", "R2"])
m = visualize.make_unary_matrix(observations, ["A", "B"], ["R1", "R2"])
self.assertEqual(m.shape, (2, 2))

def test_normalize_matrix(self):
m = np.identity(3)
np.testing.assert_almost_equal(m, visualize.normalize_matrix(m))
np.testing.assert_almost_equal(m, visualize.normalize_matrix(3.0*m))
np.testing.assert_almost_equal([0, 0.25, 0.5, 0.75, 1.0],
visualize.normalize_matrix(np.arange(3, 8)))
def test_make_binary_matrix(self):
observations = []
observations.append(hirm_io.Observation(relation="R1", value="0", items=["A", "A"]))
observations.append(hirm_io.Observation(relation="R1", value="1", items=["B", "A"]))
observations.append(hirm_io.Observation(relation="R1", value="0", items=["A", "B"]))
observations.append(hirm_io.Observation(relation="R1", value="1", items=["B", "B"]))
m = visualize.make_binary_matrix(observations, ["A", "B"])
self.assertEqual(m.shape, (2, 2))

def test_get_all_entities(self):
c = hirm_io.Cluster(
cluster_id="1", relations=["R1"],
domain_clusters=[
hirm_io.DomainCluster(cluster_id="0", domain="animals",
entities=["dog", "cat", "elephant"]),
hirm_io.DomainCluster(cluster_id="1", domain="animals",
entities=["penguin"]),
hirm_io.DomainCluster(cluster_id="3", domain="animals",
entities=["mongoose", "eel", "human"])
])
self.assertEqual(
["cat", "dog", "elephant", "penguin", "eel", "human", "mongoose"],
visualize.get_all_entities(c))


if __name__ == '__main__':
Expand Down