Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update webserver to comply with websockets 14.1 version #171

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/selkies_gstreamer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,9 @@ def main():
parser.add_argument('--metrics_http_port',
default=os.environ.get('SELKIES_METRICS_HTTP_PORT', '8000'),
help='Port to start the Prometheus metrics server on')
parser.add_argument('--websockets_log_level',
default=os.environ.get('SELKIES_WEBSOCKETS_LOG_LEVEL', "CRITICAL"),
help='By default websockets log level is set to "CRITICAL", allowed values: https://docs.python.org/3/library/logging.html#levels')
parser.add_argument('--debug', action='store_true',
help='Enable debug logging')
args = parser.parse_args()
Expand Down
70 changes: 51 additions & 19 deletions src/selkies_gstreamer/signalling_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,20 @@
from pathlib import Path
from http import HTTPStatus

from websockets.http11 import Response
from websockets.datastructures import Headers
import email.utils

logger = logging.getLogger("signaling")
web_logger = logging.getLogger("web")

# websockets 14.0 (and presumably later) logs an error if a connection is opened and
# closed before any data is sent. The websockets client seems to do same thing causing
# an inital handshake error. So better to just suppress all errors until we think we have
# a problem. You can unsuppress by setting the environment variable to DEBUG.
loglevel = os.getenv("SELKIES_WEBSOCKETS_LOG_LEVEL", "CRITICAL")
logging.getLogger("websockets").setLevel(loglevel)

