Skip to content

Commit

Permalink
Merge pull request #77 from probcomp/070824-thomaswc-catemission
Browse files Browse the repository at this point in the history
Add emission class for Categorical distributions
  • Loading branch information
ThomasColthurst authored Jul 8, 2024
2 parents d6061f3 + d4d6012 commit dc81c54
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 0 deletions.
19 changes: 19 additions & 0 deletions cxx/emissions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ cc_library(
deps = [":base"],
)

cc_library(
name = "categorical",
srcs = ["categorical.hh"],
visibility = ["//:__subpackages__"],
deps = [
":base",
"//distributions:dirichlet_categorical",
],
)

cc_library(
name = "gaussian",
srcs = ["gaussian.hh"],
Expand Down Expand Up @@ -92,6 +102,15 @@ cc_test(
],
)

cc_test(
name = "categorical_test",
srcs = ["categorical_test.cc"],
deps = [
":categorical",
"@boost//:algorithm",
"@boost//:test",
],
)
cc_test(
name = "gaussian_test",
srcs = ["gaussian_test.cc"],
Expand Down
73 changes: 73 additions & 0 deletions cxx/emissions/categorical.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#pragma once

#include <cassert>
#include <limits>
#include <utility>

#include "distributions/dirichlet_categorical.hh"
#include "emissions/base.hh"

// A "bigram" emission model that tracks separate emission distributions per
// clean categorical state.
class CategoricalEmission : public Emission<int> {
public:
mutable std::vector<DirichletCategorical> emission_dists;

CategoricalEmission(int num_states) {
emission_dists.reserve(num_states);
for (int i = 0; i < num_states; ++i) {
emission_dists.emplace_back(num_states);
}
};

void incorporate(const std::pair<int, int>& x) {
++N;
emission_dists[x.first].incorporate(x.second);
}

void unincorporate(const std::pair<int, int>& x) {
--N;
emission_dists[x.first].unincorporate(x.second);
}

double logp(const std::pair<int, int>& x) const {
return emission_dists[x.first].logp(x.second);
}

double logp_score() const {
double lp = 0.0;
for (const auto& e : emission_dists) {
lp += e.logp_score();
}
return lp;
}

void transition_hyperparameters(std::mt19937* prng) {
for (auto& e : emission_dists) {
e.transition_hyperparameters(prng);
}
}

int sample_corrupted(const int& clean, std::mt19937* prng) {
return emission_dists[clean].sample(prng);
}

int propose_clean(const std::vector<int>& corrupted,
std::mt19937* unused_prng) {
// Brute force; compute log prob over all possible clean states.
int best_clean;
double best_clean_logp = std::numeric_limits<double>::lowest();
for (size_t i = 0; i < emission_dists.size(); ++i) {
double lp = 0.0;
for (const auto& c : corrupted) {
lp += emission_dists[i].logp(c);
}
if (lp > best_clean_logp) {
best_clean = i;
best_clean_logp = lp;
}
}
return best_clean;
}

};
36 changes: 36 additions & 0 deletions cxx/emissions/categorical_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Apache License, Version 2.0, refer to LICENSE.txt

#define BOOST_TEST_MODULE test CategoricalEmission

#include "emissions/categorical.hh"

#include <boost/test/included/unit_test.hpp>
#include <random>
namespace tt = boost::test_tools;

BOOST_AUTO_TEST_CASE(test_simple) {
CategoricalEmission ce(5);

BOOST_TEST(ce.logp_score() == 0.0);
BOOST_TEST(ce.N == 0);
ce.incorporate(std::make_pair<int, int>(0, 2));
BOOST_TEST(ce.N == 1);
BOOST_TEST(ce.logp_score() == -1.6094379124341001, tt::tolerance(1e-6));
ce.unincorporate(std::make_pair<int, int>(0, 2));
BOOST_TEST(ce.N == 0);
ce.incorporate(std::make_pair<int, int>(3, 3));
ce.incorporate(std::make_pair<int, int>(4, 4));
BOOST_TEST(ce.N == 2);

BOOST_TEST(ce.logp(std::make_pair<int, int>(2, 2)) == -1.6094379124341003,
tt::tolerance(1e-6));

std::mt19937 prng;
int s = ce.sample_corrupted(1, &prng);
BOOST_TEST(s < 5);
BOOST_TEST(s >= 0);

int clean = ce.propose_clean({1, 1, 3, 4}, &prng);
BOOST_TEST(clean < 5);
BOOST_TEST(clean >= 0);
}

0 comments on commit dc81c54

Please sign in to comment.