Skip to content

Commit

Permalink
feat(python): AssumeRole support for AWS Credential Provider (#19346)
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Oct 22, 2024
1 parent 791c336 commit e8cfa44
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 40 deletions.
35 changes: 7 additions & 28 deletions py-polars/polars/io/cloud/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING

from polars._utils.various import is_path_or_str_sequence

Expand All @@ -20,36 +20,15 @@ def _first_scan_path(
return None


def _infer_cloud_type(
source: ScanSource,
) -> Literal["aws", "azure", "gcp", "file", "http", "hf"] | None:
if (path := _first_scan_path(source)) is None:
return None

def _get_path_scheme(path: str | Path) -> str | None:
splitted = str(path).split("://", maxsplit=1)

# Fast path - local file
if not splitted:
return None

scheme = splitted[0]

if scheme == "file":
return "file"
return None if not splitted else splitted[0]

if any(scheme == x for x in ["s3", "s3a"]):
return "aws"

if any(scheme == x for x in ["az", "azure", "adl", "abfs", "abfss"]):
return "azure"
def _is_aws_cloud(scheme: str) -> bool:
return any(scheme == x for x in ["s3", "s3a"])

if any(scheme == x for x in ["gs", "gcp", "gcs"]):
return "gcp"

if any(scheme == x for x in ["http", "https"]):
return "http"

if scheme == "hf":
return "hf"

return None
def _is_gcp_cloud(scheme: str) -> bool:
return any(scheme == x for x in ["gs", "gcp", "gcs"])
81 changes: 69 additions & 12 deletions py-polars/polars/io/cloud/credential_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import sys
import zoneinfo
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict, Union

if TYPE_CHECKING:
if sys.version_info >= (3, 10):
Expand All @@ -30,6 +30,23 @@
]


class AWSAssumeRoleKWArgs(TypedDict):
"""Parameters for [STS.Client.assume_role()](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role.html#STS.Client.assume_role)."""

RoleArn: str
RoleSessionName: str
PolicyArns: list[dict[str, str]]
Policy: str
DurationSeconds: int
Tags: list[dict[str, str]]
TransitiveTagKeys: list[str]
ExternalId: str
SerialNumber: str
TokenCode: str
SourceIdentity: str
ProvidedContexts: list[dict[str, str]]


class CredentialProvider(abc.ABC):
"""
Base class for credential providers.
Expand All @@ -55,26 +72,39 @@ class CredentialProviderAWS(CredentialProvider):
at any point without it being considered a breaking change.
"""

def __init__(self, *, profile_name: str | None = None) -> None:
def __init__(
self,
*,
profile_name: str | None = None,
assume_role: AWSAssumeRoleKWArgs | None = None,
) -> None:
"""
Initialize a credential provider for AWS.
Parameters
----------
profile_name : str
Profile name to use from credentials file.
assume_role : AWSAssumeRoleKWArgs | None
Configure a role to assume. These are passed as kwarg parameters to
[STS.client.assume_role()](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role.html#STS.Client.assume_role)
"""
msg = "`CredentialProviderAWS` functionality is considered unstable"
issue_unstable_warning(msg)

self._check_module_availability()
self.profile_name = profile_name
self.assume_role = assume_role

def __call__(self) -> CredentialProviderFunctionReturn:
"""Fetch the credentials for the configured profile name."""
import boto3

session = boto3.Session(profile_name=self.profile_name)

if self.assume_role is not None:
return self._finish_assume_role(session)

creds = session.get_credentials()

if creds is None:
Expand All @@ -87,6 +117,24 @@ def __call__(self) -> CredentialProviderFunctionReturn:
"aws_session_token": creds.token,
}, None

def _finish_assume_role(self, session: Any) -> CredentialProviderFunctionReturn:
client = session.client("sts")

sts_response = client.assume_role(**self.assume_role)
creds = sts_response["Credentials"]

expiry = creds["Expiration"]

if expiry.tzinfo is None:
msg = "expiration time in STS response did not contain timezone information"
raise ValueError(msg)

return {
"aws_access_key_id": creds["AccessKeyId"],
"aws_secret_access_key": creds["SecretAccessKey"],
"aws_session_token": creds["SessionToken"],
}, int(expiry.timestamp())

@classmethod
def _check_module_availability(cls) -> None:
if importlib.util.find_spec("boto3") is None:
Expand Down Expand Up @@ -134,9 +182,10 @@ def __call__(self) -> CredentialProviderFunctionReturn:

return {"bearer_token": self.creds.token}, (
int(
expiry.replace(
# Google auth does not set this properly
tzinfo=zoneinfo.ZoneInfo("UTC")
(
expiry.replace(tzinfo=zoneinfo.ZoneInfo("UTC"))
if expiry.tzinfo is None
else expiry
).timestamp()
)
if (expiry := self.creds.expiry) is not None
Expand All @@ -153,21 +202,29 @@ def _check_module_availability(cls) -> None:
def _auto_select_credential_provider(
source: ScanSource,
) -> CredentialProvider | None:
from polars.io.cloud._utils import _infer_cloud_type
from polars.io.cloud._utils import (
_first_scan_path,
_get_path_scheme,
_is_aws_cloud,
_is_gcp_cloud,
)

verbose = os.getenv("POLARS_VERBOSE") == "1"
cloud_type = _infer_cloud_type(source)

if (path := _first_scan_path(source)) is None:
return None

if (scheme := _get_path_scheme(path)) is None:
return None

provider = None

try:
provider = (
None
if cloud_type is None
else CredentialProviderAWS()
if cloud_type == "aws"
CredentialProviderAWS()
if _is_aws_cloud(scheme)
else CredentialProviderGCP()
if cloud_type == "gcp"
if _is_gcp_cloud(scheme)
else None
)
except ImportError as e:
Expand Down

0 comments on commit e8cfa44

Please sign in to comment.