diff --git a/cxx/pclean/schema_helper.cc b/cxx/pclean/schema_helper.cc index 7b56b38..cd45f45 100644 --- a/cxx/pclean/schema_helper.cc +++ b/cxx/pclean/schema_helper.cc @@ -15,29 +15,34 @@ void PCleanSchemaHelper::compute_class_name_cache() { void PCleanSchemaHelper::compute_domains_cache() { for (const auto& c: schema.classes) { if (!domains.contains(c.name)) { - domains[c.name] = compute_domains_for(c.name); + compute_domains_for(c.name); } } } -std::vector PCleanSchemaHelper::compute_domains_for( - const std::string& name) { +void PCleanSchemaHelper::compute_domains_for(const std::string& name) { std::vector ds; + std::vector annotated_ds; ds.push_back(name); + annotated_ds.push_back(name); PCleanClass c = get_class_by_name(name); for (const auto& v: c.vars) { if (const ClassVar* cv = std::get_if(&(v.spec))) { if (!domains.contains(cv->class_name)) { - domains[cv->class_name] = compute_domains_for(cv->class_name); + compute_domains_for(cv->class_name); } for (const std::string& s : domains[cv->class_name]) { - ds.push_back(v.name + ':' + s); + ds.push_back(s); + } + for (const std::string& s : annotated_domains[cv->class_name]) { + annotated_ds.push_back(v.name + ':' + s); } } } - return ds; + domains[name] = ds; + annotated_domains[name] = annotated_ds; } PCleanClass PCleanSchemaHelper::get_class_by_name(const std::string& name) { @@ -47,7 +52,8 @@ PCleanClass PCleanSchemaHelper::get_class_by_name(const std::string& name) { PCleanVariable PCleanSchemaHelper::get_scalarvar_from_path( const PCleanClass& base_class, std::vector::const_iterator path_iterator, - std::string* final_class_name) { + std::string* final_class_name, + std::string* path_prefix) { const std::string& s = *path_iterator; for (const PCleanVariable& v : base_class.vars) { if (v.name == s) { @@ -55,10 +61,11 @@ PCleanVariable PCleanSchemaHelper::get_scalarvar_from_path( *final_class_name = base_class.name; return v; } + path_prefix->append(v.name + ":"); const PCleanClass& next_class = get_class_by_name( std::get(v.spec).class_name); PCleanVariable sv = get_scalarvar_from_path( - next_class, ++path_iterator, final_class_name); + next_class, ++path_iterator, final_class_name, path_prefix); return sv; } } @@ -67,6 +74,26 @@ PCleanVariable PCleanSchemaHelper::get_scalarvar_from_path( assert(false); } +// Returns original_domains, but with the elements corresponding to +// annotated_ds elements that start with prefix moved to the front. +std::vector reorder_domains( + const std::vector& original_domains, + const std::vector& annotated_ds, + const std::string& prefix) { + std::vector output_domains; + for (size_t i = 0; i < original_domains.size(); ++i) { + if (annotated_ds[i].starts_with(prefix)) { + output_domains.push_back(original_domains[i]); + } + } + for (size_t i = 0; i < original_domains.size(); ++i) { + if (!annotated_ds[i].starts_with(prefix)) { + output_domains.push_back(original_domains[i]); + } + } + return output_domains; +} + T_schema PCleanSchemaHelper::make_hirm_schema() { T_schema tschema; for (const auto& c : schema.classes) { @@ -81,11 +108,16 @@ T_schema PCleanSchemaHelper::make_hirm_schema() { const PCleanClass query_class = get_class_by_name(schema.query.record_class); for (const auto& f : schema.query.fields) { std::string final_class_name; + std::string path_prefix; const PCleanVariable sv = get_scalarvar_from_path( - query_class, f.class_path.cbegin(), &final_class_name); + query_class, f.class_path.cbegin(), &final_class_name, &path_prefix); std::string base_relation = final_class_name + ':' + sv.name; + std::vector reordered_domains = reorder_domains( + domains[query_class.name], + annotated_domains[query_class.name], + path_prefix); tschema[f.name] = get_emission_relation( - std::get(sv.spec), domains[query_class.name], base_relation); + std::get(sv.spec), reordered_domains, base_relation); } return tschema; diff --git a/cxx/pclean/schema_helper.hh b/cxx/pclean/schema_helper.hh index b7529cb..11a750d 100644 --- a/cxx/pclean/schema_helper.hh +++ b/cxx/pclean/schema_helper.hh @@ -25,14 +25,16 @@ class PCleanSchemaHelper { void compute_class_name_cache(); void compute_domains_cache(); - std::vector compute_domains_for(const std::string& name); + void compute_domains_for(const std::string& name); PCleanVariable get_scalarvar_from_path( const PCleanClass& base_class, std::vector::const_iterator path_iterator, - std::string* final_class_name); + std::string* final_class_name, + std::string* path_prefix); PCleanSchema schema; std::map class_name_to_index; std::map> domains; + std::map> annotated_domains; }; diff --git a/cxx/pclean/schema_helper_test.cc b/cxx/pclean/schema_helper_test.cc index 65fc1a9..747d060 100644 --- a/cxx/pclean/schema_helper_test.cc +++ b/cxx/pclean/schema_helper_test.cc @@ -58,21 +58,32 @@ BOOST_AUTO_TEST_CASE(test_domains_cache) { PCleanSchemaHelper schema_helper(schema); std::vector expected_domains = {"School"}; + std::vector expected_annotated_domains = {"School"}; BOOST_TEST(schema_helper.domains["School"] == expected_domains); + BOOST_TEST(schema_helper.annotated_domains["School"] == expected_annotated_domains); - expected_domains = {"Physician", "school:School"}; + expected_domains = {"Physician", "School"}; + expected_annotated_domains = {"Physician", "school:School"}; BOOST_TEST(schema_helper.domains["Physician"] == expected_domains); + BOOST_TEST(schema_helper.annotated_domains["Physician"] == expected_annotated_domains); expected_domains = {"City"}; + expected_annotated_domains = {"City"}; BOOST_TEST(schema_helper.domains["City"] == expected_domains); + BOOST_TEST(schema_helper.annotated_domains["City"] == expected_annotated_domains); - expected_domains = {"Practice", "city:City"}; + expected_domains = {"Practice", "City"}; + expected_annotated_domains = {"Practice", "city:City"}; BOOST_TEST(schema_helper.domains["Practice"] == expected_domains); + BOOST_TEST(schema_helper.annotated_domains["Practice"] == expected_annotated_domains); expected_domains = { + "Record", "Physician", "School", "Practice", "City"}; + expected_annotated_domains = { "Record", "physician:Physician", "physician:school:School", "location:Practice", "location:city:City"}; BOOST_TEST(schema_helper.domains["Record"] == expected_domains); + BOOST_TEST(schema_helper.annotated_domains["Record"] == expected_annotated_domains); } BOOST_AUTO_TEST_CASE(test_domains_cache_two_paths_same_source) { @@ -89,8 +100,11 @@ class Person PCleanSchemaHelper schema_helper(schema); std::vector expected_domains = { + "Person", "City", "City"}; + std::vector expected_annotated_domains = { "Person", "birth_city:City", "home_city:City"}; BOOST_TEST(schema_helper.domains["Person"] == expected_domains); + BOOST_TEST(schema_helper.annotated_domains["Person"] == expected_annotated_domains); } BOOST_AUTO_TEST_CASE(test_domains_cache_diamond) { @@ -113,9 +127,12 @@ class Physician PCleanSchemaHelper schema_helper(schema); std::vector expected_domains = { + "Physician", "Practice", "City", "School", "City"}; + std::vector expected_annotated_domains = { "Physician", "practice:Practice", "practice:location:City", "school:School", "school:location:City"}; BOOST_TEST(schema_helper.domains["Physician"] == expected_domains); + BOOST_TEST(schema_helper.annotated_domains["Physician"] == expected_annotated_domains); } BOOST_AUTO_TEST_CASE(test_make_hirm_schmea) { @@ -138,7 +155,7 @@ BOOST_AUTO_TEST_CASE(test_make_hirm_schmea) { BOOST_TEST(tschema.contains("Physician:degree")); T_clean_relation cr3 = std::get(tschema["Physician:degree"]); BOOST_TEST((cr3.distribution_spec.distribution == DistributionEnum::stringcat)); - std::vector expected_domains2 = {"Physician", "school:School"}; + std::vector expected_domains2 = {"Physician", "School"}; BOOST_TEST(cr3.domains == expected_domains2); BOOST_TEST(tschema.contains("Physician:specialty")); @@ -154,33 +171,40 @@ BOOST_AUTO_TEST_CASE(test_make_hirm_schmea) { T_noisy_relation nr1 = std::get(tschema["Specialty"]); BOOST_TEST(!nr1.is_observed); BOOST_TEST((nr1.emission_spec.emission == EmissionEnum::bigram_string)); - expected_domains = { - "Record", "physician:Physician", "physician:school:School", - "location:Practice", "location:city:City"}; + // "Physician", "School" moved to the front of the list. + expected_domains = {"Physician", "School", "Record", "Practice", "City"}; BOOST_TEST(nr1.domains == expected_domains); BOOST_TEST(tschema.contains("School")); T_noisy_relation nr2 = std::get(tschema["School"]); BOOST_TEST(!nr2.is_observed); BOOST_TEST((nr2.emission_spec.emission == EmissionEnum::bigram_string)); + // "School" moved to the front of the list. + expected_domains = {"School", "Record", "Physician", "Practice", "City"}; BOOST_TEST(nr2.domains == expected_domains); BOOST_TEST(tschema.contains("Degree")); T_noisy_relation nr3 = std::get(tschema["Degree"]); BOOST_TEST(!nr3.is_observed); BOOST_TEST((nr3.emission_spec.emission == EmissionEnum::bigram_string)); + // "Physician", "School" moved to the front of the list. + expected_domains = {"Physician", "School", "Record", "Practice", "City"}; BOOST_TEST(nr3.domains == expected_domains); BOOST_TEST(tschema.contains("City")); T_noisy_relation nr4 = std::get(tschema["City"]); BOOST_TEST(!nr4.is_observed); BOOST_TEST((nr4.emission_spec.emission == EmissionEnum::bigram_string)); + // "City" moved to the front of the list. + expected_domains = {"City", "Record", "Physician", "School", "Practice"}; BOOST_TEST(nr4.domains == expected_domains); BOOST_TEST(tschema.contains("State")); T_noisy_relation nr5 = std::get(tschema["State"]); BOOST_TEST(!nr5.is_observed); BOOST_TEST((nr5.emission_spec.emission == EmissionEnum::bigram_string)); + // "City" moved to the front of the list. + expected_domains = {"City", "Record", "Physician", "School", "Practice"}; BOOST_TEST(nr5.domains == expected_domains); }