From d4d6012712da78cf8c9b13e348c22c8ed36dc8f1 Mon Sep 17 00:00:00 2001 From: Thomas Colthurst Date: Mon, 8 Jul 2024 19:38:22 +0000 Subject: [PATCH] Fix categorical logp to be conditional probability --- cxx/emissions/categorical.hh | 10 +--------- cxx/emissions/categorical_test.cc | 2 +- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/cxx/emissions/categorical.hh b/cxx/emissions/categorical.hh index 40bf4ec..477bf1a 100644 --- a/cxx/emissions/categorical.hh +++ b/cxx/emissions/categorical.hh @@ -31,15 +31,7 @@ class CategoricalEmission : public Emission { } double logp(const std::pair& x) const { - double lp; - for (size_t i = 0; i < emission_dists.size(); ++i) { - if (std::cmp_equal(i, x.first)) { - lp += emission_dists[i].logp(x.second); - } else { - lp += emission_dists[i].logp_score(); - } - } - return lp; + return emission_dists[x.first].logp(x.second); } double logp_score() const { diff --git a/cxx/emissions/categorical_test.cc b/cxx/emissions/categorical_test.cc index faa8074..5015f04 100644 --- a/cxx/emissions/categorical_test.cc +++ b/cxx/emissions/categorical_test.cc @@ -22,7 +22,7 @@ BOOST_AUTO_TEST_CASE(test_simple) { ce.incorporate(std::make_pair(4, 4)); BOOST_TEST(ce.N == 2); - BOOST_TEST(ce.logp(std::make_pair(2, 2)) == -4.8283137373023006, + BOOST_TEST(ce.logp(std::make_pair(2, 2)) == -1.6094379124341003, tt::tolerance(1e-6)); std::mt19937 prng;