From c6524be2ea8fa96276b161fd14d9fed8fb199a86 Mon Sep 17 00:00:00 2001 From: Srinivas Vasudevan Date: Thu, 11 Jul 2024 14:47:53 -0700 Subject: [PATCH 1/3] Add unincorporate for clean_relation and domain --- cxx/clean_relation.hh | 41 +++++++++++++++++++------------------ cxx/clean_relation_test.cc | 42 +++++++++++++++++++++++++++++++++++--- cxx/domain.hh | 12 +++-------- cxx/domain_test.cc | 5 +++++ 4 files changed, 68 insertions(+), 32 deletions(-) diff --git a/cxx/clean_relation.hh b/cxx/clean_relation.hh index cd98625..671251a 100644 --- a/cxx/clean_relation.hh +++ b/cxx/clean_relation.hh @@ -103,26 +103,27 @@ class CleanRelation : public Relation { } void unincorporate(const T_items& items) { - printf("Not implemented\n"); - exit(EXIT_FAILURE); - // auto x = data.at(items); - // auto z = get_cluster_assignment(items); - // clusters.at(z)->unincorporate(x); - // if (clusters.at(z)->N == 0) { - // delete clusters.at(z); - // clusters.erase(z); - // } - // for (int i = 0; i < domains.size(); i++) { - // const std::string &n = domains[i]->name; - // if (data_r.at(n).count(items[i]) > 0) { - // data_r.at(n).at(items[i]).erase(items); - // if (data_r.at(n).at(items[i]).size() == 0) { - // data_r.at(n).erase(items[i]); - // domains[i]->unincorporate(name, items[i]); - // } - // } - // } - // data.erase(items); + assert(data.count(items) == 1); + auto value = data.at(items); + auto z = get_cluster_assignment(items); + clusters.at(z)->unincorporate(value); + if (clusters.at(z)->N == 0) { + delete clusters.at(z); + clusters.erase(z); + } + for (int i = 0; i < std::ssize(domains); ++i) { + const std::string& name = domains[i]->name; + if (data_r.at(name).contains(items[i])) { + data_r.at(name).at(items[i]).erase(items); + if (data_r.at(name).at(items[i]).size() == 0) { + // It's safe to unincorporate this element since no other data point + // refers to it. + data_r.at(name).erase(items[i]); + domains[i]->unincorporate(items[i]); + } + } + } + data.erase(items); } std::vector get_cluster_assignment(const T_items& items) const { diff --git a/cxx/clean_relation_test.cc b/cxx/clean_relation_test.cc index d0d52dc..9ca8aa6 100644 --- a/cxx/clean_relation_test.cc +++ b/cxx/clean_relation_test.cc @@ -52,20 +52,56 @@ BOOST_AUTO_TEST_CASE(test_clean_relation) { BOOST_TEST(db->N == 0); db->incorporate(false); BOOST_TEST(db->N == 1); +} + +BOOST_AUTO_TEST_CASE(test_string_relation) { + std::mt19937 prng; + Domain D1("D1"); + Domain D2("D2"); DistributionSpec bigram_spec = DistributionSpec("bigram"); - CleanRelation R2("R1", bigram_spec, {&D2, &D3}); + CleanRelation R2("R2", bigram_spec, {&D1, &D2}); R2.incorporate(&prng, {1, 3}, "cat"); R2.incorporate(&prng, {1, 2}, "dog"); R2.incorporate(&prng, {1, 4}, "catt"); R2.incorporate(&prng, {2, 6}, "fish"); + double lpg __attribute__((unused)); lpg = R2.logp_gibbs_approx(D2, 2, 0, &prng); - R2.set_cluster_assignment_gibbs(D3, 3, 1, &prng); - D1.set_cluster_assignment_gibbs(0, 1); + R2.set_cluster_assignment_gibbs(D2, 3, 1, &prng); + D1.set_cluster_assignment_gibbs(1, 1); Distribution* db2 = R2.make_new_distribution(&prng); BOOST_TEST(db2->N == 0); db2->incorporate("hello"); BOOST_TEST(db2->N == 1); } + +BOOST_AUTO_TEST_CASE(test_unincorporate) { + std::mt19937 prng; + Domain D1("D1"); + Domain D2("D2"); + DistributionSpec spec = DistributionSpec("bernoulli"); + CleanRelation R1("R1", spec, {&D1, &D2}); + R1.incorporate(&prng, {0, 1}, 1); + R1.incorporate(&prng, {0, 2}, 1); + R1.incorporate(&prng, {3, 0}, 1); + R1.incorporate(&prng, {3, 1}, 1); + + R1.unincorporate({3, 1}); + BOOST_TEST(R1.data.size() == 3); + // Expect that these are still in the domain since the data points {3, 0} and + // {0, 1} refer to them. + BOOST_TEST(D1.items.contains(3)); + BOOST_TEST(D2.items.contains(1)); + + R1.unincorporate({0, 2}); + BOOST_TEST(R1.data.size() == 2); + BOOST_TEST(D1.items.contains(0)); + BOOST_TEST(!D2.items.contains(2)); + + R1.unincorporate({0, 1}); + BOOST_TEST(R1.data.size() == 1); + BOOST_TEST(!D1.items.contains(0)); + BOOST_TEST(!D2.items.contains(1)); +} diff --git a/cxx/domain.hh b/cxx/domain.hh index 4db73d7..a6f0c70 100644 --- a/cxx/domain.hh +++ b/cxx/domain.hh @@ -27,15 +27,9 @@ class Domain { } } void unincorporate(const T_item& item) { - printf("Not implemented\n"); - exit(EXIT_FAILURE); - // assert(items.count(item) == 1); - // assert(items.at(item).count(relation) == 1); - // items.at(item).erase(relation); - // if (items.at(item).size() == 0) { - // crp.unincorporate(item); - // items.erase(item); - // } + assert(items.count(item) == 1); + crp.unincorporate(item); + items.erase(item); } int get_cluster_assignment(const T_item& item) const { assert(items.contains(item)); diff --git a/cxx/domain_test.cc b/cxx/domain_test.cc index c8ebcbf..4b3e3ee 100644 --- a/cxx/domain_test.cc +++ b/cxx/domain_test.cc @@ -26,4 +26,9 @@ BOOST_AUTO_TEST_CASE(test_domain) { int ca = d.get_cluster_assignment(apple); BOOST_TEST(ca == 5); BOOST_TEST(cb == 12); + + d.unincorporate(banana); + BOOST_TEST(!d.items.contains(banana)); + BOOST_TEST(d.items.contains(apple)); + BOOST_TEST(d.items.size() == 1); } From 1751fb1168a0f55b7a308da52126b1b870360876 Mon Sep 17 00:00:00 2001 From: Srinivas Vasudevan Date: Thu, 11 Jul 2024 14:48:45 -0700 Subject: [PATCH 2/3] Update irm_test to check unincorporate --- cxx/irm_test.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cxx/irm_test.cc b/cxx/irm_test.cc index 7a172ac..11700f2 100644 --- a/cxx/irm_test.cc +++ b/cxx/irm_test.cc @@ -44,10 +44,8 @@ BOOST_AUTO_TEST_CASE(test_irm) { irm.transition_cluster_assignments_all(&prng); BOOST_TEST(irm.logp_score() == one_obs_score); - // TODO(thomaswc): Uncomment below when relation::unincorporate is - // implemented. - // irm.unincorporate("R1", {1, 2}); - // BOOST_TEST(irm.logp_score() == 0.0); + irm.unincorporate("R1", {1, 2}); + BOOST_TEST(irm.logp_score() == 0.0); irm.incorporate(&prng, "R2", {0, 3}, 1.); irm.incorporate(&prng, "R4", {0, 3, 1}, 1.2); From 57d10d984e3adb9f90bbc1d0ebe2325423480214 Mon Sep 17 00:00:00 2001 From: Srinivas Vasudevan Date: Thu, 11 Jul 2024 15:29:30 -0700 Subject: [PATCH 3/3] Add noisy_relation testing --- cxx/clean_relation.hh | 6 +++--- cxx/noisy_relation.hh | 2 ++ cxx/noisy_relation_test.cc | 38 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/cxx/clean_relation.hh b/cxx/clean_relation.hh index 671251a..69892ca 100644 --- a/cxx/clean_relation.hh +++ b/cxx/clean_relation.hh @@ -103,9 +103,9 @@ class CleanRelation : public Relation { } void unincorporate(const T_items& items) { - assert(data.count(items) == 1); - auto value = data.at(items); - auto z = get_cluster_assignment(items); + assert(data.contains(items)); + ValueType value = data.at(items); + std::vector z = get_cluster_assignment(items); clusters.at(z)->unincorporate(value); if (clusters.at(z)->N == 0) { delete clusters.at(z); diff --git a/cxx/noisy_relation.hh b/cxx/noisy_relation.hh index ce5d0ff..ffff587 100644 --- a/cxx/noisy_relation.hh +++ b/cxx/noisy_relation.hh @@ -66,7 +66,9 @@ class NoisyRelation : public Relation { } void unincorporate(const T_items& items) { + assert(data.contains(items)); emission_relation.unincorporate(items); + data.erase(items); } double logp_gibbs_approx(const Domain& domain, const T_item& item, int table, diff --git a/cxx/noisy_relation_test.cc b/cxx/noisy_relation_test.cc index 3aad85f..6f37e28 100644 --- a/cxx/noisy_relation_test.cc +++ b/cxx/noisy_relation_test.cc @@ -72,3 +72,41 @@ BOOST_AUTO_TEST_CASE(test_noisy_relation) { NR2.set_cluster_assignment_gibbs(D3, 3, 1, &prng); D1.set_cluster_assignment_gibbs(0, 1); } + +BOOST_AUTO_TEST_CASE(test_unincorporate) { + std::mt19937 prng; + Domain D1("D1"); + Domain D2("D2"); + DistributionSpec spec = DistributionSpec("bernoulli"); + CleanRelation R1("R1", spec, {&D1, &D2}); + R1.incorporate(&prng, {0, 1}, 1); + R1.incorporate(&prng, {0, 2}, 1); + R1.incorporate(&prng, {3, 0}, 1); + R1.incorporate(&prng, {3, 1}, 1); + + EmissionSpec em_spec = EmissionSpec("sometimes_bitflip"); + NoisyRelation NR1("NR1", em_spec, {&D1, &D2}, &R1); + + NR1.incorporate(&prng, {0, 1}, 0); + NR1.incorporate(&prng, {0, 2}, 1); + NR1.incorporate(&prng, {3, 0}, 0); + NR1.incorporate(&prng, {3, 1}, 1); + + NR1.unincorporate({3, 1}); + BOOST_TEST(NR1.data.size() == 3); + BOOST_TEST(NR1.data.size() == 3); + // Expect that these are still in the domain since the data points {3, 0} and + // {0, 1} refer to them. + BOOST_TEST(D1.items.contains(3)); + BOOST_TEST(D2.items.contains(1)); + + NR1.unincorporate({0, 2}); + BOOST_TEST(NR1.data.size() == 2); + BOOST_TEST(D1.items.contains(0)); + BOOST_TEST(!D2.items.contains(2)); + + NR1.unincorporate({0, 1}); + BOOST_TEST(NR1.data.size() == 1); + BOOST_TEST(!D1.items.contains(0)); + BOOST_TEST(!D2.items.contains(1)); +}