From 943338c56e5e2873d3d68a93e5d36e1b3049fb38 Mon Sep 17 00:00:00 2001 From: Srinivas Vasudevan Date: Thu, 6 Jun 2024 17:39:47 -0700 Subject: [PATCH 1/2] Add likelihood and posterior predictive tests for dirichletcategorical --- .../dirichlet_categorical_test.cc | 91 +++++++++++++++++-- 1 file changed, 82 insertions(+), 9 deletions(-) diff --git a/cxx/distributions/dirichlet_categorical_test.cc b/cxx/distributions/dirichlet_categorical_test.cc index 9132afd..f55ab8d 100644 --- a/cxx/distributions/dirichlet_categorical_test.cc +++ b/cxx/distributions/dirichlet_categorical_test.cc @@ -31,20 +31,93 @@ BOOST_AUTO_TEST_CASE(test_matches_beta_bernoulli) { BOOST_TEST(dc.logp_score() == bb.logp_score(), tt::tolerance(1e-6)); } -BOOST_AUTO_TEST_CASE(test_simple) { +BOOST_AUTO_TEST_CASE(test_logp_score) { + // Sample from a Dirichlet and use that to create a MC estimate of the log + // prob. std::mt19937 prng; - DirichletCategorical dc(&prng, 10); + std::mt19937 prng2; - for (int i = 0; i < 10; ++i) { - dc.incorporate(i); + DirichletCategorical dc(&prng, 5); + + int num_samples = 1000; + std::vector> dirichlet_samples; + std::gamma_distribution<> gamma_dist(dc.alpha, 1.); + for (int i = 0; i < num_samples; ++i) { + std::vector sample; + for (int j = 0; j < 5; ++j) { + sample.emplace_back(gamma_dist(prng2)); + } + double sum_of_elements = std::accumulate(sample.begin(), sample.end(), 0.); + for (int j = 0; j < 5; ++j) { + sample[j] /= sum_of_elements; + } + dirichlet_samples.emplace_back(sample); } - for (int i = 0; i < 10; i += 2) { - dc.unincorporate(i); + + for (int i = 0; i < 20; ++i) { + dc.incorporate(i % 5); + double average_prob = 0; + for (const auto& sample : dirichlet_samples) { + double prob = 1; + for (int j = 0; j <= i; ++j) { + prob *= sample[j % 5]; + } + average_prob += prob; + } + average_prob /= num_samples; + BOOST_TEST(dc.logp_score() == log(average_prob), tt::tolerance(8e-3)); } +} - BOOST_TEST(dc.N == 5); - BOOST_TEST(dc.logp(1) == -2.0149030205422647, tt::tolerance(1e-6)); - BOOST_TEST(dc.logp_score() == -12.389393702657209, tt::tolerance(1e-6)); +BOOST_AUTO_TEST_CASE(test_logp) { + // Sample from a Dirichlet and use that to create a MC estimate of the log + // prob. + std::mt19937 prng; + std::mt19937 prng2; + + DirichletCategorical dc(&prng, 5); + + // We'll use the fact that the posterior distribution of a + // DirichletCategorical is a Dirichlet. + // Thus we only need to compute an expectation with respect to + // this Dirichlet posterior for the posterior predictive. + std::vector effective_concentration; + for (int i = 0; i < 5; ++i) { + effective_concentration.emplace_back(dc.alpha); + } + int num_samples = 5000; + + for (int i = 0; i < 20; ++i) { + dc.incorporate(i % 5); + ++effective_concentration[i % 5]; + + int test_data = (i * i) % 5; + double average_prob = 0; + + std::vector> gamma_dists; + for (int j = 0; j < 5; ++j) { + gamma_dists.emplace_back( + std::gamma_distribution<>(effective_concentration[j], 1.)); + } + + // Create samples with these concentration parameters. + for (int j = 0; j < num_samples; ++j) { + std::vector sample; + for (int k = 0; k < 5; ++k) { + sample.emplace_back(gamma_dists[k](prng2)); + } + double sum_of_elements = + std::accumulate(sample.begin(), sample.end(), 0.); + for (int k = 0; k < 5; ++k) { + sample[k] /= sum_of_elements; + } + // For this posterior sample, compute an estimate of the posterior + // predictive on the test data point. + average_prob += sample[test_data]; + } + average_prob /= num_samples; + BOOST_TEST(dc.logp(test_data) == log(average_prob), tt::tolerance(2e-2)); + } } BOOST_AUTO_TEST_CASE(test_transition_hyperparameters) { From cc5a04165aea2f250dc571d1ebb247fe216d7484 Mon Sep 17 00:00:00 2001 From: Srinivas Vasudevan Date: Thu, 6 Jun 2024 18:25:11 -0700 Subject: [PATCH 2/2] Updated to address PR comments --- .../dirichlet_categorical_test.cc | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/cxx/distributions/dirichlet_categorical_test.cc b/cxx/distributions/dirichlet_categorical_test.cc index f55ab8d..6512a46 100644 --- a/cxx/distributions/dirichlet_categorical_test.cc +++ b/cxx/distributions/dirichlet_categorical_test.cc @@ -74,28 +74,29 @@ BOOST_AUTO_TEST_CASE(test_logp) { // prob. std::mt19937 prng; std::mt19937 prng2; + int num_categories = 5; - DirichletCategorical dc(&prng, 5); + DirichletCategorical dc(&prng, num_categories); // We'll use the fact that the posterior distribution of a // DirichletCategorical is a Dirichlet. // Thus we only need to compute an expectation with respect to // this Dirichlet posterior for the posterior predictive. std::vector effective_concentration; - for (int i = 0; i < 5; ++i) { + for (int i = 0; i < num_categories; ++i) { effective_concentration.emplace_back(dc.alpha); } - int num_samples = 5000; + int num_samples = 10000; for (int i = 0; i < 20; ++i) { - dc.incorporate(i % 5); - ++effective_concentration[i % 5]; + dc.incorporate(i % num_categories); + ++effective_concentration[i % num_categories]; - int test_data = (i * i) % 5; + int test_data = (i * i) % num_categories; double average_prob = 0; std::vector> gamma_dists; - for (int j = 0; j < 5; ++j) { + for (int j = 0; j < num_categories; ++j) { gamma_dists.emplace_back( std::gamma_distribution<>(effective_concentration[j], 1.)); } @@ -103,12 +104,12 @@ BOOST_AUTO_TEST_CASE(test_logp) { // Create samples with these concentration parameters. for (int j = 0; j < num_samples; ++j) { std::vector sample; - for (int k = 0; k < 5; ++k) { + for (int k = 0; k < num_categories; ++k) { sample.emplace_back(gamma_dists[k](prng2)); } double sum_of_elements = std::accumulate(sample.begin(), sample.end(), 0.); - for (int k = 0; k < 5; ++k) { + for (int k = 0; k < num_categories; ++k) { sample[k] /= sum_of_elements; } // For this posterior sample, compute an estimate of the posterior @@ -116,7 +117,7 @@ BOOST_AUTO_TEST_CASE(test_logp) { average_prob += sample[test_data]; } average_prob /= num_samples; - BOOST_TEST(dc.logp(test_data) == log(average_prob), tt::tolerance(2e-2)); + BOOST_TEST(dc.logp(test_data) == log(average_prob), tt::tolerance(8e-3)); } }