diff --git a/cxx/distributions/BUILD b/cxx/distributions/BUILD index cb5f564..b5852ac 100644 --- a/cxx/distributions/BUILD +++ b/cxx/distributions/BUILD @@ -117,6 +117,7 @@ cc_library( deps = [ ":base", ":dirichlet_categorical", + "//emissions:string_alignment", ], ) @@ -232,6 +233,15 @@ cc_test( ], ) +cc_test( + name = "string_nat_test", + srcs = ["string_nat_test.cc"], + deps = [ + ":string_nat", + "@boost//:test", + ], +) + cc_test( name = "zero_mean_normal_test", srcs = ["zero_mean_normal_test.cc"], diff --git a/cxx/distributions/adapter.hh b/cxx/distributions/adapter.hh index b50bc7f..cb9594f 100644 --- a/cxx/distributions/adapter.hh +++ b/cxx/distributions/adapter.hh @@ -57,5 +57,8 @@ class DistributionAdapter : public Distribution { void init_theta(std::mt19937* prng) { d->init_theta(prng); } void transition_theta(std::mt19937* prng) { d->transition_theta(prng); } + // TODO(thomaswc): Define nearest methods for the DistributionAdapter + // instantiations we use. + ~DistributionAdapter() { delete d; } }; diff --git a/cxx/distributions/base.hh b/cxx/distributions/base.hh index ffd19ab..32cff7c 100644 --- a/cxx/distributions/base.hh +++ b/cxx/distributions/base.hh @@ -48,5 +48,11 @@ class Distribution { // NonconjugateDistribution need define this. virtual void transition_theta(std::mt19937* prng) {}; + // Return the value nearest to x that is given non-zero probability by + // this distribution. + virtual T nearest(const T& x) const { + return x; + } + virtual ~Distribution() = default; }; diff --git a/cxx/distributions/dirichlet_categorical.cc b/cxx/distributions/dirichlet_categorical.cc index 1eb5f2f..7cb683b 100644 --- a/cxx/distributions/dirichlet_categorical.cc +++ b/cxx/distributions/dirichlet_categorical.cc @@ -59,3 +59,14 @@ void DirichletCategorical::transition_hyperparameters(std::mt19937* prng) { alpha = alphas[i]; } } + +int DirichletCategorical::nearest(const int& x) const { + if (x < 0) { + return 0; + } + // x can't be negative here, so safe to cast to size_t. + if (size_t(x) >= counts.size()) { + return counts.size() - 1; + } + return x; +} diff --git a/cxx/distributions/dirichlet_categorical.hh b/cxx/distributions/dirichlet_categorical.hh index cea7bca..6595913 100644 --- a/cxx/distributions/dirichlet_categorical.hh +++ b/cxx/distributions/dirichlet_categorical.hh @@ -27,4 +27,6 @@ class DirichletCategorical : public Distribution { int sample(std::mt19937* prng); void transition_hyperparameters(std::mt19937* prng); + + int nearest(const int& x) const; }; diff --git a/cxx/distributions/dirichlet_categorical_test.cc b/cxx/distributions/dirichlet_categorical_test.cc index f78e96a..2ca12bb 100644 --- a/cxx/distributions/dirichlet_categorical_test.cc +++ b/cxx/distributions/dirichlet_categorical_test.cc @@ -162,3 +162,11 @@ BOOST_AUTO_TEST_CASE(test_sample_and_log_prob) { BOOST_TEST(abs(probs[i] - approx_p) <= 3 * stddev); } } + +BOOST_AUTO_TEST_CASE(test_nearest) { + DirichletCategorical dc(12); + + BOOST_TEST(dc.nearest(-5) == 0); + BOOST_TEST(dc.nearest(99) == 11); + BOOST_TEST(dc.nearest(7) == 7); +} diff --git a/cxx/distributions/string_nat.hh b/cxx/distributions/string_nat.hh index 79f5349..5acd2d9 100644 --- a/cxx/distributions/string_nat.hh +++ b/cxx/distributions/string_nat.hh @@ -3,6 +3,8 @@ #pragma once +#include + #include "distributions/bigram.hh" // A distribution over natural numbers represented as strings of digits. @@ -12,4 +14,14 @@ class StringNat : public Bigram { public: StringNat(size_t _max_length = 20): Bigram(_max_length, '0', '9') {} + + std::string nearest(const std::string& x) const { + std::string s; + for (const char& c : x) { + if (std::isdigit(c)) { + s += c; + } + } + return s; + } }; diff --git a/cxx/distributions/string_nat_test.cc b/cxx/distributions/string_nat_test.cc new file mode 100644 index 0000000..1e8284c --- /dev/null +++ b/cxx/distributions/string_nat_test.cc @@ -0,0 +1,24 @@ +// Apache License, Version 2.0, refer to LICENSE.txt + +#define BOOST_TEST_MODULE test StringNat + +#include "distributions/string_nat.hh" + +#include + +BOOST_AUTO_TEST_CASE(test_simple) { + StringNat sn; + + sn.incorporate("42"); + sn.incorporate("0"); + BOOST_TEST(sn.N == 2); + sn.unincorporate("0"); + BOOST_TEST(sn.N == 1); +} + +BOOST_AUTO_TEST_CASE(test_nearest) { + StringNat sn; + + BOOST_TEST(sn.nearest("1234") == "1234"); + BOOST_TEST(sn.nearest("a77z99") == "7799"); +} diff --git a/cxx/distributions/stringcat.cc b/cxx/distributions/stringcat.cc index 1e199ed..38200d9 100644 --- a/cxx/distributions/stringcat.cc +++ b/cxx/distributions/stringcat.cc @@ -3,7 +3,9 @@ #include #include +#include #include "distributions/stringcat.hh" +#include "emissions/string_alignment.hh" int StringCat::string_to_index(const std::string& s) const { auto it = std::find(strings.begin(), strings.end(), s); @@ -33,3 +35,23 @@ std::string StringCat::sample(std::mt19937* prng) { void StringCat::transition_hyperparameters(std::mt19937* prng) { dc.transition_hyperparameters(prng); } + +std::string StringCat::nearest(const std::string& x) const { + if (std::find(strings.begin(), strings.end(), x) != strings.end()) { + return x; + } + + const std::string *nearest = &(strings[0]); + double lowest_distance = std::numeric_limits::max(); + for (const std::string& s : strings) { + std::vector alignments; + topk_alignments(1, s, x, edit_distance, &alignments); + double d = alignments[0].cost; + if (d < lowest_distance) { + lowest_distance = d; + nearest = &s; + } + } + + return *nearest; +} diff --git a/cxx/distributions/stringcat.hh b/cxx/distributions/stringcat.hh index 8fdec44..9aa2ced 100644 --- a/cxx/distributions/stringcat.hh +++ b/cxx/distributions/stringcat.hh @@ -31,4 +31,6 @@ class StringCat : public Distribution { void set_alpha(double alphat); void transition_hyperparameters(std::mt19937* prng); + + std::string nearest(const std::string& x) const; }; diff --git a/cxx/distributions/stringcat_test.cc b/cxx/distributions/stringcat_test.cc index c23f71c..9e56b90 100644 --- a/cxx/distributions/stringcat_test.cc +++ b/cxx/distributions/stringcat_test.cc @@ -30,4 +30,7 @@ BOOST_AUTO_TEST_CASE(test_simple) { auto it = std::find(strings.begin(), strings.end(), samp); bool found = (it != strings.end()); BOOST_TEST(found); + + BOOST_TEST(sc.nearest("test") == "test"); + BOOST_TEST(sc.nearest("otter") == "other"); }