diff --git a/cxx/BUILD b/cxx/BUILD index f6ca671..93accde 100644 --- a/cxx/BUILD +++ b/cxx/BUILD @@ -1,14 +1,5 @@ licenses(["notice"]) -cc_library( - name = "headers", - hdrs = glob( - ["*.hh"], - allow_empty = False, - ), - visibility = [":__subpackages__"], -) - cc_library( name = "cxxopts", srcs = ["cxxopts.hpp"], @@ -24,12 +15,34 @@ cc_library( ], ) +cc_library( + name = "irm", + hdrs = ["irm.hh"], + srcs = ["irm.cc"], + visibility = [":__subpackages__"], + deps = [ + ":relation", + ":relation_variant", + ":util_distribution_variant", + ], +) + +cc_library( + name = "hirm_lib", + hdrs = ["hirm.hh"], + srcs = ["hirm.cc"], + visibility = [":__subpackages__"], + deps = [ + ":irm", + ], +) + cc_binary( name = "hirm", - srcs = ["hirm.cc"], + srcs = ["hirm_main.cc"], deps = [ ":cxxopts", - ":headers", + ":hirm_lib", ":util_distribution_variant", ":util_hash", ":util_io", @@ -41,6 +54,7 @@ cc_binary( cc_library( name = "relation", hdrs = ["relation.hh"], + visibility = [":__subpackages__"], deps = [ ":domain", ":util_distribution_variant", @@ -50,12 +64,26 @@ cc_library( ], ) +cc_library( + name = "relation_variant", + hdrs = ["relation_variant.hh"], + srcs = ["relation_variant.cc"], + visibility = [":__subpackages__"], + deps = [ + ":domain", + ":relation", + ":util_distribution_variant", + "//distributions", + ], +) + cc_library( name = "util_distribution_variant", srcs = ["util_distribution_variant.cc"], visibility = [":__subpackages__"], hdrs = ["util_distribution_variant.hh"], deps = [ + ":domain", "//distributions", ], ) @@ -72,11 +100,10 @@ cc_library( srcs = ["util_io.cc"], visibility = [":__subpackages__"], hdrs = [ - "hirm.hh", "util_io.hh", ], deps = [ - ":headers", + ":hirm_lib", "//distributions", ], ) @@ -86,7 +113,7 @@ cc_library( srcs = ["util_math.cc"], hdrs = ["util_math.hh"], visibility = [":__subpackages__"], - deps = [":headers"], + deps = [], ) cc_test( @@ -98,6 +125,14 @@ cc_test( ], ) +cc_test( + name = "irm_test", + srcs = ["irm_test.cc"], + deps = [ + ":irm", + "@boost//:test", + ], +) cc_test( name = "relation_test", srcs = ["relation_test.cc"], diff --git a/cxx/distributions/BUILD b/cxx/distributions/BUILD index b61b57f..3f5730f 100644 --- a/cxx/distributions/BUILD +++ b/cxx/distributions/BUILD @@ -26,7 +26,6 @@ cc_library( visibility = ["//:__subpackages__"], deps = [ ":base", - "//:headers", "//:util_math", ], ) @@ -39,7 +38,6 @@ cc_library( deps = [ ":base", ":dirichlet_categorical", - "//:headers", "//:util_math", ], ) @@ -50,7 +48,6 @@ cc_library( hdrs = ["crp.hh"], visibility = ["//:__subpackages__"], deps = [ - "//:headers", "//:util_math", ], ) diff --git a/cxx/distributions/base.hh b/cxx/distributions/base.hh index 5e7054d..c629235 100644 --- a/cxx/distributions/base.hh +++ b/cxx/distributions/base.hh @@ -5,7 +5,8 @@ template class Distribution { // Abstract base class for probability distributions in HIRM. // New distribution subclasses need to be added to - // `util_distribution_variant` to be used in the (H)IRM models. + // `relation_variant` and `util_distribution_variant` to be used in the + // (H)IRM models. public: typedef T SampleType; // N is the number of incorporated observations. diff --git a/cxx/hirm.cc b/cxx/hirm.cc index f9a4a1e..47158b2 100644 --- a/cxx/hirm.cc +++ b/cxx/hirm.cc @@ -3,202 +3,239 @@ #include "hirm.hh" -#include -#include -#include - -#include "cxxopts.hpp" -#include "util_io.hh" - -#define GET_ELAPSED(t) double(clock() - t) / CLOCKS_PER_SEC - -// TODO(emilyaf): Refactor as a function for readibility/maintainability. -#define CHECK_TIMEOUT(timeout, t_begin) \ - if (timeout) { \ - auto elapsed = GET_ELAPSED(t_begin); \ - if (timeout < elapsed) { \ - printf("timeout after %1.2fs \n", elapsed); \ - break; \ - } \ +HIRM::HIRM(const T_schema& schema, std::mt19937* prng) { + for (const auto& [name, relation] : schema) { + this->add_relation(prng, name, relation); } +} + +void HIRM::incorporate(std::mt19937* prng, const std::string& r, + const T_items& items, const ObservationVariant& value) { + IRM* irm = relation_to_irm(r); + irm->incorporate(prng, r, items, value); +} + +void HIRM::unincorporate(const std::string& r, const T_items& items) { + IRM* irm = relation_to_irm(r); + irm->unincorporate(r, items); +} + +int HIRM::relation_to_table(const std::string& r) { + int rc = relation_to_code.at(r); + return crp.assignments.at(rc); +} + +IRM* HIRM::relation_to_irm(const std::string& r) { + int rc = relation_to_code.at(r); + int table = crp.assignments.at(rc); + return irms.at(table); +} + +RelationVariant HIRM::get_relation(const std::string& r) { + IRM* irm = relation_to_irm(r); + return irm->relations.at(r); +} -// TODO(emilyaf): Refactor as a function for readibility/maintainability. -#define REPORT_SCORE(var_verbose, var_t, var_t_total, var_model) \ - if (var_verbose) { \ - auto t_delta = GET_ELAPSED(var_t); \ - var_t_total += t_delta; \ - double x = var_model->logp_score(); \ - printf("%f %f\n", var_t_total, x); \ - fflush(stdout); \ +void HIRM::transition_cluster_assignments_all(std::mt19937* prng) { + for (const auto& [r, rc] : relation_to_code) { + transition_cluster_assignment_relation(prng, r); } +} + +void HIRM::transition_cluster_assignments(std::mt19937* prng, + const std::vector& rs) { + for (const auto& r : rs) { + transition_cluster_assignment_relation(prng, r); + } +} -void single_step_irm_inference(std::mt19937* prng, IRM* irm, double& t_total, - bool verbose) { - // TRANSITION ASSIGNMENTS. - for (const auto& [d, domain] : irm->domains) { - for (const auto item : domain->items) { - clock_t t = clock(); - irm->transition_cluster_assignment_item(prng, d, item); - REPORT_SCORE(verbose, t, t_total, irm); +void HIRM::transition_cluster_assignment_relation(std::mt19937* prng, + const std::string& r) { + int rc = relation_to_code.at(r); + int table_current = crp.assignments.at(rc); + RelationVariant relation = get_relation(r); + T_relation t_relation = + std::visit([](auto rel) { return rel->trel; }, relation); + auto crp_dist = crp.tables_weights_gibbs(table_current); + std::vector tables; + std::vector logps; + int* table_aux = nullptr; + IRM* irm_aux = nullptr; + // Compute probabilities of each table. + for (const auto& [table, n_customers] : crp_dist) { + IRM* irm; + if (!irms.contains(table)) { + irm = new IRM({}); + assert(table_aux == nullptr); + assert(irm_aux == nullptr); + table_aux = (int*)malloc(sizeof(*table_aux)); + *table_aux = table; + irm_aux = irm; + } else { + irm = irms.at(table); + } + if (table != table_current) { + irm->add_relation(r, t_relation); + std::visit( + [&](auto rel) { + for (const auto& [items, value] : rel->data) { + irm->incorporate(prng, r, items, value); + } + }, + relation); } + RelationVariant rel_r = irm->relations.at(r); + double lp_data = + std::visit([](auto rel) { return rel->logp_score(); }, rel_r); + double lp_crp = log(n_customers); + logps.push_back(lp_crp + lp_data); + tables.push_back(table); } - // TRANSITION DISTRIBUTION HYPERPARAMETERS. - for (const auto& [r, relation] : irm->relations) { - std::visit( - [&](auto r) { - for (const auto& [c, distribution] : r->clusters) { - clock_t t = clock(); - distribution->transition_hyperparameters(prng); - REPORT_SCORE(verbose, t, t_total, irm); - } - }, - relation); + // Sample new table. + int idx = log_choice(logps, prng); + T_item choice = tables[idx]; + + // Remove relation from all other tables. + for (const auto& [table, customers] : crp.tables) { + IRM* irm = irms.at(table); + if (table != choice) { + assert(irm->relations.count(r) == 1); + irm->remove_relation(r); + } + if (irm->relations.empty()) { + assert(crp.tables[table].size() == 1); + assert(table == table_current); + irms.erase(table); + delete irm; + } } - // TRANSITION ALPHA. - for (const auto& [d, domain] : irm->domains) { - clock_t t = clock(); - domain->crp.transition_alpha(prng); - REPORT_SCORE(verbose, t, t_total, irm); + // Add auxiliary table if necessary. + if ((table_aux != nullptr) && (choice == *table_aux)) { + assert(irm_aux != nullptr); + irms[choice] = irm_aux; + } else { + delete irm_aux; } -} - -void inference_irm(std::mt19937* prng, IRM* irm, int iters, int timeout, - bool verbose) { - clock_t t_begin = clock(); - double t_total = 0; - for (int i = 0; i < iters; ++i) { - CHECK_TIMEOUT(timeout, t_begin); - single_step_irm_inference(prng, irm, t_total, verbose); + free(table_aux); + // Update the CRP. + crp.unincorporate(rc); + crp.incorporate(rc, choice); + assert(irms.size() == crp.tables.size()); + for (const auto& [table, irm] : irms) { + assert(crp.tables.contains(table)); } } -void inference_hirm(std::mt19937* prng, HIRM* hirm, int iters, int timeout, - bool verbose) { - clock_t t_begin = clock(); - double t_total = 0; - for (int i = 0; i < iters; ++i) { - CHECK_TIMEOUT(timeout, t_begin); - // TRANSITION RELATIONS. - for (const auto& [r, rc] : hirm->relation_to_code) { - clock_t t = clock(); - hirm->transition_cluster_assignment_relation(prng, r); - REPORT_SCORE(verbose, t, t_total, hirm); +void HIRM::set_cluster_assignment_gibbs( + std::mt19937* prng, const std::string& r, int table) { + assert(irms.size() == crp.tables.size()); + int rc = relation_to_code.at(r); + int table_current = crp.assignments.at(rc); + RelationVariant relation = get_relation(r); + auto f_obs = [&](auto rel) { + T_relation trel = rel->trel; + IRM* irm = relation_to_irm(r); + auto observations = rel->data; + // Remove from current IRM. + irm->remove_relation(r); + if (irm->relations.empty()) { + irms.erase(table_current); + delete irm; + } + // Add to target IRM. + if (!irms.contains(table)) { + irm = new IRM({}); + irms[table] = irm; } - // TRANSITION IRMs. - for (const auto& [t, irm] : hirm->irms) { - single_step_irm_inference(prng, irm, t_total, verbose); + irm = irms.at(table); + irm->add_relation(r, trel); + for (const auto& [items, value] : observations) { + irm->incorporate(prng, r, items, value); } + }; + std::visit(f_obs, relation); + // Update CRP. + crp.unincorporate(rc); + crp.incorporate(rc, table); + assert(irms.size() == crp.tables.size()); + for (const auto& [table, irm] : irms) { + assert(crp.tables.contains(table)); } } -int main(int argc, char** argv) { - cxxopts::Options options("hirm", - "Run a hierarchical infinite relational model."); - options.add_options()("help", "show help message")( - "mode", "options are {irm, hirm}", - cxxopts::value()->default_value("hirm"))( - "seed", "random seed", cxxopts::value()->default_value("10"))( - "iters", "number of inference iterations", - cxxopts::value()->default_value("10"))( - "verbose", "report results to terminal", - cxxopts::value()->default_value("false"))( - "timeout", "number of seconds of inference", - cxxopts::value()->default_value("0"))( - "load", "path to .[h]irm file with initial clusters", - cxxopts::value()->default_value(""))( - "path", "base name of the .schema file", cxxopts::value())( - "rest", "rest", - cxxopts::value>()->default_value({})); - options.parse_positional({"path", "rest"}); - options.positional_help(""); - - auto result = options.parse(argc, argv); - if (result.count("help")) { - std::cout << options.help() << std::endl; - return 0; - } - if (result.count("path") == 0) { - std::cout << options.help() << std::endl; - return 1; +void HIRM::add_relation(std::mt19937* prng, const std::string& name, + const T_relation& rel) { + assert(!schema.contains(name)); + schema[name] = rel; + int offset = + (code_to_relation.empty()) + ? 0 + : std::max_element(code_to_relation.begin(), code_to_relation.end()) + ->first; + int rc = 1 + offset; + int table = crp.sample(prng); + crp.incorporate(rc, table); + if (irms.count(table) == 1) { + irms.at(table)->add_relation(name, rel); + } else { + irms[table] = new IRM({{name, rel}}); } + assert(!relation_to_code.contains(name)); + assert(!code_to_relation.contains(rc)); + relation_to_code[name] = rc; + code_to_relation[rc] = name; +} - std::string path_base = result["path"].as(); - int seed = result["seed"].as(); - int iters = result["iters"].as(); - int timeout = result["timeout"].as(); - bool verbose = result["verbose"].as(); - std::string path_clusters = result["load"].as(); - std::string mode = result["mode"].as(); - - if (mode != "hirm" && mode != "irm") { - std::cout << options.help() << std::endl; - std::cout << "unknown mode " << mode << std::endl; - return 1; +void HIRM::remove_relation(const std::string& name) { + schema.erase(name); + int rc = relation_to_code.at(name); + int table = crp.assignments.at(rc); + bool singleton = crp.tables.at(table).size() == 1; + crp.unincorporate(rc); + irms.at(table)->remove_relation(name); + if (singleton) { + IRM* irm = irms.at(table); + assert(irm->relations.empty()); + irms.erase(table); + delete irm; } + relation_to_code.erase(name); + code_to_relation.erase(rc); +} - std::string path_obs = path_base + ".obs"; - std::string path_schema = path_base + ".schema"; - std::string path_save = path_base + "." + std::to_string(seed); - - printf("setting seed to %d\n", seed); - std::mt19937 prng(seed); - - std::cout << "loading schema from " << path_schema << std::endl; - auto schema = load_schema(path_schema); - - std::cout << "loading observations from " << path_obs << std::endl; - auto observations = load_observations(path_obs, schema); - auto encoding = encode_observations(schema, observations); - - if (mode == "irm") { - std::cout << "selected model is IRM" << std::endl; - IRM* irm; - // Load - if (path_clusters.empty()) { - irm = new IRM(schema); - std::cout << "incorporating observations" << std::endl; - incorporate_observations(&prng, *irm, encoding, observations); - } else { - irm = new IRM({}); - std::cout << "loading clusters from " << path_clusters << std::endl; - from_txt(&prng, irm, path_schema, path_obs, path_clusters); +double HIRM::logp( + const std::vector>& + observations) { + std::unordered_map< + int, std::vector>> + obs_dict; + for (const auto& [r, items, value] : observations) { + int rc = relation_to_code.at(r); + int table = crp.assignments.at(rc); + if (!obs_dict.contains(table)) { + obs_dict[table] = {}; } - // Infer - std::cout << "inferring " << iters << " iters; timeout " << timeout - << std::endl; - inference_irm(&prng, irm, iters, timeout, verbose); - // Save - path_save += ".irm"; - std::cout << "saving to " << path_save << std::endl; - to_txt(path_save, *irm, encoding); - // Free - free(irm); - return 0; + obs_dict.at(table).push_back({r, items, value}); + } + double logp = 0.0; + for (const auto& [t, o] : obs_dict) { + logp += irms.at(t)->logp(o); } + return logp; +} - if (mode == "hirm") { - std::cout << "selected model is HIRM" << std::endl; - HIRM* hirm; - // Load - if (path_clusters.empty()) { - hirm = new HIRM(schema, &prng); - std::cout << "incorporating observations" << std::endl; - incorporate_observations(&prng, *hirm, encoding, observations); - } else { - hirm = new HIRM({}, &prng); - std::cout << "loading clusters from " << path_clusters << std::endl; - from_txt(&prng, hirm, path_schema, path_obs, path_clusters); - } - // Infer - std::cout << "inferring " << iters << " iters; timeout " << timeout - << std::endl; - inference_hirm(&prng, hirm, iters, timeout, verbose); - // Save - path_save += ".hirm"; - std::cout << "saving to " << path_save << std::endl; - to_txt(path_save, *hirm, encoding); - // Free - free(hirm); - return 0; +double HIRM::logp_score() const { + double logp_score_crp = crp.logp_score(); + double logp_score_irms = 0.0; + for (const auto& [table, irm] : irms) { + logp_score_irms += irm->logp_score(); } + return logp_score_crp + logp_score_irms; } + +HIRM::~HIRM() { + for (const auto& [table, irm] : irms) { + delete irm; + } +} + diff --git a/cxx/hirm.hh b/cxx/hirm.hh index 6bcfc69..faf0dc6 100644 --- a/cxx/hirm.hh +++ b/cxx/hirm.hh @@ -8,282 +8,10 @@ #include #include +#include "irm.hh" #include "relation.hh" #include "util_distribution_variant.hh" -// Map from names to T_relation's. -typedef std::map T_schema; - -class IRM { - public: - T_schema schema; // schema of relations - std::unordered_map domains; // map from name to Domain - std::unordered_map - relations; // map from name to Relation - std::unordered_map> - domain_to_relations; // reverse map - - IRM(const T_schema& schema) { - for (const auto& [name, relation] : schema) { - this->add_relation(name, relation); - } - } - - ~IRM() { - for (auto [d, domain] : domains) { - delete domain; - } - for (auto [r, relation] : relations) { - std::visit([](auto rel) { delete rel; }, relation); - } - } - - void incorporate(std::mt19937* prng, const std::string& r, - const T_items& items, ObservationVariant value) { - std::visit( - [&](auto rel) { - auto v = std::get< - typename std::remove_reference_t::ValueType>( - value); - rel->incorporate(prng, items, v); - }, - relations.at(r)); - } - - void unincorporate(const std::string& r, const T_items& items) { - std::visit([&](auto rel) { rel->unincorporate(items); }, relations.at(r)); - } - - void transition_cluster_assignments_all(std::mt19937* prng) { - for (const auto& [d, domain] : domains) { - for (const T_item item : domain->items) { - transition_cluster_assignment_item(prng, d, item); - } - } - } - - void transition_cluster_assignments(std::mt19937* prng, - const std::vector& ds) { - for (const std::string& d : ds) { - for (const T_item item : domains.at(d)->items) { - transition_cluster_assignment_item(prng, d, item); - } - } - } - - void transition_cluster_assignment_item(std::mt19937* prng, - const std::string& d, - const T_item& item) { - Domain* domain = domains.at(d); - auto crp_dist = domain->tables_weights_gibbs(item); - // Compute probability of each table. - std::vector tables; - std::vector logps; - tables.reserve(crp_dist.size()); - logps.reserve(crp_dist.size()); - for (const auto& [table, n_customers] : crp_dist) { - tables.push_back(table); - logps.push_back(log(n_customers)); - } - auto accumulate_logps = [&](auto rel) { - if (rel->has_observation(*domain, item)) { - std::vector lp_relation = - rel->logp_gibbs_exact(*domain, item, tables); - assert(lp_relation.size() == tables.size()); - assert(lp_relation.size() == logps.size()); - for (int i = 0; i < std::ssize(logps); ++i) { - logps[i] += lp_relation[i]; - } - } - }; - for (const auto& r : domain_to_relations.at(d)) { - std::visit(accumulate_logps, relations.at(r)); - } - // Sample new table. - assert(tables.size() == logps.size()); - int idx = log_choice(logps, prng); - T_item choice = tables[idx]; - // Move to new table (if necessary). - if (choice != domain->get_cluster_assignment(item)) { - auto set_cluster_r = [&](auto rel) { - if (rel->has_observation(*domain, item)) { - rel->set_cluster_assignment_gibbs(*domain, item, choice); - } - }; - for (const std::string& r : domain_to_relations.at(d)) { - std::visit(set_cluster_r, relations.at(r)); - } - domain->set_cluster_assignment_gibbs(item, choice); - } - } - - double logp( - const std::vector>& - observations) { - std::unordered_map> - relation_items_seen; - std::unordered_map> - domain_item_seen; - std::vector> item_universe; - std::vector> index_universe; - std::vector> weight_universe; - std::unordered_map< - std::string, - std::unordered_map>>> - cluster_universe; - // Compute all cluster combinations. - for (const auto& [r, items, value] : observations) { - // Assert observation is unique. - assert(!relation_items_seen[r].contains(items)); - relation_items_seen[r].insert(items); - // Process each (domain, item) in the observations. - RelationVariant relation = relations.at(r); - int arity = - std::visit([](auto rel) { return rel->domains.size(); }, relation); - assert(std::ssize(items) == arity); - for (int i = 0; i < arity; ++i) { - // Skip if (domain, item) processed. - Domain* domain = - std::visit([&](auto rel) { return rel->domains.at(i); }, relation); - T_item item = items.at(i); - if (domain_item_seen[domain->name].contains(item)) { - assert(cluster_universe[domain->name].contains(item)); - continue; - } - domain_item_seen[domain->name].insert(item); - // Obtain tables, weights, indexes for this item. - std::vector t_list; - std::vector w_list; - std::vector i_list; - size_t n_tables = domain->tables_weights().size() + 1; - t_list.reserve(n_tables); - w_list.reserve(n_tables); - i_list.reserve(n_tables); - if (domain->items.contains(item)) { - int z = domain->get_cluster_assignment(item); - t_list = {z}; - w_list = {0.0}; - i_list = {0}; - } else { - auto tables_weights = domain->tables_weights(); - double Z = log(domain->crp.alpha + domain->crp.N); - size_t idx = 0; - for (const auto& [t, w] : tables_weights) { - t_list.push_back(t); - w_list.push_back(log(w) - Z); - i_list.push_back(idx++); - } - assert(idx == t_list.size()); - } - // Add to universe. - item_universe.push_back({domain->name, item}); - index_universe.push_back(i_list); - weight_universe.push_back(w_list); - int loc = index_universe.size() - 1; - cluster_universe[domain->name][item] = {loc, t_list}; - } - } - assert(item_universe.size() == index_universe.size()); - assert(item_universe.size() == weight_universe.size()); - // Compute data probability given cluster combinations. - std::vector items_product = product(index_universe); - std::vector logps; // reserve size - logps.reserve(index_universe.size()); - for (const T_items& indexes : items_product) { - double logp_indexes = 0; - // Compute weight of cluster assignments. - double weight = 0.0; - for (int i = 0; i < std::ssize(indexes); ++i) { - weight += weight_universe.at(i).at(indexes[i]); - } - logp_indexes += weight; - // Compute weight of data given cluster assignments. - auto f_logp = [&](auto rel, const T_items& items, - const ObservationVariant& value) -> double { - std::vector z; - z.reserve(domains.size()); - for (int i = 0; i < std::ssize(rel->domains); ++i) { - Domain* domain = rel->domains.at(i); - T_item item = items.at(i); - auto& [loc, t_list] = cluster_universe.at(domain->name).at(item); - T_item t = t_list.at(indexes.at(loc)); - z.push_back(t); - } - auto v = std::get< - typename std::remove_reference_t::ValueType>(value); - auto prior = - std::get::DType*>( - cluster_prior_from_spec(rel->dist_spec)); - return rel->clusters.contains(z) ? rel->clusters.at(z)->logp(v) - : prior->logp(v); - }; - for (const auto& [r, items, value] : observations) { - auto g = std::bind(f_logp, std::placeholders::_1, items, value); - double logp_obs = std::visit(g, relations.at(r)); - logp_indexes += logp_obs; - }; - logps.push_back(logp_indexes); - } - return logsumexp(logps); - } - - double logp_score() const { - double logp_score_crp = 0.0; - for (const auto& [d, domain] : domains) { - logp_score_crp += domain->crp.logp_score(); - } - double logp_score_relation = 0.0; - for (const auto& [r, relation] : relations) { - double logp_rel = - std::visit([](auto rel) { return rel->logp_score(); }, relation); - logp_score_relation += logp_rel; - } - return logp_score_crp + logp_score_relation; - } - - void add_relation(const std::string& name, const T_relation& relation) { - assert(!schema.contains(name)); - assert(!relations.contains(name)); - std::vector doms; - for (const auto& d : relation.domains) { - if (domains.count(d) == 0) { - assert(domain_to_relations.count(d) == 0); - domains[d] = new Domain(d); - domain_to_relations[d] = std::unordered_set(); - } - domain_to_relations.at(d).insert(name); - doms.push_back(domains.at(d)); - } - relations[name] = - relation_from_spec(name, relation.distribution_spec, doms); - schema[name] = relation; - } - - void remove_relation(const std::string& name) { - std::unordered_set ds; - auto rel_domains = - std::visit([](auto r) { return r->domains; }, relations.at(name)); - for (const Domain* const domain : rel_domains) { - ds.insert(domain->name); - } - for (const auto& d : ds) { - domain_to_relations.at(d).erase(name); - // TODO: Remove r from domains.at(d)->items - if (domain_to_relations.at(d).empty()) { - domain_to_relations.erase(d); - delete domains.at(d); - domains.erase(d); - } - } - std::visit([](auto r) { delete r; }, relations.at(name)); - relations.erase(name); - schema.erase(name); - } - - // Disable copying. - IRM& operator=(const IRM&) = delete; - IRM(const IRM&) = delete; -}; class HIRM { public: @@ -295,235 +23,37 @@ class HIRM { code_to_relation; // map from code to relation CRP crp; // clustering model for relations - HIRM(const T_schema& schema, std::mt19937* prng) { - for (const auto& [name, relation] : schema) { - this->add_relation(prng, name, relation); - } - } + HIRM(const T_schema& schema, std::mt19937* prng); void incorporate(std::mt19937* prng, const std::string& r, - const T_items& items, const ObservationVariant& value) { - IRM* irm = relation_to_irm(r); - irm->incorporate(prng, r, items, value); - } - void unincorporate(const std::string& r, const T_items& items) { - IRM* irm = relation_to_irm(r); - irm->unincorporate(r, items); - } + const T_items& items, const ObservationVariant& value); + void unincorporate(const std::string& r, const T_items& items); - int relation_to_table(const std::string& r) { - int rc = relation_to_code.at(r); - return crp.assignments.at(rc); - } - IRM* relation_to_irm(const std::string& r) { - int rc = relation_to_code.at(r); - int table = crp.assignments.at(rc); - return irms.at(table); - } - RelationVariant get_relation(const std::string& r) { - IRM* irm = relation_to_irm(r); - return irm->relations.at(r); - } + int relation_to_table(const std::string& r); + IRM* relation_to_irm(const std::string& r); + RelationVariant get_relation(const std::string& r); - void transition_cluster_assignments_all(std::mt19937* prng) { - for (const auto& [r, rc] : relation_to_code) { - transition_cluster_assignment_relation(prng, r); - } - } + void transition_cluster_assignments_all(std::mt19937* prng); void transition_cluster_assignments(std::mt19937* prng, - const std::vector& rs) { - for (const auto& r : rs) { - transition_cluster_assignment_relation(prng, r); - } - } + const std::vector& rs); void transition_cluster_assignment_relation(std::mt19937* prng, - const std::string& r) { - int rc = relation_to_code.at(r); - int table_current = crp.assignments.at(rc); - RelationVariant relation = get_relation(r); - T_relation t_relation = - std::visit([](auto rel) { return rel->trel; }, relation); - auto crp_dist = crp.tables_weights_gibbs(table_current); - std::vector tables; - std::vector logps; - int* table_aux = nullptr; - IRM* irm_aux = nullptr; - // Compute probabilities of each table. - for (const auto& [table, n_customers] : crp_dist) { - IRM* irm; - if (!irms.contains(table)) { - irm = new IRM({}); - assert(table_aux == nullptr); - assert(irm_aux == nullptr); - table_aux = (int*)malloc(sizeof(*table_aux)); - *table_aux = table; - irm_aux = irm; - } else { - irm = irms.at(table); - } - if (table != table_current) { - irm->add_relation(r, t_relation); - std::visit( - [&](auto rel) { - for (const auto& [items, value] : rel->data) { - irm->incorporate(prng, r, items, value); - } - }, - relation); - } - RelationVariant rel_r = irm->relations.at(r); - double lp_data = - std::visit([](auto rel) { return rel->logp_score(); }, rel_r); - double lp_crp = log(n_customers); - logps.push_back(lp_crp + lp_data); - tables.push_back(table); - } - // Sample new table. - int idx = log_choice(logps, prng); - T_item choice = tables[idx]; - - // Remove relation from all other tables. - for (const auto& [table, customers] : crp.tables) { - IRM* irm = irms.at(table); - if (table != choice) { - assert(irm->relations.count(r) == 1); - irm->remove_relation(r); - } - if (irm->relations.empty()) { - assert(crp.tables[table].size() == 1); - assert(table == table_current); - irms.erase(table); - delete irm; - } - } - // Add auxiliary table if necessary. - if ((table_aux != nullptr) && (choice == *table_aux)) { - assert(irm_aux != nullptr); - irms[choice] = irm_aux; - } else { - delete irm_aux; - } - free(table_aux); - // Update the CRP. - crp.unincorporate(rc); - crp.incorporate(rc, choice); - assert(irms.size() == crp.tables.size()); - for (const auto& [table, irm] : irms) { - assert(crp.tables.contains(table)); - } - } + const std::string& r); void set_cluster_assignment_gibbs(std::mt19937* prng, const std::string& r, - int table) { - assert(irms.size() == crp.tables.size()); - int rc = relation_to_code.at(r); - int table_current = crp.assignments.at(rc); - RelationVariant relation = get_relation(r); - auto f_obs = [&](auto rel) { - T_relation trel = rel->trel; - IRM* irm = relation_to_irm(r); - auto observations = rel->data; - // Remove from current IRM. - irm->remove_relation(r); - if (irm->relations.empty()) { - irms.erase(table_current); - delete irm; - } - // Add to target IRM. - if (!irms.contains(table)) { - irm = new IRM({}); - irms[table] = irm; - } - irm = irms.at(table); - irm->add_relation(r, trel); - for (const auto& [items, value] : observations) { - irm->incorporate(prng, r, items, value); - } - }; - std::visit(f_obs, relation); - // Update CRP. - crp.unincorporate(rc); - crp.incorporate(rc, table); - assert(irms.size() == crp.tables.size()); - for (const auto& [table, irm] : irms) { - assert(crp.tables.contains(table)); - } - } + int table); void add_relation(std::mt19937* prng, const std::string& name, - const T_relation& rel) { - assert(!schema.contains(name)); - schema[name] = rel; - int offset = - (code_to_relation.empty()) - ? 0 - : std::max_element(code_to_relation.begin(), code_to_relation.end()) - ->first; - int rc = 1 + offset; - int table = crp.sample(prng); - crp.incorporate(rc, table); - if (irms.count(table) == 1) { - irms.at(table)->add_relation(name, rel); - } else { - irms[table] = new IRM({{name, rel}}); - } - assert(!relation_to_code.contains(name)); - assert(!code_to_relation.contains(rc)); - relation_to_code[name] = rc; - code_to_relation[rc] = name; - } - void remove_relation(const std::string& name) { - schema.erase(name); - int rc = relation_to_code.at(name); - int table = crp.assignments.at(rc); - bool singleton = crp.tables.at(table).size() == 1; - crp.unincorporate(rc); - irms.at(table)->remove_relation(name); - if (singleton) { - IRM* irm = irms.at(table); - assert(irm->relations.empty()); - irms.erase(table); - delete irm; - } - relation_to_code.erase(name); - code_to_relation.erase(rc); - } + const T_relation& rel); + + void remove_relation(const std::string& name); double logp( const std::vector>& - observations) { - std::unordered_map< - int, std::vector>> - obs_dict; - for (const auto& [r, items, value] : observations) { - int rc = relation_to_code.at(r); - int table = crp.assignments.at(rc); - if (!obs_dict.contains(table)) { - obs_dict[table] = {}; - } - obs_dict.at(table).push_back({r, items, value}); - } - double logp = 0.0; - for (const auto& [t, o] : obs_dict) { - logp += irms.at(t)->logp(o); - } - return logp; - } + observations); - double logp_score() { - double logp_score_crp = crp.logp_score(); - double logp_score_irms = 0.0; - for (const auto& [table, irm] : irms) { - logp_score_irms += irm->logp_score(); - } - return logp_score_crp + logp_score_irms; - } + double logp_score() const; - ~HIRM() { - for (const auto& [table, irm] : irms) { - delete irm; - } - } + ~HIRM(); // Disable copying. HIRM& operator=(const HIRM&) = delete; diff --git a/cxx/hirm_main.cc b/cxx/hirm_main.cc new file mode 100644 index 0000000..f829596 --- /dev/null +++ b/cxx/hirm_main.cc @@ -0,0 +1,175 @@ +// Copyright 2021 MIT Probabilistic Computing Project +// Apache License, Version 2.0, refer to LICENSE.txt + +#include "hirm.hh" + +#include +#include +#include + +#include "cxxopts.hpp" +#include "irm.hh" +#include "util_io.hh" + +#define GET_ELAPSED(t) double(clock() - t) / CLOCKS_PER_SEC + +// TODO(emilyaf): Refactor as a function for readibility/maintainability. +#define CHECK_TIMEOUT(timeout, t_begin) \ + if (timeout) { \ + auto elapsed = GET_ELAPSED(t_begin); \ + if (timeout < elapsed) { \ + printf("timeout after %1.2fs \n", elapsed); \ + break; \ + } \ + } + +// TODO(emilyaf): Refactor as a function for readibility/maintainability. +#define REPORT_SCORE(var_verbose, var_t, var_t_total, var_model) \ + if (var_verbose) { \ + auto t_delta = GET_ELAPSED(var_t); \ + var_t_total += t_delta; \ + double x = var_model->logp_score(); \ + printf("%f %f\n", var_t_total, x); \ + fflush(stdout); \ + } + +void inference_irm(std::mt19937* prng, IRM* irm, int iters, int timeout, + bool verbose) { + clock_t t_begin = clock(); + double t_total = 0; + for (int i = 0; i < iters; ++i) { + CHECK_TIMEOUT(timeout, t_begin); + single_step_irm_inference(prng, irm, t_total, verbose); + } +} + +void inference_hirm(std::mt19937* prng, HIRM* hirm, int iters, int timeout, + bool verbose) { + clock_t t_begin = clock(); + double t_total = 0; + for (int i = 0; i < iters; ++i) { + CHECK_TIMEOUT(timeout, t_begin); + // TRANSITION RELATIONS. + for (const auto& [r, rc] : hirm->relation_to_code) { + clock_t t = clock(); + hirm->transition_cluster_assignment_relation(prng, r); + REPORT_SCORE(verbose, t, t_total, hirm); + } + // TRANSITION IRMs. + for (const auto& [t, irm] : hirm->irms) { + single_step_irm_inference(prng, irm, t_total, verbose); + } + } +} + +int main(int argc, char** argv) { + cxxopts::Options options("hirm", + "Run a hierarchical infinite relational model."); + options.add_options()("help", "show help message")( + "mode", "options are {irm, hirm}", + cxxopts::value()->default_value("hirm"))( + "seed", "random seed", cxxopts::value()->default_value("10"))( + "iters", "number of inference iterations", + cxxopts::value()->default_value("10"))( + "verbose", "report results to terminal", + cxxopts::value()->default_value("false"))( + "timeout", "number of seconds of inference", + cxxopts::value()->default_value("0"))( + "load", "path to .[h]irm file with initial clusters", + cxxopts::value()->default_value(""))( + "path", "base name of the .schema file", cxxopts::value())( + "rest", "rest", + cxxopts::value>()->default_value({})); + options.parse_positional({"path", "rest"}); + options.positional_help(""); + + auto result = options.parse(argc, argv); + if (result.count("help")) { + std::cout << options.help() << std::endl; + return 0; + } + if (result.count("path") == 0) { + std::cout << options.help() << std::endl; + return 1; + } + + std::string path_base = result["path"].as(); + int seed = result["seed"].as(); + int iters = result["iters"].as(); + int timeout = result["timeout"].as(); + bool verbose = result["verbose"].as(); + std::string path_clusters = result["load"].as(); + std::string mode = result["mode"].as(); + + if (mode != "hirm" && mode != "irm") { + std::cout << options.help() << std::endl; + std::cout << "unknown mode " << mode << std::endl; + return 1; + } + + std::string path_obs = path_base + ".obs"; + std::string path_schema = path_base + ".schema"; + std::string path_save = path_base + "." + std::to_string(seed); + + printf("setting seed to %d\n", seed); + std::mt19937 prng(seed); + + std::cout << "loading schema from " << path_schema << std::endl; + auto schema = load_schema(path_schema); + + std::cout << "loading observations from " << path_obs << std::endl; + auto observations = load_observations(path_obs, schema); + auto encoding = encode_observations(schema, observations); + + if (mode == "irm") { + std::cout << "selected model is IRM" << std::endl; + IRM* irm; + // Load + if (path_clusters.empty()) { + irm = new IRM(schema); + std::cout << "incorporating observations" << std::endl; + incorporate_observations(&prng, *irm, encoding, observations); + } else { + irm = new IRM({}); + std::cout << "loading clusters from " << path_clusters << std::endl; + from_txt(&prng, irm, path_schema, path_obs, path_clusters); + } + // Infer + std::cout << "inferring " << iters << " iters; timeout " << timeout + << std::endl; + inference_irm(&prng, irm, iters, timeout, verbose); + // Save + path_save += ".irm"; + std::cout << "saving to " << path_save << std::endl; + to_txt(path_save, *irm, encoding); + // Free + free(irm); + return 0; + } + + if (mode == "hirm") { + std::cout << "selected model is HIRM" << std::endl; + HIRM* hirm; + // Load + if (path_clusters.empty()) { + hirm = new HIRM(schema, &prng); + std::cout << "incorporating observations" << std::endl; + incorporate_observations(&prng, *hirm, encoding, observations); + } else { + hirm = new HIRM({}, &prng); + std::cout << "loading clusters from " << path_clusters << std::endl; + from_txt(&prng, hirm, path_schema, path_obs, path_clusters); + } + // Infer + std::cout << "inferring " << iters << " iters; timeout " << timeout + << std::endl; + inference_hirm(&prng, hirm, iters, timeout, verbose); + // Save + path_save += ".hirm"; + std::cout << "saving to " << path_save << std::endl; + to_txt(path_save, *hirm, encoding); + // Free + free(hirm); + return 0; + } +} diff --git a/cxx/irm.cc b/cxx/irm.cc new file mode 100644 index 0000000..5e60606 --- /dev/null +++ b/cxx/irm.cc @@ -0,0 +1,321 @@ +// Copyright 2021 MIT Probabilistic Computing Project +// Apache License, Version 2.0, refer to LICENSE.txt + +#include "irm.hh" + +#include +#include +#include +#include +#include + +IRM::IRM(const T_schema& schema) { + for (const auto& [name, relation] : schema) { + this->add_relation(name, relation); + } +} + +IRM::~IRM() { + for (auto [d, domain] : domains) { + delete domain; + } + for (auto [r, relation] : relations) { + std::visit([](auto rel) { delete rel; }, relation); + } +} + +void IRM::incorporate(std::mt19937* prng, const std::string& r, + const T_items& items, ObservationVariant value) { + std::visit( + [&](auto rel) { + auto v = std::get< + typename std::remove_reference_t::ValueType>( + value); + rel->incorporate(prng, items, v); + }, + relations.at(r)); +} + +void IRM::unincorporate(const std::string& r, const T_items& items) { + std::visit([&](auto rel) { rel->unincorporate(items); }, relations.at(r)); +} + +void IRM::transition_cluster_assignments_all(std::mt19937* prng) { + for (const auto& [d, domain] : domains) { + for (const T_item item : domain->items) { + transition_cluster_assignment_item(prng, d, item); + } + } +} + +void IRM::transition_cluster_assignments(std::mt19937* prng, + const std::vector& ds) { + for (const std::string& d : ds) { + for (const T_item item : domains.at(d)->items) { + transition_cluster_assignment_item(prng, d, item); + } + } +} + +void IRM::transition_cluster_assignment_item(std::mt19937* prng, + const std::string& d, + const T_item& item) { + Domain* domain = domains.at(d); + auto crp_dist = domain->tables_weights_gibbs(item); + // Compute probability of each table. + std::vector tables; + std::vector logps; + tables.reserve(crp_dist.size()); + logps.reserve(crp_dist.size()); + for (const auto& [table, n_customers] : crp_dist) { + tables.push_back(table); + logps.push_back(log(n_customers)); + } + auto accumulate_logps = [&](auto rel) { + if (rel->has_observation(*domain, item)) { + std::vector lp_relation = + rel->logp_gibbs_exact(*domain, item, tables); + assert(lp_relation.size() == tables.size()); + assert(lp_relation.size() == logps.size()); + for (int i = 0; i < std::ssize(logps); ++i) { + logps[i] += lp_relation[i]; + } + } + }; + for (const auto& r : domain_to_relations.at(d)) { + std::visit(accumulate_logps, relations.at(r)); + } + // Sample new table. + assert(tables.size() == logps.size()); + int idx = log_choice(logps, prng); + T_item choice = tables[idx]; + // Move to new table (if necessary). + if (choice != domain->get_cluster_assignment(item)) { + auto set_cluster_r = [&](auto rel) { + if (rel->has_observation(*domain, item)) { + rel->set_cluster_assignment_gibbs(*domain, item, choice); + } + }; + for (const std::string& r : domain_to_relations.at(d)) { + std::visit(set_cluster_r, relations.at(r)); + } + domain->set_cluster_assignment_gibbs(item, choice); + } +} + +double IRM::logp( + const std::vector>& + observations) { + std::unordered_map> + relation_items_seen; + std::unordered_map> + domain_item_seen; + std::vector> item_universe; + std::vector> index_universe; + std::vector> weight_universe; + std::unordered_map< + std::string, + std::unordered_map>>> + cluster_universe; + // Compute all cluster combinations. + for (const auto& [r, items, value] : observations) { + // Assert observation is unique. + assert(!relation_items_seen[r].contains(items)); + relation_items_seen[r].insert(items); + // Process each (domain, item) in the observations. + RelationVariant relation = relations.at(r); + int arity = + std::visit([](auto rel) { return rel->domains.size(); }, relation); + assert(std::ssize(items) == arity); + for (int i = 0; i < arity; ++i) { + // Skip if (domain, item) processed. + Domain* domain = + std::visit([&](auto rel) { return rel->domains.at(i); }, relation); + T_item item = items.at(i); + if (domain_item_seen[domain->name].contains(item)) { + assert(cluster_universe[domain->name].contains(item)); + continue; + } + domain_item_seen[domain->name].insert(item); + // Obtain tables, weights, indexes for this item. + std::vector t_list; + std::vector w_list; + std::vector i_list; + size_t n_tables = domain->tables_weights().size() + 1; + t_list.reserve(n_tables); + w_list.reserve(n_tables); + i_list.reserve(n_tables); + if (domain->items.contains(item)) { + int z = domain->get_cluster_assignment(item); + t_list = {z}; + w_list = {0.0}; + i_list = {0}; + } else { + auto tables_weights = domain->tables_weights(); + double Z = log(domain->crp.alpha + domain->crp.N); + size_t idx = 0; + for (const auto& [t, w] : tables_weights) { + t_list.push_back(t); + w_list.push_back(log(w) - Z); + i_list.push_back(idx++); + } + assert(idx == t_list.size()); + } + // Add to universe. + item_universe.push_back({domain->name, item}); + index_universe.push_back(i_list); + weight_universe.push_back(w_list); + int loc = index_universe.size() - 1; + cluster_universe[domain->name][item] = {loc, t_list}; + } + } + assert(item_universe.size() == index_universe.size()); + assert(item_universe.size() == weight_universe.size()); + // Compute data probability given cluster combinations. + std::vector items_product = product(index_universe); + std::vector logps; // reserve size + logps.reserve(index_universe.size()); + for (const T_items& indexes : items_product) { + double logp_indexes = 0; + // Compute weight of cluster assignments. + double weight = 0.0; + for (int i = 0; i < std::ssize(indexes); ++i) { + weight += weight_universe.at(i).at(indexes[i]); + } + logp_indexes += weight; + // Compute weight of data given cluster assignments. + auto f_logp = [&](auto rel, const T_items& items, + const ObservationVariant& value) -> double { + std::vector z; + z.reserve(domains.size()); + for (int i = 0; i < std::ssize(rel->domains); ++i) { + Domain* domain = rel->domains.at(i); + T_item item = items.at(i); + auto& [loc, t_list] = cluster_universe.at(domain->name).at(item); + T_item t = t_list.at(indexes.at(loc)); + z.push_back(t); + } + auto v = std::get< + typename std::remove_reference_t::ValueType>(value); + auto prior = + std::get::DType*>( + cluster_prior_from_spec(rel->dist_spec)); + return rel->clusters.contains(z) ? rel->clusters.at(z)->logp(v) + : prior->logp(v); + }; + for (const auto& [r, items, value] : observations) { + auto g = std::bind(f_logp, std::placeholders::_1, items, value); + double logp_obs = std::visit(g, relations.at(r)); + logp_indexes += logp_obs; + }; + logps.push_back(logp_indexes); + } + return logsumexp(logps); +} + +double IRM::logp_score() const { + double logp_score_crp = 0.0; + for (const auto& [d, domain] : domains) { + logp_score_crp += domain->crp.logp_score(); + } + double logp_score_relation = 0.0; + for (const auto& [r, relation] : relations) { + double logp_rel = + std::visit([](auto rel) { return rel->logp_score(); }, relation); + logp_score_relation += logp_rel; + } + return logp_score_crp + logp_score_relation; +} + +void IRM::add_relation(const std::string& name, const T_relation& relation) { + assert(!schema.contains(name)); + assert(!relations.contains(name)); + std::vector doms; + for (const auto& d : relation.domains) { + if (domains.count(d) == 0) { + assert(domain_to_relations.count(d) == 0); + domains[d] = new Domain(d); + domain_to_relations[d] = std::unordered_set(); + } + domain_to_relations.at(d).insert(name); + doms.push_back(domains.at(d)); + } + relations[name] = + relation_from_spec(name, relation.distribution_spec, doms); + schema[name] = relation; +} + +void IRM::remove_relation(const std::string& name) { + std::unordered_set ds; + auto rel_domains = + std::visit([](auto r) { return r->domains; }, relations.at(name)); + for (const Domain* const domain : rel_domains) { + ds.insert(domain->name); + } + for (const auto& d : ds) { + domain_to_relations.at(d).erase(name); + // TODO: Remove r from domains.at(d)->items + if (domain_to_relations.at(d).empty()) { + domain_to_relations.erase(d); + delete domains.at(d); + domains.erase(d); + } + } + std::visit([](auto r) { delete r; }, relations.at(name)); + relations.erase(name); + schema.erase(name); +} + + +#define GET_ELAPSED(t) double(clock() - t) / CLOCKS_PER_SEC + +// TODO(emilyaf): Refactor as a function for readibility/maintainability. +#define CHECK_TIMEOUT(timeout, t_begin) \ + if (timeout) { \ + auto elapsed = GET_ELAPSED(t_begin); \ + if (timeout < elapsed) { \ + printf("timeout after %1.2fs \n", elapsed); \ + break; \ + } \ + } + +// TODO(emilyaf): Refactor as a function for readibility/maintainability. +#define REPORT_SCORE(var_verbose, var_t, var_t_total, var_model) \ + if (var_verbose) { \ + auto t_delta = GET_ELAPSED(var_t); \ + var_t_total += t_delta; \ + double x = var_model->logp_score(); \ + printf("%f %f\n", var_t_total, x); \ + fflush(stdout); \ + } + +void single_step_irm_inference(std::mt19937* prng, IRM* irm, double& t_total, + bool verbose) { + // TRANSITION ASSIGNMENTS. + for (const auto& [d, domain] : irm->domains) { + for (const auto item : domain->items) { + clock_t t = clock(); + irm->transition_cluster_assignment_item(prng, d, item); + REPORT_SCORE(verbose, t, t_total, irm); + } + } + // TRANSITION DISTRIBUTION HYPERPARAMETERS. + for (const auto& [r, relation] : irm->relations) { + std::visit( + [&](auto r) { + for (const auto& [c, distribution] : r->clusters) { + clock_t t = clock(); + distribution->transition_hyperparameters(prng); + REPORT_SCORE(verbose, t, t_total, irm); + } + }, + relation); + } + // TRANSITION ALPHA. + for (const auto& [d, domain] : irm->domains) { + clock_t t = clock(); + domain->crp.transition_alpha(prng); + REPORT_SCORE(verbose, t, t_total, irm); + } +} + diff --git a/cxx/irm.hh b/cxx/irm.hh new file mode 100644 index 0000000..56e99d3 --- /dev/null +++ b/cxx/irm.hh @@ -0,0 +1,60 @@ +// Copyright 2020 +// See LICENSE.txt + +#pragma once +#include +#include +#include + +#include "relation.hh" +#include "relation_variant.hh" +#include "util_distribution_variant.hh" + +// Map from names to T_relation's. +typedef std::map T_schema; + +class IRM { + public: + T_schema schema; // schema of relations + std::unordered_map domains; // map from name to Domain + std::unordered_map + relations; // map from name to Relation + std::unordered_map> + domain_to_relations; // reverse map + + IRM(const T_schema& schema); + + ~IRM(); + + void incorporate(std::mt19937* prng, const std::string& r, + const T_items& items, ObservationVariant value); + + void unincorporate(const std::string& r, const T_items& items); + + void transition_cluster_assignments_all(std::mt19937* prng); + + void transition_cluster_assignments(std::mt19937* prng, + const std::vector& ds); + + void transition_cluster_assignment_item(std::mt19937* prng, + const std::string& d, + const T_item& item); + double logp( + const std::vector>& + observations); + + double logp_score() const; + + void add_relation(const std::string& name, const T_relation& relation); + + void remove_relation(const std::string& name); + + // Disable copying. + IRM& operator=(const IRM&) = delete; + IRM(const IRM&) = delete; +}; + + +// Run a single step of inference on an IRM model. +void single_step_irm_inference(std::mt19937* prng, IRM* irm, double& t_total, + bool verbose); diff --git a/cxx/irm_test.cc b/cxx/irm_test.cc new file mode 100644 index 0000000..64e0b93 --- /dev/null +++ b/cxx/irm_test.cc @@ -0,0 +1,47 @@ +// Apache License, Version 2.0, refer to LICENSE.txt + +#define BOOST_TEST_MODULE test IRM + +#include "irm.hh" +#include "util_distribution_variant.hh" + +#include +namespace tt = boost::test_tools; + +BOOST_AUTO_TEST_CASE(test_irm) { + std::map schema1{ + {"R1", T_relation{{"D1", "D1"}, DistributionSpec {DistributionEnum::bernoulli}}}, + {"R2", T_relation{{"D1", "D2"}, DistributionSpec {DistributionEnum::normal}}}, + {"R3", T_relation{{"D3", "D1"}, DistributionSpec {DistributionEnum::bigram}}} + }; + IRM irm(schema1); + + BOOST_TEST(irm.logp_score() == 0.0); + + std::mt19937 prng; + irm.transition_cluster_assignments_all(&prng); + BOOST_TEST(irm.logp_score() == 0.0); + + irm.remove_relation("R3"); + + auto obs0 = observation_string_to_value("0", DistributionEnum::bernoulli); + + double logp_x = irm.logp({{"R1", {1, 2}, obs0}}); + + irm.incorporate(&prng, "R1", {1, 2}, obs0); + double one_obs_score = irm.logp_score(); + BOOST_TEST(one_obs_score < 0.0); + + // TODO(thomaswc): Figure out why the next test doesn't pass. + // BOOST_TEST(one_obs_score == logp_x); + + // Transitioning clusters shouldn't change the score with only one + // observation. + irm.transition_cluster_assignments_all(&prng); + BOOST_TEST(irm.logp_score() == one_obs_score); + + // TODO(thomaswc): Uncomment below when relation::unincorporate is + // implemented. + // irm.unincorporate("R1", {1, 2}); + // BOOST_TEST(irm.logp_score() == 0.0); +} diff --git a/cxx/relation_test.cc b/cxx/relation_test.cc index 55da7a1..c1ff13d 100644 --- a/cxx/relation_test.cc +++ b/cxx/relation_test.cc @@ -42,7 +42,8 @@ BOOST_AUTO_TEST_CASE(test_relation) { BOOST_TEST(z2[1] == 191); BOOST_TEST(z2[2] == 0); - double lpg = R1.logp_gibbs_approx(D1, 0, 1); + double lpg __attribute__ ((unused)); + lpg = R1.logp_gibbs_approx(D1, 0, 1); lpg = R1.logp_gibbs_approx(D1, 0, 0); lpg = R1.logp_gibbs_approx(D1, 0, 10); R1.set_cluster_assignment_gibbs(D1, 0, 1); @@ -57,4 +58,4 @@ BOOST_AUTO_TEST_CASE(test_relation) { lpg = R2.logp_gibbs_approx(D2, 2, 0); R2.set_cluster_assignment_gibbs(D3, 3, 1); D1.set_cluster_assignment_gibbs(0, 1); -} \ No newline at end of file +} diff --git a/cxx/relation_variant.cc b/cxx/relation_variant.cc new file mode 100644 index 0000000..8595709 --- /dev/null +++ b/cxx/relation_variant.cc @@ -0,0 +1,30 @@ +// Copyright 2024 +// See LICENSE.txt + +#include "relation_variant.hh" + +#include + +#include "distributions/beta_bernoulli.hh" +#include "distributions/bigram.hh" +#include "distributions/dirichlet_categorical.hh" +#include "distributions/normal.hh" +#include "domain.hh" +#include "relation.hh" + +RelationVariant relation_from_spec(const std::string& name, + const DistributionSpec& dist_spec, + std::vector& domains) { + switch (dist_spec.distribution) { + case DistributionEnum::bernoulli: + return new Relation(name, dist_spec, domains); + case DistributionEnum::bigram: + return new Relation(name, dist_spec, domains); + case DistributionEnum::categorical: + return new Relation(name, dist_spec, domains); + case DistributionEnum::normal: + return new Relation(name, dist_spec, domains); + default: + assert(false && "Unsupported distribution enum value."); + } +} diff --git a/cxx/relation_variant.hh b/cxx/relation_variant.hh new file mode 100644 index 0000000..080e78e --- /dev/null +++ b/cxx/relation_variant.hh @@ -0,0 +1,26 @@ +// Copyright 2024 +// See LICENSE.txt + +#pragma once + +#include +#include +#include + +#include "util_distribution_variant.hh" + +class BetaBernoulli; +class Bigram; +class DirichletCategorical; +class Normal; +class Domain; +template +class Relation; + +using RelationVariant = + std::variant*, Relation*, + Relation*, Relation*>; + +RelationVariant relation_from_spec(const std::string& name, + const DistributionSpec& dist_spec, + std::vector& domains); diff --git a/cxx/tests/BUILD b/cxx/tests/BUILD index 7c67b87..caa4768 100644 --- a/cxx/tests/BUILD +++ b/cxx/tests/BUILD @@ -2,7 +2,7 @@ cc_binary( name = "test_hirm_animals", srcs = ["test_hirm_animals.cc"], deps = [ - "//:headers", + "//:hirm_lib", "//:util_distribution_variant", "//:util_io", "//distributions", @@ -13,7 +13,7 @@ cc_binary( name = "test_irm_two_relations", srcs = ["test_irm_two_relations.cc"], deps = [ - "//:headers", + "//:irm", "//:util_distribution_variant", "//:util_io", "//distributions", @@ -24,9 +24,10 @@ cc_binary( name = "test_misc", srcs = ["test_misc.cc"], deps = [ - "//:headers", + "//:hirm_lib", + "//:irm", "//:util_distribution_variant", "//:util_io", "//distributions", ], -) \ No newline at end of file +) diff --git a/cxx/tests/test_irm_two_relations.cc b/cxx/tests/test_irm_two_relations.cc index e3e8509..eb2033b 100644 --- a/cxx/tests/test_irm_two_relations.cc +++ b/cxx/tests/test_irm_two_relations.cc @@ -11,7 +11,7 @@ #include #include -#include "hirm.hh" +#include "irm.hh" #include "util_io.hh" #include "util_math.hh" @@ -44,13 +44,10 @@ int main(int argc, char** argv) { IRM irm(schema); incorporate_observations(&prng, irm, encoding, observations); printf("running for %d iterations\n", iters); + double t_total = 0.0; for (int i = 0; i < iters; i++) { - irm.transition_cluster_assignments_all(&prng); - for (auto const& [d, domain] : irm.domains) { - domain->crp.transition_alpha(&prng); - } - double x = irm.logp_score(); - printf("iter %d, score %f\n", i, x); + single_step_irm_inference(&prng, &irm, t_total, true); + printf("iter %d, score %f\n", i, irm.logp_score()); } std::string path_clusters = path_base + ".irm"; @@ -110,6 +107,18 @@ int main(int argc, char** argv) { auto dx = irx.domains.at(d); dx->crp.alpha = dm->crp.alpha; } + // They shouldn't agree yet because irx's hyperparameters haven't been + // transitioned. + assert(abs(irx.logp_score() - irm.logp_score()) > 1e-8); + for (const auto& r : {"R1", "R2"}) { + auto r1m = std::get*>(irm.relations.at(r)); + auto r1x = std::get*>(irx.relations.at(r)); + for (const auto& [c, distribution] : r1m->clusters) { + auto dx = r1x->clusters.at(c); + dx->alpha = distribution->alpha; + dx->beta = distribution->beta; + } + } assert(abs(irx.logp_score() - irm.logp_score()) < 1e-8); // Check domains agree. for (const auto& d : {"D1", "D2"}) { diff --git a/cxx/util_distribution_variant.cc b/cxx/util_distribution_variant.cc index f3b4ec2..337cd3b 100644 --- a/cxx/util_distribution_variant.cc +++ b/cxx/util_distribution_variant.cc @@ -11,8 +11,6 @@ #include "distributions/crp.hh" #include "distributions/dirichlet_categorical.hh" #include "distributions/normal.hh" -#include "domain.hh" -#include "relation.hh" ObservationVariant observation_string_to_value( const std::string& value_str, const DistributionEnum& distribution) { @@ -77,20 +75,3 @@ DistributionVariant cluster_prior_from_spec(const DistributionSpec& spec) { assert(false && "Unsupported distribution enum value."); } } - -RelationVariant relation_from_spec(const std::string& name, - const DistributionSpec& dist_spec, - std::vector& domains) { - switch (dist_spec.distribution) { - case DistributionEnum::bernoulli: - return new Relation(name, dist_spec, domains); - case DistributionEnum::bigram: - return new Relation(name, dist_spec, domains); - case DistributionEnum::categorical: - return new Relation(name, dist_spec, domains); - case DistributionEnum::normal: - return new Relation(name, dist_spec, domains); - default: - assert(false && "Unsupported distribution enum value."); - } -} diff --git a/cxx/util_distribution_variant.hh b/cxx/util_distribution_variant.hh index 5acae20..2a0d14c 100644 --- a/cxx/util_distribution_variant.hh +++ b/cxx/util_distribution_variant.hh @@ -1,8 +1,9 @@ // Copyright 2024 // See LICENSE.txt -// This file collects classes/functions that depend on the set of distribution -// subclasses and should be updated when a new subclass is added. +// Classes and functions for dealing with Distributions and their values in a +// generic manner. When a new subclass is added, this file and +// relation_variant.{hh,cc} will need to be updated. #pragma once @@ -11,6 +12,11 @@ #include #include +class BetaBernoulli; +class Bigram; +class DirichletCategorical; +class Normal; + enum class DistributionEnum { bernoulli, bigram, categorical, normal }; struct DistributionSpec { @@ -18,22 +24,11 @@ struct DistributionSpec { std::map distribution_args = {}; }; -class BetaBernoulli; -class Bigram; -class DirichletCategorical; -class Normal; -class Domain; -template -class Relation; - // Set of all distribution sample types. using ObservationVariant = std::variant; using DistributionVariant = std::variant; -using RelationVariant = - std::variant*, Relation*, - Relation*, Relation*>; ObservationVariant observation_string_to_value( const std::string& value_str, const DistributionEnum& distribution); @@ -41,7 +36,3 @@ ObservationVariant observation_string_to_value( DistributionSpec parse_distribution_spec(const std::string& dist_str); DistributionVariant cluster_prior_from_spec(const DistributionSpec& spec); - -RelationVariant relation_from_spec(const std::string& name, - const DistributionSpec& dist_spec, - std::vector& domains);