diff --git a/cxx/BUILD b/cxx/BUILD index 72e5aef..598628c 100644 --- a/cxx/BUILD +++ b/cxx/BUILD @@ -7,6 +7,20 @@ cc_library( deps = [], ) +cc_library( + name = "clean_relation", + hdrs = ["clean_relation.hh"], + visibility = [":__subpackages__"], + deps = [ + ":domain", + ":relation", + ":util_distribution_variant", + ":util_hash", + ":util_math", + "//distributions:base" + ], +) + cc_library( name = "domain", hdrs = ["domain.hh"], @@ -21,7 +35,7 @@ cc_library( srcs = ["irm.cc"], visibility = [":__subpackages__"], deps = [ - ":relation", + ":clean_relation", ":relation_variant", ":util_distribution_variant", ], @@ -60,6 +74,21 @@ cc_binary( ], ) +cc_library( + name = "noisy_relation", + hdrs = ["noisy_relation.hh"], + visibility = [":__subpackages__"], + deps = [ + ":domain", + ":clean_relation", + ":relation", + ":util_distribution_variant", + ":util_hash", + ":util_math", + "//distributions:base" + ], +) + cc_library( name = "relation", hdrs = ["relation.hh"], @@ -80,6 +109,7 @@ cc_library( visibility = [":__subpackages__"], deps = [ ":domain", + ":clean_relation", ":relation", ":util_distribution_variant", ], @@ -94,6 +124,7 @@ cc_library( ":domain", "//distributions", "@boost//:algorithm", + "//emissions", ], ) @@ -125,6 +156,17 @@ cc_library( deps = [], ) +cc_test( + name = "clean_relation_test", + srcs = ["clean_relation_test.cc"], + deps = [ + ":domain", + ":clean_relation", + "//distributions", + "@boost//:test", + ], +) + cc_test( name = "domain_test", srcs = ["domain_test.cc"], @@ -144,11 +186,12 @@ cc_test( ) cc_test( - name = "relation_test", - srcs = ["relation_test.cc"], + name = "noisy_relation_test", + srcs = ["noisy_relation_test.cc"], deps = [ ":domain", - ":relation", + ":clean_relation", + ":noisy_relation", "//distributions", "@boost//:test", ], @@ -158,6 +201,7 @@ cc_test( name = "relation_variant_test", srcs = ["relation_variant_test.cc"], deps = [ + ":clean_relation", ":relation_variant", "@boost//:test", ], diff --git a/cxx/clean_relation.hh b/cxx/clean_relation.hh new file mode 100644 index 0000000..e5001da --- /dev/null +++ b/cxx/clean_relation.hh @@ -0,0 +1,412 @@ +// Copyright 2020 +// See LICENSE.txt + +#pragma once + +#include +#include +#include +#include + +#include "distributions/base.hh" +#include "domain.hh" +#include "relation.hh" +#include "util_distribution_variant.hh" +#include "util_hash.hh" +#include "util_math.hh" + +// T_clean_relation is the text we get from reading a line of the schema +// file; CleanRelation is the object that does the work. +class T_clean_relation { + public: + // The relation is a map from the domains to the space .distribution + // is a distribution over. + std::vector domains; + + // TODO(emilyaf): Enable observed vs. latent. + // bool is_observed; + + DistributionSpec distribution_spec; +}; + +template +class CleanRelation : public Relation { + public: + typedef T ValueType; + + // human-readable name + const std::string name; + // list of domain pointers + const std::vector domains; + // Distribution or Emission spec over the relation's codomain. + const std::variant prior_spec; + // map from cluster multi-index to Distribution pointer + std::unordered_map, Distribution*, + VectorIntHash> + clusters; + // map from item to observed data + std::unordered_map data; + // map from domain name to reverse map from item to + // set of items that include that item + std::unordered_map< + std::string, + std::unordered_map>> + data_r; + + CleanRelation(const std::string& name, + const std::variant& prior_spec, + const std::vector& domains) + : name(name), domains(domains), prior_spec(prior_spec) { + assert(!domains.empty()); + assert(!name.empty()); + for (const Domain* const d : domains) { + this->data_r[d->name] = + std::unordered_map>(); + } + } + + ~CleanRelation() { + for (auto [z, cluster] : clusters) { + delete cluster; + } + } + + Distribution* make_new_distribution(std::mt19937* prng) const { + auto var_to_dist = [&](auto dist_variant) { + return std::visit( + [&](auto v) { return reinterpret_cast*>(v); }, + dist_variant); + }; + auto spec_to_dist = [&](auto spec) { + return var_to_dist(cluster_prior_from_spec(spec, prng)); + }; + return std::visit(spec_to_dist, prior_spec); + } + + void incorporate(std::mt19937* prng, const T_items& items, ValueType value) { + assert(!data.contains(items)); + data[items] = value; + for (int i = 0; i < std::ssize(domains); ++i) { + domains[i]->incorporate(prng, items[i]); + if (!data_r.at(domains[i]->name).contains(items[i])) { + data_r.at(domains[i]->name)[items[i]] = + std::unordered_set(); + } + data_r.at(domains[i]->name).at(items[i]).insert(items); + } + T_items z = get_cluster_assignment(items); + if (!clusters.contains(z)) { + clusters[z] = make_new_distribution(prng); + } + clusters.at(z)->incorporate(value); + } + + void unincorporate(const T_items& items) { + printf("Not implemented\n"); + exit(EXIT_FAILURE); + // auto x = data.at(items); + // auto z = get_cluster_assignment(items); + // clusters.at(z)->unincorporate(x); + // if (clusters.at(z)->N == 0) { + // delete clusters.at(z); + // clusters.erase(z); + // } + // for (int i = 0; i < domains.size(); i++) { + // const std::string &n = domains[i]->name; + // if (data_r.at(n).count(items[i]) > 0) { + // data_r.at(n).at(items[i]).erase(items); + // if (data_r.at(n).at(items[i]).size() == 0) { + // data_r.at(n).erase(items[i]); + // domains[i]->unincorporate(name, items[i]); + // } + // } + // } + // data.erase(items); + } + + std::vector get_cluster_assignment(const T_items& items) const { + assert(items.size() == domains.size()); + std::vector z(domains.size()); + for (int i = 0; i < std::ssize(domains); ++i) { + z[i] = domains[i]->get_cluster_assignment(items[i]); + } + return z; + } + + double cluster_or_prior_logp(std::mt19937* prng, const T_items& z, + const ValueType& value) const { + if (clusters.contains(z)) { + return clusters.at(z)->logp(value); + } + Distribution* prior = make_new_distribution(prng); + double prior_logp = prior->logp(value); + delete prior; + return prior_logp; + } + + std::vector get_cluster_assignment_gibbs(const T_items& items, + const Domain& domain, + const T_item& item, + int table) const { + assert(items.size() == domains.size()); + std::vector z(domains.size()); + int hits = 0; + for (int i = 0; i < std::ssize(domains); ++i) { + if ((domains[i]->name == domain.name) && (items[i] == item)) { + z[i] = table; + ++hits; + } else { + z[i] = domains[i]->get_cluster_assignment(items[i]); + } + } + assert(hits > 0); + return z; + } + + // Implementation of approximate Gibbs data probabilities (faster). + + double logp_gibbs_approx_current(const Domain& domain, const T_item& item) { + double logp = 0.; + for (const T_items& items : data_r.at(domain.name).at(item)) { + ValueType x = data.at(items); + T_items z = get_cluster_assignment(items); + auto cluster = clusters.at(z); + cluster->unincorporate(x); + double lp = cluster->logp(x); + cluster->incorporate(x); + logp += lp; + } + return logp; + } + + double logp_gibbs_approx_variant(const Domain& domain, const T_item& item, + int table, std::mt19937* prng) { + double logp = 0.; + for (const T_items& items : data_r.at(domain.name).at(item)) { + ValueType x = data.at(items); + T_items z = get_cluster_assignment_gibbs(items, domain, item, table); + double lp; + if (!clusters.contains(z)) { + Distribution* tmp_dist = make_new_distribution(prng); + lp = tmp_dist->logp(x); + delete tmp_dist; + } else { + lp = clusters.at(z)->logp(x); + } + logp += lp; + } + return logp; + } + + double logp_gibbs_approx(const Domain& domain, const T_item& item, int table, + std::mt19937* prng) { + int table_current = domain.get_cluster_assignment(item); + return table_current == table + ? logp_gibbs_approx_current(domain, item) + : logp_gibbs_approx_variant(domain, item, table, prng); + } + + // Implementation of exact Gibbs data probabilities. + + std::unordered_map const, std::vector, + VectorIntHash> + get_cluster_to_items_list(Domain const& domain, const T_item& item) { + std::unordered_map, std::vector, + VectorIntHash> + m; + for (const T_items& items : data_r.at(domain.name).at(item)) { + T_items z = get_cluster_assignment(items); + m[z].push_back(items); + } + return m; + } + + double logp_gibbs_exact_current(const std::vector& items_list) { + assert(!items_list.empty()); + T_items z = get_cluster_assignment(items_list[0]); + auto cluster = clusters.at(z); + double logp0 = cluster->logp_score(); + for (const T_items& items : items_list) { + ValueType x = data.at(items); + // assert(z == get_cluster_assignment(items)); + cluster->unincorporate(x); + } + double logp1 = cluster->logp_score(); + for (const T_items& items : items_list) { + ValueType x = data.at(items); + cluster->incorporate(x); + } + assert(cluster->logp_score() == logp0); + return logp0 - logp1; + } + + double logp_gibbs_exact_variant(const Domain& domain, const T_item& item, + int table, + const std::vector& items_list, + std::mt19937* prng) { + assert(!items_list.empty()); + T_items z = + get_cluster_assignment_gibbs(items_list[0], domain, item, table); + + Distribution* prior = make_new_distribution(prng); + Distribution* cluster = + clusters.contains(z) ? clusters.at(z) : prior; + double logp0 = cluster->logp_score(); + for (const T_items& items : items_list) { + // assert(z == get_cluster_assignment_gibbs(items, domain, item, table)); + ValueType x = data.at(items); + cluster->incorporate(x); + } + const double logp1 = cluster->logp_score(); + for (const T_items& items : items_list) { + ValueType x = data.at(items); + cluster->unincorporate(x); + } + assert(cluster->logp_score() == logp0); + delete prior; + return logp1 - logp0; + } + + std::vector logp_gibbs_exact(const Domain& domain, const T_item& item, + std::vector tables, + std::mt19937* prng) { + auto cluster_to_items_list = get_cluster_to_items_list(domain, item); + int table_current = domain.get_cluster_assignment(item); + std::vector logps; + logps.reserve(tables.size()); + double lp_cluster; + for (const int& table : tables) { + double lp_table = 0; + for (const auto& [z, items_list] : cluster_to_items_list) { + lp_cluster = (table == table_current) + ? logp_gibbs_exact_current(items_list) + : logp_gibbs_exact_variant(domain, item, table, + items_list, prng); + lp_table += lp_cluster; + } + logps.push_back(lp_table); + } + return logps; + } + + double logp(const T_items& items, ValueType value, std::mt19937* prng) { + // TODO: Falsely assumes cluster assignments of items + // from same domain are identical, see note in hirm.py + assert(items.size() == domains.size()); + std::vector> tabl_list; + std::vector> wght_list; + std::vector> indx_list; + for (int i = 0; i < std::ssize(domains); ++i) { + Domain* domain = domains.at(i); + T_item item = items.at(i); + std::vector t_list; + std::vector w_list; + std::vector i_list; + if (domain->items.contains(item)) { + int z = domain->get_cluster_assignment(item); + t_list = {z}; + w_list = {0}; + i_list = {0}; + } else { + auto tables_weights = domain->tables_weights(); + double Z = log(domain->crp.alpha + domain->crp.N); + int idx = 0; + for (const auto& [t, w] : tables_weights) { + t_list.push_back(t); + w_list.push_back(log(w) - Z); + i_list.push_back(idx++); + } + assert(idx == std::ssize(t_list)); + } + tabl_list.push_back(t_list); + wght_list.push_back(w_list); + indx_list.push_back(i_list); + } + std::vector logps; + for (const auto& indexes : product(indx_list)) { + assert(indexes.size() == domains.size()); + std::vector z; + z.reserve(domains.size()); + double logp_w = 0; + for (int i = 0; i < std::ssize(domains); ++i) { + T_item zi = tabl_list.at(i).at(indexes[i]); + double wi = wght_list.at(i).at(indexes[i]); + z.push_back(zi); + logp_w += wi; + } + Distribution* prior = make_new_distribution(prng); + Distribution* cluster = + clusters.contains(z) ? clusters.at(z) : prior; + double logp_z = cluster->logp(value); + double logp_zw = logp_z + logp_w; + logps.push_back(logp_zw); + delete prior; + } + return logsumexp(logps); + } + + double logp_score() const { + double logp = 0.0; + for (const auto& [_, cluster] : clusters) { + logp += cluster->logp_score(); + } + return logp; + } + + void set_cluster_assignment_gibbs(const Domain& domain, const T_item& item, + int table, std::mt19937* prng) { + int table_current = domain.get_cluster_assignment(item); + assert(table != table_current); + for (const T_items& items : data_r.at(domain.name).at(item)) { + ValueType x = data.at(items); + // Remove from current cluster. + T_items z_prev = get_cluster_assignment(items); + auto cluster_prev = clusters.at(z_prev); + cluster_prev->unincorporate(x); + if (cluster_prev->N == 0) { + delete clusters.at(z_prev); + clusters.erase(z_prev); + } + // Move to desired cluster. + T_items z_new = get_cluster_assignment_gibbs(items, domain, item, table); + if (!clusters.contains(z_new)) { + // Move to fresh cluster. + clusters[z_new] = make_new_distribution(prng); + clusters.at(z_new)->incorporate(x); + } else { + // Move to existing cluster. + assert((clusters.at(z_new)->N > 0)); + clusters.at(z_new)->incorporate(x); + } + } + // Caller should invoke domain.set_cluster_gibbs + } + + bool has_observation(const Domain& domain, const T_item& item) const { + return data_r.at(domain.name).contains(item); + } + + const std::vector& get_domains() const { return domains; } + + const ValueType& get_value(const T_items& items) const { + return data.at(items); + } + + const std::unordered_map& get_data() + const { + return data; + } + + void transition_cluster_hparams(std::mt19937* prng, int num_theta_steps) { + for (const auto& [c, distribution] : clusters) { + for (int i = 0; i < num_theta_steps; ++i) { + distribution->transition_theta(prng); + } + distribution->transition_hyperparameters(prng); + } + } + + // Disable copying. + CleanRelation& operator=(const CleanRelation&) = delete; + CleanRelation(const CleanRelation&) = delete; +}; diff --git a/cxx/relation_test.cc b/cxx/clean_relation_test.cc similarity index 89% rename from cxx/relation_test.cc rename to cxx/clean_relation_test.cc index 2331049..191ab9f 100644 --- a/cxx/relation_test.cc +++ b/cxx/clean_relation_test.cc @@ -2,7 +2,7 @@ #define BOOST_TEST_MODULE test Relation -#include "relation.hh" +#include "clean_relation.hh" #include #include @@ -13,7 +13,7 @@ namespace tt = boost::test_tools; -BOOST_AUTO_TEST_CASE(test_relation) { +BOOST_AUTO_TEST_CASE(test_clean_relation) { std::mt19937 prng; Domain D1("D1"); Domain D2("D2"); @@ -22,7 +22,7 @@ BOOST_AUTO_TEST_CASE(test_relation) { D2.incorporate(&prng, 1); D3.incorporate(&prng, 3); DistributionSpec spec = DistributionSpec{DistributionEnum::bernoulli}; - Relation R1("R1", spec, {&D1, &D2, &D3}); + CleanRelation R1("R1", spec, {&D1, &D2, &D3}); R1.incorporate(&prng, {0, 1, 3}, 1); R1.incorporate(&prng, {1, 1, 3}, 1); R1.incorporate(&prng, {3, 1, 3}, 1); @@ -42,7 +42,7 @@ BOOST_AUTO_TEST_CASE(test_relation) { BOOST_TEST(z2[1] == 191); BOOST_TEST(z2[2] == 0); - double lpg __attribute__ ((unused)); + double lpg __attribute__((unused)); lpg = R1.logp_gibbs_approx(D1, 0, 1, &prng); lpg = R1.logp_gibbs_approx(D1, 0, 0, &prng); lpg = R1.logp_gibbs_approx(D1, 0, 10, &prng); @@ -54,7 +54,7 @@ BOOST_AUTO_TEST_CASE(test_relation) { BOOST_TEST(db->N == 1); DistributionSpec bigram_spec = DistributionSpec{DistributionEnum::bigram}; - Relation R2("R1", bigram_spec, {&D2, &D3}); + CleanRelation R2("R1", bigram_spec, {&D2, &D3}); R2.incorporate(&prng, {1, 3}, "cat"); R2.incorporate(&prng, {1, 2}, "dog"); R2.incorporate(&prng, {1, 4}, "catt"); diff --git a/cxx/emissions/BUILD b/cxx/emissions/BUILD index 160934b..69a71be 100644 --- a/cxx/emissions/BUILD +++ b/cxx/emissions/BUILD @@ -1,5 +1,18 @@ licenses(["notice"]) +cc_library( + name = "emissions", + visibility = ["//:__subpackages__"], + deps = [ + ":base", + ":get_emission", + ":bitflip", + ":gaussian", + ":simple_string", + ":sometimes", + ], +) + cc_library( name = "base", hdrs = ["base.hh"], diff --git a/cxx/emissions/sometimes.hh b/cxx/emissions/sometimes.hh index 8c9f2e6..d7631f9 100644 --- a/cxx/emissions/sometimes.hh +++ b/cxx/emissions/sometimes.hh @@ -6,16 +6,18 @@ #include "emissions/base.hh" // An Emission class that sometimes applies BaseEmissor and sometimes doesn't. -// BaseEmissor must assign zero probability to pairs with +// BaseEmissor must assign zero probability to pairs with // clean == dirty. [For example, BitFlip and Gaussian both satisfy this]. template -class Sometimes : public Emission::type> { +class Sometimes : public Emission::type> { public: - using SampleType = typename std::tuple_element<0, typename BaseEmissor::SampleType>::type; + using SampleType = + typename std::tuple_element<0, typename BaseEmissor::SampleType>::type; BetaBernoulli bb; BaseEmissor be; - Sometimes() {}; + Sometimes(){}; void incorporate(const std::pair& x) { ++(this->N); @@ -34,7 +36,10 @@ class Sometimes : public Emission& x) const { - return bb.logp(x.first != x.second) + be.logp(x); + if (x.first != x.second) { + return bb.logp(true) + be.logp(x); + } + return bb.logp(false); } double logp_score() const { return bb.logp_score() + be.logp_score(); } diff --git a/cxx/emissions/sometimes_test.cc b/cxx/emissions/sometimes_test.cc index 09b34df..d5c4a4f 100644 --- a/cxx/emissions/sometimes_test.cc +++ b/cxx/emissions/sometimes_test.cc @@ -16,6 +16,7 @@ BOOST_AUTO_TEST_CASE(test_simple) { BOOST_TEST(sbf.N == 0); sbf.incorporate(std::make_pair(true, false)); BOOST_TEST(sbf.logp_score() < 0.0); + BOOST_TEST(exp(sbf.logp(std::make_pair(true, false))) + exp(sbf.logp(std::make_pair(true, true))) == 1.0); BOOST_TEST(sbf.N == 1); sbf.unincorporate(std::make_pair(true, false)); BOOST_TEST(sbf.logp_score() == orig_lp); diff --git a/cxx/hirm.cc b/cxx/hirm.cc index 838e356..7313f5d 100644 --- a/cxx/hirm.cc +++ b/cxx/hirm.cc @@ -54,8 +54,7 @@ void HIRM::transition_cluster_assignment_relation(std::mt19937* prng, int rc = relation_to_code.at(r); int table_current = crp.assignments.at(rc); RelationVariant relation = get_relation(r); - T_relation t_relation = - std::visit([](auto rel) { return rel->trel; }, relation); + T_relation t_relation = schema.at(r); auto crp_dist = crp.tables_weights_gibbs(table_current); std::vector tables; std::vector logps; @@ -78,7 +77,7 @@ void HIRM::transition_cluster_assignment_relation(std::mt19937* prng, irm->add_relation(r, t_relation); std::visit( [&](auto rel) { - for (const auto& [items, value] : rel->data) { + for (const auto& [items, value] : rel->get_data()) { irm->incorporate(prng, r, items, value); } }, @@ -133,9 +132,9 @@ void HIRM::set_cluster_assignment_gibbs( int table_current = crp.assignments.at(rc); RelationVariant relation = get_relation(r); auto f_obs = [&](auto rel) { - T_relation trel = rel->trel; + T_relation trel = schema.at(r); IRM* irm = relation_to_irm(r); - auto observations = rel->data; + auto observations = rel->get_data(); // Remove from current IRM. irm->remove_relation(r); if (irm->relations.empty()) { diff --git a/cxx/hirm.hh b/cxx/hirm.hh index 79b86b5..b854149 100644 --- a/cxx/hirm.hh +++ b/cxx/hirm.hh @@ -12,7 +12,6 @@ #include "relation.hh" #include "util_distribution_variant.hh" - class HIRM { public: T_schema schema; // schema of relations @@ -49,7 +48,8 @@ class HIRM { double logp( const std::vector>& - observations, std::mt19937* prng); + observations, + std::mt19937* prng); double logp_score() const; diff --git a/cxx/hirm_main.cc b/cxx/hirm_main.cc index f829596..ab6457a 100644 --- a/cxx/hirm_main.cc +++ b/cxx/hirm_main.cc @@ -1,13 +1,12 @@ // Copyright 2021 MIT Probabilistic Computing Project // Apache License, Version 2.0, refer to LICENSE.txt -#include "hirm.hh" - #include #include #include #include "cxxopts.hpp" +#include "hirm.hh" #include "irm.hh" #include "util_io.hh" diff --git a/cxx/irm.cc b/cxx/irm.cc index ea10ae9..529c02d 100644 --- a/cxx/irm.cc +++ b/cxx/irm.cc @@ -29,8 +29,7 @@ void IRM::incorporate(std::mt19937* prng, const std::string& r, std::visit( [&](auto rel) { auto v = std::get< - typename std::remove_reference_t::ValueType>( - value); + typename std::remove_reference_t::ValueType>(value); rel->incorporate(prng, items, v); }, relations.at(r)); @@ -105,18 +104,18 @@ void IRM::transition_cluster_assignment_item(std::mt19937* prng, double IRM::logp( const std::vector>& - observations, std::mt19937* prng) { + observations, + std::mt19937* prng) { std::unordered_map> relation_items_seen; - std::unordered_map> - domain_item_seen; + std::unordered_map> domain_item_seen; std::vector> item_universe; std::vector> index_universe; std::vector> weight_universe; std::unordered_map< std::string, std::unordered_map>>> - cluster_universe; + cluster_universe; // Compute all cluster combinations. for (const auto& [r, items, value] : observations) { // Assert observation is unique. @@ -124,13 +123,13 @@ double IRM::logp( relation_items_seen[r].insert(items); // Process each (domain, item) in the observations. RelationVariant relation = relations.at(r); - int arity = - std::visit([](auto rel) { return rel->domains.size(); }, relation); + int arity = std::visit([](auto rel) { return rel->get_domains().size(); }, + relation); assert(std::ssize(items) == arity); for (int i = 0; i < arity; ++i) { // Skip if (domain, item) processed. - Domain* domain = - std::visit([&](auto rel) { return rel->domains.at(i); }, relation); + Domain* domain = std::visit( + [&](auto rel) { return rel->get_domains().at(i); }, relation); T_item item = items.at(i); if (domain_item_seen[domain->name].contains(item)) { assert(cluster_universe[domain->name].contains(item)); @@ -188,25 +187,17 @@ double IRM::logp( const ObservationVariant& value) -> double { std::vector z; z.reserve(domains.size()); - for (int i = 0; i < std::ssize(rel->domains); ++i) { - Domain* domain = rel->domains.at(i); + for (int i = 0; i < std::ssize(rel->get_domains()); ++i) { + Domain* domain = rel->get_domains().at(i); T_item item = items.at(i); auto& [loc, t_list] = cluster_universe.at(domain->name).at(item); T_item t = t_list.at(indexes.at(loc)); z.push_back(t); } - auto v = std::get< - typename std::remove_reference_t::ValueType>(value); - if (rel->clusters.contains(z)) { - return rel->clusters.at(z)->logp(v); - } - DistributionVariant prior = cluster_prior_from_spec(rel->dist_spec, prng); - return std::visit( - [&](const auto& dist_variant) { - auto v2 = std::get< - typename std::remove_reference_t::SampleType>(value); - return dist_variant->logp(v2); }, prior); + auto v = + std::get::ValueType>( + value); + return rel->cluster_or_prior_logp(prng, z, v); }; for (const auto& [r, items, value] : observations) { auto g = std::bind(f_logp, std::placeholders::_1, items, value); @@ -245,15 +236,14 @@ void IRM::add_relation(const std::string& name, const T_relation& relation) { domain_to_relations.at(d).insert(name); doms.push_back(domains.at(d)); } - relations[name] = - relation_from_spec(name, relation.distribution_spec, doms); + relations[name] = relation_from_spec(name, relation.distribution_spec, doms); schema[name] = relation; } void IRM::remove_relation(const std::string& name) { std::unordered_set ds; auto rel_domains = - std::visit([](auto r) { return r->domains; }, relations.at(name)); + std::visit([](auto r) { return r->get_domains(); }, relations.at(name)); for (const Domain* const domain : rel_domains) { ds.insert(domain->name); } @@ -271,7 +261,6 @@ void IRM::remove_relation(const std::string& name) { schema.erase(name); } - #define GET_ELAPSED(t) double(clock() - t) / CLOCKS_PER_SEC // TODO(emilyaf): Refactor as a function for readibility/maintainability. @@ -308,14 +297,9 @@ void single_step_irm_inference(std::mt19937* prng, IRM* irm, double& t_total, for (const auto& [r, relation] : irm->relations) { std::visit( [&](auto r) { - for (const auto& [c, distribution] : r->clusters) { - clock_t t = clock(); - for (int i = 0; i < num_theta_steps; ++i ) { - distribution->transition_theta(prng); - } - distribution->transition_hyperparameters(prng); - REPORT_SCORE(verbose, t, t_total, irm); - } + clock_t t = clock(); + r->transition_cluster_hparams(prng, num_theta_steps); + REPORT_SCORE(verbose, t, t_total, irm); }, relation); } @@ -326,4 +310,3 @@ void single_step_irm_inference(std::mt19937* prng, IRM* irm, double& t_total, REPORT_SCORE(verbose, t, t_total, irm); } } - diff --git a/cxx/irm.hh b/cxx/irm.hh index 8fc6125..95811bb 100644 --- a/cxx/irm.hh +++ b/cxx/irm.hh @@ -6,10 +6,13 @@ #include #include -#include "relation.hh" +#include "clean_relation.hh" #include "relation_variant.hh" #include "util_distribution_variant.hh" +// TODO(emilyaf): Support noisy relations. +using T_relation = T_clean_relation; + // Map from names to T_relation's. typedef std::map T_schema; @@ -41,7 +44,8 @@ class IRM { const T_item& item); double logp( const std::vector>& - observations, std::mt19937* prng); + observations, + std::mt19937* prng); double logp_score() const; @@ -54,7 +58,6 @@ class IRM { IRM(const IRM&) = delete; }; - // Run a single step of inference on an IRM model. void single_step_irm_inference(std::mt19937* prng, IRM* irm, double& t_total, bool verbose, int num_theta_steps = 10); diff --git a/cxx/irm_test.cc b/cxx/irm_test.cc index e3c1d39..836c94e 100644 --- a/cxx/irm_test.cc +++ b/cxx/irm_test.cc @@ -3,17 +3,20 @@ #define BOOST_TEST_MODULE test IRM #include "irm.hh" -#include "util_distribution_variant.hh" #include + +#include "util_distribution_variant.hh" namespace tt = boost::test_tools; BOOST_AUTO_TEST_CASE(test_irm) { std::map schema1{ - {"R1", T_relation{{"D1", "D1"}, DistributionSpec {DistributionEnum::bernoulli}}}, - {"R2", T_relation{{"D1", "D2"}, DistributionSpec {DistributionEnum::normal}}}, - {"R3", T_relation{{"D3", "D1"}, DistributionSpec {DistributionEnum::bigram}}} - }; + {"R1", + T_relation{{"D1", "D1"}, DistributionSpec{DistributionEnum::bernoulli}}}, + {"R2", + T_relation{{"D1", "D2"}, DistributionSpec{DistributionEnum::normal}}}, + {"R3", + T_relation{{"D3", "D1"}, DistributionSpec{DistributionEnum::bigram}}}}; IRM irm(schema1); BOOST_TEST(irm.logp_score() == 0.0); diff --git a/cxx/noisy_relation.hh b/cxx/noisy_relation.hh new file mode 100644 index 0000000..e9ff1f6 --- /dev/null +++ b/cxx/noisy_relation.hh @@ -0,0 +1,141 @@ +// Copyright 2020 +// See LICENSE.txt + +#pragma once + +#include +#include +#include +#include +#include + +#include "clean_relation.hh" +#include "distributions/base.hh" +#include "domain.hh" +#include "emissions/base.hh" +#include "relation.hh" +#include "util_distribution_variant.hh" +#include "util_hash.hh" +#include "util_math.hh" + +// T_noisy_relation is the text we get from reading a line of the schema file; +// NoisyRelation is the object that does the work. +class T_noisy_relation { + public: + // The relation is a map from the domains to the space .distribution + // is a distribution over. + std::vector domains; + + // Name of the relation for the "true" values that the NoisyRelation observes. + std::string base_relation; + + // Indicates if the NoisyRelation's values are observed or latent. + bool is_observed; + + // Describes the Emission that models the noise. + EmissionSpec emission_spec; +}; + +template +class NoisyRelation : public Relation { + public: + typedef T ValueType; + + // human-readable name + const std::string name; + // list of domain pointers + const std::vector domains; + // map from item to observed data + std::unordered_map data; + // Base relation for "" values. + const Relation* base_relation; + // A Relation for the Emission that models noisy values given values. + CleanRelation> emission_relation; + + NoisyRelation(const std::string& name, const EmissionSpec& emission_spec, + const std::vector& domains, Relation* base_relation) + : name(name), + domains(domains), + base_relation(base_relation), + emission_relation(name + "_emission", emission_spec, domains) {} + + void incorporate(std::mt19937* prng, const T_items& items, ValueType value) { + data[items] = value; + const ValueType _val = get_base_value(items); + emission_relation.incorporate(prng, items, std::make_pair(_val, value)); + } + + void unincorporate(const T_items& items) { + emission_relation.unincorporate(items); + } + + double logp_gibbs_approx(const Domain& domain, const T_item& item, int table, + std::mt19937* prng) { + return emission_relation.logp_gibbs_approx(domain, item, table, prng); + } + + std::vector logp_gibbs_exact(const Domain& domain, const T_item& item, + std::vector tables, + std::mt19937* prng) { + return emission_relation.logp_gibbs_exact(domain, item, tables, prng); + } + + double logp(const T_items& items, ValueType value, std::mt19937* prng) { + const ValueType _val = get_base_value(items); + return emission_relation.logp(items, std::make_pair(_val, value), prng); + } + + double logp_score() const { return emission_relation.logp_score(); } + + void set_cluster_assignment_gibbs(const Domain& domain, const T_item& item, + int table, std::mt19937* prng) { + emission_relation.set_cluster_assignment_gibbs(domain, item, table, prng); + } + + bool has_observation(const Domain& domain, const T_item& item) const { + return emission_relation.has_observation(domain, item); + } + + const ValueType& get_value(const T_items& items) const { + return std::get<0>(emission_relation.get_value(items)); + } + + const std::unordered_map& get_data() + const { + return data; + } + + const std::vector& get_domains() const { return domains; } + + const ValueType get_base_value(const T_items& items) const { + size_t base_arity = base_relation->get_domains().size(); + T_items base_items(items.cbegin(), items.cbegin() + base_arity); + return base_relation->get_value(base_items); + } + + void transition_cluster_hparams(std::mt19937* prng, int num_theta_steps) { + emission_relation.transition_cluster_hparams(prng, num_theta_steps); + } + + std::vector get_cluster_assignment(const T_items& items) const { + return emission_relation.get_cluster_assignment(items); + } + + double cluster_or_prior_logp(std::mt19937* prng, const T_items& z, + const ValueType& value) const { + const ValueType base_value = get_base_value(z); + if (emission_relation.clusters.contains(z)) { + return emission_relation.clusters.at(z)->logp( + std::make_pair(base_value, value)); + } + auto emission_prior = emission_relation.make_new_distribution(prng); + double emission_logp = + emission_prior->logp(std::make_pair(base_value, value)); + delete emission_prior; + return emission_logp; + } + + // Disable copying. + NoisyRelation& operator=(const NoisyRelation&) = delete; + NoisyRelation(const NoisyRelation&) = delete; +}; diff --git a/cxx/noisy_relation_test.cc b/cxx/noisy_relation_test.cc new file mode 100644 index 0000000..12f0845 --- /dev/null +++ b/cxx/noisy_relation_test.cc @@ -0,0 +1,74 @@ +// Apache License, Version 2.0, refer to LICENSE.txt + +#define BOOST_TEST_MODULE test Relation + +#include "noisy_relation.hh" + +#include +#include +#include + +#include "clean_relation.hh" +#include "distributions/beta_bernoulli.hh" +#include "distributions/bigram.hh" +#include "domain.hh" + +namespace tt = boost::test_tools; + +BOOST_AUTO_TEST_CASE(test_noisy_relation) { + std::mt19937 prng; + Domain D1("D1"); + Domain D2("D2"); + Domain D3("D3"); + D1.incorporate(&prng, 0); + D2.incorporate(&prng, 1); + D3.incorporate(&prng, 3); + DistributionSpec spec = DistributionSpec{DistributionEnum::bernoulli}; + CleanRelation R1("R1", spec, {&D1, &D2}); + R1.incorporate(&prng, {0, 1}, 1); + R1.incorporate(&prng, {1, 1}, 1); + R1.incorporate(&prng, {3, 1}, 1); + R1.incorporate(&prng, {4, 1}, 1); + R1.incorporate(&prng, {5, 1}, 1); + + EmissionSpec em_spec = EmissionSpec(EmissionEnum::sometimes_bitflip); + NoisyRelation NR1("NR1", em_spec, {&D1, &D2, &D3}, &R1); + NR1.incorporate(&prng, {0, 1, 3}, 0); + NR1.incorporate(&prng, {1, 1, 3}, 1); + NR1.incorporate(&prng, {3, 1, 3}, 0); + NR1.incorporate(&prng, {4, 1, 3}, 1); + NR1.incorporate(&prng, {5, 1, 3}, 0); + NR1.incorporate(&prng, {0, 1, 4}, 1); + NR1.incorporate(&prng, {0, 1, 6}, 0); + auto z1 = NR1.get_cluster_assignment({0, 1, 3}); + BOOST_TEST(z1.size() == 3); + BOOST_TEST(z1[0] == 0); + BOOST_TEST(z1[1] == 0); + BOOST_TEST(z1[2] == 0); + + double lpg __attribute__((unused)); + lpg = NR1.logp_gibbs_approx(D1, 0, 1, &prng); + lpg = NR1.logp_gibbs_approx(D1, 0, 0, &prng); + lpg = NR1.logp_gibbs_approx(D1, 0, 10, &prng); + NR1.set_cluster_assignment_gibbs(D1, 0, 1, &prng); + + DistributionSpec bigram_spec = DistributionSpec{DistributionEnum::bigram}; + CleanRelation R2("R2", bigram_spec, {&D2, &D3}); + EmissionSpec str_emspec = EmissionSpec(EmissionEnum::simple_string); + NoisyRelation NR2("NR2", str_emspec, {&D2, &D3}, &R2); + + R2.incorporate(&prng, {1, 3}, "cat"); + R2.incorporate(&prng, {2, 3}, "cat"); + R2.incorporate(&prng, {1, 2}, "dog"); + R2.incorporate(&prng, {2, 6}, "fish"); + + NR2.incorporate(&prng, {1, 3}, "catt"); + NR2.incorporate(&prng, {2, 3}, "at"); + NR2.incorporate(&prng, {1, 2}, "doge"); + NR2.incorporate(&prng, {2, 6}, "fish"); + + NR2.transition_cluster_hparams(&prng, 4); + lpg = NR2.logp_gibbs_approx(D2, 2, 0, &prng); + NR2.set_cluster_assignment_gibbs(D3, 3, 1, &prng); + D1.set_cluster_assignment_gibbs(0, 1); +} diff --git a/cxx/relation.hh b/cxx/relation.hh index 9464413..ec8ae06 100644 --- a/cxx/relation.hh +++ b/cxx/relation.hh @@ -8,372 +8,48 @@ #include #include -#include "distributions/base.hh" #include "domain.hh" #include "util_distribution_variant.hh" #include "util_hash.hh" -#include "util_math.hh" typedef std::vector T_items; typedef VectorIntHash H_items; -// T_relation is the text we get from reading a line of the schema file; -// hirm.hh:Relation is the object that does the work. -class T_relation { - public: - // The relation is a map from the domains to the space .distribution - // is a distribution over. - std::vector domains; - - DistributionSpec distribution_spec; -}; - template class Relation { public: typedef T ValueType; - // human-readable name - const std::string name; - // Relation spec. - T_relation trel; - // Distribution spec over the relation's codomain. - const DistributionSpec dist_spec; - // list of domain pointers - const std::vector domains; - // map from cluster multi-index to Distribution pointer - std::unordered_map< - const std::vector, Distribution*, VectorIntHash> clusters; - // map from item to observed data - std::unordered_map data; - // map from domain name to reverse map from item to - // set of items that include that item - std::unordered_map< - std::string, - std::unordered_map>> - data_r; - - Relation(const std::string& name, const DistributionSpec& dist_spec, - const std::vector& domains) - : name(name), dist_spec(dist_spec), domains(domains) { - assert(!domains.empty()); - assert(!name.empty()); - for (const Domain* const d : domains) { - this->data_r[d->name] = - std::unordered_map>(); - } - std::vector domain_names; - for (const auto& d : domains) { - domain_names.push_back(d->name); - } - trel = {domain_names, dist_spec}; - } - - ~Relation() { - for (auto [z, cluster] : clusters) { - delete cluster; - } - } - - Distribution* make_new_distribution(std::mt19937* prng) { - return std::visit([&](auto dist_variant) { - // In practice, the DistributionVariant returned by - // cluster_prior_from_spec will always be of type - // Distribution*, so this reinterpret_cast is a no-op. - return reinterpret_cast*>(dist_variant); - }, cluster_prior_from_spec(dist_spec, prng)); - } - - void incorporate(std::mt19937* prng, const T_items& items, ValueType value) { - assert(!data.contains(items)); - data[items] = value; - for (int i = 0; i < std::ssize(domains); ++i) { - domains[i]->incorporate(prng, items[i]); - if (!data_r.at(domains[i]->name).contains(items[i])) { - data_r.at(domains[i]->name)[items[i]] = - std::unordered_set(); - } - data_r.at(domains[i]->name).at(items[i]).insert(items); - } - T_items z = get_cluster_assignment(items); - if (!clusters.contains(z)) { - clusters[z] = make_new_distribution(prng); - } - clusters.at(z)->incorporate(value); - } + virtual void incorporate(std::mt19937* prng, const T_items& items, ValueType value) = 0; + + virtual void unincorporate(const T_items& items) = 0; - void unincorporate(const T_items& items) { - printf("Not implemented\n"); - exit(EXIT_FAILURE); - // auto x = data.at(items); - // auto z = get_cluster_assignment(items); - // clusters.at(z)->unincorporate(x); - // if (clusters.at(z)->N == 0) { - // delete clusters.at(z); - // clusters.erase(z); - // } - // for (int i = 0; i < domains.size(); i++) { - // const std::string &n = domains[i]->name; - // if (data_r.at(n).count(items[i]) > 0) { - // data_r.at(n).at(items[i]).erase(items); - // if (data_r.at(n).at(items[i]).size() == 0) { - // data_r.at(n).erase(items[i]); - // domains[i]->unincorporate(name, items[i]); - // } - // } - // } - // data.erase(items); - } + virtual double logp(const T_items& items, ValueType value, std::mt19937* prng) = 0; - std::vector get_cluster_assignment(const T_items& items) const { - assert(items.size() == domains.size()); - std::vector z(domains.size()); - for (int i = 0; i < std::ssize(domains); ++i) { - z[i] = domains[i]->get_cluster_assignment(items[i]); - } - return z; - } + virtual double logp_score() const = 0; - std::vector get_cluster_assignment_gibbs(const T_items& items, - const Domain& domain, - const T_item& item, - int table) const { - assert(items.size() == domains.size()); - std::vector z(domains.size()); - int hits = 0; - for (int i = 0; i < std::ssize(domains); ++i) { - if ((domains[i]->name == domain.name) && (items[i] == item)) { - z[i] = table; - ++hits; - } else { - z[i] = domains[i]->get_cluster_assignment(items[i]); - } - } - assert(hits > 0); - return z; - } + virtual double cluster_or_prior_logp(std::mt19937* prng, const T_items& items, const ValueType& value) const = 0; - // Implementation of approximate Gibbs data probabilities (faster). - - double logp_gibbs_approx_current(const Domain& domain, const T_item& item) { - double logp = 0.; - for (const T_items& items : data_r.at(domain.name).at(item)) { - ValueType x = data.at(items); - T_items z = get_cluster_assignment(items); - auto cluster = clusters.at(z); - cluster->unincorporate(x); - double lp = cluster->logp(x); - cluster->incorporate(x); - logp += lp; - } - return logp; - } - - double logp_gibbs_approx_variant(const Domain& domain, const T_item& item, - int table, std::mt19937* prng) { - double logp = 0.; - for (const T_items& items : data_r.at(domain.name).at(item)) { - ValueType x = data.at(items); - T_items z = get_cluster_assignment_gibbs(items, domain, item, table); - double lp; - if (!clusters.contains(z)) { - Distribution* tmp_dist = make_new_distribution(prng); - lp = tmp_dist->logp(x); - delete tmp_dist; - } else { - lp = clusters.at(z)->logp(x); - } - logp += lp; - } - return logp; - } - - double logp_gibbs_approx(const Domain& domain, const T_item& item, - int table, std::mt19937* prng) { - int table_current = domain.get_cluster_assignment(item); - return table_current == table - ? logp_gibbs_approx_current(domain, item) - : logp_gibbs_approx_variant(domain, item, table, prng); - } - - // Implementation of exact Gibbs data probabilities. - - std::unordered_map const, std::vector, - VectorIntHash> - get_cluster_to_items_list(Domain const& domain, const T_item& item) { - std::unordered_map, std::vector, - VectorIntHash> - m; - for (const T_items& items : data_r.at(domain.name).at(item)) { - T_items z = get_cluster_assignment(items); - m[z].push_back(items); - } - return m; - } - - double logp_gibbs_exact_current(const std::vector& items_list) { - assert(!items_list.empty()); - T_items z = get_cluster_assignment(items_list[0]); - auto cluster = clusters.at(z); - double logp0 = cluster->logp_score(); - for (const T_items& items : items_list) { - ValueType x = data.at(items); - // assert(z == get_cluster_assignment(items)); - cluster->unincorporate(x); - } - double logp1 = cluster->logp_score(); - for (const T_items& items : items_list) { - ValueType x = data.at(items); - cluster->incorporate(x); - } - assert(cluster->logp_score() == logp0); - return logp0 - logp1; - } + virtual std::vector logp_gibbs_exact( + const Domain& domain, const T_item& item, std::vector tables, + std::mt19937* prng) = 0; - double logp_gibbs_exact_variant( - const Domain& domain, const T_item& item, int table, - const std::vector& items_list, std::mt19937* prng) { - assert(!items_list.empty()); - T_items z = - get_cluster_assignment_gibbs(items_list[0], domain, item, table); + virtual void set_cluster_assignment_gibbs(const Domain& domain, const T_item& item, + int table, std::mt19937* prng) = 0; - Distribution* prior = make_new_distribution(prng); - Distribution* cluster = clusters.contains(z) ? clusters.at(z) : prior; - double logp0 = cluster->logp_score(); - for (const T_items& items : items_list) { - // assert(z == get_cluster_assignment_gibbs(items, domain, item, table)); - ValueType x = data.at(items); - cluster->incorporate(x); - } - const double logp1 = cluster->logp_score(); - for (const T_items& items : items_list) { - ValueType x = data.at(items); - cluster->unincorporate(x); - } - assert(cluster->logp_score() == logp0); - delete prior; - return logp1 - logp0; - } + virtual void transition_cluster_hparams(std::mt19937* prng, int num_theta_steps) = 0; + + // Accessor/convenience methods, mostly for subclass members that can't be accessed through the base class. + virtual const std::vector& get_domains() const = 0; - std::vector logp_gibbs_exact( - const Domain& domain, const T_item& item, std::vector tables, - std::mt19937* prng) { - auto cluster_to_items_list = get_cluster_to_items_list(domain, item); - int table_current = domain.get_cluster_assignment(item); - std::vector logps; - logps.reserve(tables.size()); - double lp_cluster; - for (const int& table : tables) { - double lp_table = 0; - for (const auto& [z, items_list] : cluster_to_items_list) { - lp_cluster = - (table == table_current) - ? logp_gibbs_exact_current(items_list) - : logp_gibbs_exact_variant(domain, item, table, items_list, prng); - lp_table += lp_cluster; - } - logps.push_back(lp_table); - } - return logps; - } + virtual const ValueType& get_value(const T_items& items) const = 0; - double logp(const T_items& items, ValueType value, std::mt19937* prng) { - // TODO: Falsely assumes cluster assignments of items - // from same domain are identical, see note in hirm.py - assert(items.size() == domains.size()); - std::vector> tabl_list; - std::vector> wght_list; - std::vector> indx_list; - for (int i = 0; i < std::ssize(domains); ++i) { - Domain* domain = domains.at(i); - T_item item = items.at(i); - std::vector t_list; - std::vector w_list; - std::vector i_list; - if (domain->items.contains(item)) { - int z = domain->get_cluster_assignment(item); - t_list = {z}; - w_list = {0}; - i_list = {0}; - } else { - auto tables_weights = domain->tables_weights(); - double Z = log(domain->crp.alpha + domain->crp.N); - int idx = 0; - for (const auto& [t, w] : tables_weights) { - t_list.push_back(t); - w_list.push_back(log(w) - Z); - i_list.push_back(idx++); - } - assert(idx == std::ssize(t_list)); - } - tabl_list.push_back(t_list); - wght_list.push_back(w_list); - indx_list.push_back(i_list); - } - std::vector logps; - for (const auto& indexes : product(indx_list)) { - assert(indexes.size() == domains.size()); - std::vector z; - z.reserve(domains.size()); - double logp_w = 0; - for (int i = 0; i < std::ssize(domains); ++i) { - T_item zi = tabl_list.at(i).at(indexes[i]); - double wi = wght_list.at(i).at(indexes[i]); - z.push_back(zi); - logp_w += wi; - } - Distribution* prior = make_new_distribution(prng); - Distribution* cluster = clusters.contains(z) ? clusters.at(z) : prior; - double logp_z = cluster->logp(value); - double logp_zw = logp_z + logp_w; - logps.push_back(logp_zw); - delete prior; - } - return logsumexp(logps); - } + virtual const std::unordered_map& get_data() const = 0; - double logp_score() const { - double logp = 0.0; - for (const auto& [_, cluster] : clusters) { - logp += cluster->logp_score(); - } - return logp; - } + virtual std::vector get_cluster_assignment(const T_items& items) const = 0; - void set_cluster_assignment_gibbs(const Domain& domain, const T_item& item, - int table, std::mt19937* prng) { - int table_current = domain.get_cluster_assignment(item); - assert(table != table_current); - for (const T_items& items : data_r.at(domain.name).at(item)) { - ValueType x = data.at(items); - // Remove from current cluster. - T_items z_prev = get_cluster_assignment(items); - auto cluster_prev = clusters.at(z_prev); - cluster_prev->unincorporate(x); - if (cluster_prev->N == 0) { - delete clusters.at(z_prev); - clusters.erase(z_prev); - } - // Move to desired cluster. - T_items z_new = get_cluster_assignment_gibbs(items, domain, item, table); - if (!clusters.contains(z_new)) { - // Move to fresh cluster. - clusters[z_new] = make_new_distribution(prng); - clusters.at(z_new)->incorporate(x); - } else { - // Move to existing cluster. - assert((clusters.at(z_new)->N > 0)); - clusters.at(z_new)->incorporate(x); - } - } - // Caller should invoke domain.set_cluster_gibbs - } + virtual bool has_observation(const Domain& domain, const T_item& item) const = 0; - bool has_observation(const Domain& domain, const T_item& item) { - return data_r.at(domain.name).contains(item); - } + virtual ~Relation() = default; - // Disable copying. - Relation& operator=(const Relation&) = delete; - Relation(const Relation&) = delete; }; diff --git a/cxx/relation_variant.cc b/cxx/relation_variant.cc index 01a407e..b7aebab 100644 --- a/cxx/relation_variant.cc +++ b/cxx/relation_variant.cc @@ -1,15 +1,16 @@ // Copyright 2024 // See LICENSE.txt +#include "relation_variant.hh" + #include #include #include +#include "clean_relation.hh" #include "domain.hh" -#include "relation.hh" -#include "relation_variant.hh" - +// TODO(emilyaf): Implement this for NoisyRelation. RelationVariant relation_from_spec(const std::string& name, const DistributionSpec& dist_spec, std::vector& domains) { @@ -33,10 +34,11 @@ RelationVariant relation_from_spec(const std::string& name, // the right kind of Relation. std::visit( [&](const auto& v) { - rv = new Relation::SampleType>( - name, dist_spec, domains); - }, dv); + rv = new CleanRelation< + typename std::remove_reference_t::SampleType>( + name, dist_spec, domains); + }, + dv); return rv; } diff --git a/cxx/relation_variant.hh b/cxx/relation_variant.hh index cc61c75..2882354 100644 --- a/cxx/relation_variant.hh +++ b/cxx/relation_variant.hh @@ -11,9 +11,8 @@ #include "relation.hh" #include "util_distribution_variant.hh" -using RelationVariant = - std::variant*, Relation*, - Relation*, Relation*>; +using RelationVariant = std::variant*, Relation*, + Relation*, Relation*>; RelationVariant relation_from_spec(const std::string& name, const DistributionSpec& dist_spec, diff --git a/cxx/relation_variant_test.cc b/cxx/relation_variant_test.cc index 2204b28..0d593f0 100644 --- a/cxx/relation_variant_test.cc +++ b/cxx/relation_variant_test.cc @@ -5,13 +5,16 @@ #include "relation_variant.hh" #include + +#include "clean_relation.hh" namespace tt = boost::test_tools; BOOST_AUTO_TEST_CASE(test_relation_variant) { - std::vector domains; + std::vector domains; domains.push_back(new Domain("D1")); - RelationVariant rv = relation_from_spec( - "r1", parse_distribution_spec("bernoulli"), domains); - Relation* rb = std::get*>(rv); + RelationVariant rv = + relation_from_spec("r1", parse_distribution_spec("bernoulli"), domains); + CleanRelation* rb = + reinterpret_cast*>(std::get*>(rv)); BOOST_TEST(rb->name == "r1"); } diff --git a/cxx/tests/test_hirm_animals.cc b/cxx/tests/test_hirm_animals.cc index 0de1bcf..2bf0221 100644 --- a/cxx/tests/test_hirm_animals.cc +++ b/cxx/tests/test_hirm_animals.cc @@ -32,7 +32,8 @@ int main(int argc, char** argv) { size_t n_obs_unary = 0; for (const auto& [z, irm] : hirm.irms) { for (const auto& [r, relation] : irm->relations) { - n_obs_unary += std::visit([](const auto r) {return r->data.size();}, relation); + n_obs_unary += std::visit( + [](const auto r) { return r->get_data().size(); }, relation); } } assert(n_obs_unary == std::size(observations_unary)); @@ -86,14 +87,14 @@ int main(int argc, char** argv) { assert(abs(logsumexp({p0_solitary_sheep, p1_solitary_sheep})) < 1e-10); // Jointly normalized. - auto p00_black_persiancat_solitary_sheep = - hirm.logp({{"black", {persiancat}, false}, {"solitary", {sheep}, false}}, &prng); - auto p01_black_persiancat_solitary_sheep = - hirm.logp({{"black", {persiancat}, false}, {"solitary", {sheep}, true}}, &prng); - auto p10_black_persiancat_solitary_sheep = - hirm.logp({{"black", {persiancat}, true}, {"solitary", {sheep}, false}}, &prng); - auto p11_black_persiancat_solitary_sheep = - hirm.logp({{"black", {persiancat}, true}, {"solitary", {sheep}, true}}, &prng); + auto p00_black_persiancat_solitary_sheep = hirm.logp( + {{"black", {persiancat}, false}, {"solitary", {sheep}, false}}, &prng); + auto p01_black_persiancat_solitary_sheep = hirm.logp( + {{"black", {persiancat}, false}, {"solitary", {sheep}, true}}, &prng); + auto p10_black_persiancat_solitary_sheep = hirm.logp( + {{"black", {persiancat}, true}, {"solitary", {sheep}, false}}, &prng); + auto p11_black_persiancat_solitary_sheep = hirm.logp( + {{"black", {persiancat}, true}, {"solitary", {sheep}, true}}, &prng); auto Z = logsumexp({ p00_black_persiancat_solitary_sheep, p01_black_persiancat_solitary_sheep, @@ -137,8 +138,10 @@ int main(int argc, char** argv) { } // Check relations agree. for (const auto& [r, rm_var] : irm->relations) { - auto rx = std::get*>(irx->relations.at(r)); - auto rm = std::get*>(rm_var); + auto rx = reinterpret_cast*>( + std::get*>(irx->relations.at(r))); + auto rm = reinterpret_cast*>( + std::get*>(rm_var)); assert(rm->data == rx->data); assert(rm->data_r == rx->data_r); assert(rm->clusters.size() == rx->clusters.size()); diff --git a/cxx/tests/test_irm_two_relations.cc b/cxx/tests/test_irm_two_relations.cc index e89bbb0..267257b 100644 --- a/cxx/tests/test_irm_two_relations.cc +++ b/cxx/tests/test_irm_two_relations.cc @@ -15,7 +15,7 @@ #include "util_io.hh" #include "util_math.hh" -using T_r = Relation*; +using T_r = CleanRelation*; int main(int argc, char** argv) { std::string path_base = "assets/two_relations"; @@ -77,10 +77,14 @@ int main(int argc, char** argv) { assert(l.size() == 2); auto x1 = l.at(0); auto x2 = l.at(1); - auto p0 = std::get(irm.relations.at("R1"))->logp({x1, x2}, false, &prng); + auto p0 = + reinterpret_cast(std::get*>(irm.relations.at("R1"))) + ->logp({x1, x2}, false, &prng); auto p0_irm = irm.logp({{"R1", {x1, x2}, false}}, &prng); assert(abs(p0 - p0_irm) < 1e-10); - auto p1 = std::get(irm.relations.at("R1"))->logp({x1, x2}, true, &prng); + auto p1 = + reinterpret_cast(std::get*>(irm.relations.at("R1"))) + ->logp({x1, x2}, true, &prng); auto Z = logsumexp({p0, p1}); assert(abs(Z) < 1e-10); assert(abs(exp(p0) - expected_p0[x1].at(x2)) < .1); @@ -91,10 +95,14 @@ int main(int argc, char** argv) { auto x1 = l.at(0); auto x2 = l.at(1); auto x3 = l.at(2); - auto p00 = irm.logp({{"R1", {x1, x2}, false}, {"R1", {x1, x3}, false}}, &prng); - auto p01 = irm.logp({{"R1", {x1, x2}, false}, {"R1", {x1, x3}, true}}, &prng); - auto p10 = irm.logp({{"R1", {x1, x2}, true}, {"R1", {x1, x3}, false}}, &prng); - auto p11 = irm.logp({{"R1", {x1, x2}, true}, {"R1", {x1, x3}, true}}, &prng); + auto p00 = + irm.logp({{"R1", {x1, x2}, false}, {"R1", {x1, x3}, false}}, &prng); + auto p01 = + irm.logp({{"R1", {x1, x2}, false}, {"R1", {x1, x3}, true}}, &prng); + auto p10 = + irm.logp({{"R1", {x1, x2}, true}, {"R1", {x1, x3}, false}}, &prng); + auto p11 = + irm.logp({{"R1", {x1, x2}, true}, {"R1", {x1, x3}, true}}, &prng); auto Z = logsumexp({p00, p01, p10, p11}); assert(abs(Z) < 1e-10); } @@ -111,8 +119,10 @@ int main(int argc, char** argv) { // transitioned. assert(abs(irx.logp_score() - irm.logp_score()) > 1e-8); for (const auto& r : {"R1", "R2"}) { - auto r1m = std::get*>(irm.relations.at(r)); - auto r1x = std::get*>(irx.relations.at(r)); + auto r1m = + reinterpret_cast(std::get*>(irm.relations.at(r))); + auto r1x = + reinterpret_cast(std::get*>(irx.relations.at(r))); for (const auto& [c, distribution] : r1m->clusters) { auto dx = reinterpret_cast(r1x->clusters.at(c)); auto dy = reinterpret_cast(distribution); @@ -135,8 +145,8 @@ int main(int argc, char** argv) { for (const auto& r : {"R1", "R2"}) { auto rm_var = irm.relations.at(r); auto rx_var = irx.relations.at(r); - T_r rm = std::get(rm_var); - T_r rx = std::get(rx_var); + T_r rm = reinterpret_cast(std::get*>(rm_var)); + T_r rx = reinterpret_cast(std::get*>(rx_var)); assert(rm->data == rx->data); assert(rm->data_r == rx->data_r); assert(rm->clusters.size() == rx->clusters.size()); diff --git a/cxx/tests/test_misc.cc b/cxx/tests/test_misc.cc index 389182d..7216d84 100644 --- a/cxx/tests/test_misc.cc +++ b/cxx/tests/test_misc.cc @@ -34,22 +34,25 @@ int main(int argc, char** argv) { printf("===== IRM ====\n"); std::map schema1{ - {"R1", T_relation{{"D1", "D1"}, DistributionSpec {DistributionEnum::bernoulli}}}, - {"R2", T_relation{{"D1", "D2"}, DistributionSpec {DistributionEnum::normal}}}, - {"R3", T_relation{{"D3", "D1"}, DistributionSpec {DistributionEnum::bigram}}} - }; + {"R1", + T_relation{{"D1", "D1"}, DistributionSpec{DistributionEnum::bernoulli}}}, + {"R2", + T_relation{{"D1", "D2"}, DistributionSpec{DistributionEnum::normal}}}, + {"R3", + T_relation{{"D3", "D1"}, DistributionSpec{DistributionEnum::bigram}}}}; IRM irm(schema1); for (auto const& kv : irm.domains) { printf("%s %s; ", kv.first.c_str(), kv.second->name.c_str()); - for (auto const &r : irm.domain_to_relations.at(kv.first)) { + for (auto const& r : irm.domain_to_relations.at(kv.first)) { printf("%s ", r.c_str()); } printf("\n"); } for (auto const& kv : irm.relations) { printf("%s ", kv.first.c_str()); - for (auto const d : std::visit([&](auto r) {return r->domains;}, kv.second)) { + for (auto const d : + std::visit([&](auto r) { return r->get_domains(); }, kv.second)) { printf("%s ", d->name.c_str()); } printf("\n"); @@ -102,7 +105,8 @@ int main(int argc, char** argv) { std::string path_clusters = "assets/animals.binary.irm"; to_txt(path_clusters, irm3, encoding); - auto rel = std::get*>(irm3.relations.at("has")); + auto rel = reinterpret_cast*>( + std::get*>(irm3.relations.at("has"))); auto& enc = std::get<0>(encoding); auto lp0 = rel->logp({enc["animal"]["tail"], enc["animal"]["bat"]}, 0, &prng); auto lp1 = rel->logp({enc["animal"]["tail"], enc["animal"]["bat"]}, 1, &prng); @@ -113,8 +117,8 @@ int main(int argc, char** argv) { printf("logsumexp is %1.2f\n", lp_01); IRM irm4({}); - from_txt(&prng, &irm4, "assets/animals.binary.schema", "assets/animals.binary.obs", - path_clusters); + from_txt(&prng, &irm4, "assets/animals.binary.schema", + "assets/animals.binary.obs", path_clusters); irm4.domains.at("animal")->crp.alpha = irm3.domains.at("animal")->crp.alpha; irm4.domains.at("feature")->crp.alpha = irm3.domains.at("feature")->crp.alpha; assert(abs(irm3.logp_score() - irm4.logp_score()) < 1e-8); @@ -128,8 +132,10 @@ int main(int argc, char** argv) { assert(d3->crp.alpha == d4->crp.alpha); } for (const auto& r : {"has"}) { - auto r3 = std::get*>(irm3.relations.at(r)); - auto r4 = std::get*>(irm4.relations.at(r)); + auto r3 = reinterpret_cast*>( + std::get*>(irm3.relations.at(r))); + auto r4 = reinterpret_cast*>( + std::get*>(irm4.relations.at(r))); assert(r3->data == r4->data); assert(r3->data_r == r4->data_r); assert(r3->clusters.size() == r4->clusters.size()); diff --git a/cxx/util_distribution_variant.cc b/cxx/util_distribution_variant.cc index fe4155e..42338a6 100644 --- a/cxx/util_distribution_variant.cc +++ b/cxx/util_distribution_variant.cc @@ -1,17 +1,22 @@ // Copyright 2024 // See LICENSE.txt +#include "util_distribution_variant.hh" + +#include #include #include -#include -#include "util_distribution_variant.hh" + #include "distributions/beta_bernoulli.hh" #include "distributions/bigram.hh" #include "distributions/dirichlet_categorical.hh" #include "distributions/normal.hh" #include "distributions/skellam.hh" #include "distributions/stringcat.hh" - +#include "emissions/bitflip.hh" +#include "emissions/gaussian.hh" +#include "emissions/simple_string.hh" +#include "emissions/sometimes.hh" ObservationVariant observation_string_to_value( const std::string& value_str, const DistributionEnum& distribution) { @@ -38,8 +43,7 @@ DistributionSpec parse_distribution_spec(const std::string& dist_str) { {"categorical", DistributionEnum::categorical}, {"normal", DistributionEnum::normal}, {"skellam", DistributionEnum::skellam}, - {"stringcat", DistributionEnum::stringcat} - }; + {"stringcat", DistributionEnum::stringcat}}; std::string dist_name = dist_str.substr(0, dist_str.find('(')); DistributionEnum dist = dist_name_to_enum.at(dist_name); @@ -64,8 +68,8 @@ DistributionSpec parse_distribution_spec(const std::string& dist_str) { } } -DistributionVariant cluster_prior_from_spec( - const DistributionSpec& spec, std::mt19937* prng) { +DistributionVariant cluster_prior_from_spec(const DistributionSpec& spec, + std::mt19937* prng) { switch (spec.distribution) { case DistributionEnum::bernoulli: return new BetaBernoulli; @@ -99,3 +103,18 @@ DistributionVariant cluster_prior_from_spec( assert(false && "Unsupported distribution enum value."); } } + +EmissionVariant cluster_prior_from_spec(const EmissionSpec& spec, + std::mt19937* prng) { + switch (spec.emission) { + case EmissionEnum::sometimes_bitflip: + return new Sometimes; + case EmissionEnum::gaussian: + return new GaussianEmission; + case EmissionEnum::simple_string: { + return new SimpleStringEmission; + } + default: + assert(false && "Unsupported emission enum value."); + } +} \ No newline at end of file diff --git a/cxx/util_distribution_variant.hh b/cxx/util_distribution_variant.hh index 3ee3e5e..7f9761a 100644 --- a/cxx/util_distribution_variant.hh +++ b/cxx/util_distribution_variant.hh @@ -18,15 +18,31 @@ #include "distributions/normal.hh" #include "distributions/skellam.hh" #include "distributions/stringcat.hh" +#include "emissions/bitflip.hh" +#include "emissions/gaussian.hh" +#include "emissions/simple_string.hh" +#include "emissions/sometimes.hh" enum class DistributionEnum { - bernoulli, bigram, categorical, normal, skellam, stringcat }; + bernoulli, + bigram, + categorical, + normal, + skellam, + stringcat +}; + +enum class EmissionEnum { sometimes_bitflip, gaussian, simple_string }; struct DistributionSpec { DistributionEnum distribution; std::map distribution_args = {}; }; +struct EmissionSpec { + EmissionEnum emission; +}; + // Set of all distribution sample types. using ObservationVariant = std::variant; @@ -34,6 +50,9 @@ using DistributionVariant = std::variant; +using EmissionVariant = + std::variant*, GaussianEmission*, SimpleStringEmission*>; + ObservationVariant observation_string_to_value( const std::string& value_str, const DistributionEnum& distribution); @@ -41,3 +60,6 @@ DistributionSpec parse_distribution_spec(const std::string& dist_str); DistributionVariant cluster_prior_from_spec(const DistributionSpec& spec, std::mt19937* prng); + +EmissionVariant cluster_prior_from_spec(const EmissionSpec& spec, + std::mt19937* prng);