Skip to content

Commit

Permalink
Fix build errors
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasColthurst committed Sep 3, 2024
1 parent 189bad0 commit a87a5d1
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 23 deletions.
1 change: 1 addition & 0 deletions cxx/hirm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,4 @@ class HIRM {
HIRM& operator=(const HIRM&) = delete;
HIRM(const HIRM&) = delete;
};

19 changes: 0 additions & 19 deletions cxx/hirm_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,6 @@
#include "irm.hh"
#include "util_io.hh"

double logp(std::mt19937* prng, std::variant<IRM*, HIRM*> h_irm,
const T_encoding& encoding,
const T_observations& observations) {
T_encoded_observations encoded_obs = encode_observations(
observations, encoding, h_irm);
std::vector<std::tuple<std::string, T_items, ObservationVariant>> logp_obs;
for (const auto& [relation, obs_for_rel]: encoded_obs) {
RelationVariant rel_var =
std::visit([&](auto m) { return m->get_relation(relation); }, h_irm);
for (const auto& [items, value]: obs_for_rel) {
ObservationVariant ov;
std::visit([&](const auto &r) {ov = r->from_string(value); }, rel_var);
logp_obs.push_back(make_tuple(relation, items, ov));
}
}
return std::visit(
[&](const auto& m) { return m->logp(logp_obs, prng); }, h_irm);
}

int main(int argc, char** argv) {
cxxopts::Options options("hirm",
"Run a hierarchical infinite relational model.");
Expand Down
8 changes: 4 additions & 4 deletions cxx/pclean/pclean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,15 @@ int main(int argc, char** argv) {
if (heldout_fn.empty()) {
encoding_observations = observations;
} else {
std::cout << "Loading held out observations from " << held_out << std::endl;
std::cout << "Loading held out observations from " << heldout_fn << std::endl;
DataFrame heldout_df = DataFrame::from_csv(heldout_fn);
heldout_obs = translate_observations(
heldout_df, hirm_schema, annotated_domains_for_relations);
encoding_observations = merge_obserations(observations, heldout_obs);
encoding_observations = merge_observations(observations, heldout_obs);
}

std::cout << "Encoding observations ...\n";
T_encoding encoding = calculate_encoding(schema, encoding_observations);
T_encoding encoding = calculate_encoding(hirm_schema, encoding_observations);

std::cout << "Incorporating observations ...\n";
incorporate_observations(&prng, &hirm, encoding, observations);
Expand All @@ -128,7 +128,7 @@ int main(int argc, char** argv) {
}

if (!heldout_fn.empty()) {
double lp = logp(&prng, hirm, encoding, heldout_obs);
double lp = logp(&prng, &hirm, encoding, heldout_obs);
std::cout << "Log likelihood of held out data is " << lp << std::endl;
}

Expand Down
19 changes: 19 additions & 0 deletions cxx/util_io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -681,3 +681,22 @@ T_observations merge_observations(const T_observations& obs1,
}
return merged;
}

double logp(std::mt19937* prng, std::variant<IRM*, HIRM*> h_irm,
const T_encoding& encoding,
const T_observations& observations) {
T_encoded_observations encoded_obs = encode_observations(
observations, encoding, h_irm);
std::vector<std::tuple<std::string, T_items, ObservationVariant>> logp_obs;
for (const auto& [relation, obs_for_rel]: encoded_obs) {
RelationVariant rel_var =
std::visit([&](auto m) { return m->get_relation(relation); }, h_irm);
for (const auto& [items, value]: obs_for_rel) {
ObservationVariant ov;
std::visit([&](const auto &r) {ov = r->from_string(value); }, rel_var);
logp_obs.push_back(make_tuple(relation, items, ov));
}
}
return std::visit(
[&](const auto& m) { return m->logp(logp_obs, prng); }, h_irm);
}
6 changes: 6 additions & 0 deletions cxx/util_io.hh
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,9 @@ void from_txt(std::mt19937* prng, HIRM* const irm,

T_observations merge_observations(const T_observations& obs1,
const T_observations& obs2);

// Return the log probability of the observations given the model h_irm.
double logp(std::mt19937* prng, std::variant<IRM*, HIRM*> h_irm,
const T_encoding& encoding,
const T_observations& observations);

0 comments on commit a87a5d1

Please sign in to comment.