Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support non-standard implicit TLS connections, such as Google Cloud SQL #786

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .codecov.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
codecov:
notify:
after_n_builds: 40
after_n_builds: 6
59 changes: 44 additions & 15 deletions .github/workflows/ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -406,18 +406,18 @@ jobs:
- ubuntu-latest
py:
- '3.7'
- '3.8'
- '3.9'
- '3.10'
# - '3.8'
# - '3.9'
# - '3.10'
- '3.11-dev'
db:
- [mysql, '5.7']
- [mysql, '8.0']
- [mariadb, '10.3']
- [mariadb, '10.4']
- [mariadb, '10.5']
- [mariadb, '10.6']
- [mariadb, '10.7']
# - [mariadb, '10.3']
# - [mariadb, '10.4']
# - [mariadb, '10.5']
# - [mariadb, '10.6']
# - [mariadb, '10.7']
- [mariadb, '10.8']

fail-fast: false
Expand Down Expand Up @@ -449,6 +449,13 @@ jobs:
options: '--name=mysqld'
env:
MYSQL_ROOT_PASSWORD: rootpw
haproxy:
image: haproxytech/haproxy-alpine:2.6
ports:
- 13306:13306
volumes:
- "/tmp/run-${{ join(matrix.db, '-') }}/:/var/lib/haproxy/socket-mount/"
options: '--name=haproxy'

steps:
- name: Setup Python ${{ matrix.py }}
Expand Down Expand Up @@ -569,6 +576,18 @@ jobs:
# unfortunately we need this hacky workaround as GitHub Actions service containers can't reference data from our repo.
- name: Prepare mysql
run: |
# we need to ensure that the socket path is readable from haproxy and
# writable for the user running the DB process
sudo chmod 0777 /tmp/run-${{ join(matrix.db, '-') }}

# inject HAproxy configuration
docker container stop haproxy

docker container cp "${{ github.workspace }}/tests/ssl_resources/haproxy.cfg" haproxy:/usr/local/etc/haproxy/haproxy.cfg
docker container cp "${{ github.workspace }}/tests/ssl_resources/ssl/server-combined.pem" haproxy:/usr/local/etc/haproxy/haproxy.pem

docker container start haproxy

# ensure server is started up
while :
do
Expand All @@ -582,9 +601,6 @@ jobs:
docker container cp "${{ github.workspace }}/tests/ssl_resources/tls.cnf" mysqld:/etc/mysql/conf.d/aiomysql-tls.cnf

# use custom socket path
# we need to ensure that the socket path is writable for the user running the DB process in the container
sudo chmod 0777 /tmp/run-${{ join(matrix.db, '-') }}

docker container cp "${{ github.workspace }}/tests/ssl_resources/socket.cnf" mysqld:/etc/mysql/conf.d/aiomysql-socket.cnf

docker container start mysqld
Expand All @@ -599,10 +615,23 @@ jobs:
mysql -h127.0.0.1 -uroot "-p$MYSQL_ROOT_PASSWORD" -e "SET GLOBAL local_infile=on"

- name: Run tests
run: |
# timeout ensures a more or less clean stop by sending a KeyboardInterrupt which will still provide useful logs
timeout --preserve-status --signal=INT --verbose 570s \
pytest --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql --cov tests ./tests --mysql-unix-socket "unix-${{ join(matrix.db, '') }}=/tmp/run-${{ join(matrix.db, '-') }}/mysql.sock" --mysql-address "tcp-${{ join(matrix.db, '') }}=127.0.0.1:3306"
# timeout ensures a more or less clean stop by sending a KeyboardInterrupt which will still provide useful logs
run: >-
timeout
--preserve-status
--signal=INT
--verbose 570s
pytest
--capture=no
--verbosity 2
--cov-report term
--cov-report xml
--cov aiomysql
--cov tests
./tests
--mysql-unix-socket "unix-${{ join(matrix.db, '') }}=/tmp/run-${{ join(matrix.db, '-') }}/mysql.sock"
--mysql-address "tcp-${{ join(matrix.db, '') }}=127.0.0.1:3306"
--mysql-address-tls "tls-${{ join(matrix.db, '') }}=127.0.0.1:13306"
env:
PYTHONUNBUFFERED: 1
timeout-minutes: 10
Expand Down
2 changes: 2 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ next (unreleased)
| aiomysql now reraises the original exception during connect() if it's not `IOError`, `OSError` or `asyncio.TimeoutError`.
| This was previously always raised as `OperationalError`.

* Add `implicit_tls` connect arg to support non-standard implicit TLS connections, such as Google Cloud SQL #757

0.1.1 (2022-05-08)
^^^^^^^^^^^^^^^^^^

