Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMQP websocket implementation #23722

Merged
merged 26 commits into from
May 7, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,6 @@ def _create_auth(self):
functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE),
token_type=token_type,
timeout=self._config.auth_timeout,
http_proxy=self._config.http_proxy,
transport_type=self._config.transport_type,
custom_endpoint_hostname=self._config.custom_endpoint_hostname,
port=self._config.connection_port,
verify=self._config.connection_verify,
Expand Down
2 changes: 2 additions & 0 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def _create_handler(self, auth):
auth=auth,
idle_timeout=self._idle_timeout,
network_trace=self._client._config.network_tracing, # pylint:disable=protected-access
transport_type=self._client._config.transport_type, # pylint:disable=protected-access
http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access
link_credit=self._prefetch,
link_properties=self._link_properties,
retry_policy=self._retry_policy,
Expand Down
4 changes: 3 additions & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def _create_handler(self, auth):
self._target,
auth=auth,
idle_timeout=self._idle_timeout,
network_trace=self._client._config.network_tracing, # pylint: disable=protected-access
network_trace=self._client._config.network_tracing, # pylint:disable=protected-access
transport_type=self._client._config.transport_type, # pylint:disable=protected-access
http_proxy=self._client._config.http_proxy, # pylint:disable=protected-access
retry_policy=self._retry_policy,
keep_alive_interval=self._keep_alive,
client_name=self._name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ssl import SSLError

from ._transport import Transport
from .sasl import SASLTransport
from .sasl import SASLTransport, SASLWithWebSocket
from .session import Session
from .performatives import OpenFrame, CloseFrame
from .constants import (
Expand All @@ -22,7 +22,8 @@
MAX_FRAME_SIZE_BYTES,
HEADER_FRAME,
ConnectionState,
EMPTY_FRAME
EMPTY_FRAME,
TransportType
)

