Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BitFlip and Sometimes Emission classes #46

Merged
merged 4 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions cxx/emissions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,51 @@ cc_library(
],
)

cc_library(
name = "bitflip",
srcs = ["bitflip.hh"],
visibility = ["//:__subpackages__"],
deps = [":base"],
)

cc_library(
name = "gaussian",
srcs = ["gaussian.hh"],
visibility = ["//:__subpackages__"],
deps = [":base"],
)

cc_library(
name = "sometimes",
srcs = ["sometimes.hh"],
visibility = ["//:__subpackages__"],
deps = [
":base",
"//distributions:beta_bernoulli",
],
)

cc_test(
name = "bitflip_test",
srcs = ["bitflip_test.cc"],
deps = [
":bitflip",
"@boost//:algorithm",
"@boost//:test",
],
)

cc_test(
name = "sometimes_test",
srcs = ["sometimes_test.cc"],
deps = [
":bitflip",
":sometimes",
"@boost//:algorithm",
"@boost//:test",
],
)

# TODO(thomaswc): Fix and re-enable.
#cc_test(
# name = "gaussian_test",
Expand Down
42 changes: 42 additions & 0 deletions cxx/emissions/bitflip.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include <cassert>
#include "emissions/base.hh"

// A *deterministic* Emission class that always emits not(clean).
// Most users will want to combine this with Sometimes.
class BitFlip : public Emission<bool> {
public:
BitFlip() {};

void incorporate(const std::pair<bool, bool>& x) {
assert(x.first != x.second);
++N;
}

void unincorporate(const std::pair<bool, bool>& x) {
assert(x.first != x.second);
--N;
}

double logp(const std::pair<bool, bool>& x) const {
assert(x.first != x.second);
return 0.0;
}

double logp_score() const {
return 0.0;
}

// No hyperparameters to transition!
void transition_hyperparameters() {}

bool sample_corrupted(const bool& clean, std::mt19937* unused_prng) {
return !clean;
}

bool propose_clean(const std::vector<bool>& corrupted,
std::mt19937* unused_prng) {
return !corrupted[0];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't entirely understand how propose_clean will be used, but would it be better to choose a random element of corrupted to invert? Or is corrupted expected to contain all the same values?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

corrupted is expected to contain all the same values. Specifically, the contract for the corrupted vector is that it is the output of sample_corrupted over repeated calls with the same clean value.

}
};
33 changes: 33 additions & 0 deletions cxx/emissions/bitflip_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Apache License, Version 2.0, refer to LICENSE.txt

#define BOOST_TEST_MODULE test BitFlip

#include <random>

#include "emissions/bitflip.hh"

#include <boost/test/included/unit_test.hpp>

BOOST_AUTO_TEST_CASE(test_simple) {
BitFlip bf;

BOOST_TEST(bf.logp_score() == 0.0);
BOOST_TEST(bf.N == 0);
bf.incorporate(std::make_pair<bool, bool>(true, false));
BOOST_TEST(bf.logp_score() == 0.0);
BOOST_TEST(bf.N == 1);
bf.unincorporate(std::make_pair<bool, bool>(true, false));
BOOST_TEST(bf.logp_score() == 0.0);
BOOST_TEST(bf.N == 0);
bf.incorporate(std::make_pair<bool, bool>(false, true));
bf.incorporate(std::make_pair<bool, bool>(false, true));
BOOST_TEST(bf.logp_score() == 0.0);
BOOST_TEST(bf.N == 2);

BOOST_TEST(bf.logp(std::make_pair<bool, bool>(true, false)) == 0.0);

std::mt19937 prng;
BOOST_TEST(bf.sample_corrupted(false, &prng));

BOOST_TEST(bf.propose_clean({false, false, false}, &prng));
}
75 changes: 75 additions & 0 deletions cxx/emissions/sometimes.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#pragma once

#include <unordered_map>

