From b194fb0941936a99b7df066c2e45f0de201df4cf Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Tue, 24 Oct 2023 11:14:01 +0000 Subject: [PATCH] chore: use generate_keys --- google/cloud/alloydb/connector/client.py | 16 +++------------- google/cloud/alloydb/connector/connector.py | 7 +++---- google/cloud/alloydb/connector/instance.py | 16 +++++++++------- 3 files changed, 15 insertions(+), 24 deletions(-) diff --git a/google/cloud/alloydb/connector/client.py b/google/cloud/alloydb/connector/client.py index 8d7e4664..a39447c4 100644 --- a/google/cloud/alloydb/connector/client.py +++ b/google/cloud/alloydb/connector/client.py @@ -18,13 +18,11 @@ from typing import List, Optional, Tuple, TYPE_CHECKING import aiohttp -from cryptography.hazmat.primitives import serialization from google.auth.transport.requests import Request from google.cloud.alloydb.connector.version import __version__ as version if TYPE_CHECKING: - from cryptography.hazmat.primitives.asymmetric import rsa from google.auth.credentials import Credentials USER_AGENT: str = f"alloydb-python-connector/{version}" @@ -116,7 +114,7 @@ async def _get_client_certificate( project: str, region: str, cluster: str, - key: rsa.RSAPrivateKey, + pub_key: str, ) -> Tuple[str, List[str]]: """ Fetch a client certificate for the given AlloyDB cluster. @@ -130,8 +128,7 @@ async def _get_client_certificate( resides in. region (str): Google Cloud region of the AlloyDB instance. cluster (str): The name of the AlloyDB cluster. - key (rsa.RSAPrivateKey): Client private key used in refresh operation - to generate client certificate. + pub_key (str): PEM-encoded client public key. Returns: Tuple[str, list[str]]: Tuple containing the CA certificate @@ -149,15 +146,8 @@ async def _get_client_certificate( url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}:generateClientCertificate" - # get client public key - pub_key = key.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) - pub_key_str = pub_key.decode("UTF-8") - data = { - "publicKey": pub_key_str, + "publicKey": pub_key, "certDuration": "3600s", } diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index c960cb5a..71031fad 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -20,13 +20,12 @@ from types import TracebackType from typing import Any, Dict, Optional, Type, TYPE_CHECKING -from cryptography.hazmat.primitives.asymmetric import rsa - from google.auth import default from google.auth.credentials import with_scopes_if_required from google.cloud.alloydb.connector.client import AlloyDBClient from google.cloud.alloydb.connector.instance import Instance import google.cloud.alloydb.connector.pg8000 as pg8000 +from google.cloud.alloydb.connector.utils import generate_keys if TYPE_CHECKING: from google.auth.credentials import Credentials @@ -68,7 +67,7 @@ def __init__( # otherwise use application default credentials else: self._credentials, _ = default(scopes=scopes) - self._key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + self._keys = generate_keys() self._client: Optional[AlloyDBClient] = None def connect(self, instance_uri: str, driver: str, **kwargs: Any) -> Any: @@ -127,7 +126,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> instance = Instance( instance_uri, self._client, - self._key, + self._keys, ) self._instances[instance_uri] = instance diff --git a/google/cloud/alloydb/connector/instance.py b/google/cloud/alloydb/connector/instance.py index 62b686d5..4a06ade1 100644 --- a/google/cloud/alloydb/connector/instance.py +++ b/google/cloud/alloydb/connector/instance.py @@ -46,12 +46,14 @@ class Instance: instance_uri (str): The instance URI of the AlloyDB instance. ex. projects//locations//clusters//instances/ client (AlloyDBClient): Client used to make requests to AlloyDB Admin APIs. - key (rsa.RSAPrivateKey): Client private key used in refresh operation - to generate client certificate. + keys (Tuple[rsa.RSAPrivateKey, str]): Private and Public key pair. """ def __init__( - self, instance_uri: str, client: AlloyDBClient, key: rsa.RSAPrivateKey + self, + instance_uri: str, + client: AlloyDBClient, + keys: Tuple[rsa.RSAPrivateKey, str], ) -> None: # validate and parse instance_uri instance_uri_split = instance_uri.split("/") @@ -70,7 +72,7 @@ def __init__( ) self._client = client - self._key = key + self._keys = keys self._refresh_rate_limiter = AsyncRateLimiter( max_capacity=2, rate=1 / 30, @@ -96,7 +98,7 @@ async def _perform_refresh(self) -> RefreshResult: try: await self._refresh_rate_limiter.acquire() - + priv_key, pub_key = self._keys # fetch metadata metadata_task = asyncio.create_task( self._client._get_metadata( @@ -112,7 +114,7 @@ async def _perform_refresh(self) -> RefreshResult: self._project, self._region, self._cluster, - self._key, + pub_key, ) ) @@ -127,7 +129,7 @@ async def _perform_refresh(self) -> RefreshResult: finally: self._refresh_in_progress.clear() - return RefreshResult(ip_addr, self._key, certs) + return RefreshResult(ip_addr, priv_key, certs) def _schedule_refresh(self, delay: int) -> asyncio.Task: """