diff --git a/cxx/assets/flights.schema b/cxx/assets/flights.schema index 2780730..db4b4eb 100644 --- a/cxx/assets/flights.schema +++ b/cxx/assets/flights.schema @@ -5,10 +5,10 @@ class TrackingWebsite name ~ stringcat(strings="aa airtravelcenter allegiantair boston businesstravellogue CO den dfw flightarrival flightaware flightexplorer flights flightstats flightview flightwise flylouisville flytecomm foxbusiness gofox helloflight iad ifly mco mia myrateplan mytripandmore orbitz ord panynj phl quicktrip sfo src travelocity ua usatoday weather world-flight-tracker wunderground") class Time - time ~ string + time ~ string(maxlength=40) class Flight - flight_id ~ string + flight_id ~ string(maxlength=20) # These are all abbreviations for "scheduled/actual arrival/depature time" sdt ~ Time sat ~ Time diff --git a/cxx/assets/flights_dirty.10.csv b/cxx/assets/flights_dirty.10.csv new file mode 100644 index 0000000..183b4ca --- /dev/null +++ b/cxx/assets/flights_dirty.10.csv @@ -0,0 +1,10 @@ +tuple_id,src,flight,sched_dep_time,act_dep_time,sched_arr_time,act_arr_time +1,aa,AA-3859-IAH-ORD,7:10 a.m.,7:16 a.m.,9:40 a.m.,9:32 a.m. +2,aa,AA-1733-ORD-PHX,7:45 p.m.,7:58 p.m.,10:30 p.m., +3,aa,AA-1640-MIA-MCO,6:30 p.m.,,7:25 p.m., +4,aa,AA-518-MIA-JFK,6:40 a.m.,6:54 a.m.,9:25 a.m.,9:28 a.m. +5,aa,AA-3756-ORD-SLC,12:15 p.m.,12:41 p.m.,2:45 p.m.,2:50 p.m. +6,aa,AA-204-LAX-MCO,11:25 p.m.,,12/02/2011 6:55 a.m., +7,aa,AA-3468-CVG-MIA,7:00 a.m.,7:25 a.m.,9:55 a.m.,9:45 a.m. +8,aa,AA-484-DFW-MIA,4:15 p.m.,4:29 p.m.,7:55 p.m.,7:39 p.m. +9,aa,AA-446-DFW-PHL,11:50 a.m.,12:12 p.m.,3:50 p.m.,4:09 p.m. diff --git a/cxx/assets/hospitals.schema b/cxx/assets/hospitals.schema index 7113c7e..7e115dc 100644 --- a/cxx/assets/hospitals.schema +++ b/cxx/assets/hospitals.schema @@ -3,30 +3,30 @@ class County state ~ stringcat(strings="al ak az ar ca co ct de dc fl ga hi id il in ia ks ky la me md ma mi mn ms mo mt ne nv nh nj nm ny nc nd oh ok or pa ri sc sd tn tx ut vt va wa wv wi wy") - county ~ string + county ~ string(maxlength=25) class Place county ~ County - city ~ string + city ~ string(maxlength=25) class Condition - desc ~ string + desc ~ string(maxlength=40) class Measure code ~ stringcat(strings="ami-1 ami-2 ami-3 ami-4 ami-5 ami-7a ami-8a ami-x amix1 amix2 amx-3 amx-4 axi-2 axi-4 cac-1 cac-2 cac-3 hf-1 hf-2 hf-3 hf-4 hf-x hfx1 hfx4 hx-1 hx-2 pn-2 pn-3b pn-4 pn-5c pn-6 pn-7 pn-x pnx5c pnx6 pn-xb px-4 px-5c scip-card-2 scip-inf-1 scip-inf-2 scip-inf-3 scip-inf-4 scip-inf-6 scip-inx-4 scip-vte-1 scip-vte-2 scip-vtx-1 scipxinfx1 scix-inf-2 scxp-xnf-3 sxip-vte-1 xax-1 xf-1") - name ~ string # TODO(thomaswc): Consider using stringcat instead. + name ~ string(maxlength=0) # TODO(thomaswc): Consider using stringcat instead condition ~ Condition class HospitalType - desc ~ string + desc ~ string(maxlength=40) class Hospital loc ~ Place type ~ HospitalType provider ~ typo_int - name ~ string - addr ~ string - phone ~ string + name ~ string(maxlength=50) + addr ~ string(maxlength=50) + phone ~ string(maxlength=15) owner ~ stringcat(strings="government - federal:government - hospital district or authority:government - local:government - state:government - federal:proprietary:voluntary non-profit - church:voluntary non-profit - other:voluntary non-profit - private", delim=":") zip ~ typo_int service ~ stringcat(strings="no yes") diff --git a/cxx/assets/rents.schema b/cxx/assets/rents.schema index bf6fb38..1424dae 100644 --- a/cxx/assets/rents.schema +++ b/cxx/assets/rents.schema @@ -2,7 +2,7 @@ # Based on https://github.com/probcomp/PClean/blob/master/experiments/rents/run.jl class County - name ~ string + name ~ string(maxlength=50) state ~ stringcat(strings="AL AK AZ AR CA CO CT DE DC FL GA HI ID IL IN IA KS KY LA ME MD MA MI MN MS MO MT NE NV NH NJ NM NY NC ND OH OK OR PA RI SC SD TN TX UT VT VA WA WV WI WY") class Obs diff --git a/cxx/assets/rents_dirty.10.csv b/cxx/assets/rents_dirty.10.csv new file mode 100644 index 0000000..943c8cc --- /dev/null +++ b/cxx/assets/rents_dirty.10.csv @@ -0,0 +1,10 @@ +Column1,Room Type,Monthly Rent,County,State +0,studio,486.0,Mahoning County,OH +1,4br,2152.0,Clark County,NV +2,1br,1267.0,Gwinnett County,GA +3,3br,1180.0,Granville County,NC +4,,1436.0,Suffolk County,NY +5,2br,1768.0,Miami-Dade County,FL +6,,585.0,Sebastian County,AR +7,studio,599.0,Lapeer County,MI +8,3br,3056.0,Monterey County,CA diff --git a/cxx/distributions/bigram.cc b/cxx/distributions/bigram.cc index 8f70b70..d79b037 100644 --- a/cxx/distributions/bigram.cc +++ b/cxx/distributions/bigram.cc @@ -4,6 +4,7 @@ #include "distributions/bigram.hh" #include +#include #include "distributions/base.hh" @@ -34,6 +35,11 @@ std::vector Bigram::string_to_indices(const std::string& str) const { } void Bigram::incorporate(const std::string& x, double weight) { + if ((max_length > 0) && (x.length() > max_length)) { + printf("String %s has length %ld, but max length is %ld.\n", + x.c_str(), x.length(), max_length); + std::exit(1); + } const std::vector indices = string_to_indices(x); for (size_t i = 0; i != indices.size() - 1; ++i) { transition_dists[indices[i]].incorporate(indices[i + 1], weight); @@ -66,9 +72,11 @@ double Bigram::logp_score() const { std::string Bigram::sample(std::mt19937* prng) { std::string sampled_string; - // TODO(emilyaf): Reconsider the reserved length and maybe enforce a - // max length. - sampled_string.reserve(30); + if (max_length > 0) { + sampled_string.reserve(max_length); + } else { + sampled_string.reserve(2 * num_chars); + } // Sample the first character conditioned on the stop/start symbol. size_t current_ind = num_chars; size_t next_ind = transition_dists[current_ind].sample(prng); @@ -80,6 +88,9 @@ std::string Bigram::sample(std::mt19937* prng) { // subsequent samples are conditioned on its observation. while (current_ind != num_chars) { sampled_string += index_to_char(current_ind); + if (sampled_string.length() == max_length) { + break; + } next_ind = transition_dists[current_ind].sample(prng); transition_dists[current_ind].incorporate(next_ind); current_ind = next_ind; diff --git a/cxx/distributions/bigram.hh b/cxx/distributions/bigram.hh index 809e826..02ab2a0 100644 --- a/cxx/distributions/bigram.hh +++ b/cxx/distributions/bigram.hh @@ -20,8 +20,9 @@ class Bigram : public Distribution { double alpha = 1; // hyperparameter for all transition distributions. size_t num_chars = '~' - ' ' + 1; // printable ASCII without DEL. mutable std::vector transition_dists; + size_t max_length = 0; // 0 means no maximum length - Bigram() { + Bigram(size_t _max_length = 80): max_length(_max_length) { const size_t total_chars = num_chars + 1; // Include a start/stop symbol. // The distribution at index `i` represents `p(X_{j+1} | X_j == char_i)`. diff --git a/cxx/distributions/bigram_test.cc b/cxx/distributions/bigram_test.cc index ce41920..464c339 100644 --- a/cxx/distributions/bigram_test.cc +++ b/cxx/distributions/bigram_test.cc @@ -2,6 +2,7 @@ #define BOOST_TEST_MODULE test Bigram +#include #include "distributions/bigram.hh" #include @@ -26,6 +27,26 @@ BOOST_AUTO_TEST_CASE(test_simple) { BOOST_TEST(bg.N == 2.23); } +BOOST_AUTO_TEST_CASE(test_max_length) { + std::mt19937 prng; + Bigram bg(5); + + for (int i = 0; i < 10; ++i) { + BOOST_TEST(bg.sample(&prng).length() <= 5); + } +} + +BOOST_AUTO_TEST_CASE(test_max_length0) { + std::mt19937 prng; + Bigram bg(0); + + size_t ml = 0; + for (int i = 0; i < 10; ++i) { + ml = std::max(ml, bg.sample(&prng).length()); + } + BOOST_TEST(ml > bg.num_chars); +} + BOOST_AUTO_TEST_CASE(test_set_alpha) { Bigram bg; diff --git a/cxx/distributions/get_distribution.cc b/cxx/distributions/get_distribution.cc index 076ca59..f7f09d3 100644 --- a/cxx/distributions/get_distribution.cc +++ b/cxx/distributions/get_distribution.cc @@ -71,8 +71,13 @@ DistributionVariant get_prior(const DistributionSpec& spec, switch (spec.distribution) { case DistributionEnum::bernoulli: return new BetaBernoulli; - case DistributionEnum::bigram: - return new Bigram; + case DistributionEnum::bigram: { + size_t max_length = 80; + if (spec.distribution_args.contains("maxlength")) { + max_length = std::stoul(spec.distribution_args.at("maxlength")); + } + return new Bigram(max_length); + } case DistributionEnum::categorical: { assert(spec.distribution_args.size() == 1); int num_categories = std::stoi(spec.distribution_args.at("k")); diff --git a/cxx/distributions/get_distribution_test.cc b/cxx/distributions/get_distribution_test.cc index 2168368..7e7ed74 100644 --- a/cxx/distributions/get_distribution_test.cc +++ b/cxx/distributions/get_distribution_test.cc @@ -21,6 +21,11 @@ BOOST_AUTO_TEST_CASE(test_distribution_spec) { BOOST_TEST((dbg.distribution == DistributionEnum::bigram)); BOOST_TEST(dbg.distribution_args.empty()); + DistributionSpec dbg2 = DistributionSpec("bigram(maxlength=10)"); + BOOST_TEST((dbg2.distribution == DistributionEnum::bigram)); + BOOST_TEST((dbg2.distribution_args.size() == 1)); + BOOST_TEST(dbg2.distribution_args.at("maxlength") == "10"); + DistributionSpec dn = DistributionSpec("normal"); BOOST_TEST((dn.distribution == DistributionEnum::normal)); BOOST_TEST(dn.distribution_args.empty()); @@ -79,6 +84,21 @@ BOOST_AUTO_TEST_CASE(test_get_prior_bigram) { BOOST_TEST(name.find("Bigram") != std::string::npos); } +BOOST_AUTO_TEST_CASE(test_get_prior_bigram2) { + std::mt19937 prng; + + DistributionVariant dv = get_prior(DistributionSpec("bigram(maxlength=2)"), + &prng); + Distribution *d = std::get*>(dv); + std::string name = typeid(*d).name(); + BOOST_TEST(name.find("Bigram") != std::string::npos); + + for (int i = 0; i < 10; i++) { + std::string s = d->sample(&prng); + BOOST_TEST(s.length() <= 2); + } +} + BOOST_AUTO_TEST_CASE(test_get_prior_categorical) { std::mt19937 prng; diff --git a/cxx/integration_tests.sh b/cxx/integration_tests.sh index 9393fae..e65a703 100755 --- a/cxx/integration_tests.sh +++ b/cxx/integration_tests.sh @@ -19,9 +19,9 @@ startt2=$(date +%s) ./bazel-bin/hirm --iters=5 --load=assets/animals.unary.1.hirm assets/animals.unary endt2=$(date +%s) startt3=$(date +%s) -#./bazel-bin/pclean/pclean --schema=assets/flights.schema --obs=assets/flights_dirty.100.csv --iters=5 +./bazel-bin/pclean/pclean --schema=assets/flights.schema --obs=assets/flights_dirty.10.csv --iters=5 ./bazel-bin/pclean/pclean --schema=assets/hospitals.schema --obs=assets/hospital_dirty.10.csv --iters=5 -#./bazel-bin/pclean/pclean --schema=assets/rents.schema --obs=assets/rents_dirty.100.csv --iters=5 +./bazel-bin/pclean/pclean --schema=assets/rents.schema --obs=assets/rents_dirty.10.csv --iters=5 endt3=$(date +%s) echo "Integration tests in /tests ran in $(($endt1-$startt1)) seconds" echo "hirm integration tests ran in $(($endt2-$startt2)) seconds"