Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transition_theta to outer inference code #71

Merged
merged 5 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cxx/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ cc_library(
":crp",
":dirichlet_categorical",
":normal",
":skellam",
],
)

Expand Down Expand Up @@ -85,6 +86,7 @@ cc_library(

cc_library(
name = "skellam",
srcs = ["skellam.cc"],
hdrs = ["skellam.hh"],
deps = [
":nonconjugate",
Expand Down
2 changes: 1 addition & 1 deletion cxx/distributions/nonconjugate.hh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class NonconjugateDistribution : public Distribution<T> {
virtual void init_theta(std::mt19937* prng) = 0;

// Return the current latent values as a vector.
virtual std::vector<double> store_latents() = 0;
virtual std::vector<double> store_latents() const = 0;

// Set the current latent values from a vector.
virtual void set_latents(const std::vector<double>& v) = 0;
Expand Down
66 changes: 66 additions & 0 deletions cxx/distributions/skellam.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include "distributions/skellam.hh"

#include <cassert>
#include <cmath>
#include "util_math.hh"

double lognormal_logp(double x, double mean, double stddev) {
double y = (std::log(x) - mean) / stddev;
return - y*y / 2.0
- std::log(x * stddev) - 0.5 * std::log(2.0 * std::numbers::pi);
}

double Skellam::logp(const int&x) const {
return -mu1 - mu2 + (x / 2.0) * std::log(mu1 / mu2)
// TODO(thomaswc): Replace this with something more numerically stable.
+ std::log(std::cyl_bessel_i(x, 2.0 * std::sqrt(mu1 * mu2)));
}

int Skellam::sample(std::mt19937* prng) {
std::poisson_distribution<int> d1(mu1);
std::poisson_distribution<int> d2(mu2);
return d1(*prng) - d2(*prng);
}

void Skellam::transition_hyperparameters(std::mt19937* prng) {
std::vector<double> logps;
std::vector<std::tuple<double, double, double, double>> hypers;
for (double tmean1 : MEAN_GRID) {
for (double tstddev1 : STDDEV_GRID) {
for (double tmean2 : MEAN_GRID) {
for (double tstddev2 : STDDEV_GRID) {
double lp = lognormal_logp(mu1, tmean1, tstddev1)
+ lognormal_logp(mu2, tmean2, tstddev2);
logps.push_back(lp);
hypers.push_back(
std::make_tuple(tmean1, tstddev1, tmean2, tstddev2));
}
}
}
}
int i = sample_from_logps(logps, prng);
mean1 = std::get<0>(hypers[i]);
stddev1 = std::get<1>(hypers[i]);
mean2 = std::get<2>(hypers[i]);
stddev2 = std::get<3>(hypers[i]);
}

void Skellam::init_theta(std::mt19937* prng) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is there a need for Skellam to be initialized randomly vs. statically like the other distributions? Or do we see other distributions initializing their parameters from their hyperprior? Mainly asking since it seems a little odd that Skellam's path is different from other distributions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, all the NonconjugateDistributions will need random initializations. The reason is that the NonconjugateDistributions can't efficiently marginalize over the latent parameters, so instead they store current values of their latents, which evolve whenever transition_theta is called. And I believe that that overall learning procedure works better when the latents are randomly initialized, but I guess I could be wrong about that. I personally always use random initialization when doing Metropolis-Hastings, but I guess some people always use the origin or the mean of the sampling distribution or something like that.

Anyway, we could definitely raise this issue on the slack channel if you like, but the above is my current understanding. Oh, and I guess another thing is that the GenDB doc says that for "distributions that explicitly represent their latents", "the code to initialize the model state will need to include code for sampling initial parameters from the parameter prior". So this is that.

std::normal_distribution<double> d1(mean1, stddev1);
std::normal_distribution<double> d2(mean2, stddev2);
mu1 = std::exp(d1(*prng));
mu2 = std::exp(d2(*prng));
}

std::vector<double> Skellam::store_latents() const {
std::vector<double> v;
v.push_back(mu1);
v.push_back(mu2);
return v;
}

void Skellam::set_latents(const std::vector<double>& v) {
assert(v.size() == 2);
mu1 = v[0];
mu2 = v[1];
}
66 changes: 6 additions & 60 deletions cxx/distributions/skellam.hh
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
#pragma once

#include <cassert>
#include <cmath>

#include "distributions/nonconjugate.hh"
#include "util_math.hh"

#define MEAN_GRID { -10.0, 0.0, 10.0 }
#define STDDEV_GRID { 0.1, 1.0, 10.0 }

double lognormal_logp(double x, double mean, double stddev) {
double y = (std::log(x) - mean) / stddev;
return - y*y / 2.0
- std::log(x * stddev) - 0.5 * std::log(2.0 * std::numbers::pi);
}

class Skellam : public NonconjugateDistribution<int> {
public:
// Skellam distribution with log Normal hyperprior of latent rates.
Expand All @@ -25,59 +15,15 @@ class Skellam : public NonconjugateDistribution<int> {
Skellam(): mean1(0.0), mean2(0.0), stddev1(1.0), stddev2(1.0),
mu1(1.0), mu2(1.0) {}

double logp(const int& x) const {
return -mu1 - mu2 + (x / 2.0) * std::log(mu1 / mu2)
// TODO(thomaswc): Replace this with something more numerically stable.
+ std::log(std::cyl_bessel_i(x, 2.0 * std::sqrt(mu1 * mu2)));
}

int sample(std::mt19937* prng) {
std::poisson_distribution<int> d1(mu1);
std::poisson_distribution<int> d2(mu2);
return d1(*prng) - d2(*prng);
}
double logp(const int& x) const;

void transition_hyperparameters(std::mt19937* prng) {
std::vector<double> logps;
std::vector<std::tuple<double, double, double, double>> hypers;
for (double tmean1 : MEAN_GRID) {
for (double tstddev1 : STDDEV_GRID) {
for (double tmean2 : MEAN_GRID) {
for (double tstddev2 : STDDEV_GRID) {
double lp = lognormal_logp(mu1, tmean1, tstddev1)
+ lognormal_logp(mu2, tmean2, tstddev2);
logps.push_back(lp);
hypers.push_back(
std::make_tuple(tmean1, tstddev1, tmean2, tstddev2));
}
}
}
}
int i = sample_from_logps(logps, prng);
mean1 = std::get<0>(hypers[i]);
stddev1 = std::get<1>(hypers[i]);
mean2 = std::get<2>(hypers[i]);
stddev2 = std::get<3>(hypers[i]);
}
int sample(std::mt19937* prng);

void init_theta(std::mt19937* prng) {
std::normal_distribution<double> d1(mean1, stddev1);
std::normal_distribution<double> d2(mean2, stddev2);
mu1 = std::exp(d1(*prng));
mu2 = std::exp(d2(*prng));
}
void transition_hyperparameters(std::mt19937* prng);

std::vector<double> store_latents() {
std::vector<double> v;
v.push_back(mu1);
v.push_back(mu2);
return v;
}
void init_theta(std::mt19937* prng);

void set_latents(const std::vector<double>& v) {
assert(v.size() == 2);
mu1 = v[0];
mu2 = v[1];
}
std::vector<double> store_latents() const;

void set_latents(const std::vector<double>& v);
};
4 changes: 2 additions & 2 deletions cxx/hirm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ void HIRM::remove_relation(const std::string& name) {

double HIRM::logp(
const std::vector<std::tuple<std::string, T_items, ObservationVariant>>&
observations) {
observations, std::mt19937* prng) {
std::unordered_map<
int, std::vector<std::tuple<std::string, T_items, ObservationVariant>>>
obs_dict;
Expand All @@ -219,7 +219,7 @@ double HIRM::logp(
}
double logp = 0.0;
for (const auto& [t, o] : obs_dict) {
logp += irms.at(t)->logp(o);
logp += irms.at(t)->logp(o, prng);
}
return logp;
}
Expand Down
2 changes: 1 addition & 1 deletion cxx/hirm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class HIRM {

double logp(
const std::vector<std::tuple<std::string, T_items, ObservationVariant>>&
observations);
observations, std::mt19937* prng);

double logp_score() const;

Expand Down
13 changes: 8 additions & 5 deletions cxx/irm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ void IRM::transition_cluster_assignment_item(std::mt19937* prng,
auto accumulate_logps = [&](auto rel) {
if (rel->has_observation(*domain, item)) {
std::vector<double> lp_relation =
rel->logp_gibbs_exact(*domain, item, tables);
rel->logp_gibbs_exact(*domain, item, tables, prng);
assert(lp_relation.size() == tables.size());
assert(lp_relation.size() == logps.size());
for (int i = 0; i < std::ssize(logps); ++i) {
Expand All @@ -93,7 +93,7 @@ void IRM::transition_cluster_assignment_item(std::mt19937* prng,
if (choice != domain->get_cluster_assignment(item)) {
auto set_cluster_r = [&](auto rel) {
if (rel->has_observation(*domain, item)) {
rel->set_cluster_assignment_gibbs(*domain, item, choice);
rel->set_cluster_assignment_gibbs(*domain, item, choice, prng);
}
};
for (const std::string& r : domain_to_relations.at(d)) {
Expand All @@ -105,7 +105,7 @@ void IRM::transition_cluster_assignment_item(std::mt19937* prng,

double IRM::logp(
const std::vector<std::tuple<std::string, T_items, ObservationVariant>>&
observations) {
observations, std::mt19937* prng) {
std::unordered_map<std::string, std::unordered_set<T_items, H_items>>
relation_items_seen;
std::unordered_map<std::string, std::unordered_set<T_item>>
Expand Down Expand Up @@ -200,7 +200,7 @@ double IRM::logp(
if (rel->clusters.contains(z)) {
return rel->clusters.at(z)->logp(v);
}
DistributionVariant prior = cluster_prior_from_spec(rel->dist_spec);
DistributionVariant prior = cluster_prior_from_spec(rel->dist_spec, prng);
return std::visit(
[&](const auto& dist_variant) {
auto v2 = std::get<
Expand Down Expand Up @@ -295,7 +295,7 @@ void IRM::remove_relation(const std::string& name) {
}

void single_step_irm_inference(std::mt19937* prng, IRM* irm, double& t_total,
bool verbose) {
bool verbose, int num_theta_steps) {
// TRANSITION ASSIGNMENTS.
for (const auto& [d, domain] : irm->domains) {
for (const auto item : domain->items) {
Expand All @@ -310,6 +310,9 @@ void single_step_irm_inference(std::mt19937* prng, IRM* irm, double& t_total,
[&](auto r) {
for (const auto& [c, distribution] : r->clusters) {
clock_t t = clock();
for (int i = 0; i < num_theta_steps; ++i ) {
distribution->transition_theta(prng);
}
distribution->transition_hyperparameters(prng);
REPORT_SCORE(verbose, t, t_total, irm);
}
Expand Down
4 changes: 2 additions & 2 deletions cxx/irm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class IRM {
const T_item& item);
double logp(
const std::vector<std::tuple<std::string, T_items, ObservationVariant>>&
observations);
observations, std::mt19937* prng);

double logp_score() const;

Expand All @@ -57,4 +57,4 @@ class IRM {

// Run a single step of inference on an IRM model.
void single_step_irm_inference(std::mt19937* prng, IRM* irm, double& t_total,
bool verbose);
bool verbose, int num_theta_steps = 10);
2 changes: 1 addition & 1 deletion cxx/irm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ BOOST_AUTO_TEST_CASE(test_irm) {

auto obs0 = observation_string_to_value("0", DistributionEnum::bernoulli);

double logp_x = irm.logp({{"R1", {1, 2}, obs0}});
double logp_x = irm.logp({{"R1", {1, 2}, obs0}}, &prng);
BOOST_TEST(logp_x < 0.0);

irm.incorporate(&prng, "R1", {1, 2}, obs0);
Expand Down
Loading