Skip to content

Commit

Permalink
Merge pull request #154 from probcomp/080124-emilyaf-sample-hirm
Browse files Browse the repository at this point in the history
Add a sampling method to HIRM.
  • Loading branch information
emilyfertig authored Aug 13, 2024
2 parents d389339 + b018f27 commit 5eb0ffd
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 25 deletions.
21 changes: 17 additions & 4 deletions cxx/clean_relation.hh
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
#include <unordered_map>
#include <vector>

#include "distributions/get_distribution.hh"
#include "domain.hh"
#include "emissions/get_emission.hh"
#include "relation.hh"
#include "util_hash.hh"
#include "util_math.hh"
#include "distributions/get_distribution.hh"
#include "emissions/get_emission.hh"

// T_clean_relation is the text we get from reading a line of the schema
// file; CleanRelation is the object that does the work.
Expand Down Expand Up @@ -83,9 +83,9 @@ class CleanRelation : public Relation<T> {
return std::visit(spec_to_dist, prior_spec);
}

void incorporate(std::mt19937* prng, const T_items& items, ValueType value) {
// Incorporates a new vector of items and returns their cluster assignments.
T_items incorporate_items(std::mt19937* prng, const T_items& items) {
assert(!data.contains(items));
data[items] = value;
for (int i = 0; i < std::ssize(domains); ++i) {
domains[i]->incorporate(prng, items[i]);
if (!data_r.at(domains[i]->name).contains(items[i])) {
Expand All @@ -98,6 +98,19 @@ class CleanRelation : public Relation<T> {
if (!clusters.contains(z)) {
clusters[z] = make_new_distribution(prng);
}
return z;
}

void incorporate(std::mt19937* prng, const T_items& items, ValueType value) {
T_items z = incorporate_items(prng, items);
data[items] = value;
clusters.at(z)->incorporate(value);
}

void incorporate_sample(std::mt19937* prng, const T_items& items) {
T_items z = incorporate_items(prng, items);
ValueType value = sample_at_items(prng, items);
data[items] = value;
clusters.at(z)->incorporate(value);
}

Expand Down
13 changes: 13 additions & 0 deletions cxx/clean_relation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,16 @@ BOOST_AUTO_TEST_CASE(test_from_string) {
std::string s = R2.from_string("hello world");
BOOST_TEST(s == "hello world");
}

BOOST_AUTO_TEST_CASE(test_incorporate_sample) {
std::mt19937 prng;
Domain D1("D1");
Domain D2("D2");
DistributionSpec spec("normal");
CleanRelation<double> R1("R1", spec, {&D1, &D2});
R1.incorporate_sample(&prng, {0, 1});
R1.incorporate_sample(&prng, {0, 2});
R1.incorporate_sample(&prng, {5, 2});

BOOST_TEST(R1.data.size() == 3);
}
51 changes: 51 additions & 0 deletions cxx/hirm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,57 @@ double HIRM::logp_score() const {
return logp_score_crp + logp_score_irms;
}

void HIRM::sample_and_incorporate_relation(std::mt19937* prng,
const std::string& r,
T_items& items) {
// If `r` is a noisy relation, first sample and incorporate to the base
// relation if necessary.
if (T_noisy_relation* trel = std::get_if<T_noisy_relation>(&schema.at(r))) {
std::visit(
[&](auto nr) {
using T = typename std::remove_pointer_t<decltype(nr)>::ValueType;
NoisyRelation<T>* noisy_rel = reinterpret_cast<NoisyRelation<T>*>(nr);
T_items base_items = noisy_rel->get_base_items(items);
if (!noisy_rel->base_relation->get_data().contains(base_items)) {
sample_and_incorporate_relation(prng, trel->base_relation,
base_items);
}
},
get_relation(r));
}
std::visit([&](auto rel) { rel->incorporate_sample(prng, items); },
get_relation(r));
}

void HIRM::sample_and_incorporate(std::mt19937* prng, int n) {
std::map<std::string, CRP> domain_crps;
for (const auto& [r, spec] : schema) {
// If the relation is a leaf, sample n observations of it.
if (!base_to_noisy_relations.contains(r)) {
const std::vector<std::string>& r_domains =
std::visit([](auto trel) { return trel.domains; }, spec);
int num_samples = 0;
while (num_samples < n) {
std::vector<int> entities;
entities.reserve(r_domains.size());
for (auto it = r_domains.cbegin(); it != r_domains.cend(); ++it) {
int entity = domain_crps[*it].sample(prng);
int crp_item = domain_crps[*it].assignments.size();
domain_crps[*it].incorporate(crp_item, entity);
entities.push_back(entity);
}
bool r_contains_items = std::visit(
[&](auto rel) { return rel->get_data().contains(entities); },
get_relation(r));
if (!r_contains_items) {
sample_and_incorporate_relation(prng, r, entities);
++num_samples;
}
}
}
}
}

HIRM::~HIRM() {
for (const auto& [table, irm] : irms) {
delete irm;
Expand Down
15 changes: 14 additions & 1 deletion cxx/hirm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
#include <unordered_set>
#include <variant>

#include "distributions/get_distribution.hh"
#include "irm.hh"
#include "relation.hh"
#include "transition_latent_value.hh"
#include "distributions/get_distribution.hh"

class HIRM {
public:
Expand Down Expand Up @@ -65,6 +65,19 @@ class HIRM {

double logp_score() const;

// Samples `n` values from each leaf relation (i.e. each relation that is not
// the base relation of a different relation). Recursively samples some number
// of values from non-leaf relations as needed to get to `n` leaf samples.
// Beware: since `n` unique values are sampled from CRPs, if `n` is too high
// relative to the CRP `alpha`s, this function might take a very long time.
void sample_and_incorporate(std::mt19937* prng, int n);

// Incorporates a sample into relation `r`. If `r` is a noisy relation, this
// function recursively incorporates a sample into the base relation, if
// necessary.
void sample_and_incorporate_relation(std::mt19937* prng, const std::string& r,
T_items& items);

~HIRM();

// Disable copying.
Expand Down
37 changes: 37 additions & 0 deletions cxx/hirm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <boost/test/included/unit_test.hpp>
#include <random>

#include "distributions/get_distribution.hh"

namespace tt = boost::test_tools;
Expand Down Expand Up @@ -71,6 +72,42 @@ BOOST_AUTO_TEST_CASE(test_hirm) {
BOOST_TEST(R2->get_data().at({0, 3}) != 0.5);
}

BOOST_AUTO_TEST_CASE(test_hirm_sample) {
std::map<std::string, T_relation> schema2{
{"R2",
T_noisy_relation{
{"D1", "D1", "D2"}, true, EmissionSpec("sometimes_bitflip"), "R1"}},
{"R1",
T_clean_relation{{"D1", "D1"}, false, DistributionSpec("bernoulli")}},
{"R3",
T_noisy_relation{
{"D1", "D1", "D5"}, true, EmissionSpec("sometimes_bitflip"), "R1"}},
{"R4", T_clean_relation{{"D1", "D3"}, false, DistributionSpec("normal")}},
{"R5", T_noisy_relation{{"D1", "D3", "D4"},
true,
EmissionSpec("sometimes_gaussian"),
"R4"}}};

std::mt19937 prng;
HIRM hirm(schema2, &prng);
hirm.sample_and_incorporate(&prng, 20);

BOOST_TEST(
std::get<Relation<bool>*>(hirm.get_relation("R2"))->get_data().size() ==
20);
BOOST_TEST(
std::get<Relation<bool>*>(hirm.get_relation("R3"))->get_data().size() ==
20);
BOOST_TEST(
std::get<Relation<double>*>(hirm.get_relation("R5"))->get_data().size() ==
20);
int nobs_R4 =
std::get<Relation<double>*>(hirm.get_relation("R4"))->get_data().size();
BOOST_TEST(nobs_R4 > 0);
BOOST_TEST(nobs_R4 <= 20);
BOOST_TEST(hirm.logp_score() < 0.0);
}

BOOST_AUTO_TEST_CASE(test_hirm_relation_names) {
std::mt19937 prng;
std::map<std::string, T_relation> schema1{
Expand Down
14 changes: 11 additions & 3 deletions cxx/noisy_relation.hh
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
#include <vector>

#include "clean_relation.hh"
#include "distributions/get_distribution.hh"
#include "domain.hh"
#include "emissions/base.hh"
#include "relation.hh"
#include "util_hash.hh"
#include "util_math.hh"
#include "distributions/get_distribution.hh"
#include "emissions/base.hh"

// T_noisy_relation is the text we get from reading a line of the schema file;
// NoisyRelation is the object that does the work.
Expand Down Expand Up @@ -68,6 +68,15 @@ class NoisyRelation : public Relation<T> {
base_to_noisy_items[base_items].insert(items);
}

void incorporate_sample(std::mt19937* prng, const T_items& items) {
T_items z = emission_relation.incorporate_items(prng, items);
const ValueType& base_val = get_base_value(items);
ValueType value = sample_at_items(prng, items);
data[items] = value;
emission_relation.clusters.at(z)->incorporate(
std::make_pair(base_val, value));
}

// incorporate_to_cluster and unincorporate_from_cluster should be used with
// care, since they mutate the clusters only and not the relation. In
// particular, for every call to unincorporate_from_cluster, there must be a
Expand Down Expand Up @@ -186,7 +195,6 @@ class NoisyRelation : public Relation<T> {
}

ValueType sample_at_items(std::mt19937* prng, const T_items& items) const {
// TODO(emilyaf): Maybe take a sample if there is no base value.
const ValueType& base_value = get_base_value(items);
if (emission_relation.clusters.contains(items)) {
return reinterpret_cast<Emission<ValueType>*>(
Expand Down
20 changes: 20 additions & 0 deletions cxx/noisy_relation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,23 @@ BOOST_AUTO_TEST_CASE(test_cluster_logp_sample) {
double lp = NR1.cluster_or_prior_logp(&prng, {0, 1, 2}, sample);
BOOST_TEST(lp < 0.0);
}

BOOST_AUTO_TEST_CASE(test_incorporate_sample) {
std::mt19937 prng;
Domain D1("D1");
Domain D2("D2");
Domain D3("D3");
DistributionSpec spec("normal");
CleanRelation<double> R1("R1", spec, {&D1, &D2});
R1.incorporate(&prng, {0, 1}, 3.);

EmissionSpec em_spec("sometimes_gaussian");
NoisyRelation<double> NR1("NR1", em_spec, {&D1, &D2, &D3}, &R1);

NR1.incorporate_sample(&prng, {0, 1, 1});
NR1.incorporate_sample(&prng, {0, 1, 5});
NR1.incorporate_sample(&prng, {0, 1, 2});

BOOST_TEST(NR1.data.size() == 3);
BOOST_TEST(R1.data.size() == 1);
}
48 changes: 31 additions & 17 deletions cxx/relation.hh
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
#include <unordered_map>
#include <vector>

#include "distributions/get_distribution.hh"
#include "domain.hh"
#include "util_hash.hh"
#include "distributions/get_distribution.hh"

typedef std::vector<T_item> T_items;
typedef VectorIntHash H_items;
Expand All @@ -21,44 +21,59 @@ class Relation {
public:
typedef T ValueType;

virtual void incorporate(std::mt19937* prng, const T_items& items, ValueType value) = 0;
virtual void incorporate(std::mt19937* prng, const T_items& items,
ValueType value) = 0;

virtual void unincorporate(const T_items& items) = 0;

virtual double logp(const T_items& items, ValueType value, std::mt19937* prng) = 0;
virtual double logp(const T_items& items, ValueType value,
std::mt19937* prng) = 0;

virtual double logp_score() const = 0;

virtual double cluster_or_prior_logp(std::mt19937* prng, const T_items& items, const ValueType& value) const = 0;
virtual double cluster_or_prior_logp(std::mt19937* prng, const T_items& items,
const ValueType& value) const = 0;

virtual ValueType sample_at_items(std::mt19937* prng, const T_items& items) const = 0;
virtual ValueType sample_at_items(std::mt19937* prng,
const T_items& items) const = 0;

virtual void incorporate_to_cluster(const T_items& items, const ValueType& value) = 0;
// Takes a sample from the cluster containing `items` and incorporates it.
virtual void incorporate_sample(std::mt19937* prng, const T_items& items) = 0;

virtual void incorporate_to_cluster(const T_items& items,
const ValueType& value) = 0;

virtual void unincorporate_from_cluster(const T_items& items) = 0;

// TODO(emilyaf): Standardize passing PRNG first or last.
virtual std::vector<double> logp_gibbs_exact(
const Domain& domain, const T_item& item, std::vector<int> tables,
std::mt19937* prng) = 0;
virtual std::vector<double> logp_gibbs_exact(const Domain& domain,
const T_item& item,
std::vector<int> tables,
std::mt19937* prng) = 0;

virtual void set_cluster_assignment_gibbs(const Domain& domain, const T_item& item,
int table, std::mt19937* prng) = 0;
virtual void set_cluster_assignment_gibbs(const Domain& domain,
const T_item& item, int table,
std::mt19937* prng) = 0;

virtual void transition_cluster_hparams(std::mt19937* prng, int num_theta_steps) = 0;
virtual void transition_cluster_hparams(std::mt19937* prng,
int num_theta_steps) = 0;

// Accessor/convenience methods, mostly for subclass members that can't be accessed through the base class.
// Accessor/convenience methods, mostly for subclass members that can't be
// accessed through the base class.
virtual const std::vector<Domain*>& get_domains() const = 0;

virtual const ValueType& get_value(const T_items& items) const = 0;

virtual const std::unordered_map<const T_items, ValueType, H_items>& get_data() const = 0;
virtual const std::unordered_map<const T_items, ValueType, H_items>&
get_data() const = 0;

virtual void update_value(const T_items& items, const ValueType& value) = 0;

virtual std::vector<int> get_cluster_assignment(const T_items& items) const = 0;
virtual std::vector<int> get_cluster_assignment(
const T_items& items) const = 0;

virtual bool has_observation(const Domain& domain, const T_item& item) const = 0;
virtual bool has_observation(const Domain& domain,
const T_item& item) const = 0;

// Convert a string to ValueType.
ValueType from_string(const std::string& s) {
Expand All @@ -69,7 +84,6 @@ class Relation {
};

virtual ~Relation() = default;

};


Expand Down

0 comments on commit 5eb0ffd

Please sign in to comment.