Skip to content

Commit

Permalink
Add implicit_tls connect arg to support non-standard implicit TLS c…
Browse files Browse the repository at this point in the history
…onnections, such as Google Cloud SQL

fixes #757
  • Loading branch information
Nothing4You committed Jul 10, 2022
1 parent b21f0ed commit 3d0c9f1
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ next (unreleased)

* Remove deprecated Pool.get #706

* Add `implicit_tls` connect arg to support non-standard implicit TLS connections, such as Google Cloud SQL #757

0.1.1 (2022-05-08)
^^^^^^^^^^^^^^^^^^

Expand Down
26 changes: 19 additions & 7 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def connect(host="localhost", user=None, password="",
connect_timeout=None, read_default_group=None,
autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name='', server_public_key=None):
program_name='', server_public_key=None, implicit_tls=False):
"""See connections.Connection.__init__() for information about
defaults."""
coro = _connect(host=host, user=user, password=password, db=db,
Expand All @@ -66,7 +66,8 @@ def connect(host="localhost", user=None, password="",
read_default_group=read_default_group,
autocommit=autocommit, echo=echo,
local_infile=local_infile, loop=loop, ssl=ssl,
auth_plugin=auth_plugin, program_name=program_name)
auth_plugin=auth_plugin, program_name=program_name,
implicit_tls=implicit_tls)
return _ConnectionContextManager(coro)


Expand Down Expand Up @@ -142,7 +143,7 @@ def __init__(self, host="localhost", user=None, password="",
connect_timeout=None, read_default_group=None,
autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name='', server_public_key=None):
program_name='', server_public_key=None, implicit_tls=False):
"""
Establish a connection to the MySQL database. Accepts several
arguments:
Expand Down Expand Up @@ -184,6 +185,9 @@ def __init__(self, host="localhost", user=None, password="",
handshaking with MySQL. (omitted by default)
:param server_public_key: SHA256 authentication plugin public
key value.
:param implicit_tls: Establish TLS immediately, skipping non-TLS
preamble before upgrading to TLS.
(default: False)
:param loop: asyncio loop
"""
self._loop = loop or asyncio.get_event_loop()
Expand Down Expand Up @@ -218,6 +222,7 @@ def __init__(self, host="localhost", user=None, password="",
self._auth_plugin_used = ""
self._secure = False
self.server_public_key = server_public_key
self._implicit_tls = implicit_tls
self.salt = None

from . import __version__
Expand Down Expand Up @@ -536,7 +541,8 @@ async def _connect(self):

self._next_seq_id = 0

await self._get_server_information()
if not self._implicit_tls:
await self._get_server_information()
await self._request_authentication()

self.connected_time = self._loop.time()
Expand Down Expand Up @@ -727,7 +733,8 @@ async def _execute_command(self, command, sql):

async def _request_authentication(self):
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
if int(self.server_version.split('.', 1)[0]) >= 5:
# FIXME: change this before merge
if self._implicit_tls or int(self.server_version.split('.', 1)[0]) >= 5:
self.client_flag |= CLIENT.MULTI_RESULTS

if self.user is None:
Expand All @@ -737,8 +744,10 @@ async def _request_authentication(self):
data_init = struct.pack('<iIB23s', self.client_flag, MAX_PACKET_LEN,
charset_id, b'')

if self._ssl_context and self.server_capabilities & CLIENT.SSL:
self.write_packet(data_init)
if self._ssl_context and \
(self._implicit_tls or self.server_capabilities & CLIENT.SSL):
if not self._implicit_tls:
self.write_packet(data_init)

# Stop sending events to data_received
self._writer.transport.pause_reading()
Expand All @@ -760,6 +769,9 @@ async def _request_authentication(self):
server_hostname=self._host
)

if self._implicit_tls:
await self._get_server_information()

self._secure = True

if isinstance(self.user, str):
Expand Down
7 changes: 6 additions & 1 deletion docs/connection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Example::
connect_timeout=None, read_default_group=None,
autocommit=False, echo=False
ssl=None, auth_plugin='', program_name='',
server_public_key=None, loop=None)
server_public_key=None, loop=None, implicit_tls=False)

A :ref:`coroutine <coroutine>` that connects to MySQL.

Expand Down Expand Up @@ -93,6 +93,11 @@ Example::
``sys.argv[0]`` is no longer passed by default
:param server_public_key: SHA256 authenticaiton plugin public key value.
:param loop: asyncio event loop instance or ``None`` for default one.
:param implicit_tls: Establish TLS immediately, skipping non-TLS
preamble before upgrading to TLS.
(default: False)

.. versionadded:: 0.2
:returns: :class:`Connection` instance.


Expand Down

0 comments on commit 3d0c9f1

Please sign in to comment.