Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move distribution implementations to cc files #38

Merged
merged 1 commit into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions cxx/distributions/BUILD
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
licenses(["notice"])


cc_library(
name = "distributions",
visibility = ["//:__subpackages__"],
deps = [
":adapter",
":base",
":bigram",
":beta_bernoulli",
":bigram",
":dirichlet_categorical",
":normal",
],
Expand All @@ -31,6 +30,7 @@ cc_library(

cc_library(
name = "beta_bernoulli",
srcs = ["beta_bernoulli.cc"],
hdrs = ["beta_bernoulli.hh"],
visibility = ["//:__subpackages__"],
deps = [
Expand All @@ -42,6 +42,7 @@ cc_library(

cc_library(
name = "bigram",
srcs = ["bigram.cc"],
hdrs = ["bigram.hh"],
visibility = ["//:__subpackages__"],
deps = [
Expand All @@ -54,7 +55,8 @@ cc_library(

cc_library(
name = "dirichlet_categorical",
srcs = ["dirichlet_categorical.hh"],
srcs = ["dirichlet_categorical.cc"],
hdrs = ["dirichlet_categorical.hh"],
visibility = ["//:__subpackages__"],
deps = [
":base",
Expand All @@ -64,8 +66,8 @@ cc_library(

cc_library(
name = "normal",
hdrs = ["normal.hh"],
srcs = ["normal.cc"],
hdrs = ["normal.hh"],
visibility = ["//:__subpackages__"],
deps = [
":base",
Expand Down Expand Up @@ -109,8 +111,8 @@ cc_test(
name = "dirichlet_categorical_test",
srcs = ["dirichlet_categorical_test.cc"],
deps = [
":dirichlet_categorical",
":beta_bernoulli",
":dirichlet_categorical",
"@boost//:algorithm",
"@boost//:test",
],
Expand Down
4 changes: 1 addition & 3 deletions cxx/distributions/adapter.hh
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ class DistributionAdapter : public Distribution<std::string> {
return to_string(s);
}

void transition_hyperparameters() {
d->transition_hyperparameters();
}
void transition_hyperparameters() { d->transition_hyperparameters(); }

~DistributionAdapter() { delete d; }
};
63 changes: 63 additions & 0 deletions cxx/distributions/beta_bernoulli.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright 2024
// See LICENSE.txt

#include "distributions/beta_bernoulli.hh"

#include <cassert>

#include "util_math.hh"

void BetaBernoulli::incorporate(const double& x) {
assert(x == 0 || x == 1);
++N;
s += x;
}

void BetaBernoulli::unincorporate(const double& x) {
assert(x == 0 || x == 1);
--N;
s -= x;
assert(0 <= s);
assert(0 <= N);
}

double BetaBernoulli::logp(const double& x) const {
assert(x == 0 || x == 1);
double log_denom = log(N + alpha + beta);
double log_numer = x ? log(s + alpha) : log(N - s + beta);
return log_numer - log_denom;
}

double BetaBernoulli::logp_score() const {
double v1 = lbeta(s + alpha, N - s + beta);
double v2 = lbeta(alpha, beta);
return v1 - v2;
}

double BetaBernoulli::sample() {
double p = exp(logp(1));
std::vector<int> items{0, 1};
std::vector<double> weights{1 - p, p};
int idx = choice(weights, prng);
return items[idx];
}

void BetaBernoulli::transition_hyperparameters() {
std::vector<double> logps;
std::vector<std::pair<double, double>> hypers;
// C++ doesn't yet allow range for-loops over existing variables. Sigh.
for (double alphat : alpha_grid) {
for (double betat : beta_grid) {
alpha = alphat;
beta = betat;
double lp = logp_score();
if (!std::isnan(lp)) {
logps.push_back(logp_score());
hypers.push_back(std::make_pair(alpha, beta));
}
}
}
int i = sample_from_logps(logps, prng);
alpha = hypers[i].first;
beta = hypers[i].second;
}
57 changes: 8 additions & 49 deletions cxx/distributions/beta_bernoulli.hh
Original file line number Diff line number Diff line change
Expand Up @@ -23,57 +23,16 @@ class BetaBernoulli : public Distribution<double> {
alpha_grid = log_linspace(1e-4, 1e4, 10, true);
beta_grid = log_linspace(1e-4, 1e4, 10, true);
}
void incorporate(const double& x) {
assert(x == 0 || x == 1);
++N;
s += x;
}
void unincorporate(const double& x) {
assert(x == 0 || x == 1);
--N;
s -= x;
assert(0 <= s);
assert(0 <= N);
}

double logp(const double& x) const {
assert(x == 0 || x == 1);
double log_denom = log(N + alpha + beta);
double log_numer = x ? log(s + alpha) : log(N - s + beta);
return log_numer - log_denom;
}
void incorporate(const double& x);

double logp_score() const {
double v1 = lbeta(s + alpha, N - s + beta);
double v2 = lbeta(alpha, beta);
return v1 - v2;
}
void unincorporate(const double& x);

double sample() {
double p = exp(logp(1));
std::vector<int> items{0, 1};
std::vector<double> weights{1 - p, p};
int idx = choice(weights, prng);
return items[idx];
}
double logp(const double& x) const;

void transition_hyperparameters() {
std::vector<double> logps;
std::vector<std::pair<double, double>> hypers;
// C++ doesn't yet allow range for-loops over existing variables. Sigh.
for (double alphat : alpha_grid) {
for (double betat : beta_grid) {
alpha = alphat;
beta = betat;
double lp = logp_score();
if (!std::isnan(lp)) {
logps.push_back(logp_score());
hypers.push_back(std::make_pair(alpha, beta));
}
}
}
int i = sample_from_logps(logps, prng);
alpha = hypers[i].first;
beta = hypers[i].second;
}
double logp_score() const;

double sample();

void transition_hyperparameters();
};
120 changes: 120 additions & 0 deletions cxx/distributions/bigram.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright 2024
// See LICENSE.txt

#include "distributions/bigram.hh"

#include <cassert>

#include "distributions/base.hh"

void Bigram::assert_valid_char(const char c) const {
assert(c >= ' ' && c <= '~');
}

size_t Bigram::char_to_index(const char c) const {
assert_valid_char(c);
return c - ' ';
}

char Bigram::index_to_char(const size_t i) const {
const char c = i + ' ';
assert_valid_char(c);
return c;
}

std::vector<size_t> Bigram::string_to_indices(const std::string& str) const {
// Convert the string to a vector of indices between 0 and `num_chars`,
// with a start/stop symbol at the beginning/end.
std::vector<size_t> inds = {num_chars};
for (const char& c : str) {
inds.push_back(char_to_index(c));
}
inds.push_back(num_chars);
return inds;
}

void Bigram::incorporate(const std::string& x) {
const std::vector<size_t> indices = string_to_indices(x);
for (size_t i = 0; i != indices.size() - 1; ++i) {
transition_dists[indices[i]].incorporate(indices[i + 1]);
}
++N;
}

void Bigram::unincorporate(const std::string& s) {
const std::vector<size_t> indices = string_to_indices(s);
for (size_t i = 0; i != indices.size() - 1; ++i) {
transition_dists[indices[i]].unincorporate(indices[i + 1]);
}
--N;
}

double Bigram::logp(const std::string& s) const {
const std::vector<size_t> indices = string_to_indices(s);
double total_logp = 0.0;
for (size_t i = 0; i != indices.size() - 1; ++i) {
total_logp += transition_dists[indices[i]].logp(indices[i + 1]);
// Incorporate each value so that subsequent probabilities are
// conditioned on it.
transition_dists[indices[i]].incorporate(indices[i + 1]);
}
for (size_t i = 0; i != indices.size() - 1; ++i) {
transition_dists[indices[i]].unincorporate(indices[i + 1]);
}
return total_logp;
}

double Bigram::logp_score() const {
double logp = 0;
for (const auto& d : transition_dists) {
logp += d.logp_score();
}
return logp;
}

std::string Bigram::sample() {
std::string sampled_string;
// TODO(emilyaf): Reconsider the reserved length and maybe enforce a
// max length.
sampled_string.reserve(30);
// Sample the first character conditioned on the stop/start symbol.
size_t current_ind = num_chars;
size_t next_ind = transition_dists[current_ind].sample();
transition_dists[current_ind].incorporate(next_ind);
current_ind = next_ind;

// Sample additional characters until the stop/start symbol is sampled.
// Incorporate the sampled character at each loop iteration so that
// subsequent samples are conditioned on its observation.
while (current_ind != num_chars) {
sampled_string += index_to_char(current_ind);
next_ind = transition_dists[current_ind].sample();
transition_dists[current_ind].incorporate(next_ind);
current_ind = next_ind;
}
unincorporate(sampled_string);
return sampled_string;
}

void Bigram::set_alpha(double alphat) {
alpha = alphat;
for (auto& trans_dist : transition_dists) {
trans_dist.alpha = alpha;
}
}

void Bigram::transition_hyperparameters() {
std::vector<double> logps;
std::vector<double> alphas;
// C++ doesn't yet allow range for-loops over existing variables. Sigh.
for (double alphat : ALPHA_GRID) {
set_alpha(alphat);
double lp = logp_score();
if (!std::isnan(lp)) {
logps.push_back(logp_score());
alphas.push_back(alphat);
}
}
int i = sample_from_logps(logps, prng);
set_alpha(alphas[i]);
}
Loading