Skip to content

Commit

Permalink
Change data format to only use structured arrays to avoid data loader…
Browse files Browse the repository at this point in the history
… mess.
  • Loading branch information
kdund committed Jul 14, 2023
1 parent b3614da commit 71ed76b
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions alea/blueice_extended_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import scipy.stats as stats
from blueice.likelihood import LogAncillaryLikelihood
from blueice.likelihood import LogLikelihoodSum
from inference_interface import dict_to_structured_array, structured_array_to_dict


class BlueiceExtendedModel(StatisticalModel):
Expand Down Expand Up @@ -155,7 +156,7 @@ def _generate_data(self, **generate_values) -> list:
generate_values_anc = {k: v for k, v in generate_values.items() if k in ancillary_keys}
ancillary_measurements = self._generate_ancillary_measurements(
**generate_values_anc)
return science_data + [ancillary_measurements] + [generate_values]
return science_data + [ancillary_measurements] + [dict_to_structured_array(generate_values)]

def _generate_science_data(self, **generate_values) -> list:
science_data = [gen.simulate(**generate_values)
Expand All @@ -177,7 +178,7 @@ def _generate_ancillary_measurements(self, **generate_values) -> dict:
parameter_meas = param.fit_limits[1]
ancillary_measurements[name] = parameter_meas

return ancillary_measurements
return dict_to_structured_array(ancillary_measurements)


class CustomAncillaryLikelihood(LogAncillaryLikelihood):
Expand Down Expand Up @@ -215,16 +216,16 @@ def constraint_terms(self) -> dict:
"""
return {name: func.logpdf for name, func in self.constraint_functions.items()}

def set_data(self, d: dict):
def set_data(self, d: np.array):
"""
Set the data of the ancillary likelihood (ancillary measurements).
Args:
d (dict): Data in this case is a dict of ancillary measurements.
d (np.array): Data of ancillary measurements, stored as numpy array
"""
# This results in shifted constraint terms.
assert set(d.keys()) == set(self.parameters.names)
self.constraint_functions = self._get_constraint_functions(**d)
assert set(d.dtype.names) == set(self.parameters.names)
self.constraint_functions = self._get_constraint_functions(**structured_array_to_dict(d))

def ancillary_likelihood_sum(self, evaluate_at: dict) -> float:
"""Return the sum of all constraint terms.
Expand Down

0 comments on commit 71ed76b

Please sign in to comment.