Skip to content

Commit

Permalink
Merge pull request #211 from probcomp/092324-emilyaf-model7-incref-part2
Browse files Browse the repository at this point in the history
Add incorporate_reference method.
  • Loading branch information
emilyfertig authored Sep 25, 2024
2 parents 50c346b + 60927ac commit 6bf7918
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 1 deletion.
50 changes: 50 additions & 0 deletions cxx/gendb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ GenDB::GenDB(std::mt19937* prng, const PCleanSchema& schema_,
}
}

double GenDB::logp_score() const {
double domain_crps_logp = 0;
for (const auto& [d, crp] : domain_crps) {
domain_crps_logp += crp.logp_score();
}
return domain_crps_logp + hirm->logp_score();
}

void GenDB::incorporate(
std::mt19937* prng,
const std::pair<int, std::map<std::string, ObservationVariant>>& row) {
Expand Down Expand Up @@ -320,4 +328,46 @@ GenDB::update_reference_items(
return new_stored_values;
}

void GenDB::incorporate_reference(
std::mt19937* prng,
std::map<std::string,
std::unordered_map<T_items, ObservationVariant, H_items>>&
stored_values,
const bool to_cluster_only) {
for (const auto& [rel_name, query_field] : schema.query.fields) {
if (stored_values.contains(rel_name)) {
auto f = [&](auto rel) {
incorporate_reference_relation(prng, rel, rel_name, stored_values,
to_cluster_only);
};
std::visit(f, hirm->get_relation(rel_name));
}
}
}

template <typename T>
void GenDB::incorporate_reference_relation(
std::mt19937* prng, Relation<T>* rel, const std::string& rel_name,
std::map<std::string,
std::unordered_map<T_items, ObservationVariant, H_items>>&
stored_values,
const bool to_cluster_only) {
if (const T_noisy_relation* trel =
std::get_if<T_noisy_relation>(&hirm->schema.at(rel_name))) {
if (stored_values.contains(trel->base_relation)) {
NoisyRelation<T>* noisy_rel = reinterpret_cast<NoisyRelation<T>*>(rel);
incorporate_reference_relation(prng, noisy_rel->base_relation,
trel->base_relation, stored_values,
to_cluster_only);
}
}
for (const auto& [items, value] : stored_values.at(rel_name)) {
if (to_cluster_only) {
rel->incorporate_to_cluster(items, std::get<T>(value));
} else {
rel->incorporate(prng, items, std::get<T>(value));
}
}
}

GenDB::~GenDB() { delete hirm; }
22 changes: 22 additions & 0 deletions cxx/gendb.hh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class GenDB {
GenDB(std::mt19937* prng, const PCleanSchema& schema,
bool _only_final_emissions = false, bool _record_class_is_clean = true);

// Return the log probability of the data incorporated into the GenDB so far.
double logp_score() const;

// Incorporates a row of observed data into the GenDB instance.
void incorporate(
std::mt19937* prng,
Expand Down Expand Up @@ -98,6 +101,25 @@ class GenDB {
const std::string& class_name, const std::string& ref_field,
const int class_item, const int new_ref_val);

// Incorporates the items and values from stored_values (generally an output
// of update_reference_items).
void incorporate_reference(
std::mt19937* prng,
std::map<std::string,
std::unordered_map<T_items, ObservationVariant, H_items>>&
stored_values,
const bool to_cluster_only = false);

// Recursively incorporates the items and values of stored_values for a single
// relation (and its base relations).
template <typename T>
void incorporate_reference_relation(
std::mt19937* prng, Relation<T>* rel, const std::string& rel_name,
std::map<std::string,
std::unordered_map<T_items, ObservationVariant, H_items>>&
stored_values,
const bool to_cluster_only);

~GenDB();

// Disable copying.
Expand Down
65 changes: 64 additions & 1 deletion cxx/gendb_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,22 @@ observe

void setup_gendb(std::mt19937* prng, GenDB& gendb) {
std::map<std::string, ObservationVariant> obs0 = {
{"School", "MIT"}, {"Degree", "PHD"}, {"City", "Cambrij"}};
{"School", "Massachusetts Institute of Technology"},
{"Degree", "PHD"},
{"City", "Cambrij"}};
std::map<std::string, ObservationVariant> obs1 = {
{"School", "MIT"}, {"Degree", "MD"}, {"City", "Cambridge"}};
std::map<std::string, ObservationVariant> obs2 = {
{"School", "Tufts"}, {"Degree", "PT"}, {"City", "Boston"}};
std::map<std::string, ObservationVariant> obs3 = {
{"School", "Boston University"}, {"Degree", "PhD"}, {"City", "Boston"}};

int i = 0;
while (i < 30) {
gendb.incorporate(prng, {i++, obs0});
gendb.incorporate(prng, {i++, obs1});
gendb.incorporate(prng, {i++, obs2});
gendb.incorporate(prng, {i++, obs3});
}
}

Expand Down Expand Up @@ -280,6 +285,13 @@ BOOST_AUTO_TEST_CASE(test_unincorporate_reference3) {
test_unincorporate_reference_helper(gendb, "Practice", "city", 0, false);
}

BOOST_AUTO_TEST_CASE(test_logp_score) {
std::mt19937 prng;
GenDB gendb(&prng, schema);
setup_gendb(&prng, gendb);
BOOST_TEST(gendb.logp_score() < 0.0);
}

BOOST_AUTO_TEST_CASE(test_update_reference_items) {
std::mt19937 prng;
GenDB gendb(&prng, schema);
Expand Down Expand Up @@ -310,4 +322,55 @@ BOOST_AUTO_TEST_CASE(test_update_reference_items) {
}
}

BOOST_AUTO_TEST_CASE(test_incorporate_stored_items) {
std::mt19937 prng;
GenDB gendb(&prng, schema);
setup_gendb(&prng, gendb);

std::string class_name = "Record";
std::string ref_field = "location";
int class_item = 1;

double init_logp = gendb.logp_score();
auto unincorporated_items =
gendb.unincorporate_reference(class_name, ref_field, class_item);

int old_ref_val =
gendb.reference_values.at({class_name, ref_field, class_item});
int new_ref_val = (old_ref_val == 0) ? 1 : old_ref_val - 1;
auto updated_items = gendb.update_reference_items(
unincorporated_items, class_name, ref_field, class_item, new_ref_val);

gendb.incorporate_reference(&prng, updated_items);
// Updating the reference values should change logp_score (though note that
// the domain_crps have not been updated), so the total logp_score is
// different only if new_ref_val and old_ref_val are in different IRM
// clusters.
BOOST_TEST(gendb.logp_score() != init_logp, tt::tolerance(1e-6));
}

BOOST_AUTO_TEST_CASE(test_incorporate_stored_items_to_cluster) {
std::mt19937 prng;
GenDB gendb(&prng, schema);
setup_gendb(&prng, gendb);

std::string class_name = "Record";
std::string ref_field = "location";
int class_item = 1;

double init_logp = gendb.logp_score();
auto unincorporated_items =
gendb.unincorporate_reference(class_name, ref_field, class_item);
int new_ref_val =
gendb.reference_values.at({class_name, ref_field, class_item});

auto updated_items = gendb.update_reference_items(
unincorporated_items, class_name, ref_field, class_item, new_ref_val);

// Logp_score shouldn't change if the same items/values are
// unincorporated/incorporated back into the same clusters.
gendb.incorporate_reference(&prng, updated_items, true);
BOOST_TEST(gendb.logp_score() == init_logp, tt::tolerance(1e-6));
}

BOOST_AUTO_TEST_SUITE_END()

0 comments on commit 6bf7918

Please sign in to comment.