From 48385ba65f9550cac605702301e0516b26790986 Mon Sep 17 00:00:00 2001 From: Feras A Saad Date: Tue, 16 Jan 2018 22:15:12 -0500 Subject: [PATCH] Fix #187, eliminate nested defaultdict from vscgpm. --- src/venturescript/vscgpm.py | 67 ++++++++++++++++--------------------- tests/test_vscgpm.py | 2 +- 2 files changed, 29 insertions(+), 40 deletions(-) diff --git a/src/venturescript/vscgpm.py b/src/venturescript/vscgpm.py index e4cca406..3c87fc24 100644 --- a/src/venturescript/vscgpm.py +++ b/src/venturescript/vscgpm.py @@ -74,31 +74,32 @@ def __init__(self, outputs, inputs, rng=None, sp=None, **kwargs): raise ValueError('source.inputs list disagrees with inputs.') self.inputs = inputs self.input_mapping = self._get_input_mapping(self.inputs) - # Check overriden observers. + # Check custom observers. num_observers = self._get_num_observers() - self.obs_override = num_observers is not None - if self.obs_override and len(self.outputs) != num_observers: + self.observe_custom = num_observers is not None + if self.observe_custom and len(self.outputs) != num_observers: raise ValueError('source.observers list disagrees with outputs.') - # XXX Eliminate this nested defaultdict - # Inputs and labels for incorporate/unincorporate. - self.obs = defaultdict(lambda: defaultdict(dict)) + # Entry labels[rowid][query] is label used to observe output cell. + self.labels = dict() def incorporate(self, rowid, observation, inputs=None): inputs2 = self._validate_incorporate(rowid, observation, inputs) + if rowid not in self.labels: + self.labels[rowid] = dict() for i, value in inputs2.iteritems(): self._observe_input_cell(rowid, i, value) for t, value in observation.iteritems(): self._observe_output_cell(rowid, t, value) def unincorporate(self, rowid): - if rowid not in self.obs: + if rowid not in self.labels: raise ValueError('Never incorporated: %d' % rowid) for q in self.outputs: self._forget_output_cell(rowid, q) for i in self.inputs: self._forget_input_cell(rowid, i) - assert len(self.obs[rowid]['labels']) == 0 - del self.obs[rowid] + assert len(self.labels[rowid]) == 0 + del self.labels[rowid] def logpdf(self, rowid, targets, constraints=None, inputs=None): return 0 @@ -147,7 +148,7 @@ def to_metadata(self): metadata['mode'] = self.mode metadata['plugins'] = self.plugins # Save the observations. We need to convert integer keys to strings. - metadata['obs'] = VsCGpm._obs_to_json(copy.deepcopy(self.obs)) + metadata['labels'] = VsCGpm.convert_key_int_to_str(self.labels) metadata['binary'] = base64.b64encode(self.ripl.saves()) metadata['factory'] = ('cgpm.venturescript.vscgpm', 'VsCGpm') return metadata @@ -167,11 +168,9 @@ def from_metadata(cls, metadata, rng=None): rng=rng, ) # Restore the observations. We need to convert string keys to integers. - # XXX Eliminate this terrible defaultdict hack. See Github #187. - obs_converted = VsCGpm._obs_from_json(metadata['obs']) - cgpm.obs = defaultdict(lambda: defaultdict(dict)) - for key, value in obs_converted.iteritems(): - cgpm.obs[key] = defaultdict(dict, value) + labels = VsCGpm.convert_key_str_to_int(metadata['labels']) + for rowid, mapping in labels.iteritems(): + cgpm.labels[rowid] = mapping return cgpm # -------------------------------------------------------------------------- @@ -187,14 +186,14 @@ def _observe_output_cell(self, rowid, query, value): output_idx = self.outputs.index(query) label = self._gen_label() sp_rowid = '(atom %d)' % (rowid,) - if not self.obs_override: + if not self.observe_custom: self.ripl.observe('((lookup outputs %i) %s)' % (output_idx, sp_rowid), value, label=label) else: obs_args = '%s %s (quote %s)' % (sp_rowid, value, label) self.ripl.evaluate('((lookup observers %i) %s)' % (output_idx, obs_args)) - self.obs[rowid]['labels'][query] = label + self.labels[rowid][query] = label def _observe_input_cell(self, rowid, idx, value): input_name = self.input_mapping[idx] @@ -206,9 +205,9 @@ def _observe_input_cell(self, rowid, idx, value): def _forget_output_cell(self, rowid, query): if self._is_observed_output_cell(rowid, query): - label = self.obs[rowid]['labels'][query] + label = self.labels[rowid][query] self.ripl.forget(label) - del self.obs[rowid]['labels'][query] + del self.labels[rowid][query] def _forget_input_cell(self, rowid, idx): if self._is_observed_input_cell(rowid, idx): @@ -222,7 +221,7 @@ def _forget_input_cell(self, rowid, idx): self.ripl.forget(input_cell_name) def _is_observed_output_cell(self, rowid, query): - return query in self.obs[rowid]['labels'] + return rowid in self.labels and query in self.labels[rowid] def _is_observed_input_cell(self, rowid, idx): input_name = self.input_mapping[idx] @@ -257,8 +256,8 @@ def _validate_incorporate(self, rowid, observation, inputs=None): raise ValueError('Nan inputs: %s' % inputs) if any(math.isnan(observation[i]) for i in observation): raise ValueError('Nan observation: %s' % (observation,)) - if rowid in self.obs \ - and any(q in self.obs[rowid]['labels'] for q in observation): + if rowid in self.labels \ + and any(q in self.labels[rowid] for q in observation): raise ValueError('Observation exists: %d %s' % (rowid, observation)) return self._check_input_args(rowid, inputs) @@ -307,31 +306,21 @@ def _check_input_args(self, rowid, inputs): return {i : inputs[i] for i in inputs if i not in inputs_obs} def _check_constraints_args(self, rowid, constraints): - constraints_obs = [q for q in constraints if rowid in self.obs and + constraints_obs = [q for q in constraints if rowid in self.labels and self._is_observed_output_cell(rowid, q)] if constraints_obs: raise ValueError('Constrained observations exists: %d, %s, %s' % (rowid, constraints, constraints_obs)) @staticmethod - def _obs_to_json(obs): - def convert_key_int_to_str(d): - assert all(isinstance(c, int) for c in d) - return {str(c): v for c, v in d.iteritems()} - obs2 = convert_key_int_to_str(obs) - for r in obs2: - obs2[r]['labels'] = convert_key_int_to_str(obs2[r]['labels']) - return obs2 + def convert_key_int_to_str(d): + assert all(isinstance(c, int) for c in d) + return {str(c): v for c, v in d.iteritems()} @staticmethod - def _obs_from_json(obs): - def convert_key_str_to_int(d): - assert all(isinstance(c, (str, unicode)) for c in d) - return {int(c): v for c, v in d.iteritems()} - obs2 = convert_key_str_to_int(obs) - for r in obs2: - obs2[r]['labels'] = convert_key_str_to_int(obs2[r]['labels']) - return obs2 + def convert_key_str_to_int(d): + assert all(isinstance(c, (str, unicode)) for c in d) + return {int(c): v for c, v in d.iteritems()} @staticmethod def _load_helpers(ripl): diff --git a/tests/test_vscgpm.py b/tests/test_vscgpm.py index c98f8385..4b0c4371 100644 --- a/tests/test_vscgpm.py +++ b/tests/test_vscgpm.py @@ -249,7 +249,7 @@ def test_serialize(case): assert cgpm.outputs == cgpm_test.outputs assert cgpm.inputs == cgpm_test.inputs assert cgpm.source == cgpm_test.source - assert cgpm.obs == cgpm_test.obs + assert cgpm.labels == cgpm_test.labels sample = cgpm_test.simulate(0, [0,1]) assert sample[0] == 1