Skip to content

Commit

Permalink
Fix transition_hyperparameter bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasColthurst committed Aug 16, 2024
1 parent 8acf4b7 commit ca24eaf
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 20 deletions.
10 changes: 7 additions & 3 deletions cxx/distributions/beta_bernoulli.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ void BetaBernoulli::transition_hyperparameters(std::mt19937* prng) {
}
}
}
int i = sample_from_logps(logps, prng);
alpha = hypers[i].first;
beta = hypers[i].second;
if (logps.empty()) {
printf("Warning! All hyperparamters for BetaBernoulli give nans!\n");
} else {
int i = sample_from_logps(logps, prng);
alpha = hypers[i].first;
beta = hypers[i].second;
}
}
8 changes: 6 additions & 2 deletions cxx/distributions/bigram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ void Bigram::transition_hyperparameters(std::mt19937* prng) {
alphas.push_back(alphat);
}
}
int i = sample_from_logps(logps, prng);
set_alpha(alphas[i]);
if (logps.empty()) {
printf("Warning! All hyperparameters for Bigram give nans!\n");
} else {
int i = sample_from_logps(logps, prng);
set_alpha(alphas[i]);
}
}
8 changes: 6 additions & 2 deletions cxx/distributions/dirichlet_categorical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ void DirichletCategorical::transition_hyperparameters(std::mt19937* prng) {
alphas.push_back(alpha);
}
}
int i = sample_from_logps(logps, prng);
alpha = alphas[i];
if (alphas.empty()) {
printf("Warning: all Dirichlet hyperparameters give nans!\n");
} else {
int i = sample_from_logps(logps, prng);
alpha = alphas[i];
}
}
14 changes: 9 additions & 5 deletions cxx/distributions/normal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,13 @@ void Normal::transition_hyperparameters(std::mt19937* prng) {
}
}

int i = sample_from_logps(logps, prng);
r = std::get<0>(hypers[i]);
v = std::get<1>(hypers[i]);
m = std::get<2>(hypers[i]);
s = std::get<3>(hypers[i]);
if (logps.empty()) {
printf("Warning! All hyperparameters for Normal give nans!\n");
} else {
int i = sample_from_logps(logps, prng);
r = std::get<0>(hypers[i]);
v = std::get<1>(hypers[i]);
m = std::get<2>(hypers[i]);
s = std::get<3>(hypers[i]);
}
}
15 changes: 10 additions & 5 deletions cxx/distributions/skellam.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@ void Skellam::transition_hyperparameters(std::mt19937* prng) {
}
}
}
int i = sample_from_logps(logps, prng);
mean1 = std::get<0>(hypers[i]);
stddev1 = std::get<1>(hypers[i]);
mean2 = std::get<2>(hypers[i]);
stddev2 = std::get<3>(hypers[i]);

if (logps.empty()) {
printf("Warning! All hyperparameters for Skellam gave nans!\n");
} else {
int i = sample_from_logps(logps, prng);
mean1 = std::get<0>(hypers[i]);
stddev1 = std::get<1>(hypers[i]);
mean2 = std::get<2>(hypers[i]);
stddev2 = std::get<3>(hypers[i]);
}
}

void Skellam::init_theta(std::mt19937* prng) {
Expand Down
10 changes: 7 additions & 3 deletions cxx/distributions/zero_mean_normal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ void ZeroMeanNormal::transition_hyperparameters(std::mt19937* prng) {
}
}

int i = sample_from_logps(logps, prng);
alpha = std::get<0>(hypers[i]);
beta = std::get<1>(hypers[i]);
if (logps.empty()) {
printf("Warning! All hyperparameters for ZeroMeanNormal gave nans!\n");
} else {
int i = sample_from_logps(logps, prng);
alpha = std::get<0>(hypers[i]);
beta = std::get<1>(hypers[i]);
}
}

0 comments on commit ca24eaf

Please sign in to comment.