Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support distribution constructor args and enable DirichletCategorical. #45

Merged
merged 2 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cxx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cc_binary(
deps = [
":cxxopts",
":headers",
":util_distribution_variant",
":util_hash",
":util_io",
":util_math",
Expand All @@ -42,12 +43,24 @@ cc_library(
hdrs = ["relation.hh"],
deps = [
":domain",
":util_distribution_variant",
":util_hash",
":util_math",
"//distributions:base"
],
)

cc_library(
name = "util_distribution_variant",
srcs = ["util_distribution_variant.cc"],
visibility = [":__subpackages__"],
hdrs = ["util_distribution_variant.hh"],
deps = [
"//distributions",
],
)


cc_library(
name = "util_hash",
hdrs = ["util_hash.hh"],
Expand Down Expand Up @@ -96,6 +109,16 @@ cc_test(
],
)

cc_test(
name = "util_distribution_variant_test",
srcs = ["util_distribution_variant_test.cc"],
deps = [
":util_distribution_variant",
"@boost//:algorithm",
"@boost//:test",
],
)

cc_test(
name = "util_math_test",
srcs = ["util_math_test.cc"],
Expand Down
21 changes: 0 additions & 21 deletions cxx/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ cc_library(
name = "distributions",
visibility = ["//:__subpackages__"],
deps = [
":adapter",
":base",
":beta_bernoulli",
":bigram",
Expand All @@ -14,15 +13,6 @@ cc_library(
],
)

cc_library(
name = "adapter",
hdrs = ["adapter.hh"],
visibility = ["//:__subpackages__"],
deps = [
":base",
],
)

cc_library(
name = "base",
hdrs = ["base.hh"],
Expand Down Expand Up @@ -87,17 +77,6 @@ cc_library(
],
)

cc_test(
name = "adapter_test",
srcs = ["adapter_test.cc"],
deps = [
":adapter",
":normal",
"@boost//:algorithm",
"@boost//:test",
],
)

cc_test(
name = "beta_bernoulli_test",
srcs = ["beta_bernoulli_test.cc"],
Expand Down
62 changes: 0 additions & 62 deletions cxx/distributions/adapter.hh

This file was deleted.

33 changes: 0 additions & 33 deletions cxx/distributions/adapter_test.cc

This file was deleted.

2 changes: 2 additions & 0 deletions cxx/distributions/base.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
template <typename T>
class Distribution {
// Abstract base class for probability distributions in HIRM.
// New distribution subclasses need to be added to
// `util_distribution_variant` to be used in the (H)IRM models.
public:
typedef T SampleType;
// N is the number of incorporated observations.
Expand Down
2 changes: 2 additions & 0 deletions cxx/distributions/beta_bernoulli.hh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "distributions/base.hh"
#include "util_math.hh"

// TODO(thomaswc, emilyaf): Change BetaBernoulli to use bool instead of
// double.
class BetaBernoulli : public Distribution<double> {
public:
double alpha = 1; // hyperparameter
Expand Down
3 changes: 1 addition & 2 deletions cxx/distributions/crp.hh
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
#include <unordered_map>
#include <unordered_set>


typedef int T_item;

// TODO(emilyaf): Make this a distribution subclass.
class CRP {
public:
double alpha = 1.; // concentration parameter
int N = 0; // number of customers
int N = 0; // number of customers
std::unordered_map<int, std::unordered_set<T_item>>
tables; // map from table id to set of customers
std::unordered_map<T_item, int> assignments; // map from customer to table id
Expand Down
46 changes: 11 additions & 35 deletions cxx/hirm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,12 @@
#include <unordered_set>
#include <variant>

#include "distributions/base.hh"
#include "distributions/beta_bernoulli.hh"
#include "distributions/bigram.hh"
#include "distributions/crp.hh"
#include "distributions/dirichlet_categorical.hh"
#include "distributions/normal.hh"
#include "relation.hh"

#include "util_distribution_variant.hh"

// Map from names to T_relation's.
typedef std::map<std::string, T_relation> T_schema;

using ObservationVariant = std::variant<double, int, std::string>;

using RelationVariant =
std::variant<Relation<BetaBernoulli>*, Relation<Bigram>*,
// Relation<DirichletCategorical>*,
Relation<Normal>*>;

class IRM {
public:
T_schema schema; // schema of relations
Expand Down Expand Up @@ -222,11 +209,13 @@ class IRM {
T_item t = t_list.at(indexes.at(loc));
z.push_back(t);
}
typename std::remove_reference_t<decltype(*rel)>::DType aux(prng);
auto cluster = rel->clusters.contains(z) ? rel->clusters.at(z) : &aux;
auto v = std::get<
typename std::remove_reference_t<decltype(*rel)>::ValueType>(value);
return cluster->logp(v);
auto prior =
std::get<typename std::remove_reference_t<decltype(*rel)>::DType*>(
cluster_prior_from_spec(rel->dist_spec, prng));
return rel->clusters.contains(z) ? rel->clusters.at(z)->logp(v)
: prior->logp(v);
};
for (const auto& [r, items, value] : observations) {
auto g = std::bind(f_logp, std::placeholders::_1, items, value);
Expand Down Expand Up @@ -265,21 +254,8 @@ class IRM {
domain_to_relations.at(d).insert(name);
doms.push_back(domains.at(d));
}
if (relation.distribution == "normal") {
relations[name] =
new Relation<Normal>(name, relation.distribution, doms, prng);
} else if (relation.distribution == "bernoulli") {
relations[name] =
new Relation<BetaBernoulli>(name, relation.distribution, doms, prng);
} else if (relation.distribution == "bigram") {
relations[name] =
new Relation<Bigram>(name, relation.distribution, doms, prng);
} else {
assert(false);
}
// TODO(emilyaf): Enable Categorical. Maybe have "categorical5" e.g. for 5
// categories. relations[name] = new Relation<DirichletCategorical>(name,
// relation.distribution, doms, prng);
relations[name] =
relation_from_spec(name, relation.distribution_spec, doms, prng);
schema[name] = relation;
}

Expand Down Expand Up @@ -327,7 +303,7 @@ class HIRM {
}
}

void incorporate(const std::string& r, const T_items& items,
void incorporate(const std::string& r, const T_items& items,
const ObservationVariant& value) {
IRM* irm = relation_to_irm(r);
irm->incorporate(r, items, value);
Expand Down Expand Up @@ -366,7 +342,7 @@ class HIRM {
int table_current = crp.assignments.at(rc);
RelationVariant relation = get_relation(r);
T_relation t_relation =
std::visit([](auto rel) { return rel->get_T_relation(); }, relation);
std::visit([](auto rel) { return rel->trel; }, relation);
auto crp_dist = crp.tables_weights_gibbs(table_current);
std::vector<int> tables;
std::vector<double> logps;
Expand Down Expand Up @@ -443,7 +419,7 @@ class HIRM {
int table_current = crp.assignments.at(rc);
RelationVariant relation = get_relation(r);
auto f_obs = [&](auto rel) {
T_relation trel = rel->get_T_relation();
T_relation trel = rel->trel;
IRM* irm = relation_to_irm(r);
auto observations = rel->data;
// Remove from current IRM.
Expand Down
Loading