From 92516a412a0b6e74f8dc77250290b131d81d9c8c Mon Sep 17 00:00:00 2001 From: Thomas Colthurst Date: Thu, 30 May 2024 19:51:25 +0000 Subject: [PATCH] Fix two big bugs in Normal --- cxx/distributions/normal.hh | 9 +++++++-- cxx/distributions/normal_test.cc | 15 +++++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/cxx/distributions/normal.hh b/cxx/distributions/normal.hh index 5ba51f6..c9568b7 100644 --- a/cxx/distributions/normal.hh +++ b/cxx/distributions/normal.hh @@ -33,8 +33,8 @@ class Normal : public Distribution { // We use Welford's algorithm for computing the mean and variance // of streaming data in a numerically stable way. See Knuth's // Art of Computer Programming vol. 2, 3rd edition, page 232. - int mean = 0; // Mean of observed values - int var = 0; // Variance of observed values + double mean = 0.0; // Mean of observed values + double var = 0.0; // Variance of observed values std::mt19937 *prng; @@ -51,6 +51,11 @@ class Normal : public Distribution { void unincorporate(const double &x) { int old_N = N; --N; + if (N == 0) { + mean = 0.0; + var = 0.0; + return; + } double old_mean = mean; mean = (mean * old_N - x) / N; var -= (x - mean) * (x - old_mean); diff --git a/cxx/distributions/normal_test.cc b/cxx/distributions/normal_test.cc index de5ddbe..d6f6ae7 100644 --- a/cxx/distributions/normal_test.cc +++ b/cxx/distributions/normal_test.cc @@ -17,8 +17,19 @@ BOOST_AUTO_TEST_CASE(simple) { nd.incorporate(7.0); nd.unincorporate(-2.0); - BOOST_TEST(nd.logp(6.0) == -3.1331256657870137, tt::tolerance(1e-6)); - BOOST_TEST(nd.logp_score() == -4.7494000141508543, tt::tolerance(1e-6)); + BOOST_TEST(nd.logp(6.0) == -2.7673076255063034, tt::tolerance(1e-6)); + BOOST_TEST(nd.logp_score() == -4.7299819282937534, tt::tolerance(1e-6)); +} + +BOOST_AUTO_TEST_CASE(no_nan_after_incorporate_unincorporate) { + std::mt19937 prng; + Normal nd(&prng); + + nd.incorporate(10.0); + nd.unincorporate(10.0); + + BOOST_TEST(!std::isnan(nd.mean)); + BOOST_TEST(!std::isnan(nd.var)); } BOOST_AUTO_TEST_CASE(logp_before_incorporate) {