Skip to content

Commit

Permalink
Merge pull request #182 from probcomp/082324-thomaswc-pclean_entities
Browse files Browse the repository at this point in the history
Fix translate_observation's entity assignments
  • Loading branch information
ThomasColthurst authored Aug 23, 2024
2 parents f7a39b8 + e511359 commit 3d61190
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 20 deletions.
7 changes: 5 additions & 2 deletions cxx/pclean/pclean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ int main(int argc, char** argv) {
result["only_final_emissions"].as<bool>(),
result["record_class_is_clean"].as<bool>());
std::cout << "Translating schema ...\n";
T_schema hirm_schema = schema_helper.make_hirm_schema();
std::map<std::string, std::vector<std::string>> annotated_domains_for_relations;
T_schema hirm_schema = schema_helper.make_hirm_schema(
&annotated_domains_for_relations);

// Read observations
std::cout << "Reading observations ...\n";
Expand All @@ -87,7 +89,8 @@ int main(int argc, char** argv) {

// Incorporate observations.
std::cout << "Translating observations ...\n";
T_observations observations = translate_observations(df, hirm_schema);
T_observations observations = translate_observations(
df, hirm_schema, annotated_domains_for_relations);
std::cout << "Encoding observations ...\n";
T_encoding encoding = encode_observations(hirm_schema, observations);
std::cout << "Incorporating observations ...\n";
Expand Down
9 changes: 5 additions & 4 deletions cxx/pclean/pclean_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#include "pclean/pclean_lib.hh"

T_observations translate_observations(
const DataFrame& df, const T_schema &schema) {
const DataFrame& df, const T_schema &schema,
const std::map<std::string, std::vector<std::string>>
&annotated_domains_for_relations) {
T_observations obs;

for (const auto& col : df.data) {
Expand Down Expand Up @@ -44,12 +46,11 @@ T_observations translate_observations(
std::vector<std::string> entities;
for (size_t j = 0; j < num_domains; ++j) {
// Give every row it's own universe of unique id's.
// TODO(thomaswc): Correctly handle the case when a row makes
// references to two or more different entities of the same type.
// TODO(thomaswc): Discuss other options for handling this, such
// as sampling the non-index domains from a CRP prior or specifying
// additional CSV columns to use as foreign keys.
entities.push_back(std::to_string(i));
entities.push_back(annotated_domains_for_relations.at(col_name)[j]
+ ":" + std::to_string(i));
}
obs[col_name].push_back(std::make_tuple(entities, val));
}
Expand Down
4 changes: 3 additions & 1 deletion cxx/pclean/pclean_lib.hh
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@
// is used as the relation name, and each entity in each domain is given
// its own unique value.
T_observations translate_observations(
const DataFrame& df, const T_schema &schema);
const DataFrame& df, const T_schema &schema,
const std::map<std::string, std::vector<std::string>>
&annotated_domains_for_relation);
9 changes: 8 additions & 1 deletion cxx/pclean/pclean_lib_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@ BOOST_AUTO_TEST_CASE(test_translate_observations) {
{"State",
T_noisy_relation{{"dCounty", "dObs"}, true, EmissionSpec("bigram"), "County:state"}}};

T_observations obs = translate_observations(df, schema);
std::map<std::string, std::vector<std::string>> annotated_domains_for_relations;
annotated_domains_for_relations["Room Type"] = {"county:County", "Obs"};
annotated_domains_for_relations["Monthly Rent"] = {"county:County", "Obs"};
annotated_domains_for_relations["County"] = {"county:County", "Obs"};
annotated_domains_for_relations["State"] = {"county:County", "Obs"};

T_observations obs = translate_observations(
df, schema, annotated_domains_for_relations);

// Relations not corresponding to columns should be un-observed.
BOOST_TEST(!obs.contains("County:name"));
Expand Down
26 changes: 23 additions & 3 deletions cxx/pclean/schema_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ std::string make_prefix_path(
}

void PCleanSchemaHelper::make_relations_for_queryfield(
const QueryField& f, const PCleanClass& record_class, T_schema* tschema) {
const QueryField& f, const PCleanClass& record_class, T_schema* tschema,
std::map<std::string, std::vector<std::string>>
*annotated_domains_for_relation) {
// First, find all the vars and classes specified in f.class_path.
std::vector<std::string> var_names;
std::vector<std::string> class_names;
Expand Down Expand Up @@ -88,13 +90,17 @@ void PCleanSchemaHelper::make_relations_for_queryfield(
cr.is_observed = true;
(*tschema)[f.name] = cr;
tschema->erase(base_relation_name);
(*annotated_domains_for_relation)[f.name] = annotated_domains[
record_class.name];
} else {
T_noisy_relation tnr = get_emission_relation(
std::get<ScalarVar>(last_var.spec),
domains[record_class.name],
base_relation_name);
tnr.is_observed = true;
(*tschema)[f.name] = tnr;
(*annotated_domains_for_relation)[f.name] = annotated_domains[
record_class.name];
}
return;
}
Expand All @@ -118,6 +124,11 @@ void PCleanSchemaHelper::make_relations_for_queryfield(
base_relation_name);
tnr.is_observed = true;
(*tschema)[f.name] = tnr;
std::vector<std::string> reordered_annotated_domains = reorder_domains(
annotated_domains[record_class.name],
annotated_domains[record_class.name],
path_prefix);
(*annotated_domains_for_relation)[f.name] = reordered_annotated_domains;
return;
}

Expand Down Expand Up @@ -148,6 +159,11 @@ void PCleanSchemaHelper::make_relations_for_queryfield(
(*tschema)[rel_name] = tnr;
}
previous_relation = rel_name;
std::vector<std::string> reordered_annotated_domains = reorder_domains(
annotated_domains[class_names[i]],
annotated_domains[class_names[i]],
path_prefix);
(*annotated_domains_for_relation)[rel_name] = reordered_annotated_domains;
}
}

