Skip to content

Commit

Permalink
hidden vars
Browse files Browse the repository at this point in the history
  • Loading branch information
goncalorafaria committed Aug 16, 2020
1 parent ac35c74 commit 7035c09
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
Binary file modified cimg.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 20 additions & 8 deletions scm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ def Variable(name, dist):
SCM.model.addVariable(arv)
return arv

def HiddenVariable(name,dist):
arv = AncestorRandomVariable(name, dist, observed=False)
SCM.model.addVariable(arv)
return arv

class SCM(Named):

def __init__(self,
Expand Down Expand Up @@ -47,13 +52,13 @@ def mark(self, name, rv):

rv._mark()
del self.nodes[rv.name]
rv.name = "RV/" + name
rv.name = "M/" + name
self.addAuxVariable(rv)

return rv

def draw_complete(self):
plt.figure()
plt.figure(figsize=(12,8))
plt.title(self.name)
G = nx.DiGraph()

Expand All @@ -64,7 +69,7 @@ def draw_complete(self):
for to in n.outbound :
G.add_edge(n.name,to.name)

nx.draw(G,with_labels=True, arrows=True, node_size=1400)
nx.draw(G,with_labels=True, arrows=True, node_size=1200)

plt.savefig("img.png")

Expand All @@ -80,10 +85,10 @@ def reach(rv):
return l

def draw(self):
plt.figure()
plt.figure(figsize=(12,8))
plt.title(self.uname())
G = self.build_causal_graph()
nx.draw(G,with_labels=True, arrows=True, node_size=1400)
nx.draw(G,with_labels=True, arrows=True, node_size=1200)
plt.savefig("cimg.png")


Expand All @@ -93,7 +98,9 @@ def build_causal_graph(self):
q = queue.Queue()

for n in self.ancestors:
q.put(self.nodes[n])
node = self.nodes[n]
if node.observed :
q.put(node)

for n in self.nodes.values():
if n.observed :
Expand Down Expand Up @@ -131,6 +138,10 @@ def addChildren(self, chlds):
def _mark(self):
self.observed = True

def mark(self):
self._mark()
return self

def reach(self):
l= []

Expand Down Expand Up @@ -158,9 +169,10 @@ def mark(self, name):
class AncestorRandomVariable(RandomVariable):
def __init__(self,
name,
sampler):
sampler,
observed=True):
super(AncestorRandomVariable,self).__init__(name,
True)
observed)
self.sampler=sampler

def sample(self, sample_shape=()):
Expand Down
6 changes: 4 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
model = SCM("Simple Causal Graph")

X = Variable("X", tfp.distributions.Normal(loc=0,scale=1))
Z = Variable("Z", tfp.distributions.Normal(loc=0,scale=1))
Ny = HiddenVariable("Ny", tfp.distributions.Normal(loc=0, scale=1))

Y = math.exp( math.square(X) ).mark("Y")
NyZ = math.multiply(Ny,Z)

Z = math.add(X, Y).mark("Z")
Y = math.add( NyZ, math.exp( math.square(X)) ).mark("Y")

model.draw()
model.draw_complete()

0 comments on commit 7035c09

Please sign in to comment.