From 6d7c22a3ffd29ab574d4181a24b5f098a8a623df Mon Sep 17 00:00:00 2001 From: Thomas Colthurst Date: Wed, 18 Sep 2024 20:49:09 +0000 Subject: [PATCH 1/3] Rename to make_plots; add binary relation visualization --- py/{visualize_unary.py => make_plots.py} | 12 ++- py/visualize.py | 122 +++++++++++++++++++---- 2 files changed, 111 insertions(+), 23 deletions(-) rename py/{visualize_unary.py => make_plots.py} (62%) diff --git a/py/visualize_unary.py b/py/make_plots.py similarity index 62% rename from py/visualize_unary.py rename to py/make_plots.py index 5291a33..b00d9ff 100755 --- a/py/visualize_unary.py +++ b/py/make_plots.py @@ -1,12 +1,20 @@ #!/usr/bin/python3 +# Generates an html file containing plots that visualize the clusters +# produced by an HIRM. Currently it only supports datasets with a single +# domain, and only visualizes the unary and binary relations over that +# domain. + +# Example usage: +# ./make_plots.py --observations=../cxx/assets/animals.unary.obs --clusters=../cxx/assets/animals.unary.hirm --output=/tmp/vis.html + import argparse import hirm_io import visualize def main(): parser = argparse.ArgumentParser( - description="Generate matrix plots for a collection of unary relations") + description="Generate matrix plots for an HIRM's output") parser.add_argument( "--observations", required=True, type=str, help="Path to file containing observations") @@ -25,7 +33,7 @@ def main(): clusters = hirm_io.load_clusters(args.clusters) print(f"Writing html to {args.output} ...") - visualize.plot_unary_matrices(clusters, obs, args.output) + visualize.make_plots(clusters, obs, args.output) if __name__ == "__main__": diff --git a/py/visualize.py b/py/visualize.py index 3f95bca..552b6db 100644 --- a/py/visualize.py +++ b/py/visualize.py @@ -2,6 +2,7 @@ # Apache License, Version 2.0, refer to LICENSE.txt import base64 +import collections import copy import io import sys @@ -19,8 +20,8 @@ def figure_to_html(fig) -> str: return html -def make_numpy_matrix(obs, entities, relations): - """Return a numpy array containing the observations.""" +def make_unary_matrix(obs, entities, relations): + """Return a numpy array containing the observations for unary relations.""" m = np.empty(shape=(len(entities), len(relations))) for ob in obs: val = ob.value @@ -31,19 +32,24 @@ def make_numpy_matrix(obs, entities, relations): return m +def make_binary_matrix(obs, entities): + """Return array containing the observations for a single binary relation.""" + m = np.empty(shape=(len(entities), len(entities))) + for ob in obs: + val = ob.value + index1 = entities.index(ob.items[0]) + index2 = entities.index(ob.items[1]) + m[index1][index2] = val + return m + + def normalize_matrix(m): """Linearly map matrix values to be in [0, 1].""" return (m - np.min(m)) / np.ptp(m) -def unary_matrix(cluster, obs, clusters): - """Plot a matrix visualization for the cluster of unary relations.""" - fontsize = 12 - - relations = sorted(cluster.relations) - num_columns = len(relations) - longest_relation_length = max(len(r) for r in relations) - +def get_all_entities(cluster): + """Return the entities in the order the cluster prefers.""" all_entities = [] domain = "" for dc in cluster.domain_clusters: @@ -55,12 +61,37 @@ def unary_matrix(cluster, obs, clusters): all_entities.extend(sorted(dc.entities)) + return all_entities + + +def add_redlines(ax, cluster, horizontal=True, vertical=False): + """Add redlines between entity clusters.""" + n = 0 + for dc in cluster.domain_clusters: + if n > 0: + if horizontal: + ax.axhline(n - 0.5, color='r', linewidth=2) + if vertical: + ax.axvline(n - 0.5, color='r', linewidth=2) + + n += len(dc.entities) + + +def unary_matrix(cluster, obs, clusters): + """Plot a matrix visualization for the cluster of unary relations.""" + fontsize = 12 + + relations = sorted(cluster.relations) + num_columns = len(relations) + longest_relation_length = max(len(r) for r in relations) + + all_entities = get_all_entities(cluster) num_rows = len(all_entities) longest_entity_length = max(len(e) for e in all_entities) width = (num_columns + longest_entity_length) * fontsize / 72.0 + 0.2 # Fontsize is in points, 1 pt = 1/72in height = (num_rows + longest_relation_length) * fontsize / 72.0 + 0.2 - fig, ax = plt.subplots(figsize=(width,height)) + fig, ax = plt.subplots(figsize=(width, height)) ax.xaxis.tick_top() ax.set_xticks(np.arange(num_columns), labels=relations, @@ -68,28 +99,77 @@ def unary_matrix(cluster, obs, clusters): ax.set_yticks(np.arange(num_rows), labels=all_entities, fontsize=fontsize) # Matrix of data - m = make_numpy_matrix(obs, all_entities, relations) + m = make_unary_matrix(obs, all_entities, relations) cmap = copy.copy(plt.get_cmap('Greys')) cmap.set_bad(color='gray') ax.imshow(normalize_matrix(m)) - # Red lines between rows (entities) - n = 0 - for dc in cluster.domain_clusters: - if n > 0: - ax.axhline(n - 0.5, color='r', linewidth=2) - n += len(dc.entities) + add_redlines(ax, cluster) + + fig.tight_layout() + return fig + + +def binary_matrix(cluster, obs, clusters): + """Plot a matrix visualization for a single binary relation.""" + fontsize = 12 + all_entities = get_all_entities(cluster) + n = len(all_entities) + longest_entity_length = max(len(e) for e in all_entities) + width = (n + longest_entity_length) * fontsize / 72.0 + 0.2 + fig, ax = plt.subplots(figsize=(width, width)) + + ax.xaxis.tick_top() + ax.set_xticks(np.arange(n), labels=all_entities, rotation=90, fontsize=fontsize) + ax.set_yticks(np.arange(n), labels=all_entities, fontsize=fontsize) + + m = make_binary_matrix(obs, all_entities) + cmap = copy.copy(plt.get_cmap('Greys')) + cmap.set_bad(color='gray') + ax.imshow(normalize_matrix(m)) + + add_redlines(ax, cluster, vertical=True) fig.tight_layout() return fig -def plot_unary_matrices(clusters, obs, output): +def collate_observations(obs): + """Separate observations into unary and binary.""" + unary_obs = [] + binary_obs = collections.defaultdict(list) + for ob in obs: + if len(ob.items) == 1: + unary_obs.append(ob) + break + + if len(ob.items) == 2: + binary_obs[ob.relation].append(ob) + + return unary_obs, binary_obs + + +def html_for_cluster(cluster, obs, clusters): + """Return the html for visualizing a single cluster.""" + html = '' + unary_obs, binary_obs = collate_observations(obs) + if unary_obs: + fig = unary_matrix(cluster, obs, clusters) + html += figure_to_html(fig) + "\n" + + if binary_obs: + for rel, rel_obs in binary_obs.items(): + html += f"

