-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add StringCat distribution for a categorical distribution over a finite set of strings #73
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,6 +93,7 @@ cc_library( | |
deps = [ | ||
":domain", | ||
"//distributions", | ||
"@boost//:algorithm", | ||
], | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// Copyright 2024 | ||
// See LICENSE.txt | ||
|
||
#include <algorithm> | ||
#include <cassert> | ||
#include "distributions/stringcat.hh" | ||
|
||
int StringCat::string_to_index(const std::string& s) const { | ||
auto it = std::find(strings.begin(), strings.end(), s); | ||
if (it == strings.end()) { | ||
assert(false); | ||
} | ||
return it - strings.begin(); | ||
} | ||
|
||
void StringCat::incorporate(const std::string& s) { | ||
dc.incorporate(string_to_index(s)); | ||
++N; | ||
} | ||
|
||
void StringCat::unincorporate(const std::string& s) { | ||
dc.unincorporate(string_to_index(s)); | ||
--N; | ||
} | ||
|
||
double StringCat::logp(const std::string& s) const { | ||
return dc.logp(string_to_index(s)); | ||
} | ||
|
||
double StringCat::logp_score() const { | ||
return dc.logp_score(); | ||
} | ||
|
||
std::string StringCat::sample(std::mt19937* prng) { | ||
return strings[dc.sample(prng)]; | ||
} | ||
|
||
void StringCat::transition_hyperparameters(std::mt19937* prng) { | ||
dc.transition_hyperparameters(prng); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// Copyright 2024 | ||
// See LICENSE.txt | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#include "distributions/base.hh" | ||
#include "distributions/dirichlet_categorical.hh" | ||
|
||
// A distribution over a finite set of strings. | ||
class StringCat : public Distribution<std::string> { | ||
public: | ||
std::vector<std::string> strings; | ||
DirichletCategorical dc; | ||
|
||
// Each element of vs should be distinct. | ||
StringCat(const std::vector<std::string> &vs) : strings(vs), dc(vs.size()) {}; | ||
|
||
int string_to_index(const std::string& s) const; | ||
|
||
void incorporate(const std::string& s); | ||
|
||
void unincorporate(const std::string& s); | ||
|
||
double logp(const std::string& s) const; | ||
|
||
double logp_score() const; | ||
|
||
std::string sample(std::mt19937* prng); | ||
|
||
void set_alpha(double alphat); | ||
|
||
void transition_hyperparameters(std::mt19937* prng); | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Apache License, Version 2.0, refer to LICENSE.txt | ||
|
||
#define BOOST_TEST_MODULE test StringCat | ||
|
||
#include "distributions/stringcat.hh" | ||
|
||
#include <boost/test/included/unit_test.hpp> | ||
namespace tt = boost::test_tools; | ||
|
||
BOOST_AUTO_TEST_CASE(test_simple) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
std::vector<std::string> strings = { | ||
"hello", "world", "train", "test", "other"}; | ||
StringCat sc(strings); | ||
|
||
sc.incorporate("hello"); | ||
sc.incorporate("world"); | ||
BOOST_TEST(sc.N == 2); | ||
sc.unincorporate("hello"); | ||
BOOST_TEST(sc.N == 1); | ||
sc.incorporate("train"); | ||
sc.unincorporate("world"); | ||
BOOST_TEST(sc.N == 1); | ||
|
||
BOOST_TEST(sc.logp("test") == -1.791759469228055, tt::tolerance(1e-6)); | ||
BOOST_TEST(sc.logp_score() == -1.6094379124341001, tt::tolerance(1e-6)); | ||
|
||
std::mt19937 prng; | ||
std::string samp = sc.sample(&prng); | ||
|
||
auto it = std::find(strings.begin(), strings.end(), samp); | ||
bool found = (it != strings.end()); | ||
BOOST_TEST(found); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,17 @@ | ||
// Copyright 2024 | ||
// See LICENSE.txt | ||
|
||
#include "util_distribution_variant.hh" | ||
|
||
#include <cassert> | ||
#include <sstream> | ||
|
||
#include <boost/algorithm/string.hpp> | ||
#include "util_distribution_variant.hh" | ||
#include "distributions/beta_bernoulli.hh" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I looked this up the other day and the style guide says to include these: https://engdoc.corp.google.com/eng/doc/devguide/cpp/styleguide.md?cl=head#Include_What_You_Use (IMO we might as well follow that but I don't feel strongly) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
#include "distributions/bigram.hh" | ||
#include "distributions/crp.hh" | ||
#include "distributions/dirichlet_categorical.hh" | ||
#include "distributions/normal.hh" | ||
#include "distributions/skellam.hh" | ||
#include "distributions/stringcat.hh" | ||
|
||
|
||
ObservationVariant observation_string_to_value( | ||
const std::string& value_str, const DistributionEnum& distribution) { | ||
|
@@ -24,6 +24,7 @@ ObservationVariant observation_string_to_value( | |
case DistributionEnum::skellam: | ||
return std::stoi(value_str); | ||
case DistributionEnum::bigram: | ||
case DistributionEnum::stringcat: | ||
return value_str; | ||
default: | ||
assert(false && "Unsupported distribution enum value."); | ||
|
@@ -36,7 +37,8 @@ DistributionSpec parse_distribution_spec(const std::string& dist_str) { | |
{"bigram", DistributionEnum::bigram}, | ||
{"categorical", DistributionEnum::categorical}, | ||
{"normal", DistributionEnum::normal}, | ||
{"skellam", DistributionEnum::skellam} | ||
{"skellam", DistributionEnum::skellam}, | ||
{"stringcat", DistributionEnum::stringcat} | ||
}; | ||
std::string dist_name = dist_str.substr(0, dist_str.find('(')); | ||
DistributionEnum dist = dist_name_to_enum.at(dist_name); | ||
|
@@ -81,6 +83,18 @@ DistributionVariant cluster_prior_from_spec( | |
s->init_theta(prng); | ||
return s; | ||
} | ||
case DistributionEnum::stringcat: { | ||
std::string delim = " "; // Default deliminator | ||
auto it = spec.distribution_args.find("delim"); | ||
if (it != spec.distribution_args.end()) { | ||
delim = it->second; | ||
assert(delim.length() == 1); | ||
} | ||
std::vector<std::string> strings; | ||
boost::split(strings, spec.distribution_args.at("strings"), | ||
boost::is_any_of(delim)); | ||
return new StringCat(strings); | ||
} | ||
default: | ||
assert(false && "Unsupported distribution enum value."); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT of instead building a
map<string, int>
in the ctor for faster lookup?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this, but my expectation is that most of the time, the number of strings in the class will be small enough that the speed difference will be minimal. And saving space by not creating a map isn't entirely inconsequential when running on large datasets. (Keeping in mind that the number of distinct values in a column isn't the same as the number of rows -- I'm saying that the first will probably be small, but the second might be large. And because we cluster rows and create Distributions per cluster, that drives the number of instances of this class that get instantiated. In fact, we might want to consider designs where this class doesn't store its own copy of the vector of strings, but that's a pull request for another day.)