Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add maxlength option to Bigram string distribution #170

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cxx/assets/flights.schema
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ class TrackingWebsite
name ~ stringcat(strings="aa airtravelcenter allegiantair boston businesstravellogue CO den dfw flightarrival flightaware flightexplorer flights flightstats flightview flightwise flylouisville flytecomm foxbusiness gofox helloflight iad ifly mco mia myrateplan mytripandmore orbitz ord panynj phl quicktrip sfo src travelocity ua usatoday weather world-flight-tracker wunderground")

class Time
time ~ string
time ~ string(maxlength=40)

class Flight
flight_id ~ string
flight_id ~ string(maxlength=20)
# These are all abbreviations for "scheduled/actual arrival/depature time"
sdt ~ Time
sat ~ Time
Expand Down
10 changes: 10 additions & 0 deletions cxx/assets/flights_dirty.10.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
tuple_id,src,flight,sched_dep_time,act_dep_time,sched_arr_time,act_arr_time
1,aa,AA-3859-IAH-ORD,7:10 a.m.,7:16 a.m.,9:40 a.m.,9:32 a.m.
2,aa,AA-1733-ORD-PHX,7:45 p.m.,7:58 p.m.,10:30 p.m.,
3,aa,AA-1640-MIA-MCO,6:30 p.m.,,7:25 p.m.,
4,aa,AA-518-MIA-JFK,6:40 a.m.,6:54 a.m.,9:25 a.m.,9:28 a.m.
5,aa,AA-3756-ORD-SLC,12:15 p.m.,12:41 p.m.,2:45 p.m.,2:50 p.m.
6,aa,AA-204-LAX-MCO,11:25 p.m.,,12/02/2011 6:55 a.m.,
7,aa,AA-3468-CVG-MIA,7:00 a.m.,7:25 a.m.,9:55 a.m.,9:45 a.m.
8,aa,AA-484-DFW-MIA,4:15 p.m.,4:29 p.m.,7:55 p.m.,7:39 p.m.
9,aa,AA-446-DFW-PHL,11:50 a.m.,12:12 p.m.,3:50 p.m.,4:09 p.m.
16 changes: 8 additions & 8 deletions cxx/assets/hospitals.schema
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,30 @@

class County
state ~ stringcat(strings="al ak az ar ca co ct de dc fl ga hi id il in ia ks ky la me md ma mi mn ms mo mt ne nv nh nj nm ny nc nd oh ok or pa ri sc sd tn tx ut vt va wa wv wi wy")
county ~ string
county ~ string(maxlength=25)

class Place
county ~ County
city ~ string
city ~ string(maxlength=25)

class Condition
desc ~ string
desc ~ string(maxlength=40)

class Measure
code ~ stringcat(strings="ami-1 ami-2 ami-3 ami-4 ami-5 ami-7a ami-8a ami-x amix1 amix2 amx-3 amx-4 axi-2 axi-4 cac-1 cac-2 cac-3 hf-1 hf-2 hf-3 hf-4 hf-x hfx1 hfx4 hx-1 hx-2 pn-2 pn-3b pn-4 pn-5c pn-6 pn-7 pn-x pnx5c pnx6 pn-xb px-4 px-5c scip-card-2 scip-inf-1 scip-inf-2 scip-inf-3 scip-inf-4 scip-inf-6 scip-inx-4 scip-vte-1 scip-vte-2 scip-vtx-1 scipxinfx1 scix-inf-2 scxp-xnf-3 sxip-vte-1 xax-1 xf-1")
name ~ string # TODO(thomaswc): Consider using stringcat instead.
name ~ string(maxlength=0) # TODO(thomaswc): Consider using stringcat instead
condition ~ Condition

class HospitalType
desc ~ string
desc ~ string(maxlength=40)

class Hospital
loc ~ Place
type ~ HospitalType
provider ~ typo_int
name ~ string
addr ~ string
phone ~ string
name ~ string(maxlength=50)
addr ~ string(maxlength=50)
phone ~ string(maxlength=15)
owner ~ stringcat(strings="government - federal:government - hospital district or authority:government - local:government - state:government - federal:proprietary:voluntary non-profit - church:voluntary non-profit - other:voluntary non-profit - private", delim=":")
zip ~ typo_int
service ~ stringcat(strings="no yes")
Expand Down
2 changes: 1 addition & 1 deletion cxx/assets/rents.schema
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Based on https://github.com/probcomp/PClean/blob/master/experiments/rents/run.jl

