From c65f700dfe61f0461d3d1f570f6976f7406f0e58 Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Tue, 24 Dec 2024 03:37:35 +0000 Subject: [PATCH] feat: support static connection info --- .../alloydb/connector/async_connector.py | 10 ++ google/cloud/alloydb/connector/connector.py | 10 ++ google/cloud/alloydb/connector/static.py | 90 ++++++++++++++++++ tests/unit/conftest.py | 9 +- tests/unit/mocks.py | 95 +++++++++++++++---- tests/unit/test_async_connector.py | 45 +++++++++ tests/unit/test_connector.py | 48 ++++++++++ 7 files changed, 281 insertions(+), 26 deletions(-) create mode 100644 google/cloud/alloydb/connector/static.py diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index fba74887..0035a5eb 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -15,6 +15,7 @@ from __future__ import annotations import asyncio +import io import logging from types import TracebackType from typing import Any, Dict, Optional, Type, TYPE_CHECKING, Union @@ -29,6 +30,7 @@ from google.cloud.alloydb.connector.enums import RefreshStrategy from google.cloud.alloydb.connector.instance import RefreshAheadCache from google.cloud.alloydb.connector.lazy import LazyRefreshCache +from google.cloud.alloydb.connector.static import StaticConnectionInfoCache from google.cloud.alloydb.connector.utils import generate_keys if TYPE_CHECKING: @@ -59,6 +61,9 @@ class AsyncConnector: of the following: RefreshStrategy.LAZY ("LAZY") or RefreshStrategy.BACKGROUND ("BACKGROUND"). Default: RefreshStrategy.BACKGROUND + static_conn_info (io.TextIOBase): A file-like JSON object that contains + static connection info for the StaticConnectionInfoCache. + Defaults to None, which will not use the StaticConnectionInfoCache. """ def __init__( @@ -70,6 +75,7 @@ def __init__( ip_type: str | IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, + static_conn_info: io.TextIOBase = None, ) -> None: self._cache: Dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {} # initialize default params @@ -100,6 +106,7 @@ def __init__( except RuntimeError: self._keys = None self._client: Optional[AlloyDBClient] = None + self._static_conn_info = static_conn_info async def connect( self, @@ -138,10 +145,13 @@ async def connect( ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) + static_conn_info = kwargs.pop("static_conn_info", self._static_conn_info) # use existing connection info if possible if instance_uri in self._cache: cache = self._cache[instance_uri] + elif static_conn_info: + cache = StaticConnectionInfoCache(instance_uri, static_conn_info) else: if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index c4ad2997..b0e8dcfa 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -16,6 +16,7 @@ import asyncio from functools import partial +import io import logging import socket import struct @@ -34,6 +35,7 @@ from google.cloud.alloydb.connector.instance import RefreshAheadCache from google.cloud.alloydb.connector.lazy import LazyRefreshCache import google.cloud.alloydb.connector.pg8000 as pg8000 +from google.cloud.alloydb.connector.static import StaticConnectionInfoCache from google.cloud.alloydb.connector.utils import generate_keys import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb @@ -71,6 +73,9 @@ class Connector: of the following: RefreshStrategy.LAZY ("LAZY") or RefreshStrategy.BACKGROUND ("BACKGROUND"). Default: RefreshStrategy.BACKGROUND + static_conn_info (io.TextIOBase): A file-like JSON object that contains + static connection info for the StaticConnectionInfoCache. + Defaults to None, which will not use the StaticConnectionInfoCache. """ def __init__( @@ -82,6 +87,7 @@ def __init__( ip_type: str | IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, + static_conn_info: io.TextIOBase = None, ) -> None: # create event loop and start it in background thread self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() @@ -113,6 +119,7 @@ def __init__( loop=self._loop, ) self._client: Optional[AlloyDBClient] = None + self._static_conn_info = static_conn_info def connect(self, instance_uri: str, driver: str, **kwargs: Any) -> Any: """ @@ -168,9 +175,12 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> driver=driver, ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) + static_conn_info = kwargs.pop("static_conn_info", self._static_conn_info) # use existing connection info if possible if instance_uri in self._cache: cache = self._cache[instance_uri] + elif static_conn_info: + cache = StaticConnectionInfoCache(instance_uri, static_conn_info) else: if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( diff --git a/google/cloud/alloydb/connector/static.py b/google/cloud/alloydb/connector/static.py new file mode 100644 index 00000000..9731e594 --- /dev/null +++ b/google/cloud/alloydb/connector/static.py @@ -0,0 +1,90 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from datetime import datetime +from datetime import timedelta +from datetime import timezone +import io +import json + +from google.cloud.alloydb.connector.connection_info import ConnectionInfo + + +class StaticConnectionInfoCache: + """ + StaticConnectionInfoCache creates a connection info cache that will always + return a pre-defined connection info. + + This static connection info should hold JSON with the following format: + { + "publicKey": "", + "privateKey": "", + "projects//locations//clusters//instances/": { + "ipAddress": "", + "publicIpAddress": "", + "pscInstanceConfig": { + "pscDnsName": "" + }, + "pemCertificateChain": [ + "", "", "" + ], + "caCert": "" + } + } + """ + + def __init__(self, instance_uri: str, static_conn_info: io.TextIOBase) -> None: + """ + Initializes a StaticConnectionInfoCache instance. + + Args: + instance_uri (str): The AlloyDB instance's connection URI. + static_conn_info (io.TextIOBase): The static connection info JSON. + """ + static_info = json.load(static_conn_info) + ca_cert = static_info[instance_uri]["caCert"] + cert_chain = static_info[instance_uri]["pemCertificateChain"] + ip_addrs = { + "PRIVATE": static_info[instance_uri]["ipAddress"], + "PUBLIC": static_info[instance_uri]["publicIpAddress"], + "PSC": static_info[instance_uri]["pscInstanceConfig"]["pscDnsName"], + } + expiration = datetime.now(timezone.utc) + timedelta(hours=1) + priv_key = static_info["privateKey"] + priv_key_bytes: rsa.RSAPrivateKey = serialization.load_pem_private_key( + priv_key.encode("UTF-8"), password=None, + ) + self._info = ConnectionInfo(cert_chain, ca_cert, priv_key_bytes, ip_addrs, expiration) + + async def force_refresh(self) -> None: + """ + This is a no-op as the cache holds only static connection information + and does no refresh. + """ + pass + + async def connect_info(self) -> ConnectionInfo: + """ + Retrieves ConnectionInfo instance for establishing a secure + connection to the AlloyDB instance. + """ + return self._info + + async def close(self) -> None: + """ + This is a no-op. + """ + pass \ No newline at end of file diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 45648fa7..c7e27f4a 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -66,8 +66,8 @@ async def start_proxy_server(instance: FakeInstance) -> None: # listen for incoming connections sock.listen(5) - while True: - with context.wrap_socket(sock, server_side=True) as ssock: + with context.wrap_socket(sock, server_side=True) as ssock: + while True: conn, _ = ssock.accept() metadata_exchange(conn) conn.sendall(instance.name.encode("utf-8")) @@ -75,7 +75,7 @@ async def start_proxy_server(instance: FakeInstance) -> None: @pytest.fixture(scope="session") -def proxy_server(fake_instance: FakeInstance) -> Generator: +def proxy_server(fake_instance: FakeInstance) -> None: """Run local proxy server capable of performing metadata exchange""" thread = Thread( target=asyncio.run, @@ -87,5 +87,4 @@ def proxy_server(fake_instance: FakeInstance) -> Generator: daemon=True, ) thread.start() - yield thread - thread.join() + thread.join(0.1) # wait 100ms to allow the proxy server to start diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index ae600356..31b4df30 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -16,7 +16,9 @@ from datetime import datetime from datetime import timedelta from datetime import timezone +import io import ipaddress +import json import ssl import struct from typing import Any, Callable, Dict, List, Literal, Optional, Tuple @@ -193,6 +195,34 @@ def get_pem_certs(self) -> Tuple[str, str, str]: encoding=serialization.Encoding.PEM ).decode("UTF-8") return (pem_root, pem_intermediate, pem_server) + + def generate_pem_certificate_chain(self, pub_key: str) -> Tuple[str, List[str]]: + """Generate the CA certificate and certificate chain for the AlloyDB instance.""" + root_cert, intermediate_cert, server_cert = self.get_pem_certs() + # encode public key to bytes + pub_key_bytes: rsa.RSAPublicKey = serialization.load_pem_public_key( + pub_key.encode("UTF-8"), + ) + # build client cert + client_cert = ( + x509.CertificateBuilder() + .subject_name(self.intermediate_cert.subject) + .issuer_name(self.intermediate_cert.issuer) + .public_key(pub_key_bytes) + .serial_number(x509.random_serial_number()) + .not_valid_before(self.cert_before) + .not_valid_after(self.cert_expiry) + ) + # sign client cert with intermediate cert + client_cert = client_cert.sign(self.intermediate_key, hashes.SHA256()) + client_cert = client_cert.public_bytes( + encoding=serialization.Encoding.PEM + ).decode("UTF-8") + return (server_cert, [client_cert, intermediate_cert, root_cert]) + + def uri(self) -> str: + """The URI of the AlloyDB instance.""" + return f"projects/{self.project}/locations/{self.region}/clusters/{self.cluster}/instances/{self.name}" class FakeAlloyDBClient: @@ -216,27 +246,7 @@ async def _get_client_certificate( cluster: str, pub_key: str, ) -> Tuple[str, List[str]]: - root_cert, intermediate_cert, server_cert = self.instance.get_pem_certs() - # encode public key to bytes - pub_key_bytes: rsa.RSAPublicKey = serialization.load_pem_public_key( - pub_key.encode("UTF-8"), - ) - # build client cert - client_cert = ( - x509.CertificateBuilder() - .subject_name(self.instance.intermediate_cert.subject) - .issuer_name(self.instance.intermediate_cert.issuer) - .public_key(pub_key_bytes) - .serial_number(x509.random_serial_number()) - .not_valid_before(self.instance.cert_before) - .not_valid_after(self.instance.cert_expiry) - ) - # sign client cert with intermediate cert - client_cert = client_cert.sign(self.instance.intermediate_key, hashes.SHA256()) - client_cert = client_cert.public_bytes( - encoding=serialization.Encoding.PEM - ).decode("UTF-8") - return (server_cert, [client_cert, intermediate_cert, root_cert]) + return self.instance.generate_pem_certificate_chain(pub_key) async def get_connection_info( self, @@ -378,3 +388,46 @@ async def force_refresh(self) -> None: async def close(self) -> None: self._close_called = True + + +def write_static_info(i: FakeInstance) -> io.StringIO: + """ + Creates a static connection info JSON for the StaticConnectionInfoCache. + + Args: + i (FakeInstance): The FakeInstance to use to create the CA cert and + chain. + + Returns: + io.StringIO + """ + priv_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + pub_pem = ( + priv_key.public_key() + .public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + .decode("UTF-8") + ) + priv_pem = ( + priv_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + .decode("UTF-8") + ) + ca_cert, chain = i.generate_pem_certificate_chain(pub_pem) + static = { + "publicKey": pub_pem, + "privateKey": priv_pem, + } + static[i.uri()] = { + "pemCertificateChain": chain, + "caCert": ca_cert, + "ipAddress": "127.0.0.1", # "private" IP is localhost in testing + "publicIpAddress": "", + "pscInstanceConfig": {"pscDnsName": ""}, + } + return io.StringIO(json.dumps(static)) diff --git a/tests/unit/test_async_connector.py b/tests/unit/test_async_connector.py index 0f150875..8e518276 100644 --- a/tests/unit/test_async_connector.py +++ b/tests/unit/test_async_connector.py @@ -20,6 +20,7 @@ from mocks import FakeAlloyDBClient from mocks import FakeConnectionInfo from mocks import FakeCredentials +from mocks import write_static_info import pytest from google.cloud.alloydb.connector import AsyncConnector @@ -333,3 +334,47 @@ async def test_Connector_remove_cached_no_ip_type(credentials: FakeCredentials) await connector.connect(instance_uri, "asyncpg", ip_type="private") # check that cache has been removed from dict assert instance_uri not in connector._cache + +async def test_Connector_static_connection_info(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -> None: + """ + Test that AsyncConnector.__init__() can specify a static connection info to + connect to an instance. + """ + static_info = write_static_info(fake_client.instance) + async with AsyncConnector(credentials=credentials, static_conn_info=static_info) as connector: + connector._client = fake_client + # patch db connection creation + with patch("google.cloud.alloydb.connector.asyncpg.connect") as mock_connect: + mock_connect.return_value = True + connection = await connector.connect( + fake_client.instance.uri(), + "asyncpg", + user="test-user", + password="test-password", + db="test-db", + ) + # check connection is returned + assert connection is True + + +async def test_connect_static_connection_info(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -> None: + """ + Test that AsyncConnector.connect() can specify a static connection info to + connect to an instance. + """ + static_info = write_static_info(fake_client.instance) + async with AsyncConnector(credentials=credentials) as connector: + connector._client = fake_client + # patch db connection creation + with patch("google.cloud.alloydb.connector.asyncpg.connect") as mock_connect: + mock_connect.return_value = True + connection = await connector.connect( + fake_client.instance.uri(), + "asyncpg", + user="test-user", + password="test-password", + db="test-db", + static_conn_info=static_info, + ) + # check connection is returned + assert connection is True \ No newline at end of file diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index a02ad30e..902a032f 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -20,6 +20,7 @@ from mock import patch from mocks import FakeAlloyDBClient from mocks import FakeCredentials +from mocks import write_static_info import pytest from google.cloud.alloydb.connector import Connector @@ -248,3 +249,50 @@ async def test_Connector_remove_cached_no_ip_type(credentials: FakeCredentials) await connector.connect_async(instance_uri, "pg8000", ip_type="private") # check that cache has been removed from dict assert instance_uri not in connector._cache + + +@pytest.mark.usefixtures("proxy_server") +def test_Connector_static_connection_info(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -> None: + """ + Test that Connector.__init__() can specify a static connection info to + connect to an instance. + """ + static_info = write_static_info(fake_client.instance) + with Connector(credentials=credentials, static_conn_info=static_info) as connector: + connector._client = fake_client + # patch db connection creation + with patch("google.cloud.alloydb.connector.pg8000.connect") as mock_connect: + mock_connect.return_value = True + connection = connector.connect( + fake_client.instance.uri(), + "pg8000", + user="test-user", + password="test-password", + db="test-db", + ) + # check connection is returned + assert connection is True + + +@pytest.mark.usefixtures("proxy_server") +def test_connect_static_connection_info(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -> None: + """ + Test that Connector.connect() can specify a static connection info to + connect to an instance. + """ + static_info = write_static_info(fake_client.instance) + with Connector(credentials=credentials) as connector: + connector._client = fake_client + # patch db connection creation + with patch("google.cloud.alloydb.connector.pg8000.connect") as mock_connect: + mock_connect.return_value = True + connection = connector.connect( + fake_client.instance.uri(), + "pg8000", + user="test-user", + password="test-password", + db="test-db", + static_conn_info=static_info, + ) + # check connection is returned + assert connection is True \ No newline at end of file