Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add likelihood and posterior predictive tests for DirichletCategorical. #39

Merged
merged 2 commits into from
Jun 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 83 additions & 9 deletions cxx/distributions/dirichlet_categorical_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,94 @@ BOOST_AUTO_TEST_CASE(test_matches_beta_bernoulli) {
BOOST_TEST(dc.logp_score() == bb.logp_score(), tt::tolerance(1e-6));
}

BOOST_AUTO_TEST_CASE(test_simple) {
BOOST_AUTO_TEST_CASE(test_logp_score) {
// Sample from a Dirichlet and use that to create a MC estimate of the log
// prob.
std::mt19937 prng;
DirichletCategorical dc(&prng, 10);
std::mt19937 prng2;

for (int i = 0; i < 10; ++i) {
dc.incorporate(i);
DirichletCategorical dc(&prng, 5);
srvasude marked this conversation as resolved.
Show resolved Hide resolved

int num_samples = 1000;
std::vector<std::vector<double>> dirichlet_samples;
std::gamma_distribution<> gamma_dist(dc.alpha, 1.);
for (int i = 0; i < num_samples; ++i) {
std::vector<double> sample;
for (int j = 0; j < 5; ++j) {
sample.emplace_back(gamma_dist(prng2));
}
double sum_of_elements = std::accumulate(sample.begin(), sample.end(), 0.);
for (int j = 0; j < 5; ++j) {
sample[j] /= sum_of_elements;
}
dirichlet_samples.emplace_back(sample);
}
for (int i = 0; i < 10; i += 2) {
dc.unincorporate(i);

for (int i = 0; i < 20; ++i) {
dc.incorporate(i % 5);
double average_prob = 0;
for (const auto& sample : dirichlet_samples) {
double prob = 1;
for (int j = 0; j <= i; ++j) {
prob *= sample[j % 5];
}
average_prob += prob;
}
average_prob /= num_samples;
BOOST_TEST(dc.logp_score() == log(average_prob), tt::tolerance(8e-3));
}
}

BOOST_TEST(dc.N == 5);
BOOST_TEST(dc.logp(1) == -2.0149030205422647, tt::tolerance(1e-6));
BOOST_TEST(dc.logp_score() == -12.389393702657209, tt::tolerance(1e-6));
BOOST_AUTO_TEST_CASE(test_logp) {
// Sample from a Dirichlet and use that to create a MC estimate of the log
// prob.
std::mt19937 prng;
std::mt19937 prng2;
int num_categories = 5;

DirichletCategorical dc(&prng, num_categories);

// We'll use the fact that the posterior distribution of a
// DirichletCategorical is a Dirichlet.
// Thus we only need to compute an expectation with respect to
// this Dirichlet posterior for the posterior predictive.
std::vector<double> effective_concentration;
for (int i = 0; i < num_categories; ++i) {
effective_concentration.emplace_back(dc.alpha);
}
int num_samples = 10000;

for (int i = 0; i < 20; ++i) {
dc.incorporate(i % num_categories);
++effective_concentration[i % num_categories];

int test_data = (i * i) % num_categories;
double average_prob = 0;

std::vector<std::gamma_distribution<>> gamma_dists;
for (int j = 0; j < num_categories; ++j) {
gamma_dists.emplace_back(
std::gamma_distribution<>(effective_concentration[j], 1.));
}

// Create samples with these concentration parameters.
for (int j = 0; j < num_samples; ++j) {
std::vector<double> sample;
for (int k = 0; k < num_categories; ++k) {
sample.emplace_back(gamma_dists[k](prng2));
}
double sum_of_elements =
std::accumulate(sample.begin(), sample.end(), 0.);
for (int k = 0; k < num_categories; ++k) {
sample[k] /= sum_of_elements;
}
// For this posterior sample, compute an estimate of the posterior
// predictive on the test data point.
average_prob += sample[test_data];
}
average_prob /= num_samples;
BOOST_TEST(dc.logp(test_data) == log(average_prob), tt::tolerance(8e-3));
}
}

BOOST_AUTO_TEST_CASE(test_transition_hyperparameters) {
Expand Down