Skip to content

Commit

Permalink
Add unit tests for async connector
Browse files Browse the repository at this point in the history
  • Loading branch information
rhatgadkar-goog committed Nov 13, 2024
1 parent 01cc3aa commit e4537a5
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import asyncio
from aiohttp import ClientResponseError
from aiohttp import ClientResponseError, RequestInfo
from datetime import datetime
from datetime import timedelta
from datetime import timezone
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(
async def _get_metadata(self, *args: Any, **kwargs: Any) -> str:
instance_uri = f"projects/{self.instance.project}/locations/{self.instance.region}/clusters/{self.instance.cluster}/instances/{self.instance.name}"
if instance_uri not in self.existing_instances:
raise ClientResponseError(None, 404)
raise ClientResponseError(RequestInfo(url = instance_uri, method = "GET", headers = None), 404)
return self.instance.ip_addrs

async def _get_client_certificate(
Expand All @@ -226,7 +226,7 @@ async def _get_client_certificate(
) -> Tuple[str, List[str]]:
instance_uri = f"projects/{self.instance.project}/locations/{self.instance.region}/clusters/{self.instance.cluster}/instances/{self.instance.name}"
if instance_uri not in self.existing_instances:
raise ClientResponseError(None, 404)
raise ClientResponseError(RequestInfo(url = instance_uri, method = "POST", headers = None), 404)
root_cert, intermediate_cert, server_cert = self.instance.get_pem_certs()
# encode public key to bytes
pub_key_bytes: rsa.RSAPublicKey = serialization.load_pem_public_key(
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/test_async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
from mocks import FakeAlloyDBClient
from mocks import FakeConnectionInfo
from mocks import FakeCredentials
from mocks import FakeInstance
from aiohttp import ClientResponseError
import pytest

from google.cloud.alloydb.connector import AsyncConnector
from google.cloud.alloydb.connector import IPTypes
from google.cloud.alloydb.connector.instance import RefreshAheadCache

ALLOYDB_API_ENDPOINT = "https://alloydb.googleapis.com"

Expand Down Expand Up @@ -294,3 +297,47 @@ async def test_async_connect_bad_ip_type(
exc_info.value.args[0]
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE', 'PSC'."
)

async def test_Connector_remove_cached_bad_instance(credentials: FakeCredentials) -> None:
"""When a Connector attempts to retrieve connection info for a
non-existent instance, it should delete the instance from
the cache and ensure no background refresh happens (which would be
wasted cycles).
"""
instance_uri = "projects/test-project/locations/test-region/clusters/test-cluster/instances/bad-test-instance"
async with AsyncConnector(credentials=credentials) as connector:
connector._client = FakeAlloyDBClient(instance = FakeInstance(name = "bad-test-instance"))
cache = RefreshAheadCache(instance_uri, connector._client, connector._keys)
connector._cache[instance_uri] = cache
with pytest.raises(ClientResponseError):
await connector.connect(instance_uri, "asyncpg")
assert instance_uri not in connector._cache


# async def test_Connector_remove_cached_no_ip_type(
# fake_credentials: Credentials, fake_client: CloudSQLClient
# ) -> None:
# """When a Connector attempts to connect and preferred IP type is not present,
# it should delete the instance from the cache and ensure no background refresh
# happens (which would be wasted cycles).
# """
# # set instance to only have public IP
# fake_client.instance.ip_addrs = {"PRIMARY": "127.0.0.1"}
# async with Connector(
# credentials=fake_credentials, loop=asyncio.get_running_loop()
# ) as connector:
# conn_name = "test-project:test-region:test-instance"
# # populate cache
# cache = RefreshAheadCache(conn_name, fake_client, connector._keys)
# connector._cache[(conn_name, False)] = cache
# # test instance does not have Private IP, thus should invalidate cache
# with pytest.raises(CloudSQLIPTypeError):
# await connector.connect_async(
# conn_name,
# "pg8000",
# user="my-user",
# password="my-pass",
# ip_type="private",
# )
# # check that cache has been removed from dict
# assert (conn_name, False) not in connector._cache

0 comments on commit e4537a5

Please sign in to comment.