-
Notifications
You must be signed in to change notification settings - Fork 0
/
src.txt
253 lines (219 loc) · 9.73 KB
/
src.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
class Connect:
"""
Connect to the WebSocket server at the given ``uri``.
Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which
can then be used to send and receive messages.
:func:`connect` can also be used as a asynchronous context manager. In
that case, the connection is closed when exiting the context.
:func:`connect` is a wrapper around the event loop's
:meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments
are passed to :meth:`~asyncio.loop.create_connection`.
For example, you can set the ``ssl`` keyword argument to a
:class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to
a ``wss://`` URI, if this argument isn't provided explicitly,
:func:`ssl.create_default_context` is called to create a context.
You can connect to a different host and port from those found in ``uri``
by setting ``host`` and ``port`` keyword arguments. This only changes the
destination of the TCP connection. The host name from ``uri`` is still
used in the TLS handshake for secure connections and in the ``Host`` HTTP
header.
The ``create_protocol`` parameter allows customizing the
:class:`~asyncio.Protocol` that manages the connection. It should be a
callable or class accepting the same arguments as
:class:`WebSocketClientProtocol` and returning an instance of
:class:`WebSocketClientProtocol` or a subclass. It defaults to
:class:`WebSocketClientProtocol`.
The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is
described in :class:`~websockets.protocol.WebSocketCommonProtocol`.
:func:`connect` also accepts the following optional arguments:
* ``compression`` is a shortcut to configure compression extensions;
by default it enables the "permessage-deflate" extension; set it to
``None`` to disable compression
* ``origin`` sets the Origin HTTP header
* ``extensions`` is a list of supported extensions in order of
decreasing preference
* ``subprotocols`` is a list of supported subprotocols in order of
decreasing preference
* ``extra_headers`` sets additional HTTP request headers; it can be a
:class:`~websockets.http.Headers` instance, a
:class:`~collections.abc.Mapping`, or an iterable of ``(name, value)``
pairs
:raises ~websockets.uri.InvalidURI: if ``uri`` is invalid
:raises ~websockets.handshake.InvalidHandshake: if the opening handshake
fails
"""
MAX_REDIRECTS_ALLOWED = 10
def __init__(
self,
uri: str,
*,
path: Optional[str] = None,
create_protocol: Optional[Type[WebSocketClientProtocol]] = None,
ping_interval: float = 20,
ping_timeout: float = 20,
close_timeout: Optional[float] = None,
max_size: int = 2 ** 20,
max_queue: int = 2 ** 5,
read_limit: int = 2 ** 16,
write_limit: int = 2 ** 16,
loop: Optional[asyncio.AbstractEventLoop] = None,
legacy_recv: bool = False,
klass: Optional[Type[WebSocketClientProtocol]] = None,
timeout: Optional[float] = None,
compression: Optional[str] = "deflate",
origin: Optional[Origin] = None,
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
subprotocols: Optional[Sequence[Subprotocol]] = None,
extra_headers: Optional[HeadersLike] = None,
**kwargs: Any,
) -> None:
# Backwards compatibility: close_timeout used to be called timeout.
if timeout is None:
timeout = 10
else:
warnings.warn("rename timeout to close_timeout", DeprecationWarning)
# If both are specified, timeout is ignored.
if close_timeout is None:
close_timeout = timeout
# Backwards compatibility: create_protocol used to be called klass.
if klass is None:
klass = WebSocketClientProtocol
else:
warnings.warn("rename klass to create_protocol", DeprecationWarning)
# If both are specified, klass is ignored.
if create_protocol is None:
create_protocol = klass
if loop is None:
loop = asyncio.get_event_loop()
wsuri = parse_uri(uri)
if wsuri.secure:
kwargs.setdefault("ssl", True)
elif kwargs.get("ssl") is not None:
raise ValueError(
"connect() received a ssl argument for a ws:// URI, "
"use a wss:// URI to enable TLS"
)
if compression == "deflate":
if extensions is None:
extensions = []
if not any(
extension_factory.name == ClientPerMessageDeflateFactory.name
for extension_factory in extensions
):
extensions = list(extensions) + [
ClientPerMessageDeflateFactory(client_max_window_bits=True)
]
elif compression is not None:
raise ValueError(f"unsupported compression: {compression}")
factory = functools.partial(
create_protocol,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
close_timeout=close_timeout,
max_size=max_size,
max_queue=max_queue,
read_limit=read_limit,
write_limit=write_limit,
loop=loop,
host=wsuri.host,
port=wsuri.port,
secure=wsuri.secure,
legacy_recv=legacy_recv,
origin=origin,
extensions=extensions,
subprotocols=subprotocols,
extra_headers=extra_headers,
)
if path is None:
host: Optional[str]
port: Optional[int]
if kwargs.get("sock") is None:
host, port = wsuri.host, wsuri.port
else:
# If sock is given, host and port shouldn't be specified.
host, port = None, None
# If host and port are given, override values from the URI.
host = kwargs.pop("host", host)
port = kwargs.pop("port", port)
create_connection = functools.partial(
loop.create_connection, factory, host, port, **kwargs
)
else:
create_connection = functools.partial(
loop.create_unix_connection, factory, path, **kwargs
)
# This is a coroutine function.
self._create_connection = create_connection
self._wsuri = wsuri
def handle_redirect(self, uri: str) -> None:
# Update the state of this instance to connect to a new URI.
old_wsuri = self._wsuri
new_wsuri = parse_uri(uri)
# Forbid TLS downgrade.
if old_wsuri.secure and not new_wsuri.secure:
raise SecurityError("redirect from WSS to WS")
same_origin = (
old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port
)
# Rewrite the host and port arguments for cross-origin redirects.
# This preserves connection overrides with the host and port
# arguments if the redirect points to the same host and port.
if not same_origin:
# Replace the host and port argument passed to the protocol factory.
factory = self._create_connection.args[0]
factory = functools.partial(
factory.func,
*factory.args,
**dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port),
)
# Replace the host and port argument passed to create_connection.
self._create_connection = functools.partial(
self._create_connection.func,
*(factory, new_wsuri.host, new_wsuri.port),
**self._create_connection.keywords,
)
# Set the new WebSocket URI. This suffices for same-origin redirects.
self._wsuri = new_wsuri
# async with connect(...)
async def __aenter__(self) -> WebSocketClientProtocol:
return await self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.ws_client.close()
# await connect(...)
def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]:
# Create a suitable iterator by calling __await__ on a coroutine.
return self.__await_impl__().__await__()
async def __await_impl__(self) -> WebSocketClientProtocol:
for redirects in range(self.MAX_REDIRECTS_ALLOWED):
transport, protocol = await self._create_connection()
# https://github.com/python/typeshed/pull/2756
transport = cast(asyncio.Transport, transport)
protocol = cast(WebSocketClientProtocol, protocol)
try:
try:
await protocol.handshake(
self._wsuri,
origin=protocol.origin,
available_extensions=protocol.available_extensions,
available_subprotocols=protocol.available_subprotocols,
extra_headers=protocol.extra_headers,
)
except Exception:
protocol.fail_connection()
await protocol.wait_closed()
raise
else:
self.ws_client = protocol
return protocol
except RedirectHandshake as exc:
self.handle_redirect(exc.uri)
else:
raise SecurityError("too many redirects")
# yield from connect(...)
__iter__ = __await__