From 120becd7143257c537df4914c1b3302074b453de Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 13 Sep 2024 18:18:46 +0000 Subject: [PATCH 01/24] chore: test dns as SAN --- google/cloud/sql/connector/client.py | 2 +- google/cloud/sql/connector/pg8000.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index 1c805814..31a2f416 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -145,7 +145,7 @@ async def _get_metadata( # Note that we have to check for PSC enablement also because CAS # instances also set the dnsName field. # Remove trailing period from DNS name. Required for SSL in Python - dns_name = ret_dict.get("dnsName", "").rstrip(".") + dns_name = ret_dict.get("dnsName", "") if dns_name and ret_dict.get("pscEnabled"): ip_addresses["PSC"] = dns_name diff --git a/google/cloud/sql/connector/pg8000.py b/google/cloud/sql/connector/pg8000.py index 623738f8..85a4ec3f 100644 --- a/google/cloud/sql/connector/pg8000.py +++ b/google/cloud/sql/connector/pg8000.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import socket import ssl from typing import Any, TYPE_CHECKING @@ -49,7 +50,7 @@ def connect( # Create socket and wrap with context. sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), + socket.create_connection((ip_address.rstrip("."), SERVER_PROXY_PORT)), server_hostname=ip_address, ) From a24822a91b7af8f6d38d347cee0c7e7957483e0a Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 13 Sep 2024 18:24:32 +0000 Subject: [PATCH 02/24] chore: check hostname --- google/cloud/sql/connector/connection_info.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index 7181134d..60a942b2 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -57,9 +57,6 @@ async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLCont return self.context context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - # update ssl.PROTOCOL_TLS_CLIENT default - context.check_hostname = False - # TODO: remove if/else when Python 3.10 is min version. PEP 644 has been # implemented. The ssl module requires OpenSSL 1.1.1 or newer. # verify OpenSSL version supports TLSv1.3 From ce9c5fffd0799e7635432724d21a4a730ab4d783 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 30 Oct 2024 21:31:09 +0000 Subject: [PATCH 03/24] chore: add skeleton of resolver classes --- google/cloud/sql/connector/resolver.py | 37 ++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 google/cloud/sql/connector/resolver.py diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py new file mode 100644 index 00000000..05ed4605 --- /dev/null +++ b/google/cloud/sql/connector/resolver.py @@ -0,0 +1,37 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dns.asyncresolver import Resolver + +from google.cloud.sql.connector.instance import _parse_instance_connection_name + + +class DefaultResolver: + """DefaultResolver simply validates and parses instance connection name.""" + + async def resolve(connection_name: str) -> str: + pass + + +class DnsResolver(Resolver): + """ + DnsResolver resolves domain names into instance connection names using + TXT records in DNS. + """ + + pass + + +async def resolve(dns: str) -> str: + pass From 1d28ab97e48c8f66ca31a48c870c9421c67864ce Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 25 Nov 2024 01:21:20 +0000 Subject: [PATCH 04/24] chore: working code path --- google/cloud/sql/connector/__init__.py | 4 +++ google/cloud/sql/connector/connection_info.py | 3 ++ google/cloud/sql/connector/connector.py | 15 +++++++-- google/cloud/sql/connector/exceptions.py | 7 ++++ google/cloud/sql/connector/instance.py | 10 +++--- google/cloud/sql/connector/resolver.py | 32 ++++++++++++++----- 6 files changed, 55 insertions(+), 16 deletions(-) diff --git a/google/cloud/sql/connector/__init__.py b/google/cloud/sql/connector/__init__.py index 5b06fcd7..99a5097a 100644 --- a/google/cloud/sql/connector/__init__.py +++ b/google/cloud/sql/connector/__init__.py @@ -18,12 +18,16 @@ from google.cloud.sql.connector.connector import create_async_connector from google.cloud.sql.connector.enums import IPTypes from google.cloud.sql.connector.enums import RefreshStrategy +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver from google.cloud.sql.connector.version import __version__ __all__ = [ "__version__", "create_async_connector", "Connector", + "DefaultResolver", + "DnsResolver", "IPTypes", "RefreshStrategy", ] diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index 38aaf107..b738063c 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -57,6 +57,9 @@ async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLCont return self.context context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + # update ssl.PROTOCOL_TLS_CLIENT default + context.check_hostname = False + # TODO: remove if/else when Python 3.10 is min version. PEP 644 has been # implemented. The ssl module requires OpenSSL 1.1.1 or newer. # verify OpenSSL version supports TLSv1.3 diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 7a89d719..14969631 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -37,6 +37,8 @@ import google.cloud.sql.connector.pg8000 as pg8000 import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver from google.cloud.sql.connector.utils import format_database_user from google.cloud.sql.connector.utils import generate_keys @@ -63,6 +65,7 @@ def __init__( user_agent: Optional[str] = None, universe_domain: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, + resolver: DefaultResolver | DnsResolver = DefaultResolver, ) -> None: """Initializes a Connector instance. @@ -104,6 +107,12 @@ def __init__( of the following: RefreshStrategy.LAZY ("LAZY") or RefreshStrategy.BACKGROUND ("BACKGROUND"). Default: RefreshStrategy.BACKGROUND + resolver (DefaultResolver | DnsResolver): The class name of the + resolver to use for resolving the Cloud SQL instance connection + name. To resolve a DNS record to an instance connection name, use + DnsResolver. + Default: DefaultResolver + """ # if refresh_strategy is str, convert to RefreshStrategy enum if isinstance(refresh_strategy, str): @@ -157,6 +166,7 @@ def __init__( self._enable_iam_auth = enable_iam_auth self._quota_project = quota_project self._user_agent = user_agent + self._resolver = resolver() # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes._from_str(ip_type) @@ -269,13 +279,14 @@ async def connect_async( if (instance_connection_string, enable_iam_auth) in self._cache: cache = self._cache[(instance_connection_string, enable_iam_auth)] else: + conn_name = await self._resolver.resolve(instance_connection_string) if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( f"['{instance_connection_string}']: Refresh strategy is set" " to lazy refresh" ) cache = LazyRefreshCache( - instance_connection_string, + conn_name, self._client, self._keys, enable_iam_auth, @@ -286,7 +297,7 @@ async def connect_async( " to backgound refresh" ) cache = RefreshAheadCache( - instance_connection_string, + conn_name, self._client, self._keys, enable_iam_auth, diff --git a/google/cloud/sql/connector/exceptions.py b/google/cloud/sql/connector/exceptions.py index 7bff2300..92e3e566 100644 --- a/google/cloud/sql/connector/exceptions.py +++ b/google/cloud/sql/connector/exceptions.py @@ -70,3 +70,10 @@ class IncompatibleDriverError(Exception): Exception to be raised when the database driver given is for the wrong database engine. (i.e. asyncpg for a MySQL database) """ + + +class DnsResolutionError(Exception): + """ + Exception to be raised when an instance connection name can not be resolved + from a DNS record. + """ diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index f244b8cf..4cd5e4ea 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -26,7 +26,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo -from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import RefreshNotValidError from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter from google.cloud.sql.connector.refresh_utils import _is_valid @@ -47,7 +47,7 @@ class RefreshAheadCache: def __init__( self, - instance_connection_string: str, + conn_name: ConnectionName, client: CloudSQLClient, keys: asyncio.Future, enable_iam_auth: bool = False, @@ -55,8 +55,8 @@ def __init__( """Initializes a RefreshAheadCache instance. Args: - instance_connection_string (str): The Cloud SQL Instance's - connection string (also known as an instance connection name). + conn_name (ConnectionName): The Cloud SQL Instance's + connection name. client (CloudSQLClient): The Cloud SQL Client instance. keys (asyncio.Future): A future to the client's public-private key pair. @@ -64,8 +64,6 @@ def __init__( (Postgres and MySQL) as the default authentication method for all connections. """ - # validate and parse instance connection name - conn_name = _parse_instance_connection_name(instance_connection_string) self._project, self._region, self._instance = ( conn_name.project, conn_name.region, diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py index 05ed4605..fd5d19ad 100644 --- a/google/cloud/sql/connector/resolver.py +++ b/google/cloud/sql/connector/resolver.py @@ -14,14 +14,15 @@ from dns.asyncresolver import Resolver -from google.cloud.sql.connector.instance import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.exceptions import DnsResolutionError class DefaultResolver: """DefaultResolver simply validates and parses instance connection name.""" - async def resolve(connection_name: str) -> str: - pass + async def resolve(self, connection_name: str) -> str: + return _parse_instance_connection_name(connection_name) class DnsResolver(Resolver): @@ -30,8 +31,23 @@ class DnsResolver(Resolver): TXT records in DNS. """ - pass - - -async def resolve(dns: str) -> str: - pass + async def resolve(self, dns: str) -> str: + try: + conn_name = _parse_instance_connection_name(dns) + except ValueError: + # The connection name was not project:region:instance format. + # Attempt to query a TXT record to get connection name. + try: + result = await super().resolve(dns, "TXT", raise_on_no_answer=True) + if result is not None: + rdata = result[0].to_text().strip('"') + conn_name = _parse_instance_connection_name(rdata) + except ValueError: + raise DnsResolutionError( + f"Unable to parse TXT for `{dns}` -> {result[0]}" + ) + except Exception as e: + raise DnsResolutionError( + f"Unable to resolve TXT record for `{dns}`" + ) from e + return conn_name From 5e2e3af2cdc0cfcc515e876131ebfc9931abe4d0 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 28 Nov 2024 16:25:07 +0000 Subject: [PATCH 05/24] chore: update dnsName --- google/cloud/sql/connector/client.py | 2 +- google/cloud/sql/connector/pg8000.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index ed3883ff..ed305ec5 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -145,7 +145,7 @@ async def _get_metadata( # Note that we have to check for PSC enablement also because CAS # instances also set the dnsName field. # Remove trailing period from DNS name. Required for SSL in Python - dns_name = ret_dict.get("dnsName", "") + dns_name = ret_dict.get("dnsName", "").rstrip(".") if dns_name and ret_dict.get("pscEnabled"): ip_addresses["PSC"] = dns_name diff --git a/google/cloud/sql/connector/pg8000.py b/google/cloud/sql/connector/pg8000.py index 85a4ec3f..1f66dde2 100644 --- a/google/cloud/sql/connector/pg8000.py +++ b/google/cloud/sql/connector/pg8000.py @@ -50,7 +50,7 @@ def connect( # Create socket and wrap with context. sock = ctx.wrap_socket( - socket.create_connection((ip_address.rstrip("."), SERVER_PROXY_PORT)), + socket.create_connection((ip_address, SERVER_PROXY_PORT)), server_hostname=ip_address, ) From 697f9a705197517d6d28e0316585c1d3a93c7ea4 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 28 Nov 2024 16:35:28 +0000 Subject: [PATCH 06/24] chore: add dnspython to requirements.txt --- requirements.txt | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index e285d4a0..8b5eb499 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ aiofiles==24.1.0 aiohttp==3.11.7 cryptography==43.0.3 +dnspython==2.7.0 Requests==2.32.3 google-auth==2.36.0 diff --git a/setup.py b/setup.py index bb70449a..79c6acf7 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ "aiofiles", "aiohttp", "cryptography>=42.0.0", + "dnspython>=2.0.0", "Requests", "google-auth>=2.28.0", ] From d0d9d86849f8af88e055520d39735dcab8ce60e4 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 28 Nov 2024 19:18:18 +0000 Subject: [PATCH 07/24] chore: update LazyRefresh cache --- google/cloud/sql/connector/instance.py | 2 +- google/cloud/sql/connector/lazy.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index 4cd5e4ea..050c0f98 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -55,7 +55,7 @@ def __init__( """Initializes a RefreshAheadCache instance. Args: - conn_name (ConnectionName): The Cloud SQL Instance's + conn_name (ConnectionName): The Cloud SQL instance's connection name. client (CloudSQLClient): The Cloud SQL Client instance. keys (asyncio.Future): A future to the client's public-private key diff --git a/google/cloud/sql/connector/lazy.py b/google/cloud/sql/connector/lazy.py index 672f989e..aca72cfb 100644 --- a/google/cloud/sql/connector/lazy.py +++ b/google/cloud/sql/connector/lazy.py @@ -38,7 +38,7 @@ class LazyRefreshCache: def __init__( self, - instance_connection_string: str, + conn_name: str, client: CloudSQLClient, keys: asyncio.Future, enable_iam_auth: bool = False, @@ -46,8 +46,8 @@ def __init__( """Initializes a LazyRefreshCache instance. Args: - instance_connection_string (str): The Cloud SQL Instance's - connection string (also known as an instance connection name). + conn_name (ConnectionName): The Cloud SQL instance's + connection name. client (CloudSQLClient): The Cloud SQL Client instance. keys (asyncio.Future): A future to the client's public-private key pair. @@ -55,8 +55,6 @@ def __init__( (Postgres and MySQL) as the default authentication method for all connections. """ - # validate and parse instance connection name - conn_name = _parse_instance_connection_name(instance_connection_string) self._project, self._region, self._instance = ( conn_name.project, conn_name.region, From 44ca807079177c1718fac885c92f269a5d76b657 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 28 Nov 2024 19:19:29 +0000 Subject: [PATCH 08/24] chore: update variable type --- google/cloud/sql/connector/lazy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/sql/connector/lazy.py b/google/cloud/sql/connector/lazy.py index aca72cfb..ab73785d 100644 --- a/google/cloud/sql/connector/lazy.py +++ b/google/cloud/sql/connector/lazy.py @@ -21,7 +21,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo -from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.refresh_utils import _refresh_buffer logger = logging.getLogger(name=__name__) @@ -38,7 +38,7 @@ class LazyRefreshCache: def __init__( self, - conn_name: str, + conn_name: ConnectionName, client: CloudSQLClient, keys: asyncio.Future, enable_iam_auth: bool = False, From 213ed089b5e7f568d1efe8c9566fa687226a26e6 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 28 Nov 2024 20:17:45 +0000 Subject: [PATCH 09/24] chore: lint --- google/cloud/sql/connector/connector.py | 3 ++- google/cloud/sql/connector/resolver.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 14969631..1e67373e 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -65,7 +65,7 @@ def __init__( user_agent: Optional[str] = None, universe_domain: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, - resolver: DefaultResolver | DnsResolver = DefaultResolver, + resolver: Type[DefaultResolver] | Type[DnsResolver] = DefaultResolver, ) -> None: """Initializes a Connector instance. @@ -107,6 +107,7 @@ def __init__( of the following: RefreshStrategy.LAZY ("LAZY") or RefreshStrategy.BACKGROUND ("BACKGROUND"). Default: RefreshStrategy.BACKGROUND + resolver (DefaultResolver | DnsResolver): The class name of the resolver to use for resolving the Cloud SQL instance connection name. To resolve a DNS record to an instance connection name, use diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py index fd5d19ad..fff6256d 100644 --- a/google/cloud/sql/connector/resolver.py +++ b/google/cloud/sql/connector/resolver.py @@ -15,13 +15,14 @@ from dns.asyncresolver import Resolver from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import DnsResolutionError class DefaultResolver: """DefaultResolver simply validates and parses instance connection name.""" - async def resolve(self, connection_name: str) -> str: + async def resolve(self, connection_name: str) -> ConnectionName: return _parse_instance_connection_name(connection_name) @@ -31,7 +32,7 @@ class DnsResolver(Resolver): TXT records in DNS. """ - async def resolve(self, dns: str) -> str: + async def resolve(self, dns: str) -> ConnectionName: # type: ignore try: conn_name = _parse_instance_connection_name(dns) except ValueError: From 7cf7273e1f0f0cac73888f1d39e80607c6f194dd Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 28 Nov 2024 22:19:42 +0000 Subject: [PATCH 10/24] chore: sort records and more closely match Go --- google/cloud/sql/connector/resolver.py | 39 +++++++++++++++++--------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py index fff6256d..9dc07d23 100644 --- a/google/cloud/sql/connector/resolver.py +++ b/google/cloud/sql/connector/resolver.py @@ -38,17 +38,30 @@ async def resolve(self, dns: str) -> ConnectionName: # type: ignore except ValueError: # The connection name was not project:region:instance format. # Attempt to query a TXT record to get connection name. - try: - result = await super().resolve(dns, "TXT", raise_on_no_answer=True) - if result is not None: - rdata = result[0].to_text().strip('"') - conn_name = _parse_instance_connection_name(rdata) - except ValueError: - raise DnsResolutionError( - f"Unable to parse TXT for `{dns}` -> {result[0]}" - ) - except Exception as e: - raise DnsResolutionError( - f"Unable to resolve TXT record for `{dns}`" - ) from e + conn_name = await self.query_dns(dns) return conn_name + + async def query_dns(self, dns: str) -> ConnectionName: + try: + # Attempt to query the TXT records. + records = await super().resolve(dns, "TXT", raise_on_no_answer=True) + # Sort the TXT record values alphabetically, strip quotes as record + # values can be returned as raw strings + rdata = [record.to_text().strip('"') for record in records] + rdata.sort() + # Attempt to parse records, returning the first valid record. + for record in rdata: + try: + conn_name = _parse_instance_connection_name(record) + return conn_name + except Exception: + continue + # If all records failed to parse, throw error + raise DnsResolutionError( + f"Unable to parse TXT record for `{dns}` -> {rdata[0]}" + ) + # Don't override above DnsResolutionError + except DnsResolutionError: + raise + except Exception as e: + raise DnsResolutionError(f"Unable to resolve TXT record for `{dns}`") from e From e9cd9292d631058212902a2f97eb4338bd4a8280 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 28 Nov 2024 22:43:02 +0000 Subject: [PATCH 11/24] chore: fix existing tests --- tests/conftest.py | 3 ++- tests/unit/test_connector.py | 17 +++++++++-------- tests/unit/test_instance.py | 5 +++-- tests/unit/test_lazy.py | 5 +++-- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 470fe19f..3a1a38a2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,6 +26,7 @@ from unit.mocks import FakeCSQLInstance # type: ignore from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.utils import generate_keys @@ -144,7 +145,7 @@ async def fake_client( async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache, None]: keys = asyncio.create_task(generate_keys()) cache = RefreshAheadCache( - "test-project:test-region:test-instance", + ConnectionName("test-project", "test-region", "test-instance"), client=fake_client, keys=keys, ) diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index fd18f2d5..d4f53ed5 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -26,6 +26,7 @@ from google.cloud.sql.connector import create_async_connector from google.cloud.sql.connector import IPTypes from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import IncompatibleDriverError from google.cloud.sql.connector.instance import RefreshAheadCache @@ -322,18 +323,18 @@ async def test_Connector_remove_cached_bad_instance( async with Connector( credentials=fake_credentials, loop=asyncio.get_running_loop() ) as connector: - conn_name = "bad-project:bad-region:bad-inst" + conn_name = ConnectionName("bad-project", "bad-region", "bad-inst") # populate cache cache = RefreshAheadCache(conn_name, fake_client, connector._keys) - connector._cache[(conn_name, False)] = cache + connector._cache[(str(conn_name), False)] = cache # aiohttp client should throw a 404 ClientResponseError with pytest.raises(ClientResponseError): await connector.connect_async( - conn_name, + str(conn_name), "pg8000", ) # check that cache has been removed from dict - assert (conn_name, False) not in connector._cache + assert (str(conn_name), False) not in connector._cache async def test_Connector_remove_cached_no_ip_type( @@ -348,21 +349,21 @@ async def test_Connector_remove_cached_no_ip_type( async with Connector( credentials=fake_credentials, loop=asyncio.get_running_loop() ) as connector: - conn_name = "test-project:test-region:test-instance" + conn_name = ConnectionName("test-project", "test-region", "test-instance") # populate cache cache = RefreshAheadCache(conn_name, fake_client, connector._keys) - connector._cache[(conn_name, False)] = cache + connector._cache[(str(conn_name), False)] = cache # test instance does not have Private IP, thus should invalidate cache with pytest.raises(CloudSQLIPTypeError): await connector.connect_async( - conn_name, + str(conn_name), "pg8000", user="my-user", password="my-pass", ip_type="private", ) # check that cache has been removed from dict - assert (conn_name, False) not in connector._cache + assert (str(conn_name), False) not in connector._cache def test_default_universe_domain(fake_credentials: Credentials) -> None: diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 5b0887aa..3adfb37c 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -28,6 +28,7 @@ from google.cloud.sql.connector import IPTypes from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.instance import RefreshAheadCache @@ -300,7 +301,7 @@ async def test_ClientResponseError( repeat=True, ) cache = RefreshAheadCache( - "my-project:my-region:my-instance", + ConnectionName("my-project", "my-region", "my-instance"), client, keys, ) @@ -328,7 +329,7 @@ async def test_AutoIAMAuthNotSupportedError(fake_client: CloudSQLClient) -> None # generate client key pair keys = asyncio.create_task(generate_keys()) cache = RefreshAheadCache( - "test-project:test-region:sqlserver-instance", + ConnectionName("test-project", "test-region", "sqlserver-instance"), client=fake_client, keys=keys, enable_iam_auth=True, diff --git a/tests/unit/test_lazy.py b/tests/unit/test_lazy.py index 27cd80b4..344b073e 100644 --- a/tests/unit/test_lazy.py +++ b/tests/unit/test_lazy.py @@ -16,6 +16,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.lazy import LazyRefreshCache from google.cloud.sql.connector.utils import generate_keys @@ -26,7 +27,7 @@ async def test_LazyRefreshCache_connect_info(fake_client: CloudSQLClient) -> Non """ keys = asyncio.create_task(generate_keys()) cache = LazyRefreshCache( - "test-project:test-region:test-instance", + ConnectionName("test-project", "test-region", "test-instance"), client=fake_client, keys=keys, enable_iam_auth=False, @@ -47,7 +48,7 @@ async def test_LazyRefreshCache_force_refresh(fake_client: CloudSQLClient) -> No """ keys = asyncio.create_task(generate_keys()) cache = LazyRefreshCache( - "test-project:test-region:test-instance", + ConnectionName("test-project", "test-region", "test-instance"), client=fake_client, keys=keys, enable_iam_auth=False, From 5abcc08d7bab7465f734245df68bf89c54c8b74c Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 29 Nov 2024 16:23:34 +0000 Subject: [PATCH 12/24] chore: first wave of unit tests --- tests/unit/test_resolver.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 tests/unit/test_resolver.py diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py new file mode 100644 index 00000000..42ad6a19 --- /dev/null +++ b/tests/unit/test_resolver.py @@ -0,0 +1,34 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver + +conn_str = "test-project:test-region:test-instance" +conn_name = ConnectionName("test-project", "test-region", "test-instance") + + +async def test_DefaultResolver() -> None: + """Test DefaultResolver just parses instance connection string.""" + resolver = DefaultResolver() + result = await resolver.resolve(conn_str) + assert result == conn_name + + +async def test_DnsResolver_with_conn_str() -> None: + """Test DnsResolver with instance connection name just parses connection string.""" + resolver = DnsResolver() + result = await resolver.resolve(conn_str) + assert result == conn_name From 8207ff23c31cd9174984a2350e6b40fd7122fe3b Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Sat, 30 Nov 2024 02:50:51 +0000 Subject: [PATCH 13/24] chore: add local dns server and tests --- google/cloud/sql/connector/resolver.py | 6 ++-- requirements-test.txt | 1 + tests/conftest.py | 11 ++++++++ tests/test_zones.toml | 14 ++++++++++ tests/unit/test_resolver.py | 38 ++++++++++++++++++++++++++ 5 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 tests/test_zones.toml diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py index 9dc07d23..15ccd6a2 100644 --- a/google/cloud/sql/connector/resolver.py +++ b/google/cloud/sql/connector/resolver.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dns.asyncresolver import Resolver +import dns.asyncresolver from google.cloud.sql.connector.connection_name import _parse_instance_connection_name from google.cloud.sql.connector.connection_name import ConnectionName @@ -26,7 +26,7 @@ async def resolve(self, connection_name: str) -> ConnectionName: return _parse_instance_connection_name(connection_name) -class DnsResolver(Resolver): +class DnsResolver(dns.asyncresolver.Resolver): """ DnsResolver resolves domain names into instance connection names using TXT records in DNS. @@ -58,7 +58,7 @@ async def query_dns(self, dns: str) -> ConnectionName: continue # If all records failed to parse, throw error raise DnsResolutionError( - f"Unable to parse TXT record for `{dns}` -> {rdata[0]}" + f"Unable to parse TXT record for `{dns}` -> `{rdata[0]}`" ) # Don't override above DnsResolutionError except DnsResolutionError: diff --git a/requirements-test.txt b/requirements-test.txt index 4aeecede..348563cb 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -11,3 +11,4 @@ asyncpg==0.30.0 python-tds==1.16.0 aioresponses==0.7.7 pytest-aiohttp==1.0.5 +dnserver==0.4.0 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 3a1a38a2..379a4226 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,7 @@ from typing import Any, AsyncGenerator, Generator from aiohttp import web +from dnserver import DNSServer import pytest # noqa F401 Needed to run the tests from unit.mocks import FakeCredentials # type: ignore from unit.mocks import FakeCSQLInstance # type: ignore @@ -151,3 +152,13 @@ async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache ) yield cache await cache.close() + + +@pytest.fixture(autouse=True, scope="session") +def dns_server() -> Generator: + """Setup local DNS server for tests with TXT records.""" + server = DNSServer.from_toml("tests/test_zones.toml", port=5053, upstream=None) + server.start() + assert server.is_running + yield server + server.stop() diff --git a/tests/test_zones.toml b/tests/test_zones.toml new file mode 100644 index 00000000..5e7d3fb5 --- /dev/null +++ b/tests/test_zones.toml @@ -0,0 +1,14 @@ +[[zones]] +host = 'db.example.com' +type = 'TXT' +answer = "test-project:test-region:test-instance" + +[[zones]] +host = 'db.example.com' +type = 'TXT' +answer = "test-project2:test-region2:test-instance2" + +[[zones]] +host = 'bad.example.com' +type = 'TXT' +answer = "bad-instance-name" \ No newline at end of file diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py index 42ad6a19..70ac3abc 100644 --- a/tests/unit/test_resolver.py +++ b/tests/unit/test_resolver.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.exceptions import DnsResolutionError from google.cloud.sql.connector.resolver import DefaultResolver from google.cloud.sql.connector.resolver import DnsResolver @@ -32,3 +35,38 @@ async def test_DnsResolver_with_conn_str() -> None: resolver = DnsResolver() result = await resolver.resolve(conn_str) assert result == conn_name + + +async def test_DnsResolver_with_dns_name() -> None: + """Test DnsResolver resolves TXT record into proper instance connection name.""" + resolver = DnsResolver() + resolver.port = 5053 + result = await resolver.resolve(conn_str) + assert result == conn_name + + +async def test_DnsResolver_with_malformed_txt() -> None: + """Test DnsResolver with TXT record that holds malformed instance connection name. + + Should throw DnsResolutionError + """ + resolver = DnsResolver() + resolver.port = 5053 + with pytest.raises(DnsResolutionError) as exc_info: + await resolver.resolve("bad.example.com") + assert ( + exc_info.value.args[0] + == "Unable to parse TXT record for `bad.example.com` -> `bad-instance-name`" + ) + + +async def test_DnsResolver_with_bad_dns_name() -> None: + """Test DnsResolver with bad dns name. + + Should throw DnsResolutionError + """ + resolver = DnsResolver() + resolver.port = 5053 + with pytest.raises(DnsResolutionError) as exc_info: + await resolver.resolve("bad.dns.com") + assert exc_info.value.args[0] == "Unable to resolve TXT record for `bad.dns.com`" From c473ff23b1d5545af6c0ebfea24b26d9c4cb4197 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Sat, 30 Nov 2024 02:59:12 +0000 Subject: [PATCH 14/24] chore: whitespace --- tests/test_zones.toml | 2 +- tests/unit/test_resolver.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_zones.toml b/tests/test_zones.toml index 5e7d3fb5..e07bdf29 100644 --- a/tests/test_zones.toml +++ b/tests/test_zones.toml @@ -11,4 +11,4 @@ answer = "test-project2:test-region2:test-instance2" [[zones]] host = 'bad.example.com' type = 'TXT' -answer = "bad-instance-name" \ No newline at end of file +answer = "bad-instance-name" diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py index 70ac3abc..4aa3adbd 100644 --- a/tests/unit/test_resolver.py +++ b/tests/unit/test_resolver.py @@ -54,10 +54,10 @@ async def test_DnsResolver_with_malformed_txt() -> None: resolver.port = 5053 with pytest.raises(DnsResolutionError) as exc_info: await resolver.resolve("bad.example.com") - assert ( - exc_info.value.args[0] - == "Unable to parse TXT record for `bad.example.com` -> `bad-instance-name`" - ) + assert ( + exc_info.value.args[0] + == "Unable to parse TXT record for `bad.example.com` -> `bad-instance-name`" + ) async def test_DnsResolver_with_bad_dns_name() -> None: From e0213c42a9844185c6ea3e72a767dc592d51127d Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Sat, 30 Nov 2024 03:10:13 +0000 Subject: [PATCH 15/24] chore: update test --- tests/unit/test_resolver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py index 4aa3adbd..24341ac5 100644 --- a/tests/unit/test_resolver.py +++ b/tests/unit/test_resolver.py @@ -41,7 +41,7 @@ async def test_DnsResolver_with_dns_name() -> None: """Test DnsResolver resolves TXT record into proper instance connection name.""" resolver = DnsResolver() resolver.port = 5053 - result = await resolver.resolve(conn_str) + result = await resolver.resolve("db.example.com") assert result == conn_name From bd1d7dc925c266c211dcaec168ac3c73ff60b0cc Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Sat, 30 Nov 2024 03:17:37 +0000 Subject: [PATCH 16/24] chore: update alphabetic test --- tests/test_zones.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_zones.toml b/tests/test_zones.toml index e07bdf29..d91816ed 100644 --- a/tests/test_zones.toml +++ b/tests/test_zones.toml @@ -6,7 +6,7 @@ answer = "test-project:test-region:test-instance" [[zones]] host = 'db.example.com' type = 'TXT' -answer = "test-project2:test-region2:test-instance2" +answer = "zzzz-project:zzzz-region:zzzz-instance" [[zones]] host = 'bad.example.com' From d0f1e3d0dc42b9b6189a3cef1d3ca1938ae1ccd7 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 2 Dec 2024 14:46:17 +0000 Subject: [PATCH 17/24] chore: use dns_server only on DNS tests --- tests/conftest.py | 2 +- tests/unit/test_resolver.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 379a4226..e257bbf3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -154,7 +154,7 @@ async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache await cache.close() -@pytest.fixture(autouse=True, scope="session") +@pytest.fixture() def dns_server() -> Generator: """Setup local DNS server for tests with TXT records.""" server = DNSServer.from_toml("tests/test_zones.toml", port=5053, upstream=None) diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py index 24341ac5..a3e33b06 100644 --- a/tests/unit/test_resolver.py +++ b/tests/unit/test_resolver.py @@ -37,6 +37,7 @@ async def test_DnsResolver_with_conn_str() -> None: assert result == conn_name +@pytest.mark.usefixtures("dns_server") async def test_DnsResolver_with_dns_name() -> None: """Test DnsResolver resolves TXT record into proper instance connection name.""" resolver = DnsResolver() @@ -45,6 +46,7 @@ async def test_DnsResolver_with_dns_name() -> None: assert result == conn_name +@pytest.mark.usefixtures("dns_server") async def test_DnsResolver_with_malformed_txt() -> None: """Test DnsResolver with TXT record that holds malformed instance connection name. @@ -60,6 +62,7 @@ async def test_DnsResolver_with_malformed_txt() -> None: ) +@pytest.mark.usefixtures("dns_server") async def test_DnsResolver_with_bad_dns_name() -> None: """Test DnsResolver with bad dns name. From ea95cf6513ce871d2ef2a1675f839925105ed77f Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 2 Dec 2024 14:57:06 +0000 Subject: [PATCH 18/24] chore: update zones --- tests/test_zones.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_zones.toml b/tests/test_zones.toml index d91816ed..be22188b 100644 --- a/tests/test_zones.toml +++ b/tests/test_zones.toml @@ -6,7 +6,7 @@ answer = "test-project:test-region:test-instance" [[zones]] host = 'db.example.com' type = 'TXT' -answer = "zzzz-project:zzzz-region:zzzz-instance" +answer = "z-project:z-region:z-instance" [[zones]] host = 'bad.example.com' From 90f8b4e169d9f5f88b4c0b71497670e1cddd0ff1 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 2 Dec 2024 15:40:34 +0000 Subject: [PATCH 19/24] chore: test longer lifetime --- tests/conftest.py | 2 +- tests/unit/test_resolver.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index e257bbf3..314a236c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -154,7 +154,7 @@ async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache await cache.close() -@pytest.fixture() +@pytest.fixture def dns_server() -> Generator: """Setup local DNS server for tests with TXT records.""" server = DNSServer.from_toml("tests/test_zones.toml", port=5053, upstream=None) diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py index a3e33b06..319b6d48 100644 --- a/tests/unit/test_resolver.py +++ b/tests/unit/test_resolver.py @@ -42,6 +42,7 @@ async def test_DnsResolver_with_dns_name() -> None: """Test DnsResolver resolves TXT record into proper instance connection name.""" resolver = DnsResolver() resolver.port = 5053 + resolver.lifetime = 10.0 result = await resolver.resolve("db.example.com") assert result == conn_name From 9666a177acbd6e88cdd1d94e7dcbbc5f872e2be0 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 2 Dec 2024 15:45:57 +0000 Subject: [PATCH 20/24] chore: try adding back upstream --- tests/conftest.py | 2 +- tests/unit/test_resolver.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 314a236c..1331d7fe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -157,7 +157,7 @@ async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache @pytest.fixture def dns_server() -> Generator: """Setup local DNS server for tests with TXT records.""" - server = DNSServer.from_toml("tests/test_zones.toml", port=5053, upstream=None) + server = DNSServer.from_toml("tests/test_zones.toml", port=5053) server.start() assert server.is_running yield server diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py index 319b6d48..a3e33b06 100644 --- a/tests/unit/test_resolver.py +++ b/tests/unit/test_resolver.py @@ -42,7 +42,6 @@ async def test_DnsResolver_with_dns_name() -> None: """Test DnsResolver resolves TXT record into proper instance connection name.""" resolver = DnsResolver() resolver.port = 5053 - resolver.lifetime = 10.0 result = await resolver.resolve("db.example.com") assert result == conn_name From 7ce09d30dc5173f403454c1faa1fc5efde9ce3bd Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 2 Dec 2024 17:33:55 +0000 Subject: [PATCH 21/24] chore: re-add upstream --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1331d7fe..314a236c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -157,7 +157,7 @@ async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache @pytest.fixture def dns_server() -> Generator: """Setup local DNS server for tests with TXT records.""" - server = DNSServer.from_toml("tests/test_zones.toml", port=5053) + server = DNSServer.from_toml("tests/test_zones.toml", port=5053, upstream=None) server.start() assert server.is_running yield server From 2c8bace6869645af89f10c44601d2fd604da9b95 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 2 Dec 2024 20:58:43 +0000 Subject: [PATCH 22/24] chore: remove need dns fixture, use mocks --- requirements-test.txt | 3 +- tests/conftest.py | 11 ----- tests/test_zones.toml | 14 ------ tests/unit/test_resolver.py | 87 +++++++++++++++++++++++++++++-------- 4 files changed, 71 insertions(+), 44 deletions(-) delete mode 100644 tests/test_zones.toml diff --git a/requirements-test.txt b/requirements-test.txt index 348563cb..8d64dd5e 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -10,5 +10,4 @@ pg8000==1.31.2 asyncpg==0.30.0 python-tds==1.16.0 aioresponses==0.7.7 -pytest-aiohttp==1.0.5 -dnserver==0.4.0 \ No newline at end of file +pytest-aiohttp==1.0.5 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 314a236c..3a1a38a2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,6 @@ from typing import Any, AsyncGenerator, Generator from aiohttp import web -from dnserver import DNSServer import pytest # noqa F401 Needed to run the tests from unit.mocks import FakeCredentials # type: ignore from unit.mocks import FakeCSQLInstance # type: ignore @@ -152,13 +151,3 @@ async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache ) yield cache await cache.close() - - -@pytest.fixture -def dns_server() -> Generator: - """Setup local DNS server for tests with TXT records.""" - server = DNSServer.from_toml("tests/test_zones.toml", port=5053, upstream=None) - server.start() - assert server.is_running - yield server - server.stop() diff --git a/tests/test_zones.toml b/tests/test_zones.toml deleted file mode 100644 index be22188b..00000000 --- a/tests/test_zones.toml +++ /dev/null @@ -1,14 +0,0 @@ -[[zones]] -host = 'db.example.com' -type = 'TXT' -answer = "test-project:test-region:test-instance" - -[[zones]] -host = 'db.example.com' -type = 'TXT' -answer = "z-project:z-region:z-instance" - -[[zones]] -host = 'bad.example.com' -type = 'TXT' -answer = "bad-instance-name" diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py index a3e33b06..d7404890 100644 --- a/tests/unit/test_resolver.py +++ b/tests/unit/test_resolver.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dns.message +import dns.rdataclass +import dns.rdatatype +import dns.resolver +from mock import patch import pytest from google.cloud.sql.connector.connection_name import ConnectionName @@ -19,8 +24,8 @@ from google.cloud.sql.connector.resolver import DefaultResolver from google.cloud.sql.connector.resolver import DnsResolver -conn_str = "test-project:test-region:test-instance" -conn_name = ConnectionName("test-project", "test-region", "test-instance") +conn_str = "my-project:my-region:my-instance" +conn_name = ConnectionName("my-project", "my-region", "my-instance") async def test_DefaultResolver() -> None: @@ -37,32 +42,78 @@ async def test_DnsResolver_with_conn_str() -> None: assert result == conn_name -@pytest.mark.usefixtures("dns_server") +query_text = """id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD RA +;QUESTION +db.example.com. IN TXT +;ANSWER +db.example.com. 0 IN TXT "test-project:test-region:test-instance" +db.example.com. 0 IN TXT "my-project:my-region:my-instance" +;AUTHORITY +;ADDITIONAL +""" + + async def test_DnsResolver_with_dns_name() -> None: - """Test DnsResolver resolves TXT record into proper instance connection name.""" - resolver = DnsResolver() - resolver.port = 5053 - result = await resolver.resolve("db.example.com") - assert result == conn_name + """Test DnsResolver resolves TXT record into proper instance connection name. + + Should sort valid TXT records alphabetically and take first one. + """ + # Patch DNS resolution with valid TXT records + with patch("dns.asyncresolver.Resolver.resolve") as mock_connect: + answer = dns.resolver.Answer( + "db.example.com", + dns.rdatatype.TXT, + dns.rdataclass.IN, + dns.message.from_text(query_text), + ) + mock_connect.return_value = answer + resolver = DnsResolver() + resolver.port = 5053 + # Resolution should return first value sorted alphabetically + result = await resolver.resolve("db.example.com") + assert result == conn_name + + +query_text_malformed = """id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD RA +;QUESTION +bad.example.com. IN TXT +;ANSWER +bad.example.com. 0 IN TXT "malformed-instance-name" +;AUTHORITY +;ADDITIONAL +""" -@pytest.mark.usefixtures("dns_server") async def test_DnsResolver_with_malformed_txt() -> None: """Test DnsResolver with TXT record that holds malformed instance connection name. Should throw DnsResolutionError """ - resolver = DnsResolver() - resolver.port = 5053 - with pytest.raises(DnsResolutionError) as exc_info: - await resolver.resolve("bad.example.com") - assert ( - exc_info.value.args[0] - == "Unable to parse TXT record for `bad.example.com` -> `bad-instance-name`" + # patch DNS resolution with malformed TXT record + with patch("dns.asyncresolver.Resolver.resolve") as mock_connect: + answer = dns.resolver.Answer( + "bad.example.com", + dns.rdatatype.TXT, + dns.rdataclass.IN, + dns.message.from_text(query_text_malformed), ) + mock_connect.return_value = answer + resolver = DnsResolver() + resolver.port = 5053 + with pytest.raises(DnsResolutionError) as exc_info: + await resolver.resolve("bad.example.com") + assert ( + exc_info.value.args[0] + == "Unable to parse TXT record for `bad.example.com` -> `malformed-instance-name`" + ) -@pytest.mark.usefixtures("dns_server") async def test_DnsResolver_with_bad_dns_name() -> None: """Test DnsResolver with bad dns name. @@ -70,6 +121,8 @@ async def test_DnsResolver_with_bad_dns_name() -> None: """ resolver = DnsResolver() resolver.port = 5053 + # set lifetime to 1 second for shorter timeout + resolver.lifetime = 1 with pytest.raises(DnsResolutionError) as exc_info: await resolver.resolve("bad.dns.com") assert exc_info.value.args[0] == "Unable to resolve TXT record for `bad.dns.com`" From 712023b237a5294d05495c26149a5d22dcb9b85d Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 2 Dec 2024 21:22:45 +0000 Subject: [PATCH 23/24] chore: add usage example in README --- README.md | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/README.md b/README.md index 28553f97..1f0e633b 100644 --- a/README.md +++ b/README.md @@ -365,6 +365,69 @@ conn = connector.connect( ) ``` +### Using DNS domain names to identify instances + +The connector can be configured to use DNS to look up an instance. This would +allow you to configure your application to connect to a database instance, and +centrally configure which instance in your DNS zone. + +#### Configure your DNS Records + +Add a DNS TXT record for the Cloud SQL instance to a **private** DNS server +or a private Google Cloud DNS Zone used by your application. + +> [!NOTE] +> +> You are strongly discouraged from adding DNS records for your +> Cloud SQL instances to a public DNS server. This would allow anyone on the +> internet to discover the Cloud SQL instance name. + +For example: suppose you wanted to use the domain name +`prod-db.mycompany.example.com` to connect to your database instance +`my-project:region:my-instance`. You would create the following DNS record: + +* Record type: `TXT` +* Name: `prod-db.mycompany.example.com` – This is the domain name used by the application +* Value: `my-project:my-region:my-instance` – This is the Cloud SQL instance connection name + +#### Configure the connector + +Configure the connector to resolve DNS names by initializing it with +`resolver=DnsResolver` and replacing the instance connection name with the DNS +name in `connector.connect`: + +```python +from google.cloud.sql.connector import Connector, DnsResolver +import pymysql +import sqlalchemy + +# helper function to return SQLAlchemy connection pool +def init_connection_pool(connector: Connector) -> sqlalchemy.engine.Engine: + # function used to generate database connection + def getconn() -> pymysql.connections.Connection: + conn = connector.connect( + "prod-db.mycompany.example.com", # using DNS name + "pymysql", + user="my-user", + password="my-password", + db="my-db-name" + ) + return conn + + # create connection pool + pool = sqlalchemy.create_engine( + "mysql+pymysql://", + creator=getconn, + ) + return pool + +# initialize Cloud SQL Python Connector with `resolver=DnsResolver` +with Connector(resolver=DnsResolver) as connector: + # initialize connection pool + pool = init_connection_pool(connector) + # ... use SQLAlchemy engine normally +``` + ### Using the Python Connector with Python Web Frameworks The Python Connector can be used alongside popular Python web frameworks such From feecd84fb131fb7d5203fada8bd2698ea0fdb1eb Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 2 Dec 2024 21:24:44 +0000 Subject: [PATCH 24/24] chore: add newline at EOF --- requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-test.txt b/requirements-test.txt index 8d64dd5e..4aeecede 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -10,4 +10,4 @@ pg8000==1.31.2 asyncpg==0.30.0 python-tds==1.16.0 aioresponses==0.7.7 -pytest-aiohttp==1.0.5 \ No newline at end of file +pytest-aiohttp==1.0.5