Expand All @@ -169,7 +185,9 @@ std::vector<std::string> reorder_domains(
return output_domains;
}

T_schema PCleanSchemaHelper::make_hirm_schema() {
T_schema PCleanSchemaHelper::make_hirm_schema(
std::map<std::string, std::vector<std::string>>
*annotated_domains_for_relation) {
T_schema tschema;

// For every scalar variable, make a clean relation with the name
Expand All @@ -179,6 +197,7 @@ T_schema PCleanSchemaHelper::make_hirm_schema() {
std::string rel_name = c.first + ':' + v.first;
if (const ScalarVar* dv = std::get_if<ScalarVar>(&(v.second.spec))) {
tschema[rel_name] = get_distribution_relation(*dv, domains[c.first]);
(*annotated_domains_for_relation)[rel_name] = annotated_domains[c.first];
}
}
}
Expand All @@ -188,7 +207,8 @@ T_schema PCleanSchemaHelper::make_hirm_schema() {
// to the name of the QueryField.
const PCleanClass record_class = schema.classes[schema.query.record_class];
for (const QueryField& f : schema.query.fields) {
make_relations_for_queryfield(f, record_class, &tschema);
make_relations_for_queryfield(f, record_class, &tschema,
annotated_domains_for_relation);
}

return tschema;
Expand Down
14 changes: 11 additions & 3 deletions cxx/pclean/schema_helper.hh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ class PCleanSchemaHelper {
bool _only_final_emissions = false,
bool _record_class_is_clean = true);

T_schema make_hirm_schema();
// Translate the PCleanSchema into an HIRM T_schema.
// Also, fill annotated_domains_for_relation[r] with the vector of
// annotated domains for the relation r.
T_schema make_hirm_schema(
std::map<std::string, std::vector<std::string>>
*annotated_domains_for_relation);

// The rest of these methods are conceptually private, but actually
// public for testing.
Expand All @@ -27,8 +32,11 @@ class PCleanSchemaHelper {
void compute_domains_for(const std::string& name);

void make_relations_for_queryfield(
const QueryField& f, const PCleanClass& c,
T_schema* schema);
const QueryField& f,
const PCleanClass& c,
T_schema* schema,
std::map<std::string, std::vector<std::string>>
*annotated_domains_for_relation);

PCleanSchema schema;
bool only_final_emissions;
Expand Down
33 changes: 27 additions & 6 deletions cxx/pclean/schema_helper_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,31 +146,43 @@ BOOST_AUTO_TEST_CASE(test_make_relations_for_queryfield) {
T_schema tschema;

PCleanClass query_class = schema.classes[schema.query.record_class];
std::map<std::string, std::vector<std::string>> annotated_domains_for_relation;
schema_helper.make_relations_for_queryfield(
schema.query.fields[1], query_class, &tschema);
schema.query.fields[1], query_class, &tschema,
&annotated_domains_for_relation);

BOOST_TEST(tschema.size() == 2);
BOOST_TEST(tschema.contains("School"));
BOOST_TEST(tschema.contains("Physician:school::School:name"));
BOOST_TEST(std::get<T_noisy_relation>(tschema["School"]).is_observed);
BOOST_TEST(!std::get<T_noisy_relation>(tschema["Physician:school::School:name"]).is_observed);

std::vector<std::string> expected_adfr = {
"physician:school:School", "location:city:City",
"location:Practice", "physician:Physician", "Record"};
BOOST_TEST(annotated_domains_for_relation["School"] == expected_adfr,
tt::per_element());
}

BOOST_AUTO_TEST_CASE(test_make_relations_for_queryfield_only_final_emissions) {
PCleanSchemaHelper schema_helper(schema, true);
T_schema tschema;

PCleanClass query_class = schema.classes[schema.query.record_class];
std::map<std::string, std::vector<std::string>> annotated_domains_for_relation;
schema_helper.make_relations_for_queryfield(
schema.query.fields[1], query_class, &tschema);
schema.query.fields[1], query_class, &tschema,
&annotated_domains_for_relation);

BOOST_TEST(tschema.size() == 1);
BOOST_TEST(tschema.contains("School"));
}

