diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 71031fad..b77deee7 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -67,7 +67,10 @@ def __init__( # otherwise use application default credentials else: self._credentials, _ = default(scopes=scopes) - self._keys = generate_keys() + self._keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), self._loop), + loop=self._loop, + ) self._client: Optional[AlloyDBClient] = None def connect(self, instance_uri: str, driver: str, **kwargs: Any) -> Any: @@ -123,11 +126,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> if instance_uri in self._instances: instance = self._instances[instance_uri] else: - instance = Instance( - instance_uri, - self._client, - self._keys, - ) + instance = Instance(instance_uri, self._client, self._keys) self._instances[instance_uri] = instance connect_func = { diff --git a/google/cloud/alloydb/connector/instance.py b/google/cloud/alloydb/connector/instance.py index 4a06ade1..3e5545dd 100644 --- a/google/cloud/alloydb/connector/instance.py +++ b/google/cloud/alloydb/connector/instance.py @@ -53,7 +53,7 @@ def __init__( self, instance_uri: str, client: AlloyDBClient, - keys: Tuple[rsa.RSAPrivateKey, str], + keys: asyncio.Future[Tuple[rsa.RSAPrivateKey, str]], ) -> None: # validate and parse instance_uri instance_uri_split = instance_uri.split("/") @@ -98,7 +98,7 @@ async def _perform_refresh(self) -> RefreshResult: try: await self._refresh_rate_limiter.acquire() - priv_key, pub_key = self._keys + priv_key, pub_key = await self._keys # fetch metadata metadata_task = asyncio.create_task( self._client._get_metadata( diff --git a/google/cloud/alloydb/connector/utils.py b/google/cloud/alloydb/connector/utils.py index 956e8bd1..a549f707 100644 --- a/google/cloud/alloydb/connector/utils.py +++ b/google/cloud/alloydb/connector/utils.py @@ -47,7 +47,7 @@ def _write_to_file( return (ca_filename, cert_chain_filename, key_filename) -def generate_keys() -> Tuple[rsa.RSAPrivateKey, str]: +async def generate_keys() -> Tuple[rsa.RSAPrivateKey, str]: priv_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) pub_key = ( priv_key.public_key() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 77201558..fa686ed4 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -77,7 +77,7 @@ async def test__get_client_certificate( Test _get_client_certificate returns successfully. """ test_client = AlloyDBClient("", "", credentials, client) - keys = generate_keys() + keys = await generate_keys() certs = await test_client._get_client_certificate( "test-project", "test-region", "test-cluster", keys[1] ) diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 949d42c8..cf518e97 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -31,7 +31,7 @@ async def test_Instance_init() -> None: Test to check whether the __init__ method of Instance can tell if the instance URI that's passed in is formatted correctly. """ - keys = generate_keys() + keys = asyncio.create_task(generate_keys()) async with aiohttp.ClientSession() as client: instance = Instance( "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", @@ -52,7 +52,7 @@ async def test_Instance_init_invalid_instant_uri() -> None: Test to check whether the __init__ method of Instance will throw error for invalid instance URI. """ - keys = generate_keys() + keys = asyncio.create_task(generate_keys()) async with aiohttp.ClientSession() as client: with pytest.raises(ValueError): Instance("invalid/instance/uri/", client, keys) @@ -64,7 +64,7 @@ async def test_Instance_close() -> None: Test that Instance's close method cancels tasks gracefully. """ - keys = generate_keys() + keys = asyncio.create_task(generate_keys()) client = FakeAlloyDBClient() instance = Instance( "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", @@ -84,7 +84,7 @@ async def test_Instance_close() -> None: @pytest.mark.asyncio async def test_perform_refresh() -> None: """Test that _perform refresh returns valid RefreshResult""" - keys = generate_keys() + keys = asyncio.create_task(generate_keys()) client = FakeAlloyDBClient() instance = Instance( "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", @@ -104,7 +104,7 @@ async def test_schedule_refresh_replaces_result() -> None: Test to check whether _schedule_refresh replaces a valid refresh result with another refresh result. """ - keys = generate_keys() + keys = asyncio.create_task(generate_keys()) client = FakeAlloyDBClient() instance = Instance( "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", @@ -131,7 +131,7 @@ async def test_schedule_refresh_wont_replace_valid_result_with_invalid() -> None Test to check whether _schedule_refresh won't replace a valid refresh result with an invalid one. """ - keys = generate_keys() + keys = asyncio.create_task(generate_keys()) client = FakeAlloyDBClient() instance = Instance( "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", @@ -160,7 +160,7 @@ async def test_schedule_refresh_expired_cert() -> None: Test to check whether _schedule_refresh will throw RefreshError on expired certificate. """ - keys = generate_keys() + keys = asyncio.create_task(generate_keys()) client = FakeAlloyDBClient() # set certificate to be expired client.instance.cert_before = datetime.now() - timedelta(minutes=20) @@ -182,7 +182,7 @@ async def test_force_refresh_cancels_pending_refresh() -> None: """ Test that force_refresh cancels pending task if refresh_in_progress event is not set. """ - keys = generate_keys() + keys = asyncio.create_task(generate_keys()) client = FakeAlloyDBClient() instance = Instance( "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",