Expand Down
32 changes: 24 additions & 8 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
import configparser
import getpass
import ssl as ssllib
from functools import partial

from pymysql.charset import charset_by_name, charset_by_id
Expand Down Expand Up @@ -53,7 +54,7 @@ def connect(host="localhost", user=None, password="",
connect_timeout=None, read_default_group=None,
autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name='', server_public_key=None):
program_name='', server_public_key=None, implicit_tls=False):
"""See connections.Connection.__init__() for information about
defaults."""
coro = _connect(host=host, user=user, password=password, db=db,
Expand All @@ -66,7 +67,8 @@ def connect(host="localhost", user=None, password="",
read_default_group=read_default_group,
autocommit=autocommit, echo=echo,
local_infile=local_infile, loop=loop, ssl=ssl,
auth_plugin=auth_plugin, program_name=program_name)
auth_plugin=auth_plugin, program_name=program_name,
implicit_tls=implicit_tls)
return _ConnectionContextManager(coro)


Expand Down Expand Up @@ -142,7 +144,7 @@ def __init__(self, host="localhost", user=None, password="",
connect_timeout=None, read_default_group=None,
autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name='', server_public_key=None):
program_name='', server_public_key=None, implicit_tls=False):
"""
Establish a connection to the MySQL database. Accepts several
arguments:
Expand Down Expand Up @@ -184,6 +186,9 @@ def __init__(self, host="localhost", user=None, password="",
handshaking with MySQL. (omitted by default)
:param server_public_key: SHA256 authentication plugin public
key value.
:param implicit_tls: Establish TLS immediately, skipping non-TLS
preamble before upgrading to TLS.
(default: False)
:param loop: asyncio loop
"""
self._loop = loop or asyncio.get_event_loop()
Expand Down Expand Up @@ -218,6 +223,7 @@ def __init__(self, host="localhost", user=None, password="",
self._auth_plugin_used = ""
self._secure = False
self.server_public_key = server_public_key
self._implicit_tls = implicit_tls
self.salt = None

from . import __version__
Expand All @@ -241,7 +247,10 @@ def __init__(self, host="localhost", user=None, password="",
self.use_unicode = use_unicode

self._ssl_context = ssl
if ssl:
# TLS is required when implicit_tls is True
if implicit_tls and not self._ssl_context:
self._ssl_context = ssllib.create_default_context()
if ssl and not implicit_tls:
client_flag |= CLIENT.SSL

self._encoding = charset_by_name(self._charset).encoding
Expand Down Expand Up @@ -536,7 +545,8 @@ async def _connect(self):

self._next_seq_id = 0

await self._get_server_information()
if not self._implicit_tls:
await self._get_server_information()
await self._request_authentication()

self.connected_time = self._loop.time()
Expand Down Expand Up @@ -738,7 +748,8 @@ async def _execute_command(self, command, sql):

async def _request_authentication(self):
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
if int(self.server_version.split('.', 1)[0]) >= 5:
# FIXME: change this before merge
if self._implicit_tls or int(self.server_version.split('.', 1)[0]) >= 5:
self.client_flag |= CLIENT.MULTI_RESULTS

if self.user is None:
Expand All @@ -748,8 +759,10 @@ async def _request_authentication(self):
data_init = struct.pack('<iIB23s', self.client_flag, MAX_PACKET_LEN,
charset_id, b'')

if self._ssl_context and self.server_capabilities & CLIENT.SSL:
self.write_packet(data_init)
if self._ssl_context and \
(self._implicit_tls or self.server_capabilities & CLIENT.SSL):
if not self._implicit_tls:
self.write_packet(data_init)

# Stop sending events to data_received
self._writer.transport.pause_reading()
Expand All @@ -771,6 +784,9 @@ async def _request_authentication(self):
server_hostname=self._host
)

if self._implicit_tls:
await self._get_server_information()

self._secure = True

if isinstance(self.user, str):
Expand Down
7 changes: 6 additions & 1 deletion docs/connection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Example::
connect_timeout=None, read_default_group=None,
autocommit=False, echo=False
ssl=None, auth_plugin='', program_name='',
server_public_key=None, loop=None)
server_public_key=None, loop=None, implicit_tls=False)

A :ref:`coroutine <coroutine>` that connects to MySQL.

Expand Down Expand Up @@ -93,6 +93,11 @@ Example::
``sys.argv[0]`` is no longer passed by default
:param server_public_key: SHA256 authenticaiton plugin public key value.
:param loop: asyncio event loop instance or ``None`` for default one.
:param implicit_tls: Establish TLS immediately, skipping non-TLS
preamble before upgrading to TLS.
(default: False)

.. versionadded:: 0.2
:returns: :class:`Connection` instance.


