Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StringCat distribution for a categorical distribution over a finite set of strings #73

Merged
merged 4 commits into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cxx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ cc_library(
deps = [
":domain",
"//distributions",
"@boost//:algorithm",
],
)

Expand Down
20 changes: 20 additions & 0 deletions cxx/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ cc_library(
":dirichlet_categorical",
":normal",
":skellam",
":stringcat",
],
)

Expand Down Expand Up @@ -94,6 +95,16 @@ cc_library(
],
)

cc_library(
name = "stringcat",
srcs = ["stringcat.cc"],
hdrs = ["stringcat.hh"],
deps = [
":base",
":dirichlet_categorical",
],
)

cc_library(
name = "zero_mean_normal",
srcs = ["zero_mean_normal.cc"],
Expand Down Expand Up @@ -169,6 +180,15 @@ cc_test(
],
)

cc_test(
name = "stringcat_test",
srcs = ["stringcat_test.cc"],
deps = [
":stringcat",
"@boost//:test",
],
)

cc_test(
name = "zero_mean_normal_test",
srcs = ["zero_mean_normal_test.cc"],
Expand Down
40 changes: 40 additions & 0 deletions cxx/distributions/stringcat.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright 2024
// See LICENSE.txt

#include <algorithm>
#include <cassert>
#include "distributions/stringcat.hh"

int StringCat::string_to_index(const std::string& s) const {
auto it = std::find(strings.begin(), strings.end(), s);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT of instead building a map<string, int> in the ctor for faster lookup?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this, but my expectation is that most of the time, the number of strings in the class will be small enough that the speed difference will be minimal. And saving space by not creating a map isn't entirely inconsequential when running on large datasets. (Keeping in mind that the number of distinct values in a column isn't the same as the number of rows -- I'm saying that the first will probably be small, but the second might be large. And because we cluster rows and create Distributions per cluster, that drives the number of instances of this class that get instantiated. In fact, we might want to consider designs where this class doesn't store its own copy of the vector of strings, but that's a pull request for another day.)

if (it == strings.end()) {
assert(false);
}
return it - strings.begin();
}

void StringCat::incorporate(const std::string& s) {
dc.incorporate(string_to_index(s));
++N;
}

void StringCat::unincorporate(const std::string& s) {
dc.unincorporate(string_to_index(s));
--N;
}

double StringCat::logp(const std::string& s) const {
return dc.logp(string_to_index(s));
}

double StringCat::logp_score() const {
return dc.logp_score();
}

std::string StringCat::sample(std::mt19937* prng) {
return strings[dc.sample(prng)];
}

void StringCat::transition_hyperparameters(std::mt19937* prng) {
dc.transition_hyperparameters(prng);
}
36 changes: 36 additions & 0 deletions cxx/distributions/stringcat.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright 2024
// See LICENSE.txt

#pragma once

#include <string>
#include <vector>

#include "distributions/base.hh"
#include "distributions/dirichlet_categorical.hh"

// A distribution over a finite set of strings.
class StringCat : public Distribution<std::string> {
public:
std::vector<std::string> strings;
DirichletCategorical dc;

// Each element of vs should be distinct.
StringCat(const std::vector<std::string> &vs) : strings(vs), dc(vs.size()) {};

int string_to_index(const std::string& s) const;

void incorporate(const std::string& s);

void unincorporate(const std::string& s);

double logp(const std::string& s) const;

double logp_score() const;

std::string sample(std::mt19937* prng);

void set_alpha(double alphat);

void transition_hyperparameters(std::mt19937* prng);
};
33 changes: 33 additions & 0 deletions cxx/distributions/stringcat_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Apache License, Version 2.0, refer to LICENSE.txt

#define BOOST_TEST_MODULE test StringCat

#include "distributions/stringcat.hh"

#include <boost/test/included/unit_test.hpp>
namespace tt = boost::test_tools;

BOOST_AUTO_TEST_CASE(test_simple) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you call sample in the test somewhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

std::vector<std::string> strings = {
"hello", "world", "train", "test", "other"};
StringCat sc(strings);

sc.incorporate("hello");
sc.incorporate("world");
BOOST_TEST(sc.N == 2);
sc.unincorporate("hello");
BOOST_TEST(sc.N == 1);
sc.incorporate("train");
sc.unincorporate("world");
BOOST_TEST(sc.N == 1);

BOOST_TEST(sc.logp("test") == -1.791759469228055, tt::tolerance(1e-6));
BOOST_TEST(sc.logp_score() == -1.6094379124341001, tt::tolerance(1e-6));

std::mt19937 prng;
std::string samp = sc.sample(&prng);

auto it = std::find(strings.begin(), strings.end(), samp);
bool found = (it != strings.end());
BOOST_TEST(found);
}
24 changes: 19 additions & 5 deletions cxx/util_distribution_variant.cc
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
// Copyright 2024
// See LICENSE.txt

#include "util_distribution_variant.hh"

#include <cassert>
#include <sstream>

