Skip to content

Commit

Permalink
Support TDNN models from the yesno recipe from icefall (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Aug 12, 2023
1 parent b094868 commit a4bff28
Show file tree
Hide file tree
Showing 23 changed files with 613 additions and 37 deletions.
44 changes: 44 additions & 0 deletions .github/scripts/test-offline-ctc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,50 @@ echo "PATH: $PATH"

which $EXE

log "------------------------------------------------------------"
log "Run tdnn yesno (Hebrew)"
log "------------------------------------------------------------"
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"

GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
ls -lh *.onnx
popd

log "test float32 models"
time $EXE \
--sample-rate=8000 \
--feat-dim=23 \
\
--tokens=$repo/tokens.txt \
--tdnn-model=$repo/model-epoch-14-avg-2.onnx \
$repo/test_wavs/0_0_0_1_0_0_0_1.wav \
$repo/test_wavs/0_0_1_0_0_0_1_0.wav \
$repo/test_wavs/0_0_1_0_0_1_1_1.wav \
$repo/test_wavs/0_0_1_0_1_0_0_1.wav \
$repo/test_wavs/0_0_1_1_0_0_0_1.wav \
$repo/test_wavs/0_0_1_1_0_1_1_0.wav

log "test int8 models"
time $EXE \
--sample-rate=8000 \
--feat-dim=23 \
\
--tokens=$repo/tokens.txt \
--tdnn-model=$repo/model-epoch-14-avg-2.int8.onnx \
$repo/test_wavs/0_0_0_1_0_0_0_1.wav \
$repo/test_wavs/0_0_1_0_0_0_1_0.wav \
$repo/test_wavs/0_0_1_0_0_1_1_1.wav \
$repo/test_wavs/0_0_1_0_1_0_0_1.wav \
$repo/test_wavs/0_0_1_1_0_0_0_1.wav \
$repo/test_wavs/0_0_1_1_0_1_1_0.wav

rm -rf $repo

log "------------------------------------------------------------"
log "Run Citrinet (stt_en_citrinet_512, English)"
log "------------------------------------------------------------"
Expand Down
40 changes: 39 additions & 1 deletion .github/workflows/test-python-offline-websocket-server.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
model_type: ["transducer", "paraformer", "nemo_ctc", "whisper"]
model_type: ["transducer", "paraformer", "nemo_ctc", "whisper", "tdnn"]

steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -172,3 +172,41 @@ jobs:
./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav \
./sherpa-onnx-whisper-tiny.en/test_wavs/1.wav \
./sherpa-onnx-whisper-tiny.en/test_wavs/8k.wav
- name: Start server for tdnn models
if: matrix.model_type == 'tdnn'
shell: bash
run: |
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno
cd sherpa-onnx-tdnn-yesno
git lfs pull --include "*.onnx"
cd ..
python3 ./python-api-examples/non_streaming_server.py \
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
--sample-rate=8000 \
--feat-dim=23 &
echo "sleep 10 seconds to wait the server start"
sleep 10
- name: Start client for tdnn models
if: matrix.model_type == 'tdnn'
shell: bash
run: |
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_1_0_0_1.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_0_0_1.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_1_1_0.wav
python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_1_0_0_1.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_0_0_1.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_1_0_1_1_0.wav
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)

set(SHERPA_ONNX_VERSION "1.7.2")
set(SHERPA_ONNX_VERSION "1.7.3")

# Disable warning about
#
Expand Down
39 changes: 39 additions & 0 deletions python-api-examples/non_streaming_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
--tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
(5) Use a tdnn model of the yesno recipe from icefall
cd /path/to/sherpa-onnx
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno
cd sherpa-onnx-tdnn-yesno
git lfs pull --include "*.onnx"
python3 ./python-api-examples/non_streaming_server.py \
--sample-rate=8000 \
--feat-dim=23 \
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt
----
To use a certificate so that you can use https, please use
Expand Down Expand Up @@ -196,6 +210,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser):
)


def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--tdnn-model",
default="",
type=str,
help="Path to the model.onnx for the tdnn model of the yesno recipe",
)


def add_whisper_model_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--whisper-encoder",
Expand All @@ -216,6 +239,7 @@ def add_model_args(parser: argparse.ArgumentParser):
add_transducer_model_args(parser)
add_paraformer_model_args(parser)
add_nemo_ctc_model_args(parser)
add_tdnn_ctc_model_args(parser)
add_whisper_model_args(parser)