{rel}

\n" + fig = binary_matrix(cluster, rel_obs, clusters) + html += figure_to_html(fig) + "\n" + return html + + +def make_plots(clusters, obs, output): """Write a matrix visualization for each cluster in clusters.""" with open(output, 'w') as f: f.write("\n\n") for cluster in clusters: f.write("

IRM #" + cluster.cluster_id + "\n") - fig = unary_matrix(cluster, obs, clusters) - f.write(figure_to_html(fig) + "\n\n") + f.write(html_for_cluster(cluster, obs, clusters) + "\n) f.write("\n") From e1b45064f99036a2549a532f3d192bf232724c94 Mon Sep 17 00:00:00 2001 From: Thomas Colthurst Date: Wed, 18 Sep 2024 20:53:34 +0000 Subject: [PATCH 2/3] Bug fix --- py/visualize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/visualize.py b/py/visualize.py index 552b6db..afcb023 100644 --- a/py/visualize.py +++ b/py/visualize.py @@ -171,5 +171,5 @@ def make_plots(clusters, obs, output): f.write("\n\n") for cluster in clusters: f.write("

IRM #" + cluster.cluster_id + "\n") - f.write(html_for_cluster(cluster, obs, clusters) + "\n) + f.write(html_for_cluster(cluster, obs, clusters) + "\n") f.write("\n") From 50735ab47be75e9243b87bb29cd0cb0e57cf0c30 Mon Sep 17 00:00:00 2001 From: Thomas Colthurst Date: Thu, 19 Sep 2024 17:27:42 +0000 Subject: [PATCH 3/3] Fix nan-related bugs --- py/hirm_io.py | 12 ++++++------ py/visualize.py | 28 ++++++++++++---------------- py/visualize_test.py | 33 +++++++++++++++++++++++++-------- 3 files changed, 43 insertions(+), 30 deletions(-) diff --git a/py/hirm_io.py b/py/hirm_io.py index b1515f5..98d0626 100644 --- a/py/hirm_io.py +++ b/py/hirm_io.py @@ -8,8 +8,8 @@ Cluster = collections.namedtuple( "Cluster", ["cluster_id", "relations", "domain_clusters"]) -Domain_Cluster = collections.namedtuple( - "Domain_Cluster", ["cluster_id", "domain", "entities"]) +DomainCluster = collections.namedtuple( + "DomainCluster", ["cluster_id", "domain", "entities"]) def load_observations(path): """Load a dataset from path, and return it as an array of Observations.""" @@ -24,7 +24,7 @@ def load_observations(path): return data def load_clusters(path): - """Load the hirm clusters output from path as an HIRM_Clusters.""" + """Load the hirm clusters output from path as a list of Cluster's.""" id_to_relations = {} id_to_clusters = {} @@ -55,9 +55,9 @@ def load_clusters(path): continue domain_clusters.append( - Domain_Cluster(cluster_id=fields[1], - domain=fields[0], - entities=fields[2:])) + DomainCluster(cluster_id=fields[1], + domain=fields[0], + entities=fields[2:])) id_to_clusters[cluster_id] = domain_clusters diff --git a/py/visualize.py b/py/visualize.py index afcb023..84f6c10 100644 --- a/py/visualize.py +++ b/py/visualize.py @@ -22,7 +22,7 @@ def figure_to_html(fig) -> str: def make_unary_matrix(obs, entities, relations): """Return a numpy array containing the observations for unary relations.""" - m = np.empty(shape=(len(entities), len(relations))) + m = np.full((len(entities), len(relations)), np.nan) for ob in obs: val = ob.value ent_index = entities.index(ob.items[0]) @@ -34,7 +34,7 @@ def make_unary_matrix(obs, entities, relations): def make_binary_matrix(obs, entities): """Return array containing the observations for a single binary relation.""" - m = np.empty(shape=(len(entities), len(entities))) + m = np.full((len(entities), len(entities)), np.nan) for ob in obs: val = ob.value index1 = entities.index(ob.items[0]) @@ -43,11 +43,6 @@ def make_binary_matrix(obs, entities): return m -def normalize_matrix(m): - """Linearly map matrix values to be in [0, 1].""" - return (m - np.min(m)) / np.ptp(m) - - def get_all_entities(cluster): """Return the entities in the order the cluster prefers.""" all_entities = [] @@ -100,9 +95,9 @@ def unary_matrix(cluster, obs, clusters): # Matrix of data m = make_unary_matrix(obs, all_entities, relations) - cmap = copy.copy(plt.get_cmap('Greys')) - cmap.set_bad(color='gray') - ax.imshow(normalize_matrix(m)) + cmap = copy.copy(plt.get_cmap()) + cmap.set_bad(color='white') + ax.imshow(m, cmap=cmap) add_redlines(ax, cluster) @@ -124,9 +119,9 @@ def binary_matrix(cluster, obs, clusters): ax.set_yticks(np.arange(n), labels=all_entities, fontsize=fontsize) m = make_binary_matrix(obs, all_entities) - cmap = copy.copy(plt.get_cmap('Greys')) - cmap.set_bad(color='gray') - ax.imshow(normalize_matrix(m)) + cmap = copy.copy(plt.get_cmap()) + cmap.set_bad(color='white') + ax.imshow(m, cmap=cmap) add_redlines(ax, cluster, vertical=True) @@ -141,7 +136,6 @@ def collate_observations(obs): for ob in obs: if len(ob.items) == 1: unary_obs.append(ob) - break if len(ob.items) == 2: binary_obs[ob.relation].append(ob) @@ -154,11 +148,13 @@ def html_for_cluster(cluster, obs, clusters): html = '' unary_obs, binary_obs = collate_observations(obs) if unary_obs: - fig = unary_matrix(cluster, obs, clusters) + print(f"For irm #{cluster.cluster_id}, building unary matrix based on {len(unary_obs)} observations") + fig = unary_matrix(cluster, unary_obs, clusters) html += figure_to_html(fig) + "\n" if binary_obs: for rel, rel_obs in binary_obs.items(): + print(f"For irm #{cluster.cluster_id}, building binary matrix for {rel} based on {len(rel_obs)} observations") html += f"

{rel}

\n" fig = binary_matrix(cluster, rel_obs, clusters) html += figure_to_html(fig) + "\n" @@ -170,6 +166,6 @@ def make_plots(clusters, obs, output): with open(output, 'w') as f: f.write("\n\n") for cluster in clusters: - f.write("

IRM #" + cluster.cluster_id + "\n") + f.write(f"

IRM #{cluster.cluster_id}

\n") f.write(html_for_cluster(cluster, obs, clusters) + "\n") f.write("\n") diff --git a/py/visualize_test.py b/py/visualize_test.py index fe6fc96..3e8c794 100755 --- a/py/visualize_test.py +++ b/py/visualize_test.py @@ -17,21 +17,38 @@ def test_figure_to_html(self): html = visualize.figure_to_html(fig) self.assertEqual(html[:9], '