Skip to content

Commit

Permalink
chore: use certChain
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon committed Oct 24, 2023
1 parent 1df5b75 commit 7babf3a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 18 deletions.
6 changes: 3 additions & 3 deletions google/cloud/alloydb/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import logging
from typing import List, Optional, Tuple, TYPE_CHECKING
from typing import List, Optional, TYPE_CHECKING

import aiohttp
from cryptography.hazmat.primitives import serialization
Expand Down Expand Up @@ -117,7 +117,7 @@ async def _get_client_certificate(
region: str,
cluster: str,
key: rsa.RSAPrivateKey,
) -> Tuple[str, List[str]]:
) -> List[str]:
"""
Fetch a client certificate for the given AlloyDB cluster.
Expand Down Expand Up @@ -166,7 +166,7 @@ async def _get_client_certificate(
)
resp_dict = await resp.json()

return (resp_dict["pemCertificate"], resp_dict["pemCertificateChain"])
return resp_dict["pemCertificateChain"]

async def close(self) -> None:
"""Close AlloyDBClient gracefully."""
Expand Down
11 changes: 4 additions & 7 deletions google/cloud/alloydb/connector/refresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import ssl
from tempfile import TemporaryDirectory
from typing import List, Tuple, TYPE_CHECKING
from typing import List, TYPE_CHECKING

from cryptography import x509

Expand Down Expand Up @@ -75,11 +75,9 @@ class RefreshResult:
"""

def __init__(
self, instance_ip: str, key: rsa.RSAPrivateKey, certs: Tuple[str, List[str]]
self, instance_ip: str, key: rsa.RSAPrivateKey, cert_chain: List[str]
) -> None:
self.instance_ip = instance_ip
self._key = key
self._certs = certs

# create TLS context
self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
Expand All @@ -90,17 +88,16 @@ def __init__(
# add request_ssl attribute to ssl.SSLContext, required for pg8000 driver
self.context.request_ssl = False # type: ignore

client_cert, cert_chain = self._certs
# get expiration from client certificate
cert_obj = x509.load_pem_x509_certificate(client_cert.encode("UTF-8"))
cert_obj = x509.load_pem_x509_certificate(cert_chain[0].encode("UTF-8"))
self.expiration = cert_obj.not_valid_after

# tmpdir and its contents are automatically deleted after the CA cert
# and cert chain are loaded into the SSLcontext. The values
# need to be written to files in order to be loaded by the SSLContext
with TemporaryDirectory() as tmpdir:
ca_filename, cert_chain_filename, key_filename = _write_to_file(
tmpdir, cert_chain, client_cert, self._key
tmpdir, cert_chain, key
)
self.context.load_cert_chain(cert_chain_filename, keyfile=key_filename)
self.context.load_verify_locations(cafile=ca_filename)
Expand Down
11 changes: 3 additions & 8 deletions google/cloud/alloydb/connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@

from typing import List, Tuple, TYPE_CHECKING

from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import serialization

if TYPE_CHECKING:
from cryptography.hazmat.primitives.asymmetric import rsa


def _write_to_file(
dir_path: str, cert_chain: List[str], client_cert: str, key: rsa.RSAPrivateKey
dir_path: str, cert_chain: List[str], key: rsa.RSAPrivateKey
) -> Tuple[str, str, str]:
"""
Helper function to write the server_ca, client certificate and
Expand All @@ -41,13 +39,10 @@ def _write_to_file(
encryption_algorithm=serialization.NoEncryption(),
)

# add client cert to beginning of cert chain
full_chain = [client_cert] + cert_chain

with open(ca_filename, "w+") as ca_out:
ca_out.write("".join(cert_chain))
with open(cert_chain_filename, "w+") as chain_out:
chain_out.write("".join(full_chain))
chain_out.write("".join(cert_chain))
with open(key_filename, "wb") as priv_out:
priv_out.write(key_bytes)

Expand Down

0 comments on commit 7babf3a

Please sign in to comment.