Skip to content

Commit

Permalink
Re-enable gaussian_test plus fixes to SimpleStringEmission
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasColthurst committed Jun 13, 2024
1 parent 355cb72 commit f72023c
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 35 deletions.
21 changes: 10 additions & 11 deletions cxx/emissions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ cc_test(
],
)

cc_test(
name = "gaussian_test",
srcs = ["gaussian_test.cc"],
deps = [
":gaussian",
"@boost//:algorithm",
"@boost//:test",
],
)

cc_test(
name = "simple_string_test",
srcs = ["simple_string_test.cc"],
Expand All @@ -77,14 +87,3 @@ cc_test(
"@boost//:test",
],
)

# TODO(thomaswc): Fix and re-enable.
#cc_test(
# name = "gaussian_test",
# srcs = ["gaussian_test.cc"],
# deps = [
# ":gaussian",
# "@boost//:algorithm",
# "@boost//:test",
# ],
#)
7 changes: 2 additions & 5 deletions cxx/emissions/base.hh
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
#pragma once

#include <cstdio>
#include <cstdlib>
#include <cassert>
#include <random>
#include <utility>

#include "distributions/base.hh"

template <typename SampleType = double>
class Emission : public Distribution<std::pair<SampleType, SampleType>> {
public:
virtual std::pair<SampleType, SampleType> sample() {
printf("sample() should never be called on an Emission\n");
std::abort();
assert(false && "sample() should never be called on an Emission\n");
}

// Return a stochastically corrupted version of clean.
Expand Down
10 changes: 7 additions & 3 deletions cxx/emissions/gaussian.hh
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

class GaussianEmission : public Emission<double> {
public:
ZeroMeanNormal zmn(nullptr);
ZeroMeanNormal zmn;

GaussianEmission() {}
GaussianEmission() : zmn(nullptr) {}

void incorporate(const std::pair<double, double>& x) {
++N;
Expand All @@ -27,6 +27,10 @@ class GaussianEmission : public Emission<double> {
return zmn.logp_score();
}

void transition_hyperparameters() {
zmn.transition_hyperparameters();
}

double sample_corrupted(const double& clean, std::mt19937* prng) {
zmn.prng = prng;
return clean + zmn.sample();
Expand All @@ -45,4 +49,4 @@ class GaussianEmission : public Emission<double> {
}
return mean;
}
}
};
17 changes: 16 additions & 1 deletion cxx/emissions/gaussian_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,24 @@ namespace tt = boost::test_tools;
BOOST_AUTO_TEST_CASE(simple) {
GaussianEmission ge;

BOOST_TEST(ge.logp_score() == -0.69314718055994529, tt::tolerance(1e-6));
ge.incorporate(std::make_pair(1.0, 1.1));
BOOST_TEST(ge.N == 1);
ge.incorporate(std::make_pair(5.0, 4.9));
BOOST_TEST(ge.N == 2);
ge.unincorporate(std::make_pair(1.0, 1.1));
BOOST_TEST(ge.N == 1);

BOOST_TEST(ge.logp_score() == 0.0, tt::tolerance(1e-6));
BOOST_TEST(ge.logp_score() == -2.0831868619777163, tt::tolerance(1e-6));

BOOST_TEST(ge.logp(std::make_pair(2.0, 2.001)) == -0.80313245109477638,
tt::tolerance(1e-6));

std::mt19937 prng;
double dirty = ge.sample_corrupted(5.0, &prng);
BOOST_TEST(dirty < 6.0);
BOOST_TEST(dirty > 4.0);

double clean = ge.propose_clean({-5.0, -4.9, -5.1, -5.2, -4.8}, &prng);
BOOST_TEST(clean == -5.0);
}
42 changes: 27 additions & 15 deletions cxx/emissions/simple_string.hh
Original file line number Diff line number Diff line change
Expand Up @@ -25,41 +25,53 @@ class SimpleStringEmission : public Emission<std::string> {
}

void unincorporate(const std::pair<std::string, std::string>& x) {
++N;
--N;
corporate(x.first, x.second, false);
}

void corporate(const std::string& clean, const std::string& dirty, bool d) {
if (clean.empty()) {
// All of dirty must be insertions.
for (size_t i = 0; i < dirty.length(); ++i) {
insertion.incorporate(d);
d ? insertion.incorporate(1) : insertion.unincorporate(1);
}
return;
}

if (dirty.empty()) {
// All of clean must have be deleted.
for (size_t i = 0; i < dirty.length(); ++i) {
deletion.incorporate(d);
d ? deletion.incorporate(1) : deletion.unincorporate(1);
}
return;
}

if (clean[0] == dirty[0]) {
substitution.incorporate(!d);
insertion.incorporate(!d);
deletion.incorporate(!d);
if (d) {
substitution.incorporate(0);
insertion.incorporate(0);
deletion.incorporate(0);
} else {
substitution.unincorporate(0);
insertion.unincorporate(0);
deletion.unincorporate(0);
}
corporate(clean.substr(1, std::string::npos),
dirty.substr(1, std::string::npos),
d);
return;
}

if (clean.back() == dirty.back()) {
substitution.incorporate(!d);
insertion.incorporate(!d);
deletion.incorporate(!d);
if (d) {
substitution.incorporate(0);
insertion.incorporate(0);
deletion.incorporate(0);
} else {
substitution.unincorporate(0);
insertion.unincorporate(0);
deletion.unincorporate(0);
}
corporate(clean.substr(0, clean.length() - 1),
dirty.substr(0, dirty.length() - 1),
d);
Expand All @@ -74,29 +86,29 @@ class SimpleStringEmission : public Emission<std::string> {
// So instead, we just guess based on the std::string lengths.
if (clean.length() < dirty.length()) {
// Probably an insertion.
insertion.incorporate(d);
d ? insertion.incorporate(1) : insertion.unincorporate(1);
corporate(clean, dirty.substr(1, std::string::npos), d);
return;
}

if (clean.length() > dirty.length()) {
// Probably a deletion.
deletion.incorporate(d);
d ? deletion.incorporate(1) : deletion.unincorporate(1);
corporate(clean.substr(1, std::string::npos), dirty, d);
return;
}

// Probably a substitution.
substitution.incorporate(d);
d ? substitution.incorporate(1) : substitution.unincorporate(1);
corporate(clean.substr(1, std::string::npos),
dirty.substr(1, std::string::npos),
d);
}

double logp(const std::pair<std::string, std::string>& x) const {
incorporate(x);
const_cast<SimpleStringEmission*>(this)->incorporate(x);
double lp = logp_score();
unincorporate(x);
const_cast<SimpleStringEmission*>(this)->unincorporate(x);
return lp - logp_score();
}

Expand Down Expand Up @@ -169,4 +181,4 @@ class SimpleStringEmission : public Emission<std::string> {
}
}

}
};

0 comments on commit f72023c

Please sign in to comment.