diff --git a/cxx/noisy_relation.hh b/cxx/noisy_relation.hh index 53ec1a9..408a5eb 100644 --- a/cxx/noisy_relation.hh +++ b/cxx/noisy_relation.hh @@ -71,8 +71,10 @@ class NoisyRelation : public Relation { ValueType sample_and_incorporate(std::mt19937* prng, const T_items& items) { T_items z = emission_relation.incorporate_items(prng, items); const ValueType& base_val = get_base_value(items); + T_items base_items = get_base_items(items); ValueType value = sample_at_items(prng, items); data[items] = value; + base_to_noisy_items[base_items].insert(items); emission_relation.clusters.at(z)->incorporate( std::make_pair(base_val, value)); emission_relation.data[items] = {base_val, value}; @@ -91,9 +93,18 @@ class NoisyRelation : public Relation { std::make_pair(base_val, value)); } + void cleanup_base_to_noisy(const T_items& items) { + T_items base_items = get_base_items(items); + base_to_noisy_items.at(base_items).erase(items); + if (base_to_noisy_items.at(base_items).empty()) { + base_to_noisy_items.erase(base_items); + } + } + void cleanup_data(const T_items& items) { emission_relation.cleanup_data(items); data.erase(items); + cleanup_base_to_noisy(items); } void cleanup_clusters() { emission_relation.cleanup_clusters(); } @@ -106,6 +117,7 @@ class NoisyRelation : public Relation { assert(data.contains(items)); emission_relation.unincorporate(items); data.erase(items); + cleanup_base_to_noisy(items); } double logp_gibbs_approx(const Domain& domain, const T_item& item, int table, diff --git a/cxx/noisy_relation_test.cc b/cxx/noisy_relation_test.cc index 628138d..4921662 100644 --- a/cxx/noisy_relation_test.cc +++ b/cxx/noisy_relation_test.cc @@ -219,3 +219,30 @@ BOOST_AUTO_TEST_CASE(test_sample_and_incorporate) { BOOST_TEST(NR1.data.size() == 3); BOOST_TEST(R1.data.size() == 1); } + +BOOST_AUTO_TEST_CASE(test_cleanup_data) { + std::mt19937 prng; + Domain D1("D1"); + Domain D2("D2"); + Domain D3("D3"); + + DistributionSpec spec("normal"); + CleanRelation R1("R1", spec, {&D1, &D2}); + + EmissionSpec em_spec("sometimes_gaussian"); + NoisyRelation NR1("NR1", em_spec, {&D1, &D2, &D3}, &R1); + + R1.incorporate(&prng, {0, 1}, 3.); + R1.incorporate(&prng, {0, 2}, 2.8); + NR1.incorporate(&prng, {0, 1, 2}, 2.); + NR1.incorporate(&prng, {0, 2, 3}, 2.1); + BOOST_TEST(R1.data.contains({0, 2})); + BOOST_TEST(NR1.data.contains({0, 2, 3})); + BOOST_TEST(NR1.base_to_noisy_items.at({0, 2}).contains({0, 2, 3})); + + NR1.unincorporate_from_cluster({0, 2, 3}); + NR1.cleanup_data({0, 2, 3}); + BOOST_TEST(R1.data.contains({0, 2})); + BOOST_TEST(!NR1.data.contains({0, 2, 3})); + BOOST_TEST(!NR1.base_to_noisy_items.contains({0, 2})); +}