diff --git a/cxx/emissions/BUILD b/cxx/emissions/BUILD index 69a71be..1fe690a 100644 --- a/cxx/emissions/BUILD +++ b/cxx/emissions/BUILD @@ -42,6 +42,16 @@ cc_library( deps = [":base"], ) +cc_library( + name = "categorical", + srcs = ["categorical.hh"], + visibility = ["//:__subpackages__"], + deps = [ + ":base", + "//distributions:dirichlet_categorical", + ], +) + cc_library( name = "gaussian", srcs = ["gaussian.hh"], @@ -92,6 +102,15 @@ cc_test( ], ) +cc_test( + name = "categorical_test", + srcs = ["categorical_test.cc"], + deps = [ + ":categorical", + "@boost//:algorithm", + "@boost//:test", + ], +) cc_test( name = "gaussian_test", srcs = ["gaussian_test.cc"], diff --git a/cxx/emissions/categorical.hh b/cxx/emissions/categorical.hh new file mode 100644 index 0000000..477bf1a --- /dev/null +++ b/cxx/emissions/categorical.hh @@ -0,0 +1,73 @@ +#pragma once + +#include +#include +#include + +#include "distributions/dirichlet_categorical.hh" +#include "emissions/base.hh" + +// A "bigram" emission model that tracks separate emission distributions per +// clean categorical state. +class CategoricalEmission : public Emission { + public: + mutable std::vector emission_dists; + + CategoricalEmission(int num_states) { + emission_dists.reserve(num_states); + for (int i = 0; i < num_states; ++i) { + emission_dists.emplace_back(num_states); + } + }; + + void incorporate(const std::pair& x) { + ++N; + emission_dists[x.first].incorporate(x.second); + } + + void unincorporate(const std::pair& x) { + --N; + emission_dists[x.first].unincorporate(x.second); + } + + double logp(const std::pair& x) const { + return emission_dists[x.first].logp(x.second); + } + + double logp_score() const { + double lp = 0.0; + for (const auto& e : emission_dists) { + lp += e.logp_score(); + } + return lp; + } + + void transition_hyperparameters(std::mt19937* prng) { + for (auto& e : emission_dists) { + e.transition_hyperparameters(prng); + } + } + + int sample_corrupted(const int& clean, std::mt19937* prng) { + return emission_dists[clean].sample(prng); + } + + int propose_clean(const std::vector& corrupted, + std::mt19937* unused_prng) { + // Brute force; compute log prob over all possible clean states. + int best_clean; + double best_clean_logp = std::numeric_limits::lowest(); + for (size_t i = 0; i < emission_dists.size(); ++i) { + double lp = 0.0; + for (const auto& c : corrupted) { + lp += emission_dists[i].logp(c); + } + if (lp > best_clean_logp) { + best_clean = i; + best_clean_logp = lp; + } + } + return best_clean; + } + +}; diff --git a/cxx/emissions/categorical_test.cc b/cxx/emissions/categorical_test.cc new file mode 100644 index 0000000..5015f04 --- /dev/null +++ b/cxx/emissions/categorical_test.cc @@ -0,0 +1,36 @@ +// Apache License, Version 2.0, refer to LICENSE.txt + +#define BOOST_TEST_MODULE test CategoricalEmission + +#include "emissions/categorical.hh" + +#include +#include +namespace tt = boost::test_tools; + +BOOST_AUTO_TEST_CASE(test_simple) { + CategoricalEmission ce(5); + + BOOST_TEST(ce.logp_score() == 0.0); + BOOST_TEST(ce.N == 0); + ce.incorporate(std::make_pair(0, 2)); + BOOST_TEST(ce.N == 1); + BOOST_TEST(ce.logp_score() == -1.6094379124341001, tt::tolerance(1e-6)); + ce.unincorporate(std::make_pair(0, 2)); + BOOST_TEST(ce.N == 0); + ce.incorporate(std::make_pair(3, 3)); + ce.incorporate(std::make_pair(4, 4)); + BOOST_TEST(ce.N == 2); + + BOOST_TEST(ce.logp(std::make_pair(2, 2)) == -1.6094379124341003, + tt::tolerance(1e-6)); + + std::mt19937 prng; + int s = ce.sample_corrupted(1, &prng); + BOOST_TEST(s < 5); + BOOST_TEST(s >= 0); + + int clean = ce.propose_clean({1, 1, 3, 4}, &prng); + BOOST_TEST(clean < 5); + BOOST_TEST(clean >= 0); +}