Skip to content

Commit

Permalink
Add annotated_domains and use it to reorder the domains of noisy_rela…
Browse files Browse the repository at this point in the history
…tions
  • Loading branch information
ThomasColthurst committed Jul 31, 2024
1 parent 993044e commit 5cea10e
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 18 deletions.
52 changes: 42 additions & 10 deletions cxx/pclean/schema_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,34 @@ void PCleanSchemaHelper::compute_class_name_cache() {
void PCleanSchemaHelper::compute_domains_cache() {
for (const auto& c: schema.classes) {
if (!domains.contains(c.name)) {
domains[c.name] = compute_domains_for(c.name);
compute_domains_for(c.name);
}
}
}

std::vector<std::string> PCleanSchemaHelper::compute_domains_for(
const std::string& name) {
void PCleanSchemaHelper::compute_domains_for(const std::string& name) {
std::vector<std::string> ds;
std::vector<std::string> annotated_ds;
ds.push_back(name);
annotated_ds.push_back(name);
PCleanClass c = get_class_by_name(name);

for (const auto& v: c.vars) {
if (const ClassVar* cv = std::get_if<ClassVar>(&(v.spec))) {
if (!domains.contains(cv->class_name)) {
domains[cv->class_name] = compute_domains_for(cv->class_name);
compute_domains_for(cv->class_name);
}
for (const std::string& s : domains[cv->class_name]) {
ds.push_back(v.name + ':' + s);
ds.push_back(s);
}
for (const std::string& s : annotated_domains[cv->class_name]) {
annotated_ds.push_back(v.name + ':' + s);
}
}
}

return ds;
domains[name] = ds;
annotated_domains[name] = annotated_ds;
}

PCleanClass PCleanSchemaHelper::get_class_by_name(const std::string& name) {
Expand All @@ -47,18 +52,20 @@ PCleanClass PCleanSchemaHelper::get_class_by_name(const std::string& name) {
PCleanVariable PCleanSchemaHelper::get_scalarvar_from_path(
const PCleanClass& base_class,
std::vector<std::string>::const_iterator path_iterator,
std::string* final_class_name) {
std::string* final_class_name,
std::string* path_prefix) {
const std::string& s = *path_iterator;
for (const PCleanVariable& v : base_class.vars) {
if (v.name == s) {
if (std::holds_alternative<ScalarVar>(v.spec)) {
*final_class_name = base_class.name;
return v;
}
path_prefix->append(v.name + ":");
const PCleanClass& next_class = get_class_by_name(
std::get<ClassVar>(v.spec).class_name);
PCleanVariable sv = get_scalarvar_from_path(
next_class, ++path_iterator, final_class_name);
next_class, ++path_iterator, final_class_name, path_prefix);
return sv;
}
}
Expand All @@ -67,6 +74,26 @@ PCleanVariable PCleanSchemaHelper::get_scalarvar_from_path(
assert(false);
}

// Returns original_domains, but with the elements corresponding to
// annotated_ds elements that start with prefix moved to the front.
std::vector<std::string> reorder_domains(
const std::vector<std::string>& original_domains,
const std::vector<std::string>& annotated_ds,
const std::string& prefix) {
std::vector<std::string> output_domains;
for (size_t i = 0; i < original_domains.size(); ++i) {
if (annotated_ds[i].starts_with(prefix)) {
output_domains.push_back(original_domains[i]);
}
}
for (size_t i = 0; i < original_domains.size(); ++i) {
if (!annotated_ds[i].starts_with(prefix)) {
output_domains.push_back(original_domains[i]);
}
}
return output_domains;
}

T_schema PCleanSchemaHelper::make_hirm_schema() {
T_schema tschema;
for (const auto& c : schema.classes) {
Expand All @@ -81,11 +108,16 @@ T_schema PCleanSchemaHelper::make_hirm_schema() {
const PCleanClass query_class = get_class_by_name(schema.query.record_class);
for (const auto& f : schema.query.fields) {
std::string final_class_name;
std::string path_prefix;
const PCleanVariable sv = get_scalarvar_from_path(
query_class, f.class_path.cbegin(), &final_class_name);
query_class, f.class_path.cbegin(), &final_class_name, &path_prefix);
std::string base_relation = final_class_name + ':' + sv.name;
std::vector<std::string> reordered_domains = reorder_domains(
domains[query_class.name],
annotated_domains[query_class.name],
path_prefix);
tschema[f.name] = get_emission_relation(
std::get<ScalarVar>(sv.spec), domains[query_class.name], base_relation);
std::get<ScalarVar>(sv.spec), reordered_domains, base_relation);
}

return tschema;
Expand Down
6 changes: 4 additions & 2 deletions cxx/pclean/schema_helper.hh
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ class PCleanSchemaHelper {
void compute_class_name_cache();
void compute_domains_cache();

std::vector<std::string> compute_domains_for(const std::string& name);
void compute_domains_for(const std::string& name);

PCleanVariable get_scalarvar_from_path(
const PCleanClass& base_class,
std::vector<std::string>::const_iterator path_iterator,
std::string* final_class_name);
std::string* final_class_name,
std::string* path_prefix);

PCleanSchema schema;
std::map<std::string, int> class_name_to_index;
std::map<std::string, std::vector<std::string>> domains;
std::map<std::string, std::vector<std::string>> annotated_domains;
};
36 changes: 30 additions & 6 deletions cxx/pclean/schema_helper_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,32 @@ BOOST_AUTO_TEST_CASE(test_domains_cache) {
PCleanSchemaHelper schema_helper(schema);

std::vector<std::string> expected_domains = {"School"};
std::vector<std::string> expected_annotated_domains = {"School"};
BOOST_TEST(schema_helper.domains["School"] == expected_domains);
BOOST_TEST(schema_helper.annotated_domains["School"] == expected_annotated_domains);

expected_domains = {"Physician", "school:School"};
expected_domains = {"Physician", "School"};
expected_annotated_domains = {"Physician", "school:School"};
BOOST_TEST(schema_helper.domains["Physician"] == expected_domains);
BOOST_TEST(schema_helper.annotated_domains["Physician"] == expected_annotated_domains);

expected_domains = {"City"};
expected_annotated_domains = {"City"};
BOOST_TEST(schema_helper.domains["City"] == expected_domains);
BOOST_TEST(schema_helper.annotated_domains["City"] == expected_annotated_domains);

expected_domains = {"Practice", "city:City"};
expected_domains = {"Practice", "City"};
expected_annotated_domains = {"Practice", "city:City"};
BOOST_TEST(schema_helper.domains["Practice"] == expected_domains);
BOOST_TEST(schema_helper.annotated_domains["Practice"] == expected_annotated_domains);

expected_domains = {
"Record", "Physician", "School", "Practice", "City"};
expected_annotated_domains = {
"Record", "physician:Physician", "physician:school:School",
"location:Practice", "location:city:City"};
BOOST_TEST(schema_helper.domains["Record"] == expected_domains);
BOOST_TEST(schema_helper.annotated_domains["Record"] == expected_annotated_domains);
}

BOOST_AUTO_TEST_CASE(test_domains_cache_two_paths_same_source) {
Expand All @@ -89,8 +100,11 @@ class Person
PCleanSchemaHelper schema_helper(schema);

std::vector<std::string> expected_domains = {
"Person", "City", "City"};
std::vector<std::string> expected_annotated_domains = {
"Person", "birth_city:City", "home_city:City"};
BOOST_TEST(schema_helper.domains["Person"] == expected_domains);
BOOST_TEST(schema_helper.annotated_domains["Person"] == expected_annotated_domains);
}

BOOST_AUTO_TEST_CASE(test_domains_cache_diamond) {
Expand All @@ -113,9 +127,12 @@ class Physician
PCleanSchemaHelper schema_helper(schema);

std::vector<std::string> expected_domains = {
"Physician", "Practice", "City", "School", "City"};
std::vector<std::string> expected_annotated_domains = {
"Physician", "practice:Practice", "practice:location:City",
"school:School", "school:location:City"};
BOOST_TEST(schema_helper.domains["Physician"] == expected_domains);
BOOST_TEST(schema_helper.annotated_domains["Physician"] == expected_annotated_domains);
}

BOOST_AUTO_TEST_CASE(test_make_hirm_schmea) {
Expand All @@ -138,7 +155,7 @@ BOOST_AUTO_TEST_CASE(test_make_hirm_schmea) {
BOOST_TEST(tschema.contains("Physician:degree"));
T_clean_relation cr3 = std::get<T_clean_relation>(tschema["Physician:degree"]);
BOOST_TEST((cr3.distribution_spec.distribution == DistributionEnum::stringcat));
std::vector<std::string> expected_domains2 = {"Physician", "school:School"};
std::vector<std::string> expected_domains2 = {"Physician", "School"};
BOOST_TEST(cr3.domains == expected_domains2);

BOOST_TEST(tschema.contains("Physician:specialty"));
Expand All @@ -154,33 +171,40 @@ BOOST_AUTO_TEST_CASE(test_make_hirm_schmea) {
T_noisy_relation nr1 = std::get<T_noisy_relation>(tschema["Specialty"]);
BOOST_TEST(!nr1.is_observed);
BOOST_TEST((nr1.emission_spec.emission == EmissionEnum::bigram_string));
expected_domains = {
"Record", "physician:Physician", "physician:school:School",
"location:Practice", "location:city:City"};
// "Physician", "School" moved to the front of the list.
expected_domains = {"Physician", "School", "Record", "Practice", "City"};
BOOST_TEST(nr1.domains == expected_domains);

BOOST_TEST(tschema.contains("School"));
T_noisy_relation nr2 = std::get<T_noisy_relation>(tschema["School"]);
BOOST_TEST(!nr2.is_observed);
BOOST_TEST((nr2.emission_spec.emission == EmissionEnum::bigram_string));
// "School" moved to the front of the list.
expected_domains = {"School", "Record", "Physician", "Practice", "City"};
BOOST_TEST(nr2.domains == expected_domains);

BOOST_TEST(tschema.contains("Degree"));
T_noisy_relation nr3 = std::get<T_noisy_relation>(tschema["Degree"]);
BOOST_TEST(!nr3.is_observed);
BOOST_TEST((nr3.emission_spec.emission == EmissionEnum::bigram_string));
// "Physician", "School" moved to the front of the list.
expected_domains = {"Physician", "School", "Record", "Practice", "City"};
BOOST_TEST(nr3.domains == expected_domains);

BOOST_TEST(tschema.contains("City"));
T_noisy_relation nr4 = std::get<T_noisy_relation>(tschema["City"]);
BOOST_TEST(!nr4.is_observed);
BOOST_TEST((nr4.emission_spec.emission == EmissionEnum::bigram_string));
// "City" moved to the front of the list.
expected_domains = {"City", "Record", "Physician", "School", "Practice"};
BOOST_TEST(nr4.domains == expected_domains);

BOOST_TEST(tschema.contains("State"));
T_noisy_relation nr5 = std::get<T_noisy_relation>(tschema["State"]);
BOOST_TEST(!nr5.is_observed);
BOOST_TEST((nr5.emission_spec.emission == EmissionEnum::bigram_string));
// "City" moved to the front of the list.
expected_domains = {"City", "Record", "Physician", "School", "Practice"};
BOOST_TEST(nr5.domains == expected_domains);
}

Expand Down

0 comments on commit 5cea10e

Please sign in to comment.