Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Model6 #172

Merged
merged 15 commits into from
Aug 21, 2024
Merged
11 changes: 8 additions & 3 deletions cxx/distributions/beta_bernoulli.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ void BetaBernoulli::transition_hyperparameters(std::mt19937* prng) {
}
}
}
int i = sample_from_logps(logps, prng);
alpha = hypers[i].first;
beta = hypers[i].second;
if (logps.empty()) {
printf("Warning! All hyperparameters for BetaBernoulli give nans!\n");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess practically inference gets stuck, because the hyperparameters don't move, but there might be some hope for other parameters to move and get unstuck? My thinking is these should be asserts since it would be hard to get out of here (you need the observations to change, so the cluster assignments to change) and highlights some suboptimality of inference that should be fixed somewhere else (either numerical stability, bad preconditioning of some sort, etc).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

assert(false);
} else {
int i = sample_from_logps(logps, prng);
alpha = hypers[i].first;
beta = hypers[i].second;
}
}
9 changes: 7 additions & 2 deletions cxx/distributions/bigram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ void Bigram::transition_hyperparameters(std::mt19937* prng) {
alphas.push_back(alphat);
}
}
int i = sample_from_logps(logps, prng);
set_alpha(alphas[i]);
if (logps.empty()) {
printf("Warning! All hyperparameters for Bigram give nans!\n");
assert(false);
} else {
int i = sample_from_logps(logps, prng);
set_alpha(alphas[i]);
}
}
9 changes: 7 additions & 2 deletions cxx/distributions/dirichlet_categorical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ void DirichletCategorical::transition_hyperparameters(std::mt19937* prng) {
alphas.push_back(alpha);
}
}
int i = sample_from_logps(logps, prng);
alpha = alphas[i];
if (alphas.empty()) {
printf("Warning: all Dirichlet hyperparameters give nans!\n");
assert(false);
} else {
int i = sample_from_logps(logps, prng);
alpha = alphas[i];
}
}
16 changes: 11 additions & 5 deletions cxx/distributions/normal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "normal.hh"

#include <cassert>
#include <cmath>
#include <numbers>

Expand Down Expand Up @@ -91,9 +92,14 @@ void Normal::transition_hyperparameters(std::mt19937* prng) {
}
}

int i = sample_from_logps(logps, prng);
r = std::get<0>(hypers[i]);
v = std::get<1>(hypers[i]);
m = std::get<2>(hypers[i]);
s = std::get<3>(hypers[i]);
if (logps.empty()) {
printf("Warning! All hyperparameters for Normal give nans!\n");
assert(false);
} else {
int i = sample_from_logps(logps, prng);
r = std::get<0>(hypers[i]);
v = std::get<1>(hypers[i]);
m = std::get<2>(hypers[i]);
s = std::get<3>(hypers[i]);
}
}
16 changes: 11 additions & 5 deletions cxx/distributions/skellam.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,17 @@ void Skellam::transition_hyperparameters(std::mt19937* prng) {
}
}
}
int i = sample_from_logps(logps, prng);
mean1 = std::get<0>(hypers[i]);
stddev1 = std::get<1>(hypers[i]);
mean2 = std::get<2>(hypers[i]);
stddev2 = std::get<3>(hypers[i]);

if (logps.empty()) {
printf("Warning! All hyperparameters for Skellam gave nans!\n");
assert(false);
} else {
int i = sample_from_logps(logps, prng);
mean1 = std::get<0>(hypers[i]);
stddev1 = std::get<1>(hypers[i]);
mean2 = std::get<2>(hypers[i]);
stddev2 = std::get<3>(hypers[i]);
}
}

void Skellam::init_theta(std::mt19937* prng) {
Expand Down
12 changes: 9 additions & 3 deletions cxx/distributions/zero_mean_normal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "zero_mean_normal.hh"

#include <cassert>
#include <cmath>
#include <numbers>

Expand Down Expand Up @@ -73,7 +74,12 @@ void ZeroMeanNormal::transition_hyperparameters(std::mt19937* prng) {
}
}

int i = sample_from_logps(logps, prng);
alpha = std::get<0>(hypers[i]);
beta = std::get<1>(hypers[i]);
if (logps.empty()) {
printf("Warning! All hyperparameters for ZeroMeanNormal gave nans!\n");
assert(false);
} else {
int i = sample_from_logps(logps, prng);
alpha = std::get<0>(hypers[i]);
beta = std::get<1>(hypers[i]);
}
}
4 changes: 2 additions & 2 deletions cxx/integration_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ startt2=$(date +%s)
endt2=$(date +%s)
startt3=$(date +%s)
./bazel-bin/pclean/pclean --schema=assets/flights.schema --obs=assets/flights_dirty.10.csv --iters=5
./bazel-bin/pclean/pclean --schema=assets/hospitals.schema --obs=assets/hospital_dirty.10.csv --iters=5
./bazel-bin/pclean/pclean --schema=assets/rents.schema --obs=assets/rents_dirty.10.csv --iters=5
./bazel-bin/pclean/pclean --schema=assets/hospitals.schema --obs=assets/hospital_dirty.10.csv --iters=5 --only_final_emissions
./bazel-bin/pclean/pclean --schema=assets/rents.schema --obs=assets/rents_dirty.10.csv --iters=5 --record_class_is_clean
endt3=$(date +%s)
echo "Integration tests in /tests ran in $(($endt1-$startt1)) seconds"
echo "hirm integration tests ran in $(($endt2-$startt2)) seconds"
Expand Down
4 changes: 2 additions & 2 deletions cxx/pclean/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ bool read_class(std::istream& is, PCleanClass* pclass) {
return false;
}

