Skip to content

Commit

Permalink
Add nearest method for distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasColthurst committed Oct 1, 2024
1 parent 6bf7918 commit 91fddc7
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 0 deletions.
10 changes: 10 additions & 0 deletions cxx/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ cc_library(
deps = [
":base",
":dirichlet_categorical",
"//emissions:string_alignment",
],
)

Expand Down Expand Up @@ -232,6 +233,15 @@ cc_test(
],
)

cc_test(
name = "string_nat_test",
srcs = ["string_nat_test.cc"],
deps = [
":string_nat",
"@boost//:test",
],
)

cc_test(
name = "zero_mean_normal_test",
srcs = ["zero_mean_normal_test.cc"],
Expand Down
3 changes: 3 additions & 0 deletions cxx/distributions/adapter.hh
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,8 @@ class DistributionAdapter : public Distribution<std::string> {
void init_theta(std::mt19937* prng) { d->init_theta(prng); }
void transition_theta(std::mt19937* prng) { d->transition_theta(prng); }

// TODO(thomaswc): Define nearest methods for the DistributionAdapter
// instantiations we use.

~DistributionAdapter() { delete d; }
};
6 changes: 6 additions & 0 deletions cxx/distributions/base.hh
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,11 @@ class Distribution {
// NonconjugateDistribution need define this.
virtual void transition_theta(std::mt19937* prng) {};

// Return the value nearest to x that is given non-zero probability by
// this distribution.
virtual T nearest(const T& x) const {
return x;
}

virtual ~Distribution() = default;
};
11 changes: 11 additions & 0 deletions cxx/distributions/dirichlet_categorical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,14 @@ void DirichletCategorical::transition_hyperparameters(std::mt19937* prng) {
alpha = alphas[i];
}
}

int DirichletCategorical::nearest(const int& x) const {
if (x < 0) {
return 0;
}
// x can't be negative here, so safe to cast to size_t.
if (size_t(x) >= counts.size()) {
return counts.size() - 1;
}
return x;
}
2 changes: 2 additions & 0 deletions cxx/distributions/dirichlet_categorical.hh
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ class DirichletCategorical : public Distribution<int> {
int sample(std::mt19937* prng);

void transition_hyperparameters(std::mt19937* prng);

int nearest(const int& x) const;
};
8 changes: 8 additions & 0 deletions cxx/distributions/dirichlet_categorical_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,11 @@ BOOST_AUTO_TEST_CASE(test_sample_and_log_prob) {
BOOST_TEST(abs(probs[i] - approx_p) <= 3 * stddev);
}
}

BOOST_AUTO_TEST_CASE(test_nearest) {
DirichletCategorical dc(12);

BOOST_TEST(dc.nearest(-5) == 0);
BOOST_TEST(dc.nearest(99) == 11);
BOOST_TEST(dc.nearest(7) == 7);
}
12 changes: 12 additions & 0 deletions cxx/distributions/string_nat.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#pragma once

#include <cctype>

#include "distributions/bigram.hh"

// A distribution over natural numbers represented as strings of digits.
Expand All @@ -12,4 +14,14 @@
class StringNat : public Bigram {
public:
StringNat(size_t _max_length = 20): Bigram(_max_length, '0', '9') {}

std::string nearest(const std::string& x) const {
std::string s;
for (const char& c : x) {
if (std::isdigit(c)) {
s += c;
}
}
return s;
}
};
24 changes: 24 additions & 0 deletions cxx/distributions/string_nat_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Apache License, Version 2.0, refer to LICENSE.txt

#define BOOST_TEST_MODULE test StringNat

#include "distributions/string_nat.hh"

#include <boost/test/included/unit_test.hpp>

BOOST_AUTO_TEST_CASE(test_simple) {
StringNat sn;

sn.incorporate("42");
sn.incorporate("0");
BOOST_TEST(sn.N == 2);
sn.unincorporate("0");
BOOST_TEST(sn.N == 1);
}

BOOST_AUTO_TEST_CASE(test_nearest) {
StringNat sn;

BOOST_TEST(sn.nearest("1234") == "1234");
BOOST_TEST(sn.nearest("a77z99") == "7799");
}
22 changes: 22 additions & 0 deletions cxx/distributions/stringcat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

#include <algorithm>
#include <cassert>
#include <limits>
#include "distributions/stringcat.hh"
#include "emissions/string_alignment.hh"

int StringCat::string_to_index(const std::string& s) const {
auto it = std::find(strings.begin(), strings.end(), s);
Expand Down Expand Up @@ -33,3 +35,23 @@ std::string StringCat::sample(std::mt19937* prng) {
void StringCat::transition_hyperparameters(std::mt19937* prng) {
dc.transition_hyperparameters(prng);
}

std::string StringCat::nearest(const std::string& x) const {
if (std::find(strings.begin(), strings.end(), x) != strings.end()) {
return x;
}

const std::string *nearest = &(strings[0]);
double lowest_distance = std::numeric_limits<double>::max();
for (const std::string& s : strings) {
std::vector<StrAlignment> alignments;
topk_alignments(1, s, x, edit_distance, &alignments);
double d = alignments[0].cost;
if (d < lowest_distance) {
lowest_distance = d;
nearest = &s;
}
}

return *nearest;
}
2 changes: 2 additions & 0 deletions cxx/distributions/stringcat.hh
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@ class StringCat : public Distribution<std::string> {
void set_alpha(double alphat);

void transition_hyperparameters(std::mt19937* prng);

std::string nearest(const std::string& x) const;
};
3 changes: 3 additions & 0 deletions cxx/distributions/stringcat_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,7 @@ BOOST_AUTO_TEST_CASE(test_simple) {
auto it = std::find(strings.begin(), strings.end(), samp);
bool found = (it != strings.end());
BOOST_TEST(found);

BOOST_TEST(sc.nearest("test") == "test");
BOOST_TEST(sc.nearest("otter") == "other");
}

0 comments on commit 91fddc7

Please sign in to comment.