Skip to content

Commit

Permalink
Implement configured maximum node streams with documentation and unit…
Browse files Browse the repository at this point in the history
… tests
  • Loading branch information
NeonDaniel committed Oct 7, 2024
1 parent 5fccf65 commit c997355
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 9 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ hana:
enable_email: True # Disabled by default; anyone with access to the API will be able to send emails from the configured address
node_username: node_user # Username to authenticate Node API access; leave empty to disable Node API access
node_password: node_password # Password associated with node_username
max_streaming_clients: -1 # Maximum audio streaming clients allowed (including 0). Default unset value allows infinite clients
```
It is recommended to generate unique values for configured tokens, these are 32
bytes in hexadecimal representation.
Expand Down
26 changes: 17 additions & 9 deletions neon_hana/app/routers/node_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,23 @@ async def node_v1_stream_endpoint(websocket: WebSocket, token: str):
raise HTTPException(status_code=401,
detail=f"Client not known ({client_id})")

await websocket.accept()
disconnect_event = Event()
socket_api.new_stream(websocket, client_id)
while not disconnect_event.is_set():
try:
client_in: bytes = await websocket.receive_bytes()
socket_api.handle_audio_input_stream(client_in, client_id)
except WebSocketDisconnect:
disconnect_event.set()
if not client_manager.check_connect_stream():
raise HTTPException(status_code=503,
detail=f"Server is not accepting any more streams")
try:
await websocket.accept()
disconnect_event = Event()
socket_api.new_stream(websocket, client_id)
while not disconnect_event.is_set():
try:
client_in: bytes = await websocket.receive_bytes()
socket_api.handle_audio_input_stream(client_in, client_id)
except WebSocketDisconnect:
disconnect_event.set()
except Exception as e:
print(e)
finally:
client_manager.disconnect_stream()


@node_route.get("/v1/doc")
Expand Down
24 changes: 24 additions & 0 deletions neon_hana/auth/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from threading import Lock

import jwt

Expand Down Expand Up @@ -53,7 +54,10 @@ def __init__(self, config: dict):
self._disable_auth = config.get("disable_auth")
self._node_username = config.get("node_username")
self._node_password = config.get("node_password")
self._max_streaming_clients = config.get("max_streaming_clients")
self._jwt_algo = "HS256"
self._connected_streams = 0
self._stream_check_lock = Lock()

def _create_tokens(self, encode_data: dict) -> dict:
# Permissions were not included in old tokens, allow refreshing with
Expand Down Expand Up @@ -88,6 +92,26 @@ def get_permissions(self, client_id: str) -> ClientPermissions:
client = self.authorized_clients[client_id]
return ClientPermissions(**client.get('permissions', dict()))

def check_connect_stream(self) -> bool:
"""
Check if a new stream is allowed
"""
with self._stream_check_lock:
if not isinstance(self._max_streaming_clients, int) or \
self._max_streaming_clients is False or \
self._max_streaming_clients < 0:
self._connected_streams += 1
return True
if self._connected_streams >= self._max_streaming_clients:
LOG.warning(f"No more streams allowed ({self._connected_streams})")
return False
self._connected_streams += 1
return True

def disconnect_stream(self):
with self._stream_check_lock:
self._connected_streams -= 1

def check_auth_request(self, client_id: str, username: str,
password: Optional[str] = None,
origin_ip: str = "127.0.0.1") -> dict:
Expand Down
30 changes: 30 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,33 @@ def test_client_permissions(self):
self.assertFalse(any([v for v in restricted_perms.as_dict().values()]))
self.assertIsInstance(permissive_perms.as_dict(), dict)
self.assertTrue(all([v for v in permissive_perms.as_dict().values()]))

def test_stream_connections(self):
# Test configured maximum
self.client_manager._max_streaming_clients = 1
self.assertTrue(self.client_manager.check_connect_stream())
self.assertEqual(self.client_manager._connected_streams, 1)
self.assertFalse(self.client_manager.check_connect_stream())
self.assertFalse(self.client_manager.check_connect_stream())
self.assertEqual(self.client_manager._connected_streams, 1)
self.client_manager.disconnect_stream()
self.assertEqual(self.client_manager._connected_streams, 0)

# Test explicitly disabled streaming
self.client_manager._max_streaming_clients = 0
self.assertFalse(self.client_manager.check_connect_stream())

# Test unlimited clients
self.client_manager._max_streaming_clients = None
self.assertTrue(self.client_manager.check_connect_stream())
self.assertTrue(self.client_manager.check_connect_stream())
self.assertTrue(self.client_manager.check_connect_stream())
self.assertEqual(self.client_manager._connected_streams, 3)

self.client_manager._max_streaming_clients = -1
self.assertTrue(self.client_manager.check_connect_stream())
self.assertEqual(self.client_manager._connected_streams, 4)

self.client_manager._max_streaming_clients = False
self.assertTrue(self.client_manager.check_connect_stream())
self.assertEqual(self.client_manager._connected_streams, 5)

0 comments on commit c997355

Please sign in to comment.