Skip to content

Commit

Permalink
Merge pull request #212 from probcomp/240924-thomaswc-merge_gendb
Browse files Browse the repository at this point in the history
Merge GenDB and SchemaHelper; use GenDB in pclean binary
  • Loading branch information
ThomasColthurst authored Oct 2, 2024
2 parents c87569a + 57ef714 commit fca1b5b
Show file tree
Hide file tree
Showing 14 changed files with 901 additions and 1,057 deletions.
2 changes: 1 addition & 1 deletion cxx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ cc_library(
":irm",
":observations",
"//distributions:crp",
"//pclean:get_joint_relations",
"//pclean:io",
"//pclean:schema",
"//pclean:schema_helper",
],
)

Expand Down
6 changes: 5 additions & 1 deletion cxx/clean_relation.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <cstdlib>
#include <random>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -159,7 +160,10 @@ class CleanRelation : public Relation<T> {
}

std::vector<int> get_cluster_assignment(const T_items& items) const {
assert(items.size() == domains.size());
if (items.size() != domains.size()) {
printf("Warning: for relation %s, items.size=%ld and domains.size()=%ld\n", name.c_str(), items.size(), domains.size());
std::exit(1);
}
std::vector<int> z(domains.size());
for (int i = 0; i < std::ssize(domains); ++i) {
z[i] = domains[i]->get_cluster_assignment(items[i]);
Expand Down
4 changes: 3 additions & 1 deletion cxx/distributions/stringcat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// See LICENSE.txt

#include <algorithm>
#include <cstdlib>
#include <cassert>
#include <limits>
#include "distributions/stringcat.hh"
Expand All @@ -10,7 +11,8 @@
int StringCat::string_to_index(const std::string& s) const {
auto it = std::find(strings.begin(), strings.end(), s);
if (it == strings.end()) {
assert(false);
printf("String %s not in StringCat's list of strings\n", s.c_str());
std::exit(1);
}
return it - strings.begin();
}
Expand Down
238 changes: 229 additions & 9 deletions cxx/gendb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@
#include "hirm.hh"
#include "irm.hh"
#include "observations.hh"
#include "pclean/get_joint_relations.hh"
#include "pclean/schema.hh"
#include "pclean/schema_helper.hh"

GenDB::GenDB(std::mt19937* prng, const PCleanSchema& schema_,
bool _only_final_emissions, bool _record_class_is_clean)
: schema(schema_),
schema_helper(schema_, _only_final_emissions, _record_class_is_clean) {
std::map<std::string, std::vector<std::string>>
annotated_domains_for_relation;
T_schema hirm_schema =
schema_helper.make_hirm_schema(&annotated_domains_for_relation);
: schema(schema_), only_final_emissions(_only_final_emissions),
record_class_is_clean(_record_class_is_clean) {
// Note that the domains cache must be populated before the reference
// indices.
compute_domains_cache();
compute_reference_indices_cache();

T_schema hirm_schema = make_hirm_schema();
hirm = new HIRM(hirm_schema, prng);

for (const auto& [class_name, unused_class] : schema.classes) {
Expand Down Expand Up @@ -171,6 +173,7 @@ T_items GenDB::sample_class_ancestors(std::mt19937* prng,
const std::string& class_name,
int class_item, bool new_rows_have_unique_entities) {
T_items items;
assert(schema.classes.contains(class_name));
PCleanClass c = schema.classes.at(class_name);

for (const auto& [name, var] : c.vars) {
Expand Down Expand Up @@ -203,7 +206,7 @@ void GenDB::get_relation_items(const std::string& rel_name, const int ind,
const std::vector<std::string>& domains = std::visit(
[&](auto tr) { return tr.domains; }, hirm->schema.at(rel_name));
items[ind] = class_item;
auto& ref_indices = schema_helper.relation_reference_indices;
auto& ref_indices = relation_reference_indices;
if (ref_indices.contains(rel_name)) {
if (ref_indices.at(rel_name).contains(ind)) {
for (const auto& [rf_name, rf_ind] : ref_indices.at(rel_name).at(ind)) {
Expand Down Expand Up @@ -238,7 +241,7 @@ GenDB::unincorporate_reference(const std::string& class_name,
std::vector<size_t> domain_inds;
for (size_t i = 0; i < domains.size(); ++i) {
if (domains[i] == class_name &&
schema_helper.relation_reference_indices.at(rel_name).at(i).contains(
relation_reference_indices.at(rel_name).at(i).contains(
ref_field)) {
domain_inds.push_back(i);
}
Expand Down Expand Up @@ -392,3 +395,220 @@ void GenDB::incorporate_reference_relation(
}

GenDB::~GenDB() { delete hirm; }

void GenDB::compute_domains_cache() {
for (const auto& c : schema.classes) {
if (!domains.contains(c.first)) {
compute_domains_for(c.first);
}
}
}

void GenDB::compute_reference_indices_cache() {
for (const auto& c : schema.classes) {
if (!class_reference_indices.contains(c.first)) {
compute_reference_indices_for(c.first);
}
}
}

void GenDB::compute_domains_for(const std::string& name) {
std::vector<std::string> ds;
assert(schema.classes.contains(name));
PCleanClass c = schema.classes.at(name);

for (const auto& v : c.vars) {
if (const ClassVar* cv = std::get_if<ClassVar>(&(v.second.spec))) {
if (!domains.contains(cv->class_name)) {
compute_domains_for(cv->class_name);
}
for (const std::string& s : domains[cv->class_name]) {
ds.push_back(s);
}
}
}

// Put the "primary" domain last, so that it survives reordering.
ds.push_back(name);

domains[name] = ds;
}

void GenDB::compute_reference_indices_for(
const std::string& name) {
std::vector<std::string> ds;
int total_offset = 0;
assert(schema.classes.contains(name));
PCleanClass c = schema.classes.at(name);

// Recursively maps the indices of class "name" (and ancestors) in relation
// items to the names and indices (in items) of their parents (reference
// fields).
std::map<int, std::map<std::string, int>> ref_indices;

// Temporarily stores reference fields and indices for class "name";
std::map<std::string, int> class_ref_indices;
for (const auto& v : c.vars) {
if (const ClassVar* cv = std::get_if<ClassVar>(&(v.second.spec))) {
if (!class_reference_indices.contains(cv->class_name)) {
compute_reference_indices_for(cv->class_name);
}
// Indices for foreign-key domains are generated by adding an offset
// to their indices in the respective class.
const int offset = total_offset;
total_offset += domains.at(cv->class_name).size();
class_ref_indices[v.first] = total_offset - 1;
std::map<std::string, int> child_class_indices;
if (class_reference_indices.contains(cv->class_name)) {
for (const auto& [ind, ref] :
class_reference_indices.at(cv->class_name)) {
std::map<std::string, int> class_ref_indices;
for (const auto& [field_name, ref_ind] : ref) {
child_class_indices[field_name] = ref_ind + offset;
}
ref_indices[ind + offset] = child_class_indices;
}
}
}
}

// Do not store a `class_reference_indices` entry for classes
// with no reference fields.
if (class_ref_indices.size() > 0) {
ref_indices[total_offset] = class_ref_indices;
class_reference_indices[name] = ref_indices;
}
}

void GenDB::make_relations_for_queryfield(
const QueryField& f, const PCleanClass& record_class, T_schema* tschema) {

// First, find all the vars and classes specified in f.class_path.
std::vector<std::string> var_names;
std::vector<std::string> class_names;
PCleanVariable last_var;
PCleanClass last_class = record_class;
class_names.push_back(record_class.name);
for (size_t i = 0; i < f.class_path.size(); ++i) {
const PCleanVariable& v = last_class.vars[f.class_path[i]];
last_var = v;
var_names.push_back(v.name);
if (i < f.class_path.size() - 1) {
class_names.push_back(std::get<ClassVar>(v.spec).class_name);
last_class = schema.classes.at(class_names.back());
}
}
// Remove the last var_name because it isn't used in making the path_prefix.
var_names.pop_back();

// Get the base relation from the last class and variable name.
std::string base_relation_name = class_names.back() + ":" + last_var.name;

// Handle queries of the record class specially.
if (f.class_path.size() == 1) {
if (record_class_is_clean) {
// Just rename the existing clean relation and set it to be observed.
T_clean_relation cr =
std::get<T_clean_relation>(tschema->at(base_relation_name));
cr.is_observed = true;
(*tschema)[f.name] = cr;
tschema->erase(base_relation_name);
} else {
T_noisy_relation tnr =
get_emission_relation(std::get<ScalarVar>(last_var.spec),
domains[record_class.name], base_relation_name);
tnr.is_observed = true;
(*tschema)[f.name] = tnr;
// If the record class is the only class in the schema, there will be
// no entries in `relation_reference_indices`.
if (class_reference_indices.contains(record_class.name)) {
relation_reference_indices[f.name] =
class_reference_indices.at(record_class.name);
}
}
return;
}

// Handle only_final_emissions == true.
if (only_final_emissions) {
std::vector<std::string> noisy_domains = domains[class_names.back()];
for (int i = class_names.size() - 2; i >= 0; --i) {
noisy_domains.push_back(class_names[i]);
relation_reference_indices[f.name][noisy_domains.size() - 1]
[var_names[i]] = noisy_domains.size() - 2;
}
T_noisy_relation tnr = get_emission_relation(
std::get<ScalarVar>(last_var.spec), noisy_domains, base_relation_name);
tnr.is_observed = true;
(*tschema)[f.name] = tnr;
// If the record class is the only class in the schema, there will be
// no entries in `relation_reference_indices`.
if (relation_reference_indices.contains(base_relation_name)) {
relation_reference_indices[f.name] =
relation_reference_indices.at(base_relation_name);
}
return;
}

// Handle only_final_emissions == false.
std::string& previous_relation = base_relation_name;
std::vector<std::string> current_domains = domains[class_names.back()];
std::map<int, std::map<std::string, int>> ref_indices;
for (int i = f.class_path.size() - 2; i >= 0; --i) {
current_domains.push_back(class_names[i]);
ref_indices[current_domains.size() - 1][var_names[i]] =
current_domains.size() - 2;
T_noisy_relation tnr = get_emission_relation(
std::get<ScalarVar>(last_var.spec), current_domains, previous_relation);
std::string rel_name;
if (i == 0) {
rel_name = f.name;
tnr.is_observed = true;
} else {
// Intermediate emissions have a name of the form
// "[Observing Class]::[QueryFieldName]"
rel_name = class_names[i] + "::" + f.name;
tnr.is_observed = false;
}
(*tschema)[rel_name] = tnr;
// Since noisy relations have the leftmost domains in common with their base
// relations, they share the reference indices with their base relations as
// well.
if (relation_reference_indices.contains(previous_relation)) {
relation_reference_indices[rel_name] =
relation_reference_indices.at(previous_relation);
}
relation_reference_indices[rel_name].merge(ref_indices);
previous_relation = rel_name;
}
}

T_schema GenDB::make_hirm_schema() {
T_schema tschema;

// For every scalar variable, make a clean relation with the name
// "[ClassName]:[VariableName]".
for (const auto& c : schema.classes) {
for (const auto& v : c.second.vars) {
std::string rel_name = c.first + ':' + v.first;
if (const ScalarVar* dv = std::get_if<ScalarVar>(&(v.second.spec))) {
tschema[rel_name] = get_distribution_relation(*dv, domains[c.first]);
if (class_reference_indices.contains(c.first)) {
relation_reference_indices[rel_name] =
class_reference_indices.at(c.first);
}
}
}
}

// For every query field, make one or more relations by walking up
// the class_path. At least one of those relations will have name equal
// to the name of the QueryField.
const PCleanClass record_class = schema.classes.at(schema.query.record_class);
for (const auto& [unused_name, f] : schema.query.fields) {
make_relations_for_queryfield(f, record_class, &tschema);
}

return tschema;
}

Loading

0 comments on commit fca1b5b

Please sign in to comment.