Skip to content

Commit

Permalink
Fix #187, eliminate nested defaultdict from vscgpm.
Browse files Browse the repository at this point in the history
  • Loading branch information
Feras A Saad committed Jan 17, 2018
1 parent 71bb54f commit 48385ba
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 40 deletions.
67 changes: 28 additions & 39 deletions src/venturescript/vscgpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,31 +74,32 @@ def __init__(self, outputs, inputs, rng=None, sp=None, **kwargs):
raise ValueError('source.inputs list disagrees with inputs.')
self.inputs = inputs
self.input_mapping = self._get_input_mapping(self.inputs)
# Check overriden observers.
# Check custom observers.
num_observers = self._get_num_observers()
self.obs_override = num_observers is not None
if self.obs_override and len(self.outputs) != num_observers:
self.observe_custom = num_observers is not None
if self.observe_custom and len(self.outputs) != num_observers:
raise ValueError('source.observers list disagrees with outputs.')
# XXX Eliminate this nested defaultdict
# Inputs and labels for incorporate/unincorporate.
self.obs = defaultdict(lambda: defaultdict(dict))
# Entry labels[rowid][query] is label used to observe output cell.
self.labels = dict()

def incorporate(self, rowid, observation, inputs=None):
inputs2 = self._validate_incorporate(rowid, observation, inputs)
if rowid not in self.labels:
self.labels[rowid] = dict()
for i, value in inputs2.iteritems():
self._observe_input_cell(rowid, i, value)
for t, value in observation.iteritems():
self._observe_output_cell(rowid, t, value)

def unincorporate(self, rowid):
if rowid not in self.obs:
if rowid not in self.labels:
raise ValueError('Never incorporated: %d' % rowid)
for q in self.outputs:
self._forget_output_cell(rowid, q)
for i in self.inputs:
self._forget_input_cell(rowid, i)
assert len(self.obs[rowid]['labels']) == 0
del self.obs[rowid]
assert len(self.labels[rowid]) == 0
del self.labels[rowid]

def logpdf(self, rowid, targets, constraints=None, inputs=None):
return 0
Expand Down Expand Up @@ -147,7 +148,7 @@ def to_metadata(self):
metadata['mode'] = self.mode
metadata['plugins'] = self.plugins
# Save the observations. We need to convert integer keys to strings.
metadata['obs'] = VsCGpm._obs_to_json(copy.deepcopy(self.obs))
metadata['labels'] = VsCGpm.convert_key_int_to_str(self.labels)
metadata['binary'] = base64.b64encode(self.ripl.saves())
metadata['factory'] = ('cgpm.venturescript.vscgpm', 'VsCGpm')
return metadata
Expand All @@ -167,11 +168,9 @@ def from_metadata(cls, metadata, rng=None):
rng=rng,
)
# Restore the observations. We need to convert string keys to integers.
# XXX Eliminate this terrible defaultdict hack. See Github #187.
obs_converted = VsCGpm._obs_from_json(metadata['obs'])
cgpm.obs = defaultdict(lambda: defaultdict(dict))
for key, value in obs_converted.iteritems():
cgpm.obs[key] = defaultdict(dict, value)
labels = VsCGpm.convert_key_str_to_int(metadata['labels'])
for rowid, mapping in labels.iteritems():
cgpm.labels[rowid] = mapping
return cgpm

# --------------------------------------------------------------------------
Expand All @@ -187,14 +186,14 @@ def _observe_output_cell(self, rowid, query, value):
output_idx = self.outputs.index(query)
label = self._gen_label()
sp_rowid = '(atom %d)' % (rowid,)
if not self.obs_override:
if not self.observe_custom:
self.ripl.observe('((lookup outputs %i) %s)'
% (output_idx, sp_rowid), value, label=label)
else:
obs_args = '%s %s (quote %s)' % (sp_rowid, value, label)
self.ripl.evaluate('((lookup observers %i) %s)'
% (output_idx, obs_args))
self.obs[rowid]['labels'][query] = label
self.labels[rowid][query] = label

def _observe_input_cell(self, rowid, idx, value):
input_name = self.input_mapping[idx]
Expand All @@ -206,9 +205,9 @@ def _observe_input_cell(self, rowid, idx, value):

def _forget_output_cell(self, rowid, query):
if self._is_observed_output_cell(rowid, query):
label = self.obs[rowid]['labels'][query]
label = self.labels[rowid][query]
self.ripl.forget(label)
del self.obs[rowid]['labels'][query]
del self.labels[rowid][query]

def _forget_input_cell(self, rowid, idx):
if self._is_observed_input_cell(rowid, idx):
Expand All @@ -222,7 +221,7 @@ def _forget_input_cell(self, rowid, idx):
self.ripl.forget(input_cell_name)

def _is_observed_output_cell(self, rowid, query):
return query in self.obs[rowid]['labels']
return rowid in self.labels and query in self.labels[rowid]

def _is_observed_input_cell(self, rowid, idx):
input_name = self.input_mapping[idx]
Expand Down Expand Up @@ -257,8 +256,8 @@ def _validate_incorporate(self, rowid, observation, inputs=None):
raise ValueError('Nan inputs: %s' % inputs)
if any(math.isnan(observation[i]) for i in observation):
raise ValueError('Nan observation: %s' % (observation,))
if rowid in self.obs \
and any(q in self.obs[rowid]['labels'] for q in observation):
if rowid in self.labels \
and any(q in self.labels[rowid] for q in observation):
raise ValueError('Observation exists: %d %s' % (rowid, observation))
return self._check_input_args(rowid, inputs)

Expand Down Expand Up @@ -307,31 +306,21 @@ def _check_input_args(self, rowid, inputs):
return {i : inputs[i] for i in inputs if i not in inputs_obs}

def _check_constraints_args(self, rowid, constraints):
constraints_obs = [q for q in constraints if rowid in self.obs and
constraints_obs = [q for q in constraints if rowid in self.labels and
self._is_observed_output_cell(rowid, q)]
if constraints_obs:
raise ValueError('Constrained observations exists: %d, %s, %s'
% (rowid, constraints, constraints_obs))

@staticmethod
def _obs_to_json(obs):
def convert_key_int_to_str(d):
assert all(isinstance(c, int) for c in d)
return {str(c): v for c, v in d.iteritems()}
obs2 = convert_key_int_to_str(obs)
for r in obs2:
obs2[r]['labels'] = convert_key_int_to_str(obs2[r]['labels'])
return obs2
def convert_key_int_to_str(d):
assert all(isinstance(c, int) for c in d)
return {str(c): v for c, v in d.iteritems()}

@staticmethod
def _obs_from_json(obs):
def convert_key_str_to_int(d):
assert all(isinstance(c, (str, unicode)) for c in d)
return {int(c): v for c, v in d.iteritems()}
obs2 = convert_key_str_to_int(obs)
for r in obs2:
obs2[r]['labels'] = convert_key_str_to_int(obs2[r]['labels'])
return obs2
def convert_key_str_to_int(d):
assert all(isinstance(c, (str, unicode)) for c in d)
return {int(c): v for c, v in d.iteritems()}

@staticmethod
def _load_helpers(ripl):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_vscgpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_serialize(case):
assert cgpm.outputs == cgpm_test.outputs
assert cgpm.inputs == cgpm_test.inputs
assert cgpm.source == cgpm_test.source
assert cgpm.obs == cgpm_test.obs
assert cgpm.labels == cgpm_test.labels

sample = cgpm_test.simulate(0, [0,1])
assert sample[0] == 1
Expand Down

0 comments on commit 48385ba

Please sign in to comment.