Skip to content

Commit

Permalink
Running test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
RahulDubey391 committed Nov 24, 2023
1 parent 3328422 commit 39bb079
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 20 deletions.
12 changes: 4 additions & 8 deletions google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def __init__(
self._credentials, _ = default(scopes=scopes)
self._keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), self._loop),
loop=self.__loop
)
loop=self._loop,
)

self._client: Optional[AlloyDBClient] = None

def connect(self, instance_uri: str, driver: str, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -127,11 +127,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
if instance_uri in self._instances:
instance = self._instances[instance_uri]
else:
instance = Instance(
instance_uri,
self._client,
self._keys,
)
instance = Instance(instance_uri, self._client, self._keys)
self._instances[instance_uri] = instance

connect_func = {
Expand Down
5 changes: 3 additions & 2 deletions google/cloud/alloydb/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

if TYPE_CHECKING:
import ssl
from cryptography.hazmat.primitives.asymmetric import rsa

# from cryptography.hazmat.primitives.asymmetric import rsa
from google.cloud.alloydb.connector.client import AlloyDBClient

logger = logging.getLogger(name=__name__)
Expand All @@ -53,7 +54,7 @@ def __init__(
self,
instance_uri: str,
client: AlloyDBClient,
keys: Tuple[rsa.RSAPrivateKey, str],
keys: asyncio.Future,
) -> None:
# validate and parse instance_uri
instance_uri_split = instance_uri.split("/")
Expand Down
3 changes: 1 addition & 2 deletions google/cloud/alloydb/connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def _write_to_file(

async def generate_keys() -> Tuple[rsa.RSAPrivateKey, str]:

priv_key = rsa.generate_private_key(public_exponent=65537,
key_size=2048)
priv_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)

pub_key = (
priv_key.public_key()
Expand Down
40 changes: 32 additions & 8 deletions tests/unit/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ async def test_Instance_init() -> None:
Test to check whether the __init__ method of Instance
can tell if the instance URI that's passed in is formatted correctly.
"""
keys = await generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
async with aiohttp.ClientSession() as client:
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -52,7 +55,10 @@ async def test_Instance_init_invalid_instant_uri() -> None:
Test to check whether the __init__ method of Instance
will throw error for invalid instance URI.
"""
keys = await generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
async with aiohttp.ClientSession() as client:
with pytest.raises(ValueError):
Instance("invalid/instance/uri/", client, keys)
Expand All @@ -64,7 +70,10 @@ async def test_Instance_close() -> None:
Test that Instance's close method
cancels tasks gracefully.
"""
keys = await generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -84,7 +93,10 @@ async def test_Instance_close() -> None:
@pytest.mark.asyncio
async def test_perform_refresh() -> None:
"""Test that _perform refresh returns valid RefreshResult"""
keys = await generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -104,7 +116,10 @@ async def test_schedule_refresh_replaces_result() -> None:
Test to check whether _schedule_refresh replaces a valid refresh result
with another refresh result.
"""
keys = await generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand All @@ -131,7 +146,10 @@ async def test_schedule_refresh_wont_replace_valid_result_with_invalid() -> None
Test to check whether _schedule_refresh won't replace a valid
refresh result with an invalid one.
"""
keys = await generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand Down Expand Up @@ -160,7 +178,10 @@ async def test_schedule_refresh_expired_cert() -> None:
Test to check whether _schedule_refresh will throw RefreshError on
expired certificate.
"""
keys = await generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
client = FakeAlloyDBClient()
# set certificate to be expired
client.instance.cert_before = datetime.now() - timedelta(minutes=20)
Expand All @@ -182,7 +203,10 @@ async def test_force_refresh_cancels_pending_refresh() -> None:
"""
Test that force_refresh cancels pending task if refresh_in_progress event is not set.
"""
keys = await generate_keys()
event_loop = asyncio.get_running_loop()
keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop
)
client = FakeAlloyDBClient()
instance = Instance(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
Expand Down

0 comments on commit 39bb079

Please sign in to comment.