Skip to content

Commit

Permalink
refactor: add ConnectionName class (#1186)
Browse files Browse the repository at this point in the history
This PR refactors all instance connection name related code into
its own file connection_name.py

It introduces the ConnectionName class which will make tracking
if a DNS name was given to the Connector easier in the future.
  • Loading branch information
jackwotherspoon authored Nov 1, 2024
1 parent 3b24c10 commit ef7d8fe
Show file tree
Hide file tree
Showing 15 changed files with 147 additions and 73 deletions.
51 changes: 51 additions & 0 deletions google/cloud/sql/connector/connection_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
import re

# Instance connection name is the format <PROJECT>:<REGION>:<INSTANCE_NAME>
# Additionally, we have to support legacy "domain-scoped" projects
# (e.g. "google.com:PROJECT")
CONN_NAME_REGEX = re.compile(("([^:]+(:[^:]+)?):([^:]+):([^:]+)"))


@dataclass
class ConnectionName:
"""ConnectionName represents a Cloud SQL instance's "instance connection name".
Takes the format "<PROJECT>:<REGION>:<INSTANCE_NAME>".
"""

project: str
region: str
instance_name: str

def __str__(self) -> str:
return f"{self.project}:{self.region}:{self.instance_name}"


def _parse_instance_connection_name(connection_name: str) -> ConnectionName:
if CONN_NAME_REGEX.fullmatch(connection_name) is None:
raise ValueError(
"Arg `instance_connection_string` must have "
"format: PROJECT:REGION:INSTANCE, "
f"got {connection_name}."
)
connection_name_split = CONN_NAME_REGEX.split(connection_name)
return ConnectionName(
connection_name_split[1],
connection_name_split[3],
connection_name_split[4],
)
50 changes: 17 additions & 33 deletions google/cloud/sql/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
from datetime import timedelta
from datetime import timezone
import logging
import re

import aiohttp

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
from google.cloud.sql.connector.exceptions import RefreshNotValidError
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
from google.cloud.sql.connector.refresh_utils import _is_valid
Expand All @@ -36,22 +36,6 @@

APPLICATION_NAME = "cloud-sql-python-connector"

# Instance connection name is the format <PROJECT>:<REGION>:<INSTANCE>
# Additionally, we have to support legacy "domain-scoped" projects
# (e.g. "google.com:PROJECT")
CONN_NAME_REGEX = re.compile(("([^:]+(:[^:]+)?):([^:]+):([^:]+)"))


def _parse_instance_connection_name(connection_name: str) -> tuple[str, str, str]:
if CONN_NAME_REGEX.fullmatch(connection_name) is None:
raise ValueError(
"Arg `instance_connection_string` must have "
"format: PROJECT:REGION:INSTANCE, "
f"got {connection_name}."
)
connection_name_split = CONN_NAME_REGEX.split(connection_name)
return connection_name_split[1], connection_name_split[3], connection_name_split[4]


class RefreshAheadCache:
"""Cache that refreshes connection info in the background prior to expiration.
Expand Down Expand Up @@ -81,10 +65,13 @@ def __init__(
connections.
"""
# validate and parse instance connection name
self._project, self._region, self._instance = _parse_instance_connection_name(
instance_connection_string
conn_name = _parse_instance_connection_name(instance_connection_string)
self._project, self._region, self._instance = (
conn_name.project,
conn_name.region,
conn_name.instance_name,
)
self._instance_connection_string = instance_connection_string
self._conn_name = conn_name

self._enable_iam_auth = enable_iam_auth
self._keys = keys
Expand Down Expand Up @@ -121,8 +108,7 @@ async def _perform_refresh(self) -> ConnectionInfo:
"""
self._refresh_in_progress.set()
logger.debug(
f"['{self._instance_connection_string}']: Connection info refresh "
"operation started"
f"['{self._conn_name}']: Connection info refresh " "operation started"
)

try:
Expand All @@ -135,17 +121,16 @@ async def _perform_refresh(self) -> ConnectionInfo:
self._enable_iam_auth,
)
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
"refresh operation complete"
f"['{self._conn_name}']: Connection info " "refresh operation complete"
)
logger.debug(
f"['{self._instance_connection_string}']: Current certificate "
f"['{self._conn_name}']: Current certificate "
f"expiration = {connection_info.expiration.isoformat()}"
)

except aiohttp.ClientResponseError as e:
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"['{self._conn_name}']: Connection info "
f"refresh operation failed: {str(e)}"
)
if e.status == 403:
Expand All @@ -154,7 +139,7 @@ async def _perform_refresh(self) -> ConnectionInfo:

except Exception as e:
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"['{self._conn_name}']: Connection info "
f"refresh operation failed: {str(e)}"
)
raise
Expand Down Expand Up @@ -188,18 +173,17 @@ async def _refresh_task(self: RefreshAheadCache, delay: int) -> ConnectionInfo:
# check that refresh is valid
if not await _is_valid(refresh_task):
raise RefreshNotValidError(
f"['{self._instance_connection_string}']: Invalid refresh operation. Certficate appears to be expired."
f"['{self._conn_name}']: Invalid refresh operation. Certficate appears to be expired."
)
except asyncio.CancelledError:
logger.debug(
f"['{self._instance_connection_string}']: Scheduled refresh"
" operation cancelled"
f"['{self._conn_name}']: Scheduled refresh" " operation cancelled"
)
raise
# bad refresh attempt
except Exception as e:
logger.exception(
f"['{self._instance_connection_string}']: "
f"['{self._conn_name}']: "
"An error occurred while performing refresh. "
"Scheduling another refresh attempt immediately",
exc_info=e,
Expand All @@ -216,7 +200,7 @@ async def _refresh_task(self: RefreshAheadCache, delay: int) -> ConnectionInfo:
# calculate refresh delay based on certificate expiration
delay = _seconds_until_refresh(refresh_data.expiration)
logger.debug(
f"['{self._instance_connection_string}']: Connection info refresh"
f"['{self._conn_name}']: Connection info refresh"
" operation scheduled for "
f"{(datetime.now(timezone.utc) + timedelta(seconds=delay)).isoformat(timespec='seconds')} "
f"(now + {timedelta(seconds=delay)})"
Expand All @@ -240,7 +224,7 @@ async def close(self) -> None:
graceful exit.
"""
logger.debug(
f"['{self._instance_connection_string}']: Canceling connection info "
f"['{self._conn_name}']: Canceling connection info "
"refresh operation tasks"
)
self._current.cancel()
Expand Down
22 changes: 12 additions & 10 deletions google/cloud/sql/connector/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.instance import _parse_instance_connection_name
from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
from google.cloud.sql.connector.refresh_utils import _refresh_buffer

logger = logging.getLogger(name=__name__)
Expand Down Expand Up @@ -56,10 +56,13 @@ def __init__(
connections.
"""
# validate and parse instance connection name
self._project, self._region, self._instance = _parse_instance_connection_name(
instance_connection_string
conn_name = _parse_instance_connection_name(instance_connection_string)
self._project, self._region, self._instance = (
conn_name.project,
conn_name.region,
conn_name.instance_name,
)
self._instance_connection_string = instance_connection_string
self._conn_name = conn_name

self._enable_iam_auth = enable_iam_auth
self._keys = keys
Expand Down Expand Up @@ -91,13 +94,12 @@ async def connect_info(self) -> ConnectionInfo:
< (self._cached.expiration - timedelta(seconds=_refresh_buffer))
):
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"['{self._conn_name}']: Connection info "
"is still valid, using cached info"
)
return self._cached
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
"refresh operation started"
f"['{self._conn_name}']: Connection info " "refresh operation started"
)
try:
conn_info = await self._client.get_connection_info(
Expand All @@ -109,16 +111,16 @@ async def connect_info(self) -> ConnectionInfo:
)
except Exception as e:
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"['{self._conn_name}']: Connection info "
f"refresh operation failed: {str(e)}"
)
raise
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"['{self._conn_name}']: Connection info "
"refresh operation completed successfully"
)
logger.debug(
f"['{self._instance_connection_string}']: Current certificate "
f"['{self._conn_name}']: Current certificate "
f"expiration = {str(conn_info.expiration)}"
)
self._cached = conn_info
Expand Down
1 change: 1 addition & 0 deletions google/cloud/sql/connector/pg8000.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import socket
import ssl
from typing import Any, TYPE_CHECKING
Expand Down
1 change: 1 addition & 0 deletions google/cloud/sql/connector/pymysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import socket
import ssl
from typing import Any, TYPE_CHECKING
Expand Down
1 change: 1 addition & 0 deletions google/cloud/sql/connector/pytds.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import platform
import socket
import ssl
Expand Down
2 changes: 2 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def lint(session):
"--check-only",
"--diff",
"--profile=google",
"-w=88",
*LINT_PATHS,
)
session.run("black", "--check", "--diff", *LINT_PATHS)
Expand Down Expand Up @@ -85,6 +86,7 @@ def format(session):
"isort",
"--fss",
"--profile=google",
"-w=88",
*LINT_PATHS,
)
session.run(
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import asyncio
import os
import socket
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/test_connection_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest # noqa F401 Needed to run the tests

from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
from google.cloud.sql.connector.connection_name import ConnectionName


def test_ConnectionName() -> None:
conn_name = ConnectionName("project", "region", "instance")
# test class attributes are set properly
assert conn_name.project == "project"
assert conn_name.region == "region"
assert conn_name.instance_name == "instance"
# test ConnectionName str() method prints instance connection name
assert str(conn_name) == "project:region:instance"


@pytest.mark.parametrize(
"connection_name, expected",
[
("project:region:instance", ConnectionName("project", "region", "instance")),
(
"domain-prefix:project:region:instance",
ConnectionName("domain-prefix:project", "region", "instance"),
),
],
)
def test_parse_instance_connection_name(
connection_name: str, expected: ConnectionName
) -> None:
"""
Test that _parse_instance_connection_name works correctly on
normal instance connection names and domain-scoped projects.
"""
assert expected == _parse_instance_connection_name(connection_name)


def test_parse_instance_connection_name_bad_conn_name() -> None:
"""
Tests that ValueError is thrown for bad instance connection names.
"""
with pytest.raises(ValueError):
_parse_instance_connection_name("project:instance") # missing region
Loading

0 comments on commit ef7d8fe

Please sign in to comment.