diff --git a/cxx/gendb.cc b/cxx/gendb.cc index 4970528..8bfc2e0 100644 --- a/cxx/gendb.cc +++ b/cxx/gendb.cc @@ -89,22 +89,16 @@ void GenDB::get_unique_entities_relation(const std::string& rel_name, if (ref_indices.contains(rel_name)) { if (ref_indices.at(rel_name).contains(ind)) { for (const auto& [rf_name, rf_ind] : ref_indices.at(rel_name).at(ind)) { - if (!reference_values.at(domains[ind]).contains({rf_name, class_item})) { - int new_val; + if (!reference_values.at(domains[ind]) + .contains({rf_name, class_item})) { const std::string& ref_class = domains.at(rf_ind); - if (domain_crps.at(ref_class).tables.size() == 0) { - new_val = 0; - } else { - auto it = domain_crps.at(ref_class).tables.rbegin(); - new_val = it->first + 1; - } + int new_val = entity_crps.at(ref_class).max_table() + 1; int new_id = get_reference_id(domains[ind], rf_name, class_item); - reference_values.at(domains[ind])[{rf_name, class_item}] = new_val; - domain_crps.at(ref_class).incorporate(new_id, new_val); + entity_crps.at(ref_class).incorporate(new_id, new_val); } - - int refval = reference_values.at(domains[ind]).at({rf_name, class_item}); + int refval = + reference_values.at(domains[ind]).at({rf_name, class_item}); get_unique_entities_relation(rel_name, rf_ind, refval, items); } } @@ -140,10 +134,9 @@ T_items GenDB::sample_entities_relation( if (!reference_values.at(class_name).contains(ref_key)) { sample_and_incorporate_reference(prng, class_name, ref_key, ref_class); } - T_items items = - sample_entities_relation( - prng, ref_class, ++class_path_start, class_path_end, - reference_values.at(class_name).at(ref_key)); + T_items items = sample_entities_relation( + prng, ref_class, ++class_path_start, class_path_end, + reference_values.at(class_name).at(ref_key)); // The order of the items corresponds to the order of the relation's domains, // with the class (domain) corresponding to the primary key placed last on the // list. @@ -166,16 +159,15 @@ int GenDB::get_reference_id(const std::string& class_name, // and stores the value in reference_values. void GenDB::sample_and_incorporate_reference( std::mt19937* prng, const std::string& class_name, - const std::pair& ref_key, - const std::string& ref_class) { + const std::pair& ref_key, const std::string& ref_class) { auto [ref_field, class_item] = ref_key; - int new_val = domain_crps[ref_class].sample(prng); + int new_val = entity_crps.at(ref_class).sample(prng); // Generate a unique ID for the sample and incorporate it into the // entity CRP. int new_id = get_reference_id(class_name, ref_field, class_item); reference_values.at(class_name)[ref_key] = new_val; - entity_crps[ref_class].incorporate(new_id, new_val); + entity_crps.at(ref_class).incorporate(new_id, new_val); } // Incorporates an observed value into a query relation. Recursively @@ -210,7 +202,7 @@ void GenDB::sample_and_incorporate_for_class(std::mt19937* prng, const std::string& class_name, const T_item& item) { for (const std::string& rel_name : class_to_relations.at(class_name)) { - sample_class_ancestors(prng, class_name, item, false); + sample_class_ancestors(prng, class_name, item); const std::vector& domains = std::visit( [&](auto tr) { return tr.domains; }, hirm->schema.at(rel_name)); T_items rel_items(domains.size()); @@ -258,8 +250,8 @@ T_items GenDB::sample_class_ancestors(std::mt19937* prng, std::pair ref_key = {name, class_item}; if (!reference_values.at(class_name).contains(ref_key)) { assert(prng != nullptr); - sample_and_incorporate_reference( - prng, class_name, ref_key, cv->class_name); + sample_and_incorporate_reference(prng, class_name, ref_key, + cv->class_name); } T_items ref_items = sample_class_ancestors( prng, cv->class_name, reference_values.at(class_name).at(ref_key)); @@ -637,7 +629,6 @@ double GenDB::unincorporate_singleton( double logp_refclass = 0.; int ref_val = reference_values.at(class_name).at({ref_field, class_item}); - T_items base_items = sample_class_ancestors(prng, ref_class, ref_val); logp_refclass += unincorporate_from_entity_cluster(class_name, ref_field, class_item, unincorporated_from_entity_crps, false); @@ -682,7 +673,6 @@ void GenDB::transition_reference(std::mt19937* prng, return; } - std::cerr << "just got gibbs probs" << std::endl; // For each relation, get the indices (in the items vector) of the reference // value being transitioned. std::map> domain_inds = diff --git a/cxx/gendb.hh b/cxx/gendb.hh index 09ae9b7..9485974 100644 --- a/cxx/gendb.hh +++ b/cxx/gendb.hh @@ -47,8 +47,6 @@ class GenDB { void get_unique_entities_relation(const std::string& rel_name, const int ind, const int class_item, T_items& items); - const std::pair& ref_key, const std::string& ref_class, - bool new_rows_have_unique_entities); // Samples a set of entities in the domains of the relation corresponding to // class_path. diff --git a/cxx/gendb_test.cc b/cxx/gendb_test.cc index b94bc24..35259e1 100644 --- a/cxx/gendb_test.cc +++ b/cxx/gendb_test.cc @@ -1032,11 +1032,4 @@ observe gendb.transition_reference_class_and_ancestors(&prng, "Record"); } -BOOST_AUTO_TEST_CASE(test_transition_reference_class) { - std::mt19937 prng; - GenDB gendb(&prng, schema); - setup_gendb(&prng, gendb, 20); - gendb.transition_reference_class_and_ancestors(&prng, "Record"); -} - BOOST_AUTO_TEST_SUITE_END()