Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
Added configurable unkown value for char_map ETL
Browse files Browse the repository at this point in the history
  • Loading branch information
tyler-nervana authored and apark263 committed Nov 2, 2016
1 parent aa4bda7 commit 7160cd6
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 29 deletions.
5 changes: 3 additions & 2 deletions doc/source/etl_audio.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ You can configure the audio processing pipeline from python using a dictionary l

.. code-block:: python
audio_config = dict(sampling_freq=16000,
audio_config = dict(sample_freq_hz=16000,
max_duration="3 seconds",
frame_length="256 samples",
frame_stride="128 samples",
Expand Down Expand Up @@ -164,7 +164,8 @@ Transcription provisioning can be configured using the following parameters:
:escape: ~

alphabet (string)| *Required* | A string of symbols to be included in the target output
max_length (uint_32t) | *Required* | Maximum number of symbols in a target
max_length (uint32_t) | *Required* | Maximum number of symbols in a target
unknown_value (uint8_t) | 0 | Integer value to give to unknown characters. 0 causes them to be discarded. Value should be between ``len(alphabet)`` and 255.
pack_for_ctc (bool) | False | Packs the output buffer to be passed to the `warp CTC`_ objective function
output_type (string) | ~"uint8_t~" | transcript data type

Expand Down
13 changes: 11 additions & 2 deletions loader/src/etl_char_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,20 @@ std::shared_ptr<char_map::decoded> char_map::extractor::extract(const char* in_a
string transcript(in_array, nvalid);
vector<uint8_t> char_ints((vector<uint8_t>::size_type) _max_length, (uint8_t) 0);

uint32_t j = 0;
for (uint32_t i=0; i<nvalid; i++)
{
auto l = _cmap.find(std::toupper(transcript[i]));
uint8_t v = (l != _cmap.end()) ? l->second : UINT8_MAX;
char_ints[i] = v;
if (l == _cmap.end()) {
if (_unknown_value > 0) {
char_ints[j++] = _unknown_value;
continue;
}
else {
continue;
}
}
char_ints[j++] = l->second;
}
auto rc = make_shared<char_map::decoded>(char_ints, nvalid);
return rc;
Expand Down
15 changes: 14 additions & 1 deletion loader/src/etl_char_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,17 @@ namespace nervana {
class char_map::config : public interface::config {
friend class extractor;
public:
/** Maximum length of each transcript. Samples with longer transcripts
* will be truncated */
uint32_t max_length;
/** Character map alphabet */
std::string alphabet;
/** Integer value to give to unknown characters. 0 causes them to be
* discarded.*/
uint8_t unknown_value = 0;
/** Pack the output buffer for use in CTC. This places them end to end */
bool pack_for_ctc = false;
/** Output data type. Currently only uint8_t is supported */
std::string output_type{"uint8_t"};

config(nlohmann::json js) {
Expand Down Expand Up @@ -67,6 +75,7 @@ namespace nervana {
std::vector<std::shared_ptr<interface::config_info_interface>> config_list = {
ADD_SCALAR(max_length, mode::REQUIRED),
ADD_SCALAR(alphabet, mode::REQUIRED),
ADD_SCALAR(unknown_value, mode::OPTIONAL),
ADD_SCALAR(pack_for_ctc, mode::OPTIONAL),
ADD_SCALAR(output_type, mode::OPTIONAL, [](const std::string& v){ return output_type::is_valid_type(v); })
};
Expand All @@ -80,6 +89,9 @@ namespace nervana {
if (!unique_chars(alphabet)) {
throw std::runtime_error("alphabet does not consist of unique chars " + alphabet);
}
if (unknown_value > 0 && unknown_value < alphabet.size()) {
throw std::runtime_error("unknown_value should be >= alphabet length and <= 255");
}
}

bool unique_chars(std::string test_string)
Expand Down Expand Up @@ -121,13 +133,14 @@ namespace nervana {
class char_map::extractor : public interface::extractor<char_map::decoded> {
public:
extractor( const char_map::config& cfg)
: _cmap{cfg.get_cmap()}, _max_length{cfg.max_length}
: _cmap{cfg.get_cmap()}, _max_length{cfg.max_length}, _unknown_value{cfg.unknown_value}
{}
virtual ~extractor(){}
virtual std::shared_ptr<char_map::decoded> extract(const char*, int) override;
private:
const std::unordered_map<char, uint8_t>& _cmap; // This comes from config
uint32_t _max_length;
const uint8_t _unknown_value;
};

class char_map::loader : public interface::loader<char_map::decoded> {
Expand Down
78 changes: 54 additions & 24 deletions loader/test/test_char_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,51 +36,81 @@ TEST(char_map, bad) {

TEST(char_map, test) {
{
nlohmann::json js = {{"alphabet", "ABCDEFGHIJKLMNOPQRSTUVWXYZ .,()"},
{"max_length", 20}};
string alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ .,()";
string transcript = "The quick brown fox jumps over the lazy dog";
uint8_t max_length = transcript.size() + 5;

nlohmann::json js = {{"alphabet", alphabet},
{"max_length", max_length},
{"unknown_value", 0}};
char_map::config cfg{js};
char_map::extractor extractor(cfg);
auto data = cfg.get_cmap();
char_map::loader loader(cfg);

// Ensure cmap is set up properly
std::unordered_map<char, uint8_t> data = cfg.get_cmap();
EXPECT_EQ(2, data['C']);
EXPECT_EQ(26, data[' ']);

// handle mapping of unknown character
// Make sure mapping is correct and extra characters are mapped to 0
{
string t1 = "The quick brown -fox jump over the lazy dog";
auto extracted = extractor.extract(&t1[0], t1.size());
EXPECT_EQ(UINT8_MAX, extracted->get_data()[16]);
vector<int> expected = {19, 7, 4, 26, 16, 20, 8, 2, 10, 26, 1, 17, 14, 22, 13, 26, 5, 14, 23, 26, 9, 20, 12, 15, 18, 26, 14, 21, 4, 17, 26, 19, 7, 4, 26, 11, 0, 25, 24, 26, 3, 14, 6, 0, 0, 0, 0, 0};
auto decoded = extractor.extract(&transcript[0], transcript.size());
// decoded exists
ASSERT_NE(nullptr, decoded);
// has the right length
EXPECT_EQ(expected.size(), max_length);
// and the right values
for( int i=0; i<expected.size(); i++ ) {
EXPECT_EQ(expected[i], decoded->get_data()[i]) << "at index " << i;
}
}


// handle mapping of unknown characters
{
string t1 = "The quick brOwn";
vector<int> expected = {19, 7, 4, 26, 16, 20, 8, 2, 10, 26, 1, 17, 14, 22, 13,
0, 0, 0, 0, 0};
auto decoded = extractor.extract(&t1[0], t1.size());
ASSERT_NE(nullptr, decoded);
ASSERT_EQ(expected.size(),decoded->get_data().size());
// Skip unknown characters
string unknown = "The0:3 ?q!uick brown";
string discarded = "The quick brown";
auto unk_dec = extractor.extract(&unknown[0], unknown.size());
auto exp_dec = extractor.extract(&discarded[0], discarded.size());

for (int i = 0; i < discarded.size(); i++) {
EXPECT_EQ(exp_dec->get_data()[i],
unk_dec->get_data()[i]);
}

// Unknown characters should be given value of UINT8_MAX
nlohmann::json js = {{"alphabet", alphabet},
{"max_length", max_length},
{"unknown_value", 255}};
char_map::config unk_cfg{js};
char_map::extractor unk_extractor(unk_cfg);
vector<int> expected = {19, 7, 4, 255, 255, 255, 26, 255, 16, 255, 20, 8, 2, 10, 26, 1, 17, 14, 22, 13};
unk_dec = unk_extractor.extract(&unknown[0], unknown.size());
for( int i=0; i<expected.size(); i++ ) {
EXPECT_EQ(expected[i], decoded->get_data()[i]) << "at index " << i;
EXPECT_EQ(expected[i], unk_dec->get_data()[i]) << "at index " << i;
}

}

char_map::loader loader(cfg);
int max_length = js["max_length"];
char outbuf[max_length];
// Now check max length truncation
char outbuf[max_length];
{
string t1 = "now is the winter of our discontent";
auto decoded = extractor.extract(&t1[0], t1.size());
string long_str = "This is a really long transcript that should overflow the buffer at the letter e in overflow";
auto decoded = extractor.extract(&long_str[0], long_str.size());
loader.load({outbuf}, decoded);

ASSERT_EQ(outbuf[max_length - 1], 5);
ASSERT_EQ(outbuf[max_length - 1], 4);
}

// Check zero padding
{
string t1 = "now";
auto decoded = extractor.extract(&t1[0], t1.size());
string short_str = "now";
auto decoded = extractor.extract(&short_str[0], short_str.size());
loader.load({outbuf}, decoded);
ASSERT_EQ(outbuf[max_length - 1], 0);
for (int i = 3; i < max_length; i++) {
ASSERT_EQ(outbuf[i], 0);
}
}

}
Expand Down

0 comments on commit 7160cd6

Please sign in to comment.