Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: wrap generate_keys() in future #168

Merged
merged 8 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def __init__(
# otherwise use application default credentials
else:
self._credentials, _ = default(scopes=scopes)
self._keys = generate_keys()
self._keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), self._loop),
loop=self._loop,
)
self._client: Optional[AlloyDBClient] = None

def connect(self, instance_uri: str, driver: str, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -123,11 +126,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
if instance_uri in self._instances:
instance = self._instances[instance_uri]
else:
instance = Instance(
instance_uri,
self._client,
self._keys,
)
instance = Instance(instance_uri, self._client, self._keys)
self._instances[instance_uri] = instance

connect_func = {
Expand Down
4 changes: 2 additions & 2 deletions google/cloud/alloydb/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
self,
instance_uri: str,
client: AlloyDBClient,
keys: Tuple[rsa.RSAPrivateKey, str],
keys: asyncio.Future[Tuple[rsa.RSAPrivateKey, str]],
) -> None:
# validate and parse instance_uri
instance_uri_split = instance_uri.split("/")
Expand Down Expand Up @@ -98,7 +98,7 @@ async def _perform_refresh(self) -> RefreshResult:

try:
await self._refresh_rate_limiter.acquire()
priv_key, pub_key = self._keys
priv_key, pub_key = await self._keys
# fetch metadata
metadata_task = asyncio.create_task(
self._client._get_metadata(
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/alloydb/connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _write_to_file(
return (ca_filename, cert_chain_filename, key_filename)


def generate_keys() -> Tuple[rsa.RSAPrivateKey, str]:
async def generate_keys() -> Tuple[rsa.RSAPrivateKey, str]:
priv_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
pub_key = (
priv_key.public_key()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def test__get_client_certificate(
Test _get_client_certificate returns successfully.
"""
test_client = AlloyDBClient("", "", credentials, client)
keys = generate_keys()
keys = await generate_keys()
certs = await test_client._get_client_certificate(
"test-project", "test-region", "test-cluster", keys[1]
)
Expand Down
16 changes: 8 additions & 8 deletions tests/unit/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def test_Instance_init() -> None:
Test to check whether the __init__ method of Instance
can tell if the instance URI that's passed in is formatted correctly.
"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
async with aiohttp.ClientSession() as client:
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -52,7 +52,7 @@ async def test_Instance_init_invalid_instant_uri() -> None:
Test to check whether the __init__ method of Instance
will throw error for invalid instance URI.
"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
async with aiohttp.ClientSession() as client:
with pytest.raises(ValueError):
Instance("invalid/instance/uri/", client, keys)
Expand All @@ -64,7 +64,7 @@ async def test_Instance_close() -> None:
Test that Instance's close method
cancels tasks gracefully.
"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -84,7 +84,7 @@ async def test_Instance_close() -> None:
@pytest.mark.asyncio
async def test_perform_refresh() -> None:
"""Test that _perform refresh returns valid RefreshResult"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -104,7 +104,7 @@ async def test_schedule_refresh_replaces_result() -> None:
Test to check whether _schedule_refresh replaces a valid refresh result
with another refresh result.
"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -131,7 +131,7 @@ async def test_schedule_refresh_wont_replace_valid_result_with_invalid() -> None
Test to check whether _schedule_refresh won't replace a valid
refresh result with an invalid one.
"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand Down Expand Up @@ -160,7 +160,7 @@ async def test_schedule_refresh_expired_cert() -> None:
Test to check whether _schedule_refresh will throw RefreshError on
expired certificate.
"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
client = FakeAlloyDBClient()
# set certificate to be expired
client.instance.cert_before = datetime.now() - timedelta(minutes=20)
Expand All @@ -182,7 +182,7 @@ async def test_force_refresh_cancels_pending_refresh() -> None:
"""
Test that force_refresh cancels pending task if refresh_in_progress event is not set.
"""
keys = generate_keys()
keys = asyncio.create_task(generate_keys())
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand Down
Loading