Expand Down
63 changes: 54 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gc
import os
import re
import socket
import ssl
import sys

Expand Down Expand Up @@ -63,13 +64,26 @@ def pytest_generate_tests(metafunc):

if ":" in addr:
addr = addr.split(":", 1)
mysql_addresses.append((addr[0], int(addr[1])))
mysql_addresses.append((addr[0], int(addr[1]), False))
else:
mysql_addresses.append((addr, 3306))
mysql_addresses.append((addr, 3306, False))

opt_mysql_address_tls =\
list(metafunc.config.getoption("mysql_address_tls"))
for i in range(len(opt_mysql_address_tls)):
if "=" in opt_mysql_address_tls[i]:
label, addr = opt_mysql_address_tls[i].split("=", 1)
ids.append(label)
else:
addr = opt_mysql_address_tls[i]
ids.append("tls{}".format(i))

addr = addr.split(":", 1)
mysql_addresses.append((addr[0], int(addr[1]), True))

# default to connecting to localhost
if len(mysql_addresses) == 0:
mysql_addresses = [("127.0.0.1", 3306)]
mysql_addresses = [("127.0.0.1", 3306, False)]
ids = ["tcp-local"]

assert len(mysql_addresses) == len(set(mysql_addresses)), \
Expand Down Expand Up @@ -153,6 +167,12 @@ def pytest_addoption(parser):
default=[],
help="list of addresses to connect to: [name=]host[:port]",
)
parser.addoption(
"--mysql-address-tls",
action="append",
default=[],
help="list of addresses to connect to using implicit TLS: [name=]host:port",
)
parser.addoption(
"--mysql-unix-socket",
action="append",
Expand Down Expand Up @@ -249,6 +269,7 @@ def _register_table(table_name):
@pytest.fixture(scope='session')
def mysql_server(mysql_address):
unix_socket = type(mysql_address) is str
implicit_tls = not unix_socket and mysql_address[2]

if not unix_socket:
ssl_directory = os.path.join(os.path.dirname(__file__),
Expand All @@ -270,14 +291,34 @@ def mysql_server(mysql_address):
else:
server_params["host"] = mysql_address[0]
server_params["port"] = mysql_address[1]

if not unix_socket and not implicit_tls:
server_params["ssl"] = ctx

try:
connection = pymysql.connect(
db='mysql',
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor,
**server_params)
if implicit_tls:
sock = ctx.wrap_socket(
socket.create_connection(
(server_params["host"], server_params["port"]),
),
server_hostname=server_params["host"],
)
connection = pymysql.Connection(
db='mysql',
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor,
**server_params,
defer_connect=True,
)
connection.connect(sock)

else:
connection = pymysql.connect(
db='mysql',
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor,
**server_params,
)

with connection.cursor() as cursor:
cursor.execute("SELECT VERSION() AS version")
Expand All @@ -297,7 +338,7 @@ def mysql_server(mysql_address):
pytest.fail("Unable to determine database type from {!r}"
.format(server_version_tuple))

if not unix_socket:
if not unix_socket and not implicit_tls:
cursor.execute("SHOW VARIABLES LIKE '%ssl%';")

result = cursor.fetchall()
Expand Down Expand Up @@ -353,6 +394,10 @@ def mysql_server(mysql_address):
except Exception:
pytest.fail("Cannot initialize MySQL environment")

if implicit_tls:
server_params["ssl"] = ctx
server_params["implicit_tls"] = implicit_tls

return {
"conn_params": server_params,
"server_version": server_version,
Expand Down
2 changes: 2 additions & 0 deletions tests/sa/test_sa_compiled_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ async def _make_engine(**kwargs):
}
if "ssl" in mysql_params:
conn_args["ssl"] = mysql_params["ssl"]
if "implicit_tls" in mysql_params:
conn_args["implicit_tls"] = mysql_params["implicit_tls"]

engine = await sa.create_engine(
db=mysql_params['db'],
Expand Down
2 changes: 2 additions & 0 deletions tests/sa/test_sa_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ async def _make_engine(**kwargs):
}
if "ssl" in mysql_params:
conn_args["ssl"] = mysql_params["ssl"]
if "implicit_tls" in mysql_params:
conn_args["implicit_tls"] = mysql_params["implicit_tls"]

engine = await sa.create_engine(
db=mysql_params['db'],
Expand Down
2 changes: 2 additions & 0 deletions tests/sa/test_sa_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ async def _make_engine(**kwargs):
}
if "ssl" in mysql_params:
conn_args["ssl"] = mysql_params["ssl"]
if "implicit_tls" in mysql_params:
conn_args["implicit_tls"] = mysql_params["implicit_tls"]

engine = await sa.create_engine(
db=mysql_params['db'],
Expand Down
Loading