Skip to content

Commit

Permalink
fix: use IAM login creds in expiration logic (#898)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon authored Nov 15, 2023
1 parent 501d8da commit 7f8a3a4
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 43 deletions.
15 changes: 1 addition & 14 deletions google/cloud/sql/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
)

import aiohttp
from cryptography.hazmat.backends import default_backend
from cryptography.x509 import load_pem_x509_certificate

from google.auth.credentials import Credentials
from google.cloud.sql.connector.exceptions import (
Expand Down Expand Up @@ -323,18 +321,7 @@ async def _perform_refresh(self) -> ConnectionInfo:
ephemeral_task.cancel()
raise

ephemeral_cert = await ephemeral_task

x509 = load_pem_x509_certificate(
ephemeral_cert.encode("UTF-8"), default_backend()
)
expiration = x509.not_valid_after

if self._enable_iam_auth:
if self._credentials is not None:
token_expiration: datetime.datetime = self._credentials.expiry
if expiration > token_expiration:
expiration = token_expiration
ephemeral_cert, expiration = await ephemeral_task

except aiohttp.ClientResponseError as e:
logger.debug(
Expand Down
25 changes: 19 additions & 6 deletions google/cloud/sql/connector/refresh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
import copy
import datetime
import logging
from typing import Any, Dict, List, TYPE_CHECKING
from typing import Any, Dict, List, Tuple, TYPE_CHECKING

from cryptography.hazmat.backends import default_backend
from cryptography.x509 import load_pem_x509_certificate

from google.auth.credentials import Credentials, Scoped
import google.auth.transport.requests
Expand Down Expand Up @@ -133,7 +136,7 @@ async def _get_ephemeral(
instance: str,
pub_key: str,
enable_iam_auth: bool = False,
) -> str:
) -> Tuple[str, datetime.datetime]:
"""Asynchronously requests an ephemeral certificate from the Cloud SQL Instance.
:type sqladmin_api_endpoint: str
Expand Down Expand Up @@ -204,7 +207,18 @@ async def _get_ephemeral(

ret_dict = await resp.json()

return ret_dict["ephemeralCert"]["cert"]
ephemeral_cert: str = ret_dict["ephemeralCert"]["cert"]

# decode cert to read expiration
x509 = load_pem_x509_certificate(ephemeral_cert.encode("UTF-8"), default_backend())
expiration = x509.not_valid_after
# for IAM authentication OAuth2 token is embedded in cert so it
# must still be valid for successful connection
if enable_iam_auth:
token_expiration: datetime.datetime = login_creds.expiry
if expiration > token_expiration:
expiration = token_expiration
return ephemeral_cert, expiration


def _seconds_until_refresh(
Expand Down Expand Up @@ -274,7 +288,6 @@ def _downscope_credentials(
# Cloud SDK reference: https://github.com/google-cloud-sdk-unofficial/google-cloud-sdk/blob/93920ccb6d2cce0fe6d1ce841e9e33410551d66b/lib/googlecloudsdk/command_lib/sql/generate_login_token_util.py#L116
scoped_creds._scopes = scopes
# down-scoped credentials require refresh, are invalid after being re-scoped
if not scoped_creds.valid:
request = google.auth.transport.requests.Request()
scoped_creds.refresh(request)
request = google.auth.transport.requests.Request()
scoped_creds.refresh(request)
return scoped_creds
11 changes: 7 additions & 4 deletions tests/unit/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ async def test_perform_refresh(

@pytest.mark.asyncio
async def test_perform_refresh_expiration(
instance: Instance,
instance: Instance, fake_credentials: Credentials
) -> None:
"""
Test that _perform_refresh returns ConnectionInfo with proper expiration.
Expand All @@ -254,10 +254,13 @@ async def test_perform_refresh_expiration(
"""
# set credentials expiration to 1 minute from now
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=1)
setattr(instance._credentials, "expiry", expiration)
setattr(instance, "_enable_iam_auth", True)
# set all credentials to valid so downscoped credential does not refresh
with patch.object(Credentials, "valid", True):
# set downscoped credential to mock
with patch(
"google.cloud.sql.connector.refresh_utils._downscope_credentials"
) as mock_auth:
setattr(fake_credentials, "expiry", expiration)
mock_auth.return_value = fake_credentials
instance_metadata = await instance._perform_refresh()

# verify instance metadata object is returned
Expand Down
43 changes: 24 additions & 19 deletions tests/unit/test_refresh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
import pytest # noqa F401 Needed to run the tests

import google.auth
from google.auth.credentials import Credentials
from google.cloud.sql.connector.refresh_utils import (
_downscope_credentials,
Expand Down Expand Up @@ -76,12 +77,13 @@ async def test_get_ephemeral(
instance,
pub_key,
)
result = result.strip() # remove any trailing whitespace
result = result.split("\n")
cert, _ = result
cert = cert.strip() # remove any trailing whitespace
cert = cert.split("\n")

assert (
result[0] == "-----BEGIN CERTIFICATE-----"
and result[len(result) - 1] == "-----END CERTIFICATE-----"
cert[0] == "-----BEGIN CERTIFICATE-----"
and cert[len(cert) - 1] == "-----END CERTIFICATE-----"
)


Expand Down Expand Up @@ -287,19 +289,20 @@ async def test_is_valid_with_expired_metadata() -> None:
assert not await _is_valid(task)


def test_downscope_credentials_service_account(fake_credentials: Credentials) -> None:
"""
Test _downscope_credentials with google.oauth2.service_account.Credentials
which mimics an authenticated service account.
"""
# set all credentials to valid to skip refreshing credentials
with patch.object(Credentials, "valid", True):
credentials = _downscope_credentials(fake_credentials)
# verify default credential scopes have not been altered
assert fake_credentials.scopes == SCOPES
# verify downscoped credentials have new scope
assert credentials.scopes == ["https://www.googleapis.com/auth/sqlservice.login"]
assert credentials != fake_credentials
# TODO: https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/issues/901
# def test_downscope_credentials_service_account(fake_credentials: Credentials) -> None:
# """
# Test _downscope_credentials with google.oauth2.service_account.Credentials
# which mimics an authenticated service account.
# """
# # override actual refresh URI
# setattr(fake_credentials, "with_scopes", google.auth.credentials.Credentials(scopes=["https://www.googleapis.com/auth/sqlservice.login"]))
# credentials = _downscope_credentials(fake_credentials)
# # verify default credential scopes have not been altered
# assert fake_credentials.scopes == SCOPES
# # verify downscoped credentials have new scope
# assert credentials.scopes == ["https://www.googleapis.com/auth/sqlservice.login"]
# assert credentials != fake_credentials


def test_downscope_credentials_user() -> None:
Expand All @@ -308,8 +311,10 @@ def test_downscope_credentials_user() -> None:
which mimics an authenticated user.
"""
creds = google.oauth2.credentials.Credentials("token", scopes=SCOPES)
# set all credentials to valid to skip refreshing credentials
with patch.object(Credentials, "valid", True):
# override actual refresh URI
with patch.object(
google.oauth2.credentials.Credentials, "refresh", lambda *args: None
):
credentials = _downscope_credentials(creds)
# verify default credential scopes have not been altered
assert creds.scopes == SCOPES
Expand Down

0 comments on commit 7f8a3a4

Please sign in to comment.