Skip to content

Commit

Permalink
Refactor client connection check into MQWebsocketAPI class
Browse files Browse the repository at this point in the history
Add locking around session changes
  • Loading branch information
NeonDaniel committed Oct 8, 2024
1 parent eddc804 commit c3e533e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 29 deletions.
18 changes: 6 additions & 12 deletions neon_hana/app/routers/node_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from starlette.websockets import WebSocketDisconnect

from neon_hana.app.dependencies import config, client_manager
from neon_hana.mq_websocket_api import MQWebsocketAPI
from neon_hana.mq_websocket_api import MQWebsocketAPI, ClientNotKnown

from neon_hana.schema.node_v1 import (NodeAudioInput, NodeGetStt,
NodeGetTts, NodeKlatResponse,
Expand Down Expand Up @@ -76,16 +76,6 @@ async def node_v1_endpoint(websocket: WebSocket, token: str):
async def node_v1_stream_endpoint(websocket: WebSocket, token: str):
client_id = client_manager.get_client_id(token)

# Handle problem clients that don't explicitly wait for the Node WS to
# connect before starting a stream
retries = 0
while not socket_api.get_session(client_id) and retries < 3:
sleep(1)
retries += 1
if not socket_api.get_session(client_id):
raise HTTPException(status_code=401,
detail=f"Client not known ({client_id})")

if not client_manager.check_connect_stream():
raise HTTPException(status_code=503,
detail=f"Server is not accepting any more streams")
Expand All @@ -99,8 +89,12 @@ async def node_v1_stream_endpoint(websocket: WebSocket, token: str):
socket_api.handle_audio_input_stream(client_in, client_id)
except WebSocketDisconnect:
disconnect_event.set()
except Exception as e:
except ClientNotKnown as e:
LOG.error(e)
raise HTTPException(status_code=401,
detail=f"Client not known ({client_id})")
except Exception as e:
LOG.exception(e)
finally:
client_manager.disconnect_stream()

Expand Down
47 changes: 30 additions & 17 deletions neon_hana/mq_websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from asyncio import run, get_event_loop
from os import makedirs
from queue import Queue
from time import time
from time import time, sleep
from typing import Optional
from fastapi import WebSocket
from neon_iris.client import NeonAIClient
Expand All @@ -36,6 +36,12 @@
from ovos_utils import LOG


class ClientNotKnown(RuntimeError):
"""
Exception raised when a client tries to do something before authenticating
"""


class MQWebsocketAPI(NeonAIClient):
def __init__(self, config: dict):
"""
Expand Down Expand Up @@ -66,27 +72,34 @@ def new_stream(self, ws: WebSocket, session_id: str):
@param ws: Client WebSocket that handles byte audio
@param session_id: Session ID the websocket is associated with
"""
if session_id not in self._sessions:
raise RuntimeError(f"Stream cannot be established for {session_id}")
from neon_hana.streaming_client import RemoteStreamHandler, StreamMicrophone
if not self._sessions[session_id].get('stream'):
LOG.info(f"starting stream for session {session_id}")
audio_queue = Queue()
stream = RemoteStreamHandler(StreamMicrophone(audio_queue), session_id,
input_audio_callback=self.handle_client_input,
ww_callback=self.handle_ww_detected,
client_socket=ws)
self._sessions[session_id]['stream'] = stream
try:
stream.start()
except RuntimeError:
pass
timeout = time() + 5
while session_id not in self._sessions and time() < timeout:
# Handle problem clients that don't explicitly wait for the Node WS
# to connect before starting a stream
sleep(1)
with self._session_lock:
if session_id not in self._sessions:
raise ClientNotKnown(f"Stream cannot be established for {session_id}")
from neon_hana.streaming_client import RemoteStreamHandler, StreamMicrophone
if not self._sessions[session_id].get('stream'):
LOG.info(f"starting stream for session {session_id}")
audio_queue = Queue()
stream = RemoteStreamHandler(StreamMicrophone(audio_queue), session_id,
input_audio_callback=self.handle_client_input,
ww_callback=self.handle_ww_detected,
client_socket=ws)
self._sessions[session_id]['stream'] = stream
try:
stream.start()
except RuntimeError:
pass

def end_session(self, session_id: str):
"""
End a client connection upon WS disconnection
"""
session: Optional[dict] = self._sessions.pop(session_id, None)
with self._session_lock:
session: Optional[dict] = self._sessions.pop(session_id, None)
if not session:
LOG.error(f"Ended session is not established {session_id}")
return
Expand Down

0 comments on commit c3e533e

Please sign in to comment.