Skip to content

Commit

Permalink
Merge pull request #50 from probcomp/061324-emilyaf-bernoulli-bool
Browse files Browse the repository at this point in the history
Change BetaBernoulli SampleType from double to bool.
  • Loading branch information
emilyfertig authored Jun 13, 2024
2 parents 4c8ebe2 + bf3921b commit a20574f
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 28 deletions.
10 changes: 5 additions & 5 deletions cxx/distributions/beta_bernoulli.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@

#include "util_math.hh"

void BetaBernoulli::incorporate(const double& x) {
void BetaBernoulli::incorporate(const bool& x) {
assert(x == 0 || x == 1);
++N;
s += x;
}

void BetaBernoulli::unincorporate(const double& x) {
void BetaBernoulli::unincorporate(const bool& x) {
assert(x == 0 || x == 1);
--N;
s -= x;
assert(0 <= s);
assert(0 <= N);
}

double BetaBernoulli::logp(const double& x) const {
double BetaBernoulli::logp(const bool& x) const {
assert(x == 0 || x == 1);
double log_denom = log(N + alpha + beta);
double log_numer = x ? log(s + alpha) : log(N - s + beta);
Expand All @@ -34,9 +34,9 @@ double BetaBernoulli::logp_score() const {
return v1 - v2;
}

double BetaBernoulli::sample() {
bool BetaBernoulli::sample() {
double p = exp(logp(1));
std::vector<int> items{0, 1};
std::vector<bool> items{false, true};
std::vector<double> weights{1 - p, p};
int idx = choice(weights, prng);
return items[idx];
Expand Down
10 changes: 5 additions & 5 deletions cxx/distributions/beta_bernoulli.hh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

// TODO(thomaswc, emilyaf): Change BetaBernoulli to use bool instead of
// double.
class BetaBernoulli : public Distribution<double> {
class BetaBernoulli : public Distribution<bool> {
public:
double alpha = 1; // hyperparameter
double beta = 1; // hyperparameter
Expand All @@ -26,15 +26,15 @@ class BetaBernoulli : public Distribution<double> {
beta_grid = log_linspace(1e-4, 1e4, 10, true);
}

void incorporate(const double& x);
void incorporate(const bool& x);

void unincorporate(const double& x);
void unincorporate(const bool& x);

double logp(const double& x) const;
double logp(const bool& x) const;

double logp_score() const;

double sample();
bool sample();

void transition_hyperparameters();
};
16 changes: 8 additions & 8 deletions cxx/tests/test_hirm_animals.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,25 @@ int main(int argc, char** argv) {

// Marginally normalized.
int persiancat = enc["animal"]["persiancat"];
auto p0_black_persiancat = hirm.logp({{"black", {persiancat}, 0.}});
auto p1_black_persiancat = hirm.logp({{"black", {persiancat}, 1.}});
auto p0_black_persiancat = hirm.logp({{"black", {persiancat}, false}});
auto p1_black_persiancat = hirm.logp({{"black", {persiancat}, true}});
assert(abs(logsumexp({p0_black_persiancat, p1_black_persiancat})) < 1e-10);

// Marginally normalized.
int sheep = enc["animal"]["sheep"];
auto p0_solitary_sheep = hirm.logp({{"solitary", {sheep}, 0.}});
auto p1_solitary_sheep = hirm.logp({{"solitary", {sheep}, 1.}});
auto p0_solitary_sheep = hirm.logp({{"solitary", {sheep}, false}});
auto p1_solitary_sheep = hirm.logp({{"solitary", {sheep}, true}});
assert(abs(logsumexp({p0_solitary_sheep, p1_solitary_sheep})) < 1e-10);

// Jointly normalized.
auto p00_black_persiancat_solitary_sheep =
hirm.logp({{"black", {persiancat}, 0.}, {"solitary", {sheep}, 0.}});
hirm.logp({{"black", {persiancat}, false}, {"solitary", {sheep}, false}});
auto p01_black_persiancat_solitary_sheep =
hirm.logp({{"black", {persiancat}, 0.}, {"solitary", {sheep}, 1.}});
hirm.logp({{"black", {persiancat}, false}, {"solitary", {sheep}, true}});
auto p10_black_persiancat_solitary_sheep =
hirm.logp({{"black", {persiancat}, 1.}, {"solitary", {sheep}, 0.}});
hirm.logp({{"black", {persiancat}, true}, {"solitary", {sheep}, false}});
auto p11_black_persiancat_solitary_sheep =
hirm.logp({{"black", {persiancat}, 1.}, {"solitary", {sheep}, 1.}});
hirm.logp({{"black", {persiancat}, true}, {"solitary", {sheep}, true}});
auto Z = logsumexp({
p00_black_persiancat_solitary_sheep,
p01_black_persiancat_solitary_sheep,
Expand Down
14 changes: 7 additions & 7 deletions cxx/tests/test_irm_two_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ int main(int argc, char** argv) {
assert(l.size() == 2);
auto x1 = l.at(0);
auto x2 = l.at(1);
auto p0 = std::get<T_r>(irm.relations.at("R1"))->logp({x1, x2}, 0.);
auto p0_irm = irm.logp({{"R1", {x1, x2}, 0.}});
auto p0 = std::get<T_r>(irm.relations.at("R1"))->logp({x1, x2}, false);
auto p0_irm = irm.logp({{"R1", {x1, x2}, false}});
assert(abs(p0 - p0_irm) < 1e-10);
auto p1 = std::get<T_r>(irm.relations.at("R1"))->logp({x1, x2}, 1.);
auto p1 = std::get<T_r>(irm.relations.at("R1"))->logp({x1, x2}, true);
auto Z = logsumexp({p0, p1});
assert(abs(Z) < 1e-10);
assert(abs(exp(p0) - expected_p0[x1].at(x2)) < .1);
Expand All @@ -94,10 +94,10 @@ int main(int argc, char** argv) {
auto x1 = l.at(0);
auto x2 = l.at(1);
auto x3 = l.at(2);
auto p00 = irm.logp({{"R1", {x1, x2}, 0.}, {"R1", {x1, x3}, 0.}});
auto p01 = irm.logp({{"R1", {x1, x2}, 0.}, {"R1", {x1, x3}, 1.}});
auto p10 = irm.logp({{"R1", {x1, x2}, 1.}, {"R1", {x1, x3}, 0.}});
auto p11 = irm.logp({{"R1", {x1, x2}, 1.}, {"R1", {x1, x3}, 1.}});
auto p00 = irm.logp({{"R1", {x1, x2}, false}, {"R1", {x1, x3}, false}});
auto p01 = irm.logp({{"R1", {x1, x2}, false}, {"R1", {x1, x3}, true}});
auto p10 = irm.logp({{"R1", {x1, x2}, true}, {"R1", {x1, x3}, false}});
auto p11 = irm.logp({{"R1", {x1, x2}, true}, {"R1", {x1, x3}, true}});
auto Z = logsumexp({p00, p01, p10, p11});
assert(abs(Z) < 1e-10);
}
Expand Down
2 changes: 1 addition & 1 deletion cxx/tests/test_misc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ int main(int argc, char** argv) {
auto value = std::get<2>(i);
auto item = std::get<1>(i);
printf("incorporating %s ", relation.c_str());
printf("%1.f ", std::get<double>(value));
printf("%d ", std::get<bool>(value));
int counter = 0;
T_items items_code;
for (auto const& item : std::get<1>(i)) {
Expand Down
2 changes: 1 addition & 1 deletion cxx/util_distribution_variant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ ObservationVariant observation_string_to_value(
case DistributionEnum::normal:
return std::stod(value_str);
case DistributionEnum::bernoulli:
return std::stod(value_str);
return static_cast<bool>(std::stoi(value_str));
case DistributionEnum::categorical:
return std::stoi(value_str);
case DistributionEnum::bigram:
Expand Down
2 changes: 1 addition & 1 deletion cxx/util_distribution_variant.hh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ template <typename DistributionType>
class Relation;

// Set of all distribution sample types.
using ObservationVariant = std::variant<double, int, std::string>;
using ObservationVariant = std::variant<double, int, bool, std::string>;

using DistributionVariant =
std::variant<BetaBernoulli*, Bigram*, DirichletCategorical*, Normal*>;
Expand Down

0 comments on commit a20574f

Please sign in to comment.