Skip to content

Commit

Permalink
Merge pull request #45 from probcomp/061024-emilyaf-dirichlet-in-hirm
Browse files Browse the repository at this point in the history
Support distribution constructor args and enable DirichletCategorical.
  • Loading branch information
emilyfertig authored Jun 12, 2024
2 parents d9f36cb + 9ae3928 commit c029948
Show file tree
Hide file tree
Showing 17 changed files with 278 additions and 206 deletions.
23 changes: 23 additions & 0 deletions cxx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cc_binary(
deps = [
":cxxopts",
":headers",
":util_distribution_variant",
":util_hash",
":util_io",
":util_math",
Expand All @@ -42,12 +43,24 @@ cc_library(
hdrs = ["relation.hh"],
deps = [
":domain",
":util_distribution_variant",
":util_hash",
":util_math",
"//distributions:base"
],
)

cc_library(
name = "util_distribution_variant",
srcs = ["util_distribution_variant.cc"],
visibility = [":__subpackages__"],
hdrs = ["util_distribution_variant.hh"],
deps = [
"//distributions",
],
)


cc_library(
name = "util_hash",
hdrs = ["util_hash.hh"],
Expand Down Expand Up @@ -96,6 +109,16 @@ cc_test(
],
)

cc_test(
name = "util_distribution_variant_test",
srcs = ["util_distribution_variant_test.cc"],
deps = [
":util_distribution_variant",
"@boost//:algorithm",
"@boost//:test",
],
)

cc_test(
name = "util_math_test",
srcs = ["util_math_test.cc"],
Expand Down
21 changes: 0 additions & 21 deletions cxx/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ cc_library(
name = "distributions",
visibility = ["//:__subpackages__"],
deps = [
":adapter",
":base",
":beta_bernoulli",
":bigram",
Expand All @@ -14,15 +13,6 @@ cc_library(
],
)

cc_library(
name = "adapter",
hdrs = ["adapter.hh"],
visibility = ["//:__subpackages__"],
deps = [
":base",
],
)

cc_library(
name = "base",
hdrs = ["base.hh"],
Expand Down Expand Up @@ -97,17 +87,6 @@ cc_library(
],
)

cc_test(
name = "adapter_test",
srcs = ["adapter_test.cc"],
deps = [
":adapter",
":normal",
"@boost//:algorithm",
"@boost//:test",
],
)

cc_test(
name = "beta_bernoulli_test",
srcs = ["beta_bernoulli_test.cc"],
Expand Down
62 changes: 0 additions & 62 deletions cxx/distributions/adapter.hh

This file was deleted.

33 changes: 0 additions & 33 deletions cxx/distributions/adapter_test.cc

This file was deleted.

2 changes: 2 additions & 0 deletions cxx/distributions/base.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
template <typename T>
class Distribution {
// Abstract base class for probability distributions in HIRM.
// New distribution subclasses need to be added to
// `util_distribution_variant` to be used in the (H)IRM models.
public:
typedef T SampleType;
// N is the number of incorporated observations.
Expand Down
2 changes: 2 additions & 0 deletions cxx/distributions/beta_bernoulli.hh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "distributions/base.hh"
#include "util_math.hh"

// TODO(thomaswc, emilyaf): Change BetaBernoulli to use bool instead of
// double.
class BetaBernoulli : public Distribution<double> {
public:
double alpha = 1; // hyperparameter
Expand Down
3 changes: 1 addition & 2 deletions cxx/distributions/crp.hh
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
#include <unordered_map>
#include <unordered_set>


typedef int T_item;

// TODO(emilyaf): Make this a distribution subclass.
class CRP {
public:
double alpha = 1.; // concentration parameter
int N = 0; // number of customers
int N = 0; // number of customers
std::unordered_map<int, std::unordered_set<T_item>>
tables; // map from table id to set of customers
std::unordered_map<T_item, int> assignments; // map from customer to table id
Expand Down
46 changes: 11 additions & 35 deletions cxx/hirm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,12 @@
#include <unordered_set>
#include <variant>

#include "distributions/base.hh"
#include "distributions/beta_bernoulli.hh"
#include "distributions/bigram.hh"
#include "distributions/crp.hh"
#include "distributions/dirichlet_categorical.hh"
#include "distributions/normal.hh"
#include "relation.hh"

#include "util_distribution_variant.hh"

// Map from names to T_relation's.
typedef std::map<std::string, T_relation> T_schema;

using ObservationVariant = std::variant<double, int, std::string>;

using RelationVariant =
std::variant<Relation<BetaBernoulli>*, Relation<Bigram>*,
// Relation<DirichletCategorical>*,
Relation<Normal>*>;

class IRM {
public:
T_schema schema; // schema of relations
Expand Down Expand Up @@ -222,11 +209,13 @@ class IRM {
T_item t = t_list.at(indexes.at(loc));
z.push_back(t);
}
typename std::remove_reference_t<decltype(*rel)>::DType aux(prng);
auto cluster = rel->clusters.contains(z) ? rel->clusters.at(z) : &aux;
auto v = std::get<
typename std::remove_reference_t<decltype(*rel)>::ValueType>(value);
return cluster->logp(v);
auto prior =
std::get<typename std::remove_reference_t<decltype(*rel)>::DType*>(
cluster_prior_from_spec(rel->dist_spec, prng));
return rel->clusters.contains(z) ? rel->clusters.at(z)->logp(v)
: prior->logp(v);
};
for (const auto& [r, items, value] : observations) {
auto g = std::bind(f_logp, std::placeholders::_1, items, value);
Expand Down Expand Up @@ -265,21 +254,8 @@ class IRM {
domain_to_relations.at(d).insert(name);
doms.push_back(domains.at(d));
}
if (relation.distribution == "normal") {
relations[name] =
new Relation<Normal>(name, relation.distribution, doms, prng);
} else if (relation.distribution == "bernoulli") {
relations[name] =
new Relation<BetaBernoulli>(name, relation.distribution, doms, prng);
} else if (relation.distribution == "bigram") {
relations[name] =
new Relation<Bigram>(name, relation.distribution, doms, prng);
} else {
assert(false);
}
// TODO(emilyaf): Enable Categorical. Maybe have "categorical5" e.g. for 5
// categories. relations[name] = new Relation<DirichletCategorical>(name,
// relation.distribution, doms, prng);
relations[name] =
relation_from_spec(name, relation.distribution_spec, doms, prng);
schema[name] = relation;
}

Expand Down Expand Up @@ -327,7 +303,7 @@ class HIRM {
}
}

void incorporate(const std::string& r, const T_items& items,
void incorporate(const std::string& r, const T_items& items,
const ObservationVariant& value) {
IRM* irm = relation_to_irm(r);
irm->incorporate(r, items, value);
Expand Down Expand Up @@ -366,7 +342,7 @@ class HIRM {
int table_current = crp.assignments.at(rc);
RelationVariant relation = get_relation(r);
T_relation t_relation =
std::visit([](auto rel) { return rel->get_T_relation(); }, relation);
std::visit([](auto rel) { return rel->trel; }, relation);
auto crp_dist = crp.tables_weights_gibbs(table_current);
std::vector<int> tables;
std::vector<double> logps;
Expand Down Expand Up @@ -443,7 +419,7 @@ class HIRM {
int table_current = crp.assignments.at(rc);
RelationVariant relation = get_relation(r);
auto f_obs = [&](auto rel) {
T_relation trel = rel->get_T_relation();
T_relation trel = rel->trel;
IRM* irm = relation_to_irm(r);
auto observations = rel->data;
// Remove from current IRM.
Expand Down
Loading

0 comments on commit c029948

Please sign in to comment.