diff --git a/aiomysql/connection.py b/aiomysql/connection.py index f2f84bc0..0c338130 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -41,7 +41,7 @@ # from aiomysql.utils import _convert_to_str from .cursors import Cursor from .utils import _ConnectionContextManager, _ContextManager -# from .log import logger +from .log import logger DEFAULT_USER = getpass.getuser() @@ -55,7 +55,7 @@ def connect(host="localhost", user=None, password="", connect_timeout=None, read_default_group=None, no_delay=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', - program_name=''): + program_name='', server_public_key=None): """See connections.Connection.__init__() for information about defaults.""" coro = _connect(host=host, user=user, password=password, db=db, @@ -93,7 +93,7 @@ def __init__(self, host="localhost", user=None, password="", connect_timeout=None, read_default_group=None, no_delay=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', - program_name=''): + program_name='', server_public_key=None): """ Establish a connection to the MySQL database. Accepts several arguments: @@ -134,6 +134,8 @@ def __init__(self, host="localhost", user=None, password="", (default: Server Default) :param program_name: Program name string to provide when handshaking with MySQL. (default: sys.argv[0]) + :param server_public_key: SHA256 authentication plugin public + key value. :param loop: asyncio loop """ self._loop = loop or asyncio.get_event_loop() @@ -174,6 +176,8 @@ def __init__(self, host="localhost", user=None, password="", self._client_auth_plugin = auth_plugin self._server_auth_plugin = "" self._auth_plugin_used = "" + self.server_public_key = server_public_key + self.salt = None # TODO somehow import version from __init__.py self._connect_attrs = { @@ -712,6 +716,20 @@ async def _request_authentication(self): if auth_plugin in ('', 'mysql_native_password'): authresp = _auth.scramble_native_password( self._password.encode('latin1'), self.salt) + elif auth_plugin == 'caching_sha2_password': + if self._password: + authresp = _auth.scramble_caching_sha2( + self._password.encode('latin1'), self.salt + ) + # Else: empty password + elif auth_plugin == 'sha256_password': + if self._ssl_context and self.server_capabilities & CLIENT.SSL: + authresp = self._password.encode('latin1') + b'\0' + elif self._password: + authresp = b'\1' # request public key + else: + authresp = b'\0' # empty password + elif auth_plugin in ('', 'mysql_clear_password'): authresp = self._password.encode('latin1') + b'\0' @@ -768,35 +786,174 @@ async def _request_authentication(self): auth_packet.read_all()) + b'\0' self.write_packet(data) await self._read_packet() + elif auth_packet.is_extra_auth_data(): + if auth_plugin == "caching_sha2_password": + await self.caching_sha2_password_auth(auth_packet) + elif auth_plugin == "sha256_password": + await self.sha256_password_auth(auth_packet) + else: + raise OperationalError("Received extra packet " + "for auth method %r", auth_plugin) async def _process_auth(self, plugin_name, auth_packet): - if plugin_name == b"mysql_native_password": - # https://dev.mysql.com/doc/internals/en/ - # secure-password-authentication.html#packet-Authentication:: - # Native41 - data = _auth.scramble_native_password( - self._password.encode('latin1'), - auth_packet.read_all()) - elif plugin_name == b"mysql_old_password": - # https://dev.mysql.com/doc/internals/en/ - # old-password-authentication.html - data = _auth.scramble_old_password(self._password.encode('latin1'), - auth_packet.read_all()) + b'\0' - elif plugin_name == b"mysql_clear_password": - # https://dev.mysql.com/doc/internals/en/ - # clear-text-authentication.html - data = self._password.encode('latin1') + b'\0' + # These auth plugins do their own packet handling + if plugin_name == b"caching_sha2_password": + await self.caching_sha2_password_auth(auth_packet) + self._auth_plugin_used = plugin_name.decode() + elif plugin_name == b"sha256_password": + await self.sha256_password_auth(auth_packet) + self._auth_plugin_used = plugin_name.decode() else: + + if plugin_name == b"mysql_native_password": + # https://dev.mysql.com/doc/internals/en/ + # secure-password-authentication.html#packet-Authentication:: + # Native41 + data = _auth.scramble_native_password( + self._password.encode('latin1'), + auth_packet.read_all()) + elif plugin_name == b"mysql_old_password": + # https://dev.mysql.com/doc/internals/en/ + # old-password-authentication.html + data = _auth.scramble_old_password( + self._password.encode('latin1'), + auth_packet.read_all() + ) + b'\0' + elif plugin_name == b"mysql_clear_password": + # https://dev.mysql.com/doc/internals/en/ + # clear-text-authentication.html + data = self._password.encode('latin1') + b'\0' + else: + raise OperationalError( + 2059, "Authentication plugin '{0}'" + " not configured".format(plugin_name) + ) + + self.write_packet(data) + pkt = await self._read_packet() + pkt.check_error() + + self._auth_plugin_used = plugin_name.decode() + + return pkt + + async def caching_sha2_password_auth(self, pkt): + # No password fast path + if not self._password: + self.write_packet(b'') + pkt = await self._read_packet() + pkt.check_error() + return pkt + + if pkt.is_auth_switch_request(): + # Try from fast auth + logger.debug("caching sha2: Trying fast path") + self.salt = pkt.read_all() + scrambled = _auth.scramble_caching_sha2( + self._password.encode('latin1'), self.salt + ) + + self.write_packet(scrambled) + pkt = await self._read_packet() + pkt.check_error() + + # else: fast auth is tried in initial handshake + + if not pkt.is_extra_auth_data(): raise OperationalError( - 2059, "Authentication plugin '%s' not configured" % plugin_name + "caching sha2: Unknown packet " + "for fast auth: {0}".format(pkt._data[:1]) ) + # magic numbers: + # 2 - request public key + # 3 - fast auth succeeded + # 4 - need full auth + + pkt.advance(1) + n = pkt.read_uint8() + + if n == 3: + logger.debug("caching sha2: succeeded by fast path.") + pkt = await self._read_packet() + pkt.check_error() # pkt must be OK packet + return pkt + + if n != 4: + raise OperationalError("caching sha2: Unknown " + "result for fast auth: {0}".format(n)) + + logger.debug("caching sha2: Trying full auth...") + + if self._ssl_context: + logger.debug("caching sha2: Sending plain " + "password via secure connection") + self.write_packet(self._password.encode('latin1') + b'\0') + pkt = await self._read_packet() + pkt.check_error() + return pkt + + if not self.server_public_key: + self.write_packet(b'\x02') + pkt = await self._read_packet() # Request public key + pkt.check_error() + + if not pkt.is_extra_auth_data(): + raise OperationalError( + "caching sha2: Unknown packet " + "for public key: {0}".format(pkt._data[:1]) + ) + + self.server_public_key = pkt._data[1:] + logger.debug(self.server_public_key.decode('ascii')) + + data = _auth.sha2_rsa_encrypt( + self._password.encode('latin1'), self.salt, + self.server_public_key + ) self.write_packet(data) pkt = await self._read_packet() pkt.check_error() - self._auth_plugin_used = plugin_name + async def sha256_password_auth(self, pkt): + if self._ssl_context: + logger.debug("sha256: Sending plain password") + data = self._password.encode('latin1') + b'\0' + self.write_packet(data) + pkt = await self._read_packet() + pkt.check_error() + return pkt + + if pkt.is_auth_switch_request(): + self.salt = pkt.read_all() + if not self.server_public_key and self._password: + # Request server public key + logger.debug("sha256: Requesting server public key") + self.write_packet(b'\1') + pkt = await self._read_packet() + pkt.check_error() + + if pkt.is_extra_auth_data(): + self.server_public_key = pkt._data[1:] + logger.debug( + "Received public key:\n", + self.server_public_key.decode('ascii') + ) + + if self._password: + if not self.server_public_key: + raise OperationalError("Couldn't receive server's public key") + + data = _auth.sha2_rsa_encrypt( + self._password.encode('latin1'), self.salt, + self.server_public_key + ) + else: + data = b'' + self.write_packet(data) + pkt = await self._read_packet() + pkt.check_error() return pkt # _mysql support diff --git a/docs/connection.rst b/docs/connection.rst index 604a0bc3..de3dc0c8 100644 --- a/docs/connection.rst +++ b/docs/connection.rst @@ -47,7 +47,8 @@ Example:: client_flag=0, cursorclass=Cursor, init_command=None, connect_timeout=None, read_default_group=None, no_delay=False, autocommit=False, echo=False, - ssl=None, auth_plugin='', program_name='', loop=None) + ssl=None, auth_plugin='', program_name='', + server_public_key=None, loop=None) A :ref:`coroutine ` that connects to MySQL. @@ -89,6 +90,7 @@ Example:: (default: Server Default) :param program_name: Program name string to provide when handshaking with MySQL. (default: sys.argv[0]) + :param server_public_key: SHA256 authenticaiton plugin public key value. :param loop: asyncio event loop instance or ``None`` for default one. :returns: :class:`Connection` instance. diff --git a/examples/example_ssl.py b/examples/example_ssl.py new file mode 100644 index 00000000..e66c267d --- /dev/null +++ b/examples/example_ssl.py @@ -0,0 +1,38 @@ +import asyncio +import ssl +import aiomysql + +ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +ctx.check_hostname = False +ctx.load_verify_locations(cafile='../tests/ssl_resources/ssl/ca.pem') + + +async def main(): + async with aiomysql.create_pool( + host='localhost', port=3306, user='root', + password='rootpw', ssl=ctx, + auth_plugin='mysql_clear_password') as pool: + + async with pool.get() as conn: + async with conn.cursor() as cur: + # Run simple command + await cur.execute("SHOW DATABASES;") + value = await cur.fetchall() + + values = [item[0] for item in value] + # Spot check the answers, we should at least have mysql + # and information_schema + assert 'mysql' in values, \ + 'Could not find the "mysql" table' + assert 'information_schema' in values, \ + 'Could not find the "mysql" table' + + # Check TLS variables + await cur.execute("SHOW STATUS LIKE 'Ssl_version%';") + value = await cur.fetchone() + + # The context has TLS + assert value[1].startswith('TLS'), \ + 'Not connected to the database with TLS' + +asyncio.get_event_loop().run_until_complete(main()) diff --git a/tests/conftest.py b/tests/conftest.py index 4a9b9bbb..16304adf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,15 +35,15 @@ def pytest_generate_tests(metafunc): loop_type = ['asyncio', 'uvloop'] if uvloop else ['asyncio'] metafunc.parametrize("loop_type", loop_type) - # if 'mysql_tag' in metafunc.fixturenames: - # tags = set(metafunc.config.option.mysql_tag) - # if not tags: - # tags = ['5.7'] - # elif 'all' in tags: - # tags = ['5.6', '5.7', '8.0'] - # else: - # tags = list(tags) - # metafunc.parametrize("mysql_tag", tags, scope='session') + if 'mysql_tag' in metafunc.fixturenames: + tags = set(metafunc.config.option.mysql_tag) + if not tags: + tags = ['5.6', '8.0'] + elif 'all' in tags: + tags = ['5.6', '5.7', '8.0'] + else: + tags = list(tags) + metafunc.parametrize("mysql_tag", tags, scope='session') # This is here unless someone fixes the generate_tests bit @@ -218,8 +218,18 @@ def docker(): return APIClient(version='auto') +@pytest.fixture(autouse=True) +def ensure_mysql_verison(request, mysql_tag): + if request.node.get_marker('mysql_verison'): + if request.node.get_marker('mysql_verison').args[0] != mysql_tag: + pytest.skip('Not applicable for ' + 'MySQL version: {0}'.format(mysql_tag)) + + @pytest.fixture(scope='session') def mysql_server(unused_port, docker, session_id, mysql_tag, request): + print('\nSTARTUP CONTAINER - {0}\n'.format(mysql_tag)) + if not request.config.option.no_pull: docker.pull('mysql:{}'.format(mysql_tag)) @@ -288,13 +298,39 @@ def mysql_server(unused_port, docker, session_id, mysql_tag, request): assert result['have_ssl'] == "YES", \ "SSL Not Enabled on docker'd MySQL" - cursor.execute("SHOW STATUS LIKE '%Ssl_version%'") + cursor.execute("SHOW STATUS LIKE 'Ssl_version%'") result = cursor.fetchone() # As we connected with TLS, it should start with that :D assert result['Value'].startswith('TLS'), \ "Not connected to the database with TLS" + # Create Databases + cursor.execute('CREATE DATABASE test_pymysql ' + 'DEFAULT CHARACTER SET utf8 ' + 'DEFAULT COLLATE utf8_general_ci;') + cursor.execute('CREATE DATABASE test_pymysql2 ' + 'DEFAULT CHARACTER SET utf8 ' + 'DEFAULT COLLATE utf8_general_ci;') + + # Do MySQL8+ Specific Setup + if mysql_tag in ('8.0',): + # Create Users to test SHA256 + cursor.execute('CREATE USER user_sha256 ' + 'IDENTIFIED WITH "sha256_password" ' + 'BY "pass_sha256"') + cursor.execute('CREATE USER nopass_sha256 ' + 'IDENTIFIED WITH "sha256_password"') + cursor.execute('CREATE USER user_caching_sha2 ' + 'IDENTIFIED ' + 'WITH "caching_sha2_password" ' + 'BY "pass_caching_sha2"') + cursor.execute('CREATE USER nopass_caching_sha2 ' + 'IDENTIFIED ' + 'WITH "caching_sha2_password" ' + 'PASSWORD EXPIRE NEVER') + cursor.execute('FLUSH PRIVILEGES') + break except Exception as err: time.sleep(delay) @@ -308,5 +344,6 @@ def mysql_server(unused_port, docker, session_id, mysql_tag, request): yield container finally: + print('\nTEARDOWN CONTAINER - {0}\n'.format(mysql_tag)) docker.kill(container=container['Id']) docker.remove_container(container['Id']) diff --git a/tests/test_sha_connection.py b/tests/test_sha_connection.py new file mode 100644 index 00000000..f2a108d8 --- /dev/null +++ b/tests/test_sha_connection.py @@ -0,0 +1,81 @@ +import copy +from aiomysql import create_pool + +import pytest + + +# You could parameterise these tests with this, but then pytest +# does some funky stuff and spins up and tears down containers +# per function call. Remember it would be +# mysql_versions * event_loops * 4 auth tests ~= 3*2*4 ~= 24 tests + +# As the MySQL daemon restarts at least 3 times in the container +# before it becomes stable, there's a sleep(10) so that's +# around a 4min wait time. + +# @pytest.mark.parametrize("user,password,plugin", [ +# ("nopass_sha256", None, 'sha256_password'), +# ("user_sha256", 'pass_sha256', 'sha256_password'), +# ("nopass_caching_sha2", None, 'caching_sha2_password'), +# ("user_caching_sha2", 'pass_caching_sha2', 'caching_sha2_password'), +# ]) + + +@pytest.mark.mysql_verison('8.0') +@pytest.mark.run_loop +async def test_sha256_nopw(mysql_server, loop): + connection_data = copy.copy(mysql_server['conn_params']) + connection_data['user'] = 'nopass_sha256' + connection_data['password'] = None + + async with create_pool(**connection_data, + loop=loop) as pool: + async with pool.get() as conn: + # User doesnt have any permissions to look at DBs + # But as 8.0 will default to caching_sha2_password + assert conn._auth_plugin_used == 'sha256_password' + + +@pytest.mark.mysql_verison('8.0') +@pytest.mark.run_loop +async def test_sha256_pw(mysql_server, loop): + connection_data = copy.copy(mysql_server['conn_params']) + connection_data['user'] = 'user_sha256' + connection_data['password'] = 'pass_sha256' + + async with create_pool(**connection_data, + loop=loop) as pool: + async with pool.get() as conn: + # User doesnt have any permissions to look at DBs + # But as 8.0 will default to caching_sha2_password + assert conn._auth_plugin_used == 'sha256_password' + + +@pytest.mark.mysql_verison('8.0') +@pytest.mark.run_loop +async def test_cached_sha256_nopw(mysql_server, loop): + connection_data = copy.copy(mysql_server['conn_params']) + connection_data['user'] = 'nopass_caching_sha2' + connection_data['password'] = None + + async with create_pool(**connection_data, + loop=loop) as pool: + async with pool.get() as conn: + # User doesnt have any permissions to look at DBs + # But as 8.0 will default to caching_sha2_password + assert conn._auth_plugin_used == 'caching_sha2_password' + + +@pytest.mark.mysql_verison('8.0') +@pytest.mark.run_loop +async def test_cached_sha256_pw(mysql_server, loop): + connection_data = copy.copy(mysql_server['conn_params']) + connection_data['user'] = 'user_caching_sha2' + connection_data['password'] = 'pass_caching_sha2' + + async with create_pool(**connection_data, + loop=loop) as pool: + async with pool.get() as conn: + # User doesnt have any permissions to look at DBs + # But as 8.0 will default to caching_sha2_password + assert conn._auth_plugin_used == 'caching_sha2_password' diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 07c8ef61..ff1ea740 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -22,7 +22,7 @@ async def test_tls_connect(mysql_server, loop): 'Could not find the "mysql" table' # Check TLS variables - await cur.execute("SHOW STATUS LIKE '%Ssl_version%';") + await cur.execute("SHOW STATUS LIKE 'Ssl_version%';") value = await cur.fetchone() # The context has TLS @@ -44,9 +44,15 @@ async def test_auth_plugin_renegotiation(mysql_server, loop): assert len(value), 'No databases found' + # Check we tried to use the cleartext plugin assert conn._client_auth_plugin == 'mysql_clear_password', \ 'Client did not try clear password auth' - assert conn._server_auth_plugin == 'mysql_native_password', \ + + # Check the server asked us to use MySQL's default plugin + assert conn._server_auth_plugin in ( + 'mysql_native_password', 'caching_sha2_password'), \ 'Server did not ask for native auth' - assert conn._auth_plugin_used == b'mysql_native_password', \ - 'Client did not renegotiate with native auth' + # Check we actually used the servers default plugin + assert conn._auth_plugin_used in ( + 'mysql_native_password', 'caching_sha2_password'), \ + 'Client did not renegotiate with server\'s default auth'