From b9ac8536b71a0bfc2cc8cf37378da938b690304b Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Mon, 28 Mar 2022 23:06:31 -0700 Subject: [PATCH 01/26] Initial implementation --- .../azure/eventhub/_client_base.py | 1 - .../azure/eventhub/_consumer.py | 1 + .../azure/eventhub/_producer.py | 3 +- .../azure/eventhub/_pyamqp/_connection.py | 1 + .../azure/eventhub/_pyamqp/_transport.py | 36 ++++++++++++++++++- .../azure/eventhub/_pyamqp/client.py | 7 +++- .../azure/eventhub/_pyamqp/constants.py | 12 +++++++ .../azure/eventhub/_pyamqp/sasl.py | 14 +++++--- 8 files changed, 67 insertions(+), 8 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index fdd8c7f297bd..8e357dd2995b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -325,7 +325,6 @@ def _create_auth(self): token_type=token_type, timeout=self._config.auth_timeout, http_proxy=self._config.http_proxy, - transport_type=self._config.transport_type, custom_endpoint_hostname=self._config.custom_endpoint_hostname, port=self._config.connection_port, verify=self._config.connection_verify, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index ddb9a14a166f..29ae7ef906d5 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -156,6 +156,7 @@ def _create_handler(self, auth): auth=auth, idle_timeout=self._idle_timeout, network_trace=self._client._config.network_tracing, # pylint:disable=protected-access + transport_type=self._client._config.transport_type, # pylint:disable=protected-access link_credit=self._prefetch, link_properties=self._link_properties, retry_policy=self._retry_policy, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index a72ce0753980..669ff6a78552 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -130,7 +130,8 @@ def _create_handler(self, auth): self._target, auth=auth, idle_timeout=self._idle_timeout, - network_trace=self._client._config.network_tracing, # pylint: disable=protected-access + network_trace=self._client._config.network_tracing, # pylint:disable=protected-access + transport_type=self._client._config.transport_type # pylint:disable=protected-access retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py index a26d220f3286..3b4ceeafa00c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py @@ -77,6 +77,7 @@ class Connection(object): Default value is `0.1`. :keyword bool network_trace: Whether to log the network traffic. Default value is `False`. If enabled, frames will be logged at the logging.INFO level. + :keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket. """ def __init__(self, endpoint, **kwargs): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index 85371fdd07d9..cd1647ee4341 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -51,7 +51,7 @@ from ._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack from ._encode import encode_frame from ._decode import decode_frame, decode_empty_frame -from .constants import TLS_HEADER_FRAME +from .constants import TLS_HEADER_FRAME, WEBSOCKET_PORT try: @@ -655,3 +655,37 @@ def Transport(host, connect_timeout=None, ssl=False, **kwargs): """ transport = SSLTransport if ssl else TCPTransport return transport(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs) + +class WebSocketTransport(_AbstractTransport): + def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, read_timeout=None, write_timeout=None, + socket_settings=None, raise_on_initial_eintr=True, **kwargs + ): + super().__init__( + host, port, connect_timeout, read_timeout, write_timeout, socket_settings, raise_on_initial_eintr, **kwargs + ) + self.ws = None + try: + from websocket import create_connection + # TODO: transform ssl to sslopt + self.ws = create_connection( + host, + timeout=connect_timeout, + skip_utf8_validation=True, + sslopt=kwargs.pop('ssl', None) + ) + except ImportError: + raise ValueError("Please install websocket-client library to use websocket transport.") + + + def _read(self, n, initial=False): # pylint: disable=unused-arguments + """Read exactly n bytes from the peer.""" + result = self.ws.recv() + return result + + def _shutdown_transport(self): + """Do any preliminary work in shutting down the connection.""" + self.ws.close() + + def _write(self, s): + """Completely write a string to the peer.""" + self.ws.send(s) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py index 09d3303a2698..6e7fa3dce98c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py @@ -37,6 +37,7 @@ SenderSettleMode, ReceiverSettleMode, LinkDeliverySettleReason, + TransportType, SEND_DISPOSITION_ACCEPT, SEND_DISPOSITION_REJECT, AUTH_TYPE_CBS, @@ -155,6 +156,9 @@ def __init__(self, hostname, auth=None, **kwargs): self._receive_settle_mode = kwargs.pop('receive_settle_mode', ReceiverSettleMode.Second) self._desired_capabilities = kwargs.pop('desired_capabilities', None) + # transport + self._transport_type = kwargs.pop('transport_type', TransportType.Amqp) + def __enter__(self): """Run Client in a context manager.""" self.open() @@ -240,7 +244,8 @@ def open(self): channel_max=self._channel_max, idle_timeout=self._idle_timeout, properties=self._properties, - network_trace=self._network_trace + network_trace=self._network_trace, + transport=self._transport_type ) self._connection.open() if not self._session: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py index 0e60bbca7a56..a5e2928d1567 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py @@ -14,6 +14,8 @@ #: The port number is reserved for future transport mappings to these protocols. PORT = 5672 +# default port for AMQP over Websocket +WEBSOCKET_PORT = 443 #: The IANA assigned port number for secure AMQP (amqps).The standard AMQP port number that has been assigned #: by IANA for secure TCP using TLS. Implementations listening on this port should NOT expect a protocol @@ -302,3 +304,13 @@ class MessageDeliveryState(object): MessageDeliveryState.Timeout, MessageDeliveryState.Cancelled ) + +class TransportType(Enum): + """Transport type + The underlying transport protocol type: + Amqp: AMQP over the default TCP transport protocol, it uses port 5671. + AmqpOverWebsocket: Amqp over the Web Sockets transport protocol, it uses + port 443. + """ + Amqp = 1 + AmqpOverWebsocket = 2 diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py index 99dd25d43730..c7882ac8a94f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py @@ -7,9 +7,9 @@ import struct from enum import Enum -from ._transport import SSLTransport, AMQPS_PORT +from ._transport import SSLTransport, WebSocketTransport, AMQPS_PORT from .types import AMQPTypes, TYPE, VALUE -from .constants import FIELD, SASLCode, SASL_HEADER_FRAME +from .constants import FIELD, SASLCode, SASL_HEADER_FRAME, TransportType, WEBSOCKET_PORT from .performatives import ( SASLOutcome, SASLResponse, @@ -69,12 +69,18 @@ def start(self): return b'' -class SASLTransport(SSLTransport): +class SASLTransport(): def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): self.credential = credential ssl = ssl or True - super(SASLTransport, self).__init__(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs) + self._transport = SSLTransport(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs) + amqp_over_websocket = kwargs.pop('transport_type') + if amqp_over_websocket is TransportType.AmqpOverWebSocket: + self._transport = WebSocketTransport(host, port=WEBSOCKET_PORT, connect_timeout=connect_timeout, ssl=ssl, **kwargs) + self.read = self._transport.read + self.write = self._transport.write + super(SASLTransport, self).__init__( **kwargs) def negotiate(self): with self.block(): From 845cd167fd79e96e258b40766ae998949ad92585 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Mon, 28 Mar 2022 23:29:57 -0700 Subject: [PATCH 02/26] http proxy support --- .../azure-eventhub/azure/eventhub/_client_base.py | 1 - .../azure-eventhub/azure/eventhub/_consumer.py | 1 + .../azure-eventhub/azure/eventhub/_producer.py | 3 ++- .../azure/eventhub/_pyamqp/_connection.py | 3 +++ .../azure/eventhub/_pyamqp/_transport.py | 14 +++++++++++++- .../azure/eventhub/_pyamqp/client.py | 4 +++- .../azure-eventhub/azure/eventhub/_pyamqp/sasl.py | 10 +++++++++- 7 files changed, 31 insertions(+), 5 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index 8e357dd2995b..85c57dbe7b81 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -324,7 +324,6 @@ def _create_auth(self): functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE), token_type=token_type, timeout=self._config.auth_timeout, - http_proxy=self._config.http_proxy, custom_endpoint_hostname=self._config.custom_endpoint_hostname, port=self._config.connection_port, verify=self._config.connection_verify, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index 29ae7ef906d5..0c33ab46a798 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -157,6 +157,7 @@ def _create_handler(self, auth): idle_timeout=self._idle_timeout, network_trace=self._client._config.network_tracing, # pylint:disable=protected-access transport_type=self._client._config.transport_type, # pylint:disable=protected-access + http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access link_credit=self._prefetch, link_properties=self._link_properties, retry_policy=self._retry_policy, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index 669ff6a78552..feb5fcc523ca 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -131,7 +131,8 @@ def _create_handler(self, auth): auth=auth, idle_timeout=self._idle_timeout, network_trace=self._client._config.network_tracing, # pylint:disable=protected-access - transport_type=self._client._config.transport_type # pylint:disable=protected-access + transport_type=self._client._config.transport_type, # pylint:disable=protected-access + http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py index 3b4ceeafa00c..33b57e51999d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py @@ -78,6 +78,9 @@ class Connection(object): :keyword bool network_trace: Whether to log the network traffic. Default value is `False`. If enabled, frames will be logged at the logging.INFO level. :keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket. + :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). + Additionally the following keys may also be present: `'username', 'password'`. """ def __init__(self, endpoint, **kwargs): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index cd1647ee4341..866720393426 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -664,6 +664,15 @@ def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, read_timeout host, port, connect_timeout, read_timeout, write_timeout, socket_settings, raise_on_initial_eintr, **kwargs ) self.ws = None + http_proxy = kwargs.get('http_proxy', None) + http_proxy_host, http_proxy_port, http_proxy_auth = None, None, None + if http_proxy: + http_proxy_host = http_proxy['proxy_hostname'] + http_proxy_port = http_proxy['proxy_hostname'] + username = http_proxy.get('username', None) + password = http_proxy.get('password', None) + if username or password: + http_proxy_auth = (username, password) try: from websocket import create_connection # TODO: transform ssl to sslopt @@ -671,7 +680,10 @@ def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, read_timeout host, timeout=connect_timeout, skip_utf8_validation=True, - sslopt=kwargs.pop('ssl', None) + sslopt=kwargs.pop('ssl', None), + http_proxy_host=http_proxy_host, + http_proxy_port=http_proxy_port, + http_proxy_auth=http_proxy_auth ) except ImportError: raise ValueError("Please install websocket-client library to use websocket transport.") diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py index 6e7fa3dce98c..cb987f06ff51 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py @@ -158,6 +158,7 @@ def __init__(self, hostname, auth=None, **kwargs): # transport self._transport_type = kwargs.pop('transport_type', TransportType.Amqp) + self._http_proxy = kwargs.pop('http_proxy', None) def __enter__(self): """Run Client in a context manager.""" @@ -245,7 +246,8 @@ def open(self): idle_timeout=self._idle_timeout, properties=self._properties, network_trace=self._network_trace, - transport=self._transport_type + transport=self._transport_type, + http_proxy=self._http_proxy ) self._connection.open() if not self._session: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py index c7882ac8a94f..c80d5b68023c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py @@ -74,10 +74,18 @@ class SASLTransport(): def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): self.credential = credential ssl = ssl or True + http_proxy = kwargs.pop('http_proxy', None) self._transport = SSLTransport(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs) amqp_over_websocket = kwargs.pop('transport_type') if amqp_over_websocket is TransportType.AmqpOverWebSocket: - self._transport = WebSocketTransport(host, port=WEBSOCKET_PORT, connect_timeout=connect_timeout, ssl=ssl, **kwargs) + self._transport = WebSocketTransport( + host, + port=WEBSOCKET_PORT, + connect_timeout=connect_timeout, + ssl=ssl, + http_proxy=http_proxy, + **kwargs + ) self.read = self._transport.read self.write = self._transport.write super(SASLTransport, self).__init__( **kwargs) From 4f524b1d3be6183eb744443e420b7f40423786e8 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Wed, 6 Apr 2022 21:10:52 -0700 Subject: [PATCH 03/26] change impl --- .../azure/eventhub/_pyamqp/_connection.py | 23 +++++-- .../azure/eventhub/_pyamqp/_transport.py | 64 ++++++++++++++----- .../azure/eventhub/_pyamqp/client.py | 2 +- .../azure/eventhub/_pyamqp/message.py | 2 + .../azure/eventhub/_pyamqp/sasl.py | 63 ++++++++++++------ .../pyamqp_tests/synctests/test_websocket.py | 27 ++++++++ 6 files changed, 138 insertions(+), 43 deletions(-) create mode 100644 sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py index 33b57e51999d..deb6134696d0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py @@ -12,7 +12,7 @@ from ssl import SSLError from ._transport import Transport -from .sasl import SASLTransport +from .sasl import SASLTransport, SASLTransportWithWebSocket from .session import Session from .performatives import OpenFrame, CloseFrame from .constants import ( @@ -22,7 +22,8 @@ MAX_FRAME_SIZE_BYTES, HEADER_FRAME, ConnectionState, - EMPTY_FRAME + EMPTY_FRAME, + TransportType ) from .error import ( @@ -96,14 +97,22 @@ def __init__(self, endpoint, **kwargs): self.state = None # type: Optional[ConnectionState] transport = kwargs.get('transport') + self._transport_type = kwargs.pop('transport_type') if transport: self._transport = transport elif 'sasl_credential' in kwargs: - self._transport = SASLTransport( - host=parsed_url.netloc, - credential=kwargs['sasl_credential'], - **kwargs - ) + if self._transport_type is TransportType.AmqpOverWebsocket: + self._transport = SASLTransportWithWebSocket( + host=parsed_url.netloc, + credential=kwargs['sasl_credential'], + **kwargs + ) + else: + self._transport = SASLTransport( + host=parsed_url.netloc, + credential=kwargs['sasl_credential'], + **kwargs + ) else: self._transport = Transport(parsed_url.netloc, **kwargs) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index 866720393426..c11c739a67e5 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -51,7 +51,7 @@ from ._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack from ._encode import encode_frame from ._decode import decode_frame, decode_empty_frame -from .constants import TLS_HEADER_FRAME, WEBSOCKET_PORT +from .constants import TLS_HEADER_FRAME, WEBSOCKET_PORT, TransportType try: @@ -653,34 +653,42 @@ def Transport(host, connect_timeout=None, ssl=False, **kwargs): Given a few parameters from the Connection constructor, select and create a subclass of _AbstractTransport. """ + transport_type = kwargs.pop('transport_type') + if transport_type == TransportType.AmqpOverWebsocket: + transport = WebSocketTransport transport = SSLTransport if ssl else TCPTransport return transport(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs) class WebSocketTransport(_AbstractTransport): - def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, read_timeout=None, write_timeout=None, - socket_settings=None, raise_on_initial_eintr=True, **kwargs + def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs ): + self.sslopts = ssl if isinstance(ssl, dict) else {} + self._read_buffer = BytesIO() super().__init__( - host, port, connect_timeout, read_timeout, write_timeout, socket_settings, raise_on_initial_eintr, **kwargs + host, port, connect_timeout, **kwargs ) self.ws = None - http_proxy = kwargs.get('http_proxy', None) + self._http_proxy = kwargs.get('http_proxy', None) + + def connect(self): + http_proxy_host, http_proxy_port, http_proxy_auth = None, None, None - if http_proxy: - http_proxy_host = http_proxy['proxy_hostname'] - http_proxy_port = http_proxy['proxy_hostname'] - username = http_proxy.get('username', None) - password = http_proxy.get('password', None) + if self._http_proxy: + http_proxy_host = self._http_proxy['proxy_hostname'] + http_proxy_port = self._http_proxy['proxy_hostname'] + username = self._http_proxy.get('username', None) + password = self._http_proxy.get('password', None) if username or password: http_proxy_auth = (username, password) try: from websocket import create_connection # TODO: transform ssl to sslopt self.ws = create_connection( - host, - timeout=connect_timeout, + url="wss://{}/$servicebus/websocket/".format("testeh.servicebus.windows.net"), + subprotocols=['AMQPWSB10'], + timeout=self.connect_timeout, skip_utf8_validation=True, - sslopt=kwargs.pop('ssl', None), + sslopt=self.sslopts, http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port, http_proxy_auth=http_proxy_auth @@ -688,11 +696,33 @@ def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, read_timeout except ImportError: raise ValueError("Please install websocket-client library to use websocket transport.") - - def _read(self, n, initial=False): # pylint: disable=unused-arguments + def _read(self, n, initial=False, **kwargs): # pylint: disable=unused-arguments """Read exactly n bytes from the peer.""" - result = self.ws.recv() - return result + length = 0 + buffer = kwargs.get('buffer', None) + view = buffer or memoryview(bytearray(toread)) + nbytes = self._read_buffer.readinto(view) + toread -= nbytes + length += nbytes + try: + while toread: + data = self.ws.recv() + if len(data) <= toread: + view = memoryview(bytes(view) + data) + toread = 0 + else: + view = memoryview(bytes(view) + data[:n]) + toread -= n + length += len(data) + if not nbytes: + raise IOError('Server unexpectedly closed connection') + + length += nbytes + toread -= nbytes + except: # noqa + self._read_buffer = BytesIO(view[:length]) + raise + return view def _shutdown_transport(self): """Do any preliminary work in shutting down the connection.""" diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py index cb987f06ff51..fdf98b77e241 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py @@ -246,7 +246,7 @@ def open(self): idle_timeout=self._idle_timeout, properties=self._properties, network_trace=self._network_trace, - transport=self._transport_type, + transport_type=self._transport_type, http_proxy=self._http_proxy ) self._connection.open() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/message.py index a2ef0087fd94..1cf777d3ec24 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/message.py @@ -265,3 +265,5 @@ def __init__(self, message, state=MessageDeliveryState.WaitingToBeSent, expiry=N self.reason = None self.delivery = None self.error = None + +__all__ = ['Message'] diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py index c80d5b68023c..fd0023b395a1 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py @@ -69,28 +69,14 @@ def start(self): return b'' -class SASLTransport(): +class SASLTransport(SSLTransport): def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): self.credential = credential ssl = ssl or True - http_proxy = kwargs.pop('http_proxy', None) - self._transport = SSLTransport(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs) - amqp_over_websocket = kwargs.pop('transport_type') - if amqp_over_websocket is TransportType.AmqpOverWebSocket: - self._transport = WebSocketTransport( - host, - port=WEBSOCKET_PORT, - connect_timeout=connect_timeout, - ssl=ssl, - http_proxy=http_proxy, - **kwargs - ) - self.read = self._transport.read - self.write = self._transport.write - super(SASLTransport, self).__init__( **kwargs) - - def negotiate(self): + super().__init__(host, port, connect_timeout, ssl, **kwargs) + + def negotiate(self): with self.block(): self.write(SASL_HEADER_FRAME) _, returned_header = self.receive_frame() @@ -115,3 +101,44 @@ def negotiate(self): return else: raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) + +class SASLTransportWithWebSocket(WebSocketTransport): + + def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): + self.credential = credential + ssl = ssl or True + http_proxy = kwargs.pop('http_proxy', None) + self._transport = WebSocketTransport( + host, + port=WEBSOCKET_PORT, + connect_timeout=connect_timeout, + ssl=ssl, + http_proxy=http_proxy, + **kwargs + ) + super().__init__(host, port, connect_timeout, ssl, **kwargs) + + def negotiate(self): + self.write(SASL_HEADER_FRAME) + _, returned_header = self.receive_frame() + if returned_header[1] != SASL_HEADER_FRAME: + raise ValueError("Mismatching AMQP header protocol. Expected: {}, received: {}".format( + SASL_HEADER_FRAME, returned_header[1])) + + _, supported_mechansisms = self.receive_frame(verify_frame_type=1) + if self.credential.mechanism not in supported_mechansisms[1][0]: # sasl_server_mechanisms + raise ValueError("Unsupported SASL credential type: {}".format(self.credential.mechanism)) + sasl_init = SASLInit( + mechanism=self.credential.mechanism, + initial_response=self.credential.start(), + hostname=self.host) + self.send_frame(0, sasl_init, frame_type=_SASL_FRAME_TYPE) + + _, next_frame = self.receive_frame(verify_frame_type=1) + frame_type, fields = next_frame + if frame_type != 0x00000044: # SASLOutcome + raise NotImplementedError("Unsupported SASL challenge") + if fields[0] == SASLCode.Ok: # code + return + else: + raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) diff --git a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py new file mode 100644 index 000000000000..79bf6d5ecbe8 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py @@ -0,0 +1,27 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import pytest + +from azure.eventhub import TransportType +from azure.eventhub._pyamqp import authentication, ReceiveClient + +def test_event_hubs_client_web_socket(live_eventhub): + uri = "sb://{}/{}".format(live_eventhub['hostname'], live_eventhub['event_hub']) + sas_auth = authentication.SASTokenAuth( + uri=uri, + audience=uri, + username=live_eventhub['key_name'], + password=live_eventhub['access_key'] + ) + + source = "amqps://{}/{}/ConsumerGroups/{}/Partitions/{}".format( + live_eventhub['hostname'], + live_eventhub['event_hub'], + live_eventhub['consumer_group'], + live_eventhub['partition']) + + with ReceiveClient(live_eventhub['hostname'], source, auth=sas_auth, debug=False, timeout=5000, prefetch=50) as receive_client: + receive_client.receive_message_batch(max_batch_size=10) From 5d42efc4a2c1b7785bce0fae5020c36d0fe8ef5e Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Thu, 14 Apr 2022 16:23:33 -0700 Subject: [PATCH 04/26] more changes --- .../azure/eventhub/_pyamqp/_transport.py | 50 +++++++++++-------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index c11c739a67e5..a70487da6deb 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -392,7 +392,8 @@ def read(self, verify_frame_type=0, **kwargs): # TODO: verify frame type? read_frame_buffer = BytesIO() try: frame_header = memoryview(bytearray(8)) - read_frame_buffer.write(read(8, buffer=frame_header, initial=True)) + data = read(8, buffer=frame_header, initial=True) + read_frame_buffer.write(data) channel = struct.unpack('>H', frame_header[6:])[0] size = frame_header[0:4] @@ -663,7 +664,6 @@ class WebSocketTransport(_AbstractTransport): def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs ): self.sslopts = ssl if isinstance(ssl, dict) else {} - self._read_buffer = BytesIO() super().__init__( host, port, connect_timeout, **kwargs ) @@ -686,7 +686,7 @@ def connect(self): self.ws = create_connection( url="wss://{}/$servicebus/websocket/".format("testeh.servicebus.windows.net"), subprotocols=['AMQPWSB10'], - timeout=self.connect_timeout, + timeout=9999, skip_utf8_validation=True, sslopt=self.sslopts, http_proxy_host=http_proxy_host, @@ -696,31 +696,34 @@ def connect(self): except ImportError: raise ValueError("Please install websocket-client library to use websocket transport.") - def _read(self, n, initial=False, **kwargs): # pylint: disable=unused-arguments + def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-arguments """Read exactly n bytes from the peer.""" + length = 0 - buffer = kwargs.get('buffer', None) - view = buffer or memoryview(bytearray(toread)) + view = buffer or memoryview(bytearray(n)) + rbuf = self._read_buffer nbytes = self._read_buffer.readinto(view) - toread -= nbytes length += nbytes try: - while toread: + while length < n: data = self.ws.recv() - if len(data) <= toread: - view = memoryview(bytes(view) + data) - toread = 0 + try: + data = bytes(data, 'utf-8') + except TypeError: + pass + + if len(data) + length < n: + for i in range(len(data)): + view[length+i] = data[i] + length += len(data) else: - view = memoryview(bytes(view) + data[:n]) - toread -= n - length += len(data) - if not nbytes: - raise IOError('Server unexpectedly closed connection') - - length += nbytes - toread -= nbytes - except: # noqa - self._read_buffer = BytesIO(view[:length]) + for i in range(n-length): + view[length+i] = data[i] + rbuf.write(data[n-length:]) + length = n + print('done with while loop') + except: + self._read_buffer = rbuf raise return view @@ -730,4 +733,7 @@ def _shutdown_transport(self): def _write(self, s): """Completely write a string to the peer.""" - self.ws.send(s) + try: + self.ws.send(s) + except: + raise From 4eba1c173f1ad1d47722aba6959954a23a2edcdc Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Mon, 18 Apr 2022 15:10:20 -0700 Subject: [PATCH 05/26] working sol --- .../azure/eventhub/_pyamqp/_transport.py | 47 +++++++++---------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index a70487da6deb..462f4aad01a3 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -47,6 +47,7 @@ from threading import Lock import certifi +import websockets from ._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack from ._encode import encode_frame @@ -457,7 +458,6 @@ def send_frame(self, channel, frame, **kwargs): else: encoded_channel = struct.pack('>H', channel) data = header + encoded_channel + performative - self.write(data) def negotiate(self, encode, decode): @@ -664,6 +664,8 @@ class WebSocketTransport(_AbstractTransport): def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs ): self.sslopts = ssl if isinstance(ssl, dict) else {} + self._connect_timeout = connect_timeout + self._host = host super().__init__( host, port, connect_timeout, **kwargs ) @@ -671,7 +673,6 @@ def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, self._http_proxy = kwargs.get('http_proxy', None) def connect(self): - http_proxy_host, http_proxy_port, http_proxy_auth = None, None, None if self._http_proxy: http_proxy_host = self._http_proxy['proxy_hostname'] @@ -684,9 +685,9 @@ def connect(self): from websocket import create_connection # TODO: transform ssl to sslopt self.ws = create_connection( - url="wss://{}/$servicebus/websocket/".format("testeh.servicebus.windows.net"), + url="wss://{}/$servicebus/websocket/".format(self._host), subprotocols=['AMQPWSB10'], - timeout=9999, + timeout=self._connect_timeout, skip_utf8_validation=True, sslopt=self.sslopts, http_proxy_host=http_proxy_host, @@ -701,29 +702,21 @@ def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-arguments length = 0 view = buffer or memoryview(bytearray(n)) - rbuf = self._read_buffer nbytes = self._read_buffer.readinto(view) length += nbytes + n -= nbytes try: - while length < n: + while n: data = self.ws.recv() - try: - data = bytes(data, 'utf-8') - except TypeError: - pass - - if len(data) + length < n: - for i in range(len(data)): - view[length+i] = data[i] - length += len(data) + + if len(data) <= n: + view[length: length + len(data)] = data + n -= len(data) else: - for i in range(n-length): - view[length+i] = data[i] - rbuf.write(data[n-length:]) - length = n - print('done with while loop') + view[length: length + n] = data[0:n] + self._read_buffer = BytesIO(data[n:]) + n = 0 except: - self._read_buffer = rbuf raise return view @@ -734,6 +727,12 @@ def _shutdown_transport(self): def _write(self, s): """Completely write a string to the peer.""" try: - self.ws.send(s) - except: - raise + """ + ABNF, OPCODE_BINARY = 0x2 + See http://tools.ietf.org/html/rfc5234 + http://tools.ietf.org/html/rfc6455#section-5.2 + """ + from websocket import ABNF + self.ws.send(s, opcode=ABNF.OPCODE_BINARY) + except ImportError: + raise ValueError("Please install websocket-client library to use websocket transport.") From 3ce869ed2da5fd60fad6a76c7c70f958990c0e13 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Mon, 18 Apr 2022 22:18:06 -0700 Subject: [PATCH 06/26] async impl --- .../azure/eventhub/_pyamqp/_transport.py | 23 ++++----- .../eventhub/_pyamqp/aio/_client_async.py | 4 +- .../eventhub/_pyamqp/aio/_connection_async.py | 27 +++++++--- .../azure/eventhub/_pyamqp/aio/_sasl_async.py | 21 +++++++- .../eventhub/_pyamqp/aio/_transport_async.py | 50 ++++++++++++++++++- .../azure/eventhub/_pyamqp/sasl.py | 6 +-- .../azure/eventhub/aio/_consumer_async.py | 2 + .../azure/eventhub/aio/_producer_async.py | 2 + .../pyamqp_tests/synctests/test_websocket.py | 5 +- 9 files changed, 109 insertions(+), 31 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index 462f4aad01a3..fa7f433edaae 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -661,7 +661,7 @@ def Transport(host, connect_timeout=None, ssl=False, **kwargs): return transport(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs) class WebSocketTransport(_AbstractTransport): - def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs + def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs ): self.sslopts = ssl if isinstance(ssl, dict) else {} self._connect_timeout = connect_timeout @@ -705,19 +705,16 @@ def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-arguments nbytes = self._read_buffer.readinto(view) length += nbytes n -= nbytes - try: - while n: - data = self.ws.recv() + while n: + data = self.ws.recv() - if len(data) <= n: - view[length: length + len(data)] = data - n -= len(data) - else: - view[length: length + n] = data[0:n] - self._read_buffer = BytesIO(data[n:]) - n = 0 - except: - raise + if len(data) <= n: + view[length: length + len(data)] = data + n -= len(data) + else: + view[length: length + n] = data[0:n] + self._read_buffer = BytesIO(data[n:]) + n = 0 return view def _shutdown_transport(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py index e1b88b192690..5c10bcdbd4d4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py @@ -201,7 +201,9 @@ async def open_async(self): channel_max=self._channel_max, idle_timeout=self._idle_timeout, properties=self._properties, - network_trace=self._network_trace + network_trace=self._network_trace, + transport_type=self._transport_type, + http_proxy=self._http_proxy ) await self._connection.open() if not self._session: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py index 3bfa62569e9a..15baa7df1875 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py @@ -16,7 +16,7 @@ import asyncio from ._transport_async import AsyncTransport -from ._sasl_async import SASLTransport +from ._sasl_async import SASLTransport, SASLTransportWithWebSocket from ._session_async import Session from ..performatives import OpenFrame, CloseFrame from .._connection import get_local_timeout @@ -27,7 +27,8 @@ MAX_CHANNELS, HEADER_FRAME, ConnectionState, - EMPTY_FRAME + EMPTY_FRAME, + TransportType, ) from ..error import ( @@ -58,11 +59,16 @@ class Connection(object): :param list(str) offered_capabilities: The extension capabilities the sender supports. :param list(str) desired_capabilities: The extension capabilities the sender may use if the receiver supports :param dict properties: Connection properties. + :keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket. + :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). + Additionally the following keys may also be present: `'username', 'password'`. """ def __init__(self, endpoint, **kwargs): parsed_url = urlparse(endpoint) self.hostname = parsed_url.hostname + self._transport_type = kwargs.pop('transport_type') if parsed_url.port: self.port = parsed_url.port elif parsed_url.scheme == 'amqps': @@ -75,11 +81,18 @@ def __init__(self, endpoint, **kwargs): if transport: self.transport = transport elif 'sasl_credential' in kwargs: - self.transport = SASLTransport( - host=parsed_url.netloc, - credential=kwargs['sasl_credential'], - **kwargs - ) + if self._transport_type is TransportType.AmqpOverWebsocket: + self._transport = SASLTransportWithWebSocket( + host=parsed_url.netloc, + credential=kwargs['sasl_credential'], + **kwargs + ) + else: + self._transport = SASLTransport( + host=parsed_url.netloc, + credential=kwargs['sasl_credential'], + **kwargs + ) else: self.transport = AsyncTransport(parsed_url.netloc, **kwargs) self._container_id = kwargs.get('container_id') or str(uuid.uuid4()) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py index dda1931b909b..8c493760fade 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py @@ -7,9 +7,9 @@ import struct from enum import Enum -from ._transport_async import AsyncTransport +from ._transport_async import AsyncTransport, WebSocketTransportAsync from ..types import AMQPTypes, TYPE, VALUE -from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME +from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME, WEBSOCKET_PORT from .._transport import AMQPS_PORT from ..performatives import ( SASLOutcome, @@ -104,3 +104,20 @@ async def negotiate(self): return else: raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) + +class SASLTransportWithWebSocket(SASLTransport): + + def __init__( + self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs + ): # pylint: disable=super-init-not-called + self.credential = credential + ssl = ssl or True + http_proxy = kwargs.pop('http_proxy', None) + self._transport = WebSocketTransportAsync( + host, + port=WEBSOCKET_PORT, + connect_timeout=connect_timeout, + ssl=ssl, + http_proxy=http_proxy, + **kwargs + ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index acbdd8af8e76..766cc06ddd0f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -49,7 +49,7 @@ from .._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack from .._encode import encode_frame from .._decode import decode_frame, decode_empty_frame -from ..constants import TLS_HEADER_FRAME +from ..constants import TLS_HEADER_FRAME, WEBSOCKET_PORT from .._transport import ( AMQP_FRAME, get_errno, @@ -59,7 +59,8 @@ SIGNED_INT_MAX, _UNAVAIL, set_cloexec, - AMQP_PORT + AMQP_PORT, + WebSocketTransport ) @@ -412,3 +413,48 @@ async def negotiate(self): if returned_header[1] == TLS_HEADER_FRAME: raise ValueError("Mismatching TLS header protocol. Excpected: {}, received: {}".format( TLS_HEADER_FRAME, returned_header[1])) + + +class WebSocketTransport(WebSocketTransport): + async def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-arguments + """Read exactly n bytes from the peer.""" + + length = 0 + view = buffer or memoryview(bytearray(n)) + nbytes = self._read_buffer.readinto(view) + length += nbytes + n -= nbytes + while n: + data = await asyncio.get_event_loop().run_in_executor( + None, self.ws.recv + ) + + if len(data) <= n: + view[length: length + len(data)] = data + n -= len(data) + else: + view[length: length + n] = data[0:n] + self._read_buffer = BytesIO(data[n:]) + n = 0 + return view + + async def _shutdown_transport(self): + """Do any preliminary work in shutting down the connection.""" + await asyncio.get_event_loop().run_in_executor( + None, self.ws.close + ) + + async def _write(self, s): + """Completely write a string to the peer.""" + try: + """ + ABNF, OPCODE_BINARY = 0x2 + See http://tools.ietf.org/html/rfc5234 + http://tools.ietf.org/html/rfc6455#section-5.2 + """ + from websocket import ABNF + await asyncio.get_event_loop().run_in_executor( + None, self.ws.recv, s, opcode=ABNF.OPCODE_BINARY + ) + except ImportError: + raise ValueError("Please install websocket-client library to use websocket transport.") diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py index fd0023b395a1..7791e6036ac8 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py @@ -74,9 +74,9 @@ class SASLTransport(SSLTransport): def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): self.credential = credential ssl = ssl or True - super().__init__(host, port, connect_timeout, ssl, **kwargs) + super(SASLTransport, self).__init__(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs) - def negotiate(self): + def negotiate(self): with self.block(): self.write(SASL_HEADER_FRAME) _, returned_header = self.receive_frame() @@ -104,7 +104,7 @@ def negotiate(self): class SASLTransportWithWebSocket(WebSocketTransport): - def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): + def __init__(self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs): self.credential = credential ssl = ssl or True http_proxy = kwargs.pop('http_proxy', None) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py index d5be74195636..22bd621f9a1c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py @@ -152,6 +152,8 @@ def _create_handler(self, auth: "JWTTokenAuthAsync") -> None: network_trace=self._client._config.network_tracing, # pylint:disable=protected-access link_credit=self._prefetch, link_properties=self._link_properties, + transport_type=self._client._config.transport_type, # pylint:disable=protected-access + http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access retry_policy=self._retry_policy, client_name=self._name, receive_settle_mode=pyamqp_constants.ReceiverSettleMode.First, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py index 18467d15a416..552e1868f5cb 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py @@ -107,6 +107,8 @@ def _create_handler(self, auth: "JWTTokenAsync") -> None: network_trace=self._client._config.network_tracing, # pylint: disable=protected-access retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, + transport_type=self._client._config.transport_type, # pylint:disable=protected-access + http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access client_name=self._name, link_properties=self._link_properties, properties=create_properties(self._client._config.user_agent), # pylint: disable=protected-access diff --git a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py index 79bf6d5ecbe8..031f60a649b3 100644 --- a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py +++ b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py @@ -5,8 +5,7 @@ import pytest -from azure.eventhub import TransportType -from azure.eventhub._pyamqp import authentication, ReceiveClient +from azure.eventhub._pyamqp import authentication, ReceiveClient, TransportType def test_event_hubs_client_web_socket(live_eventhub): uri = "sb://{}/{}".format(live_eventhub['hostname'], live_eventhub['event_hub']) @@ -23,5 +22,5 @@ def test_event_hubs_client_web_socket(live_eventhub): live_eventhub['consumer_group'], live_eventhub['partition']) - with ReceiveClient(live_eventhub['hostname'], source, auth=sas_auth, debug=False, timeout=5000, prefetch=50) as receive_client: + with ReceiveClient(live_eventhub['hostname'], source, auth=sas_auth, debug=False, timeout=5000, prefetch=50, transport_type=TransportType.AmqpOverWebsocket) as receive_client: receive_client.receive_message_batch(max_batch_size=10) From 696b4d7a74d41631acf46202241de955393ead73 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Mon, 18 Apr 2022 22:20:13 -0700 Subject: [PATCH 07/26] Update sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py --- sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index fa7f433edaae..46728d20cb26 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -683,7 +683,6 @@ def connect(self): http_proxy_auth = (username, password) try: from websocket import create_connection - # TODO: transform ssl to sslopt self.ws = create_connection( url="wss://{}/$servicebus/websocket/".format(self._host), subprotocols=['AMQPWSB10'], From 23abd815bdd29a69fe8c334f2adc4441744e4a91 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Mon, 18 Apr 2022 22:54:04 -0700 Subject: [PATCH 08/26] more changes --- .../azure-eventhub/azure/eventhub/_pyamqp/_transport.py | 4 ++-- .../azure/eventhub/_pyamqp/aio/_transport_async.py | 2 +- .../azure-eventhub/azure/eventhub/_pyamqp/constants.py | 3 +++ .../tests/pyamqp_tests/synctests/test_websocket.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index 46728d20cb26..d4178fc0601e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -52,7 +52,7 @@ from ._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack from ._encode import encode_frame from ._decode import decode_frame, decode_empty_frame -from .constants import TLS_HEADER_FRAME, WEBSOCKET_PORT, TransportType +from .constants import TLS_HEADER_FRAME, WEBSOCKET_PORT, TransportType, AMQP_WS_SUBPROTOCOL try: @@ -685,7 +685,7 @@ def connect(self): from websocket import create_connection self.ws = create_connection( url="wss://{}/$servicebus/websocket/".format(self._host), - subprotocols=['AMQPWSB10'], + subprotocols=[AMQP_WS_SUBPROTOCOL], timeout=self._connect_timeout, skip_utf8_validation=True, sslopt=self.sslopts, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index 766cc06ddd0f..9426a2439d60 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -415,7 +415,7 @@ async def negotiate(self): TLS_HEADER_FRAME, returned_header[1])) -class WebSocketTransport(WebSocketTransport): +class WebSocketTransportAsync(WebSocketTransport): async def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-arguments """Read exactly n bytes from the peer.""" diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py index a5e2928d1567..66e4ff1ae327 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/constants.py @@ -17,6 +17,9 @@ # default port for AMQP over Websocket WEBSOCKET_PORT = 443 +# subprotocol for AMQP over Websocket +AMQP_WS_SUBPROTOCOL = 'AMQPWSB10' + #: The IANA assigned port number for secure AMQP (amqps).The standard AMQP port number that has been assigned #: by IANA for secure TCP using TLS. Implementations listening on this port should NOT expect a protocol #: handshake before TLS is negotiated. diff --git a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py index 031f60a649b3..71ed02b9be20 100644 --- a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py +++ b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py @@ -5,7 +5,8 @@ import pytest -from azure.eventhub._pyamqp import authentication, ReceiveClient, TransportType +from azure.eventhub._pyamqp import authentication, ReceiveClient +from azure.eventhub._pyamqp.constants import TransportType def test_event_hubs_client_web_socket(live_eventhub): uri = "sb://{}/{}".format(live_eventhub['hostname'], live_eventhub['event_hub']) From 6e3f3a1f95fc63146964b6bc2259fa15849811d0 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Mon, 18 Apr 2022 23:11:47 -0700 Subject: [PATCH 09/26] sasl mixin --- .../azure/eventhub/_pyamqp/sasl.py | 79 +++++++------------ 1 file changed, 30 insertions(+), 49 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py index 7791e6036ac8..667b13cb2d36 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py @@ -68,8 +68,33 @@ class SASLExternalCredential(object): def start(self): return b'' +class SASLTransportMixin(): + def _negotiate(self): + self.write(SASL_HEADER_FRAME) + _, returned_header = self.receive_frame() + if returned_header[1] != SASL_HEADER_FRAME: + raise ValueError("Mismatching AMQP header protocol. Expected: {}, received: {}".format( + SASL_HEADER_FRAME, returned_header[1])) + + _, supported_mechansisms = self.receive_frame(verify_frame_type=1) + if self.credential.mechanism not in supported_mechansisms[1][0]: # sasl_server_mechanisms + raise ValueError("Unsupported SASL credential type: {}".format(self.credential.mechanism)) + sasl_init = SASLInit( + mechanism=self.credential.mechanism, + initial_response=self.credential.start(), + hostname=self.host) + self.send_frame(0, sasl_init, frame_type=_SASL_FRAME_TYPE) + + _, next_frame = self.receive_frame(verify_frame_type=1) + frame_type, fields = next_frame + if frame_type != 0x00000044: # SASLOutcome + raise NotImplementedError("Unsupported SASL challenge") + if fields[0] == SASLCode.Ok: # code + return + else: + raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) -class SASLTransport(SSLTransport): +class SASLTransport(SSLTransport, SASLTransportMixin): def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): self.credential = credential @@ -78,31 +103,9 @@ def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl= def negotiate(self): with self.block(): - self.write(SASL_HEADER_FRAME) - _, returned_header = self.receive_frame() - if returned_header[1] != SASL_HEADER_FRAME: - raise ValueError("Mismatching AMQP header protocol. Expected: {}, received: {}".format( - SASL_HEADER_FRAME, returned_header[1])) - - _, supported_mechansisms = self.receive_frame(verify_frame_type=1) - if self.credential.mechanism not in supported_mechansisms[1][0]: # sasl_server_mechanisms - raise ValueError("Unsupported SASL credential type: {}".format(self.credential.mechanism)) - sasl_init = SASLInit( - mechanism=self.credential.mechanism, - initial_response=self.credential.start(), - hostname=self.host) - self.send_frame(0, sasl_init, frame_type=_SASL_FRAME_TYPE) - - _, next_frame = self.receive_frame(verify_frame_type=1) - frame_type, fields = next_frame - if frame_type != 0x00000044: # SASLOutcome - raise NotImplementedError("Unsupported SASL challenge") - if fields[0] == SASLCode.Ok: # code - return - else: - raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) - -class SASLTransportWithWebSocket(WebSocketTransport): + self._negotiate() + +class SASLTransportWithWebSocket(WebSocketTransport, SASLTransportMixin): def __init__(self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs): self.credential = credential @@ -119,26 +122,4 @@ def __init__(self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, super().__init__(host, port, connect_timeout, ssl, **kwargs) def negotiate(self): - self.write(SASL_HEADER_FRAME) - _, returned_header = self.receive_frame() - if returned_header[1] != SASL_HEADER_FRAME: - raise ValueError("Mismatching AMQP header protocol. Expected: {}, received: {}".format( - SASL_HEADER_FRAME, returned_header[1])) - - _, supported_mechansisms = self.receive_frame(verify_frame_type=1) - if self.credential.mechanism not in supported_mechansisms[1][0]: # sasl_server_mechanisms - raise ValueError("Unsupported SASL credential type: {}".format(self.credential.mechanism)) - sasl_init = SASLInit( - mechanism=self.credential.mechanism, - initial_response=self.credential.start(), - hostname=self.host) - self.send_frame(0, sasl_init, frame_type=_SASL_FRAME_TYPE) - - _, next_frame = self.receive_frame(verify_frame_type=1) - frame_type, fields = next_frame - if frame_type != 0x00000044: # SASLOutcome - raise NotImplementedError("Unsupported SASL challenge") - if fields[0] == SASLCode.Ok: # code - return - else: - raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) + self._negotiate() From 69053e711a629127ac809f3e640a2ddc9310f240 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Mon, 18 Apr 2022 23:12:18 -0700 Subject: [PATCH 10/26] Update sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/message.py --- sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/message.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/message.py index 1cf777d3ec24..a2ef0087fd94 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/message.py @@ -265,5 +265,3 @@ def __init__(self, message, state=MessageDeliveryState.WaitingToBeSent, expiry=N self.reason = None self.delivery = None self.error = None - -__all__ = ['Message'] From 13d661ab5e201f365d4bb601be3afc0c7e0a685e Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Mon, 18 Apr 2022 23:28:12 -0700 Subject: [PATCH 11/26] refactor --- .../azure/eventhub/_pyamqp/_connection.py | 20 +++++++------------ .../eventhub/_pyamqp/aio/_connection_async.py | 20 +++++++------------ .../azure/eventhub/_pyamqp/aio/_sasl_async.py | 2 +- .../azure/eventhub/_pyamqp/sasl.py | 2 +- 4 files changed, 16 insertions(+), 28 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py index deb6134696d0..8b567c8e91f0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py @@ -12,7 +12,7 @@ from ssl import SSLError from ._transport import Transport -from .sasl import SASLTransport, SASLTransportWithWebSocket +from .sasl import SASLTransport, SASLWithWebSocket from .session import Session from .performatives import OpenFrame, CloseFrame from .constants import ( @@ -101,18 +101,12 @@ def __init__(self, endpoint, **kwargs): if transport: self._transport = transport elif 'sasl_credential' in kwargs: - if self._transport_type is TransportType.AmqpOverWebsocket: - self._transport = SASLTransportWithWebSocket( - host=parsed_url.netloc, - credential=kwargs['sasl_credential'], - **kwargs - ) - else: - self._transport = SASLTransport( - host=parsed_url.netloc, - credential=kwargs['sasl_credential'], - **kwargs - ) + func = SASLWithWebSocket if self._transport_type is TransportType.AmqpOverWebsocket else SASLTransport + self._transport = func( + host=parsed_url.netloc, + credential=kwargs['sasl_credential'], + **kwargs + ) else: self._transport = Transport(parsed_url.netloc, **kwargs) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py index 15baa7df1875..b9f73bae8bf4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py @@ -16,7 +16,7 @@ import asyncio from ._transport_async import AsyncTransport -from ._sasl_async import SASLTransport, SASLTransportWithWebSocket +from ._sasl_async import SASLTransport, SASLWithWebSocket from ._session_async import Session from ..performatives import OpenFrame, CloseFrame from .._connection import get_local_timeout @@ -81,18 +81,12 @@ def __init__(self, endpoint, **kwargs): if transport: self.transport = transport elif 'sasl_credential' in kwargs: - if self._transport_type is TransportType.AmqpOverWebsocket: - self._transport = SASLTransportWithWebSocket( - host=parsed_url.netloc, - credential=kwargs['sasl_credential'], - **kwargs - ) - else: - self._transport = SASLTransport( - host=parsed_url.netloc, - credential=kwargs['sasl_credential'], - **kwargs - ) + func = SASLWithWebSocket if self._transport_type is TransportType.AmqpOverWebsocket else SASLTransport + self._transport = func( + host=parsed_url.netloc, + credential=kwargs['sasl_credential'], + **kwargs + ) else: self.transport = AsyncTransport(parsed_url.netloc, **kwargs) self._container_id = kwargs.get('container_id') or str(uuid.uuid4()) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py index 8c493760fade..6dfd8bdf448e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py @@ -105,7 +105,7 @@ async def negotiate(self): else: raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) -class SASLTransportWithWebSocket(SASLTransport): +class SASLWithWebSocket(SASLTransport): def __init__( self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py index 667b13cb2d36..f00042fd9b02 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py @@ -105,7 +105,7 @@ def negotiate(self): with self.block(): self._negotiate() -class SASLTransportWithWebSocket(WebSocketTransport, SASLTransportMixin): +class SASLWithWebSocket(WebSocketTransport, SASLTransportMixin): def __init__(self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs): self.credential = credential From 74240d5a52dab98504cb7d7c20c06720fa84ef1a Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Mon, 18 Apr 2022 23:47:41 -0700 Subject: [PATCH 12/26] Update sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py --- .../azure-eventhub/azure/eventhub/_pyamqp/_transport.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index d4178fc0601e..d6361778a0c4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -393,8 +393,7 @@ def read(self, verify_frame_type=0, **kwargs): # TODO: verify frame type? read_frame_buffer = BytesIO() try: frame_header = memoryview(bytearray(8)) - data = read(8, buffer=frame_header, initial=True) - read_frame_buffer.write(data) + read_frame_buffer.write(read(8, buffer=frame_header, initial=True)) channel = struct.unpack('>H', frame_header[6:])[0] size = frame_header[0:4] From 59a2e2fd4a44df93cf80ca237a1a03cb6258d582 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Tue, 19 Apr 2022 13:36:09 -0700 Subject: [PATCH 13/26] Update sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py --- sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index d6361778a0c4..2773826331cf 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -47,7 +47,6 @@ from threading import Lock import certifi -import websockets from ._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack from ._encode import encode_frame From 699d60f3e703e703ecdd5371716d1955dd26ee05 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Wed, 20 Apr 2022 14:49:44 -0700 Subject: [PATCH 14/26] oops --- .../azure/eventhub/_pyamqp/aio/_transport_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index 9426a2439d60..8172797c80c8 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -454,7 +454,7 @@ async def _write(self, s): """ from websocket import ABNF await asyncio.get_event_loop().run_in_executor( - None, self.ws.recv, s, opcode=ABNF.OPCODE_BINARY + None, self.ws.send, s, ABNF.OPCODE_BINARY ) except ImportError: raise ValueError("Please install websocket-client library to use websocket transport.") From 1e82f1d96a5170ec72014678686ee402d66aa78f Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Mon, 25 Apr 2022 22:06:59 -0700 Subject: [PATCH 15/26] comments --- .../azure/eventhub/_pyamqp/_connection.py | 6 ++++-- .../azure/eventhub/_pyamqp/_transport.py | 20 +++++++----------- .../eventhub/_pyamqp/aio/_connection_async.py | 6 ++++-- .../eventhub/_pyamqp/aio/_transport_async.py | 21 +++++++------------ 4 files changed, 24 insertions(+), 29 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py index 8b567c8e91f0..b3629eb73e3c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py @@ -101,8 +101,10 @@ def __init__(self, endpoint, **kwargs): if transport: self._transport = transport elif 'sasl_credential' in kwargs: - func = SASLWithWebSocket if self._transport_type is TransportType.AmqpOverWebsocket else SASLTransport - self._transport = func( + sasl_transport = SASLWithWebSocket if ( + self._transport_type is TransportType.AmqpOverWebsocket or kwargs.get("http_proxy") + ) else SASLTransport + self._transport = sasl_transport( host=parsed_url.netloc, credential=kwargs['sasl_credential'], **kwargs diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index 2773826331cf..25477f13e270 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -655,7 +655,8 @@ def Transport(host, connect_timeout=None, ssl=False, **kwargs): transport_type = kwargs.pop('transport_type') if transport_type == TransportType.AmqpOverWebsocket: transport = WebSocketTransport - transport = SSLTransport if ssl else TCPTransport + else: + transport = SSLTransport if ssl else TCPTransport return transport(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs) class WebSocketTransport(_AbstractTransport): @@ -719,14 +720,9 @@ def _shutdown_transport(self): self.ws.close() def _write(self, s): - """Completely write a string to the peer.""" - try: - """ - ABNF, OPCODE_BINARY = 0x2 - See http://tools.ietf.org/html/rfc5234 - http://tools.ietf.org/html/rfc6455#section-5.2 - """ - from websocket import ABNF - self.ws.send(s, opcode=ABNF.OPCODE_BINARY) - except ImportError: - raise ValueError("Please install websocket-client library to use websocket transport.") + """Completely write a string to the peer. + ABNF, OPCODE_BINARY = 0x2 + See http://tools.ietf.org/html/rfc5234 + http://tools.ietf.org/html/rfc6455#section-5.2 + """ + self.ws.send_binary(s) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py index b9f73bae8bf4..921d2ad7b66e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py @@ -81,8 +81,10 @@ def __init__(self, endpoint, **kwargs): if transport: self.transport = transport elif 'sasl_credential' in kwargs: - func = SASLWithWebSocket if self._transport_type is TransportType.AmqpOverWebsocket else SASLTransport - self._transport = func( + sasl_transport = SASLWithWebSocket if ( + self._transport_type is TransportType.AmqpOverWebsocket or kwargs.get("http_proxy") + ) else SASLTransport + self._transport = sasl_transport( host=parsed_url.netloc, credential=kwargs['sasl_credential'], **kwargs diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index 8172797c80c8..fa27910508d6 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -445,16 +445,11 @@ async def _shutdown_transport(self): ) async def _write(self, s): - """Completely write a string to the peer.""" - try: - """ - ABNF, OPCODE_BINARY = 0x2 - See http://tools.ietf.org/html/rfc5234 - http://tools.ietf.org/html/rfc6455#section-5.2 - """ - from websocket import ABNF - await asyncio.get_event_loop().run_in_executor( - None, self.ws.send, s, ABNF.OPCODE_BINARY - ) - except ImportError: - raise ValueError("Please install websocket-client library to use websocket transport.") + """Completely write a string to the peer. + ABNF, OPCODE_BINARY = 0x2 + See http://tools.ietf.org/html/rfc5234 + http://tools.ietf.org/html/rfc6455#section-5.2 + """ + await asyncio.get_event_loop().run_in_executor( + None, self.ws.send_binary, s + ) From 0215429eaf5c9fca045a42e81b559833e648ca26 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Tue, 26 Apr 2022 15:41:30 -0700 Subject: [PATCH 16/26] comment --- .../azure-eventhub/azure/eventhub/_pyamqp/_transport.py | 2 +- .../azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py | 2 +- sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index 25477f13e270..bdf1ee5470df 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -675,7 +675,7 @@ def connect(self): http_proxy_host, http_proxy_port, http_proxy_auth = None, None, None if self._http_proxy: http_proxy_host = self._http_proxy['proxy_hostname'] - http_proxy_port = self._http_proxy['proxy_hostname'] + http_proxy_port = self._http_proxy['proxy_port'] username = self._http_proxy.get('username', None) password = self._http_proxy.get('password', None) if username or password: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py index 6dfd8bdf448e..d809310072e5 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py @@ -115,7 +115,7 @@ def __init__( http_proxy = kwargs.pop('http_proxy', None) self._transport = WebSocketTransportAsync( host, - port=WEBSOCKET_PORT, + port=port, connect_timeout=connect_timeout, ssl=ssl, http_proxy=http_proxy, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py index f00042fd9b02..51848304bfae 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/sasl.py @@ -113,7 +113,7 @@ def __init__(self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, http_proxy = kwargs.pop('http_proxy', None) self._transport = WebSocketTransport( host, - port=WEBSOCKET_PORT, + port=port, connect_timeout=connect_timeout, ssl=ssl, http_proxy=http_proxy, From f4088f8de218f87f777cabaa302a329a72ee3ae5 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Tue, 26 Apr 2022 15:59:16 -0700 Subject: [PATCH 17/26] Apply suggestions from code review Co-authored-by: swathipil <76007337+swathipil@users.noreply.github.com> --- .../azure/eventhub/_pyamqp/aio/_connection_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py index 921d2ad7b66e..54896b139529 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py @@ -59,7 +59,7 @@ class Connection(object): :param list(str) offered_capabilities: The extension capabilities the sender supports. :param list(str) desired_capabilities: The extension capabilities the sender may use if the receiver supports :param dict properties: Connection properties. - :keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket. + :keyword str transport_type: Required. Determines if the transport type is Amqp or AmqpOverWebSocket. :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). Additionally the following keys may also be present: `'username', 'password'`. From e9d94864189750a153c8c36c4dc959dc4743bacd Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Mon, 2 May 2022 10:54:41 -0700 Subject: [PATCH 18/26] comments --- .../azure/eventhub/_pyamqp/_connection.py | 6 ++++-- .../eventhub/_pyamqp/aio/_connection_async.py | 8 +++++--- .../azure/eventhub/_pyamqp/aio/_sasl_async.py | 17 +++++++++-------- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py index b3629eb73e3c..dbdb37341381 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py @@ -79,8 +79,10 @@ class Connection(object): :keyword bool network_trace: Whether to log the network traffic. Default value is `False`. If enabled, frames will be logged at the logging.INFO level. :keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket. + Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy. :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following - keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings, + the transport_type would be AmqpOverWebSocket. Additionally the following keys may also be present: `'username', 'password'`. """ @@ -97,7 +99,7 @@ def __init__(self, endpoint, **kwargs): self.state = None # type: Optional[ConnectionState] transport = kwargs.get('transport') - self._transport_type = kwargs.pop('transport_type') + self._transport_type = kwargs.pop('transport_type', TransportType.Amqp) if transport: self._transport = transport elif 'sasl_credential' in kwargs: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py index 54896b139529..c0e97dd61d81 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py @@ -59,9 +59,11 @@ class Connection(object): :param list(str) offered_capabilities: The extension capabilities the sender supports. :param list(str) desired_capabilities: The extension capabilities the sender may use if the receiver supports :param dict properties: Connection properties. - :keyword str transport_type: Required. Determines if the transport type is Amqp or AmqpOverWebSocket. + :keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket. + Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy. :keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following - keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). + keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings, + the transport_type would be AmqpOverWebSocket. Additionally the following keys may also be present: `'username', 'password'`. """ @@ -77,7 +79,7 @@ def __init__(self, endpoint, **kwargs): self.port = PORT self.state = None - transport = kwargs.get('transport') + transport = kwargs.get('transport', TransportType.Amqp) if transport: self.transport = transport elif 'sasl_credential' in kwargs: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py index d809310072e5..36aa8f93839c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py @@ -73,13 +73,7 @@ def start(self): return b'' -class SASLTransport(AsyncTransport): - - def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): - self.credential = credential - ssl = ssl or True - super(SASLTransport, self).__init__(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs) - +class SASLTransportMixinAsync(): async def negotiate(self): await self.write(SASL_HEADER_FRAME) _, returned_header = await self.receive_frame() @@ -105,7 +99,14 @@ async def negotiate(self): else: raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) -class SASLWithWebSocket(SASLTransport): + +class SASLTransport(AsyncTransport, SASLTransportMixinAsync): + def __init__(self, host, credential, connect_timeout=None, ssl=None, **kwargs): + self.credential = credential + ssl = ssl or True + super(SASLTransport, self).__init__(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs) + +class SASLWithWebSocket(AsyncTransport, SASLTransportMixinAsync): def __init__( self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs From e04f18c16e5cb19a07bb2c40e87cdccd060e7f38 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Wed, 4 May 2022 15:46:20 -0700 Subject: [PATCH 19/26] changes --- .../eventhub/_pyamqp/aio/_client_async.py | 28 ++- .../eventhub/_pyamqp/aio/_connection_async.py | 9 +- .../azure/eventhub/_pyamqp/aio/_sasl_async.py | 9 +- .../eventhub/_pyamqp/aio/_transport_async.py | 182 +++++++++++------- 4 files changed, 143 insertions(+), 85 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py index 5c10bcdbd4d4..957588d2a921 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py @@ -176,6 +176,23 @@ async def _do_retryable_operation_async(self, operation, *args, **kwargs): absolute_timeout -= (end_time - start_time) raise retry_settings['history'][-1] + async def _keep_alive_worker_async(self): + interval = 10 if self._keep_alive is True else self._keep_alive + start_time = time.time() + try: + while self._connection and not self._shutdown: + current_time = time.time() + elapsed_time = (current_time - start_time) + if elapsed_time >= interval: + _logger.info("Keeping %r connection alive. %r", + self.__class__.__name__, + self._connection._container_id) + await self._connection._get_remote_timeout(current_time) + start_time = current_time + await asyncio.sleep(1) + except Exception as e: # pylint: disable=broad-except + _logger.info("Connection keep-alive for %r failed: %r.", self.__class__.__name__, e) + async def open_async(self): """Asynchronously open the client. The client can create a new Connection or an existing Connection can be passed in. This existing Connection @@ -200,10 +217,10 @@ async def open_async(self): max_frame_size=self._max_frame_size, channel_max=self._channel_max, idle_timeout=self._idle_timeout, - properties=self._properties, - network_trace=self._network_trace, transport_type=self._transport_type, - http_proxy=self._http_proxy + http_proxy=self._http_proxy, + properties=self._properties, + network_trace=self._network_trace ) await self._connection.open() if not self._session: @@ -219,6 +236,8 @@ async def open_async(self): auth_timeout=self._auth_timeout ) await self._cbs_authenticator.open() + if self._keep_alive: + self._keep_alive_thread = asyncio.ensure_future(self._keep_alive_worker_async()) self._shutdown = False async def close_async(self): @@ -230,6 +249,9 @@ async def close_async(self): self._shutdown = True if not self._session: return # already closed. + if self._keep_alive_thread: + await self._keep_alive_thread + self._keep_alive_thread = None await self._close_link_async(close=True) if self._cbs_authenticator: await self._cbs_authenticator.close() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py index c0e97dd61d81..689712740123 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py @@ -28,7 +28,7 @@ HEADER_FRAME, ConnectionState, EMPTY_FRAME, - TransportType, + TransportType ) from ..error import ( @@ -70,7 +70,7 @@ class Connection(object): def __init__(self, endpoint, **kwargs): parsed_url = urlparse(endpoint) self.hostname = parsed_url.hostname - self._transport_type = kwargs.pop('transport_type') + self._transport_type = kwargs.pop('transport_type', TransportType.Amqp) if parsed_url.port: self.port = parsed_url.port elif parsed_url.scheme == 'amqps': @@ -78,15 +78,14 @@ def __init__(self, endpoint, **kwargs): else: self.port = PORT self.state = None - - transport = kwargs.get('transport', TransportType.Amqp) + transport = kwargs.get('transport') if transport: self.transport = transport elif 'sasl_credential' in kwargs: sasl_transport = SASLWithWebSocket if ( self._transport_type is TransportType.AmqpOverWebsocket or kwargs.get("http_proxy") ) else SASLTransport - self._transport = sasl_transport( + self.transport = sasl_transport( host=parsed_url.netloc, credential=kwargs['sasl_credential'], **kwargs diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py index 36aa8f93839c..3a69dbd0f0ed 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py @@ -4,12 +4,13 @@ # license information. #-------------------------------------------------------------------------- +import http import struct from enum import Enum from ._transport_async import AsyncTransport, WebSocketTransportAsync from ..types import AMQPTypes, TYPE, VALUE -from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME, WEBSOCKET_PORT +from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME, WEBSOCKET_PORT, TransportType from .._transport import AMQPS_PORT from ..performatives import ( SASLOutcome, @@ -72,7 +73,6 @@ class SASLExternalCredential(object): def start(self): return b'' - class SASLTransportMixinAsync(): async def negotiate(self): await self.write(SASL_HEADER_FRAME) @@ -99,15 +99,13 @@ async def negotiate(self): else: raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) - class SASLTransport(AsyncTransport, SASLTransportMixinAsync): def __init__(self, host, credential, connect_timeout=None, ssl=None, **kwargs): self.credential = credential ssl = ssl or True super(SASLTransport, self).__init__(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs) -class SASLWithWebSocket(AsyncTransport, SASLTransportMixinAsync): - +class SASLWithWebSocket(WebSocketTransportAsync, SASLTransportMixinAsync): def __init__( self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs ): # pylint: disable=super-init-not-called @@ -122,3 +120,4 @@ def __init__( http_proxy=http_proxy, **kwargs ) + super().__init__(host, port, connect_timeout, ssl, **kwargs) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index fa27910508d6..157466189697 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -49,7 +49,7 @@ from .._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack from .._encode import encode_frame from .._decode import decode_frame, decode_empty_frame -from ..constants import TLS_HEADER_FRAME, WEBSOCKET_PORT +from ..constants import TLS_HEADER_FRAME, WEBSOCKET_PORT, AMQP_WS_SUBPROTOCOL from .._transport import ( AMQP_FRAME, get_errno, @@ -82,8 +82,77 @@ def get_running_loop(): loop = asyncio.get_event_loop() return loop +class AsyncTransportMixin(): + def __init__(self): + self._read_buffer = BytesIO() + self.loop = get_running_loop() + self.socket_lock = asyncio.Lock() + + async def receive_frame(self, *args, **kwargs): + try: + header, channel, payload = await self.read(**kwargs) + if not payload: + decoded = decode_empty_frame(header) + else: + decoded = decode_frame(payload) + # TODO: Catch decode error and return amqp:decode-error + #_LOGGER.info("ICH%d <- %r", channel, decoded) + return channel, decoded + except (socket.timeout, asyncio.IncompleteReadError, asyncio.TimeoutError): + return None, None + + async def read(self, verify_frame_type=0, **kwargs): # TODO: verify frame type? + async with self.socket_lock: + read_frame_buffer = BytesIO() + try: + frame_header = memoryview(bytearray(8)) + read_frame_buffer.write(await self._read(8, buffer=frame_header, initial=True)) + + channel = struct.unpack('>H', frame_header[6:])[0] + size = frame_header[0:4] + if size == AMQP_FRAME: # Empty frame or AMQP header negotiation + return frame_header, channel, None + size = struct.unpack('>I', size)[0] + offset = frame_header[4] + frame_type = frame_header[5] -class AsyncTransport(object): + # >I is an unsigned int, but the argument to sock.recv is signed, + # so we know the size can be at most 2 * SIGNED_INT_MAX + payload_size = size - len(frame_header) + payload = memoryview(bytearray(payload_size)) + if size > SIGNED_INT_MAX: + read_frame_buffer.write(await self._read(SIGNED_INT_MAX, buffer=payload)) + read_frame_buffer.write(await self._read(size - SIGNED_INT_MAX, buffer=payload[SIGNED_INT_MAX:])) + else: + read_frame_buffer.write(await self._read(payload_size, buffer=payload)) + except (socket.timeout, asyncio.IncompleteReadError): + read_frame_buffer.write(self._read_buffer.getvalue()) + self._read_buffer = read_frame_buffer + self._read_buffer.seek(0) + raise + except (OSError, IOError, SSLError, socket.error) as exc: + # Don't disconnect for ssl read time outs + # http://bugs.python.org/issue10272 + if isinstance(exc, SSLError) and 'timed out' in str(exc): + raise socket.timeout() + if get_errno(exc) not in _UNAVAIL: + self.connected = False + raise + offset -= 2 + return frame_header, channel, payload[offset:] + + async def send_frame(self, channel, frame, **kwargs): + header, performative = encode_frame(frame, **kwargs) + if performative is None: + data = header + else: + encoded_channel = struct.pack('>H', channel) + data = header + encoded_channel + performative + + await self.write(data) + #_LOGGER.info("OCH%d -> %r", channel, frame) + +class AsyncTransport(AsyncTransportMixin): """Common superclass for TCP and SSL transports.""" def __init__(self, host, port=AMQP_PORT, connect_timeout=None, @@ -319,46 +388,6 @@ def close(self): self.sock = None self.connected = False - async def read(self, verify_frame_type=0, **kwargs): # TODO: verify frame type? - async with self.socket_lock: - read_frame_buffer = BytesIO() - try: - frame_header = memoryview(bytearray(8)) - read_frame_buffer.write(await self._read(8, buffer=frame_header, initial=True)) - - channel = struct.unpack('>H', frame_header[6:])[0] - size = frame_header[0:4] - if size == AMQP_FRAME: # Empty frame or AMQP header negotiation - return frame_header, channel, None - size = struct.unpack('>I', size)[0] - offset = frame_header[4] - frame_type = frame_header[5] - - # >I is an unsigned int, but the argument to sock.recv is signed, - # so we know the size can be at most 2 * SIGNED_INT_MAX - payload_size = size - len(frame_header) - payload = memoryview(bytearray(payload_size)) - if size > SIGNED_INT_MAX: - read_frame_buffer.write(await self._read(SIGNED_INT_MAX, buffer=payload)) - read_frame_buffer.write(await self._read(size - SIGNED_INT_MAX, buffer=payload[SIGNED_INT_MAX:])) - else: - read_frame_buffer.write(await self._read(payload_size, buffer=payload)) - except (socket.timeout, asyncio.IncompleteReadError): - read_frame_buffer.write(self._read_buffer.getvalue()) - self._read_buffer = read_frame_buffer - self._read_buffer.seek(0) - raise - except (OSError, IOError, SSLError, socket.error) as exc: - # Don't disconnect for ssl read time outs - # http://bugs.python.org/issue10272 - if isinstance(exc, SSLError) and 'timed out' in str(exc): - raise socket.timeout() - if get_errno(exc) not in _UNAVAIL: - self.connected = False - raise - offset -= 2 - return frame_header, channel, payload[offset:] - async def write(self, s): try: await self._write(s) @@ -369,19 +398,6 @@ async def write(self, s): self.connected = False raise - async def receive_frame(self, *args, **kwargs): - try: - header, channel, payload = await self.read(**kwargs) - if not payload: - decoded = decode_empty_frame(header) - else: - decoded = decode_frame(payload) - # TODO: Catch decode error and return amqp:decode-error - #_LOGGER.info("ICH%d <- %r", channel, decoded) - return channel, decoded - except (socket.timeout, asyncio.IncompleteReadError, asyncio.TimeoutError): - return None, None - async def receive_frame_with_lock(self, *args, **kwargs): try: async with self.socket_lock: @@ -394,17 +410,6 @@ async def receive_frame_with_lock(self, *args, **kwargs): except socket.timeout: return None, None - async def send_frame(self, channel, frame, **kwargs): - header, performative = encode_frame(frame, **kwargs) - if performative is None: - data = header - else: - encoded_channel = struct.pack('>H', channel) - data = header + encoded_channel + performative - - await self.write(data) - #_LOGGER.info("OCH%d -> %r", channel, frame) - async def negotiate(self): if not self.sslopts: return @@ -415,7 +420,42 @@ async def negotiate(self): TLS_HEADER_FRAME, returned_header[1])) -class WebSocketTransportAsync(WebSocketTransport): +class WebSocketTransportAsync(AsyncTransportMixin): + def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs + ): + self.sslopts = ssl if isinstance(ssl, dict) else {} + self._connect_timeout = connect_timeout + self.host = host + super().__init__( + ) + self.socket_lock = asyncio.Lock() + self.ws = None + self._http_proxy = kwargs.get('http_proxy', None) + + async def connect(self): + http_proxy_host, http_proxy_port, http_proxy_auth = None, None, None + if self._http_proxy: + http_proxy_host = self._http_proxy['proxy_hostname'] + http_proxy_port = self._http_proxy['proxy_port'] + username = self._http_proxy.get('username', None) + password = self._http_proxy.get('password', None) + if username or password: + http_proxy_auth = (username, password) + try: + from websocket import create_connection + self.ws = create_connection( + url="wss://{}/$servicebus/websocket/".format(self.host), + subprotocols=[AMQP_WS_SUBPROTOCOL], + timeout=self._connect_timeout, + skip_utf8_validation=True, + sslopt=self.sslopts, + http_proxy_host=http_proxy_host, + http_proxy_port=http_proxy_port, + http_proxy_auth=http_proxy_auth + ) + except ImportError: + raise ValueError("Please install websocket-client library to use websocket transport.") + async def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-arguments """Read exactly n bytes from the peer.""" @@ -438,13 +478,11 @@ async def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-argume n = 0 return view - async def _shutdown_transport(self): + def close(self): """Do any preliminary work in shutting down the connection.""" - await asyncio.get_event_loop().run_in_executor( - None, self.ws.close - ) + self.ws.close() - async def _write(self, s): + async def write(self, s): """Completely write a string to the peer. ABNF, OPCODE_BINARY = 0x2 See http://tools.ietf.org/html/rfc5234 From 50e8594b9c1dbbe7f2ff5d44817772acc7fd7e1a Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Wed, 4 May 2022 15:50:22 -0700 Subject: [PATCH 20/26] async test --- .../async/test_websocket_async.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 sdk/eventhub/azure-eventhub/tests/pyamqp_tests/async/test_websocket_async.py diff --git a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/async/test_websocket_async.py b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/async/test_websocket_async.py new file mode 100644 index 000000000000..4fa43d329b67 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/async/test_websocket_async.py @@ -0,0 +1,35 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import pytest +import asyncio +import logging +from uamqp.aio import ReceiveClientAsync, SASTokenAuthAsync +from uamqp.constants import TransportType + +@pytest.mark.asyncio +async def test_event_hubs_client_web_socket(eventhub_config): + uri = "sb://{}/{}".format(eventhub_config['hostname'], eventhub_config['event_hub']) + sas_auth = SASTokenAuthAsync( + uri=uri, + audience=uri, + username=eventhub_config['key_name'], + password=eventhub_config['access_key'] + ) + + source = "amqps://{}/{}/ConsumerGroups/{}/Partitions/{}".format( + eventhub_config['hostname'], + eventhub_config['event_hub'], + eventhub_config['consumer_group'], + eventhub_config['partition']) + + receive_client = ReceiveClientAsync(eventhub_config['hostname'], source, auth=sas_auth, debug=False, timeout=5000, prefetch=50, transport_type=TransportType.AmqpOverWebsocket) + await receive_client.open_async() + while not await receive_client.client_ready_async(): + await asyncio.sleep(0.05) + messages = await receive_client.receive_message_batch_async(max_batch_size=1) + logging.info(len(messages)) + logging.info(messages[0]) + await receive_client.close_async() From eb8cdc0d68fc66b3ea1ee020d6218fae0af717ec Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Thu, 5 May 2022 00:27:44 -0700 Subject: [PATCH 21/26] rasie --- .../azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py | 1 - sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py index 3a69dbd0f0ed..014681787c27 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_sasl_async.py @@ -4,7 +4,6 @@ # license information. #-------------------------------------------------------------------------- -import http import struct from enum import Enum diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py index fdf98b77e241..2b6c06070347 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py @@ -157,6 +157,8 @@ def __init__(self, hostname, auth=None, **kwargs): self._desired_capabilities = kwargs.pop('desired_capabilities', None) # transport + if kwargs.get('transport_type') is TransportType.Amqp and kwargs.get('http_proxy') is not None: + raise ValueError("Http proxy settings can't be passed if transport_type is explicitly set to Amqp") self._transport_type = kwargs.pop('transport_type', TransportType.Amqp) self._http_proxy = kwargs.pop('http_proxy', None) From 4f7690544f6f252ea32aae723ae07e55fc4de6b7 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Thu, 5 May 2022 00:35:49 -0700 Subject: [PATCH 22/26] lint --- .../azure-eventhub/azure/eventhub/_pyamqp/_connection.py | 2 +- .../azure-eventhub/azure/eventhub/_pyamqp/_transport.py | 3 +-- .../azure/eventhub/_pyamqp/aio/_transport_async.py | 4 +--- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py index dbdb37341381..bffc7940648c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py @@ -112,7 +112,7 @@ def __init__(self, endpoint, **kwargs): **kwargs ) else: - self._transport = Transport(parsed_url.netloc, **kwargs) + self._transport = Transport(parsed_url.netloc, self._transport_type, **kwargs) self._container_id = kwargs.pop('container_id', None) or str(uuid.uuid4()) # type: str self._max_frame_size = kwargs.pop('max_frame_size', MAX_FRAME_SIZE_BYTES) # type: int diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index bdf1ee5470df..bc0f34e4eaf6 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -646,13 +646,12 @@ def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)): return result -def Transport(host, connect_timeout=None, ssl=False, **kwargs): +def Transport(host, transport_type, connect_timeout=None, ssl=False, **kwargs): """Create transport. Given a few parameters from the Connection constructor, select and create a subclass of _AbstractTransport. """ - transport_type = kwargs.pop('transport_type') if transport_type == TransportType.AmqpOverWebsocket: transport = WebSocketTransport else: diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index 157466189697..e42a02b11fb5 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -423,12 +423,10 @@ async def negotiate(self): class WebSocketTransportAsync(AsyncTransportMixin): def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs ): + super().__init__() self.sslopts = ssl if isinstance(ssl, dict) else {} self._connect_timeout = connect_timeout self.host = host - super().__init__( - ) - self.socket_lock = asyncio.Lock() self.ws = None self._http_proxy = kwargs.get('http_proxy', None) From 0db10f8bde13cdac881da0661532c1e0f299ca98 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Thu, 5 May 2022 11:33:28 -0700 Subject: [PATCH 23/26] changelog --- sdk/eventhub/azure-eventhub/CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index 7d105b3b0fba..1af1b168b856 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -1,9 +1,11 @@ # Release History -## 5.8.0b4 (Unreleased) +## 5.8.0a4 (Unreleased) ### Features Added +- Added suppport for connection using websocket and http proxy. + ### Breaking Changes ### Bugs Fixed From 17df37ec488a925fe86c2dd60ea5f54115b78b4a Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Thu, 5 May 2022 11:49:11 -0700 Subject: [PATCH 24/26] version --- sdk/eventhub/azure-eventhub/azure/eventhub/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py index 440fcc69d1d5..03c76c0832af 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py @@ -3,4 +3,4 @@ # Licensed under the MIT License. # ------------------------------------ -VERSION = "5.8.0b4" +VERSION = "5.8.0a4" From ac822f4dcc8b012a0fa692abdd3f588e77ff8c05 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Thu, 5 May 2022 14:32:22 -0700 Subject: [PATCH 25/26] comments --- .../eventhub/_pyamqp/aio/_transport_async.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index e42a02b11fb5..16755614b8a2 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -79,15 +79,10 @@ def get_running_loop(): _LOGGER.warning('This version of Python is deprecated, please upgrade to >= v3.6') if loop is None: _LOGGER.warning('No running event loop') - loop = asyncio.get_event_loop() + loop = self.loop return loop class AsyncTransportMixin(): - def __init__(self): - self._read_buffer = BytesIO() - self.loop = get_running_loop() - self.socket_lock = asyncio.Lock() - async def receive_frame(self, *args, **kwargs): try: header, channel, payload = await self.read(**kwargs) @@ -165,7 +160,6 @@ def __init__(self, host, port=AMQP_PORT, connect_timeout=None, self.raise_on_initial_eintr = raise_on_initial_eintr self._read_buffer = BytesIO() self.host, self.port = to_host_port(host, port) - self.connect_timeout = connect_timeout self.read_timeout = read_timeout self.write_timeout = write_timeout @@ -423,7 +417,9 @@ async def negotiate(self): class WebSocketTransportAsync(AsyncTransportMixin): def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs ): - super().__init__() + self._read_buffer = BytesIO() + self.loop = get_running_loop() + self.socket_lock = asyncio.Lock() self.sslopts = ssl if isinstance(ssl, dict) else {} self._connect_timeout = connect_timeout self.host = host @@ -463,7 +459,7 @@ async def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-argume length += nbytes n -= nbytes while n: - data = await asyncio.get_event_loop().run_in_executor( + data = await self.loop.run_in_executor( None, self.ws.recv ) @@ -486,6 +482,6 @@ async def write(self, s): See http://tools.ietf.org/html/rfc5234 http://tools.ietf.org/html/rfc6455#section-5.2 """ - await asyncio.get_event_loop().run_in_executor( + await self.loop.run_in_executor( None, self.ws.send_binary, s ) From d351b5a1c8c82cc49d3cd46ebec408de52af2d77 Mon Sep 17 00:00:00 2001 From: Rakshith Bhyravabhotla Date: Sat, 7 May 2022 00:19:15 -0700 Subject: [PATCH 26/26] move path to EH --- .../azure-eventhub/azure/eventhub/_consumer.py | 8 ++++++-- .../azure-eventhub/azure/eventhub/_producer.py | 8 ++++++-- .../azure/eventhub/_pyamqp/_connection.py | 10 ++++++---- .../azure/eventhub/_pyamqp/_transport.py | 2 +- .../azure/eventhub/_pyamqp/aio/_connection_async.py | 12 +++++++----- .../azure/eventhub/_pyamqp/aio/_transport_async.py | 2 +- .../azure/eventhub/aio/_consumer_async.py | 9 ++++++--- .../azure/eventhub/aio/_producer_async.py | 8 ++++++-- .../tests/pyamqp_tests/async/test_websocket_async.py | 2 +- .../tests/pyamqp_tests/synctests/test_websocket.py | 2 +- 10 files changed, 41 insertions(+), 22 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py index 0c33ab46a798..31c6e8cda89c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py @@ -136,6 +136,10 @@ def __init__(self, client, source, **kwargs): def _create_handler(self, auth): # type: (JWTTokenAuth) -> None + transport_type = self._client._config.transport_type, # pylint:disable=protected-access + hostname = urlparse(source.address).hostname + if transport_type.name is 'AmqpOverWebsocket': + hostname += '/$servicebus/websocket/' source = Source(address=self._source, filters={}) if self._offset is not None: filter_key = ApacheFilters.selector_filter @@ -151,12 +155,12 @@ def _create_handler(self, auth): desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None self._handler = ReceiveClient( - urlparse(source.address).hostname, + hostname, source, auth=auth, idle_timeout=self._idle_timeout, network_trace=self._client._config.network_tracing, # pylint:disable=protected-access - transport_type=self._client._config.transport_type, # pylint:disable=protected-access + transport_type=transport_type, http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access link_credit=self._prefetch, link_properties=self._link_properties, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py index feb5fcc523ca..2867b67cea03 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py @@ -125,13 +125,17 @@ def __init__(self, client, target, **kwargs): def _create_handler(self, auth): # type: (JWTTokenAuth) -> None + transport_type=self._client._config.transport_type, # pylint:disable=protected-access + hostname = self._client._address.hostname, # pylint: disable=protected-access + if transport_type.name is 'AmqpOverWebsocket': + hostname += '/$servicebus/websocket/' self._handler = SendClient( - self._client._address.hostname, # pylint: disable=protected-access + hostname, self._target, auth=auth, idle_timeout=self._idle_timeout, network_trace=self._client._config.network_tracing, # pylint:disable=protected-access - transport_type=self._client._config.transport_type, # pylint:disable=protected-access + transport_type=transport_type, http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py index bffc7940648c..cc84d6870e02 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_connection.py @@ -90,6 +90,7 @@ def __init__(self, endpoint, **kwargs): # type(str, Any) -> None parsed_url = urlparse(endpoint) self._hostname = parsed_url.hostname + endpoint = self._hostname if parsed_url.port: self._port = parsed_url.port elif parsed_url.scheme == 'amqps': @@ -103,11 +104,12 @@ def __init__(self, endpoint, **kwargs): if transport: self._transport = transport elif 'sasl_credential' in kwargs: - sasl_transport = SASLWithWebSocket if ( - self._transport_type is TransportType.AmqpOverWebsocket or kwargs.get("http_proxy") - ) else SASLTransport + sasl_transport = SASLTransport + if self._transport_type.name is 'AmqpOverWebsocket' or kwargs.get("http_proxy"): + sasl_transport = SASLWithWebSocket + endpoint = parsed_url.hostname + parsed_url.path self._transport = sasl_transport( - host=parsed_url.netloc, + host=endpoint, credential=kwargs['sasl_credential'], **kwargs ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index bc0f34e4eaf6..29e506177cd3 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -682,7 +682,7 @@ def connect(self): try: from websocket import create_connection self.ws = create_connection( - url="wss://{}/$servicebus/websocket/".format(self._host), + url="wss://{}".format(self._host), subprotocols=[AMQP_WS_SUBPROTOCOL], timeout=self._connect_timeout, skip_utf8_validation=True, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py index 689712740123..7b69e1f2e64e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_connection_async.py @@ -70,6 +70,7 @@ class Connection(object): def __init__(self, endpoint, **kwargs): parsed_url = urlparse(endpoint) self.hostname = parsed_url.hostname + endpoint = self._hostname self._transport_type = kwargs.pop('transport_type', TransportType.Amqp) if parsed_url.port: self.port = parsed_url.port @@ -82,11 +83,12 @@ def __init__(self, endpoint, **kwargs): if transport: self.transport = transport elif 'sasl_credential' in kwargs: - sasl_transport = SASLWithWebSocket if ( - self._transport_type is TransportType.AmqpOverWebsocket or kwargs.get("http_proxy") - ) else SASLTransport - self.transport = sasl_transport( - host=parsed_url.netloc, + sasl_transport = SASLTransport + if self._transport_type.name is 'AmqpOverWebsocket' or kwargs.get("http_proxy"): + sasl_transport = SASLWithWebSocket + endpoint = parsed_url.hostname + parsed_url.path + self._transport = sasl_transport( + host=endpoint, credential=kwargs['sasl_credential'], **kwargs ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index 16755614b8a2..39d09213eba3 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -438,7 +438,7 @@ async def connect(self): try: from websocket import create_connection self.ws = create_connection( - url="wss://{}/$servicebus/websocket/".format(self.host), + url="wss://{}".format(self.host), subprotocols=[AMQP_WS_SUBPROTOCOL], timeout=self._connect_timeout, skip_utf8_validation=True, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py index 22bd621f9a1c..73d5effc2697 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py @@ -143,16 +143,19 @@ def _create_handler(self, auth: "JWTTokenAuthAsync") -> None: ) ) desired_capabilities = [RECEIVER_RUNTIME_METRIC_SYMBOL] if self._track_last_enqueued_event_properties else None - + hostname = urlparse(source.address).hostname + transport_type = self._client._config.transport_type, # pylint:disable=protected-access + if transport_type.name is 'AmqpOverWebsocket': + hostname += '/$servicebus/websocket/' self._handler = ReceiveClientAsync( - urlparse(source.address).hostname, + hostname, source, auth=auth, idle_timeout=self._idle_timeout, network_trace=self._client._config.network_tracing, # pylint:disable=protected-access link_credit=self._prefetch, link_properties=self._link_properties, - transport_type=self._client._config.transport_type, # pylint:disable=protected-access + transport_type=transport_type, http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access retry_policy=self._retry_policy, client_name=self._name, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py index 552e1868f5cb..4e8a628e5d19 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_async.py @@ -99,15 +99,19 @@ def __init__(self, client: "EventHubProducerClient", target: str, **kwargs) -> N self._link_properties = {TIMEOUT_SYMBOL: pyamqp_utils.amqp_long_value(int(self._timeout * 1000))} def _create_handler(self, auth: "JWTTokenAsync") -> None: + hostname = self._client._address.hostname, # pylint: disable=protected-access + transport_type = self._client._config.transport_type, # pylint:disable=protected-access + if transport_type.name is 'AmqpOverWebsocket': + hostname += '/$servicebus/websocket/' self._handler = SendClientAsync( - self._client._address.hostname, # pylint: disable=protected-access + hostname, self._target, auth=auth, idle_timeout=self._idle_timeout, network_trace=self._client._config.network_tracing, # pylint: disable=protected-access retry_policy=self._retry_policy, keep_alive_interval=self._keep_alive, - transport_type=self._client._config.transport_type, # pylint:disable=protected-access + transport_type=transport_type, http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access client_name=self._name, link_properties=self._link_properties, diff --git a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/async/test_websocket_async.py b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/async/test_websocket_async.py index 4fa43d329b67..443862be0c97 100644 --- a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/async/test_websocket_async.py +++ b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/async/test_websocket_async.py @@ -25,7 +25,7 @@ async def test_event_hubs_client_web_socket(eventhub_config): eventhub_config['consumer_group'], eventhub_config['partition']) - receive_client = ReceiveClientAsync(eventhub_config['hostname'], source, auth=sas_auth, debug=False, timeout=5000, prefetch=50, transport_type=TransportType.AmqpOverWebsocket) + receive_client = ReceiveClientAsync(eventhub_config['hostname'] + '/$servicebus/websocket/', source, auth=sas_auth, debug=False, timeout=5000, prefetch=50, transport_type=TransportType.AmqpOverWebsocket) await receive_client.open_async() while not await receive_client.client_ready_async(): await asyncio.sleep(0.05) diff --git a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py index 71ed02b9be20..7dd9e5bfbe9c 100644 --- a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py +++ b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/synctests/test_websocket.py @@ -23,5 +23,5 @@ def test_event_hubs_client_web_socket(live_eventhub): live_eventhub['consumer_group'], live_eventhub['partition']) - with ReceiveClient(live_eventhub['hostname'], source, auth=sas_auth, debug=False, timeout=5000, prefetch=50, transport_type=TransportType.AmqpOverWebsocket) as receive_client: + with ReceiveClient(live_eventhub['hostname'] + '/$servicebus/websocket/', source, auth=sas_auth, debug=False, timeout=5000, prefetch=50, transport_type=TransportType.AmqpOverWebsocket) as receive_client: receive_client.receive_message_batch(max_batch_size=10)