-
Notifications
You must be signed in to change notification settings - Fork 10
/
render.py
86 lines (79 loc) · 3 KB
/
render.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright 2020 MIT Probabilistic Computing Project.
# See LICENSE.txt
import os
import tempfile
import time
from math import exp
import graphviz
import networkx as nx
from sppl.spe import AtomicLeaf
from sppl.spe import NominalLeaf
from sppl.spe import ProductSPE
from sppl.spe import RealLeaf
from sppl.spe import SumSPE
gensym = lambda: 'r%s' % (str(time.time()).replace('.', ''),)
def render_networkx_graph(spe):
if isinstance(spe, NominalLeaf):
G = nx.DiGraph()
root = gensym()
G.add_node(root, label='%s\n%s' % (spe.symbol.token, 'Nominal'))
return G
if isinstance(spe, AtomicLeaf):
G = nx.DiGraph()
root = gensym()
G.add_node(root, label='%s\n%s(%s)'
% (spe.symbol.token, 'Atomic', str(spe.value)))
return G
if isinstance(spe, RealLeaf):
G = nx.DiGraph()
root = gensym()
kwds = '\n%s' % (tuple(spe.dist.kwds.values()),) if spe.dist.kwds else ''
G.add_node(root, label='%s\n%s%s' % (spe.symbol.token, spe.dist.dist.name, kwds))
if len(spe.env) > 1:
for k, v in spe.env.items():
if k != spe.symbol:
roott = gensym()
G.add_node(roott, label=str(v), style='filled')
G.add_edge(root, roott, label=' %s' % (str(k),), style='dashed')
return G
if isinstance(spe, SumSPE):
G = nx.DiGraph()
root = gensym()
G.add_node(root, label='\N{PLUS SIGN}')
# Add nodes and edges from children.
G_children = [render_networkx_graph(c) for c in spe.children]
for i, x in enumerate(G_children):
G.add_nodes_from(x.nodes.data())
G.add_edges_from(x.edges.data())
subroot = list(nx.topological_sort(x))[0]
G.add_edge(root, subroot, label='%1.3f' % (exp(spe.weights[i]),))
return G
if isinstance(spe, ProductSPE):
G = nx.DiGraph()
root = gensym()
G.add_node(root, label='\N{MULTIPLICATION SIGN}')
# Add nodes and edges from children.
G_children = [render_networkx_graph(c) for c in spe.children]
for x in G_children:
G.add_nodes_from(x.nodes.data())
G.add_edges_from(x.edges.data())
subroot = list(nx.topological_sort(x))[0]
G.add_edge(root, subroot)
return G
assert False, 'Unknown SPE type: %s' % (spe,)
def render_graphviz(spe, filename=None, ext=None, show=None):
fname = filename
if filename is None:
f = tempfile.NamedTemporaryFile(delete=False)
fname = f.name
G = render_networkx_graph(spe)
ext = ext or 'png'
assert ext in ['png', 'pdf'], 'Extension must be .pdf or .png'
fname_dot = '%s.dot' % (fname,)
# nx.set_edge_attributes(G, 'serif', 'fontname')
# nx.set_node_attributes(G, 'serif', 'fontname')
nx.nx_agraph.write_dot(G, fname_dot)
source = graphviz.Source.from_file(fname_dot, format=ext)
source.render(filename=fname, view=show)
os.unlink(fname)
return source