diff --git a/cxx/BUILD b/cxx/BUILD index 811cc71..f6ca671 100644 --- a/cxx/BUILD +++ b/cxx/BUILD @@ -30,6 +30,7 @@ cc_binary( deps = [ ":cxxopts", ":headers", + ":util_distribution_variant", ":util_hash", ":util_io", ":util_math", @@ -42,12 +43,24 @@ cc_library( hdrs = ["relation.hh"], deps = [ ":domain", + ":util_distribution_variant", ":util_hash", ":util_math", "//distributions:base" ], ) +cc_library( + name = "util_distribution_variant", + srcs = ["util_distribution_variant.cc"], + visibility = [":__subpackages__"], + hdrs = ["util_distribution_variant.hh"], + deps = [ + "//distributions", + ], +) + + cc_library( name = "util_hash", hdrs = ["util_hash.hh"], @@ -96,6 +109,16 @@ cc_test( ], ) +cc_test( + name = "util_distribution_variant_test", + srcs = ["util_distribution_variant_test.cc"], + deps = [ + ":util_distribution_variant", + "@boost//:algorithm", + "@boost//:test", + ], +) + cc_test( name = "util_math_test", srcs = ["util_math_test.cc"], diff --git a/cxx/distributions/BUILD b/cxx/distributions/BUILD index 9841f15..a1a3fe7 100644 --- a/cxx/distributions/BUILD +++ b/cxx/distributions/BUILD @@ -4,7 +4,6 @@ cc_library( name = "distributions", visibility = ["//:__subpackages__"], deps = [ - ":adapter", ":base", ":beta_bernoulli", ":bigram", @@ -14,15 +13,6 @@ cc_library( ], ) -cc_library( - name = "adapter", - hdrs = ["adapter.hh"], - visibility = ["//:__subpackages__"], - deps = [ - ":base", - ], -) - cc_library( name = "base", hdrs = ["base.hh"], @@ -87,17 +77,6 @@ cc_library( ], ) -cc_test( - name = "adapter_test", - srcs = ["adapter_test.cc"], - deps = [ - ":adapter", - ":normal", - "@boost//:algorithm", - "@boost//:test", - ], -) - cc_test( name = "beta_bernoulli_test", srcs = ["beta_bernoulli_test.cc"], diff --git a/cxx/distributions/adapter.hh b/cxx/distributions/adapter.hh deleted file mode 100644 index b679990..0000000 --- a/cxx/distributions/adapter.hh +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2024 -// See LICENSE.txt - -// A class for turning Distribution's into -// Distribution's. - -#pragma once -#include -#include -#include - -#include "distributions/base.hh" - -template -class DistributionAdapter : public Distribution { - public: - // The underlying distribution that is being adapted. We own the - // underlying Distribution. - Distribution* d; - - DistributionAdapter(Distribution* dd) : d(dd) {}; - - S from_string(const std::string& x) const { - S s; - std::istringstream(x) >> s; - return s; - } - - std::string to_string(const S& s) const { - std::ostringstream os; - os << s; - return os.str(); - } - - void incorporate(const std::string& x) { - S s = from_string(x); - ++N; - d->incorporate(s); - } - - void unincorporate(const std::string& x) { - S s = from_string(x); - --N; - d->unincorporate(s); - } - - double logp(const std::string& x) const { - S s = from_string(x); - return d->logp(s); - } - - double logp_score() const { return d->logp_score(); } - - std::string sample() { - S s = d->sample(); - return to_string(s); - } - - void transition_hyperparameters() { d->transition_hyperparameters(); } - - ~DistributionAdapter() { delete d; } -}; diff --git a/cxx/distributions/adapter_test.cc b/cxx/distributions/adapter_test.cc deleted file mode 100644 index f86ce73..0000000 --- a/cxx/distributions/adapter_test.cc +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2024 -// Refer to LICENSE.txt - -#define BOOST_TEST_MODULE test Normal - -#include "distributions/adapter.hh" - -#include - -#include "distributions/normal.hh" -namespace tt = boost::test_tools; - -BOOST_AUTO_TEST_CASE(adapt_normal) { - std::mt19937 prng; - Normal* n = new Normal(&prng); - DistributionAdapter ad(n); - - ad.incorporate("5.0"); - ad.incorporate("-2.0"); - BOOST_TEST(n->N == ad.N); - - ad.unincorporate("5.0"); - ad.incorporate("7.0"); - BOOST_TEST(n->N == ad.N); - - ad.unincorporate("-2.0"); - BOOST_TEST(n->N == ad.N); - - BOOST_TEST(ad.logp("6.0") == n->logp(6.), tt::tolerance(1e-6)); - BOOST_TEST(ad.logp_score() == n->logp_score(), tt::tolerance(1e-6)); - - std::string samp = ad.sample(); -} diff --git a/cxx/distributions/base.hh b/cxx/distributions/base.hh index 6423714..bd92dd3 100644 --- a/cxx/distributions/base.hh +++ b/cxx/distributions/base.hh @@ -3,6 +3,8 @@ template class Distribution { // Abstract base class for probability distributions in HIRM. + // New distribution subclasses need to be added to + // `util_distribution_variant` to be used in the (H)IRM models. public: typedef T SampleType; // N is the number of incorporated observations. diff --git a/cxx/distributions/beta_bernoulli.hh b/cxx/distributions/beta_bernoulli.hh index d0b0ce7..45bef74 100644 --- a/cxx/distributions/beta_bernoulli.hh +++ b/cxx/distributions/beta_bernoulli.hh @@ -7,6 +7,8 @@ #include "distributions/base.hh" #include "util_math.hh" +// TODO(thomaswc, emilyaf): Change BetaBernoulli to use bool instead of +// double. class BetaBernoulli : public Distribution { public: double alpha = 1; // hyperparameter diff --git a/cxx/distributions/crp.hh b/cxx/distributions/crp.hh index 2fc9a56..f3e60f8 100644 --- a/cxx/distributions/crp.hh +++ b/cxx/distributions/crp.hh @@ -6,14 +6,13 @@ #include #include - typedef int T_item; // TODO(emilyaf): Make this a distribution subclass. class CRP { public: double alpha = 1.; // concentration parameter - int N = 0; // number of customers + int N = 0; // number of customers std::unordered_map> tables; // map from table id to set of customers std::unordered_map assignments; // map from customer to table id diff --git a/cxx/hirm.hh b/cxx/hirm.hh index 083f910..f86d324 100644 --- a/cxx/hirm.hh +++ b/cxx/hirm.hh @@ -8,25 +8,12 @@ #include #include -#include "distributions/base.hh" -#include "distributions/beta_bernoulli.hh" -#include "distributions/bigram.hh" -#include "distributions/crp.hh" -#include "distributions/dirichlet_categorical.hh" -#include "distributions/normal.hh" #include "relation.hh" - +#include "util_distribution_variant.hh" // Map from names to T_relation's. typedef std::map T_schema; -using ObservationVariant = std::variant; - -using RelationVariant = - std::variant*, Relation*, - // Relation*, - Relation*>; - class IRM { public: T_schema schema; // schema of relations @@ -222,11 +209,13 @@ class IRM { T_item t = t_list.at(indexes.at(loc)); z.push_back(t); } - typename std::remove_reference_t::DType aux(prng); - auto cluster = rel->clusters.contains(z) ? rel->clusters.at(z) : &aux; auto v = std::get< typename std::remove_reference_t::ValueType>(value); - return cluster->logp(v); + auto prior = + std::get::DType*>( + cluster_prior_from_spec(rel->dist_spec, prng)); + return rel->clusters.contains(z) ? rel->clusters.at(z)->logp(v) + : prior->logp(v); }; for (const auto& [r, items, value] : observations) { auto g = std::bind(f_logp, std::placeholders::_1, items, value); @@ -265,21 +254,8 @@ class IRM { domain_to_relations.at(d).insert(name); doms.push_back(domains.at(d)); } - if (relation.distribution == "normal") { - relations[name] = - new Relation(name, relation.distribution, doms, prng); - } else if (relation.distribution == "bernoulli") { - relations[name] = - new Relation(name, relation.distribution, doms, prng); - } else if (relation.distribution == "bigram") { - relations[name] = - new Relation(name, relation.distribution, doms, prng); - } else { - assert(false); - } - // TODO(emilyaf): Enable Categorical. Maybe have "categorical5" e.g. for 5 - // categories. relations[name] = new Relation(name, - // relation.distribution, doms, prng); + relations[name] = + relation_from_spec(name, relation.distribution_spec, doms, prng); schema[name] = relation; } @@ -327,7 +303,7 @@ class HIRM { } } - void incorporate(const std::string& r, const T_items& items, + void incorporate(const std::string& r, const T_items& items, const ObservationVariant& value) { IRM* irm = relation_to_irm(r); irm->incorporate(r, items, value); @@ -366,7 +342,7 @@ class HIRM { int table_current = crp.assignments.at(rc); RelationVariant relation = get_relation(r); T_relation t_relation = - std::visit([](auto rel) { return rel->get_T_relation(); }, relation); + std::visit([](auto rel) { return rel->trel; }, relation); auto crp_dist = crp.tables_weights_gibbs(table_current); std::vector tables; std::vector logps; @@ -443,7 +419,7 @@ class HIRM { int table_current = crp.assignments.at(rc); RelationVariant relation = get_relation(r); auto f_obs = [&](auto rel) { - T_relation trel = rel->get_T_relation(); + T_relation trel = rel->trel; IRM* irm = relation_to_irm(r); auto observations = rel->data; // Remove from current IRM. diff --git a/cxx/relation.hh b/cxx/relation.hh index 4127bd2..128d29c 100644 --- a/cxx/relation.hh +++ b/cxx/relation.hh @@ -3,12 +3,19 @@ #pragma once +#include #include #include #include #include "distributions/base.hh" +#include "distributions/beta_bernoulli.hh" +#include "distributions/bigram.hh" +#include "distributions/crp.hh" +#include "distributions/dirichlet_categorical.hh" +#include "distributions/normal.hh" #include "domain.hh" +#include "util_distribution_variant.hh" #include "util_hash.hh" #include "util_math.hh" @@ -23,8 +30,7 @@ class T_relation { // is a distribution over. std::vector domains; - // Must be the name of a distribution in distributions/. - std::string distribution; + DistributionSpec distribution_spec; }; template @@ -32,12 +38,14 @@ class Relation { public: using ValueType = typename DistributionType::SampleType; using DType = DistributionType; - static_assert(std::is_base_of, DType>::value, + static_assert(std::is_base_of, DType>::value, "DistributionType must inherit from Distribution."); // human-readable name const std::string name; - // Distribution over the relation's codomain. - const std::string distribution; + // Relation spec. + T_relation trel; + // Distribution spec over the relation's codomain. + const DistributionSpec dist_spec; // list of domain pointers const std::vector domains; // map from cluster multi-index to Distribution pointer @@ -53,9 +61,9 @@ class Relation { data_r; std::mt19937* prng; - Relation(const std::string& name, const std::string& distribution, + Relation(const std::string& name, const DistributionSpec& dist_spec, const std::vector& domains, std::mt19937* prng) - : name(name), distribution(distribution), domains(domains) { + : name(name), dist_spec(dist_spec), domains(domains) { assert(!domains.empty()); assert(!name.empty()); this->prng = prng; @@ -63,6 +71,11 @@ class Relation { this->data_r[d->name] = std::unordered_map>(); } + std::vector domain_names; + for (const auto& d : domains) { + domain_names.push_back(d->name); + } + trel = {domain_names, dist_spec}; } ~Relation() { @@ -71,15 +84,6 @@ class Relation { } } - T_relation get_T_relation() { - T_relation trel; - trel.distribution = distribution; - for (const auto& d : domains) { - trel.domains.push_back(d->name); - } - return trel; - } - void incorporate(const T_items& items, ValueType value) { assert(!data.contains(items)); data[items] = value; @@ -97,7 +101,8 @@ class Relation { // Cannot use clusters[z] because BetaBernoulli // does not have a default constructor, whereas operator[] // calls default constructor when the key does not exist. - clusters[z] = new DistributionType(prng); + clusters[z] = + std::get(cluster_prior_from_spec(dist_spec, prng)); } clusters.at(z)->incorporate(value); } @@ -236,9 +241,9 @@ class Relation { T_items z = get_cluster_assignment_gibbs(items_list[0], domain, item, table); - DistributionType aux(prng); - DistributionType* cluster = clusters.contains(z) ? clusters.at(z) : &aux; - // auto cluster = self.clusters.get(z, self.aux()) + DistributionType* prior = + std::get(cluster_prior_from_spec(dist_spec, prng)); + DistributionType* cluster = clusters.contains(z) ? clusters.at(z) : prior; double logp0 = cluster->logp_score(); for (const T_items& items : items_list) { // assert(z == get_cluster_assignment_gibbs(items, domain, item, table)); @@ -258,7 +263,7 @@ class Relation { std::vector tables) { auto cluster_to_items_list = get_cluster_to_items_list(domain, item); int table_current = domain.get_cluster_assignment(item); - std::vector logps; // size this? + std::vector logps; logps.reserve(tables.size()); double lp_cluster; for (const int& table : tables) { @@ -320,8 +325,9 @@ class Relation { z.push_back(zi); logp_w += wi; } - DistributionType aux(prng); - DistributionType* cluster = clusters.contains(z) ? clusters.at(z) : &aux; + DistributionType* prior = + std::get(cluster_prior_from_spec(dist_spec, prng)); + DistributionType* cluster = clusters.contains(z) ? clusters.at(z) : prior; double logp_z = cluster->logp(value); double logp_zw = logp_z + logp_w; logps.push_back(logp_zw); @@ -355,7 +361,8 @@ class Relation { T_items z_new = get_cluster_assignment_gibbs(items, domain, item, table); if (!clusters.contains(z_new)) { // Move to fresh cluster. - clusters[z_new] = new DistributionType(prng); + clusters[z_new] = std::get( + cluster_prior_from_spec(dist_spec, prng)); clusters.at(z_new)->incorporate(x); } else { // Move to existing cluster. diff --git a/cxx/relation_test.cc b/cxx/relation_test.cc index dbd8512..9fcbd1e 100644 --- a/cxx/relation_test.cc +++ b/cxx/relation_test.cc @@ -21,7 +21,8 @@ BOOST_AUTO_TEST_CASE(test_relation) { D1.incorporate(0); D2.incorporate(1); D3.incorporate(3); - Relation R1("R1", "bernoulli", {&D1, &D2, &D3}, &prng); + DistributionSpec spec = DistributionSpec{DistributionEnum::bernoulli}; + Relation R1("R1", spec, {&D1, &D2, &D3}, &prng); R1.incorporate({0, 1, 3}, 1); R1.incorporate({1, 1, 3}, 1); R1.incorporate({3, 1, 3}, 1); @@ -46,7 +47,8 @@ BOOST_AUTO_TEST_CASE(test_relation) { lpg = R1.logp_gibbs_approx(D1, 0, 10); R1.set_cluster_assignment_gibbs(D1, 0, 1); - Relation R2("R1", "bigram", {&D2, &D3}, &prng); + DistributionSpec bigram_spec = DistributionSpec{DistributionEnum::bigram}; + Relation R2("R1", bigram_spec, {&D2, &D3}, &prng); R2.incorporate({1, 3}, "cat"); R2.incorporate({1, 2}, "dog"); R2.incorporate({1, 4}, "catt"); @@ -55,5 +57,4 @@ BOOST_AUTO_TEST_CASE(test_relation) { lpg = R2.logp_gibbs_approx(D2, 2, 0); R2.set_cluster_assignment_gibbs(D3, 3, 1); D1.set_cluster_assignment_gibbs(0, 1); - } \ No newline at end of file diff --git a/cxx/tests/BUILD b/cxx/tests/BUILD index c93b3d1..7c67b87 100644 --- a/cxx/tests/BUILD +++ b/cxx/tests/BUILD @@ -3,6 +3,7 @@ cc_binary( srcs = ["test_hirm_animals.cc"], deps = [ "//:headers", + "//:util_distribution_variant", "//:util_io", "//distributions", ], @@ -13,6 +14,7 @@ cc_binary( srcs = ["test_irm_two_relations.cc"], deps = [ "//:headers", + "//:util_distribution_variant", "//:util_io", "//distributions", ], @@ -23,6 +25,7 @@ cc_binary( srcs = ["test_misc.cc"], deps = [ "//:headers", + "//:util_distribution_variant", "//:util_io", "//distributions", ], diff --git a/cxx/tests/test_irm_two_relations.cc b/cxx/tests/test_irm_two_relations.cc index 8e8865b..ff0ce9f 100644 --- a/cxx/tests/test_irm_two_relations.cc +++ b/cxx/tests/test_irm_two_relations.cc @@ -29,7 +29,6 @@ int main(int argc, char** argv) { auto schema = load_schema(path_schema); for (auto const& [relation_name, relation] : schema) { printf("relation: %s, ", relation_name.c_str()); - printf("distribution: %s, ", relation.distribution.c_str()); printf("domains: "); for (auto const& domain : relation.domains) { printf("%s ", domain.c_str()); diff --git a/cxx/tests/test_misc.cc b/cxx/tests/test_misc.cc index 0e664c0..455e382 100644 --- a/cxx/tests/test_misc.cc +++ b/cxx/tests/test_misc.cc @@ -34,9 +34,9 @@ int main(int argc, char** argv) { printf("===== IRM ====\n"); std::map schema1{ - {"R1", T_relation{{"D1", "D1"}, "bernoulli"}}, - {"R2", T_relation{{"D1", "D2"}, "normal"}}, - {"R3", T_relation{{"D3", "D1"}, "bigram"}}, + {"R1", T_relation{{"D1", "D1"}, DistributionSpec {DistributionEnum::bernoulli}}}, + {"R2", T_relation{{"D1", "D2"}, DistributionSpec {DistributionEnum::normal}}}, + {"R3", T_relation{{"D3", "D1"}, DistributionSpec {DistributionEnum::bigram}}} }; IRM irm(schema1, &prng); @@ -59,7 +59,6 @@ int main(int argc, char** argv) { auto schema = load_schema("assets/animals.binary.schema"); for (auto const& i : schema) { printf("relation: %s\n", i.first.c_str()); - printf("distribution: %s\n", i.second.distribution.c_str()); printf("domains: "); for (auto const& j : i.second.domains) { printf("%s ", j.c_str()); diff --git a/cxx/util_distribution_variant.cc b/cxx/util_distribution_variant.cc new file mode 100644 index 0000000..3e6023f --- /dev/null +++ b/cxx/util_distribution_variant.cc @@ -0,0 +1,98 @@ +// Copyright 2024 +// See LICENSE.txt + +#include "util_distribution_variant.hh" + +#include +#include + +#include "distributions/beta_bernoulli.hh" +#include "distributions/bigram.hh" +#include "distributions/crp.hh" +#include "distributions/dirichlet_categorical.hh" +#include "distributions/normal.hh" +#include "domain.hh" +#include "relation.hh" + +ObservationVariant observation_string_to_value( + const std::string& value_str, const DistributionEnum& distribution) { + switch (distribution) { + case DistributionEnum::normal: + return std::stod(value_str); + case DistributionEnum::bernoulli: + return std::stod(value_str); + case DistributionEnum::categorical: + return std::stoi(value_str); + case DistributionEnum::bigram: + return value_str; + default: + assert(false && "Unsupported distribution enum value."); + } +} + +DistributionSpec parse_distribution_spec(const std::string& dist_str) { + std::map dist_name_to_enum = { + {"bernoulli", DistributionEnum::bernoulli}, + {"bigram", DistributionEnum::bigram}, + {"categorical", DistributionEnum::categorical}, + {"normal", DistributionEnum::normal}}; + std::string dist_name = dist_str.substr(0, dist_str.find('(')); + DistributionEnum dist = dist_name_to_enum.at(dist_name); + + std::string args_str = dist_str.substr(dist_name.length()); + if (args_str.empty()) { + return DistributionSpec{dist}; + } else { + assert(args_str[0] == '('); + assert(args_str.back() == ')'); + args_str = args_str.substr(1, args_str.size() - 2); + + std::string part; + std::istringstream iss{args_str}; + std::map dist_args; + while (std::getline(iss, part, ',')) { + assert(part.find('=') != std::string::npos); + std::string arg_name = part.substr(0, part.find('=')); + std::string arg_val = part.substr(part.find('=') + 1); + dist_args[arg_name] = arg_val; + } + return DistributionSpec{dist, dist_args}; + } +} + +DistributionVariant cluster_prior_from_spec(const DistributionSpec& spec, + std::mt19937* prng) { + switch (spec.distribution) { + case DistributionEnum::bernoulli: + return new BetaBernoulli(prng); + case DistributionEnum::bigram: + return new Bigram(prng); + case DistributionEnum::categorical: { + assert(spec.distribution_args.size() == 1); + int num_categories = std::stoi(spec.distribution_args.at("k")); + return new DirichletCategorical(prng, num_categories); + } + case DistributionEnum::normal: + return new Normal(prng); + default: + assert(false && "Unsupported distribution enum value."); + } +} + +RelationVariant relation_from_spec(const std::string& name, + const DistributionSpec& dist_spec, + std::vector& domains, + std::mt19937* prng) { + switch (dist_spec.distribution) { + case DistributionEnum::bernoulli: + return new Relation(name, dist_spec, domains, prng); + case DistributionEnum::bigram: + return new Relation(name, dist_spec, domains, prng); + case DistributionEnum::categorical: + return new Relation(name, dist_spec, domains, prng); + case DistributionEnum::normal: + return new Relation(name, dist_spec, domains, prng); + default: + assert(false && "Unsupported distribution enum value."); + } +} diff --git a/cxx/util_distribution_variant.hh b/cxx/util_distribution_variant.hh new file mode 100644 index 0000000..0386689 --- /dev/null +++ b/cxx/util_distribution_variant.hh @@ -0,0 +1,50 @@ +// Copyright 2024 +// See LICENSE.txt + +// This file collects classes/functions that depend on the set of distribution +// subclasses and should be updated when a new subclass is added. + +#pragma once + +#include +#include +#include +#include +#include + +enum class DistributionEnum { bernoulli, bigram, categorical, normal }; + +struct DistributionSpec { + DistributionEnum distribution; + std::map distribution_args = {}; +}; + +class BetaBernoulli; +class Bigram; +class DirichletCategorical; +class Normal; +class Domain; +template +class Relation; + +// Set of all distribution sample types. +using ObservationVariant = std::variant; + +using DistributionVariant = + std::variant; +using RelationVariant = + std::variant*, Relation*, + Relation*, Relation*>; + +ObservationVariant observation_string_to_value( + const std::string& value_str, const DistributionEnum& distribution); + +DistributionSpec parse_distribution_spec(const std::string& dist_str); + +DistributionVariant cluster_prior_from_spec(const DistributionSpec& spec, + std::mt19937* prng); + +RelationVariant relation_from_spec(const std::string& name, + const DistributionSpec& dist_spec, + std::vector& domains, + std::mt19937* prng); diff --git a/cxx/util_distribution_variant_test.cc b/cxx/util_distribution_variant_test.cc new file mode 100644 index 0000000..6a4f812 --- /dev/null +++ b/cxx/util_distribution_variant_test.cc @@ -0,0 +1,44 @@ +// Apache License, Version 2.0, refer to LICENSE.txt + +#define BOOST_TEST_MODULE test UtilDistributionVariant + +#include "util_distribution_variant.hh" + +#include +#include +#include + +#include "distributions/beta_bernoulli.hh" +#include "distributions/bigram.hh" +#include "distributions/dirichlet_categorical.hh" +#include "domain.hh" + +namespace tt = boost::test_tools; + +BOOST_AUTO_TEST_CASE(test_parse_distribution_spec) { + DistributionSpec dbb = parse_distribution_spec("bernoulli"); + BOOST_TEST((dbb.distribution == DistributionEnum::bernoulli)); + BOOST_TEST(dbb.distribution_args.empty()); + + DistributionSpec dbg = parse_distribution_spec("bigram"); + BOOST_TEST((dbg.distribution == DistributionEnum::bigram)); + BOOST_TEST(dbg.distribution_args.empty()); + + DistributionSpec dn = parse_distribution_spec("normal"); + BOOST_TEST((dn.distribution == DistributionEnum::normal)); + BOOST_TEST(dn.distribution_args.empty()); + + DistributionSpec dc = parse_distribution_spec("categorical(k=6)"); + BOOST_TEST((dc.distribution == DistributionEnum::categorical)); + BOOST_TEST((dc.distribution_args.size() == 1)); + std::string expected = "6"; + BOOST_CHECK_EQUAL(dc.distribution_args.at("k"), expected); +} + +BOOST_AUTO_TEST_CASE(test_cluster_prior_from_spec) { + std::mt19937 prng; + DistributionSpec dc = {DistributionEnum::categorical, {{"k", "4"}}}; + DistributionVariant dcdv = cluster_prior_from_spec(dc, &prng); + DirichletCategorical* dcd = std::get(dcdv); + BOOST_TEST(dcd->counts.size() == 4); +} \ No newline at end of file diff --git a/cxx/util_io.cc b/cxx/util_io.cc index d3fc9d0..e5738c1 100644 --- a/cxx/util_io.cc +++ b/cxx/util_io.cc @@ -5,7 +5,6 @@ #include #include -#include #include #include #include @@ -23,33 +22,21 @@ T_schema load_schema(const std::string& path) { T_relation relation; std::string relname; + std::string distribution_spec_str; - stream >> relation.distribution; + stream >> distribution_spec_str; stream >> relname; for (std::string w; stream >> w;) { relation.domains.push_back(w); } assert(relation.domains.size() > 0); + relation.distribution_spec = parse_distribution_spec(distribution_spec_str); schema[relname] = relation; } fp.close(); return schema; } -ObservationVariant observation_string_to_value(const std::string& value_str, - const std::string& distribution) { - if (distribution == "normal" || distribution == "bernoulli") { - return std::stod(value_str); - } else if (distribution == "categorical") { - return std::stoi(value_str); - } else if (distribution == "bigram") { - return value_str; - } else { - // Unrecognized distribution name. - assert(false); - } -} - T_observations load_observations(const std::string& path, const T_schema& schema) { std::ifstream fp(path, std::ifstream::in); @@ -67,7 +54,7 @@ T_observations load_observations(const std::string& path, stream >> value_str; stream >> relname; ObservationVariant value = observation_string_to_value( - value_str, schema.at(relname).distribution); + value_str, schema.at(relname).distribution_spec.distribution); for (std::string w; stream >> w;) { items.push_back(w); } @@ -143,10 +130,8 @@ void incorporate_observations(HIRM& hirm, const T_encoding& encoding, int code = item_to_code.at(domain).at(item); items_e.push_back(code); } - std::visit( - [&](const auto& v) {hirm.incorporate(relation, items_e, v);}, - value - ); + std::visit([&](const auto& v) { hirm.incorporate(relation, items_e, v); }, + value); } }