Skip to content

Commit

Permalink
Resolve merge conflicts.
Browse files Browse the repository at this point in the history
  • Loading branch information
emilyfertig committed Oct 4, 2024
1 parent a72aeac commit 7aa4920
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 34 deletions.
40 changes: 15 additions & 25 deletions cxx/gendb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,16 @@ void GenDB::get_unique_entities_relation(const std::string& rel_name,
if (ref_indices.contains(rel_name)) {
if (ref_indices.at(rel_name).contains(ind)) {
for (const auto& [rf_name, rf_ind] : ref_indices.at(rel_name).at(ind)) {
if (!reference_values.at(domains[ind]).contains({rf_name, class_item})) {
int new_val;
if (!reference_values.at(domains[ind])
.contains({rf_name, class_item})) {
const std::string& ref_class = domains.at(rf_ind);
if (domain_crps.at(ref_class).tables.size() == 0) {
new_val = 0;
} else {
auto it = domain_crps.at(ref_class).tables.rbegin();
new_val = it->first + 1;
}
int new_val = entity_crps.at(ref_class).max_table() + 1;
int new_id = get_reference_id(domains[ind], rf_name, class_item);

reference_values.at(domains[ind])[{rf_name, class_item}] = new_val;
domain_crps.at(ref_class).incorporate(new_id, new_val);
entity_crps.at(ref_class).incorporate(new_id, new_val);
}

int refval = reference_values.at(domains[ind]).at({rf_name, class_item});
int refval =
reference_values.at(domains[ind]).at({rf_name, class_item});
get_unique_entities_relation(rel_name, rf_ind, refval, items);
}
}
Expand Down Expand Up @@ -140,10 +134,9 @@ T_items GenDB::sample_entities_relation(
if (!reference_values.at(class_name).contains(ref_key)) {
sample_and_incorporate_reference(prng, class_name, ref_key, ref_class);
}
T_items items =
sample_entities_relation(
prng, ref_class, ++class_path_start, class_path_end,
reference_values.at(class_name).at(ref_key));
T_items items = sample_entities_relation(
prng, ref_class, ++class_path_start, class_path_end,
reference_values.at(class_name).at(ref_key));
// The order of the items corresponds to the order of the relation's domains,
// with the class (domain) corresponding to the primary key placed last on the
// list.
Expand All @@ -166,16 +159,15 @@ int GenDB::get_reference_id(const std::string& class_name,
// and stores the value in reference_values.
void GenDB::sample_and_incorporate_reference(
std::mt19937* prng, const std::string& class_name,
const std::pair<std::string, int>& ref_key,
const std::string& ref_class) {
const std::pair<std::string, int>& ref_key, const std::string& ref_class) {
auto [ref_field, class_item] = ref_key;
int new_val = domain_crps[ref_class].sample(prng);
int new_val = entity_crps.at(ref_class).sample(prng);

// Generate a unique ID for the sample and incorporate it into the
// entity CRP.
int new_id = get_reference_id(class_name, ref_field, class_item);
reference_values.at(class_name)[ref_key] = new_val;
entity_crps[ref_class].incorporate(new_id, new_val);
entity_crps.at(ref_class).incorporate(new_id, new_val);
}

// Incorporates an observed value into a query relation. Recursively
Expand Down Expand Up @@ -210,7 +202,7 @@ void GenDB::sample_and_incorporate_for_class(std::mt19937* prng,
const std::string& class_name,
const T_item& item) {
for (const std::string& rel_name : class_to_relations.at(class_name)) {
sample_class_ancestors(prng, class_name, item, false);
sample_class_ancestors(prng, class_name, item);
const std::vector<std::string>& domains = std::visit(
[&](auto tr) { return tr.domains; }, hirm->schema.at(rel_name));
T_items rel_items(domains.size());
Expand Down Expand Up @@ -258,8 +250,8 @@ T_items GenDB::sample_class_ancestors(std::mt19937* prng,
std::pair<std::string, int> ref_key = {name, class_item};
if (!reference_values.at(class_name).contains(ref_key)) {
assert(prng != nullptr);
sample_and_incorporate_reference(
prng, class_name, ref_key, cv->class_name);
sample_and_incorporate_reference(prng, class_name, ref_key,
cv->class_name);
}
T_items ref_items = sample_class_ancestors(
prng, cv->class_name, reference_values.at(class_name).at(ref_key));
Expand Down Expand Up @@ -637,7 +629,6 @@ double GenDB::unincorporate_singleton(
double logp_refclass = 0.;

int ref_val = reference_values.at(class_name).at({ref_field, class_item});
T_items base_items = sample_class_ancestors(prng, ref_class, ref_val);
logp_refclass +=
unincorporate_from_entity_cluster(class_name, ref_field, class_item,
unincorporated_from_entity_crps, false);
Expand Down Expand Up @@ -682,7 +673,6 @@ void GenDB::transition_reference(std::mt19937* prng,
return;
}

std::cerr << "just got gibbs probs" << std::endl;
// For each relation, get the indices (in the items vector) of the reference
// value being transitioned.
std::map<std::string, std::vector<size_t>> domain_inds =
Expand Down
2 changes: 0 additions & 2 deletions cxx/gendb.hh
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ class GenDB {

void get_unique_entities_relation(const std::string& rel_name, const int ind,
const int class_item, T_items& items);
const std::pair<std::string, int>& ref_key, const std::string& ref_class,
bool new_rows_have_unique_entities);

// Samples a set of entities in the domains of the relation corresponding to
// class_path.
Expand Down
7 changes: 0 additions & 7 deletions cxx/gendb_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1032,11 +1032,4 @@ observe
gendb.transition_reference_class_and_ancestors(&prng, "Record");
}

BOOST_AUTO_TEST_CASE(test_transition_reference_class) {
std::mt19937 prng;
GenDB gendb(&prng, schema);
setup_gendb(&prng, gendb, 20);
gendb.transition_reference_class_and_ancestors(&prng, "Record");
}

BOOST_AUTO_TEST_SUITE_END()

0 comments on commit 7aa4920

Please sign in to comment.