pclass->vars.push_back(v);
pclass->vars[v.name] = v;
}
return true;
}
Expand Down Expand Up @@ -204,7 +204,7 @@ bool read_schema(std::istream& is, PCleanSchema* schema) {
if (!success) {
return false;
}
schema->classes.push_back(pcc);
schema->classes[pcc.name] = pcc;
continue;
}

Expand Down
21 changes: 10 additions & 11 deletions cxx/pclean/io_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,15 @@ observe
tt::per_element());

BOOST_TEST(schema.classes.size() == 5);
BOOST_TEST(schema.classes[0].name == "School");
BOOST_TEST(schema.classes[0].vars.size() == 2);
BOOST_TEST(schema.classes[0].vars[0].name == "name");
BOOST_TEST(std::get<ScalarVar>(schema.classes[0].vars[0].spec).joint_name == "bigram");
BOOST_TEST(schema.classes[0].vars[1].name == "degree_dist");
BOOST_TEST(std::get<ScalarVar>(schema.classes[0].vars[1].spec).joint_name == "categorical");
BOOST_TEST(std::get<ScalarVar>(schema.classes[0].vars[1].spec).params["num_classes"] == "100");
BOOST_TEST(schema.classes["School"].vars.size() == 2);
BOOST_TEST(schema.classes["School"].vars.contains("name"));
BOOST_TEST(std::get<ScalarVar>(schema.classes["School"].vars["name"].spec).joint_name == "bigram");
BOOST_TEST(schema.classes["School"].vars.contains("degree_dist"));
BOOST_TEST(std::get<ScalarVar>(schema.classes["School"].vars["degree_dist"].spec).joint_name == "categorical");
BOOST_TEST(std::get<ScalarVar>(schema.classes["School"].vars["degree_dist"].spec).params["num_classes"] == "100");

BOOST_TEST(schema.classes[1].name == "Physician");
BOOST_TEST(schema.classes[1].vars.size() == 3);
BOOST_TEST(schema.classes[1].vars[0].name == "school");
BOOST_TEST(std::get<ClassVar>(schema.classes[1].vars[0].spec).class_name == "School");
BOOST_TEST(schema.classes.contains("Physician"));
BOOST_TEST(schema.classes["Physician"].vars.size() == 3);
BOOST_TEST(schema.classes["Physician"].vars.contains("school"));
BOOST_TEST(std::get<ClassVar>(schema.classes["Physician"].vars["school"].spec).class_name == "School");
}
10 changes: 9 additions & 1 deletion cxx/pclean/pclean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ int main(int argc, char** argv) {
("i,iters", "Number of inference iterations",
cxxopts::value<int>()->default_value("10"))
("seed", "Random seed", cxxopts::value<int>()->default_value("10"))
("only_final_emissions", "Only create one layer of emissions",
cxxopts::value<bool>()->default_value("false"))
("record_class_is_clean",
"If false, model queries of the query class with emissions noise.",
cxxopts::value<bool>()->default_value("true"))
("t,timeout", "Timeout in seconds for inference",
cxxopts::value<int>()->default_value("0"))
("v,verbose", "Verbose output",
Expand Down Expand Up @@ -63,7 +68,10 @@ int main(int argc, char** argv) {

// Translate schema
std::cout << "Making schema helper ...\n";
PCleanSchemaHelper schema_helper(pclean_schema);
PCleanSchemaHelper schema_helper(
pclean_schema,
result["only_final_emissions"].as<bool>(),
result["record_class_is_clean"].as<bool>());
std::cout << "Translating schema ...\n";
T_schema hirm_schema = schema_helper.make_hirm_schema();

Expand Down
6 changes: 4 additions & 2 deletions cxx/pclean/schema.hh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ struct PCleanVariable {

struct PCleanClass {
std::string name;
std::vector<PCleanVariable> vars;
// Map from variable name to variable.
std::map<std::string, PCleanVariable> vars;
// TODO(thomaswc): Figure out how to handle class level configurations.
};

Expand All @@ -43,6 +44,7 @@ struct PCleanQuery {
};

struct PCleanSchema {
std::vector<PCleanClass> classes;
// Map from class name to class.
std::map<std::string, PCleanClass> classes;
PCleanQuery query;
};
Loading