diff --git a/.github/workflows/test-pip-install.yaml b/.github/workflows/test-pip-install.yaml index 6467fc228..34a153600 100644 --- a/.github/workflows/test-pip-install.yaml +++ b/.github/workflows/test-pip-install.yaml @@ -23,12 +23,12 @@ permissions: jobs: test_pip_install: runs-on: ${{ matrix.os }} - name: Test pip install on ${{ matrix.os }} + name: ${{ matrix.os }} ${{ matrix.python-version }} strategy: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 @@ -50,3 +50,15 @@ jobs: run: | python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)" python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)" + + sherpa-onnx --help + sherpa-onnx-offline --help + + sherpa-onnx-microphone --help + sherpa-onnx-microphone-offline --help + + sherpa-onnx-offline-websocket-server --help + sherpa-onnx-offline-websocket-client --help + + sherpa-onnx-online-websocket-server --help + sherpa-onnx-online-websocket-client --help diff --git a/.github/workflows/test-python-offline-websocket-server.yaml b/.github/workflows/test-python-offline-websocket-server.yaml new file mode 100644 index 000000000..ce538d180 --- /dev/null +++ b/.github/workflows/test-python-offline-websocket-server.yaml @@ -0,0 +1,174 @@ +name: Python offline websocket server + +on: + push: + branches: + - master + pull_request: + branches: + - master + +concurrency: + group: python-offline-websocket-server-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + python_offline_websocket_server: + runs-on: ${{ matrix.os }} + name: ${{ matrix.os }} ${{ matrix.python-version }} ${{ matrix.model_type }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + model_type: ["transducer", "paraformer", "nemo_ctc", "whisper"] + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Python dependencies + shell: bash + run: | + python3 -m pip install --upgrade pip numpy + + - name: Install sherpa-onnx + shell: bash + run: | + python3 -m pip install --no-deps --verbose . + python3 -m pip install websockets + + + - name: Start server for transducer models + if: matrix.model_type == 'transducer' + shell: bash + run: | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-2023-06-26 + cd sherpa-onnx-zipformer-en-2023-06-26 + git lfs pull --include "*.onnx" + cd .. + + python3 ./python-api-examples/non_streaming_server.py \ + --encoder ./sherpa-onnx-zipformer-en-2023-06-26/encoder-epoch-99-avg-1.onnx \ + --decoder ./sherpa-onnx-zipformer-en-2023-06-26/decoder-epoch-99-avg-1.onnx \ + --joiner ./sherpa-onnx-zipformer-en-2023-06-26/joiner-epoch-99-avg-1.onnx \ + --tokens ./sherpa-onnx-zipformer-en-2023-06-26/tokens.txt & + + echo "sleep 10 seconds to wait the server start" + sleep 10 + + - name: Start client for transducer models + if: matrix.model_type == 'transducer' + shell: bash + run: | + python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \ + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/0.wav \ + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/1.wav \ + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/8k.wav + + python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \ + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/0.wav \ + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/1.wav \ + ./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/8k.wav + + - name: Start server for paraformer models + if: matrix.model_type == 'paraformer' + shell: bash + run: | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 + cd sherpa-onnx-paraformer-zh-2023-03-28 + git lfs pull --include "*.onnx" + cd .. + + python3 ./python-api-examples/non_streaming_server.py \ + --paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \ + --tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt & + + echo "sleep 10 seconds to wait the server start" + sleep 10 + + - name: Start client for paraformer models + if: matrix.model_type == 'paraformer' + shell: bash + run: | + python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \ + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \ + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \ + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \ + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav + + python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \ + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \ + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \ + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \ + ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav + + - name: Start server for nemo_ctc models + if: matrix.model_type == 'nemo_ctc' + shell: bash + run: | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-medium + cd sherpa-onnx-nemo-ctc-en-conformer-medium + git lfs pull --include "*.onnx" + cd .. + + python3 ./python-api-examples/non_streaming_server.py \ + --nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \ + --tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt & + + echo "sleep 10 seconds to wait the server start" + sleep 10 + + - name: Start client for nemo_ctc models + if: matrix.model_type == 'nemo_ctc' + shell: bash + run: | + python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \ + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav \ + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \ + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav + + python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \ + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav \ + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/1.wav \ + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/8k.wav + + - name: Start server for whisper models + if: matrix.model_type == 'whisper' + shell: bash + run: | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en + cd sherpa-onnx-whisper-tiny.en + git lfs pull --include "*.onnx" + cd .. + + python3 ./python-api-examples/non_streaming_server.py \ + --whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ + --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt & + + echo "sleep 10 seconds to wait the server start" + sleep 10 + + - name: Start client for whisper models + if: matrix.model_type == 'whisper' + shell: bash + run: | + python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \ + ./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \ + ./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \ + ./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav + + python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \ + ./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \ + ./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \ + ./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav diff --git a/.github/workflows/test-python-online-websocket-server.yaml b/.github/workflows/test-python-online-websocket-server.yaml new file mode 100644 index 000000000..c7e3319d6 --- /dev/null +++ b/.github/workflows/test-python-online-websocket-server.yaml @@ -0,0 +1,73 @@ +name: Python online websocket server + +on: + push: + branches: + - master + pull_request: + branches: + - master + +concurrency: + group: python-online-websocket-server-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + python_online_websocket_server: + runs-on: ${{ matrix.os }} + name: ${{ matrix.os }} ${{ matrix.python-version }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + model_type: ["transducer"] + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Python dependencies + shell: bash + run: | + python3 -m pip install --upgrade pip numpy + + - name: Install sherpa-onnx + shell: bash + run: | + python3 -m pip install --no-deps --verbose . + python3 -m pip install websockets + + + - name: Start server for transducer models + if: matrix.model_type == 'transducer' + shell: bash + run: | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26 + cd sherpa-onnx-streaming-zipformer-en-2023-06-26 + git lfs pull --include "*.onnx" + cd .. + + python3 ./python-api-examples/streaming_server.py \ + --encoder ./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-128.onnx \ + --decoder ./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-128.onnx \ + --joiner ./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-128.onnx \ + --tokens ./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt & + echo "sleep 10 seconds to wait the server start" + sleep 10 + + - name: Start client for transducer models + if: matrix.model_type == 'transducer' + shell: bash + run: | + python3 ./python-api-examples/online-websocket-client-decode-file.py \ + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav diff --git a/CMakeLists.txt b/CMakeLists.txt index 89ee2c87f..c0ff29011 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.7.1") +set(SHERPA_ONNX_VERSION "1.7.2") # Disable warning about # diff --git a/c-api-examples/README.md b/c-api-examples/README.md new file mode 100644 index 000000000..85f2e505f --- /dev/null +++ b/c-api-examples/README.md @@ -0,0 +1,9 @@ +# Introduction + +This folder contains C API examples for [sherpa-onnx][sherpa-onnx]. + +Please refer to the documentation +https://k2-fsa.github.io/sherpa/onnx/c-api/index.html +for details. + +[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx diff --git a/dotnet-examples/README.md b/dotnet-examples/README.md new file mode 100644 index 000000000..d65c942eb --- /dev/null +++ b/dotnet-examples/README.md @@ -0,0 +1,9 @@ +# Introduction + +This folder contains C# API examples for [sherpa-onnx][sherpa-onnx]. + +Please refer to the documentation +https://k2-fsa.github.io/sherpa/onnx/csharp-api/index.html +for details. + +[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx diff --git a/go-api-examples/README.md b/go-api-examples/README.md new file mode 100644 index 000000000..8d50381b9 --- /dev/null +++ b/go-api-examples/README.md @@ -0,0 +1,9 @@ +# Introduction + +This folder contains Go API examples for [sherpa-onnx][sherpa-onnx]. + +Please refer to the documentation +https://k2-fsa.github.io/sherpa/onnx/go-api/index.html +for details. + +[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx diff --git a/python-api-examples/non_streaming_server.py b/python-api-examples/non_streaming_server.py new file mode 100755 index 000000000..52210d8b1 --- /dev/null +++ b/python-api-examples/non_streaming_server.py @@ -0,0 +1,835 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. +""" +A server for non-streaming speech recognition. Non-streaming means you send all +the content of the audio at once for recognition. + +It supports multiple clients sending at the same time. + +Usage: + ./non_streaming_server.py --help + +Please refer to + +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html + +for pre-trained models to download. + +Usage examples: + +(1) Use a non-streaming transducer model + +cd /path/to/sherpa-onnx +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-2023-06-26 +cd sherpa-onnx-zipformer-en-2023-06-26 +git lfs pull --include "*.onnx" +cd .. + +python3 ./python-api-examples/non_streaming_server.py \ + --encoder ./sherpa-onnx-zipformer-en-2023-06-26/encoder-epoch-99-avg-1.onnx \ + --decoder ./sherpa-onnx-zipformer-en-2023-06-26/decoder-epoch-99-avg-1.onnx \ + --joiner ./sherpa-onnx-zipformer-en-2023-06-26/joiner-epoch-99-avg-1.onnx \ + --tokens ./sherpa-onnx-zipformer-en-2023-06-26/tokens.txt + +(2) Use a non-streaming paraformer + +cd /path/to/sherpa-onnx +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 +cd sherpa-onnx-paraformer-zh-2023-03-28 +git lfs pull --include "*.onnx" +cd .. + +python3 ./python-api-examples/non_streaming_server.py \ + --paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \ + --tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt + +(3) Use a non-streaming CTC model from NeMo + +cd /path/to/sherpa-onnx +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-medium +cd sherpa-onnx-nemo-ctc-en-conformer-medium +git lfs pull --include "*.onnx" +cd .. + +python3 ./python-api-examples/non_streaming_server.py \ + --nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \ + --tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt + +(4) Use a Whisper model + +cd /path/to/sherpa-onnx +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en +cd sherpa-onnx-whisper-tiny.en +git lfs pull --include "*.onnx" +cd .. + +python3 ./python-api-examples/non_streaming_server.py \ + --whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ + --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt + +---- + +To use a certificate so that you can use https, please use + +python3 ./python-api-examples/non_streaming_server.py \ + --whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ + --certificate=/path/to/your/cert.pem + +If you don't have a certificate, please run: + + cd ./python-api-examples/web + ./generate-certificate.py + +It will generate 3 files, one of which is the required `cert.pem`. +""" # noqa + +import argparse +import asyncio +import http +import logging +import socket +import ssl +import sys +import warnings +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from pathlib import Path +from typing import Optional, Tuple + +import numpy as np +import sherpa_onnx + +import websockets + +from http_server import HttpServer + + +def setup_logger( + log_filename: str, + log_level: str = "info", + use_console: bool = True, +) -> None: + """Setup log level. + + Args: + log_filename: + The filename to save the log. + log_level: + The log level to use, e.g., "debug", "info", "warning", "error", + "critical" + use_console: + True to also print logs to console. + """ + now = datetime.now() + date_time = now.strftime("%Y-%m-%d-%H-%M-%S") + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + log_filename = f"{log_filename}-{date_time}.txt" + + Path(log_filename).parent.mkdir(parents=True, exist_ok=True) + + level = logging.ERROR + if log_level == "debug": + level = logging.DEBUG + elif log_level == "info": + level = logging.INFO + elif log_level == "warning": + level = logging.WARNING + elif log_level == "critical": + level = logging.CRITICAL + + logging.basicConfig( + filename=log_filename, + format=formatter, + level=level, + filemode="w", + ) + if use_console: + console = logging.StreamHandler() + console.setLevel(level) + console.setFormatter(logging.Formatter(formatter)) + logging.getLogger("").addHandler(console) + + +def add_transducer_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--encoder", + default="", + type=str, + help="Path to the transducer encoder model", + ) + + parser.add_argument( + "--decoder", + default="", + type=str, + help="Path to the transducer decoder model", + ) + + parser.add_argument( + "--joiner", + default="", + type=str, + help="Path to the transducer joiner model", + ) + + +def add_paraformer_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--paraformer", + default="", + type=str, + help="Path to the model.onnx from Paraformer", + ) + + +def add_nemo_ctc_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--nemo-ctc", + default="", + type=str, + help="Path to the model.onnx from NeMo CTC", + ) + + +def add_whisper_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--whisper-encoder", + default="", + type=str, + help="Path to whisper encoder model", + ) + + parser.add_argument( + "--whisper-decoder", + default="", + type=str, + help="Path to whisper decoder model", + ) + + +def add_model_args(parser: argparse.ArgumentParser): + add_transducer_model_args(parser) + add_paraformer_model_args(parser) + add_nemo_ctc_model_args(parser) + add_whisper_model_args(parser) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=2, + help="Number of threads to run the neural network model", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + +def add_feature_config_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="Sample rate of the data used to train the model. ", + ) + + parser.add_argument( + "--feat-dim", + type=int, + default=80, + help="Feature dimension of the model", + ) + + +def add_decoding_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Decoding method to use. Current supported methods are: + - greedy_search + - modified_beam_search (for transducer models only) + """, + ) + + add_modified_beam_search_args(parser) + + +def add_modified_beam_search_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--max-active-paths", + type=int, + default=4, + help="""Used only when --decoding-method is modified_beam_search. + It specifies number of active paths to keep during decoding. + """, + ) + + +def check_args(args): + if not Path(args.tokens).is_file(): + raise ValueError(f"{args.tokens} does not exist") + + if args.decoding_method not in ( + "greedy_search", + "modified_beam_search", + ): + raise ValueError(f"Unsupported decoding method {args.decoding_method}") + + if args.decoding_method == "modified_beam_search": + assert args.num_active_paths > 0, args.num_active_paths + assert Path(args.encoder).is_file(), args.encoder + assert Path(args.decoder).is_file(), args.decoder + assert Path(args.joiner).is_file(), args.joiner + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + add_model_args(parser) + add_feature_config_args(parser) + add_decoding_args(parser) + + parser.add_argument( + "--port", + type=int, + default=6006, + help="The server will listen on this port", + ) + + parser.add_argument( + "--max-batch-size", + type=int, + default=25, + help="""Max batch size for computation. Note if there are not enough + requests in the queue, it will wait for max_wait_ms time. After that, + even if there are not enough requests, it still sends the + available requests in the queue for computation. + """, + ) + + parser.add_argument( + "--max-wait-ms", + type=float, + default=5, + help="""Max time in millisecond to wait to build batches for inference. + If there are not enough requests in the feature queue to build a batch + of max_batch_size, it waits up to this time before fetching available + requests for computation. + """, + ) + + parser.add_argument( + "--nn-pool-size", + type=int, + default=1, + help="Number of threads for NN computation and decoding.", + ) + + parser.add_argument( + "--max-message-size", + type=int, + default=(1 << 20), + help="""Max message size in bytes. + The max size per message cannot exceed this limit. + """, + ) + + parser.add_argument( + "--max-queue-size", + type=int, + default=32, + help="Max number of messages in the queue for each connection.", + ) + + parser.add_argument( + "--max-active-connections", + type=int, + default=500, + help="""Maximum number of active connections. The server will refuse + to accept new connections once the current number of active connections + equals to this limit. + """, + ) + + parser.add_argument( + "--certificate", + type=str, + help="""Path to the X.509 certificate. You need it only if you want to + use a secure websocket connection, i.e., use wss:// instead of ws://. + You can use ./web/generate-certificate.py + to generate the certificate `cert.pem`. + Note ./web/generate-certificate.py will generate three files but you + only need to pass the generated cert.pem to this option. + """, + ) + + parser.add_argument( + "--doc-root", + type=str, + default="./python-api-examples/web", + help="Path to the web root", + ) + + return parser.parse_args() + + +class NonStreamingServer: + def __init__( + self, + recognizer: sherpa_onnx.OfflineRecognizer, + max_batch_size: int, + max_wait_ms: float, + nn_pool_size: int, + max_message_size: int, + max_queue_size: int, + max_active_connections: int, + doc_root: str, + certificate: Optional[str] = None, + ): + """ + Args: + recognizer: + An instance of the sherpa_onnx.OfflineRecognizer. + max_batch_size: + Max batch size for inference. + max_wait_ms: + Max wait time in milliseconds in order to build a batch of + `max_batch_size`. + nn_pool_size: + Number of threads for the thread pool that is used for NN + computation and decoding. + max_message_size: + Max size in bytes per message. + max_queue_size: + Max number of messages in the queue for each connection. + max_active_connections: + Max number of active connections. Once number of active client + equals to this limit, the server refuses to accept new connections. + doc_root: + Path to the directory where files like index.html for the HTTP + server locate. + certificate: + Optional. If not None, it will use secure websocket. + You can use ./web/generate-certificate.py to generate + it (the default generated filename is `cert.pem`). + """ + self.recognizer = recognizer + + self.certificate = certificate + self.http_server = HttpServer(doc_root) + + self.nn_pool = ThreadPoolExecutor( + max_workers=nn_pool_size, + thread_name_prefix="nn", + ) + + self.stream_queue = asyncio.Queue() + + self.max_wait_ms = max_wait_ms + self.max_batch_size = max_batch_size + self.max_message_size = max_message_size + self.max_queue_size = max_queue_size + self.max_active_connections = max_active_connections + + self.current_active_connections = 0 + self.sample_rate = int(recognizer.config.feat_config.sampling_rate) + + async def process_request( + self, + path: str, + request_headers: websockets.Headers, + ) -> Optional[Tuple[http.HTTPStatus, websockets.Headers, bytes]]: + if "sec-websocket-key" not in request_headers: + # This is a normal HTTP request + if path == "/": + path = "/index.html" + if path[-1] == "?": + path = path[:-1] + + if path == "/streaming_record.html": + response = r""" + +Speech recognition with next-gen Kaldi +

Only +/upload.html +and +/offline_record.html +is available for the non-streaming server.

+
+
+Go back to /upload.html +or /offline_record.html + +""" + found = True + mime_type = "text/html" + else: + found, response, mime_type = self.http_server.process_request(path) + if isinstance(response, str): + response = response.encode("utf-8") + + if not found: + status = http.HTTPStatus.NOT_FOUND + else: + status = http.HTTPStatus.OK + header = {"Content-Type": mime_type} + return status, header, response + + if self.current_active_connections < self.max_active_connections: + self.current_active_connections += 1 + return None + + # Refuse new connections + status = http.HTTPStatus.SERVICE_UNAVAILABLE # 503 + header = {"Hint": "The server is overloaded. Please retry later."} + response = b"The server is busy. Please retry later." + + return status, header, response + + async def run(self, port: int): + logging.info("started") + + task = asyncio.create_task(self.stream_consumer_task()) + + if self.certificate: + logging.info(f"Using certificate: {self.certificate}") + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_context.load_cert_chain(self.certificate) + else: + ssl_context = None + logging.info("No certificate provided") + + async with websockets.serve( + self.handle_connection, + host="", + port=port, + max_size=self.max_message_size, + max_queue=self.max_queue_size, + process_request=self.process_request, + ssl=ssl_context, + ): + ip_list = ["localhost"] + if ssl_context: + ip_list += ["0.0.0.0", "127.0.0.1"] + ip_list.append(socket.gethostbyname(socket.gethostname())) + + proto = "http://" if ssl_context is None else "https://" + s = "Please visit one of the following addresses:\n\n" + for p in ip_list: + s += " " + proto + p + f":{port}" "\n" + logging.info(s) + + await asyncio.Future() # run forever + + await task # not reachable + + async def recv_audio_samples( + self, + socket: websockets.WebSocketServerProtocol, + ) -> Tuple[Optional[np.ndarray], Optional[float]]: + """Receive a tensor from the client. + + The message from the client is a **bytes** buffer. + + The first message can be either "Done" meaning the client won't send + anything in the future or it can be a buffer containing 8 bytes. + The first 4 bytes in little endian specifies the sample + rate of the audio samples; the second 4 bytes in little endian specifies + the number of bytes in the audio file, which will be sent by the client + in the subsequent messages. + Since there is a limit in the message size posed by the websocket + protocol, the client may send the audio file in multiple messages if the + audio file is very large. + + The second and remaining messages contain audio samples. + + Please refer to ./offline-websocket-client-decode-files-paralell.py + and ./offline-websocket-client-decode-files-sequential.py + for how the client sends the message. + + Args: + socket: + The socket for communicating with the client. + Returns: + Return a containing: + - 1-D np.float32 array containing the audio samples + - sample rate of the audio samples + or return (None, None) indicating the end of utterance. + """ + header = await socket.recv() + if header == "Done": + return None, None + + assert len(header) >= 8, ( + "The first message should contain at least 8 bytes." + + f"Given {len(header)}" + ) + + sample_rate = int.from_bytes(header[:4], "little", signed=True) + expected_num_bytes = int.from_bytes(header[4:8], "little", signed=True) + + received = [] + num_received_bytes = 0 + if len(header) > 8: + received.append(header[8:]) + num_received_bytes += len(header) - 8 + + if num_received_bytes < expected_num_bytes: + async for message in socket: + received.append(message) + num_received_bytes += len(message) + if num_received_bytes >= expected_num_bytes: + break + + assert num_received_bytes == expected_num_bytes, ( + num_received_bytes, + expected_num_bytes, + ) + + samples = b"".join(received) + array = np.frombuffer(samples, dtype=np.float32) + return array, sample_rate + + async def stream_consumer_task(self): + """This function extracts streams from the queue, batches them up, sends + them to the RNN-T model for computation and decoding. + """ + while True: + if self.stream_queue.empty(): + await asyncio.sleep(self.max_wait_ms / 1000) + continue + + batch = [] + try: + while len(batch) < self.max_batch_size: + item = self.stream_queue.get_nowait() + + batch.append(item) + except asyncio.QueueEmpty: + pass + stream_list = [b[0] for b in batch] + future_list = [b[1] for b in batch] + + loop = asyncio.get_running_loop() + await loop.run_in_executor( + self.nn_pool, + self.recognizer.decode_streams, + stream_list, + ) + + for f in future_list: + self.stream_queue.task_done() + f.set_result(None) + + async def compute_and_decode( + self, + stream: sherpa_onnx.OfflineStream, + ) -> None: + """Put the stream into the queue and wait it to be processed by the + consumer task. + + Args: + stream: + The stream to be processed. Note: It is changed in-place. + """ + loop = asyncio.get_running_loop() + future = loop.create_future() + await self.stream_queue.put((stream, future)) + await future + + async def handle_connection( + self, + socket: websockets.WebSocketServerProtocol, + ): + """Receive audio samples from the client, process it, and sends + deocoding result back to the client. + + Args: + socket: + The socket for communicating with the client. + """ + try: + await self.handle_connection_impl(socket) + except websockets.exceptions.ConnectionClosedError: + logging.info(f"{socket.remote_address} disconnected") + finally: + # Decrement so that it can accept new connections + self.current_active_connections -= 1 + + logging.info( + f"Disconnected: {socket.remote_address}. " + f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa + ) + + async def handle_connection_impl( + self, + socket: websockets.WebSocketServerProtocol, + ): + """Receive audio samples from the client, process it, and send + decoding results back to the client. + + Args: + socket: + The socket for communicating with the client. + """ + logging.info( + f"Connected: {socket.remote_address}. " + f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa + ) + + while True: + stream = self.recognizer.create_stream() + samples, sample_rate = await self.recv_audio_samples(socket) + if samples is None: + break + # stream.accept_samples() runs in the main thread + + stream.accept_waveform(sample_rate, samples) + + await self.compute_and_decode(stream) + result = stream.result.text + logging.info(f"result: {result}") + + if result: + await socket.send(result) + else: + # If result is an empty string, send something to the client. + # Otherwise, socket.send() is a no-op and the client will + # wait for a reply indefinitely. + await socket.send("") + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: + if args.encoder: + assert len(args.paraformer) == 0, args.paraformer + assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + max_active_paths=args.max_active_paths, + ) + elif args.paraformer: + assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + + assert_file_exists(args.paraformer) + + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( + paraformer=args.paraformer, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + ) + elif args.nemo_ctc: + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + + assert_file_exists(args.nemo_ctc) + + recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( + model=args.nemo_ctc, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + ) + elif args.whisper_encoder: + assert_file_exists(args.whisper_encoder) + assert_file_exists(args.whisper_decoder) + + recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( + encoder=args.whisper_encoder, + decoder=args.whisper_decoder, + tokens=args.tokens, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + ) + else: + raise ValueError("Please specify at least one model") + + return recognizer + + +def main(): + args = get_args() + logging.info(vars(args)) + check_args(args) + + recognizer = create_recognizer(args) + + port = args.port + max_wait_ms = args.max_wait_ms + max_batch_size = args.max_batch_size + nn_pool_size = args.nn_pool_size + max_message_size = args.max_message_size + max_queue_size = args.max_queue_size + max_active_connections = args.max_active_connections + certificate = args.certificate + doc_root = args.doc_root + + if certificate and not Path(certificate).is_file(): + raise ValueError(f"{certificate} does not exist") + + if not Path(doc_root).is_dir(): + raise ValueError(f"Directory {doc_root} does not exist") + + non_streaming_server = NonStreamingServer( + recognizer=recognizer, + max_wait_ms=max_wait_ms, + max_batch_size=max_batch_size, + nn_pool_size=nn_pool_size, + max_message_size=max_message_size, + max_queue_size=max_queue_size, + max_active_connections=max_active_connections, + certificate=certificate, + doc_root=doc_root, + ) + asyncio.run(non_streaming_server.run(port)) + + +if __name__ == "__main__": + log_filename = "log/log-non-streaming-server" + setup_logger(log_filename) + main() diff --git a/python-api-examples/offline-websocket-client-decode-files-paralell.py b/python-api-examples/offline-websocket-client-decode-files-paralell.py index d1d691a29..f97f9c44a 100755 --- a/python-api-examples/offline-websocket-client-decode-files-paralell.py +++ b/python-api-examples/offline-websocket-client-decode-files-paralell.py @@ -119,7 +119,13 @@ async def run( buf += (samples.size * 4).to_bytes(4, byteorder="little") buf += samples.tobytes() - await websocket.send(buf) + payload_len = 10240 + while len(buf) > payload_len: + await websocket.send(buf[:payload_len]) + buf = buf[payload_len:] + + if buf: + await websocket.send(buf) decoding_results = await websocket.recv() logging.info(f"{wave_filename}\n{decoding_results}") diff --git a/python-api-examples/offline-websocket-client-decode-files-sequential.py b/python-api-examples/offline-websocket-client-decode-files-sequential.py index 935226e46..7dac1fc07 100755 --- a/python-api-examples/offline-websocket-client-decode-files-sequential.py +++ b/python-api-examples/offline-websocket-client-decode-files-sequential.py @@ -116,11 +116,18 @@ async def run( assert isinstance(sample_rate, int) assert samples.dtype == np.float32, samples.dtype assert samples.ndim == 1, samples.dim + buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes buf += (samples.size * 4).to_bytes(4, byteorder="little") buf += samples.tobytes() - await websocket.send(buf) + payload_len = 10240 + while len(buf) > payload_len: + await websocket.send(buf[:payload_len]) + buf = buf[payload_len:] + + if buf: + await websocket.send(buf) decoding_results = await websocket.recv() print(decoding_results) diff --git a/python-api-examples/online-websocket-client-decode-file.py b/python-api-examples/online-websocket-client-decode-file.py index b03b99286..e0ad8d256 100755 --- a/python-api-examples/online-websocket-client-decode-file.py +++ b/python-api-examples/online-websocket-client-decode-file.py @@ -15,10 +15,9 @@ (Note: You have to first start the server before starting the client) -You can find the server at +You can find the c++ server at https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc - -Note: The server is implemented in C++. +or use the python server ./python-api-examples/streaming_server.py There is also a C++ version of the client. Please see https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc @@ -115,7 +114,8 @@ async def receive_results(socket: websockets.WebSocketServerProtocol): last_message = message logging.info(message) else: - return last_message + break + return last_message async def run( @@ -142,6 +142,7 @@ async def run( await websocket.send(d) + # Simulate streaming. You can remove the sleep if you want await asyncio.sleep(seconds_per_message) # in seconds start += samples_per_message diff --git a/python-api-examples/online-websocket-client-microphone.py b/python-api-examples/online-websocket-client-microphone.py index ab3d57335..f42dd0086 100755 --- a/python-api-examples/online-websocket-client-microphone.py +++ b/python-api-examples/online-websocket-client-microphone.py @@ -12,10 +12,9 @@ (Note: You have to first start the server before starting the client) -You can find the server at +You can find the C++ server at https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc - -Note: The server is implemented in C++. +or use the python server ./python-api-examples/streaming_server.py There is also a C++ version of the client. Please see https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc diff --git a/python-api-examples/streaming_server.py b/python-api-examples/streaming_server.py index be79979a6..c707a70c3 100755 --- a/python-api-examples/streaming_server.py +++ b/python-api-examples/streaming_server.py @@ -13,11 +13,37 @@ Example: +(1) Without a certificate + python3 ./python-api-examples/streaming_server.py \ --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt + +(2) With a certificate + +(a) Generate a certificate first: + + cd python-api-examples/web + ./generate-certificate.py + cd ../.. + +(b) Start the server + +python3 ./python-api-examples/streaming_server.py \ + --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ + --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ + --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ + --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ + --certificate ./python-api-examples/web/cert.pem + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html +to download pre-trained models. + +The model in the above help messages is from +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english """ import argparse @@ -35,6 +61,7 @@ import numpy as np import sherpa_onnx import websockets + from http_server import HttpServer @@ -269,8 +296,8 @@ def get_args(): parser.add_argument( "--num-threads", type=int, - default=1, - help="Sets the number of threads used for interop parallelism (e.g. in JIT interpreter) on CPU.", + default=2, + help="Number of threads to run the neural network model", ) parser.add_argument( @@ -278,8 +305,10 @@ def get_args(): type=str, help="""Path to the X.509 certificate. You need it only if you want to use a secure websocket connection, i.e., use wss:// instead of ws://. - You can use sherpa/bin/web/generate-certificate.py + You can use ./web/generate-certificate.py to generate the certificate `cert.pem`. + Note ./web/generate-certificate.py will generate three files but you + only need to pass the generated cert.pem to this option. """, ) @@ -287,7 +316,7 @@ def get_args(): "--doc-root", type=str, default="./python-api-examples/web", - help="""Path to the web root""", + help="Path to the web root", ) return parser.parse_args() @@ -299,9 +328,9 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: encoder=args.encoder_model, decoder=args.decoder_model, joiner=args.joiner_model, - num_threads=1, - sample_rate=16000, - feature_dim=80, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, decoding_method=args.decoding_method, max_active_paths=args.num_active_paths, enable_endpoint_detection=args.use_endpoint != 0, @@ -359,7 +388,7 @@ def __init__( server locate. certificate: Optional. If not None, it will use secure websocket. - You can use ./sherpa/bin/web/generate-certificate.py to generate + You can use ./web/generate-certificate.py to generate it (the default generated filename is `cert.pem`). """ self.recognizer = recognizer @@ -373,6 +402,7 @@ def __init__( ) self.stream_queue = asyncio.Queue() + self.max_wait_ms = max_wait_ms self.max_batch_size = max_batch_size self.max_message_size = max_message_size @@ -382,11 +412,10 @@ def __init__( self.current_active_connections = 0 self.sample_rate = int(recognizer.config.feat_config.sampling_rate) - self.decoding_method = recognizer.config.decoding_method async def stream_consumer_task(self): """This function extracts streams from the queue, batches them up, sends - them to the RNN-T model for computation and decoding. + them to the neural network model for computation and decoding. """ while True: if self.stream_queue.empty(): @@ -442,7 +471,22 @@ async def process_request( # This is a normal HTTP request if path == "/": path = "/index.html" - found, response, mime_type = self.http_server.process_request(path) + + if path in ("/upload.html", "/offline_record.html"): + response = r""" + +Speech recognition with next-gen Kaldi +

Only /streaming_record.html is available for the streaming server.

+
+
+Go back to /streaming_record.html + +""" + found = True + mime_type = "text/html" + else: + found, response, mime_type = self.http_server.process_request(path) + if isinstance(response, str): response = response.encode("utf-8") @@ -484,12 +528,21 @@ async def run(self, port: int): process_request=self.process_request, ssl=ssl_context, ): - ip_list = ["0.0.0.0", "localhost", "127.0.0.1"] - ip_list.append(socket.gethostbyname(socket.gethostname())) + ip_list = ["localhost"] + if ssl_context: + ip_list += ["0.0.0.0", "127.0.0.1"] + ip_list.append(socket.gethostbyname(socket.gethostname())) proto = "http://" if ssl_context is None else "https://" s = "Please visit one of the following addresses:\n\n" for p in ip_list: s += " " + proto + p + f":{port}" "\n" + + if not ssl_context: + s += "\nSince you are not providing a certificate, you cannot " + s += "use your microphone from within the browser using " + s += "public IP addresses. Only localhost can be used." + s += "You also cannot use 0.0.0.0 or 127.0.0.1" + logging.info(s) await asyncio.Future() # run forever @@ -525,7 +578,7 @@ async def handle_connection_impl( socket: websockets.WebSocketServerProtocol, ): """Receive audio samples from the client, process it, and send - deocoding result back to the client. + decoding result back to the client. Args: socket: @@ -560,8 +613,6 @@ async def handle_connection_impl( self.recognizer.reset(stream) segment += 1 - print(message) - await socket.send(json.dumps(message)) tail_padding = np.zeros(int(self.sample_rate * 0.3)).astype(np.float32) @@ -583,7 +634,7 @@ async def recv_audio_samples( self, socket: websockets.WebSocketServerProtocol, ) -> Optional[np.ndarray]: - """Receives a tensor from the client. + """Receive a tensor from the client. Each message contains either a bytes buffer containing audio samples in 16 kHz or contains "Done" meaning the end of utterance. @@ -660,6 +711,6 @@ def main(): if __name__ == "__main__": - log_filename = "log/log-streaming-zipformer" + log_filename = "log/log-streaming-server" setup_logger(log_filename) main() diff --git a/python-api-examples/web/README.md b/python-api-examples/web/README.md deleted file mode 100644 index d00b8d533..000000000 --- a/python-api-examples/web/README.md +++ /dev/null @@ -1,34 +0,0 @@ -# How to use - -```bash -git clone https://github.com/k2-fsa/sherpa - -cd sherpa/sherpa/bin/web -python3 -m http.server 6009 -``` -and then go to - -You will see a page like the following screenshot: - -![Screenshot if you visit http://localhost:6009](./pic/web-ui.png) - -If your server is listening at the port *6006* with address **localhost**, -then you can either click **Upload**, **Streaming_Record** or **Offline_Record** to play with it. - -## File descriptions - -### ./css/bootstrap.min.css - -It is downloaded from https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/css/bootstrap.min.css - -### ./js/jquery-3.6.0.min.js - -It is downloaded from https://code.jquery.com/jquery-3.6.0.min.js - -### ./js/popper.min.js - -It is downloaded from https://cdn.jsdelivr.net/npm/popper.js@1.14.7/dist/umd/popper.min.js - -### ./js/bootstrap.min.js - -It is download from https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/js/bootstrap.min.js diff --git a/python-api-examples/web/generate-certificate.py b/python-api-examples/web/generate-certificate.py index e1364ee75..e8154bcef 100755 --- a/python-api-examples/web/generate-certificate.py +++ b/python-api-examples/web/generate-certificate.py @@ -35,8 +35,8 @@ def cert_gen( - emailAddress="https://github.com/k2-fsa/k2", - commonName="sherpa", + emailAddress="https://github.com/k2-fsa/sherpa-onnx", + commonName="sherpa-onnx", countryName="CN", localityName="k2-fsa", stateOrProvinceName="k2-fsa", @@ -70,17 +70,13 @@ def cert_gen( cert.set_pubkey(k) cert.sign(k, "sha512") with open(CERT_FILE, "wt") as f: - f.write( - crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8") - ) + f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")) with open(KEY_FILE, "wt") as f: f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8")) with open(ALL_IN_ONE_FILE, "wt") as f: f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8")) - f.write( - crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8") - ) + f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")) print(f"Generated {CERT_FILE}") print(f"Generated {KEY_FILE}") print(f"Generated {ALL_IN_ONE_FILE}") diff --git a/python-api-examples/web/index.html b/python-api-examples/web/index.html index 65b111d43..4da4e228e 100644 --- a/python-api-examples/web/index.html +++ b/python-api-examples/web/index.html @@ -53,7 +53,7 @@
Offline_Record
Code is available at - https://github.com/k2-fsa/sherpa + https://github.com/k2-fsa/sherpa-onnx diff --git a/python-api-examples/web/js/offline_record.js b/python-api-examples/web/js/offline_record.js index d1bf8fe19..294ede3e2 100644 --- a/python-api-examples/web/js/offline_record.js +++ b/python-api-examples/web/js/offline_record.js @@ -60,6 +60,7 @@ const soundClips = document.getElementById('sound-clips'); const canvas = document.getElementById('canvas'); const mainSection = document.querySelector('.container'); +recordBtn.disabled = true; stopBtn.disabled = true; window.onload = (event) => { @@ -95,9 +96,10 @@ clearBtn.onclick = function() { }; function send_header(n) { - const header = new ArrayBuffer(4); - new DataView(header).setInt32(0, n, true /* littleEndian */); - socket.send(new Int32Array(header, 0, 1)); + const header = new ArrayBuffer(8); + new DataView(header).setInt32(0, expectedSampleRate, true /* littleEndian */); + new DataView(header).setInt32(4, n, true /* littleEndian */); + socket.send(new Int32Array(header, 0, 2)); } // copied/modified from https://mdn.github.io/web-dictaphone/ diff --git a/python-api-examples/web/js/streaming_record.js b/python-api-examples/web/js/streaming_record.js index e1be94dae..1d123d864 100644 --- a/python-api-examples/web/js/streaming_record.js +++ b/python-api-examples/web/js/streaming_record.js @@ -88,6 +88,7 @@ const canvas = document.getElementById('canvas'); const mainSection = document.querySelector('.container'); stopBtn.disabled = true; +recordBtn.disabled = true; let audioCtx; const canvasCtx = canvas.getContext('2d'); diff --git a/python-api-examples/web/js/upload.js b/python-api-examples/web/js/upload.js index a6fb6e9e4..343150106 100644 --- a/python-api-examples/web/js/upload.js +++ b/python-api-examples/web/js/upload.js @@ -74,9 +74,11 @@ connectBtn.onclick = function() { }; function send_header(n) { - const header = new ArrayBuffer(4); - new DataView(header).setInt32(0, n, true /* littleEndian */); - socket.send(new Int32Array(header, 0, 1)); + const header = new ArrayBuffer(8); + // assume the uploaded wave is 16000 Hz + new DataView(header).setInt32(0, 16000, true /* littleEndian */); + new DataView(header).setInt32(4, n, true /* littleEndian */); + socket.send(new Int32Array(header, 0, 2)); } function onFileChange() { diff --git a/python-api-examples/web/offline_record.html b/python-api-examples/web/offline_record.html index ff5f6f2e9..8c17f84f9 100644 --- a/python-api-examples/web/offline_record.html +++ b/python-api-examples/web/offline_record.html @@ -33,9 +33,9 @@

Recognition from offline recordings

ws:// - + : - +
diff --git a/python-api-examples/web/streaming_record.html b/python-api-examples/web/streaming_record.html index b31fee68c..46ef82563 100644 --- a/python-api-examples/web/streaming_record.html +++ b/python-api-examples/web/streaming_record.html @@ -33,9 +33,9 @@

Recognition from real-time recordings

ws:// - + : - +
diff --git a/python-api-examples/web/upload.html b/python-api-examples/web/upload.html index a50936bb8..e1003369a 100644 --- a/python-api-examples/web/upload.html +++ b/python-api-examples/web/upload.html @@ -32,9 +32,9 @@

Recognition from a selected file

ws:// - + : - +
diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 27c3e5494..0f1e23f52 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -1,12 +1,7 @@ from typing import Dict, List, Optional -from _sherpa_onnx import Display +from _sherpa_onnx import Display, OfflineStream, OnlineStream -from .online_recognizer import OnlineRecognizer -from .online_recognizer import OnlineStream from .offline_recognizer import OfflineRecognizer - +from .online_recognizer import OnlineRecognizer from .utils import encode_contexts - - - diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 0312c01c5..cc5b5559e 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -41,6 +41,7 @@ def from_transducer( sample_rate: int = 16000, feature_dim: int = 80, decoding_method: str = "greedy_search", + max_active_paths: int = 4, context_score: float = 1.5, debug: bool = False, provider: str = "cpu", @@ -72,6 +73,9 @@ def from_transducer( Dimension of the feature used to train the model. decoding_method: Valid values: greedy_search, modified_beam_search. + max_active_paths: + Maximum number of active paths to keep. Used only when + decoding_method is modified_beam_search. debug: True to show debug messages. provider: @@ -103,6 +107,7 @@ def from_transducer( context_score=context_score, ) self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config return self @classmethod @@ -166,6 +171,7 @@ def from_paraformer( decoding_method=decoding_method, ) self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config return self @classmethod @@ -229,6 +235,7 @@ def from_nemo_ctc( decoding_method=decoding_method, ) self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config return self @classmethod @@ -291,6 +298,7 @@ def from_whisper( decoding_method=decoding_method, ) self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config return self def create_stream(self, contexts_list: Optional[List[List[int]]] = None):