Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasColthurst committed Jun 24, 2024
1 parent 00c3230 commit 3c498cc
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 15 deletions.
3 changes: 2 additions & 1 deletion cxx/distributions/nonconjugate.hh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <map>
#include <random>
#include <vector>
#include "distributions/base.hh"

template <typename T>
Expand Down Expand Up @@ -52,7 +53,7 @@ class NonconjugateDistribution : public Distribution<T> {
virtual void transition_theta(std::mt19937* prng) {
std::vector<double> old_latents = store_latents();
double old_logp_score = logp_score();
init_theta();
init_theta(prng);
double new_logp_score = logp_score();
double threshold = std::exp(new_logp_score - old_logp_score);
std::uniform_real_distribution rnd(0.0, 1.0);
Expand Down
15 changes: 6 additions & 9 deletions cxx/distributions/skellam.hh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <cassert>
#include <cmath>

#include "distributions/nonconjugate.hh"
Expand Down Expand Up @@ -68,19 +69,15 @@ class Skellam : public NonconjugateDistribution<int> {

std::vector<double> store_latents() {
std::vector<double> v;
v.push_back(mean1);
v.push_back(mean2);
v.push_back(stddev1);
v.push_back(stddev2);
v.push_back(mu1);
v.push_back(mu2);
return v;
}

void set_latents(const std::vector<double>& v) {
assert(v.size() == 4);
mean1 = v[0];
mean2 = v[1];
stddev1 = v[2];
stddev2 = v[3];
assert(v.size() == 2);
mu1 = v[0];
mu2 = v[1];
}

};
8 changes: 3 additions & 5 deletions cxx/distributions/skellam_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,14 @@ BOOST_AUTO_TEST_CASE(set_and_store_latents) {
std::mt19937 prng;

sd.init_theta(&prng);

std::vector<double> v = sd.store_latents();

sd.init_theta(&prng);

BOOST_TEST(v != sd.store_latents());
BOOST_TEST(v != sd.store_latents(), tt::per_element());

sd.set_latents(v);

BOOST_TEST(v == sd.store_latents());
BOOST_TEST(v == sd.store_latents(), tt::per_element());
}

BOOST_AUTO_TEST_CASE(transition_theta) {
Expand All @@ -66,5 +64,5 @@ BOOST_AUTO_TEST_CASE(transition_theta) {
sd.transition_theta(&prng);
}

BOOST_TEST(sd.mean1 > sd.mean2);
BOOST_TEST(sd.mu1 > sd.mu2);
}

0 comments on commit 3c498cc

Please sign in to comment.