Skip to content

Commit

Permalink
Add basic support for proxies.
Browse files Browse the repository at this point in the history
This is missing proper error handling, tests and support for WSS.
  • Loading branch information
aaugustin committed Jun 2, 2018
1 parent caecbe4 commit e0b69ab
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 18 deletions.
2 changes: 1 addition & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Client

.. automodule:: websockets.client

.. autofunction:: connect(uri, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds)
.. autofunction:: connect(uri, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', proxy_uri=USE_SYSTEM_PROXY, proxy_ssl=None, **kwds)

.. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None)

Expand Down
100 changes: 85 additions & 15 deletions websockets/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import asyncio
import collections.abc
import sys
import urllib.request

from .exceptions import (
InvalidHandshake, InvalidMessage, InvalidStatusCode, NegotiationError
Expand All @@ -18,11 +19,13 @@
)
from .http import USER_AGENT, basic_auth_header, build_headers, read_response
from .protocol import WebSocketCommonProtocol
from .uri import parse_uri
from .uri import parse_proxy_uri, parse_uri


__all__ = ['connect', 'WebSocketClientProtocol']

USE_SYSTEM_PROXY = object()


class WebSocketClientProtocol(WebSocketCommonProtocol):
"""
Expand Down Expand Up @@ -195,6 +198,37 @@ def process_subprotocol(headers, available_subprotocols):

return subprotocol

@asyncio.coroutine
def proxy_connect(self, proxy_uri, uri, ssl=None):
assert ssl is None, "proxying TLS/SSL connections isn't supported yet"

request = ['CONNECT {uri.host}:{uri.port} HTTP/1.1'.format(uri=uri)]

headers = []

if uri.port == (443 if uri.secure else 80): # pragma: no cover
headers.append(('Host', uri.host))
else:
headers.append(('Host', '{uri.host}:{uri.port}'.format(uri=uri)))

if proxy_uri.user_info:
headers.append((
'Proxy-Authorization',
basic_auth_header(*proxy_uri.user_info),
))

request.extend('{}: {}'.format(k, v) for k, v in headers)
request.append('\r\n')
request = '\r\n'.join(request).encode()

self.writer.write(request)

status_code, headers = yield from read_response(self.reader)

if not 200 <= status_code < 300:
# TODO improve error handling
raise ValueError("proxy error: HTTP {}".format(status_code))

@asyncio.coroutine
def handshake(self, uri, origin=None, available_extensions=None,
available_subprotocols=None, extra_headers=None):
Expand Down Expand Up @@ -223,10 +257,10 @@ def handshake(self, uri, origin=None, available_extensions=None,
if uri.port == (443 if uri.secure else 80): # pragma: no cover
set_header('Host', uri.host)
else:
set_header('Host', '{}:{}'.format(uri.host, uri.port))
set_header('Host', '{uri.host}:{uri.port}'.format(uri=uri))

if uri.user_info:
set_header(*basic_auth_header(*uri.user_info))
set_header('Authorization', basic_auth_header(*uri.user_info))

if origin is not None:
set_header('Origin', origin)
Expand Down Expand Up @@ -318,6 +352,12 @@ class Connect:
* ``compression`` is a shortcut to configure compression extensions;
by default it enables the "permessage-deflate" extension; set it to
``None`` to disable compression
* ``proxy`` defines the HTTP proxy for establishing the connection; by
default, :func:`connect` uses proxies configured in the environment or
the system (see :func:`~urllib.request.getproxies` for details); set
``proxy`` to ``None`` to disable this behavior
* ``proxy_ssl`` may be set to a :class:`~ssl.SSLContext` to enforce TLS
settings for connecting to a ``https://`` proxy; it defaults to ``True``
:func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is
invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening
Expand All @@ -331,7 +371,9 @@ def __init__(self, uri, *,
read_limit=2 ** 16, write_limit=2 ** 16,
loop=None, legacy_recv=False, klass=None,
origin=None, extensions=None, subprotocols=None,
extra_headers=None, compression='deflate', **kwds):
extra_headers=None, compression='deflate',
proxy_uri=USE_SYSTEM_PROXY, proxy_ssl=None,
ssl=None, sock=None, **kwds):
if loop is None:
loop = asyncio.get_event_loop()

Expand All @@ -345,10 +387,13 @@ def __init__(self, uri, *,

uri = parse_uri(uri)
if uri.secure:
kwds.setdefault('ssl', True)
elif kwds.get('ssl') is not None:
raise ValueError("connect() received a SSL context for a ws:// "
"URI, use a wss:// URI to enable TLS")
if ssl is None:
ssl = True
elif ssl is not None:
raise ValueError(
"connect() received a TLS/SSL context for a ws:// URI;"
"use a wss:// URI to enable TLS",
)

if compression == 'deflate':
if extensions is None:
Expand All @@ -372,18 +417,39 @@ def __init__(self, uri, *,
extra_headers=extra_headers,
)

if kwds.get('sock') is None:
host, port = uri.host, uri.port
else:
if proxy_uri is USE_SYSTEM_PROXY:
proxies = urllib.request.getproxies()
# RFC 6455 recommends to prefer the proxy configured for HTTPS
# connections over the proxy configured for HTTP connections.
proxy_uri = proxies.get('https', proxies.get('http'))

if proxy_uri is not None:
proxy_uri = parse_proxy_uri(proxy_uri)
if proxy_uri.secure:
if proxy_ssl is None:
proxy_ssl = True
elif proxy_ssl is not None:
raise ValueError(
"connect() received a TLS/SSL context for a HTTP proxy; "
"use a HTTPS proxy to enable TLS",
)

if sock is not None:
# If sock is given, host and port mustn't be specified.
host, port = None, None
conn_host, conn_port, conn_ssl = None, None, ssl
elif proxy_uri is not None:
conn_host, conn_port, conn_ssl = (
proxy_uri.host, proxy_uri.port, proxy_ssl)
else:
conn_host, conn_port, conn_ssl = uri.host, uri.port, ssl

self._proxy_uri = proxy_uri
self._uri = uri
self._origin = origin
self._ssl = ssl

# This is a coroutine object.
self._creating_connection = loop.create_connection(
factory, host, port, **kwds)
factory, conn_host, conn_port, ssl=conn_ssl, sock=sock, **kwds)

@asyncio.coroutine
def __aenter__(self):
Expand All @@ -397,8 +463,12 @@ def __await__(self):
transport, protocol = yield from self._creating_connection

try:
if self._proxy_uri is not None:
yield from protocol.proxy_connect(
self._proxy_uri, self._uri, self._ssl)
yield from protocol.handshake(
self._uri, origin=self._origin,
self._uri,
origin=protocol.origin,
available_extensions=protocol.available_extensions,
available_subprotocols=protocol.available_subprotocols,
extra_headers=protocol.extra_headers,
Expand Down
2 changes: 1 addition & 1 deletion websockets/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,4 @@ def basic_auth_header(username, password):
assert ':' not in username
user_pass = '{}:{}'.format(username, password)
basic_credentials = base64.b64encode(user_pass.encode()).decode()
return ('Authorization', 'Basic ' + basic_credentials)
return 'Basic ' + basic_credentials
2 changes: 1 addition & 1 deletion websockets/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,5 @@ def test_basic_auth_header(self):
# Test vector from RFC 7617.
self.assertEqual(
basic_auth_header("Aladdin", "open sesame"),
('Authorization', 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='),
'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==',
)

0 comments on commit e0b69ab

Please sign in to comment.