parser.add_argument(
Expand Down Expand Up @@ -730,6 +754,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model

assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
Expand All @@ -750,6 +775,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model

assert_file_exists(args.paraformer)

Expand All @@ -764,6 +790,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
elif args.nemo_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model

assert_file_exists(args.nemo_ctc)

Expand All @@ -776,6 +803,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
decoding_method=args.decoding_method,
)
elif args.whisper_encoder:
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)

Expand All @@ -786,6 +814,17 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
num_threads=args.num_threads,
decoding_method=args.decoding_method,
)
elif args.tdnn_model:
assert_file_exists(args.tdnn_model)

recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
model=args.tdnn_model,
tokens=args.tokens,
sample_rate=args.sample_rate,
feature_dim=args.feat_dim,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
)
else:
raise ValueError("Please specify at least one model")

Expand Down
40 changes: 39 additions & 1 deletion python-api-examples/offline-decode-files.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
file(s) with a non-streaming model.
(1) For paraformer
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--paraformer=/path/to/paraformer.onnx \
Expand All @@ -20,6 +21,7 @@
/path/to/1.wav
(2) For transducer models from icefall
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
Expand Down Expand Up @@ -56,9 +58,20 @@
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
(5) For tdnn models of the yesno recipe from icefall
python3 ./python-api-examples/offline-decode-files.py \
--sample-rate=8000 \
--feature-dim=23 \
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download the pre-trained models
to install sherpa-onnx and to download non-streaming pre-trained models
used in this file.
"""
import argparse
Expand Down Expand Up @@ -159,6 +172,13 @@ def get_args():
help="Path to the model.onnx from NeMo CTC",
)

parser.add_argument(
"--tdnn-model",
default="",
type=str,
help="Path to the model.onnx for the tdnn model of the yesno recipe",
)

parser.add_argument(
"--num-threads",
type=int,
Expand Down Expand Up @@ -285,6 +305,7 @@ def main():
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model

contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
if contexts:
Expand All @@ -311,6 +332,7 @@ def main():
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model

assert_file_exists(args.paraformer)

Expand All @@ -326,6 +348,7 @@ def main():
elif args.nemo_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model

assert_file_exists(args.nemo_ctc)

Expand All @@ -339,6 +362,7 @@ def main():
debug=args.debug,
)
elif args.whisper_encoder:
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)

Expand All @@ -347,6 +371,20 @@ def main():
decoder=args.whisper_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.tdnn_model:
assert_file_exists(args.tdnn_model)

recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
model=args.tdnn_model,
tokens=args.tokens,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
debug=args.debug,
)
Expand Down
23 changes: 10 additions & 13 deletions python-api-examples/web/js/upload.js
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,18 @@ function onFileChange() {
console.log('file.type ' + file.type);
console.log('file.size ' + file.size);

let audioCtx = new AudioContext({sampleRate: 16000});

let reader = new FileReader();
reader.onload = function() {
console.log('reading file!');
let view = new Int16Array(reader.result);
// we assume the input file is a wav file.
// TODO: add some checks here.
let int16_samples = view.subarray(22); // header has 44 bytes == 22 shorts
let num_samples = int16_samples.length;
let float32_samples = new Float32Array(num_samples);
console.log('num_samples ' + num_samples)

for (let i = 0; i < num_samples; ++i) {
float32_samples[i] = int16_samples[i] / 32768.
}
audioCtx.decodeAudioData(reader.result, decodedDone);
};

function decodedDone(decoded) {
let typedArray = new Float32Array(decoded.length);
let float32_samples = decoded.getChannelData(0);
let buf = float32_samples.buffer

// Send 1024 audio samples per request.
//
Expand All @@ -119,14 +117,13 @@ function onFileChange() {
// (2) There is a limit on the number of bytes in the payload that can be
// sent by websocket, which is 1MB, I think. We can send a large
// audio file for decoding in this approach.
let buf = float32_samples.buffer
let n = 1024 * 4; // send this number of bytes per request.
console.log('buf length, ' + buf.byteLength);
send_header(buf.byteLength);
for (let start = 0; start < buf.byteLength; start += n) {
socket.send(buf.slice(start, start + n));
}
};
}

reader.readAsArrayBuffer(file);
}
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ set(sources
offline-recognizer.cc
offline-rnn-lm.cc
offline-stream.cc
offline-tdnn-ctc-model.cc
offline-tdnn-model-config.cc
offline-transducer-greedy-search-decoder.cc
offline-transducer-model-config.cc
offline-transducer-model.cc
Expand Down
Loading

0 comments on commit a4bff28

Please sign in to comment.