-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BitFlip and Sometimes Emission classes #46
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#pragma once | ||
|
||
#include <cassert> | ||
#include "emissions/base.hh" | ||
|
||
// A *deterministic* Emission class that always emits not(clean). | ||
// Most users will want to combine this with Sometimes. | ||
class BitFlip : public Emission<bool> { | ||
public: | ||
BitFlip() {}; | ||
|
||
void incorporate(const std::pair<bool, bool>& x) { | ||
assert(x.first != x.second); | ||
++N; | ||
} | ||
|
||
void unincorporate(const std::pair<bool, bool>& x) { | ||
assert(x.first != x.second); | ||
--N; | ||
} | ||
|
||
double logp(const std::pair<bool, bool>& x) const { | ||
assert(x.first != x.second); | ||
return 0.0; | ||
} | ||
|
||
double logp_score() const { | ||
return 0.0; | ||
} | ||
|
||
// No hyperparameters to transition! | ||
void transition_hyperparameters() {} | ||
|
||
bool sample_corrupted(const bool& clean, std::mt19937* unused_prng) { | ||
return !clean; | ||
} | ||
|
||
bool propose_clean(const std::vector<bool>& corrupted, | ||
std::mt19937* unused_prng) { | ||
return !corrupted[0]; | ||
} | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Apache License, Version 2.0, refer to LICENSE.txt | ||
|
||
#define BOOST_TEST_MODULE test BitFlip | ||
|
||
#include <random> | ||
|
||
#include "emissions/bitflip.hh" | ||
|
||
#include <boost/test/included/unit_test.hpp> | ||
|
||
BOOST_AUTO_TEST_CASE(test_simple) { | ||
BitFlip bf; | ||
|
||
BOOST_TEST(bf.logp_score() == 0.0); | ||
BOOST_TEST(bf.N == 0); | ||
bf.incorporate(std::make_pair<bool, bool>(true, false)); | ||
BOOST_TEST(bf.logp_score() == 0.0); | ||
BOOST_TEST(bf.N == 1); | ||
bf.unincorporate(std::make_pair<bool, bool>(true, false)); | ||
BOOST_TEST(bf.logp_score() == 0.0); | ||
BOOST_TEST(bf.N == 0); | ||
bf.incorporate(std::make_pair<bool, bool>(false, true)); | ||
bf.incorporate(std::make_pair<bool, bool>(false, true)); | ||
BOOST_TEST(bf.logp_score() == 0.0); | ||
BOOST_TEST(bf.N == 2); | ||
|
||
BOOST_TEST(bf.logp(std::make_pair<bool, bool>(true, false)) == 0.0); | ||
|
||
std::mt19937 prng; | ||
BOOST_TEST(bf.sample_corrupted(false, &prng)); | ||
|
||
BOOST_TEST(bf.propose_clean({false, false, false}, &prng)); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#pragma once | ||
|
||
#include <unordered_map> | ||
|
||
#include "distributions/beta_bernoulli.hh" | ||
#include "emissions/base.hh" | ||
|
||
// An Emission class that sometimes applies BaseEmissor and sometimes doesn't. | ||
// BaseEmissor must (1) of type Emission<SampleType> and (2) assign zero | ||
// probability to <clean, dirty> pairs with clean == dirty. [For example, | ||
// BitFlip and Gaussian both satisfy #2]. | ||
template <typename BaseEmissor, typename SampleType = double> | ||
class Sometimes : public Emission<SampleType> { | ||
public: | ||
BetaBernoulli bb; | ||
BaseEmissor be; | ||
|
||
Sometimes() : bb(nullptr) {}; | ||
|
||
void incorporate(const std::pair<SampleType, SampleType>& x) { | ||
++(this->N); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For my own understanding, how come There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really really wish I knew. I originally had it as ++N; and the compiler complained that no variable named N was defined in scope. |
||
bb.incorporate(x.first != x.second); | ||
if (x.first != x.second) { | ||
be.incorporate(x); | ||
} | ||
} | ||
|
||
void unincorporate(const std::pair<SampleType, SampleType>& x) { | ||
--(this->N); | ||
bb.unincorporate(x.first != x.second); | ||
if (x.first != x.second) { | ||
be.unincorporate(x); | ||
} | ||
} | ||
|
||
double logp(const std::pair<SampleType, SampleType>& x) const { | ||
return bb.logp(x.first != x.second) + be.logp(x); | ||
} | ||
|
||
double logp_score() const { | ||
return bb.logp_score() + be.logp_score(); | ||
} | ||
|
||
void transition_hyperparameters() { | ||
be.transition_hyperparameters(); | ||
bb.transition_hyperparameters(); | ||
} | ||
|
||
SampleType sample_corrupted(const SampleType& clean, std::mt19937* prng) { | ||
bb.prng = prng; | ||
if (bb.sample()) { | ||
return be.sample_corrupted(clean, prng); | ||
} | ||
return clean; | ||
} | ||
|
||
SampleType propose_clean(const std::vector<SampleType>& corrupted, | ||
std::mt19937* prng) { | ||
// We approximate the maximum likelihood estimate by taking the mode of | ||
// corrupted. The full solution would construct BaseEmissor and | ||
// BetaBernoulli instances for each choice of clean and picking the | ||
// clean with the highest combined logp_score(). | ||
std::unordered_map<SampleType, int> counts; | ||
SampleType mode; | ||
int max_count = 0; | ||
for (const SampleType& c: corrupted) { | ||
++counts[c]; | ||
if (counts[c] > max_count) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optional nit: it seems cleaner to me to do this after the loop terminates with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I considered that, but the code to use std::max_element over a map turns out to be many lines that look ugly to me: auto max = std::max_element( return max->first; |
||
max_count = counts[c]; | ||
mode = c; | ||
} | ||
} | ||
return mode; | ||
} | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Apache License, Version 2.0, refer to LICENSE.txt | ||
|
||
#define BOOST_TEST_MODULE test Sometimes | ||
|
||
#include <random> | ||
|
||
#include "emissions/bitflip.hh" | ||
#include "emissions/sometimes.hh" | ||
|
||
#include <boost/test/included/unit_test.hpp> | ||
|
||
BOOST_AUTO_TEST_CASE(test_simple) { | ||
Sometimes<BitFlip, bool> sbf; | ||
|
||
double orig_lp = sbf.logp_score(); | ||
BOOST_TEST(sbf.N == 0); | ||
sbf.incorporate(std::make_pair<bool, bool>(true, false)); | ||
BOOST_TEST(sbf.logp_score() < 0.0); | ||
BOOST_TEST(sbf.N == 1); | ||
sbf.unincorporate(std::make_pair<bool, bool>(true, false)); | ||
BOOST_TEST(sbf.logp_score() == orig_lp); | ||
BOOST_TEST(sbf.N == 0); | ||
|
||
sbf.incorporate(std::make_pair<bool, bool>(false, true)); | ||
sbf.incorporate(std::make_pair<bool, bool>(false, true)); | ||
BOOST_TEST(sbf.logp_score() < 0.0); | ||
BOOST_TEST(sbf.N == 2); | ||
|
||
BOOST_TEST(sbf.logp(std::make_pair<bool, bool>(true, false)) < 0.0); | ||
|
||
std::mt19937 prng; | ||
BOOST_TEST(sbf.propose_clean({true, true, false}, &prng)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't entirely understand how
propose_clean
will be used, but would it be better to choose a random element ofcorrupted
to invert? Or iscorrupted
expected to contain all the same values?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
corrupted is expected to contain all the same values. Specifically, the contract for the corrupted vector is that it is the output of sample_corrupted over repeated calls with the same clean value.