Skip to content

Commit

Permalink
refactor: use google.auth TokenState for credentials validity
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon committed May 24, 2024
1 parent abbe586 commit 7b6f042
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 17 deletions.
9 changes: 0 additions & 9 deletions google/cloud/alloydb/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING

import aiohttp
from google.auth.transport.requests import Request

from google.cloud.alloydb.connector.version import __version__ as version

Expand Down Expand Up @@ -116,10 +115,6 @@ async def _get_metadata(
"""
logger.debug(f"['{project}/{region}/{cluster}/{name}']: Requesting metadata")

if not self._credentials.valid:
request = Request()
self._credentials.refresh(request)

headers = {
"Authorization": f"Bearer {self._credentials.token}",
}
Expand Down Expand Up @@ -167,10 +162,6 @@ async def _get_client_certificate(
"""
logger.debug(f"['{project}/{region}/{cluster}']: Requesting client certificate")

if not self._credentials.valid:
request = Request()
self._credentials.refresh(request)

headers = {
"Authorization": f"Bearer {self._credentials.token}",
}
Expand Down
8 changes: 8 additions & 0 deletions google/cloud/alloydb/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import re
from typing import Tuple, TYPE_CHECKING

from google.auth.credentials import TokenState
from google.auth.transport import requests

from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError
from google.cloud.alloydb.connector.exceptions import RefreshError
from google.cloud.alloydb.connector.rate_limiter import AsyncRateLimiter
Expand Down Expand Up @@ -130,6 +133,11 @@ async def _perform_refresh(self) -> RefreshResult:
try:
await self._refresh_rate_limiter.acquire()
priv_key, pub_key = await self._keys

# before making AlloyDB API calls, refresh creds if required
if not self._client._credentials.token_state == TokenState.FRESH:
self._client._credentials.refresh(requests.Request())

# fetch metadata
metadata_task = asyncio.create_task(
self._client._get_metadata(
Expand Down
39 changes: 31 additions & 8 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import ipaddress
import ssl
import struct
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
from google.auth.credentials import _helpers
from google.auth.credentials import TokenState

import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb

Expand All @@ -35,7 +37,7 @@ def __init__(self) -> None:
self.token: Optional[str] = None
self.expiry: Optional[datetime] = None

def refresh(self, request: Callable) -> None:
def refresh(self, _: Callable) -> None:
"""Refreshes the access token."""
self.token = "12345"
self.expiry = datetime.now(timezone.utc) + timedelta(minutes=60)
Expand All @@ -51,13 +53,33 @@ def expired(self) -> bool:
return False if not self.expiry else True

@property
def valid(self) -> bool:
"""Checks the validity of the credentials.
This is True if the credentials have a token and the token
is not expired.
def token_state(
self,
) -> Literal[TokenState.FRESH, TokenState.STALE, TokenState.INVALID]:
"""
return self.token is not None and not self.expired
Tracks the state of a token.
FRESH: The token is valid. It is not expired or close to expired, or the token has no expiry.
STALE: The token is close to expired, and should be refreshed. The token can be used normally.
INVALID: The token is expired or invalid. The token cannot be used for a normal operation.
"""
if self.token is None:
return TokenState.INVALID

# Credentials that can't expire are always treated as fresh.
if self.expiry is None:
return TokenState.FRESH

expired = datetime.now(timezone.utc) >= self.expiry
if expired:
return TokenState.INVALID

is_stale = datetime.now(timezone.utc) >= (
self.expiry - _helpers.REFRESH_THRESHOLD
)
if is_stale:
return TokenState.STALE

return TokenState.FRESH


def generate_cert(
Expand Down Expand Up @@ -180,6 +202,7 @@ def __init__(
self.instance = FakeInstance() if instance is None else instance
self.closed = False
self._user_agent = f"test-user-agent+{driver}"
self._credentials = FakeCredentials()

async def _get_metadata(self, *args: Any, **kwargs: Any) -> str:
return self.instance.ip_addrs
Expand Down

0 comments on commit 7b6f042

Please sign in to comment.