diff --git a/cxx/distributions/zero_mean_normal_test.cc b/cxx/distributions/zero_mean_normal_test.cc index fddfc03..5cfb35c 100644 --- a/cxx/distributions/zero_mean_normal_test.cc +++ b/cxx/distributions/zero_mean_normal_test.cc @@ -4,12 +4,55 @@ #include "distributions/zero_mean_normal.hh" +#include +#include +#include #include #include "util_math.hh" namespace bm = boost::math; namespace tt = boost::test_tools; +BOOST_AUTO_TEST_CASE(test_log_prob) { + std::mt19937 prng; + ZeroMeanNormal nd(&prng); + + double nd_v = 2.0 * nd.alpha; + double nd_s = 2.0 * nd.beta; + double nd_m = 0.0; + double inv_nd_r = 0.0; + bm::inverse_gamma_distribution inv_gamma_dist(nd_v / 2., nd_s / 2.); + auto quad = bm::quadrature::gauss_kronrod(); + + for (int i = 0; i < 10; ++i) { + nd.incorporate(i); + + auto integrand1 = [&nd_m, &inv_nd_r, &inv_gamma_dist, &i](double n, double ig) { + bm::normal_distribution normal_prior_dist(nd_m, sqrt(ig * inv_nd_r)); + bm::normal_distribution normal_dist(n, sqrt(ig)); + double result = + bm::pdf(normal_prior_dist, n) * bm::pdf(inv_gamma_dist, ig); + for (int j = 0; j <= i; ++j) { + result *= bm::pdf(normal_dist, j); + } + return result; + }; + + auto integrand2 = [&quad, &integrand1](double ig) { + auto f = [&](double n) { return integrand1(n, ig); }; + return quad.integrate(f, -std::numeric_limits::infinity(), + std::numeric_limits::infinity()); + }; + + double result = + quad.integrate(integrand2, 0., std::numeric_limits::infinity()); + + BOOST_TEST(nd.logp_score() == log(result), tt::tolerance(1e-4)); + } + BOOST_TEST(nd.N == 10); +} + + BOOST_AUTO_TEST_CASE(simple) { std::mt19937 prng; ZeroMeanNormal nd(&prng);