Skip to content

Commit

Permalink
Rename NonNoisyRelation to CleanRelation.
Browse files Browse the repository at this point in the history
  • Loading branch information
emilyfertig committed Jul 3, 2024
1 parent e630c27 commit 8c8cf9a
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 71 deletions.
60 changes: 30 additions & 30 deletions cxx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,20 @@ cc_library(
deps = [],
)

cc_library(
name = "clean_relation",
hdrs = ["clean_relation.hh"],
visibility = [":__subpackages__"],
deps = [
":domain",
":relation",
":util_distribution_variant",
":util_hash",
":util_math",
"//distributions:base"
],
)

cc_library(
name = "domain",
hdrs = ["domain.hh"],
Expand All @@ -21,7 +35,7 @@ cc_library(
srcs = ["irm.cc"],
visibility = [":__subpackages__"],
deps = [
":non_noisy_relation",
":clean_relation",
":relation_variant",
":util_distribution_variant",
],
Expand Down Expand Up @@ -66,21 +80,7 @@ cc_library(
visibility = [":__subpackages__"],
deps = [
":domain",
":non_noisy_relation",
":relation",
":util_distribution_variant",
":util_hash",
":util_math",
"//distributions:base"
],
)

cc_library(
name = "non_noisy_relation",
hdrs = ["non_noisy_relation.hh"],
visibility = [":__subpackages__"],
deps = [
":domain",
":clean_relation",
":relation",
":util_distribution_variant",
":util_hash",
Expand Down Expand Up @@ -109,7 +109,7 @@ cc_library(
visibility = [":__subpackages__"],
deps = [
":domain",
":non_noisy_relation",
":clean_relation",
":relation",
":util_distribution_variant",
],
Expand Down Expand Up @@ -156,6 +156,17 @@ cc_library(
deps = [],
)

cc_test(
name = "clean_relation_test",
srcs = ["clean_relation_test.cc"],
deps = [
":domain",
":clean_relation",
"//distributions",
"@boost//:test",
],
)

cc_test(
name = "domain_test",
srcs = ["domain_test.cc"],
Expand All @@ -179,29 +190,18 @@ cc_test(
srcs = ["noisy_relation_test.cc"],
deps = [
":domain",
":non_noisy_relation",
":clean_relation",
":noisy_relation",
"//distributions",
"@boost//:test",
],
)

cc_test(
name = "non_noisy_relation_test",
srcs = ["non_noisy_relation_test.cc"],
deps = [
":domain",
":non_noisy_relation",
"//distributions",
"@boost//:test",
],
)

cc_test(
name = "relation_variant_test",
srcs = ["relation_variant_test.cc"],
deps = [
":non_noisy_relation",
":clean_relation",
":relation_variant",
"@boost//:test",
],
Expand Down
21 changes: 10 additions & 11 deletions cxx/non_noisy_relation.hh → cxx/clean_relation.hh
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
#include "util_hash.hh"
#include "util_math.hh"

// T_non_noisy_relation is the text we get from reading a line of the schema
// file; NonNoisyRelation is the object that does the work.
class T_non_noisy_relation {
// T_clean_relation is the text we get from reading a line of the schema
// file; CleanRelation is the object that does the work.
class T_clean_relation {
public:
// The relation is a map from the domains to the space .distribution
// is a distribution over.
Expand All @@ -30,7 +30,7 @@ class T_non_noisy_relation {
};

template <typename T>
class NonNoisyRelation : public Relation<T> {
class CleanRelation : public Relation<T> {
public:
typedef T ValueType;

Expand All @@ -53,10 +53,9 @@ class NonNoisyRelation : public Relation<T> {
std::unordered_map<T_item, std::unordered_set<T_items, H_items>>>
data_r;

NonNoisyRelation(
const std::string& name,
const std::variant<DistributionSpec, EmissionSpec>& prior_spec,
const std::vector<Domain*>& domains)
CleanRelation(const std::string& name,
const std::variant<DistributionSpec, EmissionSpec>& prior_spec,
const std::vector<Domain*>& domains)
: name(name), domains(domains), prior_spec(prior_spec) {
assert(!domains.empty());
assert(!name.empty());
Expand All @@ -66,7 +65,7 @@ class NonNoisyRelation : public Relation<T> {
}
}

~NonNoisyRelation() {
~CleanRelation() {
for (auto [z, cluster] : clusters) {
delete cluster;
}
Expand Down Expand Up @@ -408,6 +407,6 @@ class NonNoisyRelation : public Relation<T> {
}

// Disable copying.
NonNoisyRelation& operator=(const NonNoisyRelation&) = delete;
NonNoisyRelation(const NonNoisyRelation&) = delete;
CleanRelation& operator=(const CleanRelation&) = delete;
CleanRelation(const CleanRelation&) = delete;
};
8 changes: 4 additions & 4 deletions cxx/non_noisy_relation_test.cc → cxx/clean_relation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#define BOOST_TEST_MODULE test Relation

#include "non_noisy_relation.hh"
#include "clean_relation.hh"

#include <boost/test/included/unit_test.hpp>
#include <random>
Expand All @@ -13,7 +13,7 @@

namespace tt = boost::test_tools;

BOOST_AUTO_TEST_CASE(test_relation) {
BOOST_AUTO_TEST_CASE(test_clean_relation) {
std::mt19937 prng;
Domain D1("D1");
Domain D2("D2");
Expand All @@ -22,7 +22,7 @@ BOOST_AUTO_TEST_CASE(test_relation) {
D2.incorporate(&prng, 1);
D3.incorporate(&prng, 3);
DistributionSpec spec = DistributionSpec{DistributionEnum::bernoulli};
NonNoisyRelation<bool> R1("R1", spec, {&D1, &D2, &D3});
CleanRelation<bool> R1("R1", spec, {&D1, &D2, &D3});
R1.incorporate(&prng, {0, 1, 3}, 1);
R1.incorporate(&prng, {1, 1, 3}, 1);
R1.incorporate(&prng, {3, 1, 3}, 1);
Expand Down Expand Up @@ -54,7 +54,7 @@ BOOST_AUTO_TEST_CASE(test_relation) {
BOOST_TEST(db->N == 1);

DistributionSpec bigram_spec = DistributionSpec{DistributionEnum::bigram};
NonNoisyRelation<std::string> R2("R1", bigram_spec, {&D2, &D3});
CleanRelation<std::string> R2("R1", bigram_spec, {&D2, &D3});
R2.incorporate(&prng, {1, 3}, "cat");
R2.incorporate(&prng, {1, 2}, "dog");
R2.incorporate(&prng, {1, 4}, "catt");
Expand Down
4 changes: 2 additions & 2 deletions cxx/irm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
#include <unordered_map>
#include <unordered_set>

#include "non_noisy_relation.hh"
#include "clean_relation.hh"
#include "relation_variant.hh"
#include "util_distribution_variant.hh"

// TODO(emilyaf): Support noisy relations.
using T_relation = T_non_noisy_relation;
using T_relation = T_clean_relation;

// Map from names to T_relation's.
typedef std::map<std::string, T_relation> T_schema;
Expand Down
4 changes: 2 additions & 2 deletions cxx/noisy_relation.hh
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
#include <unordered_map>
#include <vector>

#include "clean_relation.hh"
#include "distributions/base.hh"
#include "domain.hh"
#include "emissions/base.hh"
#include "non_noisy_relation.hh"
#include "relation.hh"
#include "util_distribution_variant.hh"
#include "util_hash.hh"
Expand Down Expand Up @@ -50,7 +50,7 @@ class NoisyRelation : public Relation<T> {
// Base relation for "" values.
const Relation<ValueType>* base_relation;
// A Relation for the Emission that models noisy values given values.
NonNoisyRelation<std::pair<ValueType, ValueType>> emission_relation;
CleanRelation<std::pair<ValueType, ValueType>> emission_relation;

NoisyRelation(const std::string& name, const EmissionSpec& emission_spec,
const std::vector<Domain*>& domains, Relation<T>* base_relation)
Expand Down
6 changes: 3 additions & 3 deletions cxx/noisy_relation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
#include <iostream>
#include <random>

#include "clean_relation.hh"
#include "distributions/beta_bernoulli.hh"
#include "distributions/bigram.hh"
#include "domain.hh"
#include "non_noisy_relation.hh"

namespace tt = boost::test_tools;

Expand All @@ -24,7 +24,7 @@ BOOST_AUTO_TEST_CASE(test_noisy_relation) {
D2.incorporate(&prng, 1);
D3.incorporate(&prng, 3);
DistributionSpec spec = DistributionSpec{DistributionEnum::bernoulli};
NonNoisyRelation<bool> R1("R1", spec, {&D1, &D2});
CleanRelation<bool> R1("R1", spec, {&D1, &D2});
R1.incorporate(&prng, {0, 1}, 1);
R1.incorporate(&prng, {1, 1}, 1);
R1.incorporate(&prng, {3, 1}, 1);
Expand Down Expand Up @@ -53,7 +53,7 @@ BOOST_AUTO_TEST_CASE(test_noisy_relation) {
NR1.set_cluster_assignment_gibbs(D1, 0, 1, &prng);

DistributionSpec bigram_spec = DistributionSpec{DistributionEnum::bigram};
NonNoisyRelation<std::string> R2("R2", bigram_spec, {&D2, &D3});
CleanRelation<std::string> R2("R2", bigram_spec, {&D2, &D3});
EmissionSpec str_emspec = EmissionSpec(EmissionEnum::simple_string);
NoisyRelation<std::string> NR2("NR2", str_emspec, {&D2, &D3}, &R2);

Expand Down
4 changes: 2 additions & 2 deletions cxx/relation_variant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#include <random>
#include <type_traits>

#include "clean_relation.hh"
#include "domain.hh"
#include "non_noisy_relation.hh"

// TODO(emilyaf): Implement this for NoisyRelation.
RelationVariant relation_from_spec(const std::string& name,
Expand All @@ -34,7 +34,7 @@ RelationVariant relation_from_spec(const std::string& name,
// the right kind of Relation.
std::visit(
[&](const auto& v) {
rv = new NonNoisyRelation<
rv = new CleanRelation<
typename std::remove_reference_t<decltype(*v)>::SampleType>(
name, dist_spec, domains);
},
Expand Down
6 changes: 3 additions & 3 deletions cxx/relation_variant_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

#include <boost/test/included/unit_test.hpp>

#include "non_noisy_relation.hh"
#include "clean_relation.hh"
namespace tt = boost::test_tools;

BOOST_AUTO_TEST_CASE(test_relation_variant) {
std::vector<Domain*> domains;
domains.push_back(new Domain("D1"));
RelationVariant rv =
relation_from_spec("r1", parse_distribution_spec("bernoulli"), domains);
NonNoisyRelation<bool>* rb =
reinterpret_cast<NonNoisyRelation<bool>*>(std::get<Relation<bool>*>(rv));
CleanRelation<bool>* rb =
reinterpret_cast<CleanRelation<bool>*>(std::get<Relation<bool>*>(rv));
BOOST_TEST(rb->name == "r1");
}
4 changes: 2 additions & 2 deletions cxx/tests/test_hirm_animals.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ int main(int argc, char** argv) {
}
// Check relations agree.
for (const auto& [r, rm_var] : irm->relations) {
auto rx = reinterpret_cast<NonNoisyRelation<bool>*>(
auto rx = reinterpret_cast<CleanRelation<bool>*>(
std::get<Relation<bool>*>(irx->relations.at(r)));
auto rm = reinterpret_cast<NonNoisyRelation<bool>*>(
auto rm = reinterpret_cast<CleanRelation<bool>*>(
std::get<Relation<bool>*>(rm_var));
assert(rm->data == rx->data);
assert(rm->data_r == rx->data_r);
Expand Down
28 changes: 19 additions & 9 deletions cxx/tests/test_irm_two_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "util_io.hh"
#include "util_math.hh"

using T_r = NonNoisyRelation<bool>*;
using T_r = CleanRelation<bool>*;

int main(int argc, char** argv) {
std::string path_base = "assets/two_relations";
Expand Down Expand Up @@ -77,10 +77,14 @@ int main(int argc, char** argv) {
assert(l.size() == 2);
auto x1 = l.at(0);
auto x2 = l.at(1);
auto p0 = reinterpret_cast<T_r>(std::get<Relation<bool>*>(irm.relations.at("R1")))->logp({x1, x2}, false, &prng);
auto p0 =
reinterpret_cast<T_r>(std::get<Relation<bool>*>(irm.relations.at("R1")))
->logp({x1, x2}, false, &prng);
auto p0_irm = irm.logp({{"R1", {x1, x2}, false}}, &prng);
assert(abs(p0 - p0_irm) < 1e-10);
auto p1 = reinterpret_cast<T_r>(std::get<Relation<bool>*>(irm.relations.at("R1")))->logp({x1, x2}, true, &prng);
auto p1 =
reinterpret_cast<T_r>(std::get<Relation<bool>*>(irm.relations.at("R1")))
->logp({x1, x2}, true, &prng);
auto Z = logsumexp({p0, p1});
assert(abs(Z) < 1e-10);
assert(abs(exp(p0) - expected_p0[x1].at(x2)) < .1);
Expand All @@ -91,10 +95,14 @@ int main(int argc, char** argv) {
auto x1 = l.at(0);
auto x2 = l.at(1);
auto x3 = l.at(2);
auto p00 = irm.logp({{"R1", {x1, x2}, false}, {"R1", {x1, x3}, false}}, &prng);
auto p01 = irm.logp({{"R1", {x1, x2}, false}, {"R1", {x1, x3}, true}}, &prng);
auto p10 = irm.logp({{"R1", {x1, x2}, true}, {"R1", {x1, x3}, false}}, &prng);
auto p11 = irm.logp({{"R1", {x1, x2}, true}, {"R1", {x1, x3}, true}}, &prng);
auto p00 =
irm.logp({{"R1", {x1, x2}, false}, {"R1", {x1, x3}, false}}, &prng);
auto p01 =
irm.logp({{"R1", {x1, x2}, false}, {"R1", {x1, x3}, true}}, &prng);
auto p10 =
irm.logp({{"R1", {x1, x2}, true}, {"R1", {x1, x3}, false}}, &prng);
auto p11 =
irm.logp({{"R1", {x1, x2}, true}, {"R1", {x1, x3}, true}}, &prng);
auto Z = logsumexp({p00, p01, p10, p11});
assert(abs(Z) < 1e-10);
}
Expand All @@ -111,8 +119,10 @@ int main(int argc, char** argv) {
// transitioned.
assert(abs(irx.logp_score() - irm.logp_score()) > 1e-8);
for (const auto& r : {"R1", "R2"}) {
auto r1m = reinterpret_cast<T_r>(std::get<Relation<bool>*>(irm.relations.at(r)));
auto r1x = reinterpret_cast<T_r>(std::get<Relation<bool>*>(irx.relations.at(r)));
auto r1m =
reinterpret_cast<T_r>(std::get<Relation<bool>*>(irm.relations.at(r)));
auto r1x =
reinterpret_cast<T_r>(std::get<Relation<bool>*>(irx.relations.at(r)));
for (const auto& [c, distribution] : r1m->clusters) {
auto dx = reinterpret_cast<BetaBernoulli*>(r1x->clusters.at(c));
auto dy = reinterpret_cast<BetaBernoulli*>(distribution);
Expand Down
6 changes: 3 additions & 3 deletions cxx/tests/test_misc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ int main(int argc, char** argv) {
std::string path_clusters = "assets/animals.binary.irm";
to_txt(path_clusters, irm3, encoding);

auto rel = reinterpret_cast<NonNoisyRelation<bool>*>(
auto rel = reinterpret_cast<CleanRelation<bool>*>(
std::get<Relation<bool>*>(irm3.relations.at("has")));
auto& enc = std::get<0>(encoding);
auto lp0 = rel->logp({enc["animal"]["tail"], enc["animal"]["bat"]}, 0, &prng);
Expand All @@ -132,9 +132,9 @@ int main(int argc, char** argv) {
assert(d3->crp.alpha == d4->crp.alpha);
}
for (const auto& r : {"has"}) {
auto r3 = reinterpret_cast<NonNoisyRelation<bool>*>(
auto r3 = reinterpret_cast<CleanRelation<bool>*>(
std::get<Relation<bool>*>(irm3.relations.at(r)));
auto r4 = reinterpret_cast<NonNoisyRelation<bool>*>(
auto r4 = reinterpret_cast<CleanRelation<bool>*>(
std::get<Relation<bool>*>(irm4.relations.at(r)));
assert(r3->data == r4->data);
assert(r3->data_r == r4->data_r);
Expand Down

0 comments on commit 8c8cf9a

Please sign in to comment.