forked from probsys/hierarchical-irm
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #77 from probcomp/070824-thomaswc-catemission
Add emission class for Categorical distributions
- Loading branch information
Showing
3 changed files
with
128 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
#pragma once | ||
|
||
#include <cassert> | ||
#include <limits> | ||
#include <utility> | ||
|
||
#include "distributions/dirichlet_categorical.hh" | ||
#include "emissions/base.hh" | ||
|
||
// A "bigram" emission model that tracks separate emission distributions per | ||
// clean categorical state. | ||
class CategoricalEmission : public Emission<int> { | ||
public: | ||
mutable std::vector<DirichletCategorical> emission_dists; | ||
|
||
CategoricalEmission(int num_states) { | ||
emission_dists.reserve(num_states); | ||
for (int i = 0; i < num_states; ++i) { | ||
emission_dists.emplace_back(num_states); | ||
} | ||
}; | ||
|
||
void incorporate(const std::pair<int, int>& x) { | ||
++N; | ||
emission_dists[x.first].incorporate(x.second); | ||
} | ||
|
||
void unincorporate(const std::pair<int, int>& x) { | ||
--N; | ||
emission_dists[x.first].unincorporate(x.second); | ||
} | ||
|
||
double logp(const std::pair<int, int>& x) const { | ||
return emission_dists[x.first].logp(x.second); | ||
} | ||
|
||
double logp_score() const { | ||
double lp = 0.0; | ||
for (const auto& e : emission_dists) { | ||
lp += e.logp_score(); | ||
} | ||
return lp; | ||
} | ||
|
||
void transition_hyperparameters(std::mt19937* prng) { | ||
for (auto& e : emission_dists) { | ||
e.transition_hyperparameters(prng); | ||
} | ||
} | ||
|
||
int sample_corrupted(const int& clean, std::mt19937* prng) { | ||
return emission_dists[clean].sample(prng); | ||
} | ||
|
||
int propose_clean(const std::vector<int>& corrupted, | ||
std::mt19937* unused_prng) { | ||
// Brute force; compute log prob over all possible clean states. | ||
int best_clean; | ||
double best_clean_logp = std::numeric_limits<double>::lowest(); | ||
for (size_t i = 0; i < emission_dists.size(); ++i) { | ||
double lp = 0.0; | ||
for (const auto& c : corrupted) { | ||
lp += emission_dists[i].logp(c); | ||
} | ||
if (lp > best_clean_logp) { | ||
best_clean = i; | ||
best_clean_logp = lp; | ||
} | ||
} | ||
return best_clean; | ||
} | ||
|
||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// Apache License, Version 2.0, refer to LICENSE.txt | ||
|
||
#define BOOST_TEST_MODULE test CategoricalEmission | ||
|
||
#include "emissions/categorical.hh" | ||
|
||
#include <boost/test/included/unit_test.hpp> | ||
#include <random> | ||
namespace tt = boost::test_tools; | ||
|
||
BOOST_AUTO_TEST_CASE(test_simple) { | ||
CategoricalEmission ce(5); | ||
|
||
BOOST_TEST(ce.logp_score() == 0.0); | ||
BOOST_TEST(ce.N == 0); | ||
ce.incorporate(std::make_pair<int, int>(0, 2)); | ||
BOOST_TEST(ce.N == 1); | ||
BOOST_TEST(ce.logp_score() == -1.6094379124341001, tt::tolerance(1e-6)); | ||
ce.unincorporate(std::make_pair<int, int>(0, 2)); | ||
BOOST_TEST(ce.N == 0); | ||
ce.incorporate(std::make_pair<int, int>(3, 3)); | ||
ce.incorporate(std::make_pair<int, int>(4, 4)); | ||
BOOST_TEST(ce.N == 2); | ||
|
||
BOOST_TEST(ce.logp(std::make_pair<int, int>(2, 2)) == -1.6094379124341003, | ||
tt::tolerance(1e-6)); | ||
|
||
std::mt19937 prng; | ||
int s = ce.sample_corrupted(1, &prng); | ||
BOOST_TEST(s < 5); | ||
BOOST_TEST(s >= 0); | ||
|
||
int clean = ce.propose_clean({1, 1, 3, 4}, &prng); | ||
BOOST_TEST(clean < 5); | ||
BOOST_TEST(clean >= 0); | ||
} |