diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b15eddf7..b0e15b54 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -558,21 +558,27 @@ def handle_redirect(self, uri: str) -> None: raise SecurityError("redirect from WSS to WS") same_origin = ( - old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port + old_wsuri.secure == new_wsuri.secure + and old_wsuri.host == new_wsuri.host + and old_wsuri.port == new_wsuri.port ) - # Rewrite the host and port arguments for cross-origin redirects. + # Rewrite secure, host, and port for cross-origin redirects. # This preserves connection overrides with the host and port # arguments if the redirect points to the same host and port. if not same_origin: - # Replace the host and port argument passed to the protocol factory. factory = self._create_connection.args[0] + # Support TLS upgrade. + if not old_wsuri.secure and new_wsuri.secure: + factory.keywords["secure"] = True + self._create_connection.keywords.setdefault("ssl", True) + # Replace secure, host, and port arguments of the protocol factory. factory = functools.partial( factory.func, *factory.args, **dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port), ) - # Replace the host and port argument passed to create_connection. + # Replace secure, host, and port arguments of create_connection. self._create_connection = functools.partial( self._create_connection.func, *(factory, new_wsuri.host, new_wsuri.port), diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index c3808657..09b3b361 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -75,6 +75,8 @@ async def redirect_request(path, headers, test, status): location = "/" elif path == "/infinite": location = get_server_uri(test.server, test.secure, "/infinite") + elif path == "/force_secure": + location = get_server_uri(test.server, True, "/") elif path == "/force_insecure": location = get_server_uri(test.server, False, "/") elif path == "/missing_location": @@ -1290,7 +1292,16 @@ def test_connection_error_during_closing_handshake(self, close): class ClientServerTests( CommonClientServerTests, ClientServerTestsMixin, AsyncioTestCase ): - pass + + def test_redirect_secure(self): + with temp_test_redirecting_server(self): + # websockets doesn't support serving non-TLS and TLS connections + # from the same server and this test suite makes it difficult to + # run two servers. Therefore, we expect the redirect to create a + # TLS client connection to a non-TLS server, which will fail. + with self.assertRaises(ssl.SSLError): + with self.temp_client("/force_secure"): + self.fail("did not raise") class SecureClientServerTests(