from .error import (
Expand Down Expand Up @@ -77,6 +78,12 @@ class Connection(object):
Default value is `0.1`.
:keyword bool network_trace: Whether to log the network traffic. Default value is `False`. If enabled, frames
will be logged at the logging.INFO level.
:keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket.
annatisch marked this conversation as resolved.
Show resolved Hide resolved
Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy.
:keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following
keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings,
the transport_type would be AmqpOverWebSocket.
Additionally the following keys may also be present: `'username', 'password'`.
"""

def __init__(self, endpoint, **kwargs):
Expand All @@ -92,16 +99,20 @@ def __init__(self, endpoint, **kwargs):
self.state = None # type: Optional[ConnectionState]

transport = kwargs.get('transport')
self._transport_type = kwargs.pop('transport_type', TransportType.Amqp)
swathipil marked this conversation as resolved.
Show resolved Hide resolved
if transport:
self._transport = transport
elif 'sasl_credential' in kwargs:
self._transport = SASLTransport(
sasl_transport = SASLWithWebSocket if (
self._transport_type is TransportType.AmqpOverWebsocket or kwargs.get("http_proxy")
swathipil marked this conversation as resolved.
Show resolved Hide resolved
) else SASLTransport
self._transport = sasl_transport(
host=parsed_url.netloc,
credential=kwargs['sasl_credential'],
**kwargs
)
else:
self._transport = Transport(parsed_url.netloc, **kwargs)
self._transport = Transport(parsed_url.netloc, self._transport_type, **kwargs)

self._container_id = kwargs.pop('container_id', None) or str(uuid.uuid4()) # type: str
self._max_frame_size = kwargs.pop('max_frame_size', MAX_FRAME_SIZE_BYTES) # type: int
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from ._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack
from ._encode import encode_frame
from ._decode import decode_frame, decode_empty_frame
from .constants import TLS_HEADER_FRAME
from .constants import TLS_HEADER_FRAME, WEBSOCKET_PORT, TransportType, AMQP_WS_SUBPROTOCOL


try:
Expand Down Expand Up @@ -456,7 +456,6 @@ def send_frame(self, channel, frame, **kwargs):
else:
encoded_channel = struct.pack('>H', channel)
data = header + encoded_channel + performative

self.write(data)

def negotiate(self, encode, decode):
Expand Down Expand Up @@ -647,11 +646,82 @@ def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)):
return result


def Transport(host, connect_timeout=None, ssl=False, **kwargs):
def Transport(host, transport_type, connect_timeout=None, ssl=False, **kwargs):
"""Create transport.

Given a few parameters from the Connection constructor,
swathipil marked this conversation as resolved.
Show resolved Hide resolved
select and create a subclass of _AbstractTransport.
"""
transport = SSLTransport if ssl else TCPTransport
if transport_type == TransportType.AmqpOverWebsocket:
transport = WebSocketTransport
else:
transport = SSLTransport if ssl else TCPTransport
return transport(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs)

class WebSocketTransport(_AbstractTransport):
def __init__(self, host, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs
):
self.sslopts = ssl if isinstance(ssl, dict) else {}
self._connect_timeout = connect_timeout
self._host = host
super().__init__(
host, port, connect_timeout, **kwargs
)
self.ws = None
self._http_proxy = kwargs.get('http_proxy', None)

def connect(self):
http_proxy_host, http_proxy_port, http_proxy_auth = None, None, None
if self._http_proxy:
http_proxy_host = self._http_proxy['proxy_hostname']
http_proxy_port = self._http_proxy['proxy_port']
username = self._http_proxy.get('username', None)
password = self._http_proxy.get('password', None)
if username or password:
http_proxy_auth = (username, password)
try:
from websocket import create_connection
self.ws = create_connection(
url="wss://{}/$servicebus/websocket/".format(self._host),
rakshith91 marked this conversation as resolved.
Show resolved Hide resolved
subprotocols=[AMQP_WS_SUBPROTOCOL],
timeout=self._connect_timeout,
skip_utf8_validation=True,
sslopt=self.sslopts,
http_proxy_host=http_proxy_host,
http_proxy_port=http_proxy_port,
http_proxy_auth=http_proxy_auth
)
except ImportError:
raise ValueError("Please install websocket-client library to use websocket transport.")

def _read(self, n, buffer=None, **kwargs): # pylint: disable=unused-arguments
annatisch marked this conversation as resolved.
Show resolved Hide resolved
"""Read exactly n bytes from the peer."""

length = 0
view = buffer or memoryview(bytearray(n))
nbytes = self._read_buffer.readinto(view)
length += nbytes
n -= nbytes
while n:
data = self.ws.recv()

if len(data) <= n:
view[length: length + len(data)] = data
n -= len(data)
else:
view[length: length + n] = data[0:n]
self._read_buffer = BytesIO(data[n:])
n = 0
swathipil marked this conversation as resolved.
Show resolved Hide resolved
return view

def _shutdown_transport(self):
"""Do any preliminary work in shutting down the connection."""
self.ws.close()

def _write(self, s):
"""Completely write a string to the peer.
ABNF, OPCODE_BINARY = 0x2
See http://tools.ietf.org/html/rfc5234
http://tools.ietf.org/html/rfc6455#section-5.2
"""
self.ws.send_binary(s)
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,23 @@ async def _do_retryable_operation_async(self, operation, *args, **kwargs):
absolute_timeout -= (end_time - start_time)
raise retry_settings['history'][-1]

async def _keep_alive_worker_async(self):
interval = 10 if self._keep_alive is True else self._keep_alive
start_time = time.time()
try:
while self._connection and not self._shutdown:
current_time = time.time()
elapsed_time = (current_time - start_time)
if elapsed_time >= interval:
_logger.info("Keeping %r connection alive. %r",
self.__class__.__name__,
self._connection._container_id)
await self._connection._get_remote_timeout(current_time)
start_time = current_time
await asyncio.sleep(1)
except Exception as e: # pylint: disable=broad-except
_logger.info("Connection keep-alive for %r failed: %r.", self.__class__.__name__, e)

annatisch marked this conversation as resolved.
Show resolved Hide resolved
async def open_async(self):
"""Asynchronously open the client. The client can create a new Connection
or an existing Connection can be passed in. This existing Connection
Expand All @@ -200,6 +217,8 @@ async def open_async(self):
max_frame_size=self._max_frame_size,
channel_max=self._channel_max,
idle_timeout=self._idle_timeout,
transport_type=self._transport_type,
http_proxy=self._http_proxy,
properties=self._properties,
network_trace=self._network_trace
)
Expand All @@ -217,6 +236,8 @@ async def open_async(self):
auth_timeout=self._auth_timeout
)
await self._cbs_authenticator.open()
if self._keep_alive:
self._keep_alive_thread = asyncio.ensure_future(self._keep_alive_worker_async())
self._shutdown = False

