Skip to content

Commit

Permalink
Implement streaming audio Websocket (#30)
Browse files Browse the repository at this point in the history
# Description
Implements a WS API for streaming raw audio (input and output) with a
per-client listener running in the backend

# Issues
<!-- If this is related to or closes an issue/other PR, please note them
here -->

# Other Notes
- Consider implementing streaming response audio
> Implemented. Stream socket handles chunked audio input and sends wav
audio segments, all as bytes.
- Consider WW config and included plugins
> Global configuration will be used (at least initially). Plugins will
match neon-speech defaults.
- Consider configuring max allowed clients to prevent overloading the
backend server with listener instances
  > Implemented in configuration with unit test coverage

---------

Co-authored-by: Daniel McKnight <[email protected]>
  • Loading branch information
NeonDaniel and NeonDaniel authored Oct 9, 2024
1 parent 3bdeabc commit d851940
Show file tree
Hide file tree
Showing 12 changed files with 351 additions and 13 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
- name: Install system dependencies
run: |
sudo apt update
sudo apt install -y swig gcc libpulse-dev portaudio19-dev
- name: Install package
run: |
python -m pip install --upgrade pip
pip install . -r requirements/test_requirements.txt
pip install .[streaming] -r requirements/test_requirements.txt
- name: Run Tests
run: |
pytest tests
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ ENV OVOS_CONFIG_BASE_FOLDER neon
ENV OVOS_CONFIG_FILENAME diana.yaml
ENV XDG_CONFIG_HOME /config

RUN apt update && apt install -y swig gcc libpulse-dev portaudio19-dev

COPY docker_overlay/ /

WORKDIR /app
COPY . /app
RUN pip install /app[websocket]

RUN pip install /app[websocket,streaming]

CMD ["python3", "/app/neon_hana/app/__main__.py"]
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ hana:
enable_email: True # Disabled by default; anyone with access to the API will be able to send emails from the configured address
node_username: node_user # Username to authenticate Node API access; leave empty to disable Node API access
node_password: node_password # Password associated with node_username
max_streaming_clients: -1 # Maximum audio streaming clients allowed (including 0). Default unset value allows infinite clients
```
It is recommended to generate unique values for configured tokens, these are 32
bytes in hexadecimal representation.
Expand Down
4 changes: 3 additions & 1 deletion docker_overlay/etc/neon/diana.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ hana:
fastapi_summary: "HANA (HTTP API for Neon Applications) is the HTTP component of the Device Independent API for Neon Applications (DIANA)"
stt_max_length_encoded: 500000
tts_max_words: 128
enable_email: False
enable_email: False
vad:
module: ovos-vad-plugin-webrtcvad
68 changes: 64 additions & 4 deletions neon_hana/app/routers/node_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,23 @@

from asyncio import Event
from signal import signal, SIGINT
from time import sleep
from typing import Optional, Union

from fastapi import APIRouter, WebSocket, HTTPException, Request
from fastapi import APIRouter, WebSocket, HTTPException
from ovos_utils import LOG
from starlette.websockets import WebSocketDisconnect

from neon_hana.app.dependencies import config, client_manager
from neon_hana.mq_websocket_api import MQWebsocketAPI
from neon_hana.mq_websocket_api import MQWebsocketAPI, ClientNotKnown

from neon_hana.schema.node_v1 import (NodeAudioInput, NodeGetStt,
NodeGetTts, NodeKlatResponse,
NodeAudioInputResponse,
NodeGetSttResponse,
NodeGetTtsResponse)
NodeGetTtsResponse, CoreWWDetected,
CoreIntentFailure, CoreErrorResponse,
CoreClearData, CoreAlertExpired)
node_route = APIRouter(prefix="/node", tags=["node"])

socket_api = MQWebsocketAPI(config)
Expand All @@ -65,11 +69,67 @@ async def node_v1_endpoint(websocket: WebSocket, token: str):
socket_api.handle_client_input(client_in, client_id)
except WebSocketDisconnect:
disconnect_event.set()
socket_api.end_session(session_id=client_id)


@node_route.websocket("/v1/stream")
async def node_v1_stream_endpoint(websocket: WebSocket, token: str):
client_id = client_manager.get_client_id(token)

if not client_manager.check_connect_stream():
raise HTTPException(status_code=503,
detail=f"Server is not accepting any more streams")
try:
await websocket.accept()
disconnect_event = Event()
socket_api.new_stream(websocket, client_id)
while not disconnect_event.is_set():
try:
client_in: bytes = await websocket.receive_bytes()
socket_api.handle_audio_input_stream(client_in, client_id)
except WebSocketDisconnect:
disconnect_event.set()
except ClientNotKnown as e:
LOG.error(e)
raise HTTPException(status_code=401,
detail=f"Client not known ({client_id})")
except Exception as e:
LOG.exception(e)
finally:
client_manager.disconnect_stream()


@node_route.get("/v1/doc")
async def node_v1_doc(_: Optional[Union[NodeAudioInput, NodeGetStt,
NodeGetTts]]) -> \
Optional[Union[NodeKlatResponse, NodeAudioInputResponse,
NodeGetSttResponse, NodeGetTtsResponse]]:
NodeGetSttResponse, NodeGetTtsResponse,
CoreWWDetected, CoreIntentFailure, CoreErrorResponse,
CoreClearData, CoreAlertExpired]]:
"""
The node endpoint (`/node/v1`) accepts and returns JSON objects representing
Messages. All inputs and responses will contain keys:
`msg_type`, `data`, `context`. Only the inputs and responses documented here
are explicitly supported. Other messages sent or received on this socket are
not guaranteed to be stable.
"""
pass


@node_route.get("/v1/stream/doc")
async def node_v1_stream_doc():
"""
The stream endpoint accepts input audio as raw bytes. It expects inputs to
have:
- sample_rate=16000
- sample_width=2
- sample_channels=1
- chunk_size=4096
Response audio is WAV audio as raw bytes, with each message containing one
full audio file. A client should queue all responses for playback.
Any client accessing the stream endpoint (`/node/v1/stream`), must first
establish a connection to the node endpoint (`/node/v1`).
"""
pass
24 changes: 24 additions & 0 deletions neon_hana/auth/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from threading import Lock

import jwt

Expand Down Expand Up @@ -53,7 +54,10 @@ def __init__(self, config: dict):
self._disable_auth = config.get("disable_auth")
self._node_username = config.get("node_username")
self._node_password = config.get("node_password")
self._max_streaming_clients = config.get("max_streaming_clients")
self._jwt_algo = "HS256"
self._connected_streams = 0
self._stream_check_lock = Lock()

def _create_tokens(self, encode_data: dict) -> dict:
# Permissions were not included in old tokens, allow refreshing with
Expand Down Expand Up @@ -88,6 +92,26 @@ def get_permissions(self, client_id: str) -> ClientPermissions:
client = self.authorized_clients[client_id]
return ClientPermissions(**client.get('permissions', dict()))

def check_connect_stream(self) -> bool:
"""
Check if a new stream is allowed
"""
with self._stream_check_lock:
if not isinstance(self._max_streaming_clients, int) or \
self._max_streaming_clients is False or \
self._max_streaming_clients < 0:
self._connected_streams += 1
return True
if self._connected_streams >= self._max_streaming_clients:
LOG.warning(f"No more streams allowed ({self._connected_streams})")
return False
self._connected_streams += 1
return True

def disconnect_stream(self):
with self._stream_check_lock:
self._connected_streams -= 1

def check_auth_request(self, client_id: str, username: str,
password: Optional[str] = None,
origin_ip: str = "127.0.0.1") -> dict:
Expand Down
75 changes: 71 additions & 4 deletions neon_hana/mq_websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,22 @@

from asyncio import run, get_event_loop
from os import makedirs
from time import time
from queue import Queue
from time import time, sleep
from typing import Optional
from fastapi import WebSocket
from neon_iris.client import NeonAIClient
from ovos_bus_client.message import Message
from threading import RLock
from ovos_utils import LOG


class ClientNotKnown(RuntimeError):
"""
Exception raised when a client tries to do something before authenticating
"""


class MQWebsocketAPI(NeonAIClient):
def __init__(self, config: dict):
"""
Expand All @@ -58,6 +66,49 @@ def new_connection(self, ws: WebSocket, session_id: str):
"socket": ws,
"user": self.user_config}

def new_stream(self, ws: WebSocket, session_id: str):
"""
Establish a new streaming connection, associated with an existing session.
@param ws: Client WebSocket that handles byte audio
@param session_id: Session ID the websocket is associated with
"""
timeout = time() + 5
while session_id not in self._sessions and time() < timeout:
# Handle problem clients that don't explicitly wait for the Node WS
# to connect before starting a stream
sleep(1)
with self._session_lock:
if session_id not in self._sessions:
raise ClientNotKnown(f"Stream cannot be established for {session_id}")
from neon_hana.streaming_client import RemoteStreamHandler, StreamMicrophone
if not self._sessions[session_id].get('stream'):
LOG.info(f"starting stream for session {session_id}")
audio_queue = Queue()
stream = RemoteStreamHandler(StreamMicrophone(audio_queue), session_id,
input_audio_callback=self.handle_client_input,
ww_callback=self.handle_ww_detected,
client_socket=ws)
self._sessions[session_id]['stream'] = stream
try:
stream.start()
except RuntimeError:
pass

def end_session(self, session_id: str):
"""
End a client connection upon WS disconnection
"""
with self._session_lock:
session: Optional[dict] = self._sessions.pop(session_id, None)
if not session:
LOG.error(f"Ended session is not established {session_id}")
return
stream = session.get('stream')
if stream:
stream.shutdown()
stream.join()
LOG.info(f"Ended stream handler for: {session_id}")

def get_session(self, session_id: str) -> dict:
"""
Get the latest session context for the given session_id.
Expand Down Expand Up @@ -114,6 +165,15 @@ def _update_session_data(self, message: Message):
if user_config:
self._sessions[session_id]['user'] = user_config

def handle_audio_input_stream(self, audio: bytes, session_id: str):
self._sessions[session_id]['stream'].mic.queue.put(audio)

def handle_ww_detected(self, ww_context: dict, session_id: str):
session = self.get_session(session_id)
message = Message("neon.ww_detected", ww_context,
{"session": session})
run(self.send_to_client(message))

def handle_client_input(self, data: dict, session_id: str):
"""
Handle some client input data.
Expand All @@ -134,9 +194,16 @@ def handle_klat_response(self, message: Message):
Handle a Neon text+audio response to a user input.
@param message: `klat.response` message from Neon
"""
self._update_session_data(message)
run(self.send_to_client(message))
LOG.debug(message.context.get("timing"))
try:
self._update_session_data(message)
run(self.send_to_client(message))
session_id = message.context.get('session', {}).get('session_id')
if stream := self._sessions.get(session_id, {}).get('stream'):
LOG.info("Stream response audio")
stream.on_response_audio(message.data)
LOG.debug(message.context.get("timing"))
except Exception as e:
LOG.exception(e)

def handle_complete_intent_failure(self, message: Message):
"""
Expand Down
30 changes: 30 additions & 0 deletions neon_hana/schema/node_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,33 @@ class NodeGetTtsResponse(BaseModel):
msg_type: str = "neon.get_tts.response"
data: KlatResponseData
context: dict


class CoreWWDetected(BaseModel):
msg_type: str = "neon.ww_detected"
data: dict
context: dict


class CoreIntentFailure(BaseModel):
msg_type: str = "complete.intent.failure"
data: dict
context: dict


class CoreErrorResponse(BaseModel):
msg_type: str = "klat.error"
data: dict
context: dict


class CoreClearData(BaseModel):
msg_type: str = "neon.clear_data"
data: dict
context: dict


class CoreAlertExpired(BaseModel):
msg_type: str = "neon.alert_expired"
data: dict
context: dict
Loading

0 comments on commit d851940

Please sign in to comment.