Skip to content

Commit

Permalink
Address reviewer comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
emilyfertig committed Jun 12, 2024
1 parent 67c7528 commit 9ae3928
Show file tree
Hide file tree
Showing 17 changed files with 39 additions and 198 deletions.
21 changes: 0 additions & 21 deletions cxx/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ cc_library(
name = "distributions",
visibility = ["//:__subpackages__"],
deps = [
":adapter",
":base",
":beta_bernoulli",
":bigram",
Expand All @@ -14,15 +13,6 @@ cc_library(
],
)

cc_library(
name = "adapter",
hdrs = ["adapter.hh"],
visibility = ["//:__subpackages__"],
deps = [
":base",
],
)

cc_library(
name = "base",
hdrs = ["base.hh"],
Expand Down Expand Up @@ -87,17 +77,6 @@ cc_library(
],
)

cc_test(
name = "adapter_test",
srcs = ["adapter_test.cc"],
deps = [
":adapter",
":normal",
"@boost//:algorithm",
"@boost//:test",
],
)

cc_test(
name = "beta_bernoulli_test",
srcs = ["beta_bernoulli_test.cc"],
Expand Down
66 changes: 0 additions & 66 deletions cxx/distributions/adapter.hh

This file was deleted.

33 changes: 0 additions & 33 deletions cxx/distributions/adapter_test.cc

This file was deleted.

5 changes: 1 addition & 4 deletions cxx/distributions/base.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
template <typename T>
class Distribution {
// Abstract base class for probability distributions in HIRM.
// New distribution subclasses need to be added to
// New distribution subclasses need to be added to
// `util_distribution_variant` to be used in the (H)IRM models.
public:
typedef T SampleType;
Expand Down Expand Up @@ -39,8 +39,5 @@ class Distribution {
// e^logp_score() under those hyperparameters.
virtual void transition_hyperparameters() = 0;

// Return a copy of the distribution with no observed data.
virtual Distribution<T>* prior() const = 0;

virtual ~Distribution() = default;
};
4 changes: 0 additions & 4 deletions cxx/distributions/beta_bernoulli.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,3 @@ void BetaBernoulli::transition_hyperparameters() {
alpha = hypers[i].first;
beta = hypers[i].second;
}

BetaBernoulli* BetaBernoulli::prior() const {
return new BetaBernoulli (prng);
}
4 changes: 2 additions & 2 deletions cxx/distributions/beta_bernoulli.hh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "distributions/base.hh"
#include "util_math.hh"

// TODO(thomaswc, emilyaf): Change BetaBernoulli to use bool instead of
// double.
class BetaBernoulli : public Distribution<double> {
public:
double alpha = 1; // hyperparameter
Expand Down Expand Up @@ -35,6 +37,4 @@ class BetaBernoulli : public Distribution<double> {
double sample();

void transition_hyperparameters();

BetaBernoulli* prior() const;
};
4 changes: 0 additions & 4 deletions cxx/distributions/bigram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,3 @@ void Bigram::transition_hyperparameters() {
int i = sample_from_logps(logps, prng);
set_alpha(alphas[i]);
}

Bigram* Bigram::prior() const {
return new Bigram (prng);
}
2 changes: 0 additions & 2 deletions cxx/distributions/bigram.hh
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,4 @@ class Bigram : public Distribution<std::string> {
void set_alpha(double alphat);

void transition_hyperparameters();

Bigram* prior() const;
};
4 changes: 0 additions & 4 deletions cxx/distributions/crp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,3 @@ void CRP::transition_alpha() {
int idx = log_choice(logps, prng);
this->alpha = grid[idx];
}

CRP* CRP::prior() const {
return new CRP(prng);
}
5 changes: 1 addition & 4 deletions cxx/distributions/crp.hh
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
#include <unordered_map>
#include <unordered_set>


typedef int T_item;

// TODO(emilyaf): Make this a distribution subclass.
class CRP {
public:
double alpha = 1.; // concentration parameter
int N = 0; // number of customers
int N = 0; // number of customers
std::unordered_map<int, std::unordered_set<T_item>>
tables; // map from table id to set of customers
std::unordered_map<T_item, int> assignments; // map from customer to table id
Expand All @@ -36,6 +35,4 @@ class CRP {
std::unordered_map<int, double> tables_weights_gibbs(int table) const;

void transition_alpha();

CRP* prior() const;
};
4 changes: 0 additions & 4 deletions cxx/distributions/dirichlet_categorical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,3 @@ void DirichletCategorical::transition_hyperparameters() {
int i = sample_from_logps(logps, prng);
alpha = alphas[i];
}

DirichletCategorical* DirichletCategorical::prior() const {
return new DirichletCategorical(prng, counts.size());
}
2 changes: 0 additions & 2 deletions cxx/distributions/dirichlet_categorical.hh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,4 @@ class DirichletCategorical : public Distribution<double> {
double sample();

void transition_hyperparameters();

DirichletCategorical* prior() const;
};
4 changes: 0 additions & 4 deletions cxx/distributions/normal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,3 @@ void Normal::transition_hyperparameters() {
m = std::get<2>(hypers[i]);
s = std::get<3>(hypers[i]);
}

Normal* Normal::prior() const {
return new Normal(prng);
}
2 changes: 0 additions & 2 deletions cxx/distributions/normal.hh
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ class Normal : public Distribution<double> {

void transition_hyperparameters();

Normal* prior() const;

// Disable copying.
Normal& operator=(const Normal&) = delete;
Normal(const Normal&) = delete;
Expand Down
13 changes: 8 additions & 5 deletions cxx/hirm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "relation.hh"
#include "util_distribution_variant.hh"


// Map from names to T_relation's.
typedef std::map<std::string, T_relation> T_schema;

Expand Down Expand Up @@ -212,8 +211,11 @@ class IRM {
}
auto v = std::get<
typename std::remove_reference_t<decltype(*rel)>::ValueType>(value);
return rel->clusters.contains(z) ? rel->clusters.at(z)->logp(v)
: rel->cluster_prior->logp(v);
auto prior =
std::get<typename std::remove_reference_t<decltype(*rel)>::DType*>(
cluster_prior_from_spec(rel->dist_spec, prng));
return rel->clusters.contains(z) ? rel->clusters.at(z)->logp(v)
: prior->logp(v);
};
for (const auto& [r, items, value] : observations) {
auto g = std::bind(f_logp, std::placeholders::_1, items, value);
Expand Down Expand Up @@ -252,7 +254,8 @@ class IRM {
domain_to_relations.at(d).insert(name);
doms.push_back(domains.at(d));
}
relations[name] = relation_from_spec(name, relation.distribution_spec, doms, prng);
relations[name] =
relation_from_spec(name, relation.distribution_spec, doms, prng);
schema[name] = relation;
}

Expand Down Expand Up @@ -300,7 +303,7 @@ class HIRM {
}
}

void incorporate(const std::string& r, const T_items& items,
void incorporate(const std::string& r, const T_items& items,
const ObservationVariant& value) {
IRM* irm = relation_to_irm(r);
irm->incorporate(r, items, value);
Expand Down
26 changes: 14 additions & 12 deletions cxx/relation.hh
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class Relation {
const std::string name;
// Relation spec.
T_relation trel;
// Distribution prior over the relation's codomain.
const DistributionType* cluster_prior;
// Distribution spec over the relation's codomain.
const DistributionSpec dist_spec;
// list of domain pointers
const std::vector<Domain*> domains;
// map from cluster multi-index to Distribution pointer
Expand All @@ -63,13 +63,10 @@ class Relation {

Relation(const std::string& name, const DistributionSpec& dist_spec,
const std::vector<Domain*>& domains, std::mt19937* prng)
: name(name), domains(domains) {
: name(name), dist_spec(dist_spec), domains(domains) {
assert(!domains.empty());
assert(!name.empty());
this->prng = prng;
DistributionVariant cluster_prior_var =
cluster_prior_from_spec(dist_spec, prng);
cluster_prior = std::get<DistributionType*>(cluster_prior_var);
for (const Domain* const d : domains) {
this->data_r[d->name] =
std::unordered_map<T_item, std::unordered_set<T_items, H_items>>();
Expand Down Expand Up @@ -104,7 +101,8 @@ class Relation {
// Cannot use clusters[z] because BetaBernoulli
// does not have a default constructor, whereas operator[]
// calls default constructor when the key does not exist.
clusters[z] = cluster_prior->prior();
clusters[z] =
std::get<DistributionType*>(cluster_prior_from_spec(dist_spec, prng));
}
clusters.at(z)->incorporate(value);
}
Expand Down Expand Up @@ -243,8 +241,9 @@ class Relation {
T_items z =
get_cluster_assignment_gibbs(items_list[0], domain, item, table);

DistributionType* cluster =
clusters.contains(z) ? clusters.at(z) : cluster_prior->prior();
DistributionType* prior =
std::get<DistributionType*>(cluster_prior_from_spec(dist_spec, prng));
DistributionType* cluster = clusters.contains(z) ? clusters.at(z) : prior;
double logp0 = cluster->logp_score();
for (const T_items& items : items_list) {
// assert(z == get_cluster_assignment_gibbs(items, domain, item, table));
Expand Down Expand Up @@ -326,8 +325,10 @@ class Relation {
z.push_back(zi);
logp_w += wi;
}
double logp_z = clusters.contains(z) ? clusters.at(z)->logp(value)
: cluster_prior->logp(value);
DistributionType* prior =
std::get<DistributionType*>(cluster_prior_from_spec(dist_spec, prng));
DistributionType* cluster = clusters.contains(z) ? clusters.at(z) : prior;
double logp_z = cluster->logp(value);
double logp_zw = logp_z + logp_w;
logps.push_back(logp_zw);
}
Expand Down Expand Up @@ -360,7 +361,8 @@ class Relation {
T_items z_new = get_cluster_assignment_gibbs(items, domain, item, table);
if (!clusters.contains(z_new)) {
// Move to fresh cluster.
clusters[z_new] = cluster_prior->prior();
clusters[z_new] = std::get<DistributionType*>(
cluster_prior_from_spec(dist_spec, prng));
clusters.at(z_new)->incorporate(x);
} else {
// Move to existing cluster.
Expand Down
Loading

0 comments on commit 9ae3928

Please sign in to comment.