#include <boost/algorithm/string.hpp>
#include "util_distribution_variant.hh"
#include "distributions/beta_bernoulli.hh"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked this up the other day and the style guide says to include these: https://engdoc.corp.google.com/eng/doc/devguide/cpp/styleguide.md?cl=head#Include_What_You_Use (IMO we might as well follow that but I don't feel strongly)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

#include "distributions/bigram.hh"
#include "distributions/crp.hh"
#include "distributions/dirichlet_categorical.hh"
#include "distributions/normal.hh"
#include "distributions/skellam.hh"
#include "distributions/stringcat.hh"


ObservationVariant observation_string_to_value(
const std::string& value_str, const DistributionEnum& distribution) {
Expand All @@ -24,6 +24,7 @@ ObservationVariant observation_string_to_value(
case DistributionEnum::skellam:
return std::stoi(value_str);
case DistributionEnum::bigram:
case DistributionEnum::stringcat:
return value_str;
default:
assert(false && "Unsupported distribution enum value.");
Expand All @@ -36,7 +37,8 @@ DistributionSpec parse_distribution_spec(const std::string& dist_str) {
{"bigram", DistributionEnum::bigram},
{"categorical", DistributionEnum::categorical},
{"normal", DistributionEnum::normal},
{"skellam", DistributionEnum::skellam}
{"skellam", DistributionEnum::skellam},
{"stringcat", DistributionEnum::stringcat}
};
std::string dist_name = dist_str.substr(0, dist_str.find('('));
DistributionEnum dist = dist_name_to_enum.at(dist_name);
Expand Down Expand Up @@ -81,6 +83,18 @@ DistributionVariant cluster_prior_from_spec(
s->init_theta(prng);
return s;
}
case DistributionEnum::stringcat: {
std::string delim = " "; // Default deliminator
auto it = spec.distribution_args.find("delim");
if (it != spec.distribution_args.end()) {
delim = it->second;
assert(delim.length() == 1);
}
std::vector<std::string> strings;
boost::split(strings, spec.distribution_args.at("strings"),
boost::is_any_of(delim));
return new StringCat(strings);
}
default:
assert(false && "Unsupported distribution enum value.");
}
Expand Down
6 changes: 4 additions & 2 deletions cxx/util_distribution_variant.hh
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
#include "distributions/dirichlet_categorical.hh"
#include "distributions/normal.hh"
#include "distributions/skellam.hh"
#include "distributions/stringcat.hh"

enum class DistributionEnum { bernoulli, bigram, categorical, normal, skellam };
enum class DistributionEnum {
bernoulli, bigram, categorical, normal, skellam, stringcat };

struct DistributionSpec {
DistributionEnum distribution;
Expand All @@ -30,7 +32,7 @@ using ObservationVariant = std::variant<double, int, bool, std::string>;

using DistributionVariant =
std::variant<BetaBernoulli*, Bigram*, DirichletCategorical*, Normal*,
Skellam*>;
Skellam*, StringCat*>;

ObservationVariant observation_string_to_value(
const std::string& value_str, const DistributionEnum& distribution);
Expand Down
17 changes: 17 additions & 0 deletions cxx/util_distribution_variant_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
namespace tt = boost::test_tools;

BOOST_AUTO_TEST_CASE(test_parse_distribution_spec) {
std::mt19937 prng;

DistributionSpec dbb = parse_distribution_spec("bernoulli");
BOOST_TEST((dbb.distribution == DistributionEnum::bernoulli));
BOOST_TEST(dbb.distribution_args.empty());
Expand All @@ -36,6 +38,21 @@ BOOST_AUTO_TEST_CASE(test_parse_distribution_spec) {
BOOST_TEST((dc.distribution_args.size() == 1));
std::string expected = "6";
BOOST_CHECK_EQUAL(dc.distribution_args.at("k"), expected);

DistributionSpec dsc = parse_distribution_spec("stringcat(strings=a b c d)");
BOOST_TEST((dsc.distribution == DistributionEnum::stringcat));
BOOST_TEST((dsc.distribution_args.size() == 1));
BOOST_CHECK_EQUAL(dsc.distribution_args.at("strings"), "a b c d");
DistributionVariant dv = cluster_prior_from_spec(dsc, &prng);
BOOST_TEST(std::get<StringCat*>(dv)->strings.size() == 4);

DistributionSpec dsc2 = parse_distribution_spec(
"stringcat(strings=yes:no,delim=:)");
BOOST_TEST((dsc2.distribution == DistributionEnum::stringcat));
BOOST_TEST((dsc2.distribution_args.size() == 2));
BOOST_CHECK_EQUAL(dsc2.distribution_args.at("strings"), "yes:no");
DistributionVariant dv2 = cluster_prior_from_spec(dsc2, &prng);
BOOST_TEST(std::get<StringCat*>(dv2)->strings.size() == 2);
}

BOOST_AUTO_TEST_CASE(test_cluster_prior_from_spec) {
Expand Down