forked from probsys/hierarchical-irm
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #38 from probcomp/060624-srvasude-move_cc
Move distribution implementations to cc files
- Loading branch information
Showing
9 changed files
with
293 additions
and
214 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
// Copyright 2024 | ||
// See LICENSE.txt | ||
|
||
#include "distributions/beta_bernoulli.hh" | ||
|
||
#include <cassert> | ||
|
||
#include "util_math.hh" | ||
|
||
void BetaBernoulli::incorporate(const double& x) { | ||
assert(x == 0 || x == 1); | ||
++N; | ||
s += x; | ||
} | ||
|
||
void BetaBernoulli::unincorporate(const double& x) { | ||
assert(x == 0 || x == 1); | ||
--N; | ||
s -= x; | ||
assert(0 <= s); | ||
assert(0 <= N); | ||
} | ||
|
||
double BetaBernoulli::logp(const double& x) const { | ||
assert(x == 0 || x == 1); | ||
double log_denom = log(N + alpha + beta); | ||
double log_numer = x ? log(s + alpha) : log(N - s + beta); | ||
return log_numer - log_denom; | ||
} | ||
|
||
double BetaBernoulli::logp_score() const { | ||
double v1 = lbeta(s + alpha, N - s + beta); | ||
double v2 = lbeta(alpha, beta); | ||
return v1 - v2; | ||
} | ||
|
||
double BetaBernoulli::sample() { | ||
double p = exp(logp(1)); | ||
std::vector<int> items{0, 1}; | ||
std::vector<double> weights{1 - p, p}; | ||
int idx = choice(weights, prng); | ||
return items[idx]; | ||
} | ||
|
||
void BetaBernoulli::transition_hyperparameters() { | ||
std::vector<double> logps; | ||
std::vector<std::pair<double, double>> hypers; | ||
// C++ doesn't yet allow range for-loops over existing variables. Sigh. | ||
for (double alphat : alpha_grid) { | ||
for (double betat : beta_grid) { | ||
alpha = alphat; | ||
beta = betat; | ||
double lp = logp_score(); | ||
if (!std::isnan(lp)) { | ||
logps.push_back(logp_score()); | ||
hypers.push_back(std::make_pair(alpha, beta)); | ||
} | ||
} | ||
} | ||
int i = sample_from_logps(logps, prng); | ||
alpha = hypers[i].first; | ||
beta = hypers[i].second; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
// Copyright 2024 | ||
// See LICENSE.txt | ||
|
||
#include "distributions/bigram.hh" | ||
|
||
#include <cassert> | ||
|
||
#include "distributions/base.hh" | ||
|
||
void Bigram::assert_valid_char(const char c) const { | ||
assert(c >= ' ' && c <= '~'); | ||
} | ||
|
||
size_t Bigram::char_to_index(const char c) const { | ||
assert_valid_char(c); | ||
return c - ' '; | ||
} | ||
|
||
char Bigram::index_to_char(const size_t i) const { | ||
const char c = i + ' '; | ||
assert_valid_char(c); | ||
return c; | ||
} | ||
|
||
std::vector<size_t> Bigram::string_to_indices(const std::string& str) const { | ||
// Convert the string to a vector of indices between 0 and `num_chars`, | ||
// with a start/stop symbol at the beginning/end. | ||
std::vector<size_t> inds = {num_chars}; | ||
for (const char& c : str) { | ||
inds.push_back(char_to_index(c)); | ||
} | ||
inds.push_back(num_chars); | ||
return inds; | ||
} | ||
|
||
void Bigram::incorporate(const std::string& x) { | ||
const std::vector<size_t> indices = string_to_indices(x); | ||
for (size_t i = 0; i != indices.size() - 1; ++i) { | ||
transition_dists[indices[i]].incorporate(indices[i + 1]); | ||
} | ||
++N; | ||
} | ||
|
||
void Bigram::unincorporate(const std::string& s) { | ||
const std::vector<size_t> indices = string_to_indices(s); | ||
for (size_t i = 0; i != indices.size() - 1; ++i) { | ||
transition_dists[indices[i]].unincorporate(indices[i + 1]); | ||
} | ||
--N; | ||
} | ||
|
||
double Bigram::logp(const std::string& s) const { | ||
const std::vector<size_t> indices = string_to_indices(s); | ||
double total_logp = 0.0; | ||
for (size_t i = 0; i != indices.size() - 1; ++i) { | ||
total_logp += transition_dists[indices[i]].logp(indices[i + 1]); | ||
// Incorporate each value so that subsequent probabilities are | ||
// conditioned on it. | ||
transition_dists[indices[i]].incorporate(indices[i + 1]); | ||
} | ||
for (size_t i = 0; i != indices.size() - 1; ++i) { | ||
transition_dists[indices[i]].unincorporate(indices[i + 1]); | ||
} | ||
return total_logp; | ||
} | ||
|
||
double Bigram::logp_score() const { | ||
double logp = 0; | ||
for (const auto& d : transition_dists) { | ||
logp += d.logp_score(); | ||
} | ||
return logp; | ||
} | ||
|
||
std::string Bigram::sample() { | ||
std::string sampled_string; | ||
// TODO(emilyaf): Reconsider the reserved length and maybe enforce a | ||
// max length. | ||
sampled_string.reserve(30); | ||
// Sample the first character conditioned on the stop/start symbol. | ||
size_t current_ind = num_chars; | ||
size_t next_ind = transition_dists[current_ind].sample(); | ||
transition_dists[current_ind].incorporate(next_ind); | ||
current_ind = next_ind; | ||
|
||
// Sample additional characters until the stop/start symbol is sampled. | ||
// Incorporate the sampled character at each loop iteration so that | ||
// subsequent samples are conditioned on its observation. | ||
while (current_ind != num_chars) { | ||
sampled_string += index_to_char(current_ind); | ||
next_ind = transition_dists[current_ind].sample(); | ||
transition_dists[current_ind].incorporate(next_ind); | ||
current_ind = next_ind; | ||
} | ||
unincorporate(sampled_string); | ||
return sampled_string; | ||
} | ||
|
||
void Bigram::set_alpha(double alphat) { | ||
alpha = alphat; | ||
for (auto& trans_dist : transition_dists) { | ||
trans_dist.alpha = alpha; | ||
} | ||
} | ||
|
||
void Bigram::transition_hyperparameters() { | ||
std::vector<double> logps; | ||
std::vector<double> alphas; | ||
// C++ doesn't yet allow range for-loops over existing variables. Sigh. | ||
for (double alphat : ALPHA_GRID) { | ||
set_alpha(alphat); | ||
double lp = logp_score(); | ||
if (!std::isnan(lp)) { | ||
logps.push_back(logp_score()); | ||
alphas.push_back(alphat); | ||
} | ||
} | ||
int i = sample_from_logps(logps, prng); | ||
set_alpha(alphas[i]); | ||
} |
Oops, something went wrong.