class County
name ~ string
name ~ string(maxlength=50)
state ~ stringcat(strings="AL AK AZ AR CA CO CT DE DC FL GA HI ID IL IN IA KS KY LA ME MD MA MI MN MS MO MT NE NV NH NJ NM NY NC ND OH OK OR PA RI SC SD TN TX UT VT VA WA WV WI WY")

class Obs
Expand Down
10 changes: 10 additions & 0 deletions cxx/assets/rents_dirty.10.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Column1,Room Type,Monthly Rent,County,State
0,studio,486.0,Mahoning County,OH
1,4br,2152.0,Clark County,NV
2,1br,1267.0,Gwinnett County,GA
3,3br,1180.0,Granville County,NC
4,,1436.0,Suffolk County,NY
5,2br,1768.0,Miami-Dade County,FL
6,,585.0,Sebastian County,AR
7,studio,599.0,Lapeer County,MI
8,3br,3056.0,Monterey County,CA
17 changes: 14 additions & 3 deletions cxx/distributions/bigram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "distributions/bigram.hh"

#include <cassert>
#include <cstdlib>

#include "distributions/base.hh"

Expand Down Expand Up @@ -34,6 +35,11 @@ std::vector<size_t> Bigram::string_to_indices(const std::string& str) const {
}

void Bigram::incorporate(const std::string& x, double weight) {
if ((max_length > 0) && (x.length() > max_length)) {
printf("String %s has length %ld, but max length is %ld.\n",
x.c_str(), x.length(), max_length);
std::exit(1);
}
const std::vector<size_t> indices = string_to_indices(x);
for (size_t i = 0; i != indices.size() - 1; ++i) {
transition_dists[indices[i]].incorporate(indices[i + 1], weight);
Expand Down Expand Up @@ -66,9 +72,11 @@ double Bigram::logp_score() const {

std::string Bigram::sample(std::mt19937* prng) {
std::string sampled_string;
// TODO(emilyaf): Reconsider the reserved length and maybe enforce a
// max length.
sampled_string.reserve(30);
if (max_length > 0) {
sampled_string.reserve(max_length);
} else {
sampled_string.reserve(2 * num_chars);
}
// Sample the first character conditioned on the stop/start symbol.
size_t current_ind = num_chars;
size_t next_ind = transition_dists[current_ind].sample(prng);
Expand All @@ -80,6 +88,9 @@ std::string Bigram::sample(std::mt19937* prng) {
// subsequent samples are conditioned on its observation.
while (current_ind != num_chars) {
sampled_string += index_to_char(current_ind);
if (sampled_string.length() == max_length) {
break;
}
next_ind = transition_dists[current_ind].sample(prng);
transition_dists[current_ind].incorporate(next_ind);
current_ind = next_ind;
Expand Down
3 changes: 2 additions & 1 deletion cxx/distributions/bigram.hh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ class Bigram : public Distribution<std::string> {
double alpha = 1; // hyperparameter for all transition distributions.
size_t num_chars = '~' - ' ' + 1; // printable ASCII without DEL.
mutable std::vector<DirichletCategorical> transition_dists;
size_t max_length = 0; // 0 means no maximum length

Bigram() {
Bigram(size_t _max_length = 80): max_length(_max_length) {
const size_t total_chars = num_chars + 1; // Include a start/stop symbol.

// The distribution at index `i` represents `p(X_{j+1} | X_j == char_i)`.
Expand Down
21 changes: 21 additions & 0 deletions cxx/distributions/bigram_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#define BOOST_TEST_MODULE test Bigram

#include <algorithm>
#include "distributions/bigram.hh"

#include <boost/test/included/unit_test.hpp>
Expand All @@ -26,6 +27,26 @@ BOOST_AUTO_TEST_CASE(test_simple) {
BOOST_TEST(bg.N == 2.23);
}

BOOST_AUTO_TEST_CASE(test_max_length) {
std::mt19937 prng;
Bigram bg(5);

for (int i = 0; i < 10; ++i) {
BOOST_TEST(bg.sample(&prng).length() <= 5);
}
}

BOOST_AUTO_TEST_CASE(test_max_length0) {
std::mt19937 prng;
Bigram bg(0);

size_t ml = 0;
for (int i = 0; i < 10; ++i) {
ml = std::max(ml, bg.sample(&prng).length());
}
BOOST_TEST(ml > bg.num_chars);
}

BOOST_AUTO_TEST_CASE(test_set_alpha) {
Bigram bg;

Expand Down
9 changes: 7 additions & 2 deletions cxx/distributions/get_distribution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,13 @@ DistributionVariant get_prior(const DistributionSpec& spec,
switch (spec.distribution) {
case DistributionEnum::bernoulli:
return new BetaBernoulli;
case DistributionEnum::bigram:
return new Bigram;
case DistributionEnum::bigram: {
size_t max_length = 80;
if (spec.distribution_args.contains("maxlength")) {
max_length = std::stoul(spec.distribution_args.at("maxlength"));
}
return new Bigram(max_length);
}
case DistributionEnum::categorical: {
assert(spec.distribution_args.size() == 1);
int num_categories = std::stoi(spec.distribution_args.at("k"));
Expand Down
20 changes: 20 additions & 0 deletions cxx/distributions/get_distribution_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ BOOST_AUTO_TEST_CASE(test_distribution_spec) {
BOOST_TEST((dbg.distribution == DistributionEnum::bigram));
BOOST_TEST(dbg.distribution_args.empty());

DistributionSpec dbg2 = DistributionSpec("bigram(maxlength=10)");
BOOST_TEST((dbg2.distribution == DistributionEnum::bigram));
BOOST_TEST((dbg2.distribution_args.size() == 1));
BOOST_TEST(dbg2.distribution_args.at("maxlength") == "10");

DistributionSpec dn = DistributionSpec("normal");
BOOST_TEST((dn.distribution == DistributionEnum::normal));
BOOST_TEST(dn.distribution_args.empty());
Expand Down Expand Up @@ -79,6 +84,21 @@ BOOST_AUTO_TEST_CASE(test_get_prior_bigram) {
BOOST_TEST(name.find("Bigram") != std::string::npos);
}

BOOST_AUTO_TEST_CASE(test_get_prior_bigram2) {
std::mt19937 prng;

DistributionVariant dv = get_prior(DistributionSpec("bigram(maxlength=2)"),
&prng);
Distribution<std::string> *d = std::get<Distribution<std::string>*>(dv);
std::string name = typeid(*d).name();
BOOST_TEST(name.find("Bigram") != std::string::npos);

for (int i = 0; i < 10; i++) {
std::string s = d->sample(&prng);
BOOST_TEST(s.length() <= 2);
}
}

BOOST_AUTO_TEST_CASE(test_get_prior_categorical) {
std::mt19937 prng;

Expand Down
4 changes: 2 additions & 2 deletions cxx/integration_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ startt2=$(date +%s)
./bazel-bin/hirm --iters=5 --load=assets/animals.unary.1.hirm assets/animals.unary
endt2=$(date +%s)
startt3=$(date +%s)
#./bazel-bin/pclean/pclean --schema=assets/flights.schema --obs=assets/flights_dirty.100.csv --iters=5
./bazel-bin/pclean/pclean --schema=assets/flights.schema --obs=assets/flights_dirty.10.csv --iters=5
./bazel-bin/pclean/pclean --schema=assets/hospitals.schema --obs=assets/hospital_dirty.10.csv --iters=5
#./bazel-bin/pclean/pclean --schema=assets/rents.schema --obs=assets/rents_dirty.100.csv --iters=5
./bazel-bin/pclean/pclean --schema=assets/rents.schema --obs=assets/rents_dirty.10.csv --iters=5
endt3=$(date +%s)
echo "Integration tests in /tests ran in $(($endt1-$startt1)) seconds"
echo "hirm integration tests ran in $(($endt2-$startt2)) seconds"
Expand Down
Loading