diff --git a/py/hirm_io.py b/py/hirm_io.py index b1515f5..98d0626 100644 --- a/py/hirm_io.py +++ b/py/hirm_io.py @@ -8,8 +8,8 @@ Cluster = collections.namedtuple( "Cluster", ["cluster_id", "relations", "domain_clusters"]) -Domain_Cluster = collections.namedtuple( - "Domain_Cluster", ["cluster_id", "domain", "entities"]) +DomainCluster = collections.namedtuple( + "DomainCluster", ["cluster_id", "domain", "entities"]) def load_observations(path): """Load a dataset from path, and return it as an array of Observations.""" @@ -24,7 +24,7 @@ def load_observations(path): return data def load_clusters(path): - """Load the hirm clusters output from path as an HIRM_Clusters.""" + """Load the hirm clusters output from path as a list of Cluster's.""" id_to_relations = {} id_to_clusters = {} @@ -55,9 +55,9 @@ def load_clusters(path): continue domain_clusters.append( - Domain_Cluster(cluster_id=fields[1], - domain=fields[0], - entities=fields[2:])) + DomainCluster(cluster_id=fields[1], + domain=fields[0], + entities=fields[2:])) id_to_clusters[cluster_id] = domain_clusters diff --git a/py/visualize_unary.py b/py/make_plots.py similarity index 62% rename from py/visualize_unary.py rename to py/make_plots.py index 5291a33..b00d9ff 100755 --- a/py/visualize_unary.py +++ b/py/make_plots.py @@ -1,12 +1,20 @@ #!/usr/bin/python3 +# Generates an html file containing plots that visualize the clusters +# produced by an HIRM. Currently it only supports datasets with a single +# domain, and only visualizes the unary and binary relations over that +# domain. + +# Example usage: +# ./make_plots.py --observations=../cxx/assets/animals.unary.obs --clusters=../cxx/assets/animals.unary.hirm --output=/tmp/vis.html + import argparse import hirm_io import visualize def main(): parser = argparse.ArgumentParser( - description="Generate matrix plots for a collection of unary relations") + description="Generate matrix plots for an HIRM's output") parser.add_argument( "--observations", required=True, type=str, help="Path to file containing observations") @@ -25,7 +33,7 @@ def main(): clusters = hirm_io.load_clusters(args.clusters) print(f"Writing html to {args.output} ...") - visualize.plot_unary_matrices(clusters, obs, args.output) + visualize.make_plots(clusters, obs, args.output) if __name__ == "__main__": diff --git a/py/visualize.py b/py/visualize.py index 3f95bca..84f6c10 100644 --- a/py/visualize.py +++ b/py/visualize.py @@ -2,6 +2,7 @@ # Apache License, Version 2.0, refer to LICENSE.txt import base64 +import collections import copy import io import sys @@ -19,9 +20,9 @@ def figure_to_html(fig) -> str: return html -def make_numpy_matrix(obs, entities, relations): - """Return a numpy array containing the observations.""" - m = np.empty(shape=(len(entities), len(relations))) +def make_unary_matrix(obs, entities, relations): + """Return a numpy array containing the observations for unary relations.""" + m = np.full((len(entities), len(relations)), np.nan) for ob in obs: val = ob.value ent_index = entities.index(ob.items[0]) @@ -31,19 +32,19 @@ def make_numpy_matrix(obs, entities, relations): return m -def normalize_matrix(m): - """Linearly map matrix values to be in [0, 1].""" - return (m - np.min(m)) / np.ptp(m) - - -def unary_matrix(cluster, obs, clusters): - """Plot a matrix visualization for the cluster of unary relations.""" - fontsize = 12 +def make_binary_matrix(obs, entities): + """Return array containing the observations for a single binary relation.""" + m = np.full((len(entities), len(entities)), np.nan) + for ob in obs: + val = ob.value + index1 = entities.index(ob.items[0]) + index2 = entities.index(ob.items[1]) + m[index1][index2] = val + return m - relations = sorted(cluster.relations) - num_columns = len(relations) - longest_relation_length = max(len(r) for r in relations) +def get_all_entities(cluster): + """Return the entities in the order the cluster prefers.""" all_entities = [] domain = "" for dc in cluster.domain_clusters: @@ -55,12 +56,37 @@ def unary_matrix(cluster, obs, clusters): all_entities.extend(sorted(dc.entities)) + return all_entities + + +def add_redlines(ax, cluster, horizontal=True, vertical=False): + """Add redlines between entity clusters.""" + n = 0 + for dc in cluster.domain_clusters: + if n > 0: + if horizontal: + ax.axhline(n - 0.5, color='r', linewidth=2) + if vertical: + ax.axvline(n - 0.5, color='r', linewidth=2) + + n += len(dc.entities) + + +def unary_matrix(cluster, obs, clusters): + """Plot a matrix visualization for the cluster of unary relations.""" + fontsize = 12 + + relations = sorted(cluster.relations) + num_columns = len(relations) + longest_relation_length = max(len(r) for r in relations) + + all_entities = get_all_entities(cluster) num_rows = len(all_entities) longest_entity_length = max(len(e) for e in all_entities) width = (num_columns + longest_entity_length) * fontsize / 72.0 + 0.2 # Fontsize is in points, 1 pt = 1/72in height = (num_rows + longest_relation_length) * fontsize / 72.0 + 0.2 - fig, ax = plt.subplots(figsize=(width,height)) + fig, ax = plt.subplots(figsize=(width, height)) ax.xaxis.tick_top() ax.set_xticks(np.arange(num_columns), labels=relations, @@ -68,28 +94,78 @@ def unary_matrix(cluster, obs, clusters): ax.set_yticks(np.arange(num_rows), labels=all_entities, fontsize=fontsize) # Matrix of data - m = make_numpy_matrix(obs, all_entities, relations) - cmap = copy.copy(plt.get_cmap('Greys')) - cmap.set_bad(color='gray') - ax.imshow(normalize_matrix(m)) + m = make_unary_matrix(obs, all_entities, relations) + cmap = copy.copy(plt.get_cmap()) + cmap.set_bad(color='white') + ax.imshow(m, cmap=cmap) - # Red lines between rows (entities) - n = 0 - for dc in cluster.domain_clusters: - if n > 0: - ax.axhline(n - 0.5, color='r', linewidth=2) - n += len(dc.entities) + add_redlines(ax, cluster) fig.tight_layout() return fig -def plot_unary_matrices(clusters, obs, output): +def binary_matrix(cluster, obs, clusters): + """Plot a matrix visualization for a single binary relation.""" + fontsize = 12 + all_entities = get_all_entities(cluster) + n = len(all_entities) + longest_entity_length = max(len(e) for e in all_entities) + width = (n + longest_entity_length) * fontsize / 72.0 + 0.2 + fig, ax = plt.subplots(figsize=(width, width)) + + ax.xaxis.tick_top() + ax.set_xticks(np.arange(n), labels=all_entities, rotation=90, fontsize=fontsize) + ax.set_yticks(np.arange(n), labels=all_entities, fontsize=fontsize) + + m = make_binary_matrix(obs, all_entities) + cmap = copy.copy(plt.get_cmap()) + cmap.set_bad(color='white') + ax.imshow(m, cmap=cmap) + + add_redlines(ax, cluster, vertical=True) + + fig.tight_layout() + return fig + + +def collate_observations(obs): + """Separate observations into unary and binary.""" + unary_obs = [] + binary_obs = collections.defaultdict(list) + for ob in obs: + if len(ob.items) == 1: + unary_obs.append(ob) + + if len(ob.items) == 2: + binary_obs[ob.relation].append(ob) + + return unary_obs, binary_obs + + +def html_for_cluster(cluster, obs, clusters): + """Return the html for visualizing a single cluster.""" + html = '' + unary_obs, binary_obs = collate_observations(obs) + if unary_obs: + print(f"For irm #{cluster.cluster_id}, building unary matrix based on {len(unary_obs)} observations") + fig = unary_matrix(cluster, unary_obs, clusters) + html += figure_to_html(fig) + "\n" + + if binary_obs: + for rel, rel_obs in binary_obs.items(): + print(f"For irm #{cluster.cluster_id}, building binary matrix for {rel} based on {len(rel_obs)} observations") + html += f"

{rel}

\n" + fig = binary_matrix(cluster, rel_obs, clusters) + html += figure_to_html(fig) + "\n" + return html + + +def make_plots(clusters, obs, output): """Write a matrix visualization for each cluster in clusters.""" with open(output, 'w') as f: f.write("\n\n") for cluster in clusters: - f.write("

IRM #" + cluster.cluster_id + "\n") - fig = unary_matrix(cluster, obs, clusters) - f.write(figure_to_html(fig) + "\n\n") + f.write(f"

IRM #{cluster.cluster_id}

\n") + f.write(html_for_cluster(cluster, obs, clusters) + "\n") f.write("\n") diff --git a/py/visualize_test.py b/py/visualize_test.py index fe6fc96..3e8c794 100755 --- a/py/visualize_test.py +++ b/py/visualize_test.py @@ -17,21 +17,38 @@ def test_figure_to_html(self): html = visualize.figure_to_html(fig) self.assertEqual(html[:9], '