From c99735590923cc970e819f0a5a0971db078e9579 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Mon, 7 Oct 2024 15:58:58 -0700 Subject: [PATCH] Implement configured maximum node streams with documentation and unit tests --- README.md | 1 + neon_hana/app/routers/node_server.py | 26 +++++++++++++++--------- neon_hana/auth/client_manager.py | 24 ++++++++++++++++++++++ tests/test_auth.py | 30 ++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index a84a42b..d741025 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ hana: enable_email: True # Disabled by default; anyone with access to the API will be able to send emails from the configured address node_username: node_user # Username to authenticate Node API access; leave empty to disable Node API access node_password: node_password # Password associated with node_username + max_streaming_clients: -1 # Maximum audio streaming clients allowed (including 0). Default unset value allows infinite clients ``` It is recommended to generate unique values for configured tokens, these are 32 bytes in hexadecimal representation. diff --git a/neon_hana/app/routers/node_server.py b/neon_hana/app/routers/node_server.py index 22a9658..b6a3a34 100644 --- a/neon_hana/app/routers/node_server.py +++ b/neon_hana/app/routers/node_server.py @@ -85,15 +85,23 @@ async def node_v1_stream_endpoint(websocket: WebSocket, token: str): raise HTTPException(status_code=401, detail=f"Client not known ({client_id})") - await websocket.accept() - disconnect_event = Event() - socket_api.new_stream(websocket, client_id) - while not disconnect_event.is_set(): - try: - client_in: bytes = await websocket.receive_bytes() - socket_api.handle_audio_input_stream(client_in, client_id) - except WebSocketDisconnect: - disconnect_event.set() + if not client_manager.check_connect_stream(): + raise HTTPException(status_code=503, + detail=f"Server is not accepting any more streams") + try: + await websocket.accept() + disconnect_event = Event() + socket_api.new_stream(websocket, client_id) + while not disconnect_event.is_set(): + try: + client_in: bytes = await websocket.receive_bytes() + socket_api.handle_audio_input_stream(client_in, client_id) + except WebSocketDisconnect: + disconnect_event.set() + except Exception as e: + print(e) + finally: + client_manager.disconnect_stream() @node_route.get("/v1/doc") diff --git a/neon_hana/auth/client_manager.py b/neon_hana/auth/client_manager.py index d1bb6d0..578ea13 100644 --- a/neon_hana/auth/client_manager.py +++ b/neon_hana/auth/client_manager.py @@ -23,6 +23,7 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from threading import Lock import jwt @@ -53,7 +54,10 @@ def __init__(self, config: dict): self._disable_auth = config.get("disable_auth") self._node_username = config.get("node_username") self._node_password = config.get("node_password") + self._max_streaming_clients = config.get("max_streaming_clients") self._jwt_algo = "HS256" + self._connected_streams = 0 + self._stream_check_lock = Lock() def _create_tokens(self, encode_data: dict) -> dict: # Permissions were not included in old tokens, allow refreshing with @@ -88,6 +92,26 @@ def get_permissions(self, client_id: str) -> ClientPermissions: client = self.authorized_clients[client_id] return ClientPermissions(**client.get('permissions', dict())) + def check_connect_stream(self) -> bool: + """ + Check if a new stream is allowed + """ + with self._stream_check_lock: + if not isinstance(self._max_streaming_clients, int) or \ + self._max_streaming_clients is False or \ + self._max_streaming_clients < 0: + self._connected_streams += 1 + return True + if self._connected_streams >= self._max_streaming_clients: + LOG.warning(f"No more streams allowed ({self._connected_streams})") + return False + self._connected_streams += 1 + return True + + def disconnect_stream(self): + with self._stream_check_lock: + self._connected_streams -= 1 + def check_auth_request(self, client_id: str, username: str, password: Optional[str] = None, origin_ip: str = "127.0.0.1") -> dict: diff --git a/tests/test_auth.py b/tests/test_auth.py index d2fcbef..88422fb 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -202,3 +202,33 @@ def test_client_permissions(self): self.assertFalse(any([v for v in restricted_perms.as_dict().values()])) self.assertIsInstance(permissive_perms.as_dict(), dict) self.assertTrue(all([v for v in permissive_perms.as_dict().values()])) + + def test_stream_connections(self): + # Test configured maximum + self.client_manager._max_streaming_clients = 1 + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertEqual(self.client_manager._connected_streams, 1) + self.assertFalse(self.client_manager.check_connect_stream()) + self.assertFalse(self.client_manager.check_connect_stream()) + self.assertEqual(self.client_manager._connected_streams, 1) + self.client_manager.disconnect_stream() + self.assertEqual(self.client_manager._connected_streams, 0) + + # Test explicitly disabled streaming + self.client_manager._max_streaming_clients = 0 + self.assertFalse(self.client_manager.check_connect_stream()) + + # Test unlimited clients + self.client_manager._max_streaming_clients = None + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertEqual(self.client_manager._connected_streams, 3) + + self.client_manager._max_streaming_clients = -1 + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertEqual(self.client_manager._connected_streams, 4) + + self.client_manager._max_streaming_clients = False + self.assertTrue(self.client_manager.check_connect_stream()) + self.assertEqual(self.client_manager._connected_streams, 5)