diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 5370fee9..38aab682 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -69,9 +69,9 @@ def __init__( self._credentials, _ = default(scopes=scopes) self._keys = asyncio.wrap_future( asyncio.run_coroutine_threadsafe(generate_keys(), self._loop), - loop=self.__loop - ) - + loop=self._loop, + ) + self._client: Optional[AlloyDBClient] = None def connect(self, instance_uri: str, driver: str, **kwargs: Any) -> Any: @@ -127,11 +127,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> if instance_uri in self._instances: instance = self._instances[instance_uri] else: - instance = Instance( - instance_uri, - self._client, - self._keys, - ) + instance = Instance(instance_uri, self._client, self._keys) self._instances[instance_uri] = instance connect_func = { diff --git a/google/cloud/alloydb/connector/instance.py b/google/cloud/alloydb/connector/instance.py index 24fd142b..769ed1ee 100644 --- a/google/cloud/alloydb/connector/instance.py +++ b/google/cloud/alloydb/connector/instance.py @@ -28,7 +28,8 @@ if TYPE_CHECKING: import ssl - from cryptography.hazmat.primitives.asymmetric import rsa + + # from cryptography.hazmat.primitives.asymmetric import rsa from google.cloud.alloydb.connector.client import AlloyDBClient logger = logging.getLogger(name=__name__) @@ -53,7 +54,7 @@ def __init__( self, instance_uri: str, client: AlloyDBClient, - keys: Tuple[rsa.RSAPrivateKey, str], + keys: asyncio.Future, ) -> None: # validate and parse instance_uri instance_uri_split = instance_uri.split("/") diff --git a/google/cloud/alloydb/connector/utils.py b/google/cloud/alloydb/connector/utils.py index 60ade716..86c3a7bb 100644 --- a/google/cloud/alloydb/connector/utils.py +++ b/google/cloud/alloydb/connector/utils.py @@ -49,8 +49,7 @@ def _write_to_file( async def generate_keys() -> Tuple[rsa.RSAPrivateKey, str]: - priv_key = rsa.generate_private_key(public_exponent=65537, - key_size=2048) + priv_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) pub_key = ( priv_key.public_key() diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index ec212035..4f558a3c 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -31,7 +31,10 @@ async def test_Instance_init() -> None: Test to check whether the __init__ method of Instance can tell if the instance URI that's passed in is formatted correctly. """ - keys = await generate_keys() + event_loop = asyncio.get_running_loop() + keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop + ) async with aiohttp.ClientSession() as client: instance = Instance( "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", @@ -52,7 +55,10 @@ async def test_Instance_init_invalid_instant_uri() -> None: Test to check whether the __init__ method of Instance will throw error for invalid instance URI. """ - keys = await generate_keys() + event_loop = asyncio.get_running_loop() + keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop + ) async with aiohttp.ClientSession() as client: with pytest.raises(ValueError): Instance("invalid/instance/uri/", client, keys) @@ -64,7 +70,10 @@ async def test_Instance_close() -> None: Test that Instance's close method cancels tasks gracefully. """ - keys = await generate_keys() + event_loop = asyncio.get_running_loop() + keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop + ) client = FakeAlloyDBClient() instance = Instance( "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", @@ -84,7 +93,10 @@ async def test_Instance_close() -> None: @pytest.mark.asyncio async def test_perform_refresh() -> None: """Test that _perform refresh returns valid RefreshResult""" - keys = await generate_keys() + event_loop = asyncio.get_running_loop() + keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop + ) client = FakeAlloyDBClient() instance = Instance( "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", @@ -104,7 +116,10 @@ async def test_schedule_refresh_replaces_result() -> None: Test to check whether _schedule_refresh replaces a valid refresh result with another refresh result. """ - keys = await generate_keys() + event_loop = asyncio.get_running_loop() + keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop + ) client = FakeAlloyDBClient() instance = Instance( "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", @@ -131,7 +146,10 @@ async def test_schedule_refresh_wont_replace_valid_result_with_invalid() -> None Test to check whether _schedule_refresh won't replace a valid refresh result with an invalid one. """ - keys = await generate_keys() + event_loop = asyncio.get_running_loop() + keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop + ) client = FakeAlloyDBClient() instance = Instance( "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", @@ -160,7 +178,10 @@ async def test_schedule_refresh_expired_cert() -> None: Test to check whether _schedule_refresh will throw RefreshError on expired certificate. """ - keys = await generate_keys() + event_loop = asyncio.get_running_loop() + keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop + ) client = FakeAlloyDBClient() # set certificate to be expired client.instance.cert_before = datetime.now() - timedelta(minutes=20) @@ -182,7 +203,10 @@ async def test_force_refresh_cancels_pending_refresh() -> None: """ Test that force_refresh cancels pending task if refresh_in_progress event is not set. """ - keys = await generate_keys() + event_loop = asyncio.get_running_loop() + keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop + ) client = FakeAlloyDBClient() instance = Instance( "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",