From 3d0c9f19751396560d3a0374f9ff17d216426baa Mon Sep 17 00:00:00 2001 From: Richard Schwab Date: Sun, 10 Jul 2022 03:34:51 +0200 Subject: [PATCH] Add `implicit_tls` connect arg to support non-standard implicit TLS connections, such as Google Cloud SQL fixes #757 --- CHANGES.txt | 2 ++ aiomysql/connection.py | 26 +++++++++++++++++++------- docs/connection.rst | 7 ++++++- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index fd1746bb..31a5cee9 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -8,6 +8,8 @@ next (unreleased) * Remove deprecated Pool.get #706 +* Add `implicit_tls` connect arg to support non-standard implicit TLS connections, such as Google Cloud SQL #757 + 0.1.1 (2022-05-08) ^^^^^^^^^^^^^^^^^^ diff --git a/aiomysql/connection.py b/aiomysql/connection.py index 022315a9..2e7c9d57 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -53,7 +53,7 @@ def connect(host="localhost", user=None, password="", connect_timeout=None, read_default_group=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', - program_name='', server_public_key=None): + program_name='', server_public_key=None, implicit_tls=False): """See connections.Connection.__init__() for information about defaults.""" coro = _connect(host=host, user=user, password=password, db=db, @@ -66,7 +66,8 @@ def connect(host="localhost", user=None, password="", read_default_group=read_default_group, autocommit=autocommit, echo=echo, local_infile=local_infile, loop=loop, ssl=ssl, - auth_plugin=auth_plugin, program_name=program_name) + auth_plugin=auth_plugin, program_name=program_name, + implicit_tls=implicit_tls) return _ConnectionContextManager(coro) @@ -142,7 +143,7 @@ def __init__(self, host="localhost", user=None, password="", connect_timeout=None, read_default_group=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', - program_name='', server_public_key=None): + program_name='', server_public_key=None, implicit_tls=False): """ Establish a connection to the MySQL database. Accepts several arguments: @@ -184,6 +185,9 @@ def __init__(self, host="localhost", user=None, password="", handshaking with MySQL. (omitted by default) :param server_public_key: SHA256 authentication plugin public key value. + :param implicit_tls: Establish TLS immediately, skipping non-TLS + preamble before upgrading to TLS. + (default: False) :param loop: asyncio loop """ self._loop = loop or asyncio.get_event_loop() @@ -218,6 +222,7 @@ def __init__(self, host="localhost", user=None, password="", self._auth_plugin_used = "" self._secure = False self.server_public_key = server_public_key + self._implicit_tls = implicit_tls self.salt = None from . import __version__ @@ -536,7 +541,8 @@ async def _connect(self): self._next_seq_id = 0 - await self._get_server_information() + if not self._implicit_tls: + await self._get_server_information() await self._request_authentication() self.connected_time = self._loop.time() @@ -727,7 +733,8 @@ async def _execute_command(self, command, sql): async def _request_authentication(self): # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse - if int(self.server_version.split('.', 1)[0]) >= 5: + # FIXME: change this before merge + if self._implicit_tls or int(self.server_version.split('.', 1)[0]) >= 5: self.client_flag |= CLIENT.MULTI_RESULTS if self.user is None: @@ -737,8 +744,10 @@ async def _request_authentication(self): data_init = struct.pack('` that connects to MySQL. @@ -93,6 +93,11 @@ Example:: ``sys.argv[0]`` is no longer passed by default :param server_public_key: SHA256 authenticaiton plugin public key value. :param loop: asyncio event loop instance or ``None`` for default one. + :param implicit_tls: Establish TLS immediately, skipping non-TLS + preamble before upgrading to TLS. + (default: False) + + .. versionadded:: 0.2 :returns: :class:`Connection` instance.