diff --git a/src/selkies_gstreamer/__main__.py b/src/selkies_gstreamer/__main__.py index 0e33db4..14fe51a 100644 --- a/src/selkies_gstreamer/__main__.py +++ b/src/selkies_gstreamer/__main__.py @@ -471,6 +471,9 @@ def main(): parser.add_argument('--metrics_http_port', default=os.environ.get('SELKIES_METRICS_HTTP_PORT', '8000'), help='Port to start the Prometheus metrics server on') + parser.add_argument('--websockets_log_level', + default=os.environ.get('SELKIES_WEBSOCKETS_LOG_LEVEL', "CRITICAL"), + help='By default websockets log level is set to "CRITICAL", allowed values: https://docs.python.org/3/library/logging.html#levels') parser.add_argument('--debug', action='store_true', help='Enable debug logging') args = parser.parse_args() diff --git a/src/selkies_gstreamer/signalling_web.py b/src/selkies_gstreamer/signalling_web.py index 4f7008a..45b9c5f 100644 --- a/src/selkies_gstreamer/signalling_web.py +++ b/src/selkies_gstreamer/signalling_web.py @@ -36,9 +36,20 @@ from pathlib import Path from http import HTTPStatus +from websockets.http11 import Response +from websockets.datastructures import Headers +import email.utils + logger = logging.getLogger("signaling") web_logger = logging.getLogger("web") +# websockets 14.0 (and presumably later) logs an error if a connection is opened and +# closed before any data is sent. The websockets client seems to do same thing causing +# an inital handshake error. So better to just suppress all errors until we think we have +# a problem. You can unsuppress by setting the environment variable to DEBUG. +loglevel = os.getenv("SELKIES_WEBSOCKETS_LOG_LEVEL", "CRITICAL") +logging.getLogger("websockets").setLevel(loglevel) + MIME_TYPES = { "html": "text/html", "js": "text/javascript", @@ -178,28 +189,53 @@ def cache_file(self, full_path): data = open(full_path, 'rb').read() self.http_cache[full_path] = (data, now) return data - - async def process_request(self, server_root, path, request_headers): + + def custom_response(self, status, custom_headers, body): + """A wrapper indentical to https://github.com/python-websockets/websockets/blob/main/src/websockets/server.py#L482 + but allows 'body' type to be either bytes or string. + """ + status = http.HTTPStatus(status) + headers = Headers( + [ + ("Date", email.utils.formatdate(usegmt=True)), + ("Connection", "close"), + ("Content-Length", str(len(body))), + ("Content-Type", "text/plain; charset=utf-8"), + ] + ) + + # overriding and appending headers if provided + for key, value in custom_headers: + if headers.get(key) is not None: + del headers[key] + headers[key] = value + + # Expecting bytes, but if it's string then convert to bytes + if isinstance(body, str): + body = body.encode() + return Response(status.value, status.phrase, headers, body) + + async def process_request(self, server_root, connection, request): response_headers = [ - ('Server', 'asyncio websocket server'), ('Connection', 'close'), ] - + request_headers = request.headers username = '' if self.enable_basic_auth: if "basic" in request_headers.get("authorization", "").lower(): username, passwd = basicauth.decode(request_headers.get("authorization")) if not (username == self.basic_auth_user and passwd == self.basic_auth_password): - return http.HTTPStatus.UNAUTHORIZED, response_headers, b'Unauthorized' + return self.custom_response(http.HTTPStatus.UNAUTHORIZED, response_headers, 'Unauthorized') else: response_headers.append(('WWW-Authenticate', 'Basic realm="restricted, charset="UTF-8"')) - return http.HTTPStatus.UNAUTHORIZED, response_headers, b'Authorization required' + return self.custom_response(http.HTTPStatus.UNAUTHORIZED, response_headers, 'Authorization required') + path = request.path if path == "/ws/" or path == "/ws" or path.endswith("/signalling/") or path.endswith("/signalling"): return None if path == self.health_path + "/" or path == self.health_path: - return http.HTTPStatus.OK, response_headers, b"OK\n" + return connection.respond(http.HTTPStatus.OK, "OK") if path == "/turn/" or path == "/turn": if self.turn_shared_secret: @@ -208,20 +244,18 @@ async def process_request(self, server_root, path, request_headers): username = request_headers.get(self.turn_auth_header_name, "username") if not username: web_logger.warning("HTTP GET {} 401 Unauthorized - missing auth header: {}".format(path, self.turn_auth_header_name)) - return HTTPStatus.UNAUTHORIZED, response_headers, b'401 Unauthorized - missing auth header' + return self.custom_response(HTTPStatus.UNAUTHORIZED, response_headers, '401 Unauthorized - missing auth header') web_logger.info("Generating HMAC credential for user: {}".format(username)) rtc_config = generate_rtc_config(self.turn_host, self.turn_port, self.turn_shared_secret, username, self.turn_protocol, self.turn_tls, self.stun_host, self.stun_port) - return http.HTTPStatus.OK, response_headers, str.encode(rtc_config) + return self.custom_response(http.HTTPStatus.OK, response_headers, rtc_config) elif self.rtc_config: data = self.rtc_config - if type(data) == str: - data = str.encode(data) response_headers.append(('Content-Type', 'application/json')) - return http.HTTPStatus.OK, response_headers, data + return self.custom_response(http.HTTPStatus.OK, response_headers, data) else: web_logger.warning("HTTP GET {} 404 NOT FOUND - Missing RTC config".format(path)) - return HTTPStatus.NOT_FOUND, response_headers, b'404 NOT FOUND' + return self.custom_response(http.HTTPStatus.NOT_FOUND, response_headers, '404 NOT FOUND') path = path.split("?")[0] if path == '/': @@ -235,18 +269,16 @@ async def process_request(self, server_root, path, request_headers): not os.path.exists(full_path) or not os.path.isfile(full_path): response_headers.append(('Content-Type', 'text/html')) web_logger.info("HTTP GET {} 404 NOT FOUND".format(path)) - return HTTPStatus.NOT_FOUND, response_headers, b'404 NOT FOUND' + return self.custom_response(http.HTTPStatus.NOT_FOUND, response_headers, '404 NOT FOUND') # Guess file content type extension = full_path.split(".")[-1] mime_type = MIME_TYPES.get(extension, "application/octet-stream") response_headers.append(('Content-Type', mime_type)) - # Read the whole file into memory and send it out body = self.cache_file(full_path) - response_headers.append(('Content-Length', str(len(body)))) web_logger.info("HTTP GET {} 200 OK".format(path)) - return HTTPStatus.OK, response_headers, body + return self.custom_response(http.HTTPStatus.OK, response_headers, body) async def recv_msg_ping(self, ws, raddr): ''' @@ -450,7 +482,7 @@ def get_ssl_ctx(self, https_server=True): return sslctx async def run(self): - async def handler(ws, path): + async def handler(ws): ''' All incoming messages are handled here. @path is unused. ''' @@ -475,7 +507,7 @@ async def handler(ws, path): # Websocket and HTTP server http_handler = functools.partial(self.process_request, self.web_root) self.stop_server = self.loop.create_future() - async with websockets.serve(handler, self.addr, self.port, ssl=sslctx, process_request=http_handler, loop=self.loop, + async with websockets.serve(handler, self.addr, self.port, ssl=sslctx, process_request=http_handler, # Maximum number of messages that websockets will pop # off the asyncio and OS buffers per connection. See: # https://websockets.readthedocs.io/en/stable/api.html#websockets.protocol.WebSocketCommonProtocol diff --git a/src/selkies_gstreamer/webrtc_signalling.py b/src/selkies_gstreamer/webrtc_signalling.py index 008c23f..8f042e6 100644 --- a/src/selkies_gstreamer/webrtc_signalling.py +++ b/src/selkies_gstreamer/webrtc_signalling.py @@ -108,7 +108,7 @@ async def connect(self): while True: try: - self.conn = await websockets.connect(self.server, extra_headers=headers, ssl=sslctx) + self.conn = await websockets.connect(self.server, additional_headers=headers, ssl=sslctx) break except ConnectionRefusedError: logger.info("Connecting to signal server...")