Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore pclean's ability to output its clusters #223

Merged
merged 4 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cxx/pclean/pclean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ int main(int argc, char** argv) {
if (result.count("output") > 0) {
std::string out_fn = result["output"].as<std::string>();
std::cout << "Savings results to " << out_fn << "\n";
// TODO(thomaswc): Fix this.
// to_txt(out_fn, gendb.hirm, encoding);
T_encoding encoding = make_dummy_encoding_from_gendb(gendb);
to_txt(out_fn, *(gendb.hirm), encoding);
}

std::string heldout_fn = result["heldout"].as<std::string>();
Expand Down
14 changes: 14 additions & 0 deletions cxx/pclean/pclean_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,17 @@ DataFrame make_pclean_samples(int num_samples, int start_row, GenDB *gendb,
return df;
}

T_encoding make_dummy_encoding_from_gendb(const GenDB& gendb) {
T_encoding_f item_to_code;
T_encoding_r code_to_item;

for (const auto& [domain, crp] : gendb.domain_crps) {
for (int i = 0; i <= crp.max_table(); ++i) {
// TODO: Make the auto-generated string include the row number
// and CSV field name, for ease in debugging and visualizations.
code_to_item[domain][i] = domain + ":" + std::to_string(i);
}
}

return std::make_pair(item_to_code, code_to_item);
}
4 changes: 4 additions & 0 deletions cxx/pclean/pclean_lib.hh
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ void incorporate_observations(std::mt19937* prng,
// All existing rows added to gendb should have ids < start_row.
DataFrame make_pclean_samples(int num_samples, int start_row, GenDB *gendb,
std::mt19937* prng);

// Makes an encoding from a GenDB. The i-th entity from domain "domain"
// is given the name "domain:i".
T_encoding make_dummy_encoding_from_gendb(const GenDB& gendb);
82 changes: 82 additions & 0 deletions cxx/pclean/pclean_lib_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,85 @@ observe
BOOST_TEST(samples.data["City"].size() == 10);
BOOST_TEST(samples.data["State"].size() == 10);
}

BOOST_AUTO_TEST_CASE(test_make_dummy_encoding_from_gendb) {
std::mt19937 prng;

std::stringstream ss(R"""(
class School
name ~ string
degree_dist ~ categorical(k=100)

class Physician
school ~ School
degree ~ stringcat(strings="MD PT NP DO PHD")
specialty ~ stringcat(strings="Family Med:Internal Med:Physical Therapy", delim=":")
# observed_degree ~ maybe_swap(degree)

class City
name ~ string
state ~ stringcat(strings="AL AK AZ AR CA CO CT DE DC FL GA HI ID IL IN IA KS KY LA ME MD MA MI MN MS MO MT NE NV NH NJ NM NY NC ND OH OK OR PA RI SC SD TN TX UT VT VA WA WV WI WY")

class Practice
city ~ City

class Record
physician ~ Physician
location ~ Practice

observe
physician.specialty as Specialty
physician.school.name as School
physician.degree as Degree
location.city.name as City
location.city.state as State
from Record
)""");

PCleanSchema pclean_schema;
BOOST_TEST(read_schema(ss, &pclean_schema));

GenDB gendb(&prng, pclean_schema);

T_encoding enc = make_dummy_encoding_from_gendb(gendb);

BOOST_TEST(enc.second.size() == 0);

std::map<std::string, ObservationVariant> obs = {
{"Specialty", "Internal Med"},
{"School", "Harvard"},
{"Degree", "MD"},
{"City", "Cambridge"},
{"State", "MA"}};

gendb.incorporate(&prng, {0, obs}, true);
T_encoding enc2 = make_dummy_encoding_from_gendb(gendb);

BOOST_TEST(enc2.second["School"][0] == "School:0");
BOOST_TEST(enc2.second["Physician"][0] == "Physician:0");
BOOST_TEST(enc2.second["City"][0] == "City:0");
BOOST_TEST(enc2.second["Practice"][0] == "Practice:0");

BOOST_TEST(enc2.second["School"].size() == 1);

for (int i = 1; i < 6; ++i) {
gendb.incorporate(&prng, {i, obs}, true);
}

T_encoding enc3 = make_dummy_encoding_from_gendb(gendb);
BOOST_TEST(enc3.second["School"].size() == 6);
BOOST_TEST(enc3.second["School"][0] == "School:0");
BOOST_TEST(enc3.second["School"][1] == "School:1");
BOOST_TEST(enc3.second["School"][2] == "School:2");
BOOST_TEST(enc3.second["School"][3] == "School:3");
BOOST_TEST(enc3.second["School"][4] == "School:4");
BOOST_TEST(enc3.second["School"][5] == "School:5");

// Test that we got all the entities.
for (const auto& [domain, crp] : gendb.domain_crps) {
for (int i = 0; i <= crp.max_table(); ++i) {
BOOST_TEST(enc3.second[domain].contains(i));
}
}

}
2 changes: 1 addition & 1 deletion cxx/util_io.hh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

typedef std::map<std::string, std::map<std::string, T_item>> T_encoding_f;
typedef std::map<std::string, std::map<T_item, std::string>> T_encoding_r;
typedef std::tuple<T_encoding_f, T_encoding_r> T_encoding;
typedef std::pair<T_encoding_f, T_encoding_r> T_encoding;

// Load the schema file from path. Exits if the schema file can't be parsed.
T_schema load_schema(const std::string& path);
Expand Down