diff --git a/cxx/BUILD b/cxx/BUILD index ad77bc9..72e5aef 100644 --- a/cxx/BUILD +++ b/cxx/BUILD @@ -93,6 +93,7 @@ cc_library( deps = [ ":domain", "//distributions", + "@boost//:algorithm", ], ) diff --git a/cxx/distributions/BUILD b/cxx/distributions/BUILD index caa384d..6d8bf24 100644 --- a/cxx/distributions/BUILD +++ b/cxx/distributions/BUILD @@ -11,6 +11,7 @@ cc_library( ":dirichlet_categorical", ":normal", ":skellam", + ":stringcat", ], ) @@ -94,6 +95,16 @@ cc_library( ], ) +cc_library( + name = "stringcat", + srcs = ["stringcat.cc"], + hdrs = ["stringcat.hh"], + deps = [ + ":base", + ":dirichlet_categorical", + ], +) + cc_library( name = "zero_mean_normal", srcs = ["zero_mean_normal.cc"], @@ -169,6 +180,15 @@ cc_test( ], ) +cc_test( + name = "stringcat_test", + srcs = ["stringcat_test.cc"], + deps = [ + ":stringcat", + "@boost//:test", + ], +) + cc_test( name = "zero_mean_normal_test", srcs = ["zero_mean_normal_test.cc"], diff --git a/cxx/distributions/stringcat.cc b/cxx/distributions/stringcat.cc new file mode 100644 index 0000000..132f8fa --- /dev/null +++ b/cxx/distributions/stringcat.cc @@ -0,0 +1,40 @@ +// Copyright 2024 +// See LICENSE.txt + +#include +#include +#include "distributions/stringcat.hh" + +int StringCat::string_to_index(const std::string& s) const { + auto it = std::find(strings.begin(), strings.end(), s); + if (it == strings.end()) { + assert(false); + } + return it - strings.begin(); +} + +void StringCat::incorporate(const std::string& s) { + dc.incorporate(string_to_index(s)); + ++N; +} + +void StringCat::unincorporate(const std::string& s) { + dc.unincorporate(string_to_index(s)); + --N; +} + +double StringCat::logp(const std::string& s) const { + return dc.logp(string_to_index(s)); +} + +double StringCat::logp_score() const { + return dc.logp_score(); +} + +std::string StringCat::sample(std::mt19937* prng) { + return strings[dc.sample(prng)]; +} + +void StringCat::transition_hyperparameters(std::mt19937* prng) { + dc.transition_hyperparameters(prng); +} diff --git a/cxx/distributions/stringcat.hh b/cxx/distributions/stringcat.hh new file mode 100644 index 0000000..8fe1401 --- /dev/null +++ b/cxx/distributions/stringcat.hh @@ -0,0 +1,36 @@ +// Copyright 2024 +// See LICENSE.txt + +#pragma once + +#include +#include + +#include "distributions/base.hh" +#include "distributions/dirichlet_categorical.hh" + +// A distribution over a finite set of strings. +class StringCat : public Distribution { + public: + std::vector strings; + DirichletCategorical dc; + + // Each element of vs should be distinct. + StringCat(const std::vector &vs) : strings(vs), dc(vs.size()) {}; + + int string_to_index(const std::string& s) const; + + void incorporate(const std::string& s); + + void unincorporate(const std::string& s); + + double logp(const std::string& s) const; + + double logp_score() const; + + std::string sample(std::mt19937* prng); + + void set_alpha(double alphat); + + void transition_hyperparameters(std::mt19937* prng); +}; diff --git a/cxx/distributions/stringcat_test.cc b/cxx/distributions/stringcat_test.cc new file mode 100644 index 0000000..c23f71c --- /dev/null +++ b/cxx/distributions/stringcat_test.cc @@ -0,0 +1,33 @@ +// Apache License, Version 2.0, refer to LICENSE.txt + +#define BOOST_TEST_MODULE test StringCat + +#include "distributions/stringcat.hh" + +#include +namespace tt = boost::test_tools; + +BOOST_AUTO_TEST_CASE(test_simple) { + std::vector strings = { + "hello", "world", "train", "test", "other"}; + StringCat sc(strings); + + sc.incorporate("hello"); + sc.incorporate("world"); + BOOST_TEST(sc.N == 2); + sc.unincorporate("hello"); + BOOST_TEST(sc.N == 1); + sc.incorporate("train"); + sc.unincorporate("world"); + BOOST_TEST(sc.N == 1); + + BOOST_TEST(sc.logp("test") == -1.791759469228055, tt::tolerance(1e-6)); + BOOST_TEST(sc.logp_score() == -1.6094379124341001, tt::tolerance(1e-6)); + + std::mt19937 prng; + std::string samp = sc.sample(&prng); + + auto it = std::find(strings.begin(), strings.end(), samp); + bool found = (it != strings.end()); + BOOST_TEST(found); +} diff --git a/cxx/util_distribution_variant.cc b/cxx/util_distribution_variant.cc index 1f6b8d3..fe4155e 100644 --- a/cxx/util_distribution_variant.cc +++ b/cxx/util_distribution_variant.cc @@ -1,17 +1,17 @@ // Copyright 2024 // See LICENSE.txt -#include "util_distribution_variant.hh" - #include #include - +#include +#include "util_distribution_variant.hh" #include "distributions/beta_bernoulli.hh" #include "distributions/bigram.hh" -#include "distributions/crp.hh" #include "distributions/dirichlet_categorical.hh" #include "distributions/normal.hh" #include "distributions/skellam.hh" +#include "distributions/stringcat.hh" + ObservationVariant observation_string_to_value( const std::string& value_str, const DistributionEnum& distribution) { @@ -24,6 +24,7 @@ ObservationVariant observation_string_to_value( case DistributionEnum::skellam: return std::stoi(value_str); case DistributionEnum::bigram: + case DistributionEnum::stringcat: return value_str; default: assert(false && "Unsupported distribution enum value."); @@ -36,7 +37,8 @@ DistributionSpec parse_distribution_spec(const std::string& dist_str) { {"bigram", DistributionEnum::bigram}, {"categorical", DistributionEnum::categorical}, {"normal", DistributionEnum::normal}, - {"skellam", DistributionEnum::skellam} + {"skellam", DistributionEnum::skellam}, + {"stringcat", DistributionEnum::stringcat} }; std::string dist_name = dist_str.substr(0, dist_str.find('(')); DistributionEnum dist = dist_name_to_enum.at(dist_name); @@ -81,6 +83,18 @@ DistributionVariant cluster_prior_from_spec( s->init_theta(prng); return s; } + case DistributionEnum::stringcat: { + std::string delim = " "; // Default deliminator + auto it = spec.distribution_args.find("delim"); + if (it != spec.distribution_args.end()) { + delim = it->second; + assert(delim.length() == 1); + } + std::vector strings; + boost::split(strings, spec.distribution_args.at("strings"), + boost::is_any_of(delim)); + return new StringCat(strings); + } default: assert(false && "Unsupported distribution enum value."); } diff --git a/cxx/util_distribution_variant.hh b/cxx/util_distribution_variant.hh index 7f50bc1..3ee3e5e 100644 --- a/cxx/util_distribution_variant.hh +++ b/cxx/util_distribution_variant.hh @@ -17,8 +17,10 @@ #include "distributions/dirichlet_categorical.hh" #include "distributions/normal.hh" #include "distributions/skellam.hh" +#include "distributions/stringcat.hh" -enum class DistributionEnum { bernoulli, bigram, categorical, normal, skellam }; +enum class DistributionEnum { + bernoulli, bigram, categorical, normal, skellam, stringcat }; struct DistributionSpec { DistributionEnum distribution; @@ -30,7 +32,7 @@ using ObservationVariant = std::variant; using DistributionVariant = std::variant; + Skellam*, StringCat*>; ObservationVariant observation_string_to_value( const std::string& value_str, const DistributionEnum& distribution); diff --git a/cxx/util_distribution_variant_test.cc b/cxx/util_distribution_variant_test.cc index 5615039..556ce48 100644 --- a/cxx/util_distribution_variant_test.cc +++ b/cxx/util_distribution_variant_test.cc @@ -15,6 +15,8 @@ namespace tt = boost::test_tools; BOOST_AUTO_TEST_CASE(test_parse_distribution_spec) { + std::mt19937 prng; + DistributionSpec dbb = parse_distribution_spec("bernoulli"); BOOST_TEST((dbb.distribution == DistributionEnum::bernoulli)); BOOST_TEST(dbb.distribution_args.empty()); @@ -36,6 +38,21 @@ BOOST_AUTO_TEST_CASE(test_parse_distribution_spec) { BOOST_TEST((dc.distribution_args.size() == 1)); std::string expected = "6"; BOOST_CHECK_EQUAL(dc.distribution_args.at("k"), expected); + + DistributionSpec dsc = parse_distribution_spec("stringcat(strings=a b c d)"); + BOOST_TEST((dsc.distribution == DistributionEnum::stringcat)); + BOOST_TEST((dsc.distribution_args.size() == 1)); + BOOST_CHECK_EQUAL(dsc.distribution_args.at("strings"), "a b c d"); + DistributionVariant dv = cluster_prior_from_spec(dsc, &prng); + BOOST_TEST(std::get(dv)->strings.size() == 4); + + DistributionSpec dsc2 = parse_distribution_spec( + "stringcat(strings=yes:no,delim=:)"); + BOOST_TEST((dsc2.distribution == DistributionEnum::stringcat)); + BOOST_TEST((dsc2.distribution_args.size() == 2)); + BOOST_CHECK_EQUAL(dsc2.distribution_args.at("strings"), "yes:no"); + DistributionVariant dv2 = cluster_prior_from_spec(dsc2, &prng); + BOOST_TEST(std::get(dv2)->strings.size() == 2); } BOOST_AUTO_TEST_CASE(test_cluster_prior_from_spec) {