Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: wrap generate_keys() in future #168

Merged
merged 8 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def __init__(
# otherwise use application default credentials
else:
self._credentials, _ = default(scopes=scopes)
self._keys = generate_keys()
self._keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), self._loop),
loop=self._loop,
)

self._client: Optional[AlloyDBClient] = None

def connect(self, instance_uri: str, driver: str, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -123,11 +127,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
if instance_uri in self._instances:
instance = self._instances[instance_uri]
else:
instance = Instance(
instance_uri,
self._client,
self._keys,
)
instance = Instance(instance_uri, self._client, self._keys)
self._instances[instance_uri] = instance

connect_func = {
Expand Down
7 changes: 4 additions & 3 deletions google/cloud/alloydb/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

if TYPE_CHECKING:
import ssl
from cryptography.hazmat.primitives.asymmetric import rsa

# from cryptography.hazmat.primitives.asymmetric import rsa
jackwotherspoon marked this conversation as resolved.
Show resolved Hide resolved
from google.cloud.alloydb.connector.client import AlloyDBClient

logger = logging.getLogger(name=__name__)
Expand All @@ -53,7 +54,7 @@ def __init__(
self,
instance_uri: str,
client: AlloyDBClient,
keys: Tuple[rsa.RSAPrivateKey, str],
keys: asyncio.Future,
Copy link
Collaborator

@jackwotherspoon jackwotherspoon Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can try to type hint the return of the future? I believe this was introduced in Python 3.8 so we should be okay.

Suggested change
keys: asyncio.Future,
keys: asyncio.Future[Tuple[rsa.RSAPrivateKey, str]],

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jackwotherspoon , thanks for the feedback! I have pushed the above change in the latest push

) -> None:
# validate and parse instance_uri
instance_uri_split = instance_uri.split("/")
Expand Down Expand Up @@ -98,7 +99,7 @@ async def _perform_refresh(self) -> RefreshResult:

try:
await self._refresh_rate_limiter.acquire()
priv_key, pub_key = self._keys
priv_key, pub_key = await self._keys
# fetch metadata
metadata_task = asyncio.create_task(
self._client._get_metadata(
Expand Down
5 changes: 4 additions & 1 deletion google/cloud/alloydb/connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ def _write_to_file(
return (ca_filename, cert_chain_filename, key_filename)


def generate_keys() -> Tuple[rsa.RSAPrivateKey, str]:
async def generate_keys() -> Tuple[rsa.RSAPrivateKey, str]:

priv_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)

pub_key = (
priv_key.public_key()
.public_bytes(
Expand All @@ -57,4 +59,5 @@ def generate_keys() -> Tuple[rsa.RSAPrivateKey, str]:
)
.decode("UTF-8")
)

return (priv_key, pub_key)
2 changes: 1 addition & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def test__get_client_certificate(
Test _get_client_certificate returns successfully.
"""
test_client = AlloyDBClient("", "", credentials, client)
keys = generate_keys()
keys = await generate_keys()
certs = await test_client._get_client_certificate(
"test-project", "test-region", "test-cluster", keys[1]
)
Expand Down
40 changes: 32 additions & 8 deletions tests/unit/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ async def test_Instance_init() -> None:
Test to check whether the __init__ method of Instance
can tell if the instance URI that's passed in is formatted correctly.
"""
keys = generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just await this directly here...? The reason we have to use the future in the main code is because the loop is running in a separate thread. The tests run in the same thread so I think we can just await it directly.

Suggested change
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
keys = await generate_keys()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for other test cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jackwotherspoon , I have removed the separate Async loops.

But one doubt, since the keys are being passed to the Instance object as future, due to double await, the Tests are throwing error that "an awaited expression can't be awaited again". So I removed the await for the generate_keys() function so finally it's being awaited once in the main Instance loop.

Tell me if it's fine otherwise we can revert it back.

async with aiohttp.ClientSession() as client:
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -52,7 +55,10 @@ async def test_Instance_init_invalid_instant_uri() -> None:
Test to check whether the __init__ method of Instance
will throw error for invalid instance URI.
"""
keys = generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
async with aiohttp.ClientSession() as client:
with pytest.raises(ValueError):
Instance("invalid/instance/uri/", client, keys)
Expand All @@ -64,7 +70,10 @@ async def test_Instance_close() -> None:
Test that Instance's close method
cancels tasks gracefully.
"""
keys = generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -84,7 +93,10 @@ async def test_Instance_close() -> None:
@pytest.mark.asyncio
async def test_perform_refresh() -> None:
"""Test that _perform refresh returns valid RefreshResult"""
keys = generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -104,7 +116,10 @@ async def test_schedule_refresh_replaces_result() -> None:
Test to check whether _schedule_refresh replaces a valid refresh result
with another refresh result.
"""
keys = generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -131,7 +146,10 @@ async def test_schedule_refresh_wont_replace_valid_result_with_invalid() -> None
Test to check whether _schedule_refresh won't replace a valid
refresh result with an invalid one.
"""
keys = generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand Down Expand Up @@ -160,7 +178,10 @@ async def test_schedule_refresh_expired_cert() -> None:
Test to check whether _schedule_refresh will throw RefreshError on
expired certificate.
"""
keys = generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
client = FakeAlloyDBClient()
# set certificate to be expired
client.instance.cert_before = datetime.now() - timedelta(minutes=20)
Expand All @@ -182,7 +203,10 @@ async def test_force_refresh_cancels_pending_refresh() -> None:
"""
Test that force_refresh cancels pending task if refresh_in_progress event is not set.
"""
keys = generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand Down
Loading