diff --git a/py-polars/polars/io/cloud/_utils.py b/py-polars/polars/io/cloud/_utils.py index 6b7e69c7d12f..c057916c62d8 100644 --- a/py-polars/polars/io/cloud/_utils.py +++ b/py-polars/polars/io/cloud/_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING from polars._utils.various import is_path_or_str_sequence @@ -20,36 +20,15 @@ def _first_scan_path( return None -def _infer_cloud_type( - source: ScanSource, -) -> Literal["aws", "azure", "gcp", "file", "http", "hf"] | None: - if (path := _first_scan_path(source)) is None: - return None - +def _get_path_scheme(path: str | Path) -> str | None: splitted = str(path).split("://", maxsplit=1) - # Fast path - local file - if not splitted: - return None - - scheme = splitted[0] - - if scheme == "file": - return "file" + return None if not splitted else splitted[0] - if any(scheme == x for x in ["s3", "s3a"]): - return "aws" - if any(scheme == x for x in ["az", "azure", "adl", "abfs", "abfss"]): - return "azure" +def _is_aws_cloud(scheme: str) -> bool: + return any(scheme == x for x in ["s3", "s3a"]) - if any(scheme == x for x in ["gs", "gcp", "gcs"]): - return "gcp" - if any(scheme == x for x in ["http", "https"]): - return "http" - - if scheme == "hf": - return "hf" - - return None +def _is_gcp_cloud(scheme: str) -> bool: + return any(scheme == x for x in ["gs", "gcp", "gcs"]) diff --git a/py-polars/polars/io/cloud/credential_provider.py b/py-polars/polars/io/cloud/credential_provider.py index 4bdde4b3b18a..1dd79662b642 100644 --- a/py-polars/polars/io/cloud/credential_provider.py +++ b/py-polars/polars/io/cloud/credential_provider.py @@ -5,7 +5,7 @@ import os import sys import zoneinfo -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict, Union if TYPE_CHECKING: if sys.version_info >= (3, 10): @@ -30,6 +30,23 @@ ] +class AWSAssumeRoleKWArgs(TypedDict): + """Parameters for [STS.Client.assume_role()](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role.html#STS.Client.assume_role).""" + + RoleArn: str + RoleSessionName: str + PolicyArns: list[dict[str, str]] + Policy: str + DurationSeconds: int + Tags: list[dict[str, str]] + TransitiveTagKeys: list[str] + ExternalId: str + SerialNumber: str + TokenCode: str + SourceIdentity: str + ProvidedContexts: list[dict[str, str]] + + class CredentialProvider(abc.ABC): """ Base class for credential providers. @@ -55,7 +72,12 @@ class CredentialProviderAWS(CredentialProvider): at any point without it being considered a breaking change. """ - def __init__(self, *, profile_name: str | None = None) -> None: + def __init__( + self, + *, + profile_name: str | None = None, + assume_role: AWSAssumeRoleKWArgs | None = None, + ) -> None: """ Initialize a credential provider for AWS. @@ -63,18 +85,26 @@ def __init__(self, *, profile_name: str | None = None) -> None: ---------- profile_name : str Profile name to use from credentials file. + assume_role : AWSAssumeRoleKWArgs | None + Configure a role to assume. These are passed as kwarg parameters to + [STS.client.assume_role()](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role.html#STS.Client.assume_role) """ msg = "`CredentialProviderAWS` functionality is considered unstable" issue_unstable_warning(msg) self._check_module_availability() self.profile_name = profile_name + self.assume_role = assume_role def __call__(self) -> CredentialProviderFunctionReturn: """Fetch the credentials for the configured profile name.""" import boto3 session = boto3.Session(profile_name=self.profile_name) + + if self.assume_role is not None: + return self._finish_assume_role(session) + creds = session.get_credentials() if creds is None: @@ -87,6 +117,24 @@ def __call__(self) -> CredentialProviderFunctionReturn: "aws_session_token": creds.token, }, None + def _finish_assume_role(self, session: Any) -> CredentialProviderFunctionReturn: + client = session.client("sts") + + sts_response = client.assume_role(**self.assume_role) + creds = sts_response["Credentials"] + + expiry = creds["Expiration"] + + if expiry.tzinfo is None: + msg = "expiration time in STS response did not contain timezone information" + raise ValueError(msg) + + return { + "aws_access_key_id": creds["AccessKeyId"], + "aws_secret_access_key": creds["SecretAccessKey"], + "aws_session_token": creds["SessionToken"], + }, int(expiry.timestamp()) + @classmethod def _check_module_availability(cls) -> None: if importlib.util.find_spec("boto3") is None: @@ -134,9 +182,10 @@ def __call__(self) -> CredentialProviderFunctionReturn: return {"bearer_token": self.creds.token}, ( int( - expiry.replace( - # Google auth does not set this properly - tzinfo=zoneinfo.ZoneInfo("UTC") + ( + expiry.replace(tzinfo=zoneinfo.ZoneInfo("UTC")) + if expiry.tzinfo is None + else expiry ).timestamp() ) if (expiry := self.creds.expiry) is not None @@ -153,21 +202,29 @@ def _check_module_availability(cls) -> None: def _auto_select_credential_provider( source: ScanSource, ) -> CredentialProvider | None: - from polars.io.cloud._utils import _infer_cloud_type + from polars.io.cloud._utils import ( + _first_scan_path, + _get_path_scheme, + _is_aws_cloud, + _is_gcp_cloud, + ) verbose = os.getenv("POLARS_VERBOSE") == "1" - cloud_type = _infer_cloud_type(source) + + if (path := _first_scan_path(source)) is None: + return None + + if (scheme := _get_path_scheme(path)) is None: + return None provider = None try: provider = ( - None - if cloud_type is None - else CredentialProviderAWS() - if cloud_type == "aws" + CredentialProviderAWS() + if _is_aws_cloud(scheme) else CredentialProviderGCP() - if cloud_type == "gcp" + if _is_gcp_cloud(scheme) else None ) except ImportError as e: