diff --git a/aiomysql/connection.py b/aiomysql/connection.py index ef99d483..290712f9 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -798,10 +798,10 @@ async def _process_auth(self, plugin_name, auth_packet): # These auth plugins do their own packet handling if plugin_name == b"caching_sha2_password": await self.caching_sha2_password_auth(auth_packet) - self._auth_plugin_used = plugin_name + self._auth_plugin_used = plugin_name.decode() elif plugin_name == b"sha256_password": await self.sha256_password_auth(auth_packet) - self._auth_plugin_used = plugin_name + self._auth_plugin_used = plugin_name.decode() else: if plugin_name == b"mysql_native_password": @@ -832,7 +832,7 @@ async def _process_auth(self, plugin_name, auth_packet): pkt = await self._read_packet() pkt.check_error() - self._auth_plugin_used = plugin_name + self._auth_plugin_used = plugin_name.decode() return pkt diff --git a/tests/conftest.py b/tests/conftest.py index 5fe3f93a..cbcfa1f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -218,6 +218,14 @@ def docker(): return APIClient(version='auto') +@pytest.fixture(autouse=True) +def ensure_mysql_verison(request, mysql_tag): + if request.node.get_marker('mysql_verison'): + if request.node.get_marker('mysql_verison').args[0] != mysql_tag: + pytest.skip('Not applicable for ' + 'MySQL version: {0}'.format(mysql_tag)) + + @pytest.fixture(scope='session') def mysql_server(unused_port, docker, session_id, mysql_tag, request): if not request.config.option.no_pull: @@ -295,6 +303,32 @@ def mysql_server(unused_port, docker, session_id, mysql_tag, request): assert result['Value'].startswith('TLS'), \ "Not connected to the database with TLS" + # Create Databases + cursor.execute('CREATE DATABASE test_pymysql ' + 'DEFAULT CHARACTER SET utf8 ' + 'DEFAULT COLLATE utf8_general_ci;') + cursor.execute('CREATE DATABASE test_pymysql2 ' + 'DEFAULT CHARACTER SET utf8 ' + 'DEFAULT COLLATE utf8_general_ci;') + + # Do MySQL8+ Specific Setup + if mysql_tag in ('8.0',): + # Create Users to test SHA256 + cursor.execute('CREATE USER user_sha256 ' + 'IDENTIFIED WITH "sha256_password" ' + 'BY "pass_sha256"') + cursor.execute('CREATE USER nopass_sha256 ' + 'IDENTIFIED WITH "sha256_password"') + cursor.execute('CREATE USER user_caching_sha2 ' + 'IDENTIFIED ' + 'WITH "caching_sha2_password" ' + 'BY "pass_caching_sha2"') + cursor.execute('CREATE USER nopass_caching_sha2 ' + 'IDENTIFIED ' + 'WITH "caching_sha2_password" ' + 'PASSWORD EXPIRE NEVER') + cursor.execute('FLUSH PRIVILEGES') + break except Exception as err: time.sleep(delay) diff --git a/tests/test_sha_connection.py b/tests/test_sha_connection.py new file mode 100644 index 00000000..f0ac1ddd --- /dev/null +++ b/tests/test_sha_connection.py @@ -0,0 +1,25 @@ +import copy +from aiomysql import create_pool + +import pytest + + +@pytest.mark.mysql_verison('8.0') +@pytest.mark.run_loop +@pytest.mark.parametrize("user,password,plugin", [ + ("nopass_sha256", None, 'sha256_password'), + ("user_sha256", 'pass_sha256', 'sha256_password'), + ("nopass_caching_sha2", None, 'caching_sha2_password'), + ("user_caching_sha2", 'pass_caching_sha2', 'caching_sha2_password'), +]) +async def test_sha(mysql_server, loop, user, password, plugin): + connection_data = copy.copy(mysql_server['conn_params']) + connection_data['user'] = user + connection_data['password'] = password + + async with create_pool(**connection_data, + loop=loop) as pool: + async with pool.get() as conn: + # User doesnt have any permissions to look at DBs + # But as 8.0 will default to caching_sha2_password + assert conn._auth_plugin_used == plugin diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 044d759e..ff1ea740 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -54,5 +54,5 @@ async def test_auth_plugin_renegotiation(mysql_server, loop): 'Server did not ask for native auth' # Check we actually used the servers default plugin assert conn._auth_plugin_used in ( - b'mysql_native_password', b'caching_sha2_password'), \ + 'mysql_native_password', 'caching_sha2_password'), \ 'Client did not renegotiate with server\'s default auth'