diff --git a/google/cloud/alloydb/connector/client.py b/google/cloud/alloydb/connector/client.py index e177152d..4d4a18ba 100644 --- a/google/cloud/alloydb/connector/client.py +++ b/google/cloud/alloydb/connector/client.py @@ -15,7 +15,7 @@ from __future__ import annotations import logging -from typing import List, Optional, Tuple, TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING import aiohttp from cryptography.hazmat.primitives import serialization @@ -117,7 +117,7 @@ async def _get_client_certificate( region: str, cluster: str, key: rsa.RSAPrivateKey, - ) -> Tuple[str, List[str]]: + ) -> List[str]: """ Fetch a client certificate for the given AlloyDB cluster. @@ -166,7 +166,7 @@ async def _get_client_certificate( ) resp_dict = await resp.json() - return (resp_dict["pemCertificate"], resp_dict["pemCertificateChain"]) + return resp_dict["pemCertificateChain"] async def close(self) -> None: """Close AlloyDBClient gracefully.""" diff --git a/google/cloud/alloydb/connector/refresh.py b/google/cloud/alloydb/connector/refresh.py index 9031298d..fa0e0cf7 100644 --- a/google/cloud/alloydb/connector/refresh.py +++ b/google/cloud/alloydb/connector/refresh.py @@ -19,7 +19,7 @@ import logging import ssl from tempfile import TemporaryDirectory -from typing import List, Tuple, TYPE_CHECKING +from typing import List, TYPE_CHECKING from cryptography import x509 @@ -75,11 +75,9 @@ class RefreshResult: """ def __init__( - self, instance_ip: str, key: rsa.RSAPrivateKey, certs: Tuple[str, List[str]] + self, instance_ip: str, key: rsa.RSAPrivateKey, cert_chain: List[str] ) -> None: self.instance_ip = instance_ip - self._key = key - self._certs = certs # create TLS context self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) @@ -90,9 +88,8 @@ def __init__( # add request_ssl attribute to ssl.SSLContext, required for pg8000 driver self.context.request_ssl = False # type: ignore - client_cert, cert_chain = self._certs # get expiration from client certificate - cert_obj = x509.load_pem_x509_certificate(client_cert.encode("UTF-8")) + cert_obj = x509.load_pem_x509_certificate(cert_chain[0].encode("UTF-8")) self.expiration = cert_obj.not_valid_after # tmpdir and its contents are automatically deleted after the CA cert @@ -100,7 +97,7 @@ def __init__( # need to be written to files in order to be loaded by the SSLContext with TemporaryDirectory() as tmpdir: ca_filename, cert_chain_filename, key_filename = _write_to_file( - tmpdir, cert_chain, client_cert, self._key + tmpdir, cert_chain, key ) self.context.load_cert_chain(cert_chain_filename, keyfile=key_filename) self.context.load_verify_locations(cafile=ca_filename) diff --git a/google/cloud/alloydb/connector/utils.py b/google/cloud/alloydb/connector/utils.py index 96dee4c6..aa9e7563 100644 --- a/google/cloud/alloydb/connector/utils.py +++ b/google/cloud/alloydb/connector/utils.py @@ -16,16 +16,14 @@ from typing import List, Tuple, TYPE_CHECKING -from cryptography import x509 -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.x509.oid import NameOID +from cryptography.hazmat.primitives import serialization if TYPE_CHECKING: from cryptography.hazmat.primitives.asymmetric import rsa def _write_to_file( - dir_path: str, cert_chain: List[str], client_cert: str, key: rsa.RSAPrivateKey + dir_path: str, cert_chain: List[str], key: rsa.RSAPrivateKey ) -> Tuple[str, str, str]: """ Helper function to write the server_ca, client certificate and @@ -41,13 +39,10 @@ def _write_to_file( encryption_algorithm=serialization.NoEncryption(), ) - # add client cert to beginning of cert chain - full_chain = [client_cert] + cert_chain - with open(ca_filename, "w+") as ca_out: ca_out.write("".join(cert_chain)) with open(cert_chain_filename, "w+") as chain_out: - chain_out.write("".join(full_chain)) + chain_out.write("".join(cert_chain)) with open(key_filename, "wb") as priv_out: priv_out.write(key_bytes)