Skip to content

Commit

Permalink
chore: use generate_keys
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon committed Oct 24, 2023
1 parent b3d7ebd commit b194fb0
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 24 deletions.
16 changes: 3 additions & 13 deletions google/cloud/alloydb/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
from typing import List, Optional, Tuple, TYPE_CHECKING

import aiohttp
from cryptography.hazmat.primitives import serialization

from google.auth.transport.requests import Request
from google.cloud.alloydb.connector.version import __version__ as version

if TYPE_CHECKING:
from cryptography.hazmat.primitives.asymmetric import rsa
from google.auth.credentials import Credentials

USER_AGENT: str = f"alloydb-python-connector/{version}"
Expand Down Expand Up @@ -116,7 +114,7 @@ async def _get_client_certificate(
project: str,
region: str,
cluster: str,
key: rsa.RSAPrivateKey,
pub_key: str,
) -> Tuple[str, List[str]]:
"""
Fetch a client certificate for the given AlloyDB cluster.
Expand All @@ -130,8 +128,7 @@ async def _get_client_certificate(
resides in.
region (str): Google Cloud region of the AlloyDB instance.
cluster (str): The name of the AlloyDB cluster.
key (rsa.RSAPrivateKey): Client private key used in refresh operation
to generate client certificate.
pub_key (str): PEM-encoded client public key.
Returns:
Tuple[str, list[str]]: Tuple containing the CA certificate
Expand All @@ -149,15 +146,8 @@ async def _get_client_certificate(

url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}:generateClientCertificate"

# get client public key
pub_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
pub_key_str = pub_key.decode("UTF-8")

data = {
"publicKey": pub_key_str,
"publicKey": pub_key,
"certDuration": "3600s",
}

Expand Down
7 changes: 3 additions & 4 deletions google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
from types import TracebackType
from typing import Any, Dict, Optional, Type, TYPE_CHECKING

from cryptography.hazmat.primitives.asymmetric import rsa

from google.auth import default
from google.auth.credentials import with_scopes_if_required
from google.cloud.alloydb.connector.client import AlloyDBClient
from google.cloud.alloydb.connector.instance import Instance
import google.cloud.alloydb.connector.pg8000 as pg8000
from google.cloud.alloydb.connector.utils import generate_keys

if TYPE_CHECKING:
from google.auth.credentials import Credentials
Expand Down Expand Up @@ -68,7 +67,7 @@ def __init__(
# otherwise use application default credentials
else:
self._credentials, _ = default(scopes=scopes)
self._key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
self._keys = generate_keys()
self._client: Optional[AlloyDBClient] = None

def connect(self, instance_uri: str, driver: str, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -127,7 +126,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
instance = Instance(
instance_uri,
self._client,
self._key,
self._keys,
)
self._instances[instance_uri] = instance

Expand Down
16 changes: 9 additions & 7 deletions google/cloud/alloydb/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ class Instance:
instance_uri (str): The instance URI of the AlloyDB instance.
ex. projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>
client (AlloyDBClient): Client used to make requests to AlloyDB Admin APIs.
key (rsa.RSAPrivateKey): Client private key used in refresh operation
to generate client certificate.
keys (Tuple[rsa.RSAPrivateKey, str]): Private and Public key pair.
"""

def __init__(
self, instance_uri: str, client: AlloyDBClient, key: rsa.RSAPrivateKey
self,
instance_uri: str,
client: AlloyDBClient,
keys: Tuple[rsa.RSAPrivateKey, str],
) -> None:
# validate and parse instance_uri
instance_uri_split = instance_uri.split("/")
Expand All @@ -70,7 +72,7 @@ def __init__(
)

self._client = client
self._key = key
self._keys = keys
self._refresh_rate_limiter = AsyncRateLimiter(
max_capacity=2,
rate=1 / 30,
Expand All @@ -96,7 +98,7 @@ async def _perform_refresh(self) -> RefreshResult:

try:
await self._refresh_rate_limiter.acquire()

priv_key, pub_key = self._keys
# fetch metadata
metadata_task = asyncio.create_task(
self._client._get_metadata(
Expand All @@ -112,7 +114,7 @@ async def _perform_refresh(self) -> RefreshResult:
self._project,
self._region,
self._cluster,
self._key,
pub_key,
)
)

Expand All @@ -127,7 +129,7 @@ async def _perform_refresh(self) -> RefreshResult:
finally:
self._refresh_in_progress.clear()

return RefreshResult(ip_addr, self._key, certs)
return RefreshResult(ip_addr, priv_key, certs)

def _schedule_refresh(self, delay: int) -> asyncio.Task:
"""
Expand Down

0 comments on commit b194fb0

Please sign in to comment.