From 8df1130a2a825454de7caee42cefa23651383df2 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Fri, 20 Sep 2024 19:27:41 -0700 Subject: [PATCH 01/14] Add a streaming audio input endpoint Outline handling of client audio stream --- neon_hana/app/routers/node_server.py | 21 +++++++++++++++++++++ neon_hana/mq_websocket_api.py | 4 ++++ 2 files changed, 25 insertions(+) diff --git a/neon_hana/app/routers/node_server.py b/neon_hana/app/routers/node_server.py index 33fec4a..a56473d 100644 --- a/neon_hana/app/routers/node_server.py +++ b/neon_hana/app/routers/node_server.py @@ -67,6 +67,27 @@ async def node_v1_endpoint(websocket: WebSocket, token: str): disconnect_event.set() +@node_route.websocket("/v1/stream") +async def node_v1_endpoint(websocket: WebSocket, token: str): + """ + Endpoint to handle a stream of raw audio bytes. A client using this endpoint + must first establish a connection to the `/v1` endpoint. + """ + client_id = client_manager.get_client_id(token) + if not socket_api.get_session(client_id): + raise HTTPException(status_code=401, + detail=f"Client not known ({client_id})") + await websocket.accept() + disconnect_event = Event() + + while not disconnect_event.is_set(): + try: + client_in: bytes = await websocket.receive_bytes() + socket_api.handle_audio_stream(client_in, client_id) + except WebSocketDisconnect: + disconnect_event.set() + + @node_route.get("/v1/doc") async def node_v1_doc(_: Optional[Union[NodeAudioInput, NodeGetStt, NodeGetTts]]) -> \ diff --git a/neon_hana/mq_websocket_api.py b/neon_hana/mq_websocket_api.py index 2c65af1..3390e42 100644 --- a/neon_hana/mq_websocket_api.py +++ b/neon_hana/mq_websocket_api.py @@ -114,6 +114,10 @@ def _update_session_data(self, message: Message): if user_config: self._sessions[session_id]['user'] = user_config + def handle_audio_stream(self, audio: bytes, session_id: str): + LOG.info(f"Got {len(audio)} bytes from {session_id}") + # TODO: Do something with this audio stream + def handle_client_input(self, data: dict, session_id: str): """ Handle some client input data. From c9254c55ff7cfeb879352311129c62475756298f Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Mon, 23 Sep 2024 12:07:08 -0700 Subject: [PATCH 02/14] Implement RemoteStreamHandler to consume input audio chunks Lazy init streaming when clients connect to the endpoint TODO note client cleanup upon disconnection --- neon_hana/app/routers/node_server.py | 3 +- neon_hana/mq_websocket_api.py | 107 ++++++++++++++++++++++++++- 2 files changed, 106 insertions(+), 4 deletions(-) diff --git a/neon_hana/app/routers/node_server.py b/neon_hana/app/routers/node_server.py index a56473d..c17d9b3 100644 --- a/neon_hana/app/routers/node_server.py +++ b/neon_hana/app/routers/node_server.py @@ -65,10 +65,11 @@ async def node_v1_endpoint(websocket: WebSocket, token: str): socket_api.handle_client_input(client_in, client_id) except WebSocketDisconnect: disconnect_event.set() + # TODO: Delete client from socket_api @node_route.websocket("/v1/stream") -async def node_v1_endpoint(websocket: WebSocket, token: str): +async def node_v1_stream_endpoint(websocket: WebSocket, token: str): """ Endpoint to handle a stream of raw audio bytes. A client using this endpoint must first establish a connection to the `/v1` endpoint. diff --git a/neon_hana/mq_websocket_api.py b/neon_hana/mq_websocket_api.py index 3390e42..6e87323 100644 --- a/neon_hana/mq_websocket_api.py +++ b/neon_hana/mq_websocket_api.py @@ -25,13 +25,26 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from asyncio import run, get_event_loop +from base64 import b64encode from os import makedirs +from queue import Queue from time import time +from typing import Optional, Callable + from fastapi import WebSocket +from mock.mock import Mock from neon_iris.client import NeonAIClient from ovos_bus_client.message import Message -from threading import RLock +from threading import RLock, Thread + +from ovos_dinkum_listener.voice_loop import DinkumVoiceLoop +from ovos_dinkum_listener.voice_loop.hotwords import HotwordContainer +from ovos_dinkum_listener.voice_loop.voice_loop import ChunkInfo +from ovos_plugin_manager.templates.microphone import Microphone +from ovos_plugin_manager.vad import OVOSVADFactory from ovos_utils import LOG +from ovos_utils.fakebus import FakeBus +from speech_recognition import AudioData class MQWebsocketAPI(NeonAIClient): @@ -115,8 +128,25 @@ def _update_session_data(self, message: Message): self._sessions[session_id]['user'] = user_config def handle_audio_stream(self, audio: bytes, session_id: str): - LOG.info(f"Got {len(audio)} bytes from {session_id}") - # TODO: Do something with this audio stream + if not self._sessions[session_id].get('stream'): + LOG.info(f"starting stream for session {session_id}") + audio_queue = Queue() + stream = RemoteStreamHandler(StreamMicrophone(audio_queue), session_id, + audio_callback=self.handle_client_input, + ww_callback=self.handle_ww_detected) + self._sessions[session_id]['stream'] = stream + try: + stream.start() + except RuntimeError: + pass + + self._sessions[session_id]['stream'].mic.queue.put(audio) + + def handle_ww_detected(self, ww_context: dict, session_id: str): + session = self.get_session(session_id) + message = Message("neon.ww_detected", ww_context, + {"session": session}) + run(self.send_to_client(message)) def handle_client_input(self, data: dict, session_id: str): """ @@ -205,3 +235,74 @@ def shutdown(self, *_, **__): loop.call_soon_threadsafe(loop.stop) LOG.info("Stopped Event Loop") super().shutdown() + + +class StreamMicrophone(Microphone): + def __init__(self, queue: Queue): + self.queue = queue + + def start(self): + pass + + def stop(self): + pass + + def read_chunk(self) -> Optional[bytes]: + return self.queue.get() + + +class RemoteStreamHandler(Thread): + def __init__(self, mic: StreamMicrophone, session_id: str, + audio_callback: Callable, + ww_callback: Callable, lang: str = "en-us"): + Thread.__init__(self) + self.session_id = session_id + self.ww_callback = ww_callback + self.audio_callback = audio_callback + self.bus = FakeBus() + self.mic = mic + self.lang = lang + self.hotwords = HotwordContainer(self.bus) + self.hotwords.load_hotword_engines() + self.vad = OVOSVADFactory.create() + self.voice_loop = DinkumVoiceLoop(mic=self.mic, + vad=self.vad, + hotwords=self.hotwords, + listenword_audio_callback=self.on_hotword, + hotword_audio_callback=self.on_hotword, + stopword_audio_callback=self.on_hotword, + wakeupword_audio_callback=self.on_hotword, + stt_audio_callback=self.on_audio, + stt=Mock(), + fallback_stt=Mock(), + transformers=MockTransformers(), + chunk_callback=self.on_chunk, + speech_seconds=0.5, + num_hotword_keep_chunks=0, + num_stt_rewind_chunks=0) + + def run(self): + self.voice_loop.start() + self.voice_loop.run() + + def on_hotword(self, audio_bytes: bytes, context: dict): + self.lang = context.get("stt_lang") or self.lang + LOG.info(f"Hotword: {context}") + self.ww_callback(context, self.session_id) + + def on_audio(self, audio_bytes: bytes, context: dict): + LOG.info(f"Audio: {context}") + audio_data = AudioData(audio_bytes, self.mic.sample_rate, + self.mic.sample_width).get_wav_data() + audio_data = b64encode(audio_data).decode("utf-8") + callback_data = {"type": "neon.audio_input", + "data": {"audio_data": audio_data, "lang": self.lang}} + self.audio_callback(callback_data, self.session_id) + + def on_chunk(self, chunk: ChunkInfo): + LOG.debug(f"Chunk: {chunk}") + + +class MockTransformers(Mock): + def transform(self, chunk): + return chunk, dict() From e13511914bc154cfde8fbcb0561de3b26137735b Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Mon, 23 Sep 2024 15:47:12 -0700 Subject: [PATCH 03/14] Cleanup after clients upon disconnection --- neon_hana/app/routers/node_server.py | 2 +- neon_hana/mq_websocket_api.py | 20 +++++++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/neon_hana/app/routers/node_server.py b/neon_hana/app/routers/node_server.py index c17d9b3..6badae3 100644 --- a/neon_hana/app/routers/node_server.py +++ b/neon_hana/app/routers/node_server.py @@ -65,7 +65,7 @@ async def node_v1_endpoint(websocket: WebSocket, token: str): socket_api.handle_client_input(client_in, client_id) except WebSocketDisconnect: disconnect_event.set() - # TODO: Delete client from socket_api + socket_api.end_session(session_id=client_id) @node_route.websocket("/v1/stream") diff --git a/neon_hana/mq_websocket_api.py b/neon_hana/mq_websocket_api.py index 6e87323..c4e285d 100644 --- a/neon_hana/mq_websocket_api.py +++ b/neon_hana/mq_websocket_api.py @@ -71,6 +71,20 @@ def new_connection(self, ws: WebSocket, session_id: str): "socket": ws, "user": self.user_config} + def end_session(self, session_id: str): + """ + End a client connection upon WS disconnection + """ + session: Optional[dict] = self._sessions.pop(session_id, None) + if not session: + LOG.error(f"Ended session is not established {session_id}") + return + stream: RemoteStreamHandler = session.get('stream') + if stream: + stream.shutdown() + stream.join() + LOG.info(f"Ended stream handler for: {session_id}") + def get_session(self, session_id: str) -> dict: """ Get the latest session context for the given session_id. @@ -245,7 +259,7 @@ def start(self): pass def stop(self): - pass + self.queue.put(None) def read_chunk(self) -> Optional[bytes]: return self.queue.get() @@ -302,6 +316,10 @@ def on_audio(self, audio_bytes: bytes, context: dict): def on_chunk(self, chunk: ChunkInfo): LOG.debug(f"Chunk: {chunk}") + def shutdown(self): + self.mic.stop() + self.voice_loop.stop() + class MockTransformers(Mock): def transform(self, chunk): From 69553e4bd9a608d6a8564d317b3c7e195a2cea88 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Tue, 24 Sep 2024 09:07:13 -0700 Subject: [PATCH 04/14] Update docker config to enable webrtcvad for Node server streaming audio support Update mocked methods for compat with dinkum 0.1.0+ Add websocket dependencies for streaming client Add apt dependencies to Dockerfile for Python module builds --- Dockerfile | 3 +++ docker_overlay/etc/neon/diana.yaml | 4 +++- neon_hana/mq_websocket_api.py | 4 ++-- requirements/websocket.txt | 7 ++++++- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index a4e5665..c3d62ca 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,10 +7,13 @@ ENV OVOS_CONFIG_BASE_FOLDER neon ENV OVOS_CONFIG_FILENAME diana.yaml ENV XDG_CONFIG_HOME /config +RUN apt update && apt install -y swig gcc libpulse-dev + COPY docker_overlay/ / WORKDIR /app COPY . /app + RUN pip install /app[websocket] CMD ["python3", "/app/neon_hana/app/__main__.py"] \ No newline at end of file diff --git a/docker_overlay/etc/neon/diana.yaml b/docker_overlay/etc/neon/diana.yaml index c8fa873..71b0990 100644 --- a/docker_overlay/etc/neon/diana.yaml +++ b/docker_overlay/etc/neon/diana.yaml @@ -28,4 +28,6 @@ hana: fastapi_summary: "HANA (HTTP API for Neon Applications) is the HTTP component of the Device Independent API for Neon Applications (DIANA)" stt_max_length_encoded: 500000 tts_max_words: 128 - enable_email: False \ No newline at end of file + enable_email: False +vad: + module: ovos-vad-plugin-webrtcvad \ No newline at end of file diff --git a/neon_hana/mq_websocket_api.py b/neon_hana/mq_websocket_api.py index c4e285d..40d8531 100644 --- a/neon_hana/mq_websocket_api.py +++ b/neon_hana/mq_websocket_api.py @@ -287,8 +287,8 @@ def __init__(self, mic: StreamMicrophone, session_id: str, stopword_audio_callback=self.on_hotword, wakeupword_audio_callback=self.on_hotword, stt_audio_callback=self.on_audio, - stt=Mock(), - fallback_stt=Mock(), + stt=Mock(transcribe=Mock(return_value=[])), + fallback_stt=Mock(transcribe=Mock(return_value=[])), transformers=MockTransformers(), chunk_callback=self.on_chunk, speech_seconds=0.5, diff --git a/requirements/websocket.txt b/requirements/websocket.txt index 99b6eb0..c9957af 100644 --- a/requirements/websocket.txt +++ b/requirements/websocket.txt @@ -1,2 +1,7 @@ neon-iris~=0.1,>=0.1.1a5 -websockets~=12.0 \ No newline at end of file +websockets~=12.0 +mock~=5.0 +ovos-dinkum-listener~=0.1 +ovos-vad-plugin-webrtcvad +#ovos-ww-plugin-pocketsphinx +ovos-ww-plugin-vosk \ No newline at end of file From 270479b396cff14abcaac67ec622c53931e52752 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Thu, 3 Oct 2024 17:18:44 -0700 Subject: [PATCH 05/14] Add streaming dependencies to unit tests Separate streaming dependencies from basic WS Refactor streaming client code into a separate module --- .github/workflows/unit_tests.yml | 2 +- Dockerfile | 2 +- neon_hana/mq_websocket_api.py | 97 +++----------------------------- neon_hana/streaming_client.py | 89 +++++++++++++++++++++++++++++ requirements/streaming.txt | 5 ++ requirements/websocket.txt | 5 -- setup.py | 3 +- 7 files changed, 106 insertions(+), 97 deletions(-) create mode 100644 neon_hana/streaming_client.py create mode 100644 requirements/streaming.txt diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 068ca91..08a253b 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -23,7 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install . -r requirements/test_requirements.txt + pip install .[streaming] -r requirements/test_requirements.txt - name: Run Tests run: | pytest tests diff --git a/Dockerfile b/Dockerfile index c3d62ca..6392e63 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,6 +14,6 @@ COPY docker_overlay/ / WORKDIR /app COPY . /app -RUN pip install /app[websocket] +RUN pip install /app[websocket,streaming] CMD ["python3", "/app/neon_hana/app/__main__.py"] \ No newline at end of file diff --git a/neon_hana/mq_websocket_api.py b/neon_hana/mq_websocket_api.py index 40d8531..41d4beb 100644 --- a/neon_hana/mq_websocket_api.py +++ b/neon_hana/mq_websocket_api.py @@ -25,26 +25,15 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from asyncio import run, get_event_loop -from base64 import b64encode from os import makedirs from queue import Queue from time import time -from typing import Optional, Callable - +from typing import Optional from fastapi import WebSocket -from mock.mock import Mock from neon_iris.client import NeonAIClient from ovos_bus_client.message import Message -from threading import RLock, Thread - -from ovos_dinkum_listener.voice_loop import DinkumVoiceLoop -from ovos_dinkum_listener.voice_loop.hotwords import HotwordContainer -from ovos_dinkum_listener.voice_loop.voice_loop import ChunkInfo -from ovos_plugin_manager.templates.microphone import Microphone -from ovos_plugin_manager.vad import OVOSVADFactory +from threading import RLock from ovos_utils import LOG -from ovos_utils.fakebus import FakeBus -from speech_recognition import AudioData class MQWebsocketAPI(NeonAIClient): @@ -79,7 +68,7 @@ def end_session(self, session_id: str): if not session: LOG.error(f"Ended session is not established {session_id}") return - stream: RemoteStreamHandler = session.get('stream') + stream = session.get('stream') if stream: stream.shutdown() stream.join() @@ -142,6 +131,7 @@ def _update_session_data(self, message: Message): self._sessions[session_id]['user'] = user_config def handle_audio_stream(self, audio: bytes, session_id: str): + from neon_hana.streaming_client import RemoteStreamHandler, StreamMicrophone if not self._sessions[session_id].get('stream'): LOG.info(f"starting stream for session {session_id}") audio_queue = Queue() @@ -184,6 +174,10 @@ def handle_klat_response(self, message: Message): """ self._update_session_data(message) run(self.send_to_client(message)) + session_id = message.context.get('session', {}).get('session_id') + if self._sessions.get(session_id, {}).get('stream'): + # TODO: stream response audio to streaming socket + pass LOG.debug(message.context.get("timing")) def handle_complete_intent_failure(self, message: Message): @@ -249,78 +243,3 @@ def shutdown(self, *_, **__): loop.call_soon_threadsafe(loop.stop) LOG.info("Stopped Event Loop") super().shutdown() - - -class StreamMicrophone(Microphone): - def __init__(self, queue: Queue): - self.queue = queue - - def start(self): - pass - - def stop(self): - self.queue.put(None) - - def read_chunk(self) -> Optional[bytes]: - return self.queue.get() - - -class RemoteStreamHandler(Thread): - def __init__(self, mic: StreamMicrophone, session_id: str, - audio_callback: Callable, - ww_callback: Callable, lang: str = "en-us"): - Thread.__init__(self) - self.session_id = session_id - self.ww_callback = ww_callback - self.audio_callback = audio_callback - self.bus = FakeBus() - self.mic = mic - self.lang = lang - self.hotwords = HotwordContainer(self.bus) - self.hotwords.load_hotword_engines() - self.vad = OVOSVADFactory.create() - self.voice_loop = DinkumVoiceLoop(mic=self.mic, - vad=self.vad, - hotwords=self.hotwords, - listenword_audio_callback=self.on_hotword, - hotword_audio_callback=self.on_hotword, - stopword_audio_callback=self.on_hotword, - wakeupword_audio_callback=self.on_hotword, - stt_audio_callback=self.on_audio, - stt=Mock(transcribe=Mock(return_value=[])), - fallback_stt=Mock(transcribe=Mock(return_value=[])), - transformers=MockTransformers(), - chunk_callback=self.on_chunk, - speech_seconds=0.5, - num_hotword_keep_chunks=0, - num_stt_rewind_chunks=0) - - def run(self): - self.voice_loop.start() - self.voice_loop.run() - - def on_hotword(self, audio_bytes: bytes, context: dict): - self.lang = context.get("stt_lang") or self.lang - LOG.info(f"Hotword: {context}") - self.ww_callback(context, self.session_id) - - def on_audio(self, audio_bytes: bytes, context: dict): - LOG.info(f"Audio: {context}") - audio_data = AudioData(audio_bytes, self.mic.sample_rate, - self.mic.sample_width).get_wav_data() - audio_data = b64encode(audio_data).decode("utf-8") - callback_data = {"type": "neon.audio_input", - "data": {"audio_data": audio_data, "lang": self.lang}} - self.audio_callback(callback_data, self.session_id) - - def on_chunk(self, chunk: ChunkInfo): - LOG.debug(f"Chunk: {chunk}") - - def shutdown(self): - self.mic.stop() - self.voice_loop.stop() - - -class MockTransformers(Mock): - def transform(self, chunk): - return chunk, dict() diff --git a/neon_hana/streaming_client.py b/neon_hana/streaming_client.py new file mode 100644 index 0000000..67a5b7f --- /dev/null +++ b/neon_hana/streaming_client.py @@ -0,0 +1,89 @@ +from base64 import b64encode +from typing import Optional, Callable +from mock.mock import Mock +from threading import Thread +from queue import Queue + +from ovos_dinkum_listener.voice_loop import DinkumVoiceLoop +from ovos_dinkum_listener.voice_loop.hotwords import HotwordContainer +from ovos_dinkum_listener.voice_loop.voice_loop import ChunkInfo +from ovos_plugin_manager.templates.microphone import Microphone +from ovos_plugin_manager.vad import OVOSVADFactory +from ovos_utils.fakebus import FakeBus +from speech_recognition import AudioData +from ovos_utils import LOG + + +class StreamMicrophone(Microphone): + def __init__(self, queue: Queue): + self.queue = queue + + def start(self): + pass + + def stop(self): + self.queue.put(None) + + def read_chunk(self) -> Optional[bytes]: + return self.queue.get() + + +class RemoteStreamHandler(Thread): + def __init__(self, mic: StreamMicrophone, session_id: str, + audio_callback: Callable, + ww_callback: Callable, lang: str = "en-us"): + Thread.__init__(self) + self.session_id = session_id + self.ww_callback = ww_callback + self.audio_callback = audio_callback + self.bus = FakeBus() + self.mic = mic + self.lang = lang + self.hotwords = HotwordContainer(self.bus) + self.hotwords.load_hotword_engines() + self.vad = OVOSVADFactory.create() + self.voice_loop = DinkumVoiceLoop(mic=self.mic, + vad=self.vad, + hotwords=self.hotwords, + listenword_audio_callback=self.on_hotword, + hotword_audio_callback=self.on_hotword, + stopword_audio_callback=self.on_hotword, + wakeupword_audio_callback=self.on_hotword, + stt_audio_callback=self.on_audio, + stt=Mock(transcribe=Mock(return_value=[])), + fallback_stt=Mock(transcribe=Mock(return_value=[])), + transformers=MockTransformers(), + chunk_callback=self.on_chunk, + speech_seconds=0.5, + num_hotword_keep_chunks=0, + num_stt_rewind_chunks=0) + + def run(self): + self.voice_loop.start() + self.voice_loop.run() + + def on_hotword(self, audio_bytes: bytes, context: dict): + self.lang = context.get("stt_lang") or self.lang + LOG.info(f"Hotword: {context}") + self.ww_callback(context, self.session_id) + + def on_audio(self, audio_bytes: bytes, context: dict): + LOG.info(f"Audio: {context}") + audio_data = AudioData(audio_bytes, self.mic.sample_rate, + self.mic.sample_width).get_wav_data() + audio_data = b64encode(audio_data).decode("utf-8") + callback_data = {"type": "neon.audio_input", + "data": {"audio_data": audio_data, "lang": self.lang}} + self.audio_callback(callback_data, self.session_id) + + def on_chunk(self, chunk: ChunkInfo): + LOG.debug(f"Chunk: {chunk}") + + def shutdown(self): + self.mic.stop() + self.voice_loop.stop() + + +class MockTransformers(Mock): + def transform(self, chunk): + return chunk, dict() diff --git a/requirements/streaming.txt b/requirements/streaming.txt new file mode 100644 index 0000000..6a241e7 --- /dev/null +++ b/requirements/streaming.txt @@ -0,0 +1,5 @@ +mock~=5.0 +ovos-dinkum-listener~=0.1 +ovos-vad-plugin-webrtcvad +#ovos-ww-plugin-pocketsphinx +ovos-ww-plugin-vosk \ No newline at end of file diff --git a/requirements/websocket.txt b/requirements/websocket.txt index c9957af..2a5a773 100644 --- a/requirements/websocket.txt +++ b/requirements/websocket.txt @@ -1,7 +1,2 @@ neon-iris~=0.1,>=0.1.1a5 websockets~=12.0 -mock~=5.0 -ovos-dinkum-listener~=0.1 -ovos-vad-plugin-webrtcvad -#ovos-ww-plugin-pocketsphinx -ovos-ww-plugin-vosk \ No newline at end of file diff --git a/setup.py b/setup.py index b9928ee..3777817 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,8 @@ def get_requirements(requirements_filename: str): license='BSD-3-Clause', packages=find_packages(), install_requires=get_requirements("requirements.txt"), - extras_require={"websocket": get_requirements("websocket.txt")}, + extras_require={"websocket": get_requirements("websocket.txt"), + "steaming": get_requirements("streaming.txt")}, zip_safe=True, classifiers=[ 'Intended Audience :: Developers', From 2970e1d3fe4baafe83b4eea6bbf963481e4c07b4 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Thu, 3 Oct 2024 19:28:16 -0700 Subject: [PATCH 06/14] Fix typo in extra dependencies Handle streaming socket retry if too early Implement streaming audio responses --- neon_hana/app/routers/node_server.py | 13 +++++-- neon_hana/mq_websocket_api.py | 54 +++++++++++++++++----------- neon_hana/streaming_client.py | 30 ++++++++++++---- setup.py | 2 +- 4 files changed, 69 insertions(+), 30 deletions(-) diff --git a/neon_hana/app/routers/node_server.py b/neon_hana/app/routers/node_server.py index 6badae3..65838d7 100644 --- a/neon_hana/app/routers/node_server.py +++ b/neon_hana/app/routers/node_server.py @@ -26,6 +26,7 @@ from asyncio import Event from signal import signal, SIGINT +from time import sleep from typing import Optional, Union from fastapi import APIRouter, WebSocket, HTTPException, Request @@ -75,16 +76,24 @@ async def node_v1_stream_endpoint(websocket: WebSocket, token: str): must first establish a connection to the `/v1` endpoint. """ client_id = client_manager.get_client_id(token) + + # Handle problem clients that don't explicitly wait for the Node WS to + # connect before starting a stream + retries = 0 + while not socket_api.get_session(client_id) and retries < 3: + sleep(1) + retries += 1 if not socket_api.get_session(client_id): raise HTTPException(status_code=401, detail=f"Client not known ({client_id})") + await websocket.accept() disconnect_event = Event() - + socket_api.new_stream(websocket, client_id) while not disconnect_event.is_set(): try: client_in: bytes = await websocket.receive_bytes() - socket_api.handle_audio_stream(client_in, client_id) + socket_api.handle_audio_input_stream(client_in, client_id) except WebSocketDisconnect: disconnect_event.set() diff --git a/neon_hana/mq_websocket_api.py b/neon_hana/mq_websocket_api.py index 41d4beb..c46907a 100644 --- a/neon_hana/mq_websocket_api.py +++ b/neon_hana/mq_websocket_api.py @@ -60,6 +60,28 @@ def new_connection(self, ws: WebSocket, session_id: str): "socket": ws, "user": self.user_config} + def new_stream(self, ws: WebSocket, session_id: str): + """ + Establish a new streaming connection, associated with an existing session. + @param ws: Client WebSocket that handles byte audio + @param session_id: Session ID the websocket is associated with + """ + if session_id not in self._sessions: + raise RuntimeError(f"Stream cannot be established for {session_id}") + from neon_hana.streaming_client import RemoteStreamHandler, StreamMicrophone + if not self._sessions[session_id].get('stream'): + LOG.info(f"starting stream for session {session_id}") + audio_queue = Queue() + stream = RemoteStreamHandler(StreamMicrophone(audio_queue), session_id, + input_audio_callback=self.handle_client_input, + ww_callback=self.handle_ww_detected, + client_socket=ws) + self._sessions[session_id]['stream'] = stream + try: + stream.start() + except RuntimeError: + pass + def end_session(self, session_id: str): """ End a client connection upon WS disconnection @@ -130,20 +152,7 @@ def _update_session_data(self, message: Message): if user_config: self._sessions[session_id]['user'] = user_config - def handle_audio_stream(self, audio: bytes, session_id: str): - from neon_hana.streaming_client import RemoteStreamHandler, StreamMicrophone - if not self._sessions[session_id].get('stream'): - LOG.info(f"starting stream for session {session_id}") - audio_queue = Queue() - stream = RemoteStreamHandler(StreamMicrophone(audio_queue), session_id, - audio_callback=self.handle_client_input, - ww_callback=self.handle_ww_detected) - self._sessions[session_id]['stream'] = stream - try: - stream.start() - except RuntimeError: - pass - + def handle_audio_input_stream(self, audio: bytes, session_id: str): self._sessions[session_id]['stream'].mic.queue.put(audio) def handle_ww_detected(self, ww_context: dict, session_id: str): @@ -172,13 +181,16 @@ def handle_klat_response(self, message: Message): Handle a Neon text+audio response to a user input. @param message: `klat.response` message from Neon """ - self._update_session_data(message) - run(self.send_to_client(message)) - session_id = message.context.get('session', {}).get('session_id') - if self._sessions.get(session_id, {}).get('stream'): - # TODO: stream response audio to streaming socket - pass - LOG.debug(message.context.get("timing")) + try: + self._update_session_data(message) + run(self.send_to_client(message)) + session_id = message.context.get('session', {}).get('session_id') + if stream := self._sessions.get(session_id, {}).get('stream'): + LOG.info("Stream response audio") + stream.on_response_audio(message.data) + LOG.debug(message.context.get("timing")) + except Exception as e: + LOG.exception(e) def handle_complete_intent_failure(self, message: Message): """ diff --git a/neon_hana/streaming_client.py b/neon_hana/streaming_client.py index 67a5b7f..536c9d6 100644 --- a/neon_hana/streaming_client.py +++ b/neon_hana/streaming_client.py @@ -1,4 +1,6 @@ -from base64 import b64encode +import io +from asyncio import run +from base64 import b64encode, b64decode from typing import Optional, Callable from mock.mock import Mock from threading import Thread @@ -12,6 +14,7 @@ from ovos_utils.fakebus import FakeBus from speech_recognition import AudioData from ovos_utils import LOG +from starlette.websockets import WebSocket class StreamMicrophone(Microphone): @@ -30,12 +33,14 @@ def read_chunk(self) -> Optional[bytes]: class RemoteStreamHandler(Thread): def __init__(self, mic: StreamMicrophone, session_id: str, - audio_callback: Callable, + input_audio_callback: Callable, + client_socket: WebSocket, ww_callback: Callable, lang: str = "en-us"): Thread.__init__(self) self.session_id = session_id self.ww_callback = ww_callback - self.audio_callback = audio_callback + self.input_audio_callback = input_audio_callback + self.client_socket = client_socket self.bus = FakeBus() self.mic = mic self.lang = lang @@ -49,7 +54,7 @@ def __init__(self, mic: StreamMicrophone, session_id: str, hotword_audio_callback=self.on_hotword, stopword_audio_callback=self.on_hotword, wakeupword_audio_callback=self.on_hotword, - stt_audio_callback=self.on_audio, + stt_audio_callback=self.on_input_audio, stt=Mock(transcribe=Mock(return_value=[])), fallback_stt=Mock(transcribe=Mock(return_value=[])), transformers=MockTransformers(), @@ -67,14 +72,27 @@ def on_hotword(self, audio_bytes: bytes, context: dict): LOG.info(f"Hotword: {context}") self.ww_callback(context, self.session_id) - def on_audio(self, audio_bytes: bytes, context: dict): + def on_input_audio(self, audio_bytes: bytes, context: dict): LOG.info(f"Audio: {context}") audio_data = AudioData(audio_bytes, self.mic.sample_rate, self.mic.sample_width).get_wav_data() audio_data = b64encode(audio_data).decode("utf-8") callback_data = {"type": "neon.audio_input", "data": {"audio_data": audio_data, "lang": self.lang}} - self.audio_callback(callback_data, self.session_id) + self.input_audio_callback(callback_data, self.session_id) + + def on_response_audio(self, data: dict): + async def _send_bytes(audio_bytes: bytes): + await self.client_socket.send_bytes(audio_bytes) + + i = 0 + for lang_response in data.get('responses', {}).values(): + for encoded_audio in lang_response.get('audio', {}).values(): + i += 1 + wav_audio_bytes = b64decode(encoded_audio) + LOG.info(f"Sending {len(wav_audio_bytes)} bytes of audio") + run(_send_bytes(wav_audio_bytes)) + LOG.info(f"Sent {i} binary audio response") def on_chunk(self, chunk: ChunkInfo): LOG.debug(f"Chunk: {chunk}") diff --git a/setup.py b/setup.py index 3777817..85ddbef 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ def get_requirements(requirements_filename: str): packages=find_packages(), install_requires=get_requirements("requirements.txt"), extras_require={"websocket": get_requirements("websocket.txt"), - "steaming": get_requirements("streaming.txt")}, + "streaming": get_requirements("streaming.txt")}, zip_safe=True, classifiers=[ 'Intended Audience :: Developers', From 832c5d925956a89b429be59de0de69e1aaf212b5 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Mon, 7 Oct 2024 09:17:21 -0700 Subject: [PATCH 07/14] Add docstrings to Node documentation endpoints Remove duplicate docstring not included in OpanAPI pages --- neon_hana/app/routers/node_server.py | 30 +++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/neon_hana/app/routers/node_server.py b/neon_hana/app/routers/node_server.py index 65838d7..feffb5a 100644 --- a/neon_hana/app/routers/node_server.py +++ b/neon_hana/app/routers/node_server.py @@ -29,7 +29,7 @@ from time import sleep from typing import Optional, Union -from fastapi import APIRouter, WebSocket, HTTPException, Request +from fastapi import APIRouter, WebSocket, HTTPException from starlette.websockets import WebSocketDisconnect from neon_hana.app.dependencies import config, client_manager @@ -71,10 +71,6 @@ async def node_v1_endpoint(websocket: WebSocket, token: str): @node_route.websocket("/v1/stream") async def node_v1_stream_endpoint(websocket: WebSocket, token: str): - """ - Endpoint to handle a stream of raw audio bytes. A client using this endpoint - must first establish a connection to the `/v1` endpoint. - """ client_id = client_manager.get_client_id(token) # Handle problem clients that don't explicitly wait for the Node WS to @@ -103,4 +99,28 @@ async def node_v1_doc(_: Optional[Union[NodeAudioInput, NodeGetStt, NodeGetTts]]) -> \ Optional[Union[NodeKlatResponse, NodeAudioInputResponse, NodeGetSttResponse, NodeGetTtsResponse]]: + """ + The node endpoint (`/node/v1`) accepts and returns JSON objects representing + Messages. All inputs and responses will contain keys: + `msg_type`, `data`, `context`. + """ + pass + + +@node_route.get("/v1/stream/doc") +async def node_v1_stream_doc(): + """ + The stream endpoint accepts input audio as raw bytes. It expects inputs to + have: + - sample_rate=16000 + - sample_width=2 + - sample_channels=1 + - chunk_size=4096 + + Response audio is WAV audio as raw bytes, with each message containing one + full audio file. A client should queue all responses for playback. + + Any client accessing the stream endpoint (`/node/v1/stream`), must first + establish a connection to the node endpoint (`/node/v1`). + """ pass From 2dd7ba1abb582a0d46fd9b1a9487f965fa8b4991 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Mon, 7 Oct 2024 15:24:28 -0700 Subject: [PATCH 08/14] Update streaming WW plugin dependencies to match Neon speech container --- Dockerfile | 2 +- requirements/streaming.txt | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index 6392e63..d8394da 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ ENV OVOS_CONFIG_BASE_FOLDER neon ENV OVOS_CONFIG_FILENAME diana.yaml ENV XDG_CONFIG_HOME /config -RUN apt update && apt install -y swig gcc libpulse-dev +RUN apt update && apt install -y swig gcc libpulse-dev portaudio19-dev COPY docker_overlay/ / diff --git a/requirements/streaming.txt b/requirements/streaming.txt index 6a241e7..da3ca46 100644 --- a/requirements/streaming.txt +++ b/requirements/streaming.txt @@ -1,5 +1,9 @@ mock~=5.0 ovos-dinkum-listener~=0.1 -ovos-vad-plugin-webrtcvad -#ovos-ww-plugin-pocketsphinx -ovos-ww-plugin-vosk \ No newline at end of file +ovos-vad-plugin-webrtcvad~=0.0.1 +ovos-ww-plugin-pocketsphinx~=0.1 +ovos-ww-plugin-vosk~=0.1 +ovos-ww-plugin-precise-lite[tflite]~=0.1 +ovos-ww-plugin-precise~=0.1 +# TODO: numpy patching tflite plugin compat. issue https://github.com/OpenVoiceOS/ovos-ww-plugin-precise-lite/pull/8 +numpy~=1.0 From 5fccf65301c1da72ea9e45a1e06183dd20fe4cab Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Mon, 7 Oct 2024 15:25:42 -0700 Subject: [PATCH 09/14] Add supported responses to Node endpoint documentation Update docstring to note undocumented functionality may change --- neon_hana/app/routers/node_server.py | 12 ++++++++--- neon_hana/schema/node_v1.py | 30 ++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/neon_hana/app/routers/node_server.py b/neon_hana/app/routers/node_server.py index feffb5a..22a9658 100644 --- a/neon_hana/app/routers/node_server.py +++ b/neon_hana/app/routers/node_server.py @@ -39,7 +39,9 @@ NodeGetTts, NodeKlatResponse, NodeAudioInputResponse, NodeGetSttResponse, - NodeGetTtsResponse) + NodeGetTtsResponse, CoreWWDetected, + CoreIntentFailure, CoreErrorResponse, + CoreClearData, CoreAlertExpired) node_route = APIRouter(prefix="/node", tags=["node"]) socket_api = MQWebsocketAPI(config) @@ -98,11 +100,15 @@ async def node_v1_stream_endpoint(websocket: WebSocket, token: str): async def node_v1_doc(_: Optional[Union[NodeAudioInput, NodeGetStt, NodeGetTts]]) -> \ Optional[Union[NodeKlatResponse, NodeAudioInputResponse, - NodeGetSttResponse, NodeGetTtsResponse]]: + NodeGetSttResponse, NodeGetTtsResponse, + CoreWWDetected, CoreIntentFailure, CoreErrorResponse, + CoreClearData, CoreAlertExpired]]: """ The node endpoint (`/node/v1`) accepts and returns JSON objects representing Messages. All inputs and responses will contain keys: - `msg_type`, `data`, `context`. + `msg_type`, `data`, `context`. Only the inputs and responses documented here + are explicitly supported. Other messages sent or received on this socket are + not guaranteed to be stable. """ pass diff --git a/neon_hana/schema/node_v1.py b/neon_hana/schema/node_v1.py index 11c2725..913c6b0 100644 --- a/neon_hana/schema/node_v1.py +++ b/neon_hana/schema/node_v1.py @@ -122,3 +122,33 @@ class NodeGetTtsResponse(BaseModel): msg_type: str = "neon.get_tts.response" data: KlatResponseData context: dict + + +class CoreWWDetected(BaseModel): + msg_type: str = "neon.ww_detected" + data: dict + context: dict + + +class CoreIntentFailure(BaseModel): + msg_type: str = "complete.intent.failure" + data: dict + context: dict + + +class CoreErrorResponse(BaseModel): + msg_type: str = "klat.error" + data: dict + context: dict + + +class CoreClearData(BaseModel): + msg_type: str = "neon.clear_data" + data: dict + context: dict + + +class CoreAlertExpired(BaseModel): + msg_type: str = "neon.alert_expired" + data: dict + context: dict From c99735590923cc970e819f0a5a0971db078e9579 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Mon, 7 Oct 2024 15:58:58 -0700 Subject: [PATCH 10/14] Implement configured maximum node streams with documentation and unit tests --- README.md | 1 + neon_hana/app/routers/node_server.py | 26 +++++++++++++++--------- neon_hana/auth/client_manager.py | 24 ++++++++++++++++++++++ tests/test_auth.py | 30 ++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index a84a42b..d741025 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ hana: enable_email: True # Disabled by default; anyone with access to the API will be able to send emails from the configured address node_username: node_user # Username to authenticate Node API access; leave empty to disable Node API access node_password: node_password # Password associated with node_username + max_streaming_clients: -1 # Maximum audio streaming clients allowed (including 0). Default unset value allows infinite clients ``` It is recommended to generate unique values for configured tokens, these are 32 bytes in hexadecimal representation. diff --git a/neon_hana/app/routers/node_server.py b/neon_hana/app/routers/node_server.py index 22a9658..b6a3a34 100644 --- a/neon_hana/app/routers/node_server.py +++ b/neon_hana/app/routers/node_server.py @@ -85,15 +85,23 @@ async def node_v1_stream_endpoint(websocket: WebSocket, token: str): raise HTTPException(status_code=401, detail=f"Client not known ({client_id})") - await websocket.accept() - disconnect_event = Event() - socket_api.new_stream(websocket, client_id) - while not disconnect_event.is_set(): - try: - client_in: bytes = await websocket.receive_bytes() - socket_api.handle_audio_input_stream(client_in, client_id) - except WebSocketDisconnect: - disconnect_event.set() + if not client_manager.check_connect_stream(): + raise HTTPException(status_code=503, + detail=f"Server is not accepting any more streams") + try: + await websocket.accept() + disconnect_event = Event() + socket_api.new_stream(websocket, client_id) + while not disconnect_event.is_set(): + try: + client_in: bytes = await websocket.receive_bytes() + socket_api.handle_audio_input_stream(client_in, client_id) + except WebSocketDisconnect: + disconnect_event.set() + except Exception as e: + print(e) + finally: + client_manager.disconnect_stream() @node_route.get("/v1/doc") diff --git a/neon_hana/auth/client_manager.py b/neon_hana/auth/client_manager.py index d1bb6d0..578ea13 100644 --- a/neon_hana/auth/client_manager.py +++ b/neon_hana/auth/client_manager.py @@ -23,6 +23,7 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from threading import Lock import jwt @@ -53,7 +54,10 @@ def __init__(self, config: dict): self._disable_auth = config.get("disable_auth") self._node_username = config.get("node_username") self._node_password = config.get("node_password") + self._max_streaming_clients = config.get("max_streaming_clients") self._jwt_algo = "HS256" + self._connected_streams = 0 + self._stream_check_lock = Lock() def _create_tokens(self, encode_data: dict) -> dict: # Permissions were not included in old tokens, allow refreshing with @@ -88,6 +92,26 @@ def get_permissions(self, client_id: str) -> ClientPermissions: client = self.authorized_clients[client_id] return ClientPermissions(**client.get('permissions', dict())) + def check_connect_stream(self) -> bool: + """ + Check if a new stream is allowed + """ + with self._stream_check_lock: + if not isinstance(self._max_streaming_clients, int) or \ + self._max_streaming_clients is False or \ + self._max_streaming_clients < 0: + self._connected_streams += 1 + return True + if self._connected_streams >= self._max_streaming_clients: + LOG.warning(f"No more streams allowed ({self._connected_streams})") + return False + self._connected_streams += 1 + return True + + def disconnect_stream(self): + with self._stream_check_lock: + self._connected_streams -= 1 + def check_auth_request(self, client_id: str, username: str, password: Optional[str] = None, origin_ip: str = "127.0.0.1") -> dict: diff --git a/tests/test_auth.py b/tests/test_auth.py index d2fcbef..88422fb 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -202,3 +202,33 @@ def test_client_permissions(self): self.assertFalse(any([v for v in restricted_perms.as_dict().values()])) self.assertIsInstance(permissive_perms.as_dict(), dict) self.assertTrue(all([v for v in permissive_perms.as_dict().values()])) + + def test_stream_connections(self): + # Test configured maximum + self.client_manager._max_streaming_clients = 1 + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertEqual(self.client_manager._connected_streams, 1) + self.assertFalse(self.client_manager.check_connect_stream()) + self.assertFalse(self.client_manager.check_connect_stream()) + self.assertEqual(self.client_manager._connected_streams, 1) + self.client_manager.disconnect_stream() + self.assertEqual(self.client_manager._connected_streams, 0) + + # Test explicitly disabled streaming + self.client_manager._max_streaming_clients = 0 + self.assertFalse(self.client_manager.check_connect_stream()) + + # Test unlimited clients + self.client_manager._max_streaming_clients = None + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertEqual(self.client_manager._connected_streams, 3) + + self.client_manager._max_streaming_clients = -1 + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertEqual(self.client_manager._connected_streams, 4) + + self.client_manager._max_streaming_clients = False + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertEqual(self.client_manager._connected_streams, 5) From fffef5cb70208955c5617ef2bdfeea9d922f551f Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Mon, 7 Oct 2024 16:01:24 -0700 Subject: [PATCH 11/14] Add apt dependencies to unit test automation --- .github/workflows/unit_tests.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 08a253b..0e264d4 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -20,7 +20,11 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies + - name: Install system dependencies + run: | + sudo apt update + sudo apt install -y swig gcc libpulse-dev portaudio19-dev + - name: Install package run: | python -m pip install --upgrade pip pip install .[streaming] -r requirements/test_requirements.txt From 391a978a3ac23c294dc0fc3820590af2256b4e88 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Tue, 8 Oct 2024 09:16:33 -0700 Subject: [PATCH 12/14] Address logging review comments --- neon_hana/app/routers/node_server.py | 3 ++- neon_hana/streaming_client.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/neon_hana/app/routers/node_server.py b/neon_hana/app/routers/node_server.py index b6a3a34..90444bd 100644 --- a/neon_hana/app/routers/node_server.py +++ b/neon_hana/app/routers/node_server.py @@ -30,6 +30,7 @@ from typing import Optional, Union from fastapi import APIRouter, WebSocket, HTTPException +from ovos_utils import LOG from starlette.websockets import WebSocketDisconnect from neon_hana.app.dependencies import config, client_manager @@ -99,7 +100,7 @@ async def node_v1_stream_endpoint(websocket: WebSocket, token: str): except WebSocketDisconnect: disconnect_event.set() except Exception as e: - print(e) + LOG.error(e) finally: client_manager.disconnect_stream() diff --git a/neon_hana/streaming_client.py b/neon_hana/streaming_client.py index 536c9d6..c3263a7 100644 --- a/neon_hana/streaming_client.py +++ b/neon_hana/streaming_client.py @@ -92,7 +92,7 @@ async def _send_bytes(audio_bytes: bytes): wav_audio_bytes = b64decode(encoded_audio) LOG.info(f"Sending {len(wav_audio_bytes)} bytes of audio") run(_send_bytes(wav_audio_bytes)) - LOG.info(f"Sent {i} binary audio response") + LOG.info(f"Sent {i} binary audio response(s)") def on_chunk(self, chunk: ChunkInfo): LOG.debug(f"Chunk: {chunk}") From eddc804a5071e38a69aebc2f6282e22d6905474b Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Tue, 8 Oct 2024 09:29:13 -0700 Subject: [PATCH 13/14] Remove newline change --- requirements/websocket.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/websocket.txt b/requirements/websocket.txt index 2a5a773..99b6eb0 100644 --- a/requirements/websocket.txt +++ b/requirements/websocket.txt @@ -1,2 +1,2 @@ neon-iris~=0.1,>=0.1.1a5 -websockets~=12.0 +websockets~=12.0 \ No newline at end of file From c3e533ef368e4ffacfc4d9a294ad3993aa30da34 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Tue, 8 Oct 2024 10:01:20 -0700 Subject: [PATCH 14/14] Refactor client connection check into MQWebsocketAPI class Add locking around session changes --- neon_hana/app/routers/node_server.py | 18 ++++------- neon_hana/mq_websocket_api.py | 47 ++++++++++++++++++---------- 2 files changed, 36 insertions(+), 29 deletions(-) diff --git a/neon_hana/app/routers/node_server.py b/neon_hana/app/routers/node_server.py index 90444bd..c7acd3f 100644 --- a/neon_hana/app/routers/node_server.py +++ b/neon_hana/app/routers/node_server.py @@ -34,7 +34,7 @@ from starlette.websockets import WebSocketDisconnect from neon_hana.app.dependencies import config, client_manager -from neon_hana.mq_websocket_api import MQWebsocketAPI +from neon_hana.mq_websocket_api import MQWebsocketAPI, ClientNotKnown from neon_hana.schema.node_v1 import (NodeAudioInput, NodeGetStt, NodeGetTts, NodeKlatResponse, @@ -76,16 +76,6 @@ async def node_v1_endpoint(websocket: WebSocket, token: str): async def node_v1_stream_endpoint(websocket: WebSocket, token: str): client_id = client_manager.get_client_id(token) - # Handle problem clients that don't explicitly wait for the Node WS to - # connect before starting a stream - retries = 0 - while not socket_api.get_session(client_id) and retries < 3: - sleep(1) - retries += 1 - if not socket_api.get_session(client_id): - raise HTTPException(status_code=401, - detail=f"Client not known ({client_id})") - if not client_manager.check_connect_stream(): raise HTTPException(status_code=503, detail=f"Server is not accepting any more streams") @@ -99,8 +89,12 @@ async def node_v1_stream_endpoint(websocket: WebSocket, token: str): socket_api.handle_audio_input_stream(client_in, client_id) except WebSocketDisconnect: disconnect_event.set() - except Exception as e: + except ClientNotKnown as e: LOG.error(e) + raise HTTPException(status_code=401, + detail=f"Client not known ({client_id})") + except Exception as e: + LOG.exception(e) finally: client_manager.disconnect_stream() diff --git a/neon_hana/mq_websocket_api.py b/neon_hana/mq_websocket_api.py index c46907a..6b9b96e 100644 --- a/neon_hana/mq_websocket_api.py +++ b/neon_hana/mq_websocket_api.py @@ -27,7 +27,7 @@ from asyncio import run, get_event_loop from os import makedirs from queue import Queue -from time import time +from time import time, sleep from typing import Optional from fastapi import WebSocket from neon_iris.client import NeonAIClient @@ -36,6 +36,12 @@ from ovos_utils import LOG +class ClientNotKnown(RuntimeError): + """ + Exception raised when a client tries to do something before authenticating + """ + + class MQWebsocketAPI(NeonAIClient): def __init__(self, config: dict): """ @@ -66,27 +72,34 @@ def new_stream(self, ws: WebSocket, session_id: str): @param ws: Client WebSocket that handles byte audio @param session_id: Session ID the websocket is associated with """ - if session_id not in self._sessions: - raise RuntimeError(f"Stream cannot be established for {session_id}") - from neon_hana.streaming_client import RemoteStreamHandler, StreamMicrophone - if not self._sessions[session_id].get('stream'): - LOG.info(f"starting stream for session {session_id}") - audio_queue = Queue() - stream = RemoteStreamHandler(StreamMicrophone(audio_queue), session_id, - input_audio_callback=self.handle_client_input, - ww_callback=self.handle_ww_detected, - client_socket=ws) - self._sessions[session_id]['stream'] = stream - try: - stream.start() - except RuntimeError: - pass + timeout = time() + 5 + while session_id not in self._sessions and time() < timeout: + # Handle problem clients that don't explicitly wait for the Node WS + # to connect before starting a stream + sleep(1) + with self._session_lock: + if session_id not in self._sessions: + raise ClientNotKnown(f"Stream cannot be established for {session_id}") + from neon_hana.streaming_client import RemoteStreamHandler, StreamMicrophone + if not self._sessions[session_id].get('stream'): + LOG.info(f"starting stream for session {session_id}") + audio_queue = Queue() + stream = RemoteStreamHandler(StreamMicrophone(audio_queue), session_id, + input_audio_callback=self.handle_client_input, + ww_callback=self.handle_ww_detected, + client_socket=ws) + self._sessions[session_id]['stream'] = stream + try: + stream.start() + except RuntimeError: + pass def end_session(self, session_id: str): """ End a client connection upon WS disconnection """ - session: Optional[dict] = self._sessions.pop(session_id, None) + with self._session_lock: + session: Optional[dict] = self._sessions.pop(session_id, None) if not session: LOG.error(f"Ended session is not established {session_id}") return