#include "distributions/beta_bernoulli.hh"
#include "emissions/base.hh"

// An Emission class that sometimes applies BaseEmissor and sometimes doesn't.
// BaseEmissor must (1) of type Emission<SampleType> and (2) assign zero
// probability to <clean, dirty> pairs with clean == dirty. [For example,
// BitFlip and Gaussian both satisfy #2].
template <typename BaseEmissor, typename SampleType = double>
class Sometimes : public Emission<SampleType> {
public:
BetaBernoulli bb;
BaseEmissor be;

Sometimes() : bb(nullptr) {};

void incorporate(const std::pair<SampleType, SampleType>& x) {
++(this->N);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own understanding, how come N is accessed by pointer here and not in the other Emission subclasses?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really really wish I knew. I originally had it as ++N; and the compiler complained that no variable named N was defined in scope.

bb.incorporate(x.first != x.second);
if (x.first != x.second) {
be.incorporate(x);
}
}

void unincorporate(const std::pair<SampleType, SampleType>& x) {
--(this->N);
bb.unincorporate(x.first != x.second);
if (x.first != x.second) {
be.unincorporate(x);
}
}

double logp(const std::pair<SampleType, SampleType>& x) const {
return bb.logp(x.first != x.second) + be.logp(x);
}

double logp_score() const {
return bb.logp_score() + be.logp_score();
}

void transition_hyperparameters() {
be.transition_hyperparameters();
bb.transition_hyperparameters();
}

SampleType sample_corrupted(const SampleType& clean, std::mt19937* prng) {
bb.prng = prng;
if (bb.sample()) {
return be.sample_corrupted(clean, prng);
}
return clean;
}

SampleType propose_clean(const std::vector<SampleType>& corrupted,
std::mt19937* prng) {
// We approximate the maximum likelihood estimate by taking the mode of
// corrupted. The full solution would construct BaseEmissor and
// BetaBernoulli instances for each choice of clean and picking the
// clean with the highest combined logp_score().
std::unordered_map<SampleType, int> counts;
SampleType mode;
int max_count = 0;
for (const SampleType& c: corrupted) {
++counts[c];
if (counts[c] > max_count) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional nit: it seems cleaner to me to do this after the loop terminates with std::max_element.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered that, but the code to use std::max_element over a map turns out to be many lines that look ugly to me:

auto max = std::max_element(
std::begin(counts), std::end(counts),
[] (const auto &p1, const auto &p2) { return p1.second < p2.second; }
)

return max->first;

max_count = counts[c];
mode = c;
}
}
return mode;
}
};
33 changes: 33 additions & 0 deletions cxx/emissions/sometimes_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Apache License, Version 2.0, refer to LICENSE.txt

#define BOOST_TEST_MODULE test Sometimes

#include <random>

#include "emissions/bitflip.hh"
#include "emissions/sometimes.hh"

#include <boost/test/included/unit_test.hpp>

BOOST_AUTO_TEST_CASE(test_simple) {
Sometimes<BitFlip, bool> sbf;

double orig_lp = sbf.logp_score();
BOOST_TEST(sbf.N == 0);
sbf.incorporate(std::make_pair<bool, bool>(true, false));
BOOST_TEST(sbf.logp_score() < 0.0);
BOOST_TEST(sbf.N == 1);
sbf.unincorporate(std::make_pair<bool, bool>(true, false));
BOOST_TEST(sbf.logp_score() == orig_lp);
BOOST_TEST(sbf.N == 0);

sbf.incorporate(std::make_pair<bool, bool>(false, true));
sbf.incorporate(std::make_pair<bool, bool>(false, true));
BOOST_TEST(sbf.logp_score() < 0.0);
BOOST_TEST(sbf.N == 2);

BOOST_TEST(sbf.logp(std::make_pair<bool, bool>(true, false)) < 0.0);

std::mt19937 prng;
BOOST_TEST(sbf.propose_clean({true, true, false}, &prng));
}