diff --git a/cxx/BUILD b/cxx/BUILD index e4c53f7..4d76723 100644 --- a/cxx/BUILD +++ b/cxx/BUILD @@ -55,7 +55,7 @@ cc_library( ) cc_library( - name = "gendb_lib", + name = "gendb", hdrs = ["gendb.hh"], srcs = ["gendb.cc"], visibility = [":__subpackages__"], @@ -219,7 +219,7 @@ cc_test( name = "gendb_test", srcs = ["gendb_test.cc"], deps = [ - ":gendb_lib", + ":gendb", "@boost//:test", ], ) diff --git a/cxx/gendb.cc b/cxx/gendb.cc index 0206f37..2adb090 100644 --- a/cxx/gendb.cc +++ b/cxx/gendb.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include "distributions/crp.hh" #include "hirm.hh" @@ -42,7 +43,7 @@ void GenDB::incorporate( // Sample a set of items to be incorporated into the query relation. const std::vector& class_path = schema.query.fields.at(query_rel).class_path; - std::vector items = + T_items items = sample_entities_relation(prng, schema.query.record_class, class_path.cbegin(), class_path.cend(), id); @@ -53,51 +54,58 @@ void GenDB::incorporate( // This function walks the class_path of the query, populates the global // reference_values table if necessary, and returns a sampled set of items -// for the query relation. -std::vector GenDB::sample_entities_relation( +// for the query relation that corresponds to the class path. class_path_start +// is an attribute of the Class named class_name. +T_items GenDB::sample_entities_relation( std::mt19937* prng, const std::string& class_name, std::vector::const_iterator class_path_start, std::vector::const_iterator class_path_end, int class_item) { if (class_path_end - class_path_start == 1) { - // These are domains and we need to DFS-traverse the class's + // The last item in class_path is the class from which the queried attribute + // is observed (for which there's a corresponding clean relation, observing + // the attribute from the class). We need to DFS-traverse the class's // parents, similar to PCleanSchemaHelper::compute_domains_for. return sample_class_ancestors(prng, class_name, class_item); - } else { - // These are noisy relation domains along the path from the latent cleanly- - // observed class to the record class. - std::string ref_field = *class_path_start; - - // If the reference field isn't populated, sample a value from a CRP and - // add it to reference_values. - std::string ref_class = - std::get( - schema.classes.at(class_name).vars.at(ref_field).spec) - .class_name; - if (!reference_values[class_name].contains(class_item)) { - sample_and_incorporate_reference(prng, class_name, class_item, ref_field, - ref_class); - } - std::vector items = sample_entities_relation( - prng, ref_class, ++class_path_start, class_path_end, - reference_values[class_name][class_item][ref_field]); - items.push_back(class_item); - return items; } + + // These are noisy relation domains along the path from the latent cleanly- + // observed class to the record class. + std::string ref_field = *class_path_start; + + // If the reference field isn't populated, sample a value from a CRP and + // add it to reference_values. + std::string ref_class = + std::get(schema.classes.at(class_name).vars.at(ref_field).spec) + .class_name; + std::tuple ref_key = {class_name, ref_field, + class_item}; + if (!reference_values.contains(ref_key)) { + sample_and_incorporate_reference(prng, ref_key, ref_class); + } + T_items items = + sample_entities_relation(prng, ref_class, ++class_path_start, + class_path_end, reference_values.at(ref_key)); + // The order of the items corresponds to the order of the relation's domains, + // with the class (domain) corresponding to the primary key placed last on the + // list. + items.push_back(class_item); + return items; } -void GenDB::sample_and_incorporate_reference(std::mt19937* prng, - const std::string& class_name, - int class_item, - const std::string& ref_field, - const std::string& ref_class) { +void GenDB::sample_and_incorporate_reference( + std::mt19937* prng, + const std::tuple& ref_key, + const std::string& ref_class) { + auto [class_name, ref_field, class_item] = ref_key; int new_val = domain_crps[ref_class].sample(prng); // Generate a unique ID for the sample and incorporate it into the // domain CRP. std::stringstream new_id_str; - new_id_str << class_name << class_item << ref_field; + std::string sep = " "; // Spaces are disallowed in class/variable names. + new_id_str << class_name << sep << class_item << sep << ref_field; int new_id = std::hash{}(new_id_str.str()); - reference_values[class_name][class_item][ref_field] = new_val; + reference_values[ref_key] = new_val; domain_crps[ref_class].incorporate(new_id, new_val); } @@ -106,33 +114,35 @@ void GenDB::incorporate_query_relation(std::mt19937* prng, const std::string& query_rel_name, const T_items& items, const ObservationVariant& value) { - RelationVariant query_rel = hirm->get_relation(query_rel_name); - T_items base_items = std::visit( - [&](auto nr) { - using T = typename std::remove_pointer_t::ValueType; - auto noisy_rel = reinterpret_cast*>(nr); - return noisy_rel->get_base_items(items); - }, - query_rel); - - T_noisy_relation t_query_rel = - std::get(hirm->schema.at(query_rel_name)); - bool base_contains_items = std::visit( - [&](auto rel) { return rel->get_data().contains(base_items); }, - hirm->get_relation(t_query_rel.base_relation)); - if (!base_contains_items) { - hirm->sample_and_incorporate_relation(prng, t_query_rel.base_relation, - base_items); + if (const T_noisy_relation* t_query_rel = + std::get_if(&hirm->schema.at(query_rel_name))) { + RelationVariant query_rel = hirm->get_relation(query_rel_name); + T_items base_items = std::visit( + [&](auto nr) { + using T = typename std::remove_pointer_t::ValueType; + auto noisy_rel = reinterpret_cast*>(nr); + return noisy_rel->get_base_items(items); + }, + query_rel); + + bool base_contains_items = std::visit( + [&](auto rel) { return rel->get_data().contains(base_items); }, + hirm->get_relation(t_query_rel->base_relation)); + if (!base_contains_items) { + hirm->sample_and_incorporate_relation(prng, t_query_rel->base_relation, + base_items); + } } hirm->incorporate(prng, query_rel_name, items, value); } -// Generates a vector of items from the clean relation domains, with the +// Generates a vector of items from the class' ancestors, with the // primary key (final item) equal to class_item. Items are looked up in the // global reference_values table or sampled from CRPs (and added to the // reference_values table/entity CRPs) if necessary. -std::vector GenDB::sample_class_ancestors( - std::mt19937* prng, const std::string& class_name, int class_item) { +T_items GenDB::sample_class_ancestors(std::mt19937* prng, + const std::string& class_name, + int class_item) { T_items items; PCleanClass c = schema.classes.at(class_name); @@ -140,12 +150,13 @@ std::vector GenDB::sample_class_ancestors( if (const ClassVar* cv = std::get_if(&(var.spec))) { // If the reference field isn't populated, sample a value from a CRP and // add it to reference_values. - if (!reference_values[class_name][class_item].contains(name)) { - sample_and_incorporate_reference(prng, class_name, class_item, name, - cv->class_name); + std::tuple ref_key = {class_name, name, + class_item}; + if (!reference_values.contains(ref_key)) { + sample_and_incorporate_reference(prng, ref_key, cv->class_name); } - T_items ref_items = sample_class_ancestors( - prng, cv->class_name, reference_values[class_name][class_item][name]); + T_items ref_items = sample_class_ancestors(prng, cv->class_name, + reference_values.at(ref_key)); items.insert(items.end(), ref_items.begin(), ref_items.end()); } } diff --git a/cxx/gendb.hh b/cxx/gendb.hh index ccffbb2..879f826 100644 --- a/cxx/gendb.hh +++ b/cxx/gendb.hh @@ -20,42 +20,48 @@ class GenDB { PCleanSchemaHelper schema_helper; // This data structure contains entity sets and linkages. Semantics are - // map>>, + // map ref_val>>, // where primary_key and ref_val are (integer) entity IDs. - std::map>> - reference_values; + std::map, int> reference_values; - HIRM* hirm; + HIRM* hirm; // Owned by the GenDB instance. - // CRPs for latent entities, where the "tables" are entity IDs and the - // "customers" are unique identifiers of observations of that class. Map - // keys are class names. + // Map keys are class names. Values are CRPs for latent entities, where the + // "tables" are entity IDs and the "customers" are unique identifiers of + // observations of that class. std::map domain_crps; GenDB(std::mt19937* prng, const PCleanSchema& schema, bool _only_final_emissions = false, bool _record_class_is_clean = true); + // Incorporates a row of observed data into the GenDB instance. void incorporate( std::mt19937* prng, const std::pair>& row); - void incorporate_query_relation(std::mt19937* prng, + // Incorporates a single element of a row of observed data. + void incorporate_query_relation(std::mt19937* prng, const std::string& query_rel, const T_items& items, const ObservationVariant& value); + // Samples a reference value and stores it in reference_values and the + // relevant domain CRP. void sample_and_incorporate_reference( - std::mt19937* prng, const std::string& class_name, int class_item, - const std::string& ref_field, const std::string& ref_class); + std::mt19937* prng, + const std::tuple& ref_key, + const std::string& ref_class); - std::vector sample_entities_relation( + // Samples a set of entities in the domains of the relation corresponding to + // class_path. + T_items sample_entities_relation( std::mt19937* prng, const std::string& class_name, std::vector::const_iterator class_path_start, - std::vector::const_iterator class_path_end, - int class_item); + std::vector::const_iterator class_path_end, int class_item); - std::vector sample_class_ancestors( - std::mt19937* prng, const std::string& class_name, int class_item); + // Sample items from a class' ancestors (recursive reference fields). + T_items sample_class_ancestors(std::mt19937* prng, + const std::string& class_name, int class_item); ~GenDB(); diff --git a/cxx/gendb_test.cc b/cxx/gendb_test.cc index 6a86375..fb433c8 100644 --- a/cxx/gendb_test.cc +++ b/cxx/gendb_test.cc @@ -70,18 +70,18 @@ BOOST_AUTO_TEST_CASE(test_gendb) { // Check that the structure of reference_values is as expected. // School and City are not contained in reference_values because they // have no reference fields. - BOOST_TEST(gendb.reference_values.size() == 3); - BOOST_TEST(gendb.reference_values.contains("Physician")); - BOOST_TEST(gendb.reference_values.contains("Practice")); - BOOST_TEST(gendb.reference_values.contains("Record")); - - BOOST_TEST(gendb.reference_values.at("Physician").at(0).size() == 1); - BOOST_TEST(gendb.reference_values.at("Physician").at(0).contains("school")); - BOOST_TEST(gendb.reference_values.at("Practice").at(0).size() == 1); - BOOST_TEST(gendb.reference_values.at("Practice").at(0).contains("city")); - BOOST_TEST(gendb.reference_values.at("Record").at(0).size() == 2); - BOOST_TEST(gendb.reference_values.at("Record").at(0).contains("physician")); - BOOST_TEST(gendb.reference_values.at("Record").at(0).contains("location")); + BOOST_TEST(gendb.reference_values.contains({"Record", "physician", 0})); + BOOST_TEST(gendb.reference_values.contains({"Record", "physician", 1})); + BOOST_TEST(gendb.reference_values.contains({"Record", "physician", 2})); + BOOST_TEST(gendb.reference_values.contains({"Record", "physician", 3})); + BOOST_TEST(gendb.reference_values.contains({"Record", "physician", 4})); + BOOST_TEST(gendb.reference_values.contains({"Record", "location", 0})); + BOOST_TEST(gendb.reference_values.contains({"Record", "location", 1})); + BOOST_TEST(gendb.reference_values.contains({"Record", "location", 2})); + BOOST_TEST(gendb.reference_values.contains({"Record", "location", 3})); + BOOST_TEST(gendb.reference_values.contains({"Record", "location", 4})); + BOOST_TEST(gendb.reference_values.contains({"Physician", "school", 0})); + BOOST_TEST(gendb.reference_values.contains({"Practice", "city", 0})); auto get_relation_items = [&](auto rel) { std::unordered_set items; @@ -100,10 +100,10 @@ BOOST_AUTO_TEST_CASE(test_gendb) { BOOST_TEST(i.size() == 3); int index = i[2]; int expected_location = - gendb.reference_values.at("Record").at(index).at("location"); + gendb.reference_values.at({"Record", "location", index}); BOOST_TEST(expected_location == i[1]); int expected_city = - gendb.reference_values.at("Practice").at(expected_location).at("city"); + gendb.reference_values.at({"Practice", "city", expected_location}); BOOST_TEST(expected_city == i[0]); } @@ -115,7 +115,7 @@ BOOST_AUTO_TEST_CASE(test_gendb) { BOOST_TEST(i.size() == 2); int index = i[1]; int expected_school = - gendb.reference_values.at("Physician").at(index).at("school"); + gendb.reference_values.at({"Physician", "school", index}); BOOST_TEST(expected_school == i[0]); } @@ -127,7 +127,7 @@ BOOST_AUTO_TEST_CASE(test_gendb) { BOOST_TEST(i.size() == 2); int index = i[1]; int expected_school = - gendb.reference_values.at("Physician").at(index).at("school"); + gendb.reference_values.at({"Physician", "school", index}); BOOST_TEST(expected_school == i[0]); }