diff --git a/cxx/emissions/BUILD b/cxx/emissions/BUILD index 83dd6de..160934b 100644 --- a/cxx/emissions/BUILD +++ b/cxx/emissions/BUILD @@ -9,6 +9,19 @@ cc_library( ], ) +cc_library( + name = "get_emission", + srcs = ["get_emission.cc"], + hdrs = ["get_emission.hh"], + visibility = ["//:__subpackages__"], + deps = [ + ":bitflip", + ":gaussian", + ":simple_string", + ":sometimes", + ], +) + cc_library( name = "bitflip", srcs = ["bitflip.hh"], @@ -46,6 +59,16 @@ cc_library( ], ) +cc_test( + name = "get_emission_test", + srcs = ["get_emission_test.cc"], + deps = [ + ":get_emission", + "@boost//:algorithm", + "@boost//:test", + ], +) + cc_test( name = "bitflip_test", srcs = ["bitflip_test.cc"], diff --git a/cxx/emissions/get_emission.cc b/cxx/emissions/get_emission.cc new file mode 100644 index 0000000..361b40c --- /dev/null +++ b/cxx/emissions/get_emission.cc @@ -0,0 +1,24 @@ +// Copyright 2024 +// See LICENSE.txt + +#include "emissions/get_emission.hh" + +#include + +#include "emissions/bitflip.hh" +#include "emissions/gaussian.hh" +#include "emissions/simple_string.hh" + +EmissionVariant get_emission(const std::string& emission_name) { + if (emission_name == "gaussian") { + return new GaussianEmission(); + } else if (emission_name == "simple_string") { + return new SimpleStringEmission(); + } else if (emission_name == "sometimes_gaussian") { + return new SometimesGaussian(); + } else if (emission_name == "sometimes_bitflip") { + return new SometimesBitFlip(); + } + printf("Unknown emission name %s\n", emission_name.c_str()); + assert(false); +} diff --git a/cxx/emissions/get_emission.hh b/cxx/emissions/get_emission.hh new file mode 100644 index 0000000..73684f8 --- /dev/null +++ b/cxx/emissions/get_emission.hh @@ -0,0 +1,20 @@ +// Copyright 2024 +// See LICENSE.txt + +#pragma once + +#include +#include + +#include "emissions/sometimes.hh" + +class BitFlip; +class GaussianEmission; +class SimpleStringEmission; +using SometimesBitFlip = Sometimes; +using SometimesGaussian = Sometimes; + +using EmissionVariant = std::variant; + +EmissionVariant get_emission(const std::string& emission_name); diff --git a/cxx/emissions/get_emission_test.cc b/cxx/emissions/get_emission_test.cc new file mode 100644 index 0000000..dcaa640 --- /dev/null +++ b/cxx/emissions/get_emission_test.cc @@ -0,0 +1,41 @@ +#define BOOST_TEST_MODULE test get_emission + +#include "emissions/get_emission.hh" + +#include + +#include "emissions/bitflip.hh" +#include "emissions/gaussian.hh" +#include "emissions/simple_string.hh" + +BOOST_AUTO_TEST_CASE(test_get_emission_gaussian) { + EmissionVariant ev = get_emission("gaussian"); + GaussianEmission* ge = std::get(ev); + + ge->incorporate(std::make_pair(2.0, 2.1)); + BOOST_TEST(ge->N == 1); +} + +BOOST_AUTO_TEST_CASE(test_get_emission_simple_string) { + EmissionVariant ev = get_emission("simple_string"); + SimpleStringEmission* sse = std::get(ev); + + sse->incorporate(std::make_pair("hello", "hi")); + BOOST_TEST(sse->N == 1); +} + +BOOST_AUTO_TEST_CASE(test_get_emission_sometimes_gaussian) { + EmissionVariant ev = get_emission("sometimes_gaussian"); + SometimesGaussian* sg = std::get(ev); + + sg->incorporate(std::make_pair(2.0, 2.1)); + BOOST_TEST(sg->N == 1); +} + +BOOST_AUTO_TEST_CASE(test_get_emission_sometimes_bitflip) { + EmissionVariant ev = get_emission("sometimes_bitflip"); + SometimesBitFlip* sbf = std::get(ev); + + sbf->incorporate(std::make_pair(true, true)); + BOOST_TEST(sbf->N == 1); +}