-
Notifications
You must be signed in to change notification settings - Fork 508
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Refactor online recognizer. Make it easier to support other streaming models. Note that it is a breaking change for the Python API. `sherpa_onnx.OnlineRecognizer()` used before should be replaced by `sherpa_onnx.OnlineRecognizer.from_transducer()`.
- Loading branch information
1 parent
6061318
commit 79c2ce5
Showing
40 changed files
with
670 additions
and
480 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
// sherpa-onnx/csrc/online-model-config.cc | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
#include "sherpa-onnx/csrc/online-model-config.h" | ||
|
||
#include <string> | ||
|
||
#include "sherpa-onnx/csrc/file-utils.h" | ||
#include "sherpa-onnx/csrc/macros.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
void OnlineModelConfig::Register(ParseOptions *po) { | ||
transducer.Register(po); | ||
|
||
po->Register("tokens", &tokens, "Path to tokens.txt"); | ||
|
||
po->Register("num-threads", &num_threads, | ||
"Number of threads to run the neural network"); | ||
|
||
po->Register("debug", &debug, | ||
"true to print model information while loading it."); | ||
|
||
po->Register("provider", &provider, | ||
"Specify a provider to use: cpu, cuda, coreml"); | ||
|
||
po->Register("model-type", &model_type, | ||
"Specify it to reduce model initialization time. " | ||
"Valid values are: conformer, lstm, zipformer, zipformer2." | ||
"All other values lead to loading the model twice."); | ||
} | ||
|
||
bool OnlineModelConfig::Validate() const { | ||
if (num_threads < 1) { | ||
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); | ||
return false; | ||
} | ||
|
||
if (!FileExists(tokens)) { | ||
SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str()); | ||
return false; | ||
} | ||
|
||
return transducer.Validate(); | ||
} | ||
|
||
std::string OnlineModelConfig::ToString() const { | ||
std::ostringstream os; | ||
|
||
os << "OnlineModelConfig("; | ||
os << "transducer=" << transducer.ToString() << ", "; | ||
os << "tokens=\"" << tokens << "\", "; | ||
os << "num_threads=" << num_threads << ", "; | ||
os << "debug=" << (debug ? "True" : "False") << ", "; | ||
os << "provider=\"" << provider << "\", "; | ||
os << "model_type=\"" << model_type << "\")"; | ||
|
||
return os.str(); | ||
} | ||
|
||
} // namespace sherpa_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
// sherpa-onnx/csrc/online-model-config.h | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
#ifndef SHERPA_ONNX_CSRC_ONLINE_MODEL_CONFIG_H_ | ||
#define SHERPA_ONNX_CSRC_ONLINE_MODEL_CONFIG_H_ | ||
|
||
#include <string> | ||
|
||
#include "sherpa-onnx/csrc/online-transducer-model-config.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
struct OnlineModelConfig { | ||
OnlineTransducerModelConfig transducer; | ||
std::string tokens; | ||
int32_t num_threads = 1; | ||
bool debug = false; | ||
std::string provider = "cpu"; | ||
|
||
// Valid values: | ||
// - conformer, conformer transducer from icefall | ||
// - lstm, lstm transducer from icefall | ||
// - zipformer, zipformer transducer from icefall | ||
// - zipformer2, zipformer2 transducer from icefall | ||
// | ||
// All other values are invalid and lead to loading the model twice. | ||
std::string model_type; | ||
|
||
OnlineModelConfig() = default; | ||
OnlineModelConfig(const OnlineTransducerModelConfig &transducer, | ||
const std::string &tokens, int32_t num_threads, bool debug, | ||
const std::string &provider, const std::string &model_type) | ||
: transducer(transducer), | ||
tokens(tokens), | ||
num_threads(num_threads), | ||
debug(debug), | ||
provider(provider), | ||
model_type(model_type) {} | ||
|
||
void Register(ParseOptions *po); | ||
bool Validate() const; | ||
|
||
std::string ToString() const; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_ONLINE_MODEL_CONFIG_H_ |
Oops, something went wrong.