+
+
+Go back to /upload.html
+or /offline_record.html
+
+"""
+ found = True
+ mime_type = "text/html"
+ else:
+ found, response, mime_type = self.http_server.process_request(path)
+ if isinstance(response, str):
+ response = response.encode("utf-8")
+
+ if not found:
+ status = http.HTTPStatus.NOT_FOUND
+ else:
+ status = http.HTTPStatus.OK
+ header = {"Content-Type": mime_type}
+ return status, header, response
+
+ if self.current_active_connections < self.max_active_connections:
+ self.current_active_connections += 1
+ return None
+
+ # Refuse new connections
+ status = http.HTTPStatus.SERVICE_UNAVAILABLE # 503
+ header = {"Hint": "The server is overloaded. Please retry later."}
+ response = b"The server is busy. Please retry later."
+
+ return status, header, response
+
+ async def run(self, port: int):
+ logging.info("started")
+
+ task = asyncio.create_task(self.stream_consumer_task())
+
+ if self.certificate:
+ logging.info(f"Using certificate: {self.certificate}")
+ ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ ssl_context.load_cert_chain(self.certificate)
+ else:
+ ssl_context = None
+ logging.info("No certificate provided")
+
+ async with websockets.serve(
+ self.handle_connection,
+ host="",
+ port=port,
+ max_size=self.max_message_size,
+ max_queue=self.max_queue_size,
+ process_request=self.process_request,
+ ssl=ssl_context,
+ ):
+ ip_list = ["localhost"]
+ if ssl_context:
+ ip_list += ["0.0.0.0", "127.0.0.1"]
+ ip_list.append(socket.gethostbyname(socket.gethostname()))
+
+ proto = "http://" if ssl_context is None else "https://"
+ s = "Please visit one of the following addresses:\n\n"
+ for p in ip_list:
+ s += " " + proto + p + f":{port}" "\n"
+ logging.info(s)
+
+ await asyncio.Future() # run forever
+
+ await task # not reachable
+
+ async def recv_audio_samples(
+ self,
+ socket: websockets.WebSocketServerProtocol,
+ ) -> Tuple[Optional[np.ndarray], Optional[float]]:
+ """Receive a tensor from the client.
+
+ The message from the client is a **bytes** buffer.
+
+ The first message can be either "Done" meaning the client won't send
+ anything in the future or it can be a buffer containing 8 bytes.
+ The first 4 bytes in little endian specifies the sample
+ rate of the audio samples; the second 4 bytes in little endian specifies
+ the number of bytes in the audio file, which will be sent by the client
+ in the subsequent messages.
+ Since there is a limit in the message size posed by the websocket
+ protocol, the client may send the audio file in multiple messages if the
+ audio file is very large.
+
+ The second and remaining messages contain audio samples.
+
+ Please refer to ./offline-websocket-client-decode-files-paralell.py
+ and ./offline-websocket-client-decode-files-sequential.py
+ for how the client sends the message.
+
+ Args:
+ socket:
+ The socket for communicating with the client.
+ Returns:
+ Return a containing:
+ - 1-D np.float32 array containing the audio samples
+ - sample rate of the audio samples
+ or return (None, None) indicating the end of utterance.
+ """
+ header = await socket.recv()
+ if header == "Done":
+ return None, None
+
+ assert len(header) >= 8, (
+ "The first message should contain at least 8 bytes."
+ + f"Given {len(header)}"
+ )
+
+ sample_rate = int.from_bytes(header[:4], "little", signed=True)
+ expected_num_bytes = int.from_bytes(header[4:8], "little", signed=True)
+
+ received = []
+ num_received_bytes = 0
+ if len(header) > 8:
+ received.append(header[8:])
+ num_received_bytes += len(header) - 8
+
+ if num_received_bytes < expected_num_bytes:
+ async for message in socket:
+ received.append(message)
+ num_received_bytes += len(message)
+ if num_received_bytes >= expected_num_bytes:
+ break
+
+ assert num_received_bytes == expected_num_bytes, (
+ num_received_bytes,
+ expected_num_bytes,
+ )
+
+ samples = b"".join(received)
+ array = np.frombuffer(samples, dtype=np.float32)
+ return array, sample_rate
+
+ async def stream_consumer_task(self):
+ """This function extracts streams from the queue, batches them up, sends
+ them to the RNN-T model for computation and decoding.
+ """
+ while True:
+ if self.stream_queue.empty():
+ await asyncio.sleep(self.max_wait_ms / 1000)
+ continue
+
+ batch = []
+ try:
+ while len(batch) < self.max_batch_size:
+ item = self.stream_queue.get_nowait()
+
+ batch.append(item)
+ except asyncio.QueueEmpty:
+ pass
+ stream_list = [b[0] for b in batch]
+ future_list = [b[1] for b in batch]
+
+ loop = asyncio.get_running_loop()
+ await loop.run_in_executor(
+ self.nn_pool,
+ self.recognizer.decode_streams,
+ stream_list,
+ )
+
+ for f in future_list:
+ self.stream_queue.task_done()
+ f.set_result(None)
+
+ async def compute_and_decode(
+ self,
+ stream: sherpa_onnx.OfflineStream,
+ ) -> None:
+ """Put the stream into the queue and wait it to be processed by the
+ consumer task.
+
+ Args:
+ stream:
+ The stream to be processed. Note: It is changed in-place.
+ """
+ loop = asyncio.get_running_loop()
+ future = loop.create_future()
+ await self.stream_queue.put((stream, future))
+ await future
+
+ async def handle_connection(
+ self,
+ socket: websockets.WebSocketServerProtocol,
+ ):
+ """Receive audio samples from the client, process it, and sends
+ deocoding result back to the client.
+
+ Args:
+ socket:
+ The socket for communicating with the client.
+ """
+ try:
+ await self.handle_connection_impl(socket)
+ except websockets.exceptions.ConnectionClosedError:
+ logging.info(f"{socket.remote_address} disconnected")
+ finally:
+ # Decrement so that it can accept new connections
+ self.current_active_connections -= 1
+
+ logging.info(
+ f"Disconnected: {socket.remote_address}. "
+ f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa
+ )
+
+ async def handle_connection_impl(
+ self,
+ socket: websockets.WebSocketServerProtocol,
+ ):
+ """Receive audio samples from the client, process it, and send
+ decoding results back to the client.
+
+ Args:
+ socket:
+ The socket for communicating with the client.
+ """
+ logging.info(
+ f"Connected: {socket.remote_address}. "
+ f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa
+ )
+
+ while True:
+ stream = self.recognizer.create_stream()
+ samples, sample_rate = await self.recv_audio_samples(socket)
+ if samples is None:
+ break
+ # stream.accept_samples() runs in the main thread
+
+ stream.accept_waveform(sample_rate, samples)
+
+ await self.compute_and_decode(stream)
+ result = stream.result.text
+ logging.info(f"result: {result}")
+
+ if result:
+ await socket.send(result)
+ else:
+ # If result is an empty string, send something to the client.
+ # Otherwise, socket.send() is a no-op and the client will
+ # wait for a reply indefinitely.
+ await socket.send("")
+
+
+def assert_file_exists(filename: str):
+ assert Path(filename).is_file(), (
+ f"{filename} does not exist!\n"
+ "Please refer to "
+ "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
+ )
+
+
+def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
+ if args.encoder:
+ assert len(args.paraformer) == 0, args.paraformer
+ assert len(args.nemo_ctc) == 0, args.nemo_ctc
+ assert len(args.whisper_encoder) == 0, args.whisper_encoder
+ assert len(args.whisper_decoder) == 0, args.whisper_decoder
+
+ assert_file_exists(args.encoder)
+ assert_file_exists(args.decoder)
+ assert_file_exists(args.joiner)
+
+ recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
+ encoder=args.encoder,
+ decoder=args.decoder,
+ joiner=args.joiner,
+ tokens=args.tokens,
+ num_threads=args.num_threads,
+ sample_rate=args.sample_rate,
+ feature_dim=args.feat_dim,
+ decoding_method=args.decoding_method,
+ max_active_paths=args.max_active_paths,
+ )
+ elif args.paraformer:
+ assert len(args.nemo_ctc) == 0, args.nemo_ctc
+ assert len(args.whisper_encoder) == 0, args.whisper_encoder
+ assert len(args.whisper_decoder) == 0, args.whisper_decoder
+
+ assert_file_exists(args.paraformer)
+
+ recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
+ paraformer=args.paraformer,
+ tokens=args.tokens,
+ num_threads=args.num_threads,
+ sample_rate=args.sample_rate,
+ feature_dim=args.feat_dim,
+ decoding_method=args.decoding_method,
+ )
+ elif args.nemo_ctc:
+ assert len(args.whisper_encoder) == 0, args.whisper_encoder
+ assert len(args.whisper_decoder) == 0, args.whisper_decoder
+
+ assert_file_exists(args.nemo_ctc)
+
+ recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
+ model=args.nemo_ctc,
+ tokens=args.tokens,
+ num_threads=args.num_threads,
+ sample_rate=args.sample_rate,
+ feature_dim=args.feat_dim,
+ decoding_method=args.decoding_method,
+ )
+ elif args.whisper_encoder:
+ assert_file_exists(args.whisper_encoder)
+ assert_file_exists(args.whisper_decoder)
+
+ recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
+ encoder=args.whisper_encoder,
+ decoder=args.whisper_decoder,
+ tokens=args.tokens,
+ num_threads=args.num_threads,
+ decoding_method=args.decoding_method,
+ )
+ else:
+ raise ValueError("Please specify at least one model")
+
+ return recognizer
+
+
+def main():
+ args = get_args()
+ logging.info(vars(args))
+ check_args(args)
+
+ recognizer = create_recognizer(args)
+
+ port = args.port
+ max_wait_ms = args.max_wait_ms
+ max_batch_size = args.max_batch_size
+ nn_pool_size = args.nn_pool_size
+ max_message_size = args.max_message_size
+ max_queue_size = args.max_queue_size
+ max_active_connections = args.max_active_connections
+ certificate = args.certificate
+ doc_root = args.doc_root
+
+ if certificate and not Path(certificate).is_file():
+ raise ValueError(f"{certificate} does not exist")
+
+ if not Path(doc_root).is_dir():
+ raise ValueError(f"Directory {doc_root} does not exist")
+
+ non_streaming_server = NonStreamingServer(
+ recognizer=recognizer,
+ max_wait_ms=max_wait_ms,
+ max_batch_size=max_batch_size,
+ nn_pool_size=nn_pool_size,
+ max_message_size=max_message_size,
+ max_queue_size=max_queue_size,
+ max_active_connections=max_active_connections,
+ certificate=certificate,
+ doc_root=doc_root,
+ )
+ asyncio.run(non_streaming_server.run(port))
+
+
+if __name__ == "__main__":
+ log_filename = "log/log-non-streaming-server"
+ setup_logger(log_filename)
+ main()
diff --git a/python-api-examples/offline-websocket-client-decode-files-paralell.py b/python-api-examples/offline-websocket-client-decode-files-paralell.py
index d1d691a29..f97f9c44a 100755
--- a/python-api-examples/offline-websocket-client-decode-files-paralell.py
+++ b/python-api-examples/offline-websocket-client-decode-files-paralell.py
@@ -119,7 +119,13 @@ async def run(
buf += (samples.size * 4).to_bytes(4, byteorder="little")
buf += samples.tobytes()
- await websocket.send(buf)
+ payload_len = 10240
+ while len(buf) > payload_len:
+ await websocket.send(buf[:payload_len])
+ buf = buf[payload_len:]
+
+ if buf:
+ await websocket.send(buf)
decoding_results = await websocket.recv()
logging.info(f"{wave_filename}\n{decoding_results}")
diff --git a/python-api-examples/offline-websocket-client-decode-files-sequential.py b/python-api-examples/offline-websocket-client-decode-files-sequential.py
index 935226e46..7dac1fc07 100755
--- a/python-api-examples/offline-websocket-client-decode-files-sequential.py
+++ b/python-api-examples/offline-websocket-client-decode-files-sequential.py
@@ -116,11 +116,18 @@ async def run(
assert isinstance(sample_rate, int)
assert samples.dtype == np.float32, samples.dtype
assert samples.ndim == 1, samples.dim
+
buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes
buf += (samples.size * 4).to_bytes(4, byteorder="little")
buf += samples.tobytes()
- await websocket.send(buf)
+ payload_len = 10240
+ while len(buf) > payload_len:
+ await websocket.send(buf[:payload_len])
+ buf = buf[payload_len:]
+
+ if buf:
+ await websocket.send(buf)
decoding_results = await websocket.recv()
print(decoding_results)
diff --git a/python-api-examples/online-websocket-client-decode-file.py b/python-api-examples/online-websocket-client-decode-file.py
index b03b99286..e0ad8d256 100755
--- a/python-api-examples/online-websocket-client-decode-file.py
+++ b/python-api-examples/online-websocket-client-decode-file.py
@@ -15,10 +15,9 @@
(Note: You have to first start the server before starting the client)
-You can find the server at
+You can find the c++ server at
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc
-
-Note: The server is implemented in C++.
+or use the python server ./python-api-examples/streaming_server.py
There is also a C++ version of the client. Please see
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc
@@ -115,7 +114,8 @@ async def receive_results(socket: websockets.WebSocketServerProtocol):
last_message = message
logging.info(message)
else:
- return last_message
+ break
+ return last_message
async def run(
@@ -142,6 +142,7 @@ async def run(
await websocket.send(d)
+ # Simulate streaming. You can remove the sleep if you want
await asyncio.sleep(seconds_per_message) # in seconds
start += samples_per_message
diff --git a/python-api-examples/online-websocket-client-microphone.py b/python-api-examples/online-websocket-client-microphone.py
index ab3d57335..f42dd0086 100755
--- a/python-api-examples/online-websocket-client-microphone.py
+++ b/python-api-examples/online-websocket-client-microphone.py
@@ -12,10 +12,9 @@
(Note: You have to first start the server before starting the client)
-You can find the server at
+You can find the C++ server at
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc
-
-Note: The server is implemented in C++.
+or use the python server ./python-api-examples/streaming_server.py
There is also a C++ version of the client. Please see
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc
diff --git a/python-api-examples/streaming_server.py b/python-api-examples/streaming_server.py
index be79979a6..c707a70c3 100755
--- a/python-api-examples/streaming_server.py
+++ b/python-api-examples/streaming_server.py
@@ -13,11 +13,37 @@
Example:
+(1) Without a certificate
+
python3 ./python-api-examples/streaming_server.py \
--encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
--decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
--joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
--tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
+
+(2) With a certificate
+
+(a) Generate a certificate first:
+
+ cd python-api-examples/web
+ ./generate-certificate.py
+ cd ../..
+
+(b) Start the server
+
+python3 ./python-api-examples/streaming_server.py \
+ --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
+ --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
+ --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
+ --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
+ --certificate ./python-api-examples/web/cert.pem
+
+Please refer to
+https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
+to download pre-trained models.
+
+The model in the above help messages is from
+https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
"""
import argparse
@@ -35,6 +61,7 @@
import numpy as np
import sherpa_onnx
import websockets
+
from http_server import HttpServer
@@ -269,8 +296,8 @@ def get_args():
parser.add_argument(
"--num-threads",
type=int,
- default=1,
- help="Sets the number of threads used for interop parallelism (e.g. in JIT interpreter) on CPU.",
+ default=2,
+ help="Number of threads to run the neural network model",
)
parser.add_argument(
@@ -278,8 +305,10 @@ def get_args():
type=str,
help="""Path to the X.509 certificate. You need it only if you want to
use a secure websocket connection, i.e., use wss:// instead of ws://.
- You can use sherpa/bin/web/generate-certificate.py
+ You can use ./web/generate-certificate.py
to generate the certificate `cert.pem`.
+ Note ./web/generate-certificate.py will generate three files but you
+ only need to pass the generated cert.pem to this option.
""",
)
@@ -287,7 +316,7 @@ def get_args():
"--doc-root",
type=str,
default="./python-api-examples/web",
- help="""Path to the web root""",
+ help="Path to the web root",
)
return parser.parse_args()
@@ -299,9 +328,9 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
encoder=args.encoder_model,
decoder=args.decoder_model,
joiner=args.joiner_model,
- num_threads=1,
- sample_rate=16000,
- feature_dim=80,
+ num_threads=args.num_threads,
+ sample_rate=args.sample_rate,
+ feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
max_active_paths=args.num_active_paths,
enable_endpoint_detection=args.use_endpoint != 0,
@@ -359,7 +388,7 @@ def __init__(
server locate.
certificate:
Optional. If not None, it will use secure websocket.
- You can use ./sherpa/bin/web/generate-certificate.py to generate
+ You can use ./web/generate-certificate.py to generate
it (the default generated filename is `cert.pem`).
"""
self.recognizer = recognizer
@@ -373,6 +402,7 @@ def __init__(
)
self.stream_queue = asyncio.Queue()
+
self.max_wait_ms = max_wait_ms
self.max_batch_size = max_batch_size
self.max_message_size = max_message_size
@@ -382,11 +412,10 @@ def __init__(
self.current_active_connections = 0
self.sample_rate = int(recognizer.config.feat_config.sampling_rate)
- self.decoding_method = recognizer.config.decoding_method
async def stream_consumer_task(self):
"""This function extracts streams from the queue, batches them up, sends
- them to the RNN-T model for computation and decoding.
+ them to the neural network model for computation and decoding.
"""
while True:
if self.stream_queue.empty():
@@ -442,7 +471,22 @@ async def process_request(
# This is a normal HTTP request
if path == "/":
path = "/index.html"
- found, response, mime_type = self.http_server.process_request(path)
+
+ if path in ("/upload.html", "/offline_record.html"):
+ response = r"""
+
+Speech recognition with next-gen Kaldi
+
Only /streaming_record.html is available for the streaming server.
+
+
+Go back to /streaming_record.html
+
+"""
+ found = True
+ mime_type = "text/html"
+ else:
+ found, response, mime_type = self.http_server.process_request(path)
+
if isinstance(response, str):
response = response.encode("utf-8")
@@ -484,12 +528,21 @@ async def run(self, port: int):
process_request=self.process_request,
ssl=ssl_context,
):
- ip_list = ["0.0.0.0", "localhost", "127.0.0.1"]
- ip_list.append(socket.gethostbyname(socket.gethostname()))
+ ip_list = ["localhost"]
+ if ssl_context:
+ ip_list += ["0.0.0.0", "127.0.0.1"]
+ ip_list.append(socket.gethostbyname(socket.gethostname()))
proto = "http://" if ssl_context is None else "https://"
s = "Please visit one of the following addresses:\n\n"
for p in ip_list:
s += " " + proto + p + f":{port}" "\n"
+
+ if not ssl_context:
+ s += "\nSince you are not providing a certificate, you cannot "
+ s += "use your microphone from within the browser using "
+ s += "public IP addresses. Only localhost can be used."
+ s += "You also cannot use 0.0.0.0 or 127.0.0.1"
+
logging.info(s)
await asyncio.Future() # run forever
@@ -525,7 +578,7 @@ async def handle_connection_impl(
socket: websockets.WebSocketServerProtocol,
):
"""Receive audio samples from the client, process it, and send
- deocoding result back to the client.
+ decoding result back to the client.
Args:
socket:
@@ -560,8 +613,6 @@ async def handle_connection_impl(
self.recognizer.reset(stream)
segment += 1
- print(message)
-
await socket.send(json.dumps(message))
tail_padding = np.zeros(int(self.sample_rate * 0.3)).astype(np.float32)
@@ -583,7 +634,7 @@ async def recv_audio_samples(
self,
socket: websockets.WebSocketServerProtocol,
) -> Optional[np.ndarray]:
- """Receives a tensor from the client.
+ """Receive a tensor from the client.
Each message contains either a bytes buffer containing audio samples
in 16 kHz or contains "Done" meaning the end of utterance.
@@ -660,6 +711,6 @@ def main():
if __name__ == "__main__":
- log_filename = "log/log-streaming-zipformer"
+ log_filename = "log/log-streaming-server"
setup_logger(log_filename)
main()
diff --git a/python-api-examples/web/README.md b/python-api-examples/web/README.md
deleted file mode 100644
index d00b8d533..000000000
--- a/python-api-examples/web/README.md
+++ /dev/null
@@ -1,34 +0,0 @@
-# How to use
-
-```bash
-git clone https://github.com/k2-fsa/sherpa
-
-cd sherpa/sherpa/bin/web
-python3 -m http.server 6009
-```
-and then go to
-
-You will see a page like the following screenshot:
-
-![Screenshot if you visit http://localhost:6009](./pic/web-ui.png)
-
-If your server is listening at the port *6006* with address **localhost**,
-then you can either click **Upload**, **Streaming_Record** or **Offline_Record** to play with it.
-
-## File descriptions
-
-### ./css/bootstrap.min.css
-
-It is downloaded from https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/css/bootstrap.min.css
-
-### ./js/jquery-3.6.0.min.js
-
-It is downloaded from https://code.jquery.com/jquery-3.6.0.min.js
-
-### ./js/popper.min.js
-
-It is downloaded from https://cdn.jsdelivr.net/npm/popper.js@1.14.7/dist/umd/popper.min.js
-
-### ./js/bootstrap.min.js
-
-It is download from https://cdn.jsdelivr.net/npm/bootstrap@4.3.1/dist/js/bootstrap.min.js
diff --git a/python-api-examples/web/generate-certificate.py b/python-api-examples/web/generate-certificate.py
index e1364ee75..e8154bcef 100755
--- a/python-api-examples/web/generate-certificate.py
+++ b/python-api-examples/web/generate-certificate.py
@@ -35,8 +35,8 @@
def cert_gen(
- emailAddress="https://github.com/k2-fsa/k2",
- commonName="sherpa",
+ emailAddress="https://github.com/k2-fsa/sherpa-onnx",
+ commonName="sherpa-onnx",
countryName="CN",
localityName="k2-fsa",
stateOrProvinceName="k2-fsa",
@@ -70,17 +70,13 @@ def cert_gen(
cert.set_pubkey(k)
cert.sign(k, "sha512")
with open(CERT_FILE, "wt") as f:
- f.write(
- crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")
- )
+ f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8"))
with open(KEY_FILE, "wt") as f:
f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8"))
with open(ALL_IN_ONE_FILE, "wt") as f:
f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8"))
- f.write(
- crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")
- )
+ f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8"))
print(f"Generated {CERT_FILE}")
print(f"Generated {KEY_FILE}")
print(f"Generated {ALL_IN_ONE_FILE}")
diff --git a/python-api-examples/web/index.html b/python-api-examples/web/index.html
index 65b111d43..4da4e228e 100644
--- a/python-api-examples/web/index.html
+++ b/python-api-examples/web/index.html
@@ -53,7 +53,7 @@
Offline_Record
Code is available at
- https://github.com/k2-fsa/sherpa
+ https://github.com/k2-fsa/sherpa-onnx
diff --git a/python-api-examples/web/js/offline_record.js b/python-api-examples/web/js/offline_record.js
index d1bf8fe19..294ede3e2 100644
--- a/python-api-examples/web/js/offline_record.js
+++ b/python-api-examples/web/js/offline_record.js
@@ -60,6 +60,7 @@ const soundClips = document.getElementById('sound-clips');
const canvas = document.getElementById('canvas');
const mainSection = document.querySelector('.container');
+recordBtn.disabled = true;
stopBtn.disabled = true;
window.onload = (event) => {
@@ -95,9 +96,10 @@ clearBtn.onclick = function() {
};
function send_header(n) {
- const header = new ArrayBuffer(4);
- new DataView(header).setInt32(0, n, true /* littleEndian */);
- socket.send(new Int32Array(header, 0, 1));
+ const header = new ArrayBuffer(8);
+ new DataView(header).setInt32(0, expectedSampleRate, true /* littleEndian */);
+ new DataView(header).setInt32(4, n, true /* littleEndian */);
+ socket.send(new Int32Array(header, 0, 2));
}
// copied/modified from https://mdn.github.io/web-dictaphone/
diff --git a/python-api-examples/web/js/streaming_record.js b/python-api-examples/web/js/streaming_record.js
index e1be94dae..1d123d864 100644
--- a/python-api-examples/web/js/streaming_record.js
+++ b/python-api-examples/web/js/streaming_record.js
@@ -88,6 +88,7 @@ const canvas = document.getElementById('canvas');
const mainSection = document.querySelector('.container');
stopBtn.disabled = true;
+recordBtn.disabled = true;
let audioCtx;
const canvasCtx = canvas.getContext('2d');
diff --git a/python-api-examples/web/js/upload.js b/python-api-examples/web/js/upload.js
index a6fb6e9e4..343150106 100644
--- a/python-api-examples/web/js/upload.js
+++ b/python-api-examples/web/js/upload.js
@@ -74,9 +74,11 @@ connectBtn.onclick = function() {
};
function send_header(n) {
- const header = new ArrayBuffer(4);
- new DataView(header).setInt32(0, n, true /* littleEndian */);
- socket.send(new Int32Array(header, 0, 1));
+ const header = new ArrayBuffer(8);
+ // assume the uploaded wave is 16000 Hz
+ new DataView(header).setInt32(0, 16000, true /* littleEndian */);
+ new DataView(header).setInt32(4, n, true /* littleEndian */);
+ socket.send(new Int32Array(header, 0, 2));
}
function onFileChange() {
diff --git a/python-api-examples/web/offline_record.html b/python-api-examples/web/offline_record.html
index ff5f6f2e9..8c17f84f9 100644
--- a/python-api-examples/web/offline_record.html
+++ b/python-api-examples/web/offline_record.html
@@ -33,9 +33,9 @@