From 79c2ce5dd44bd6d59082374dee5dd68d9f8b8c33 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 9 Aug 2023 20:27:31 +0800 Subject: [PATCH] Refactor online recognizer (#250) * Refactor online recognizer. Make it easier to support other streaming models. Note that it is a breaking change for the Python API. `sherpa_onnx.OnlineRecognizer()` used before should be replaced by `sherpa_onnx.OnlineRecognizer.from_transducer()`. --- python-api-examples/online-decode-files.py | 2 +- ...from-microphone-with-endpoint-detection.py | 2 +- .../speech-recognition-from-microphone.py | 2 +- .../speech-recognition-from-url.py | 2 +- python-api-examples/streaming_server.py | 2 +- sherpa-onnx/c-api/c-api.cc | 12 +- sherpa-onnx/csrc/CMakeLists.txt | 2 + .../csrc/online-conformer-transducer-model.cc | 16 +- .../csrc/online-conformer-transducer-model.h | 9 +- .../csrc/online-lstm-transducer-model.cc | 16 +- .../csrc/online-lstm-transducer-model.h | 8 +- sherpa-onnx/csrc/online-model-config.cc | 61 +++++ sherpa-onnx/csrc/online-model-config.h | 48 ++++ sherpa-onnx/csrc/online-recognizer-impl.cc | 33 +++ sherpa-onnx/csrc/online-recognizer-impl.h | 52 ++++ .../csrc/online-recognizer-transducer-impl.h | 250 ++++++++++++++++++ sherpa-onnx/csrc/online-recognizer.cc | 234 +--------------- sherpa-onnx/csrc/online-recognizer.h | 10 +- .../csrc/online-transducer-model-config.cc | 51 +--- .../csrc/online-transducer-model-config.h | 40 +-- sherpa-onnx/csrc/online-transducer-model.cc | 8 +- sherpa-onnx/csrc/online-transducer-model.h | 11 +- .../csrc/online-zipformer-transducer-model.cc | 16 +- .../csrc/online-zipformer-transducer-model.h | 9 +- .../online-zipformer2-transducer-model.cc | 16 +- .../csrc/online-zipformer2-transducer-model.h | 9 +- sherpa-onnx/csrc/session.cc | 3 +- sherpa-onnx/csrc/session.h | 5 +- sherpa-onnx/csrc/sherpa-onnx-alsa.cc | 68 ++--- sherpa-onnx/csrc/text-utils.h | 1 + sherpa-onnx/jni/jni.cc | 28 +- sherpa-onnx/python/csrc/CMakeLists.txt | 1 + .../python/csrc/online-model-config.cc | 35 +++ sherpa-onnx/python/csrc/online-model-config.h | 16 ++ sherpa-onnx/python/csrc/online-recognizer.cc | 7 +- .../csrc/online-transducer-model-config.cc | 19 +- sherpa-onnx/python/csrc/sherpa-onnx.cc | 4 +- .../python/sherpa_onnx/online_recognizer.py | 20 +- .../python/tests/test_online_recognizer.py | 4 +- .../test_online_transducer_model_config.py | 18 +- 40 files changed, 670 insertions(+), 480 deletions(-) create mode 100644 sherpa-onnx/csrc/online-model-config.cc create mode 100644 sherpa-onnx/csrc/online-model-config.h create mode 100644 sherpa-onnx/csrc/online-recognizer-impl.cc create mode 100644 sherpa-onnx/csrc/online-recognizer-impl.h create mode 100644 sherpa-onnx/csrc/online-recognizer-transducer-impl.h create mode 100644 sherpa-onnx/python/csrc/online-model-config.cc create mode 100644 sherpa-onnx/python/csrc/online-model-config.h diff --git a/python-api-examples/online-decode-files.py b/python-api-examples/online-decode-files.py index 03d0e6f55..e2e1dc556 100755 --- a/python-api-examples/online-decode-files.py +++ b/python-api-examples/online-decode-files.py @@ -205,7 +205,7 @@ def main(): assert_file_exists(args.joiner) assert_file_exists(args.tokens) - recognizer = sherpa_onnx.OnlineRecognizer( + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( tokens=args.tokens, encoder=args.encoder, decoder=args.decoder, diff --git a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py index 45cdb1e6c..36f2b5481 100755 --- a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py +++ b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py @@ -91,7 +91,7 @@ def create_recognizer(): # Please replace the model files if needed. # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html # for download links. - recognizer = sherpa_onnx.OnlineRecognizer( + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( tokens=args.tokens, encoder=args.encoder, decoder=args.decoder, diff --git a/python-api-examples/speech-recognition-from-microphone.py b/python-api-examples/speech-recognition-from-microphone.py index 6edbb804a..9723230c5 100755 --- a/python-api-examples/speech-recognition-from-microphone.py +++ b/python-api-examples/speech-recognition-from-microphone.py @@ -145,7 +145,7 @@ def create_recognizer(): # Please replace the model files if needed. # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html # for download links. - recognizer = sherpa_onnx.OnlineRecognizer( + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( tokens=args.tokens, encoder=args.encoder, decoder=args.decoder, diff --git a/python-api-examples/speech-recognition-from-url.py b/python-api-examples/speech-recognition-from-url.py index a2f61caa2..1c6c6a1f9 100755 --- a/python-api-examples/speech-recognition-from-url.py +++ b/python-api-examples/speech-recognition-from-url.py @@ -94,7 +94,7 @@ def create_recognizer(args): # Please replace the model files if needed. # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html # for download links. - recognizer = sherpa_onnx.OnlineRecognizer( + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( tokens=args.tokens, encoder=args.encoder, decoder=args.decoder, diff --git a/python-api-examples/streaming_server.py b/python-api-examples/streaming_server.py index ea9d111f9..be79979a6 100755 --- a/python-api-examples/streaming_server.py +++ b/python-api-examples/streaming_server.py @@ -294,7 +294,7 @@ def get_args(): def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: - recognizer = sherpa_onnx.OnlineRecognizer( + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( tokens=args.tokens, encoder=args.encoder_model, decoder=args.decoder_model, diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 38ffadb04..7a2e0540f 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -38,11 +38,11 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( recognizer_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); - recognizer_config.model_config.encoder_filename = + recognizer_config.model_config.transducer.encoder = SHERPA_ONNX_OR(config->model_config.encoder, ""); - recognizer_config.model_config.decoder_filename = + recognizer_config.model_config.transducer.decoder = SHERPA_ONNX_OR(config->model_config.decoder, ""); - recognizer_config.model_config.joiner_filename = + recognizer_config.model_config.transducer.joiner = SHERPA_ONNX_OR(config->model_config.joiner, ""); recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); @@ -143,7 +143,7 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult( auto count = result.tokens.size(); if (count > 0) { size_t total_length = 0; - for (const auto& token : result.tokens) { + for (const auto &token : result.tokens) { // +1 for the null character at the end of each token total_length += token.size() + 1; } @@ -154,10 +154,10 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult( memset(reinterpret_cast(const_cast(r->tokens)), 0, total_length); r->timestamps = new float[r->count]; - char **tokens_temp = new char*[r->count]; + char **tokens_temp = new char *[r->count]; int32_t pos = 0; for (int32_t i = 0; i < r->count; ++i) { - tokens_temp[i] = const_cast(r->tokens) + pos; + tokens_temp[i] = const_cast(r->tokens) + pos; memcpy(reinterpret_cast(const_cast(r->tokens + pos)), result.tokens[i].c_str(), result.tokens[i].size()); // +1 to move past the null character diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index f9befff77..3426d5659 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -43,6 +43,8 @@ set(sources online-lm-config.cc online-lm.cc online-lstm-transducer-model.cc + online-model-config.cc + online-recognizer-impl.cc online-recognizer.cc online-rnn-lm.cc online-stream.cc diff --git a/sherpa-onnx/csrc/online-conformer-transducer-model.cc b/sherpa-onnx/csrc/online-conformer-transducer-model.cc index 0d0ade3ae..58cbce01c 100644 --- a/sherpa-onnx/csrc/online-conformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-conformer-transducer-model.cc @@ -30,46 +30,46 @@ namespace sherpa_onnx { OnlineConformerTransducerModel::OnlineConformerTransducerModel( - const OnlineTransducerModelConfig &config) + const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), sess_opts_(GetSessionOptions(config)), allocator_{} { { - auto buf = ReadFile(config.encoder_filename); + auto buf = ReadFile(config.transducer.encoder); InitEncoder(buf.data(), buf.size()); } { - auto buf = ReadFile(config.decoder_filename); + auto buf = ReadFile(config.transducer.decoder); InitDecoder(buf.data(), buf.size()); } { - auto buf = ReadFile(config.joiner_filename); + auto buf = ReadFile(config.transducer.joiner); InitJoiner(buf.data(), buf.size()); } } #if __ANDROID_API__ >= 9 OnlineConformerTransducerModel::OnlineConformerTransducerModel( - AAssetManager *mgr, const OnlineTransducerModelConfig &config) + AAssetManager *mgr, const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), sess_opts_(GetSessionOptions(config)), allocator_{} { { - auto buf = ReadFile(mgr, config.encoder_filename); + auto buf = ReadFile(mgr, config.transducer.encoder); InitEncoder(buf.data(), buf.size()); } { - auto buf = ReadFile(mgr, config.decoder_filename); + auto buf = ReadFile(mgr, config.transducer.decoder); InitDecoder(buf.data(), buf.size()); } { - auto buf = ReadFile(mgr, config.joiner_filename); + auto buf = ReadFile(mgr, config.transducer.joiner); InitJoiner(buf.data(), buf.size()); } } diff --git a/sherpa-onnx/csrc/online-conformer-transducer-model.h b/sherpa-onnx/csrc/online-conformer-transducer-model.h index f60ed53c1..bcf9e6eda 100644 --- a/sherpa-onnx/csrc/online-conformer-transducer-model.h +++ b/sherpa-onnx/csrc/online-conformer-transducer-model.h @@ -16,19 +16,18 @@ #endif #include "onnxruntime_cxx_api.h" // NOLINT -#include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model.h" namespace sherpa_onnx { class OnlineConformerTransducerModel : public OnlineTransducerModel { public: - explicit OnlineConformerTransducerModel( - const OnlineTransducerModelConfig &config); + explicit OnlineConformerTransducerModel(const OnlineModelConfig &config); #if __ANDROID_API__ >= 9 OnlineConformerTransducerModel(AAssetManager *mgr, - const OnlineTransducerModelConfig &config); + const OnlineModelConfig &config); #endif std::vector StackStates( @@ -88,7 +87,7 @@ class OnlineConformerTransducerModel : public OnlineTransducerModel { std::vector joiner_output_names_; std::vector joiner_output_names_ptr_; - OnlineTransducerModelConfig config_; + OnlineModelConfig config_; int32_t num_encoder_layers_ = 0; int32_t T_ = 0; diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.cc b/sherpa-onnx/csrc/online-lstm-transducer-model.cc index 3419cfc08..4a0e838da 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.cc +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.cc @@ -28,46 +28,46 @@ namespace sherpa_onnx { OnlineLstmTransducerModel::OnlineLstmTransducerModel( - const OnlineTransducerModelConfig &config) + const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), sess_opts_(GetSessionOptions(config)), allocator_{} { { - auto buf = ReadFile(config.encoder_filename); + auto buf = ReadFile(config.transducer.encoder); InitEncoder(buf.data(), buf.size()); } { - auto buf = ReadFile(config.decoder_filename); + auto buf = ReadFile(config.transducer.decoder); InitDecoder(buf.data(), buf.size()); } { - auto buf = ReadFile(config.joiner_filename); + auto buf = ReadFile(config.transducer.joiner); InitJoiner(buf.data(), buf.size()); } } #if __ANDROID_API__ >= 9 OnlineLstmTransducerModel::OnlineLstmTransducerModel( - AAssetManager *mgr, const OnlineTransducerModelConfig &config) + AAssetManager *mgr, const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), sess_opts_(GetSessionOptions(config)), allocator_{} { { - auto buf = ReadFile(mgr, config.encoder_filename); + auto buf = ReadFile(mgr, config.transducer.encoder); InitEncoder(buf.data(), buf.size()); } { - auto buf = ReadFile(mgr, config.decoder_filename); + auto buf = ReadFile(mgr, config.transducer.decoder); InitDecoder(buf.data(), buf.size()); } { - auto buf = ReadFile(mgr, config.joiner_filename); + auto buf = ReadFile(mgr, config.transducer.joiner); InitJoiner(buf.data(), buf.size()); } } diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.h b/sherpa-onnx/csrc/online-lstm-transducer-model.h index ab673a197..24119f240 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.h +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.h @@ -15,18 +15,18 @@ #endif #include "onnxruntime_cxx_api.h" // NOLINT -#include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model.h" namespace sherpa_onnx { class OnlineLstmTransducerModel : public OnlineTransducerModel { public: - explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config); + explicit OnlineLstmTransducerModel(const OnlineModelConfig &config); #if __ANDROID_API__ >= 9 OnlineLstmTransducerModel(AAssetManager *mgr, - const OnlineTransducerModelConfig &config); + const OnlineModelConfig &config); #endif std::vector StackStates( @@ -86,7 +86,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { std::vector joiner_output_names_; std::vector joiner_output_names_ptr_; - OnlineTransducerModelConfig config_; + OnlineModelConfig config_; int32_t num_encoder_layers_ = 0; int32_t T_ = 0; diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc new file mode 100644 index 000000000..7a4416b54 --- /dev/null +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -0,0 +1,61 @@ +// sherpa-onnx/csrc/online-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-onnx/csrc/online-model-config.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OnlineModelConfig::Register(ParseOptions *po) { + transducer.Register(po); + + po->Register("tokens", &tokens, "Path to tokens.txt"); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); + + po->Register("model-type", &model_type, + "Specify it to reduce model initialization time. " + "Valid values are: conformer, lstm, zipformer, zipformer2." + "All other values lead to loading the model twice."); +} + +bool OnlineModelConfig::Validate() const { + if (num_threads < 1) { + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); + return false; + } + + if (!FileExists(tokens)) { + SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str()); + return false; + } + + return transducer.Validate(); +} + +std::string OnlineModelConfig::ToString() const { + std::ostringstream os; + + os << "OnlineModelConfig("; + os << "transducer=" << transducer.ToString() << ", "; + os << "tokens=\"" << tokens << "\", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\", "; + os << "model_type=\"" << model_type << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h new file mode 100644 index 000000000..34e7b1e40 --- /dev/null +++ b/sherpa-onnx/csrc/online-model-config.h @@ -0,0 +1,48 @@ +// sherpa-onnx/csrc/online-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/online-transducer-model-config.h" + +namespace sherpa_onnx { + +struct OnlineModelConfig { + OnlineTransducerModelConfig transducer; + std::string tokens; + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + // Valid values: + // - conformer, conformer transducer from icefall + // - lstm, lstm transducer from icefall + // - zipformer, zipformer transducer from icefall + // - zipformer2, zipformer2 transducer from icefall + // + // All other values are invalid and lead to loading the model twice. + std::string model_type; + + OnlineModelConfig() = default; + OnlineModelConfig(const OnlineTransducerModelConfig &transducer, + const std::string &tokens, int32_t num_threads, bool debug, + const std::string &provider, const std::string &model_type) + : transducer(transducer), + tokens(tokens), + num_threads(num_threads), + debug(debug), + provider(provider), + model_type(model_type) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc new file mode 100644 index 000000000..a9e545dd2 --- /dev/null +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -0,0 +1,33 @@ +// sherpa-onnx/csrc/online-recognizer-impl.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-recognizer-impl.h" + +#include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" + +namespace sherpa_onnx { + +std::unique_ptr OnlineRecognizerImpl::Create( + const OnlineRecognizerConfig &config) { + if (!config.model_config.transducer.encoder.empty()) { + return std::make_unique(config); + } + + SHERPA_ONNX_LOGE("Please specify a model"); + exit(-1); +} + +#if __ANDROID_API__ >= 9 +std::unique_ptr OnlineRecognizerImpl::Create( + AAssetManager *mgr, const OnlineRecognizerConfig &config) { + if (!config.model_config.transducer.encoder.empty()) { + return std::make_unique(mgr, config); + } + + SHERPA_ONNX_LOGE("Please specify a model"); + exit(-1); +} +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-recognizer-impl.h b/sherpa-onnx/csrc/online-recognizer-impl.h new file mode 100644 index 000000000..8b574a4d2 --- /dev/null +++ b/sherpa-onnx/csrc/online-recognizer-impl.h @@ -0,0 +1,52 @@ +// sherpa-onnx/csrc/online-recognizer-impl.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_ + +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-recognizer.h" +#include "sherpa-onnx/csrc/online-stream.h" + +namespace sherpa_onnx { + +class OnlineRecognizerImpl { + public: + static std::unique_ptr Create( + const OnlineRecognizerConfig &config); + +#if __ANDROID_API__ >= 9 + static std::unique_ptr Create( + AAssetManager *mgr, const OnlineRecognizerConfig &config); +#endif + + virtual ~OnlineRecognizerImpl() = default; + + virtual void InitOnlineStream(OnlineStream *stream) const = 0; + + virtual std::unique_ptr CreateStream() const = 0; + + virtual std::unique_ptr CreateStream( + const std::vector> &contexts) const { + SHERPA_ONNX_LOGE("Only transducer models support contextual biasing."); + exit(-1); + } + + virtual bool IsReady(OnlineStream *s) const = 0; + + virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0; + + virtual OnlineRecognizerResult GetResult(OnlineStream *s) const = 0; + + virtual bool IsEndpoint(OnlineStream *s) const = 0; + + virtual void Reset(OnlineStream *s) const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_ diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h new file mode 100644 index 000000000..8f63e2017 --- /dev/null +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -0,0 +1,250 @@ +// sherpa-onnx/csrc/online-recognizer-transducer-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-lm.h" +#include "sherpa-onnx/csrc/online-recognizer-impl.h" +#include "sherpa-onnx/csrc/online-recognizer.h" +#include "sherpa-onnx/csrc/online-transducer-decoder.h" +#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" +#include "sherpa-onnx/csrc/online-transducer-model.h" +#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" +#include "sherpa-onnx/csrc/symbol-table.h" + +namespace sherpa_onnx { + +static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, + const SymbolTable &sym_table, + int32_t frame_shift_ms, + int32_t subsampling_factor) { + OnlineRecognizerResult r; + r.tokens.reserve(src.tokens.size()); + r.timestamps.reserve(src.tokens.size()); + + for (auto i : src.tokens) { + auto sym = sym_table[i]; + + r.text.append(sym); + r.tokens.push_back(std::move(sym)); + } + + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; + for (auto t : src.timestamps) { + float time = frame_shift_s * t; + r.timestamps.push_back(time); + } + + return r; +} + +class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { + public: + explicit OnlineRecognizerTransducerImpl(const OnlineRecognizerConfig &config) + : config_(config), + model_(OnlineTransducerModel::Create(config.model_config)), + sym_(config.model_config.tokens), + endpoint_(config_.endpoint_config) { + if (config.decoding_method == "modified_beam_search") { + if (!config_.lm_config.model.empty()) { + lm_ = OnlineLM::Create(config.lm_config); + } + + decoder_ = std::make_unique( + model_.get(), lm_.get(), config_.max_active_paths, + config_.lm_config.scale); + } else if (config.decoding_method == "greedy_search") { + decoder_ = + std::make_unique(model_.get()); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config.decoding_method.c_str()); + exit(-1); + } + } + +#if __ANDROID_API__ >= 9 + explicit OnlineRecognizerTransducerImpl(AAssetManager *mgr, + const OnlineRecognizerConfig &config) + : config_(config), + model_(OnlineTransducerModel::Create(mgr, config.model_config)), + sym_(mgr, config.model_config.tokens), + endpoint_(config_.endpoint_config) { + if (config.decoding_method == "modified_beam_search") { + decoder_ = std::make_unique( + model_.get(), lm_.get(), config_.max_active_paths, + config_.lm_config.scale); + } else if (config.decoding_method == "greedy_search") { + decoder_ = + std::make_unique(model_.get()); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config.decoding_method.c_str()); + exit(-1); + } + } +#endif + + void InitOnlineStream(OnlineStream *stream) const override { + auto r = decoder_->GetEmptyResult(); + + if (config_.decoding_method == "modified_beam_search" && + nullptr != stream->GetContextGraph()) { + // r.hyps has only one element. + for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { + it->second.context_state = stream->GetContextGraph()->Root(); + } + } + + stream->SetResult(r); + stream->SetStates(model_->GetEncoderInitStates()); + } + + std::unique_ptr CreateStream() const override { + auto stream = std::make_unique(config_.feat_config); + InitOnlineStream(stream.get()); + return stream; + } + + std::unique_ptr CreateStream( + const std::vector> &contexts) const override { + // We create context_graph at this level, because we might have default + // context_graph(will be added later if needed) that belongs to the whole + // model rather than each stream. + auto context_graph = + std::make_shared(contexts, config_.context_score); + auto stream = + std::make_unique(config_.feat_config, context_graph); + InitOnlineStream(stream.get()); + return stream; + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() + model_->ChunkSize() < + s->NumFramesReady(); + } + + void DecodeStreams(OnlineStream **ss, int32_t n) const override { + int32_t chunk_size = model_->ChunkSize(); + int32_t chunk_shift = model_->ChunkShift(); + + int32_t feature_dim = ss[0]->FeatureDim(); + + std::vector results(n); + std::vector features_vec(n * chunk_size * feature_dim); + std::vector> states_vec(n); + std::vector all_processed_frames(n); + bool has_context_graph = false; + + for (int32_t i = 0; i != n; ++i) { + if (!has_context_graph && ss[i]->GetContextGraph()) + has_context_graph = true; + + const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); + std::vector features = + ss[i]->GetFrames(num_processed_frames, chunk_size); + + // Question: should num_processed_frames include chunk_shift? + ss[i]->GetNumProcessedFrames() += chunk_shift; + + std::copy(features.begin(), features.end(), + features_vec.data() + i * chunk_size * feature_dim); + + results[i] = std::move(ss[i]->GetResult()); + states_vec[i] = std::move(ss[i]->GetStates()); + all_processed_frames[i] = num_processed_frames; + } + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape{n, chunk_size, feature_dim}; + + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), + features_vec.size(), x_shape.data(), + x_shape.size()); + + std::array processed_frames_shape{ + static_cast(all_processed_frames.size())}; + + Ort::Value processed_frames = Ort::Value::CreateTensor( + memory_info, all_processed_frames.data(), all_processed_frames.size(), + processed_frames_shape.data(), processed_frames_shape.size()); + + auto states = model_->StackStates(states_vec); + + auto pair = model_->RunEncoder(std::move(x), std::move(states), + std::move(processed_frames)); + + if (has_context_graph) { + decoder_->Decode(std::move(pair.first), ss, &results); + } else { + decoder_->Decode(std::move(pair.first), &results); + } + + std::vector> next_states = + model_->UnStackStates(pair.second); + + for (int32_t i = 0; i != n; ++i) { + ss[i]->SetResult(results[i]); + ss[i]->SetStates(std::move(next_states[i])); + } + } + + OnlineRecognizerResult GetResult(OnlineStream *s) const override { + OnlineTransducerDecoderResult decoder_result = s->GetResult(); + decoder_->StripLeadingBlanks(&decoder_result); + + // TODO(fangjun): Remember to change these constants if needed + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = 4; + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor); + } + + bool IsEndpoint(OnlineStream *s) const override { + if (!config_.enable_endpoint) return false; + int32_t num_processed_frames = s->GetNumProcessedFrames(); + + // frame shift is 10 milliseconds + float frame_shift_in_seconds = 0.01; + + // subsampling factor is 4 + int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 4; + + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, + frame_shift_in_seconds); + } + + void Reset(OnlineStream *s) const override { + // we keep the decoder_out + decoder_->UpdateDecoderOut(&s->GetResult()); + Ort::Value decoder_out = std::move(s->GetResult().decoder_out); + s->SetResult(decoder_->GetEmptyResult()); + s->GetResult().decoder_out = std::move(decoder_out); + + // Note: We only update counters. The underlying audio samples + // are not discarded. + s->Reset(); + } + + private: + OnlineRecognizerConfig config_; + std::unique_ptr model_; + std::unique_ptr lm_; + std::unique_ptr decoder_; + SymbolTable sym_; + Endpoint endpoint_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 39d3c1775..f72e7fc42 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -15,14 +15,7 @@ #include #include "nlohmann/json.hpp" -#include "sherpa-onnx/csrc/file-utils.h" -#include "sherpa-onnx/csrc/macros.h" -#include "sherpa-onnx/csrc/online-lm.h" -#include "sherpa-onnx/csrc/online-transducer-decoder.h" -#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" -#include "sherpa-onnx/csrc/online-transducer-model.h" -#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" -#include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/online-recognizer-impl.h" namespace sherpa_onnx { @@ -54,30 +47,6 @@ std::string OnlineRecognizerResult::AsJsonString() const { return j.dump(); } -static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, - const SymbolTable &sym_table, - int32_t frame_shift_ms, - int32_t subsampling_factor) { - OnlineRecognizerResult r; - r.tokens.reserve(src.tokens.size()); - r.timestamps.reserve(src.tokens.size()); - - for (auto i : src.tokens) { - auto sym = sym_table[i]; - - r.text.append(sym); - r.tokens.push_back(std::move(sym)); - } - - float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; - for (auto t : src.timestamps) { - float time = frame_shift_s * t; - r.timestamps.push_back(time); - } - - return r; -} - void OnlineRecognizerConfig::Register(ParseOptions *po) { feat_config.Register(po); model_config.Register(po); @@ -124,210 +93,13 @@ std::string OnlineRecognizerConfig::ToString() const { return os.str(); } -class OnlineRecognizer::Impl { - public: - explicit Impl(const OnlineRecognizerConfig &config) - : config_(config), - model_(OnlineTransducerModel::Create(config.model_config)), - sym_(config.model_config.tokens), - endpoint_(config_.endpoint_config) { - if (config.decoding_method == "modified_beam_search") { - if (!config_.lm_config.model.empty()) { - lm_ = OnlineLM::Create(config.lm_config); - } - - decoder_ = std::make_unique( - model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale); - } else if (config.decoding_method == "greedy_search") { - decoder_ = - std::make_unique(model_.get()); - } else { - SHERPA_ONNX_LOGE("Unsupported decoding method: %s", - config.decoding_method.c_str()); - exit(-1); - } - } - -#if __ANDROID_API__ >= 9 - explicit Impl(AAssetManager *mgr, const OnlineRecognizerConfig &config) - : config_(config), - model_(OnlineTransducerModel::Create(mgr, config.model_config)), - sym_(mgr, config.model_config.tokens), - endpoint_(config_.endpoint_config) { - if (config.decoding_method == "modified_beam_search") { - decoder_ = std::make_unique( - model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale); - } else if (config.decoding_method == "greedy_search") { - decoder_ = - std::make_unique(model_.get()); - } else { - SHERPA_ONNX_LOGE("Unsupported decoding method: %s", - config.decoding_method.c_str()); - exit(-1); - } - } -#endif - - void InitOnlineStream(OnlineStream *stream) const { - auto r = decoder_->GetEmptyResult(); - - if (config_.decoding_method == "modified_beam_search" && - nullptr != stream->GetContextGraph()) { - // r.hyps has only one element. - for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { - it->second.context_state = stream->GetContextGraph()->Root(); - } - } - - stream->SetResult(r); - stream->SetStates(model_->GetEncoderInitStates()); - } - - std::unique_ptr CreateStream() const { - auto stream = std::make_unique(config_.feat_config); - InitOnlineStream(stream.get()); - return stream; - } - - std::unique_ptr CreateStream( - const std::vector> &contexts) const { - // We create context_graph at this level, because we might have default - // context_graph(will be added later if needed) that belongs to the whole - // model rather than each stream. - auto context_graph = - std::make_shared(contexts, config_.context_score); - auto stream = - std::make_unique(config_.feat_config, context_graph); - InitOnlineStream(stream.get()); - return stream; - } - - bool IsReady(OnlineStream *s) const { - return s->GetNumProcessedFrames() + model_->ChunkSize() < - s->NumFramesReady(); - } - - void DecodeStreams(OnlineStream **ss, int32_t n) const { - int32_t chunk_size = model_->ChunkSize(); - int32_t chunk_shift = model_->ChunkShift(); - - int32_t feature_dim = ss[0]->FeatureDim(); - - std::vector results(n); - std::vector features_vec(n * chunk_size * feature_dim); - std::vector> states_vec(n); - std::vector all_processed_frames(n); - bool has_context_graph = false; - - for (int32_t i = 0; i != n; ++i) { - if (!has_context_graph && ss[i]->GetContextGraph()) - has_context_graph = true; - - const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); - std::vector features = - ss[i]->GetFrames(num_processed_frames, chunk_size); - - // Question: should num_processed_frames include chunk_shift? - ss[i]->GetNumProcessedFrames() += chunk_shift; - - std::copy(features.begin(), features.end(), - features_vec.data() + i * chunk_size * feature_dim); - - results[i] = std::move(ss[i]->GetResult()); - states_vec[i] = std::move(ss[i]->GetStates()); - all_processed_frames[i] = num_processed_frames; - } - - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - - std::array x_shape{n, chunk_size, feature_dim}; - - Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), - features_vec.size(), x_shape.data(), - x_shape.size()); - - std::array processed_frames_shape{ - static_cast(all_processed_frames.size())}; - - Ort::Value processed_frames = Ort::Value::CreateTensor( - memory_info, all_processed_frames.data(), all_processed_frames.size(), - processed_frames_shape.data(), processed_frames_shape.size()); - - auto states = model_->StackStates(states_vec); - - auto pair = model_->RunEncoder(std::move(x), std::move(states), - std::move(processed_frames)); - - if (has_context_graph) { - decoder_->Decode(std::move(pair.first), ss, &results); - } else { - decoder_->Decode(std::move(pair.first), &results); - } - - std::vector> next_states = - model_->UnStackStates(pair.second); - - for (int32_t i = 0; i != n; ++i) { - ss[i]->SetResult(results[i]); - ss[i]->SetStates(std::move(next_states[i])); - } - } - - OnlineRecognizerResult GetResult(OnlineStream *s) const { - OnlineTransducerDecoderResult decoder_result = s->GetResult(); - decoder_->StripLeadingBlanks(&decoder_result); - - // TODO(fangjun): Remember to change these constants if needed - int32_t frame_shift_ms = 10; - int32_t subsampling_factor = 4; - return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor); - } - - bool IsEndpoint(OnlineStream *s) const { - if (!config_.enable_endpoint) return false; - int32_t num_processed_frames = s->GetNumProcessedFrames(); - - // frame shift is 10 milliseconds - float frame_shift_in_seconds = 0.01; - - // subsampling factor is 4 - int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 4; - - return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, - frame_shift_in_seconds); - } - - void Reset(OnlineStream *s) const { - // we keep the decoder_out - decoder_->UpdateDecoderOut(&s->GetResult()); - Ort::Value decoder_out = std::move(s->GetResult().decoder_out); - s->SetResult(decoder_->GetEmptyResult()); - s->GetResult().decoder_out = std::move(decoder_out); - - // Note: We only update counters. The underlying audio samples - // are not discarded. - s->Reset(); - } - - private: - OnlineRecognizerConfig config_; - std::unique_ptr model_; - std::unique_ptr lm_; - std::unique_ptr decoder_; - SymbolTable sym_; - Endpoint endpoint_; -}; - OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config) - : impl_(std::make_unique(config)) {} + : impl_(OnlineRecognizerImpl::Create(config)) {} #if __ANDROID_API__ >= 9 OnlineRecognizer::OnlineRecognizer(AAssetManager *mgr, const OnlineRecognizerConfig &config) - : impl_(std::make_unique(mgr, config)) {} + : impl_(OnlineRecognizerImpl::Create(mgr, config)) {} #endif OnlineRecognizer::~OnlineRecognizer() = default; diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index bd8321c15..cbac9d08f 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -17,6 +17,7 @@ #include "sherpa-onnx/csrc/endpoint.h" #include "sherpa-onnx/csrc/features.h" #include "sherpa-onnx/csrc/online-lm-config.h" +#include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-stream.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/parse-options.h" @@ -67,7 +68,7 @@ struct OnlineRecognizerResult { struct OnlineRecognizerConfig { FeatureExtractorConfig feat_config; - OnlineTransducerModelConfig model_config; + OnlineModelConfig model_config; OnlineLMConfig lm_config; EndpointConfig endpoint_config; bool enable_endpoint = true; @@ -83,7 +84,7 @@ struct OnlineRecognizerConfig { OnlineRecognizerConfig() = default; OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, - const OnlineTransducerModelConfig &model_config, + const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config, const EndpointConfig &endpoint_config, bool enable_endpoint, @@ -103,6 +104,8 @@ struct OnlineRecognizerConfig { std::string ToString() const; }; +class OnlineRecognizerImpl; + class OnlineRecognizer { public: explicit OnlineRecognizer(const OnlineRecognizerConfig &config); @@ -151,8 +154,7 @@ class OnlineRecognizer { void Reset(OnlineStream *s) const; private: - class Impl; - std::unique_ptr impl_; + std::unique_ptr impl_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-model-config.cc b/sherpa-onnx/csrc/online-transducer-model-config.cc index f13a27913..f7015f98d 100644 --- a/sherpa-onnx/csrc/online-transducer-model-config.cc +++ b/sherpa-onnx/csrc/online-transducer-model-config.cc @@ -11,46 +11,24 @@ namespace sherpa_onnx { void OnlineTransducerModelConfig::Register(ParseOptions *po) { - po->Register("encoder", &encoder_filename, "Path to encoder.onnx"); - po->Register("decoder", &decoder_filename, "Path to decoder.onnx"); - po->Register("joiner", &joiner_filename, "Path to joiner.onnx"); - po->Register("tokens", &tokens, "Path to tokens.txt"); - po->Register("num_threads", &num_threads, - "Number of threads to run the neural network"); - po->Register("provider", &provider, - "Specify a provider to use: cpu, cuda, coreml"); - - po->Register("debug", &debug, - "true to print model information while loading it."); - po->Register("model-type", &model_type, - "Specify it to reduce model initialization time. " - "Valid values are: conformer, lstm, zipformer, zipformer2. " - "All other values lead to loading the model twice."); + po->Register("encoder", &encoder, "Path to encoder.onnx"); + po->Register("decoder", &decoder, "Path to decoder.onnx"); + po->Register("joiner", &joiner, "Path to joiner.onnx"); } bool OnlineTransducerModelConfig::Validate() const { - if (!FileExists(tokens)) { - SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str()); - return false; - } - - if (!FileExists(encoder_filename)) { - SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str()); - return false; - } - - if (!FileExists(decoder_filename)) { - SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str()); + if (!FileExists(encoder)) { + SHERPA_ONNX_LOGE("transducer encoder: %s does not exist", encoder.c_str()); return false; } - if (!FileExists(joiner_filename)) { - SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str()); + if (!FileExists(decoder)) { + SHERPA_ONNX_LOGE("transducer decoder: %s does not exist", decoder.c_str()); return false; } - if (num_threads < 1) { - SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); + if (!FileExists(joiner)) { + SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner.c_str()); return false; } @@ -61,14 +39,9 @@ std::string OnlineTransducerModelConfig::ToString() const { std::ostringstream os; os << "OnlineTransducerModelConfig("; - os << "encoder_filename=\"" << encoder_filename << "\", "; - os << "decoder_filename=\"" << decoder_filename << "\", "; - os << "joiner_filename=\"" << joiner_filename << "\", "; - os << "tokens=\"" << tokens << "\", "; - os << "num_threads=" << num_threads << ", "; - os << "provider=\"" << provider << "\", "; - os << "model_type=\"" << model_type << "\", "; - os << "debug=" << (debug ? "True" : "False") << ")"; + os << "encoder=\"" << encoder << "\", "; + os << "decoder=\"" << decoder << "\", "; + os << "joiner=\"" << joiner << "\")"; return os.str(); } diff --git a/sherpa-onnx/csrc/online-transducer-model-config.h b/sherpa-onnx/csrc/online-transducer-model-config.h index 040dfe283..5d79e25bf 100644 --- a/sherpa-onnx/csrc/online-transducer-model-config.h +++ b/sherpa-onnx/csrc/online-transducer-model-config.h @@ -11,41 +11,15 @@ namespace sherpa_onnx { struct OnlineTransducerModelConfig { - std::string encoder_filename; - std::string decoder_filename; - std::string joiner_filename; - std::string tokens; - int32_t num_threads = 2; - bool debug = false; - std::string provider = "cpu"; - - // With the help of this field, we only need to load the model once - // instead of twice; and therefore it reduces initialization time. - // - // Valid values: - // - conformer - // - lstm - // - zipformer - // - zipformer2 - // - // All other values are invalid and lead to loading the model twice. - std::string model_type; + std::string encoder; + std::string decoder; + std::string joiner; OnlineTransducerModelConfig() = default; - OnlineTransducerModelConfig(const std::string &encoder_filename, - const std::string &decoder_filename, - const std::string &joiner_filename, - const std::string &tokens, int32_t num_threads, - bool debug, const std::string &provider, - const std::string &model_type) - : encoder_filename(encoder_filename), - decoder_filename(decoder_filename), - joiner_filename(joiner_filename), - tokens(tokens), - num_threads(num_threads), - debug(debug), - provider(provider), - model_type(model_type) {} + OnlineTransducerModelConfig(const std::string &encoder, + const std::string &decoder, + const std::string &joiner) + : encoder(encoder), decoder(decoder), joiner(joiner) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index bf4715263..83bdc3906 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -76,7 +76,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, } std::unique_ptr OnlineTransducerModel::Create( - const OnlineTransducerModelConfig &config) { + const OnlineModelConfig &config) { if (!config.model_type.empty()) { const auto &model_type = config.model_type; if (model_type == "conformer") { @@ -96,7 +96,7 @@ std::unique_ptr OnlineTransducerModel::Create( ModelType model_type = ModelType::kUnkown; { - auto buffer = ReadFile(config.encoder_filename); + auto buffer = ReadFile(config.transducer.encoder); model_type = GetModelType(buffer.data(), buffer.size(), config.debug); } @@ -155,7 +155,7 @@ Ort::Value OnlineTransducerModel::BuildDecoderInput( #if __ANDROID_API__ >= 9 std::unique_ptr OnlineTransducerModel::Create( - AAssetManager *mgr, const OnlineTransducerModelConfig &config) { + AAssetManager *mgr, const OnlineModelConfig &config) { if (!config.model_type.empty()) { const auto &model_type = config.model_type; if (model_type == "conformer") { @@ -173,7 +173,7 @@ std::unique_ptr OnlineTransducerModel::Create( } } - auto buffer = ReadFile(mgr, config.encoder_filename); + auto buffer = ReadFile(mgr, config.transducer.encoder); auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); switch (model_type) { diff --git a/sherpa-onnx/csrc/online-transducer-model.h b/sherpa-onnx/csrc/online-transducer-model.h index 42539de98..bbfbec40e 100644 --- a/sherpa-onnx/csrc/online-transducer-model.h +++ b/sherpa-onnx/csrc/online-transducer-model.h @@ -15,6 +15,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/hypothesis.h" +#include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" @@ -27,11 +28,11 @@ class OnlineTransducerModel { virtual ~OnlineTransducerModel() = default; static std::unique_ptr Create( - const OnlineTransducerModelConfig &config); + const OnlineModelConfig &config); #if __ANDROID_API__ >= 9 static std::unique_ptr Create( - AAssetManager *mgr, const OnlineTransducerModelConfig &config); + AAssetManager *mgr, const OnlineModelConfig &config); #endif /** Stack a list of individual states into a batch. @@ -64,15 +65,15 @@ class OnlineTransducerModel { * * @param features A tensor of shape (N, T, C). It is changed in-place. * @param states Encoder state of the previous chunk. It is changed in-place. - * @param processed_frames Processed frames before subsampling. It is a 1-D tensor with data type int64_t. + * @param processed_frames Processed frames before subsampling. It is a 1-D + * tensor with data type int64_t. * * @return Return a tuple containing: * - encoder_out, a tensor of shape (N, T', encoder_out_dim) * - next_states Encoder state for the next chunk. */ virtual std::pair> RunEncoder( - Ort::Value features, - std::vector states, + Ort::Value features, std::vector states, Ort::Value processed_frames) = 0; // NOLINT /** Run the decoder network. diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc index 238a84d39..31234ae74 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc @@ -30,46 +30,46 @@ namespace sherpa_onnx { OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( - const OnlineTransducerModelConfig &config) + const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), sess_opts_(GetSessionOptions(config)), allocator_{} { { - auto buf = ReadFile(config.encoder_filename); + auto buf = ReadFile(config.transducer.encoder); InitEncoder(buf.data(), buf.size()); } { - auto buf = ReadFile(config.decoder_filename); + auto buf = ReadFile(config.transducer.decoder); InitDecoder(buf.data(), buf.size()); } { - auto buf = ReadFile(config.joiner_filename); + auto buf = ReadFile(config.transducer.joiner); InitJoiner(buf.data(), buf.size()); } } #if __ANDROID_API__ >= 9 OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( - AAssetManager *mgr, const OnlineTransducerModelConfig &config) + AAssetManager *mgr, const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), sess_opts_(GetSessionOptions(config)), allocator_{} { { - auto buf = ReadFile(mgr, config.encoder_filename); + auto buf = ReadFile(mgr, config.transducer.encoder); InitEncoder(buf.data(), buf.size()); } { - auto buf = ReadFile(mgr, config.decoder_filename); + auto buf = ReadFile(mgr, config.transducer.decoder); InitDecoder(buf.data(), buf.size()); } { - auto buf = ReadFile(mgr, config.joiner_filename); + auto buf = ReadFile(mgr, config.transducer.joiner); InitJoiner(buf.data(), buf.size()); } } diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.h b/sherpa-onnx/csrc/online-zipformer-transducer-model.h index b3e1966a9..b2b7da040 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.h +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.h @@ -15,19 +15,18 @@ #endif #include "onnxruntime_cxx_api.h" // NOLINT -#include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model.h" namespace sherpa_onnx { class OnlineZipformerTransducerModel : public OnlineTransducerModel { public: - explicit OnlineZipformerTransducerModel( - const OnlineTransducerModelConfig &config); + explicit OnlineZipformerTransducerModel(const OnlineModelConfig &config); #if __ANDROID_API__ >= 9 OnlineZipformerTransducerModel(AAssetManager *mgr, - const OnlineTransducerModelConfig &config); + const OnlineModelConfig &config); #endif std::vector StackStates( @@ -87,7 +86,7 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { std::vector joiner_output_names_; std::vector joiner_output_names_ptr_; - OnlineTransducerModelConfig config_; + OnlineModelConfig config_; std::vector encoder_dims_; std::vector attention_dims_; diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc index d5697e7ff..e818b0bc9 100644 --- a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc @@ -32,46 +32,46 @@ namespace sherpa_onnx { OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( - const OnlineTransducerModelConfig &config) + const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), sess_opts_(GetSessionOptions(config)), allocator_{} { { - auto buf = ReadFile(config.encoder_filename); + auto buf = ReadFile(config.transducer.encoder); InitEncoder(buf.data(), buf.size()); } { - auto buf = ReadFile(config.decoder_filename); + auto buf = ReadFile(config.transducer.decoder); InitDecoder(buf.data(), buf.size()); } { - auto buf = ReadFile(config.joiner_filename); + auto buf = ReadFile(config.transducer.joiner); InitJoiner(buf.data(), buf.size()); } } #if __ANDROID_API__ >= 9 OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( - AAssetManager *mgr, const OnlineTransducerModelConfig &config) + AAssetManager *mgr, const OnlineModelConfig &config) : env_(ORT_LOGGING_LEVEL_WARNING), config_(config), sess_opts_(GetSessionOptions(config)), allocator_{} { { - auto buf = ReadFile(mgr, config.encoder_filename); + auto buf = ReadFile(mgr, config.transducer.encoder); InitEncoder(buf.data(), buf.size()); } { - auto buf = ReadFile(mgr, config.decoder_filename); + auto buf = ReadFile(mgr, config.transducer.decoder); InitDecoder(buf.data(), buf.size()); } { - auto buf = ReadFile(mgr, config.joiner_filename); + auto buf = ReadFile(mgr, config.transducer.joiner); InitJoiner(buf.data(), buf.size()); } } diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.h b/sherpa-onnx/csrc/online-zipformer2-transducer-model.h index 57b63e023..666ad1989 100644 --- a/sherpa-onnx/csrc/online-zipformer2-transducer-model.h +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.h @@ -15,19 +15,18 @@ #endif #include "onnxruntime_cxx_api.h" // NOLINT -#include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model.h" namespace sherpa_onnx { class OnlineZipformer2TransducerModel : public OnlineTransducerModel { public: - explicit OnlineZipformer2TransducerModel( - const OnlineTransducerModelConfig &config); + explicit OnlineZipformer2TransducerModel(const OnlineModelConfig &config); #if __ANDROID_API__ >= 9 OnlineZipformer2TransducerModel(AAssetManager *mgr, - const OnlineTransducerModelConfig &config); + const OnlineModelConfig &config); #endif std::vector StackStates( @@ -87,7 +86,7 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { std::vector joiner_output_names_; std::vector joiner_output_names_ptr_; - OnlineTransducerModelConfig config_; + OnlineModelConfig config_; std::vector encoder_dims_; std::vector query_head_dims_; diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 5c2abac6b..80c9471d3 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -60,8 +60,7 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, return sess_opts; } -Ort::SessionOptions GetSessionOptions( - const OnlineTransducerModelConfig &config) { +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { return GetSessionOptionsImpl(config.num_threads, config.provider); } diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index 7f28742a4..42c93e0be 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -9,12 +9,11 @@ #include "sherpa-onnx/csrc/offline-lm-config.h" #include "sherpa-onnx/csrc/offline-model-config.h" #include "sherpa-onnx/csrc/online-lm-config.h" -#include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/csrc/online-model-config.h" namespace sherpa_onnx { -Ort::SessionOptions GetSessionOptions( - const OnlineTransducerModelConfig &config); +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); diff --git a/sherpa-onnx/csrc/sherpa-onnx-alsa.cc b/sherpa-onnx/csrc/sherpa-onnx-alsa.cc index 835a772b9..d839dee4c 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-alsa.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-alsa.cc @@ -12,6 +12,7 @@ #include "sherpa-onnx/csrc/alsa.h" #include "sherpa-onnx/csrc/display.h" #include "sherpa-onnx/csrc/online-recognizer.h" +#include "sherpa-onnx/csrc/parse-options.h" bool stop = false; @@ -21,19 +22,19 @@ static void Handler(int sig) { } int main(int32_t argc, char *argv[]) { - if (argc < 6 || argc > 8) { - const char *usage = R"usage( + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( Usage: ./bin/sherpa-onnx-alsa \ - /path/to/tokens.txt \ - /path/to/encoder.onnx \ - /path/to/decoder.onnx \ - /path/to/joiner.onnx \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --provider=cpu \ + --num-threads=2 \ + --decoding-method=greedy_search \ device_name \ - [num_threads [decoding_method]] - -Default value for num_threads is 2. -Valid values for decoding_method: greedy_search (default), modified_beam_search. Please refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html @@ -55,44 +56,24 @@ and if you want to select card 3 and the device 0 on that card, please use: hw:3,0 -as the device_name. -)usage"; +or - fprintf(stderr, "%s\n", usage); - fprintf(stderr, "argc, %d\n", argc); - - return 0; - } - - signal(SIGINT, Handler); + plughw:3,0 +as the device_name. +)usage"; + sherpa_onnx::ParseOptions po(kUsageMessage); sherpa_onnx::OnlineRecognizerConfig config; - config.model_config.tokens = argv[1]; - - config.model_config.debug = false; - config.model_config.encoder_filename = argv[2]; - config.model_config.decoder_filename = argv[3]; - config.model_config.joiner_filename = argv[4]; - - const char *device_name = argv[5]; + config.Register(&po); - config.model_config.num_threads = 2; - if (argc == 7 && atoi(argv[6]) > 0) { - config.model_config.num_threads = atoi(argv[6]); + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, "Please provide only 1 argument: the device name\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); } - if (argc == 8) { - config.decoding_method = argv[7]; - } - config.max_active_paths = 4; - - config.enable_endpoint = true; - - config.endpoint_config.rule1.min_trailing_silence = 2.4; - config.endpoint_config.rule2.min_trailing_silence = 1.2; - config.endpoint_config.rule3.min_utterance_length = 300; - fprintf(stderr, "%s\n", config.ToString().c_str()); if (!config.Validate()) { @@ -103,8 +84,9 @@ as the device_name. int32_t expected_sample_rate = config.feat_config.sampling_rate; - sherpa_onnx::Alsa alsa(device_name); - fprintf(stderr, "Use recording device: %s\n", device_name); + std::string device_name = po.GetArg(1); + sherpa_onnx::Alsa alsa(device_name.c_str()); + fprintf(stderr, "Use recording device: %s\n", device_name.c_str()); if (alsa.GetExpectedSampleRate() != expected_sample_rate) { fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(), diff --git a/sherpa-onnx/csrc/text-utils.h b/sherpa-onnx/csrc/text-utils.h index 8f1cb6963..870101527 100644 --- a/sherpa-onnx/csrc/text-utils.h +++ b/sherpa-onnx/csrc/text-utils.h @@ -4,6 +4,7 @@ // Copyright 2023 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_TEXT_UTILS_H_ #define SHERPA_ONNX_CSRC_TEXT_UTILS_H_ +#include . #include #include diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index 5f9a7734e..d05140efd 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -159,47 +159,47 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { //---------- model config ---------- fid = env->GetFieldID(cls, "modelConfig", "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;"); - jobject model_config = env->GetObjectField(config, fid); - jclass model_config_cls = env->GetObjectClass(model_config); + jobject transducer_config = env->GetObjectField(config, fid); + jclass model_config_cls = env->GetObjectClass(transducer_config); fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model_config, fid); + s = (jstring)env->GetObjectField(transducer_config, fid); p = env->GetStringUTFChars(s, nullptr); - ans.model_config.encoder_filename = p; + ans.model_config.transducer.encoder = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "decoder", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model_config, fid); + s = (jstring)env->GetObjectField(transducer_config, fid); p = env->GetStringUTFChars(s, nullptr); - ans.model_config.decoder_filename = p; + ans.model_config.transducer.decoder = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "joiner", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model_config, fid); + s = (jstring)env->GetObjectField(transducer_config, fid); p = env->GetStringUTFChars(s, nullptr); - ans.model_config.joiner_filename = p; + ans.model_config.transducer.joiner = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model_config, fid); + s = (jstring)env->GetObjectField(transducer_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.model_config.tokens = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "numThreads", "I"); - ans.model_config.num_threads = env->GetIntField(model_config, fid); + ans.model_config.num_threads = env->GetIntField(transducer_config, fid); fid = env->GetFieldID(model_config_cls, "debug", "Z"); - ans.model_config.debug = env->GetBooleanField(model_config, fid); + ans.model_config.debug = env->GetBooleanField(transducer_config, fid); fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model_config, fid); + s = (jstring)env->GetObjectField(transducer_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.model_config.provider = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(model_config, fid); + s = (jstring)env->GetObjectField(transducer_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.model_config.model_type = p; env->ReleaseStringUTFChars(s, p); @@ -328,7 +328,7 @@ JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens( jobjectArray result = env->NewObjectArray(size, stringClass, NULL); for (int i = 0; i < size; i++) { // Convert the C++ string to a C string - const char* cstr = tokens[i].c_str(); + const char *cstr = tokens[i].c_str(); // Convert the C string to a jstring jstring jstr = env->NewStringUTF(cstr); diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index b1d9db522..e58d60d90 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -13,6 +13,7 @@ pybind11_add_module(_sherpa_onnx offline-transducer-model-config.cc offline-whisper-model-config.cc online-lm-config.cc + online-model-config.cc online-recognizer.cc online-stream.cc online-transducer-model-config.cc diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc new file mode 100644 index 000000000..677d3b1f9 --- /dev/null +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -0,0 +1,35 @@ +// sherpa-onnx/python/csrc/online-model-config.cc +// +// Copyright (c) 2023 by manyeyes + +#include "sherpa-onnx/python/csrc/online-model-config.h" + +#include +#include + +#include "sherpa-onnx/csrc/online-model-config.h" +#include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/python/csrc/online-transducer-model-config.h" + +namespace sherpa_onnx { + +void PybindOnlineModelConfig(py::module *m) { + PybindOnlineTransducerModelConfig(m); + + using PyClass = OnlineModelConfig; + py::class_(*m, "OnlineModelConfig") + .def(py::init(), + py::arg("transducer") = OnlineTransducerModelConfig(), + py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, + py::arg("provider") = "cpu", py::arg("model_type") = "") + .def_readwrite("transducer", &PyClass::transducer) + .def_readwrite("tokens", &PyClass::tokens) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def_readwrite("model_type", &PyClass::model_type) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-model-config.h b/sherpa-onnx/python/csrc/online-model-config.h new file mode 100644 index 000000000..73154fc9e --- /dev/null +++ b/sherpa-onnx/python/csrc/online-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/online-model-config.h +// +// Copyright (c) 2023 by manyeyes + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOnlineModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index ec1b0af86..34a907ce5 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -27,10 +27,9 @@ static void PybindOnlineRecognizerResult(py::module *m) { static void PybindOnlineRecognizerConfig(py::module *m) { using PyClass = OnlineRecognizerConfig; py::class_(*m, "OnlineRecognizerConfig") - .def(py::init(), + .def(py::init(), py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), py::arg("enable_endpoint"), py::arg("decoding_method"), diff --git a/sherpa-onnx/python/csrc/online-transducer-model-config.cc b/sherpa-onnx/python/csrc/online-transducer-model-config.cc index 8246acdd3..d1591090b 100644 --- a/sherpa-onnx/python/csrc/online-transducer-model-config.cc +++ b/sherpa-onnx/python/csrc/online-transducer-model-config.cc @@ -14,20 +14,11 @@ void PybindOnlineTransducerModelConfig(py::module *m) { using PyClass = OnlineTransducerModelConfig; py::class_(*m, "OnlineTransducerModelConfig") .def(py::init(), - py::arg("encoder_filename"), py::arg("decoder_filename"), - py::arg("joiner_filename"), py::arg("tokens"), - py::arg("num_threads"), py::arg("debug") = false, - py::arg("provider") = "cpu", py::arg("model_type") = "") - .def_readwrite("encoder_filename", &PyClass::encoder_filename) - .def_readwrite("decoder_filename", &PyClass::decoder_filename) - .def_readwrite("joiner_filename", &PyClass::joiner_filename) - .def_readwrite("tokens", &PyClass::tokens) - .def_readwrite("num_threads", &PyClass::num_threads) - .def_readwrite("debug", &PyClass::debug) - .def_readwrite("provider", &PyClass::provider) - .def_readwrite("model_type", &PyClass::model_type) + const std::string &>(), + py::arg("encoder"), py::arg("decoder"), py::arg("joiner")) + .def_readwrite("encoder", &PyClass::encoder) + .def_readwrite("decoder", &PyClass::decoder) + .def_readwrite("joiner", &PyClass::joiner) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 0850ee349..64f8aacf8 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -12,9 +12,9 @@ #include "sherpa-onnx/python/csrc/offline-recognizer.h" #include "sherpa-onnx/python/csrc/offline-stream.h" #include "sherpa-onnx/python/csrc/online-lm-config.h" +#include "sherpa-onnx/python/csrc/online-model-config.h" #include "sherpa-onnx/python/csrc/online-recognizer.h" #include "sherpa-onnx/python/csrc/online-stream.h" -#include "sherpa-onnx/python/csrc/online-transducer-model-config.h" namespace sherpa_onnx { @@ -22,7 +22,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { m.doc() = "pybind11 binding of sherpa-onnx"; PybindFeatures(&m); - PybindOnlineTransducerModelConfig(&m); + PybindOnlineModelConfig(&m); PybindOnlineLMConfig(&m); PybindOnlineStream(&m); PybindEndpoint(&m); diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index e0f47068a..c49e1b438 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -5,6 +5,7 @@ from _sherpa_onnx import ( EndpointConfig, FeatureExtractorConfig, + OnlineModelConfig, OnlineRecognizer as _Recognizer, OnlineRecognizerConfig, OnlineStream, @@ -24,8 +25,9 @@ class OnlineRecognizer(object): - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/online-decode-files.py """ - def __init__( - self, + @classmethod + def from_transducer( + cls, tokens: str, encoder: str, decoder: str, @@ -95,6 +97,7 @@ def __init__( Online transducer model type. Valid values are: conformer, lstm, zipformer, zipformer2. All other values lead to loading the model twice. """ + self = cls.__new__(cls) _assert_file_exists(tokens) _assert_file_exists(encoder) _assert_file_exists(decoder) @@ -102,10 +105,14 @@ def __init__( assert num_threads > 0, num_threads - model_config = OnlineTransducerModelConfig( - encoder_filename=encoder, - decoder_filename=decoder, - joiner_filename=joiner, + transducer_config = OnlineTransducerModelConfig( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + + model_config = OnlineModelConfig( + transducer=transducer_config, tokens=tokens, num_threads=num_threads, provider=provider, @@ -135,6 +142,7 @@ def __init__( self.recognizer = _Recognizer(recognizer_config) self.config = recognizer_config + return self def create_stream(self, contexts_list: Optional[List[List[int]]] = None): if contexts_list is None: diff --git a/sherpa-onnx/python/tests/test_online_recognizer.py b/sherpa-onnx/python/tests/test_online_recognizer.py index 0769b5e65..f5c15e5c2 100755 --- a/sherpa-onnx/python/tests/test_online_recognizer.py +++ b/sherpa-onnx/python/tests/test_online_recognizer.py @@ -65,7 +65,7 @@ def test_transducer_single_file(self): return for decoding_method in ["greedy_search", "modified_beam_search"]: - recognizer = sherpa_onnx.OnlineRecognizer( + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( encoder=encoder, decoder=decoder, joiner=joiner, @@ -109,7 +109,7 @@ def test_transducer_multiple_files(self): return for decoding_method in ["greedy_search", "modified_beam_search"]: - recognizer = sherpa_onnx.OnlineRecognizer( + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( encoder=encoder, decoder=decoder, joiner=joiner, diff --git a/sherpa-onnx/python/tests/test_online_transducer_model_config.py b/sherpa-onnx/python/tests/test_online_transducer_model_config.py index 6c41bb4f5..e161437e9 100644 --- a/sherpa-onnx/python/tests/test_online_transducer_model_config.py +++ b/sherpa-onnx/python/tests/test_online_transducer_model_config.py @@ -14,19 +14,13 @@ class TestOnlineTransducerModelConfig(unittest.TestCase): def test_constructor(self): config = _sherpa_onnx.OnlineTransducerModelConfig( - encoder_filename="encoder.onnx", - decoder_filename="decoder.onnx", - joiner_filename="joiner.onnx", - tokens="tokens.txt", - num_threads=8, - debug=True, + encoder="encoder.onnx", + decoder="decoder.onnx", + joiner="joiner.onnx", ) - assert config.encoder_filename == "encoder.onnx", config.encoder_filename - assert config.decoder_filename == "decoder.onnx", config.decoder_filename - assert config.joiner_filename == "joiner.onnx", config.joiner_filename - assert config.tokens == "tokens.txt", config.tokens - assert config.num_threads == 8, config.num_threads - assert config.debug is True, config.debug + assert config.encoder == "encoder.onnx", config.encoder + assert config.decoder == "decoder.onnx", config.decoder + assert config.joiner == "joiner.onnx", config.joiner print(config)