Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Relation<DistributionType> to Relation<ValueType> #70

Merged
merged 5 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion cxx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ cc_binary(
],
)

cc_binary(
name = "typename_playground",
srcs = ["typename_playground.cc"],
deps = [
":relation_variant",
":util_distribution_variant",
],
)

cc_library(
name = "relation",
hdrs = ["relation.hh"],
Expand All @@ -73,7 +82,6 @@ cc_library(
":domain",
":relation",
":util_distribution_variant",
"//distributions",
],
)

Expand Down Expand Up @@ -133,6 +141,7 @@ cc_test(
"@boost//:test",
],
)

cc_test(
name = "relation_test",
srcs = ["relation_test.cc"],
Expand All @@ -144,6 +153,15 @@ cc_test(
],
)

cc_test(
name = "relation_variant_test",
srcs = ["relation_variant_test.cc"],
deps = [
":relation_variant",
"@boost//:test",
],
)

cc_test(
name = "util_distribution_variant_test",
srcs = ["util_distribution_variant_test.cc"],
Expand Down
3 changes: 1 addition & 2 deletions cxx/distributions/base.hh
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ template <typename T>
class Distribution {
// Abstract base class for probability distributions in HIRM.
// New distribution subclasses need to be added to
// `relation_variant` and `util_distribution_variant` to be used in the
// (H)IRM models.
// `util_distribution_variant` to be used in the (H)IRM models.
public:
typedef T SampleType;
// N is the number of incorporated observations.
Expand Down
15 changes: 10 additions & 5 deletions cxx/irm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,16 @@ double IRM::logp(
}
auto v = std::get<
typename std::remove_reference_t<decltype(*rel)>::ValueType>(value);
auto prior =
std::get<typename std::remove_reference_t<decltype(*rel)>::DType*>(
cluster_prior_from_spec(rel->dist_spec));
return rel->clusters.contains(z) ? rel->clusters.at(z)->logp(v)
: prior->logp(v);
if (rel->clusters.contains(z)) {
return rel->clusters.at(z)->logp(v);
}
DistributionVariant prior = cluster_prior_from_spec(rel->dist_spec);
return std::visit(
[&](const auto& dist_variant) {
auto v2 = std::get<
typename std::remove_reference_t<decltype(
*dist_variant)>::SampleType>(value);
return dist_variant->logp(v2); }, prior);
};
for (const auto& [r, items, value] : observations) {
auto g = std::bind(f_logp, std::placeholders::_1, items, value);
Expand Down
1 change: 1 addition & 0 deletions cxx/irm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ BOOST_AUTO_TEST_CASE(test_irm) {
auto obs0 = observation_string_to_value("0", DistributionEnum::bernoulli);

double logp_x = irm.logp({{"R1", {1, 2}, obs0}});
BOOST_TEST(logp_x < 0.0);

irm.incorporate(&prng, "R1", {1, 2}, obs0);
double one_obs_score = irm.logp_score();
Expand Down
54 changes: 25 additions & 29 deletions cxx/relation.hh
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@
#include <vector>

#include "distributions/base.hh"
#include "distributions/beta_bernoulli.hh"
#include "distributions/bigram.hh"
#include "distributions/crp.hh"
#include "distributions/dirichlet_categorical.hh"
#include "distributions/normal.hh"
#include "domain.hh"
#include "util_distribution_variant.hh"
#include "util_hash.hh"
Expand All @@ -33,13 +28,11 @@ class T_relation {
DistributionSpec distribution_spec;
};

template <typename DistributionType>
template <typename T>
class Relation {
public:
using ValueType = typename DistributionType::SampleType;
using DType = DistributionType;
static_assert(std::is_base_of<Distribution<ValueType>, DType>::value,
"DistributionType must inherit from Distribution.");
typedef T ValueType;

// human-readable name
const std::string name;
// Relation spec.
Expand All @@ -49,8 +42,8 @@ class Relation {
// list of domain pointers
const std::vector<Domain*> domains;
// map from cluster multi-index to Distribution pointer
std::unordered_map<const std::vector<int>, DistributionType*, VectorIntHash>
clusters;
std::unordered_map<
const std::vector<int>, Distribution<ValueType>*, VectorIntHash> clusters;
// map from item to observed data
std::unordered_map<const T_items, ValueType, H_items> data;
// map from domain name to reverse map from item to
Expand Down Expand Up @@ -82,6 +75,15 @@ class Relation {
}
}

Distribution<ValueType>* make_new_distribution() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the answer is no, but is there any chance this could cause problems once DistributionVariant contains Distribution subclasses with the same ValueType?

Copy link
Collaborator Author

@ThomasColthurst ThomasColthurst Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. The only way the code here could go wrong is if relation_from_spec somehow returned the wrong type of Relation for a DistributionSpec -- for example, if it returned a Relation<bool> for a dist_spec that when passed to cluster_prior_from_spec returned a Distribution<double>.

But if that happened, lots of other things would go wrong, too.

return std::visit([&](auto dist_variant) {
// In practice, the DistributionVariant returned by
// cluster_prior_from_spec will always be of type
// Distribution<ValueType>*, so this reinterpret_cast is a no-op.
return reinterpret_cast<Distribution<ValueType>*>(dist_variant);
}, cluster_prior_from_spec(dist_spec));
}

void incorporate(std::mt19937* prng, const T_items& items, ValueType value) {
assert(!data.contains(items));
data[items] = value;
Expand All @@ -95,12 +97,7 @@ class Relation {
}
T_items z = get_cluster_assignment(items);
if (!clusters.contains(z)) {
// Invalid discussion as using pointers now;
// Cannot use clusters[z] because BetaBernoulli
// does not have a default constructor, whereas operator[]
// calls default constructor when the key does not exist.
clusters[z] =
std::get<DistributionType*>(cluster_prior_from_spec(dist_spec));
clusters[z] = make_new_distribution();
}
clusters.at(z)->incorporate(value);
}
Expand Down Expand Up @@ -180,9 +177,9 @@ class Relation {
T_items z = get_cluster_assignment_gibbs(items, domain, item, table);
double lp;
if (!clusters.contains(z)) {
DistributionType* cluster =
std::get<DistributionType*>(cluster_prior_from_spec(dist_spec));
lp = cluster->logp(x);
Distribution<ValueType>* tmp_dist = make_new_distribution();
lp = tmp_dist->logp(x);
delete tmp_dist;
} else {
lp = clusters.at(z)->logp(x);
}
Expand Down Expand Up @@ -240,9 +237,8 @@ class Relation {
T_items z =
get_cluster_assignment_gibbs(items_list[0], domain, item, table);

DistributionType* prior =
std::get<DistributionType*>(cluster_prior_from_spec(dist_spec));
DistributionType* cluster = clusters.contains(z) ? clusters.at(z) : prior;
Distribution<ValueType>* prior = make_new_distribution();
Distribution<ValueType>* cluster = clusters.contains(z) ? clusters.at(z) : prior;
double logp0 = cluster->logp_score();
for (const T_items& items : items_list) {
// assert(z == get_cluster_assignment_gibbs(items, domain, item, table));
Expand All @@ -255,6 +251,7 @@ class Relation {
cluster->unincorporate(x);
}
assert(cluster->logp_score() == logp0);
delete prior;
return logp1 - logp0;
}

Expand Down Expand Up @@ -324,12 +321,12 @@ class Relation {
z.push_back(zi);
logp_w += wi;
}
DistributionType* prior =
std::get<DistributionType*>(cluster_prior_from_spec(dist_spec));
DistributionType* cluster = clusters.contains(z) ? clusters.at(z) : prior;
Distribution<ValueType>* prior = make_new_distribution();
Distribution<ValueType>* cluster = clusters.contains(z) ? clusters.at(z) : prior;
double logp_z = cluster->logp(value);
double logp_zw = logp_z + logp_w;
logps.push_back(logp_zw);
delete prior;
}
return logsumexp(logps);
}
Expand Down Expand Up @@ -360,8 +357,7 @@ class Relation {
T_items z_new = get_cluster_assignment_gibbs(items, domain, item, table);
if (!clusters.contains(z_new)) {
// Move to fresh cluster.
clusters[z_new] =
std::get<DistributionType*>(cluster_prior_from_spec(dist_spec));
clusters[z_new] = make_new_distribution();
clusters.at(z_new)->incorporate(x);
} else {
// Move to existing cluster.
Expand Down
14 changes: 12 additions & 2 deletions cxx/relation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ BOOST_AUTO_TEST_CASE(test_relation) {
D2.incorporate(&prng, 1);
D3.incorporate(&prng, 3);
DistributionSpec spec = DistributionSpec{DistributionEnum::bernoulli};
Relation<BetaBernoulli> R1("R1", spec, {&D1, &D2, &D3});
Relation<bool> R1("R1", spec, {&D1, &D2, &D3});
R1.incorporate(&prng, {0, 1, 3}, 1);
R1.incorporate(&prng, {1, 1, 3}, 1);
R1.incorporate(&prng, {3, 1, 3}, 1);
Expand All @@ -48,8 +48,13 @@ BOOST_AUTO_TEST_CASE(test_relation) {
lpg = R1.logp_gibbs_approx(D1, 0, 10);
R1.set_cluster_assignment_gibbs(D1, 0, 1);

Distribution<bool>* db = R1.make_new_distribution();
BOOST_TEST(db->N == 0);
db->incorporate(false);
BOOST_TEST(db->N == 1);

DistributionSpec bigram_spec = DistributionSpec{DistributionEnum::bigram};
Relation<Bigram> R2("R1", bigram_spec, {&D2, &D3});
Relation<std::string> R2("R1", bigram_spec, {&D2, &D3});
R2.incorporate(&prng, {1, 3}, "cat");
R2.incorporate(&prng, {1, 2}, "dog");
R2.incorporate(&prng, {1, 4}, "catt");
Expand All @@ -58,4 +63,9 @@ BOOST_AUTO_TEST_CASE(test_relation) {
lpg = R2.logp_gibbs_approx(D2, 2, 0);
R2.set_cluster_assignment_gibbs(D3, 3, 1);
D1.set_cluster_assignment_gibbs(0, 1);

Distribution<std::string>* db2 = R2.make_new_distribution();
BOOST_TEST(db2->N == 0);
db2->incorporate("hello");
BOOST_TEST(db2->N == 1);
}
46 changes: 28 additions & 18 deletions cxx/relation_variant.cc
Original file line number Diff line number Diff line change
@@ -1,30 +1,40 @@
// Copyright 2024
// See LICENSE.txt

#include "relation_variant.hh"

#include <cassert>
#include <type_traits>

#include "distributions/beta_bernoulli.hh"
#include "distributions/bigram.hh"
#include "distributions/dirichlet_categorical.hh"
#include "distributions/normal.hh"
#include "domain.hh"
#include "relation.hh"
#include "relation_variant.hh"


RelationVariant relation_from_spec(const std::string& name,
const DistributionSpec& dist_spec,
std::vector<Domain*>& domains) {
switch (dist_spec.distribution) {
case DistributionEnum::bernoulli:
return new Relation<BetaBernoulli>(name, dist_spec, domains);
case DistributionEnum::bigram:
return new Relation<Bigram>(name, dist_spec, domains);
case DistributionEnum::categorical:
return new Relation<DirichletCategorical>(name, dist_spec, domains);
case DistributionEnum::normal:
return new Relation<Normal>(name, dist_spec, domains);
default:
assert(false && "Unsupported distribution enum value.");
}
DistributionVariant dv = cluster_prior_from_spec(dist_spec);

RelationVariant rv;

// We want to go from dv to its SampleType. This only takes five steps:
// 1. To go from the DistributionVariant dv to the underlying
// Distribution pointer, we use a std::visit.
// 2. To get the type of the Distribution, we use decltype(*v).
// 3. But that turns out to be of type DistributionName&, so we need to use
// a std::remove_reference_t to get rid of the &.
// 4. You would think that we could now just access ::SampleType and be done,
// but no -- that expression is sufficiently complicated that the C++
// parser gets confused and throws an error about "dependent-name ... is
// parsed as a non-type". Luckily the same error message also gives the
// fix: just add a typename to the beginning.
// 5. With that, we can finally access the SampleType and use it to construct
// the right kind of Relation.
std::visit(
[&](const auto& v) {
rv = new Relation<typename
std::remove_reference_t<decltype(*v)>::SampleType>(
name, dist_spec, domains);
}, dv);

return rv;
}
14 changes: 4 additions & 10 deletions cxx/relation_variant.hh
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,13 @@
#include <variant>
#include <vector>

#include "domain.hh"
#include "relation.hh"
#include "util_distribution_variant.hh"

class BetaBernoulli;
class Bigram;
class DirichletCategorical;
class Normal;
class Domain;
template <typename DistributionType>
class Relation;

using RelationVariant =
std::variant<Relation<BetaBernoulli>*, Relation<Bigram>*,
Relation<DirichletCategorical>*, Relation<Normal>*>;
std::variant<Relation<std::string>*, Relation<double>*,
Relation<int>*, Relation<bool>*>;

RelationVariant relation_from_spec(const std::string& name,
const DistributionSpec& dist_spec,
Expand Down
17 changes: 17 additions & 0 deletions cxx/relation_variant_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Apache License, Version 2.0, refer to LICENSE.txt

#define BOOST_TEST_MODULE test relation_variant

#include "relation_variant.hh"

#include <boost/test/included/unit_test.hpp>
namespace tt = boost::test_tools;

BOOST_AUTO_TEST_CASE(test_relation_variant) {
std::vector<Domain *> domains;
domains.push_back(new Domain("D1"));
RelationVariant rv = relation_from_spec(
"r1", parse_distribution_spec("bernoulli"), domains);
Relation<bool>* rb = std::get<Relation<bool>*>(rv);
BOOST_TEST(rb->name == "r1");
}
4 changes: 2 additions & 2 deletions cxx/tests/test_hirm_animals.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ int main(int argc, char** argv) {
}
// Check relations agree.
for (const auto& [r, rm_var] : irm->relations) {
auto rx = std::get<Relation<BetaBernoulli>*>(irx->relations.at(r));
auto rm = std::get<Relation<BetaBernoulli>*>(rm_var);
auto rx = std::get<Relation<bool>*>(irx->relations.at(r));
auto rm = std::get<Relation<bool>*>(rm_var);
assert(rm->data == rx->data);
assert(rm->data_r == rx->data_r);
assert(rm->clusters.size() == rx->clusters.size());
Expand Down
13 changes: 7 additions & 6 deletions cxx/tests/test_irm_two_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "util_io.hh"
#include "util_math.hh"

using T_r = Relation<BetaBernoulli>*;
using T_r = Relation<bool>*;

int main(int argc, char** argv) {
std::string path_base = "assets/two_relations";
Expand Down Expand Up @@ -111,12 +111,13 @@ int main(int argc, char** argv) {
// transitioned.
assert(abs(irx.logp_score() - irm.logp_score()) > 1e-8);
for (const auto& r : {"R1", "R2"}) {
auto r1m = std::get<Relation<BetaBernoulli>*>(irm.relations.at(r));
auto r1x = std::get<Relation<BetaBernoulli>*>(irx.relations.at(r));
auto r1m = std::get<Relation<bool>*>(irm.relations.at(r));
auto r1x = std::get<Relation<bool>*>(irx.relations.at(r));
for (const auto& [c, distribution] : r1m->clusters) {
auto dx = r1x->clusters.at(c);
dx->alpha = distribution->alpha;
dx->beta = distribution->beta;
auto dx = reinterpret_cast<BetaBernoulli*>(r1x->clusters.at(c));
auto dy = reinterpret_cast<BetaBernoulli*>(distribution);
dx->alpha = dy->alpha;
dx->beta = dy->beta;
}
}
assert(abs(irx.logp_score() - irm.logp_score()) < 1e-8);
Expand Down
Loading