async def close_async(self):
Expand All @@ -228,6 +249,9 @@ async def close_async(self):
self._shutdown = True
if not self._session:
return # already closed.
if self._keep_alive_thread:
await self._keep_alive_thread
self._keep_alive_thread = None
await self._close_link_async(close=True)
if self._cbs_authenticator:
await self._cbs_authenticator.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import asyncio

from ._transport_async import AsyncTransport
from ._sasl_async import SASLTransport
from ._sasl_async import SASLTransport, SASLWithWebSocket
from ._session_async import Session
from ..performatives import OpenFrame, CloseFrame
from .._connection import get_local_timeout
Expand All @@ -27,7 +27,8 @@
MAX_CHANNELS,
HEADER_FRAME,
ConnectionState,
EMPTY_FRAME
EMPTY_FRAME,
TransportType
)

from ..error import (
Expand Down Expand Up @@ -58,24 +59,33 @@ class Connection(object):
:param list(str) offered_capabilities: The extension capabilities the sender supports.
:param list(str) desired_capabilities: The extension capabilities the sender may use if the receiver supports
:param dict properties: Connection properties.
:keyword str transport_type: Determines if the transport type is Amqp or AmqpOverWebSocket.
rakshith91 marked this conversation as resolved.
Show resolved Hide resolved
Defaults to TransportType.Amqp. It will be AmqpOverWebSocket if using http_proxy.
:keyword Dict http_proxy: HTTP proxy settings. This must be a dictionary with the following
keys: `'proxy_hostname'` (str value) and `'proxy_port'` (int value). When using these settings,
the transport_type would be AmqpOverWebSocket.
Additionally the following keys may also be present: `'username', 'password'`.
"""

def __init__(self, endpoint, **kwargs):
parsed_url = urlparse(endpoint)
self.hostname = parsed_url.hostname
self._transport_type = kwargs.pop('transport_type', TransportType.Amqp)
swathipil marked this conversation as resolved.
Show resolved Hide resolved
if parsed_url.port:
self.port = parsed_url.port
elif parsed_url.scheme == 'amqps':
self.port = SECURE_PORT
else:
self.port = PORT
self.state = None

transport = kwargs.get('transport')
if transport:
self.transport = transport
elif 'sasl_credential' in kwargs:
self.transport = SASLTransport(
sasl_transport = SASLWithWebSocket if (
self._transport_type is TransportType.AmqpOverWebsocket or kwargs.get("http_proxy")
) else SASLTransport
self.transport = sasl_transport(
swathipil marked this conversation as resolved.
Show resolved Hide resolved
host=parsed_url.netloc,
credential=kwargs['sasl_credential'],
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import struct
from enum import Enum

from ._transport_async import AsyncTransport
from ._transport_async import AsyncTransport, WebSocketTransportAsync
from ..types import AMQPTypes, TYPE, VALUE
from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME
from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME, WEBSOCKET_PORT, TransportType
from .._transport import AMQPS_PORT
from ..performatives import (
SASLOutcome,
Expand Down Expand Up @@ -72,14 +72,7 @@ class SASLExternalCredential(object):
def start(self):
return b''


class SASLTransport(AsyncTransport):

def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs):
self.credential = credential
ssl = ssl or True
super(SASLTransport, self).__init__(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs)

class SASLTransportMixinAsync():
async def negotiate(self):
await self.write(SASL_HEADER_FRAME)
_, returned_header = await self.receive_frame()
Expand All @@ -104,3 +97,26 @@ async def negotiate(self):
return
else:
raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields))

class SASLTransport(AsyncTransport, SASLTransportMixinAsync):
def __init__(self, host, credential, connect_timeout=None, ssl=None, **kwargs):
self.credential = credential
ssl = ssl or True
super(SASLTransport, self).__init__(host, connect_timeout=connect_timeout, ssl=ssl, **kwargs)

class SASLWithWebSocket(WebSocketTransportAsync, SASLTransportMixinAsync):
def __init__(
self, host, credential, port=WEBSOCKET_PORT, connect_timeout=None, ssl=None, **kwargs
): # pylint: disable=super-init-not-called
swathipil marked this conversation as resolved.
Show resolved Hide resolved
self.credential = credential
ssl = ssl or True
http_proxy = kwargs.pop('http_proxy', None)
self._transport = WebSocketTransportAsync(
host,
port=port,
connect_timeout=connect_timeout,
ssl=ssl,
http_proxy=http_proxy,
**kwargs
)
super().__init__(host, port, connect_timeout, ssl, **kwargs)
Loading