Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change BetaBernoulli SampleType from double to bool. #50

Merged
merged 1 commit into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions cxx/distributions/beta_bernoulli.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@

#include "util_math.hh"

void BetaBernoulli::incorporate(const double& x) {
void BetaBernoulli::incorporate(const bool& x) {
assert(x == 0 || x == 1);
++N;
s += x;
}

void BetaBernoulli::unincorporate(const double& x) {
void BetaBernoulli::unincorporate(const bool& x) {
assert(x == 0 || x == 1);
--N;
s -= x;
assert(0 <= s);
assert(0 <= N);
}

double BetaBernoulli::logp(const double& x) const {
double BetaBernoulli::logp(const bool& x) const {
assert(x == 0 || x == 1);
double log_denom = log(N + alpha + beta);
double log_numer = x ? log(s + alpha) : log(N - s + beta);
Expand All @@ -34,9 +34,9 @@ double BetaBernoulli::logp_score() const {
return v1 - v2;
}

double BetaBernoulli::sample() {
bool BetaBernoulli::sample() {
double p = exp(logp(1));
std::vector<int> items{0, 1};
std::vector<bool> items{false, true};
std::vector<double> weights{1 - p, p};
int idx = choice(weights, prng);
return items[idx];
Expand Down
10 changes: 5 additions & 5 deletions cxx/distributions/beta_bernoulli.hh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

// TODO(thomaswc, emilyaf): Change BetaBernoulli to use bool instead of
// double.
class BetaBernoulli : public Distribution<double> {
class BetaBernoulli : public Distribution<bool> {
public:
double alpha = 1; // hyperparameter
double beta = 1; // hyperparameter
Expand All @@ -26,15 +26,15 @@ class BetaBernoulli : public Distribution<double> {
beta_grid = log_linspace(1e-4, 1e4, 10, true);
}

void incorporate(const double& x);
void incorporate(const bool& x);

void unincorporate(const double& x);
void unincorporate(const bool& x);

double logp(const double& x) const;
double logp(const bool& x) const;

double logp_score() const;

double sample();
bool sample();

void transition_hyperparameters();
};
16 changes: 8 additions & 8 deletions cxx/tests/test_hirm_animals.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,25 @@ int main(int argc, char** argv) {

// Marginally normalized.
int persiancat = enc["animal"]["persiancat"];
auto p0_black_persiancat = hirm.logp({{"black", {persiancat}, 0.}});
auto p1_black_persiancat = hirm.logp({{"black", {persiancat}, 1.}});
auto p0_black_persiancat = hirm.logp({{"black", {persiancat}, false}});
auto p1_black_persiancat = hirm.logp({{"black", {persiancat}, true}});
assert(abs(logsumexp({p0_black_persiancat, p1_black_persiancat})) < 1e-10);

// Marginally normalized.
int sheep = enc["animal"]["sheep"];
auto p0_solitary_sheep = hirm.logp({{"solitary", {sheep}, 0.}});
auto p1_solitary_sheep = hirm.logp({{"solitary", {sheep}, 1.}});
auto p0_solitary_sheep = hirm.logp({{"solitary", {sheep}, false}});
auto p1_solitary_sheep = hirm.logp({{"solitary", {sheep}, true}});
assert(abs(logsumexp({p0_solitary_sheep, p1_solitary_sheep})) < 1e-10);

// Jointly normalized.
auto p00_black_persiancat_solitary_sheep =
hirm.logp({{"black", {persiancat}, 0.}, {"solitary", {sheep}, 0.}});
hirm.logp({{"black", {persiancat}, false}, {"solitary", {sheep}, false}});
auto p01_black_persiancat_solitary_sheep =
hirm.logp({{"black", {persiancat}, 0.}, {"solitary", {sheep}, 1.}});
hirm.logp({{"black", {persiancat}, false}, {"solitary", {sheep}, true}});
auto p10_black_persiancat_solitary_sheep =
hirm.logp({{"black", {persiancat}, 1.}, {"solitary", {sheep}, 0.}});
hirm.logp({{"black", {persiancat}, true}, {"solitary", {sheep}, false}});
auto p11_black_persiancat_solitary_sheep =
hirm.logp({{"black", {persiancat}, 1.}, {"solitary", {sheep}, 1.}});
hirm.logp({{"black", {persiancat}, true}, {"solitary", {sheep}, true}});
auto Z = logsumexp({
p00_black_persiancat_solitary_sheep,
p01_black_persiancat_solitary_sheep,
Expand Down
14 changes: 7 additions & 7 deletions cxx/tests/test_irm_two_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ int main(int argc, char** argv) {
assert(l.size() == 2);
auto x1 = l.at(0);
auto x2 = l.at(1);
auto p0 = std::get<T_r>(irm.relations.at("R1"))->logp({x1, x2}, 0.);
auto p0_irm = irm.logp({{"R1", {x1, x2}, 0.}});
auto p0 = std::get<T_r>(irm.relations.at("R1"))->logp({x1, x2}, false);
auto p0_irm = irm.logp({{"R1", {x1, x2}, false}});
assert(abs(p0 - p0_irm) < 1e-10);
auto p1 = std::get<T_r>(irm.relations.at("R1"))->logp({x1, x2}, 1.);
auto p1 = std::get<T_r>(irm.relations.at("R1"))->logp({x1, x2}, true);
auto Z = logsumexp({p0, p1});
assert(abs(Z) < 1e-10);
assert(abs(exp(p0) - expected_p0[x1].at(x2)) < .1);
Expand All @@ -94,10 +94,10 @@ int main(int argc, char** argv) {
auto x1 = l.at(0);
auto x2 = l.at(1);
auto x3 = l.at(2);
auto p00 = irm.logp({{"R1", {x1, x2}, 0.}, {"R1", {x1, x3}, 0.}});
auto p01 = irm.logp({{"R1", {x1, x2}, 0.}, {"R1", {x1, x3}, 1.}});
auto p10 = irm.logp({{"R1", {x1, x2}, 1.}, {"R1", {x1, x3}, 0.}});
auto p11 = irm.logp({{"R1", {x1, x2}, 1.}, {"R1", {x1, x3}, 1.}});
auto p00 = irm.logp({{"R1", {x1, x2}, false}, {"R1", {x1, x3}, false}});
auto p01 = irm.logp({{"R1", {x1, x2}, false}, {"R1", {x1, x3}, true}});
auto p10 = irm.logp({{"R1", {x1, x2}, true}, {"R1", {x1, x3}, false}});
auto p11 = irm.logp({{"R1", {x1, x2}, true}, {"R1", {x1, x3}, true}});
auto Z = logsumexp({p00, p01, p10, p11});
assert(abs(Z) < 1e-10);
}
Expand Down
2 changes: 1 addition & 1 deletion cxx/tests/test_misc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ int main(int argc, char** argv) {
auto value = std::get<2>(i);
auto item = std::get<1>(i);
printf("incorporating %s ", relation.c_str());
printf("%1.f ", std::get<double>(value));
printf("%d ", std::get<bool>(value));
int counter = 0;
T_items items_code;
for (auto const& item : std::get<1>(i)) {
Expand Down
2 changes: 1 addition & 1 deletion cxx/util_distribution_variant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ ObservationVariant observation_string_to_value(
case DistributionEnum::normal:
return std::stod(value_str);
case DistributionEnum::bernoulli:
return std::stod(value_str);
return static_cast<bool>(std::stoi(value_str));
case DistributionEnum::categorical:
return std::stoi(value_str);
case DistributionEnum::bigram:
Expand Down
2 changes: 1 addition & 1 deletion cxx/util_distribution_variant.hh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ template <typename DistributionType>
class Relation;

// Set of all distribution sample types.
using ObservationVariant = std::variant<double, int, std::string>;
using ObservationVariant = std::variant<double, int, bool, std::string>;

using DistributionVariant =
std::variant<BetaBernoulli*, Bigram*, DirichletCategorical*, Normal*>;
Expand Down