diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index d6ca4b67..24dc9cbc 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -13,7 +13,7 @@ # limitations under the License. import asyncio -from aiohttp import ClientResponseError +from aiohttp import ClientResponseError, RequestInfo from datetime import datetime from datetime import timedelta from datetime import timezone @@ -214,7 +214,7 @@ def __init__( async def _get_metadata(self, *args: Any, **kwargs: Any) -> str: instance_uri = f"projects/{self.instance.project}/locations/{self.instance.region}/clusters/{self.instance.cluster}/instances/{self.instance.name}" if instance_uri not in self.existing_instances: - raise ClientResponseError(None, 404) + raise ClientResponseError(RequestInfo(url = instance_uri, method = "GET", headers = None), 404) return self.instance.ip_addrs async def _get_client_certificate( @@ -226,7 +226,7 @@ async def _get_client_certificate( ) -> Tuple[str, List[str]]: instance_uri = f"projects/{self.instance.project}/locations/{self.instance.region}/clusters/{self.instance.cluster}/instances/{self.instance.name}" if instance_uri not in self.existing_instances: - raise ClientResponseError(None, 404) + raise ClientResponseError(RequestInfo(url = instance_uri, method = "POST", headers = None), 404) root_cert, intermediate_cert, server_cert = self.instance.get_pem_certs() # encode public key to bytes pub_key_bytes: rsa.RSAPublicKey = serialization.load_pem_public_key( diff --git a/tests/unit/test_async_connector.py b/tests/unit/test_async_connector.py index e2b22b10..96e950cf 100644 --- a/tests/unit/test_async_connector.py +++ b/tests/unit/test_async_connector.py @@ -19,10 +19,13 @@ from mocks import FakeAlloyDBClient from mocks import FakeConnectionInfo from mocks import FakeCredentials +from mocks import FakeInstance +from aiohttp import ClientResponseError import pytest from google.cloud.alloydb.connector import AsyncConnector from google.cloud.alloydb.connector import IPTypes +from google.cloud.alloydb.connector.instance import RefreshAheadCache ALLOYDB_API_ENDPOINT = "https://alloydb.googleapis.com" @@ -294,3 +297,47 @@ async def test_async_connect_bad_ip_type( exc_info.value.args[0] == f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE', 'PSC'." ) + +async def test_Connector_remove_cached_bad_instance(credentials: FakeCredentials) -> None: + """When a Connector attempts to retrieve connection info for a + non-existent instance, it should delete the instance from + the cache and ensure no background refresh happens (which would be + wasted cycles). + """ + instance_uri = "projects/test-project/locations/test-region/clusters/test-cluster/instances/bad-test-instance" + async with AsyncConnector(credentials=credentials) as connector: + connector._client = FakeAlloyDBClient(instance = FakeInstance(name = "bad-test-instance")) + cache = RefreshAheadCache(instance_uri, connector._client, connector._keys) + connector._cache[instance_uri] = cache + with pytest.raises(ClientResponseError): + await connector.connect(instance_uri, "asyncpg") + assert instance_uri not in connector._cache + + +# async def test_Connector_remove_cached_no_ip_type( +# fake_credentials: Credentials, fake_client: CloudSQLClient +# ) -> None: +# """When a Connector attempts to connect and preferred IP type is not present, +# it should delete the instance from the cache and ensure no background refresh +# happens (which would be wasted cycles). +# """ +# # set instance to only have public IP +# fake_client.instance.ip_addrs = {"PRIMARY": "127.0.0.1"} +# async with Connector( +# credentials=fake_credentials, loop=asyncio.get_running_loop() +# ) as connector: +# conn_name = "test-project:test-region:test-instance" +# # populate cache +# cache = RefreshAheadCache(conn_name, fake_client, connector._keys) +# connector._cache[(conn_name, False)] = cache +# # test instance does not have Private IP, thus should invalidate cache +# with pytest.raises(CloudSQLIPTypeError): +# await connector.connect_async( +# conn_name, +# "pg8000", +# user="my-user", +# password="my-pass", +# ip_type="private", +# ) +# # check that cache has been removed from dict +# assert (conn_name, False) not in connector._cache