Skip to content

Commit

Permalink
Java api update for adding modelType in config class (#228)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaomingwork authored Jul 30, 2023
1 parent daffdab commit 5a54961
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 29 deletions.
9 changes: 5 additions & 4 deletions java-api-examples/modelconfig.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ feature_dim=80
rule1_min_trailing_silence=2.4
rule2_min_trailing_silence=1.2
rule3_min_utterance_length=20
encoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx
decoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx
joiner=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx
tokens=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
encoder=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx
decoder=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx
joiner=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx
tokens=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
num_threads=4
enable_endpoint_detection=true
decoding_method=modified_beam_search
max_active_paths=4
lm_model=
lm_scale=0.5
model_type=zipformer

#websocket server config
port=8890
Expand Down
12 changes: 7 additions & 5 deletions java-api-examples/src/DecodeFile.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ public void initModelWithPara() {
float rule3MinUtteranceLength = 20F;
String decodingMethod = "greedy_search";
int maxActivePaths = 4;
String lm_model="";
float lm_scale=0.5F;
String lm_model = "";
float lm_scale = 0.5F;
String modelType = "zipformer";
rcgOjb =
new OnlineRecognizer(
tokens,
Expand All @@ -65,9 +66,10 @@ public void initModelWithPara() {
rule2MinTrailingSilence,
rule3MinUtteranceLength,
decodingMethod,
lm_model,
lm_scale,
maxActivePaths);
lm_model,
lm_scale,
maxActivePaths,
modelType);
streamObj = rcgOjb.createStream();
} catch (Exception e) {
System.err.println(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class OnlineRecognizer {
private long ptr = 0; // this is the asr engine ptrss

private int sampleRate = 16000;

// load config file for OnlineRecognizer
public OnlineRecognizer(String modelCfgPath) {
Map<String, String> proMap = this.readProperties(modelCfgPath);
Expand All @@ -62,17 +63,20 @@ public OnlineRecognizer(String modelCfgPath) {
proMap.get("joiner").trim(),
proMap.get("tokens").trim(),
Integer.parseInt(proMap.get("num_threads").trim()),
false);
false,
proMap.get("model_type").trim());
FeatureConfig featConfig =
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim()));

OnlineRecognizerConfig rcgCfg =
OnlineLMConfig onlineLmConfig =
new OnlineLMConfig(
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));

OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig(
featConfig,
modelCfg,
endCfg,
onlineLmConfig,
onlineLmConfig,
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
proMap.get("decoding_method").trim(),
Integer.parseInt(proMap.get("max_active_paths").trim()));
Expand Down Expand Up @@ -107,18 +111,21 @@ public OnlineRecognizer(Object assetManager, String modelCfgPath) {
proMap.get("joiner").trim(),
proMap.get("tokens").trim(),
Integer.parseInt(proMap.get("num_threads").trim()),
false);
false,
proMap.get("model_type").trim());
FeatureConfig featConfig =
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));

OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim()));

OnlineRecognizerConfig rcgCfg =

OnlineLMConfig onlineLmConfig =
new OnlineLMConfig(
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));

OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig(
featConfig,
modelCfg,
endCfg,
onlineLmConfig,
onlineLmConfig,
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
proMap.get("decoding_method").trim(),
Integer.parseInt(proMap.get("max_active_paths").trim()));
Expand All @@ -144,21 +151,29 @@ public OnlineRecognizer(
float rule2MinTrailingSilence,
float rule3MinUtteranceLength,
String decodingMethod,
String lm_model,
float lm_scale,
int maxActivePaths) {
String lm_model,
float lm_scale,
int maxActivePaths,
String modelType) {
this.sampleRate = sampleRate;
EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F);
EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength);
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineTransducerModelConfig modelCfg =
new OnlineTransducerModelConfig(encoder, decoder, joiner, tokens, numThreads, false);
new OnlineTransducerModelConfig(
encoder, decoder, joiner, tokens, numThreads, false, modelType);
FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
OnlineLMConfig onlineLmConfig=new OnlineLMConfig(lm_model,lm_scale);
OnlineRecognizerConfig rcgCfg =
OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale);
OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig(
featConfig, modelCfg, endCfg, onlineLmConfig,enableEndpointDetection, decodingMethod, maxActivePaths);
featConfig,
modelCfg,
endCfg,
onlineLmConfig,
enableEndpointDetection,
decodingMethod,
maxActivePaths);
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
}
Expand Down Expand Up @@ -284,9 +299,10 @@ public void release() {
public void releaseStream(OnlineStream s) {
s.release();
}

// JNI interface libsherpa-onnx-jni.so

private static native Object[] readWave(String fileName); // static
private static native Object[] readWave(String fileName); // static

private native String getResult(long ptr, long streamPtr);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,24 @@ public class OnlineTransducerModelConfig {
private final String tokens;
private final int numThreads;
private final boolean debug;
private final String provider = "cpu";
private String modelType = "";

public OnlineTransducerModelConfig(
String encoder, String decoder, String joiner, String tokens, int numThreads, boolean debug) {
String encoder,
String decoder,
String joiner,
String tokens,
int numThreads,
boolean debug,
String modelType) {
this.encoder = encoder;
this.decoder = decoder;
this.joiner = joiner;
this.tokens = tokens;
this.numThreads = numThreads;
this.debug = debug;
this.modelType = modelType;
}

public String getEncoder() {
Expand Down

0 comments on commit 5a54961

Please sign in to comment.