From bf3921bf882cddadccb16ddee833171b4beb7be4 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Thu, 13 Jun 2024 17:45:38 +0000 Subject: [PATCH] Change BetaBernoulli SampleType from double to bool. --- cxx/distributions/beta_bernoulli.cc | 10 +++++----- cxx/distributions/beta_bernoulli.hh | 10 +++++----- cxx/tests/test_hirm_animals.cc | 16 ++++++++-------- cxx/tests/test_irm_two_relations.cc | 14 +++++++------- cxx/tests/test_misc.cc | 2 +- cxx/util_distribution_variant.cc | 2 +- cxx/util_distribution_variant.hh | 2 +- 7 files changed, 28 insertions(+), 28 deletions(-) diff --git a/cxx/distributions/beta_bernoulli.cc b/cxx/distributions/beta_bernoulli.cc index 032e09c..be490ad 100644 --- a/cxx/distributions/beta_bernoulli.cc +++ b/cxx/distributions/beta_bernoulli.cc @@ -7,13 +7,13 @@ #include "util_math.hh" -void BetaBernoulli::incorporate(const double& x) { +void BetaBernoulli::incorporate(const bool& x) { assert(x == 0 || x == 1); ++N; s += x; } -void BetaBernoulli::unincorporate(const double& x) { +void BetaBernoulli::unincorporate(const bool& x) { assert(x == 0 || x == 1); --N; s -= x; @@ -21,7 +21,7 @@ void BetaBernoulli::unincorporate(const double& x) { assert(0 <= N); } -double BetaBernoulli::logp(const double& x) const { +double BetaBernoulli::logp(const bool& x) const { assert(x == 0 || x == 1); double log_denom = log(N + alpha + beta); double log_numer = x ? log(s + alpha) : log(N - s + beta); @@ -34,9 +34,9 @@ double BetaBernoulli::logp_score() const { return v1 - v2; } -double BetaBernoulli::sample() { +bool BetaBernoulli::sample() { double p = exp(logp(1)); - std::vector items{0, 1}; + std::vector items{false, true}; std::vector weights{1 - p, p}; int idx = choice(weights, prng); return items[idx]; diff --git a/cxx/distributions/beta_bernoulli.hh b/cxx/distributions/beta_bernoulli.hh index 45bef74..bf9c972 100644 --- a/cxx/distributions/beta_bernoulli.hh +++ b/cxx/distributions/beta_bernoulli.hh @@ -9,7 +9,7 @@ // TODO(thomaswc, emilyaf): Change BetaBernoulli to use bool instead of // double. -class BetaBernoulli : public Distribution { +class BetaBernoulli : public Distribution { public: double alpha = 1; // hyperparameter double beta = 1; // hyperparameter @@ -26,15 +26,15 @@ class BetaBernoulli : public Distribution { beta_grid = log_linspace(1e-4, 1e4, 10, true); } - void incorporate(const double& x); + void incorporate(const bool& x); - void unincorporate(const double& x); + void unincorporate(const bool& x); - double logp(const double& x) const; + double logp(const bool& x) const; double logp_score() const; - double sample(); + bool sample(); void transition_hyperparameters(); }; diff --git a/cxx/tests/test_hirm_animals.cc b/cxx/tests/test_hirm_animals.cc index 47d8fdd..e5c5872 100644 --- a/cxx/tests/test_hirm_animals.cc +++ b/cxx/tests/test_hirm_animals.cc @@ -75,25 +75,25 @@ int main(int argc, char** argv) { // Marginally normalized. int persiancat = enc["animal"]["persiancat"]; - auto p0_black_persiancat = hirm.logp({{"black", {persiancat}, 0.}}); - auto p1_black_persiancat = hirm.logp({{"black", {persiancat}, 1.}}); + auto p0_black_persiancat = hirm.logp({{"black", {persiancat}, false}}); + auto p1_black_persiancat = hirm.logp({{"black", {persiancat}, true}}); assert(abs(logsumexp({p0_black_persiancat, p1_black_persiancat})) < 1e-10); // Marginally normalized. int sheep = enc["animal"]["sheep"]; - auto p0_solitary_sheep = hirm.logp({{"solitary", {sheep}, 0.}}); - auto p1_solitary_sheep = hirm.logp({{"solitary", {sheep}, 1.}}); + auto p0_solitary_sheep = hirm.logp({{"solitary", {sheep}, false}}); + auto p1_solitary_sheep = hirm.logp({{"solitary", {sheep}, true}}); assert(abs(logsumexp({p0_solitary_sheep, p1_solitary_sheep})) < 1e-10); // Jointly normalized. auto p00_black_persiancat_solitary_sheep = - hirm.logp({{"black", {persiancat}, 0.}, {"solitary", {sheep}, 0.}}); + hirm.logp({{"black", {persiancat}, false}, {"solitary", {sheep}, false}}); auto p01_black_persiancat_solitary_sheep = - hirm.logp({{"black", {persiancat}, 0.}, {"solitary", {sheep}, 1.}}); + hirm.logp({{"black", {persiancat}, false}, {"solitary", {sheep}, true}}); auto p10_black_persiancat_solitary_sheep = - hirm.logp({{"black", {persiancat}, 1.}, {"solitary", {sheep}, 0.}}); + hirm.logp({{"black", {persiancat}, true}, {"solitary", {sheep}, false}}); auto p11_black_persiancat_solitary_sheep = - hirm.logp({{"black", {persiancat}, 1.}, {"solitary", {sheep}, 1.}}); + hirm.logp({{"black", {persiancat}, true}, {"solitary", {sheep}, true}}); auto Z = logsumexp({ p00_black_persiancat_solitary_sheep, p01_black_persiancat_solitary_sheep, diff --git a/cxx/tests/test_irm_two_relations.cc b/cxx/tests/test_irm_two_relations.cc index ff0ce9f..08e4a09 100644 --- a/cxx/tests/test_irm_two_relations.cc +++ b/cxx/tests/test_irm_two_relations.cc @@ -80,10 +80,10 @@ int main(int argc, char** argv) { assert(l.size() == 2); auto x1 = l.at(0); auto x2 = l.at(1); - auto p0 = std::get(irm.relations.at("R1"))->logp({x1, x2}, 0.); - auto p0_irm = irm.logp({{"R1", {x1, x2}, 0.}}); + auto p0 = std::get(irm.relations.at("R1"))->logp({x1, x2}, false); + auto p0_irm = irm.logp({{"R1", {x1, x2}, false}}); assert(abs(p0 - p0_irm) < 1e-10); - auto p1 = std::get(irm.relations.at("R1"))->logp({x1, x2}, 1.); + auto p1 = std::get(irm.relations.at("R1"))->logp({x1, x2}, true); auto Z = logsumexp({p0, p1}); assert(abs(Z) < 1e-10); assert(abs(exp(p0) - expected_p0[x1].at(x2)) < .1); @@ -94,10 +94,10 @@ int main(int argc, char** argv) { auto x1 = l.at(0); auto x2 = l.at(1); auto x3 = l.at(2); - auto p00 = irm.logp({{"R1", {x1, x2}, 0.}, {"R1", {x1, x3}, 0.}}); - auto p01 = irm.logp({{"R1", {x1, x2}, 0.}, {"R1", {x1, x3}, 1.}}); - auto p10 = irm.logp({{"R1", {x1, x2}, 1.}, {"R1", {x1, x3}, 0.}}); - auto p11 = irm.logp({{"R1", {x1, x2}, 1.}, {"R1", {x1, x3}, 1.}}); + auto p00 = irm.logp({{"R1", {x1, x2}, false}, {"R1", {x1, x3}, false}}); + auto p01 = irm.logp({{"R1", {x1, x2}, false}, {"R1", {x1, x3}, true}}); + auto p10 = irm.logp({{"R1", {x1, x2}, true}, {"R1", {x1, x3}, false}}); + auto p11 = irm.logp({{"R1", {x1, x2}, true}, {"R1", {x1, x3}, true}}); auto Z = logsumexp({p00, p01, p10, p11}); assert(abs(Z) < 1e-10); } diff --git a/cxx/tests/test_misc.cc b/cxx/tests/test_misc.cc index 455e382..17315d2 100644 --- a/cxx/tests/test_misc.cc +++ b/cxx/tests/test_misc.cc @@ -75,7 +75,7 @@ int main(int argc, char** argv) { auto value = std::get<2>(i); auto item = std::get<1>(i); printf("incorporating %s ", relation.c_str()); - printf("%1.f ", std::get(value)); + printf("%d ", std::get(value)); int counter = 0; T_items items_code; for (auto const& item : std::get<1>(i)) { diff --git a/cxx/util_distribution_variant.cc b/cxx/util_distribution_variant.cc index 3e6023f..49d8e3f 100644 --- a/cxx/util_distribution_variant.cc +++ b/cxx/util_distribution_variant.cc @@ -20,7 +20,7 @@ ObservationVariant observation_string_to_value( case DistributionEnum::normal: return std::stod(value_str); case DistributionEnum::bernoulli: - return std::stod(value_str); + return static_cast(std::stoi(value_str)); case DistributionEnum::categorical: return std::stoi(value_str); case DistributionEnum::bigram: diff --git a/cxx/util_distribution_variant.hh b/cxx/util_distribution_variant.hh index 0386689..01ca568 100644 --- a/cxx/util_distribution_variant.hh +++ b/cxx/util_distribution_variant.hh @@ -28,7 +28,7 @@ template class Relation; // Set of all distribution sample types. -using ObservationVariant = std::variant; +using ObservationVariant = std::variant; using DistributionVariant = std::variant;