Skip to content

Commit

Permalink
Respond to reviewer comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
emilyfertig committed Sep 18, 2024
1 parent 5fb5c96 commit 4028c89
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 89 deletions.
4 changes: 2 additions & 2 deletions cxx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ cc_library(
)

cc_library(
name = "gendb_lib",
name = "gendb",
hdrs = ["gendb.hh"],
srcs = ["gendb.cc"],
visibility = [":__subpackages__"],
Expand Down Expand Up @@ -219,7 +219,7 @@ cc_test(
name = "gendb_test",
srcs = ["gendb_test.cc"],
deps = [
":gendb_lib",
":gendb",
"@boost//:test",
],
)
Expand Down
123 changes: 67 additions & 56 deletions cxx/gendb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <map>
#include <random>
#include <string>
#include <variant>

#include "distributions/crp.hh"
#include "hirm.hh"
Expand Down Expand Up @@ -42,7 +43,7 @@ void GenDB::incorporate(
// Sample a set of items to be incorporated into the query relation.
const std::vector<std::string>& class_path =
schema.query.fields.at(query_rel).class_path;
std::vector<int> items =
T_items items =
sample_entities_relation(prng, schema.query.record_class,
class_path.cbegin(), class_path.cend(), id);

Expand All @@ -53,51 +54,58 @@ void GenDB::incorporate(

// This function walks the class_path of the query, populates the global
// reference_values table if necessary, and returns a sampled set of items
// for the query relation.
std::vector<int> GenDB::sample_entities_relation(
// for the query relation that corresponds to the class path. class_path_start
// is an attribute of the Class named class_name.
T_items GenDB::sample_entities_relation(
std::mt19937* prng, const std::string& class_name,
std::vector<std::string>::const_iterator class_path_start,
std::vector<std::string>::const_iterator class_path_end, int class_item) {
if (class_path_end - class_path_start == 1) {
// These are domains and we need to DFS-traverse the class's
// The last item in class_path is the class from which the queried attribute
// is observed (for which there's a corresponding clean relation, observing
// the attribute from the class). We need to DFS-traverse the class's
// parents, similar to PCleanSchemaHelper::compute_domains_for.
return sample_class_ancestors(prng, class_name, class_item);
} else {
// These are noisy relation domains along the path from the latent cleanly-
// observed class to the record class.
std::string ref_field = *class_path_start;

// If the reference field isn't populated, sample a value from a CRP and
// add it to reference_values.
std::string ref_class =
std::get<ClassVar>(
schema.classes.at(class_name).vars.at(ref_field).spec)
.class_name;
if (!reference_values[class_name].contains(class_item)) {
sample_and_incorporate_reference(prng, class_name, class_item, ref_field,
ref_class);
}
std::vector<int> items = sample_entities_relation(
prng, ref_class, ++class_path_start, class_path_end,
reference_values[class_name][class_item][ref_field]);
items.push_back(class_item);
return items;
}

// These are noisy relation domains along the path from the latent cleanly-
// observed class to the record class.
std::string ref_field = *class_path_start;

// If the reference field isn't populated, sample a value from a CRP and
// add it to reference_values.
std::string ref_class =
std::get<ClassVar>(schema.classes.at(class_name).vars.at(ref_field).spec)
.class_name;
std::tuple<std::string, std::string, int> ref_key = {class_name, ref_field,
class_item};
if (!reference_values.contains(ref_key)) {
sample_and_incorporate_reference(prng, ref_key, ref_class);
}
T_items items =
sample_entities_relation(prng, ref_class, ++class_path_start,
class_path_end, reference_values.at(ref_key));
// The order of the items corresponds to the order of the relation's domains,
// with the class (domain) corresponding to the primary key placed last on the
// list.
items.push_back(class_item);
return items;
}

void GenDB::sample_and_incorporate_reference(std::mt19937* prng,
const std::string& class_name,
int class_item,
const std::string& ref_field,
const std::string& ref_class) {
void GenDB::sample_and_incorporate_reference(
std::mt19937* prng,
const std::tuple<std::string, std::string, int>& ref_key,
const std::string& ref_class) {
auto [class_name, ref_field, class_item] = ref_key;
int new_val = domain_crps[ref_class].sample(prng);

// Generate a unique ID for the sample and incorporate it into the
// domain CRP.
std::stringstream new_id_str;
new_id_str << class_name << class_item << ref_field;
std::string sep = " "; // Spaces are disallowed in class/variable names.
new_id_str << class_name << sep << class_item << sep << ref_field;
int new_id = std::hash<std::string>{}(new_id_str.str());
reference_values[class_name][class_item][ref_field] = new_val;
reference_values[ref_key] = new_val;
domain_crps[ref_class].incorporate(new_id, new_val);
}

Expand All @@ -106,46 +114,49 @@ void GenDB::incorporate_query_relation(std::mt19937* prng,
const std::string& query_rel_name,
const T_items& items,
const ObservationVariant& value) {
RelationVariant query_rel = hirm->get_relation(query_rel_name);
T_items base_items = std::visit(
[&](auto nr) {
using T = typename std::remove_pointer_t<decltype(nr)>::ValueType;
auto noisy_rel = reinterpret_cast<NoisyRelation<T>*>(nr);
return noisy_rel->get_base_items(items);
},
query_rel);

T_noisy_relation t_query_rel =
std::get<T_noisy_relation>(hirm->schema.at(query_rel_name));
bool base_contains_items = std::visit(
[&](auto rel) { return rel->get_data().contains(base_items); },
hirm->get_relation(t_query_rel.base_relation));
if (!base_contains_items) {
hirm->sample_and_incorporate_relation(prng, t_query_rel.base_relation,
base_items);
if (const T_noisy_relation* t_query_rel =
std::get_if<T_noisy_relation>(&hirm->schema.at(query_rel_name))) {
RelationVariant query_rel = hirm->get_relation(query_rel_name);
T_items base_items = std::visit(
[&](auto nr) {
using T = typename std::remove_pointer_t<decltype(nr)>::ValueType;
auto noisy_rel = reinterpret_cast<NoisyRelation<T>*>(nr);
return noisy_rel->get_base_items(items);
},
query_rel);

bool base_contains_items = std::visit(
[&](auto rel) { return rel->get_data().contains(base_items); },
hirm->get_relation(t_query_rel->base_relation));
if (!base_contains_items) {
hirm->sample_and_incorporate_relation(prng, t_query_rel->base_relation,
base_items);
}
}
hirm->incorporate(prng, query_rel_name, items, value);
}

// Generates a vector of items from the clean relation domains, with the
// Generates a vector of items from the class' ancestors, with the
// primary key (final item) equal to class_item. Items are looked up in the
// global reference_values table or sampled from CRPs (and added to the
// reference_values table/entity CRPs) if necessary.
std::vector<int> GenDB::sample_class_ancestors(
std::mt19937* prng, const std::string& class_name, int class_item) {
T_items GenDB::sample_class_ancestors(std::mt19937* prng,
const std::string& class_name,
int class_item) {
T_items items;
PCleanClass c = schema.classes.at(class_name);

for (const auto& [name, var] : c.vars) {
if (const ClassVar* cv = std::get_if<ClassVar>(&(var.spec))) {
// If the reference field isn't populated, sample a value from a CRP and
// add it to reference_values.
if (!reference_values[class_name][class_item].contains(name)) {
sample_and_incorporate_reference(prng, class_name, class_item, name,
cv->class_name);
std::tuple<std::string, std::string, int> ref_key = {class_name, name,
class_item};
if (!reference_values.contains(ref_key)) {
sample_and_incorporate_reference(prng, ref_key, cv->class_name);
}
T_items ref_items = sample_class_ancestors(
prng, cv->class_name, reference_values[class_name][class_item][name]);
T_items ref_items = sample_class_ancestors(prng, cv->class_name,
reference_values.at(ref_key));
items.insert(items.end(), ref_items.begin(), ref_items.end());
}
}
Expand Down
36 changes: 21 additions & 15 deletions cxx/gendb.hh
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,48 @@ class GenDB {
PCleanSchemaHelper schema_helper;

// This data structure contains entity sets and linkages. Semantics are
// map<class_name, map<primary_key, map<reference_field_name, ref_val>>>,
// map<tuple<class_name, reference_field_name, class_primary_key> ref_val>>,
// where primary_key and ref_val are (integer) entity IDs.
std::map<std::string, std::map<int, std::map<std::string, int>>>
reference_values;
std::map<std::tuple<std::string, std::string, int>, int> reference_values;

HIRM* hirm;
HIRM* hirm; // Owned by the GenDB instance.

// CRPs for latent entities, where the "tables" are entity IDs and the
// "customers" are unique identifiers of observations of that class. Map
// keys are class names.
// Map keys are class names. Values are CRPs for latent entities, where the
// "tables" are entity IDs and the "customers" are unique identifiers of
// observations of that class.
std::map<std::string, CRP> domain_crps;

GenDB(std::mt19937* prng, const PCleanSchema& schema,
bool _only_final_emissions = false, bool _record_class_is_clean = true);

// Incorporates a row of observed data into the GenDB instance.
void incorporate(
std::mt19937* prng,
const std::pair<int, std::map<std::string, ObservationVariant>>& row);

void incorporate_query_relation(std::mt19937* prng,
// Incorporates a single element of a row of observed data.
void incorporate_query_relation(std::mt19937* prng,
const std::string& query_rel,
const T_items& items,
const ObservationVariant& value);

// Samples a reference value and stores it in reference_values and the
// relevant domain CRP.
void sample_and_incorporate_reference(
std::mt19937* prng, const std::string& class_name, int class_item,
const std::string& ref_field, const std::string& ref_class);
std::mt19937* prng,
const std::tuple<std::string, std::string, int>& ref_key,
const std::string& ref_class);

std::vector<int> sample_entities_relation(
// Samples a set of entities in the domains of the relation corresponding to
// class_path.
T_items sample_entities_relation(
std::mt19937* prng, const std::string& class_name,
std::vector<std::string>::const_iterator class_path_start,
std::vector<std::string>::const_iterator class_path_end,
int class_item);
std::vector<std::string>::const_iterator class_path_end, int class_item);

std::vector<int> sample_class_ancestors(
std::mt19937* prng, const std::string& class_name, int class_item);
// Sample items from a class' ancestors (recursive reference fields).
T_items sample_class_ancestors(std::mt19937* prng,
const std::string& class_name, int class_item);

~GenDB();

Expand Down
32 changes: 16 additions & 16 deletions cxx/gendb_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,18 @@ BOOST_AUTO_TEST_CASE(test_gendb) {
// Check that the structure of reference_values is as expected.
// School and City are not contained in reference_values because they
// have no reference fields.
BOOST_TEST(gendb.reference_values.size() == 3);
BOOST_TEST(gendb.reference_values.contains("Physician"));
BOOST_TEST(gendb.reference_values.contains("Practice"));
BOOST_TEST(gendb.reference_values.contains("Record"));

BOOST_TEST(gendb.reference_values.at("Physician").at(0).size() == 1);
BOOST_TEST(gendb.reference_values.at("Physician").at(0).contains("school"));
BOOST_TEST(gendb.reference_values.at("Practice").at(0).size() == 1);
BOOST_TEST(gendb.reference_values.at("Practice").at(0).contains("city"));
BOOST_TEST(gendb.reference_values.at("Record").at(0).size() == 2);
BOOST_TEST(gendb.reference_values.at("Record").at(0).contains("physician"));
BOOST_TEST(gendb.reference_values.at("Record").at(0).contains("location"));
BOOST_TEST(gendb.reference_values.contains({"Record", "physician", 0}));
BOOST_TEST(gendb.reference_values.contains({"Record", "physician", 1}));
BOOST_TEST(gendb.reference_values.contains({"Record", "physician", 2}));
BOOST_TEST(gendb.reference_values.contains({"Record", "physician", 3}));
BOOST_TEST(gendb.reference_values.contains({"Record", "physician", 4}));
BOOST_TEST(gendb.reference_values.contains({"Record", "location", 0}));
BOOST_TEST(gendb.reference_values.contains({"Record", "location", 1}));
BOOST_TEST(gendb.reference_values.contains({"Record", "location", 2}));
BOOST_TEST(gendb.reference_values.contains({"Record", "location", 3}));
BOOST_TEST(gendb.reference_values.contains({"Record", "location", 4}));
BOOST_TEST(gendb.reference_values.contains({"Physician", "school", 0}));
BOOST_TEST(gendb.reference_values.contains({"Practice", "city", 0}));

auto get_relation_items = [&](auto rel) {
std::unordered_set<T_items, H_items> items;
Expand All @@ -100,10 +100,10 @@ BOOST_AUTO_TEST_CASE(test_gendb) {
BOOST_TEST(i.size() == 3);
int index = i[2];
int expected_location =
gendb.reference_values.at("Record").at(index).at("location");
gendb.reference_values.at({"Record", "location", index});
BOOST_TEST(expected_location == i[1]);
int expected_city =
gendb.reference_values.at("Practice").at(expected_location).at("city");
gendb.reference_values.at({"Practice", "city", expected_location});
BOOST_TEST(expected_city == i[0]);
}

Expand All @@ -115,7 +115,7 @@ BOOST_AUTO_TEST_CASE(test_gendb) {
BOOST_TEST(i.size() == 2);
int index = i[1];
int expected_school =
gendb.reference_values.at("Physician").at(index).at("school");
gendb.reference_values.at({"Physician", "school", index});
BOOST_TEST(expected_school == i[0]);
}

Expand All @@ -127,7 +127,7 @@ BOOST_AUTO_TEST_CASE(test_gendb) {
BOOST_TEST(i.size() == 2);
int index = i[1];
int expected_school =
gendb.reference_values.at("Physician").at(index).at("school");
gendb.reference_values.at({"Physician", "school", index});
BOOST_TEST(expected_school == i[0]);
}

Expand Down

0 comments on commit 4028c89

Please sign in to comment.