BOOST_AUTO_TEST_CASE(test_make_hirm_schmea) {
PCleanSchemaHelper schema_helper(schema);
T_schema tschema = schema_helper.make_hirm_schema();
std::map<std::string, std::vector<std::string>> annotated_domains_for_relation;
T_schema tschema = schema_helper.make_hirm_schema(
&annotated_domains_for_relation);

BOOST_TEST(tschema.contains("School:name"));
T_clean_relation cr = std::get<T_clean_relation>(tschema["School:name"]);
Expand Down Expand Up @@ -247,7 +259,9 @@ BOOST_AUTO_TEST_CASE(test_make_hirm_schmea) {

BOOST_AUTO_TEST_CASE(test_make_hirm_schema_only_final_emissions) {
PCleanSchemaHelper schema_helper(schema, true);
T_schema tschema = schema_helper.make_hirm_schema();
std::map<std::string, std::vector<std::string>> annotated_domains_for_relation;
T_schema tschema = schema_helper.make_hirm_schema(
&annotated_domains_for_relation);

BOOST_TEST(tschema.contains("School:name"));
T_clean_relation cr = std::get<T_clean_relation>(tschema["School:name"]);
Expand Down Expand Up @@ -357,7 +371,9 @@ observe
assert(ok);

PCleanSchemaHelper schema_helper(schema2, false, true);
T_schema tschema = schema_helper.make_hirm_schema();
std::map<std::string, std::vector<std::string>> annotated_domains_for_relation;
T_schema tschema = schema_helper.make_hirm_schema(
&annotated_domains_for_relation);

BOOST_TEST(!tschema.contains("Record:rent"));
BOOST_TEST(tschema.contains("Rent"));
Expand All @@ -380,7 +396,9 @@ observe
assert(ok);

PCleanSchemaHelper schema_helper(schema2, false, false);
T_schema tschema = schema_helper.make_hirm_schema();
std::map<std::string, std::vector<std::string>> annotated_domains_for_relation;
T_schema tschema = schema_helper.make_hirm_schema(
&annotated_domains_for_relation);

BOOST_TEST(tschema.contains("Record:rent"));
BOOST_TEST(tschema.contains("Rent"));
Expand All @@ -389,6 +407,9 @@ observe
BOOST_TEST(!cr.is_observed);
T_noisy_relation nr = std::get<T_noisy_relation>(tschema["Rent"]);
BOOST_TEST(nr.is_observed);

std::vector<std::string> expected_adfr = {"Record"};
BOOST_TEST(annotated_domains_for_relation["Rent"] == expected_adfr);
}

BOOST_AUTO_TEST_SUITE_END()

0 comments on commit 3d61190

Please sign in to comment.