MIME_TYPES = {
"html": "text/html",
"js": "text/javascript",
Expand Down Expand Up @@ -178,28 +189,53 @@ def cache_file(self, full_path):
data = open(full_path, 'rb').read()
self.http_cache[full_path] = (data, now)
return data

async def process_request(self, server_root, path, request_headers):

def custom_response(self, status, custom_headers, body):
"""A wrapper indentical to https://github.com/python-websockets/websockets/blob/main/src/websockets/server.py#L482
but allows 'body' type to be either bytes or string.
"""
status = http.HTTPStatus(status)
headers = Headers(
[
("Date", email.utils.formatdate(usegmt=True)),
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", "text/plain; charset=utf-8"),
]
)

# overriding and appending headers if provided
for key, value in custom_headers:
if headers.get(key) is not None:
del headers[key]
headers[key] = value

# Expecting bytes, but if it's string then convert to bytes
if isinstance(body, str):
body = body.encode()
return Response(status.value, status.phrase, headers, body)

async def process_request(self, server_root, connection, request):
response_headers = [
('Server', 'asyncio websocket server'),
('Connection', 'close'),
]

request_headers = request.headers
username = ''
if self.enable_basic_auth:
if "basic" in request_headers.get("authorization", "").lower():
username, passwd = basicauth.decode(request_headers.get("authorization"))
if not (username == self.basic_auth_user and passwd == self.basic_auth_password):
return http.HTTPStatus.UNAUTHORIZED, response_headers, b'Unauthorized'
return self.custom_response(http.HTTPStatus.UNAUTHORIZED, response_headers, 'Unauthorized')
else:
response_headers.append(('WWW-Authenticate', 'Basic realm="restricted, charset="UTF-8"'))
return http.HTTPStatus.UNAUTHORIZED, response_headers, b'Authorization required'
return self.custom_response(http.HTTPStatus.UNAUTHORIZED, response_headers, 'Authorization required')

path = request.path
if path == "/ws/" or path == "/ws" or path.endswith("/signalling/") or path.endswith("/signalling"):
return None

if path == self.health_path + "/" or path == self.health_path:
return http.HTTPStatus.OK, response_headers, b"OK\n"
return connection.respond(http.HTTPStatus.OK, "OK")

if path == "/turn/" or path == "/turn":
if self.turn_shared_secret:
Expand All @@ -208,20 +244,18 @@ async def process_request(self, server_root, path, request_headers):
username = request_headers.get(self.turn_auth_header_name, "username")
if not username:
web_logger.warning("HTTP GET {} 401 Unauthorized - missing auth header: {}".format(path, self.turn_auth_header_name))
return HTTPStatus.UNAUTHORIZED, response_headers, b'401 Unauthorized - missing auth header'
return self.custom_response(HTTPStatus.UNAUTHORIZED, response_headers, '401 Unauthorized - missing auth header')
web_logger.info("Generating HMAC credential for user: {}".format(username))
rtc_config = generate_rtc_config(self.turn_host, self.turn_port, self.turn_shared_secret, username, self.turn_protocol, self.turn_tls, self.stun_host, self.stun_port)
return http.HTTPStatus.OK, response_headers, str.encode(rtc_config)
return self.custom_response(http.HTTPStatus.OK, response_headers, rtc_config)

elif self.rtc_config:
data = self.rtc_config
if type(data) == str:
data = str.encode(data)
response_headers.append(('Content-Type', 'application/json'))
return http.HTTPStatus.OK, response_headers, data
return self.custom_response(http.HTTPStatus.OK, response_headers, data)
else:
web_logger.warning("HTTP GET {} 404 NOT FOUND - Missing RTC config".format(path))
return HTTPStatus.NOT_FOUND, response_headers, b'404 NOT FOUND'
return self.custom_response(http.HTTPStatus.NOT_FOUND, response_headers, '404 NOT FOUND')

path = path.split("?")[0]
if path == '/':
Expand All @@ -235,18 +269,16 @@ async def process_request(self, server_root, path, request_headers):
not os.path.exists(full_path) or not os.path.isfile(full_path):
response_headers.append(('Content-Type', 'text/html'))
web_logger.info("HTTP GET {} 404 NOT FOUND".format(path))
return HTTPStatus.NOT_FOUND, response_headers, b'404 NOT FOUND'
return self.custom_response(http.HTTPStatus.NOT_FOUND, response_headers, '404 NOT FOUND')

# Guess file content type
extension = full_path.split(".")[-1]
mime_type = MIME_TYPES.get(extension, "application/octet-stream")
response_headers.append(('Content-Type', mime_type))

# Read the whole file into memory and send it out
body = self.cache_file(full_path)
response_headers.append(('Content-Length', str(len(body))))
web_logger.info("HTTP GET {} 200 OK".format(path))
return HTTPStatus.OK, response_headers, body
return self.custom_response(http.HTTPStatus.OK, response_headers, body)

async def recv_msg_ping(self, ws, raddr):
'''
Expand Down Expand Up @@ -450,7 +482,7 @@ def get_ssl_ctx(self, https_server=True):
return sslctx

async def run(self):
async def handler(ws, path):
async def handler(ws):
'''
All incoming messages are handled here. @path is unused.
'''
Expand All @@ -475,7 +507,7 @@ async def handler(ws, path):
# Websocket and HTTP server
http_handler = functools.partial(self.process_request, self.web_root)
self.stop_server = self.loop.create_future()
async with websockets.serve(handler, self.addr, self.port, ssl=sslctx, process_request=http_handler, loop=self.loop,
async with websockets.serve(handler, self.addr, self.port, ssl=sslctx, process_request=http_handler,
# Maximum number of messages that websockets will pop
# off the asyncio and OS buffers per connection. See:
# https://websockets.readthedocs.io/en/stable/api.html#websockets.protocol.WebSocketCommonProtocol
Expand Down
2 changes: 1 addition & 1 deletion src/selkies_gstreamer/webrtc_signalling.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def connect(self):

while True:
try:
self.conn = await websockets.connect(self.server, extra_headers=headers, ssl=sslctx)
self.conn = await websockets.connect(self.server, additional_headers=headers, ssl=sslctx)
break
except ConnectionRefusedError:
logger.info("Connecting to signal server...")
Expand Down
Loading