diff --git a/doc/source/etl_audio.rst b/doc/source/etl_audio.rst index 06e84cdc..477eb03d 100644 --- a/doc/source/etl_audio.rst +++ b/doc/source/etl_audio.rst @@ -88,7 +88,7 @@ You can configure the audio processing pipeline from python using a dictionary l .. code-block:: python - audio_config = dict(sampling_freq=16000, + audio_config = dict(sample_freq_hz=16000, max_duration="3 seconds", frame_length="256 samples", frame_stride="128 samples", @@ -164,7 +164,8 @@ Transcription provisioning can be configured using the following parameters: :escape: ~ alphabet (string)| *Required* | A string of symbols to be included in the target output - max_length (uint_32t) | *Required* | Maximum number of symbols in a target + max_length (uint32_t) | *Required* | Maximum number of symbols in a target + unknown_value (uint8_t) | 0 | Integer value to give to unknown characters. 0 causes them to be discarded. Value should be between ``len(alphabet)`` and 255. pack_for_ctc (bool) | False | Packs the output buffer to be passed to the `warp CTC`_ objective function output_type (string) | ~"uint8_t~" | transcript data type diff --git a/loader/src/etl_char_map.cpp b/loader/src/etl_char_map.cpp index ddcff3d7..1c90cce8 100644 --- a/loader/src/etl_char_map.cpp +++ b/loader/src/etl_char_map.cpp @@ -24,11 +24,20 @@ std::shared_ptr char_map::extractor::extract(const char* in_a string transcript(in_array, nvalid); vector char_ints((vector::size_type) _max_length, (uint8_t) 0); + uint32_t j = 0; for (uint32_t i=0; isecond : UINT8_MAX; - char_ints[i] = v; + if (l == _cmap.end()) { + if (_unknown_value > 0) { + char_ints[j++] = _unknown_value; + continue; + } + else { + continue; + } + } + char_ints[j++] = l->second; } auto rc = make_shared(char_ints, nvalid); return rc; diff --git a/loader/src/etl_char_map.hpp b/loader/src/etl_char_map.hpp index 9d4e0a37..066c5b40 100644 --- a/loader/src/etl_char_map.hpp +++ b/loader/src/etl_char_map.hpp @@ -34,9 +34,17 @@ namespace nervana { class char_map::config : public interface::config { friend class extractor; public: + /** Maximum length of each transcript. Samples with longer transcripts + * will be truncated */ uint32_t max_length; + /** Character map alphabet */ std::string alphabet; + /** Integer value to give to unknown characters. 0 causes them to be + * discarded.*/ + uint8_t unknown_value = 0; + /** Pack the output buffer for use in CTC. This places them end to end */ bool pack_for_ctc = false; + /** Output data type. Currently only uint8_t is supported */ std::string output_type{"uint8_t"}; config(nlohmann::json js) { @@ -67,6 +75,7 @@ namespace nervana { std::vector> config_list = { ADD_SCALAR(max_length, mode::REQUIRED), ADD_SCALAR(alphabet, mode::REQUIRED), + ADD_SCALAR(unknown_value, mode::OPTIONAL), ADD_SCALAR(pack_for_ctc, mode::OPTIONAL), ADD_SCALAR(output_type, mode::OPTIONAL, [](const std::string& v){ return output_type::is_valid_type(v); }) }; @@ -80,6 +89,9 @@ namespace nervana { if (!unique_chars(alphabet)) { throw std::runtime_error("alphabet does not consist of unique chars " + alphabet); } + if (unknown_value > 0 && unknown_value < alphabet.size()) { + throw std::runtime_error("unknown_value should be >= alphabet length and <= 255"); + } } bool unique_chars(std::string test_string) @@ -121,13 +133,14 @@ namespace nervana { class char_map::extractor : public interface::extractor { public: extractor( const char_map::config& cfg) - : _cmap{cfg.get_cmap()}, _max_length{cfg.max_length} + : _cmap{cfg.get_cmap()}, _max_length{cfg.max_length}, _unknown_value{cfg.unknown_value} {} virtual ~extractor(){} virtual std::shared_ptr extract(const char*, int) override; private: const std::unordered_map& _cmap; // This comes from config uint32_t _max_length; + const uint8_t _unknown_value; }; class char_map::loader : public interface::loader { diff --git a/loader/test/test_char_map.cpp b/loader/test/test_char_map.cpp index 6dfd70d8..05743fbb 100644 --- a/loader/test/test_char_map.cpp +++ b/loader/test/test_char_map.cpp @@ -36,51 +36,81 @@ TEST(char_map, bad) { TEST(char_map, test) { { - nlohmann::json js = {{"alphabet", "ABCDEFGHIJKLMNOPQRSTUVWXYZ .,()"}, - {"max_length", 20}}; + string alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ .,()"; + string transcript = "The quick brown fox jumps over the lazy dog"; + uint8_t max_length = transcript.size() + 5; + + nlohmann::json js = {{"alphabet", alphabet}, + {"max_length", max_length}, + {"unknown_value", 0}}; char_map::config cfg{js}; char_map::extractor extractor(cfg); - auto data = cfg.get_cmap(); + char_map::loader loader(cfg); + + // Ensure cmap is set up properly + std::unordered_map data = cfg.get_cmap(); EXPECT_EQ(2, data['C']); + EXPECT_EQ(26, data[' ']); - // handle mapping of unknown character + // Make sure mapping is correct and extra characters are mapped to 0 { - string t1 = "The quick brown -fox jump over the lazy dog"; - auto extracted = extractor.extract(&t1[0], t1.size()); - EXPECT_EQ(UINT8_MAX, extracted->get_data()[16]); + vector expected = {19, 7, 4, 26, 16, 20, 8, 2, 10, 26, 1, 17, 14, 22, 13, 26, 5, 14, 23, 26, 9, 20, 12, 15, 18, 26, 14, 21, 4, 17, 26, 19, 7, 4, 26, 11, 0, 25, 24, 26, 3, 14, 6, 0, 0, 0, 0, 0}; + auto decoded = extractor.extract(&transcript[0], transcript.size()); + // decoded exists + ASSERT_NE(nullptr, decoded); + // has the right length + EXPECT_EQ(expected.size(), max_length); + // and the right values + for( int i=0; iget_data()[i]) << "at index " << i; + } } - + // handle mapping of unknown characters { - string t1 = "The quick brOwn"; - vector expected = {19, 7, 4, 26, 16, 20, 8, 2, 10, 26, 1, 17, 14, 22, 13, - 0, 0, 0, 0, 0}; - auto decoded = extractor.extract(&t1[0], t1.size()); - ASSERT_NE(nullptr, decoded); - ASSERT_EQ(expected.size(),decoded->get_data().size()); + // Skip unknown characters + string unknown = "The0:3 ?q!uick brown"; + string discarded = "The quick brown"; + auto unk_dec = extractor.extract(&unknown[0], unknown.size()); + auto exp_dec = extractor.extract(&discarded[0], discarded.size()); + + for (int i = 0; i < discarded.size(); i++) { + EXPECT_EQ(exp_dec->get_data()[i], + unk_dec->get_data()[i]); + } + + // Unknown characters should be given value of UINT8_MAX + nlohmann::json js = {{"alphabet", alphabet}, + {"max_length", max_length}, + {"unknown_value", 255}}; + char_map::config unk_cfg{js}; + char_map::extractor unk_extractor(unk_cfg); + vector expected = {19, 7, 4, 255, 255, 255, 26, 255, 16, 255, 20, 8, 2, 10, 26, 1, 17, 14, 22, 13}; + unk_dec = unk_extractor.extract(&unknown[0], unknown.size()); for( int i=0; iget_data()[i]) << "at index " << i; + EXPECT_EQ(expected[i], unk_dec->get_data()[i]) << "at index " << i; } + } - char_map::loader loader(cfg); - int max_length = js["max_length"]; - char outbuf[max_length]; // Now check max length truncation + char outbuf[max_length]; { - string t1 = "now is the winter of our discontent"; - auto decoded = extractor.extract(&t1[0], t1.size()); + string long_str = "This is a really long transcript that should overflow the buffer at the letter e in overflow"; + auto decoded = extractor.extract(&long_str[0], long_str.size()); loader.load({outbuf}, decoded); - ASSERT_EQ(outbuf[max_length - 1], 5); + ASSERT_EQ(outbuf[max_length - 1], 4); } // Check zero padding { - string t1 = "now"; - auto decoded = extractor.extract(&t1[0], t1.size()); + string short_str = "now"; + auto decoded = extractor.extract(&short_str[0], short_str.size()); loader.load({outbuf}, decoded); - ASSERT_EQ(outbuf[max_length - 1], 0); + for (int i = 3; i < max_length; i++) { + ASSERT_EQ(outbuf[i], 0); + } } }