diff --git a/crates/polars-io/src/cloud/credential_provider.rs b/crates/polars-io/src/cloud/credential_provider.rs index 8926c4d0a835..ddd23568e026 100644 --- a/crates/polars-io/src/cloud/credential_provider.rs +++ b/crates/polars-io/src/cloud/credential_provider.rs @@ -364,6 +364,19 @@ impl FetchedCredentialsCache { update_func: impl Future>, ) -> PolarsResult { let verbose = config::verbose(); + + fn expiry_msg(last_fetched_expiry: u64, now: u64) -> String { + if last_fetched_expiry == u64::MAX { + "expiry = (never expires)".into() + } else { + format!( + "expiry = {} (in {} seconds)", + last_fetched_expiry, + last_fetched_expiry.saturating_sub(now) + ) + } + } + let mut inner = self.0.lock().await; let (last_fetched_credentials, last_fetched_expiry) = &mut *inner; @@ -379,8 +392,8 @@ impl FetchedCredentialsCache { if last_fetched_expiry.saturating_sub(current_time) < REQUEST_TIME_BUFFER { if verbose { eprintln!( - "[FetchedCredentialsCache]: Call update_func: current_time = {},\ - last_fetched_expiry = {}", + "[FetchedCredentialsCache]: Call update_func: current_time = {}\ + , last_fetched_expiry = {}", current_time, *last_fetched_expiry ) } @@ -402,17 +415,27 @@ impl FetchedCredentialsCache { if verbose { eprintln!( - "[FetchedCredentialsCache]: Finish update_func: \ - new expiry = {} (in {} seconds)", - *last_fetched_expiry, - last_fetched_expiry.saturating_sub( + "[FetchedCredentialsCache]: Finish update_func: new {}", + expiry_msg( + *last_fetched_expiry, SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs() - ), + ) ) } + } else if verbose { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + eprintln!( + "[FetchedCredentialsCache]: Using cached credentials: \ + current_time = {}, {}", + now, + expiry_msg(*last_fetched_expiry, now) + ) } Ok(last_fetched_credentials.clone()) diff --git a/crates/polars-io/src/cloud/object_store_setup.rs b/crates/polars-io/src/cloud/object_store_setup.rs index b6464b109535..5f971f9a350d 100644 --- a/crates/polars-io/src/cloud/object_store_setup.rs +++ b/crates/polars-io/src/cloud/object_store_setup.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use object_store::local::LocalFileSystem; use object_store::ObjectStore; use once_cell::sync::Lazy; +use polars_core::config; use polars_error::{polars_bail, to_compute_err, PolarsError, PolarsResult}; use polars_utils::aliases::PlHashMap; use tokio::sync::RwLock; @@ -58,6 +59,8 @@ pub async fn build_object_store( let parsed = parse_url(url).map_err(to_compute_err)?; let cloud_location = CloudLocation::from_url(&parsed, glob)?; + // FIXME: `credential_provider` is currently serializing the entire Python function here + // into a string with pickle for this cache key because we are using `serde_json::to_string` let key = url_and_creds_to_key(&parsed, options); let mut allow_cache = true; @@ -124,6 +127,12 @@ pub async fn build_object_store( let mut cache = OBJECT_STORE_CACHE.write().await; // Clear the cache if we surpass a certain amount of buckets. if cache.len() > 8 { + if config::verbose() { + eprintln!( + "build_object_store: clearing store cache (cache.len(): {})", + cache.len() + ); + } cache.clear() } cache.insert(key, store.clone()); diff --git a/py-polars/docs/source/reference/io.rst b/py-polars/docs/source/reference/io.rst index 1f088958a3c0..bdfe93ee9ebe 100644 --- a/py-polars/docs/source/reference/io.rst +++ b/py-polars/docs/source/reference/io.rst @@ -117,3 +117,14 @@ Connect to pyarrow datasets. :toctree: api/ scan_pyarrow_dataset + +Cloud Credentials +~~~~~~~~~~~~~~~~~ +Configuration for cloud credential provisioning. + +.. autosummary:: + :toctree: api/ + + CredentialProvider + CredentialProviderAWS + CredentialProviderGCP diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index 10f0ee54228b..063f84c91126 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -176,6 +176,13 @@ scan_parquet, scan_pyarrow_dataset, ) +from polars.io.cloud import ( + CredentialProvider, + CredentialProviderAWS, + CredentialProviderFunction, + CredentialProviderFunctionReturn, + CredentialProviderGCP, +) from polars.lazyframe import GPUEngine, LazyFrame from polars.meta import ( build_info, @@ -266,6 +273,12 @@ "scan_ndjson", "scan_parquet", "scan_pyarrow_dataset", + # polars.io.cloud + "CredentialProvider", + "CredentialProviderAWS", + "CredentialProviderFunction", + "CredentialProviderFunctionReturn", + "CredentialProviderGCP", # polars.stringcache "StringCache", "disable_string_cache", diff --git a/py-polars/polars/_typing.py b/py-polars/polars/_typing.py index b53d67ee8b46..894ab9e79346 100644 --- a/py-polars/polars/_typing.py +++ b/py-polars/polars/_typing.py @@ -1,12 +1,12 @@ from __future__ import annotations from collections.abc import Collection, Iterable, Mapping, Sequence +from pathlib import Path from typing import ( + IO, TYPE_CHECKING, Any, - Callable, Literal, - Optional, Protocol, TypedDict, TypeVar, @@ -297,6 +297,6 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any: # LazyFrame engine selection EngineType: TypeAlias = Union[Literal["cpu", "gpu"], "GPUEngine"] -CredentialProviderFunction: TypeAlias = Callable[ - [], tuple[dict[str, Optional[str]], Optional[int]] +ScanSource: TypeAlias = Union[ + str, Path, IO[bytes], bytes, list[str], list[Path], list[IO[bytes]], list[bytes] ] diff --git a/py-polars/polars/io/_utils.py b/py-polars/polars/io/_utils.py index 68d4b604d6a6..527f1da01240 100644 --- a/py-polars/polars/io/_utils.py +++ b/py-polars/polars/io/_utils.py @@ -7,7 +7,11 @@ from pathlib import Path from typing import IO, TYPE_CHECKING, Any, overload -from polars._utils.various import is_int_sequence, is_str_sequence, normalize_filepath +from polars._utils.various import ( + is_int_sequence, + is_str_sequence, + normalize_filepath, +) from polars.dependencies import _FSSPEC_AVAILABLE, fsspec from polars.exceptions import NoDataError diff --git a/py-polars/polars/io/cloud/__init__.py b/py-polars/polars/io/cloud/__init__.py new file mode 100644 index 000000000000..f5ef9c5fd0bf --- /dev/null +++ b/py-polars/polars/io/cloud/__init__.py @@ -0,0 +1,15 @@ +from polars.io.cloud.credential_provider import ( + CredentialProvider, + CredentialProviderAWS, + CredentialProviderFunction, + CredentialProviderFunctionReturn, + CredentialProviderGCP, +) + +__all__ = [ + "CredentialProvider", + "CredentialProviderAWS", + "CredentialProviderFunction", + "CredentialProviderFunctionReturn", + "CredentialProviderGCP", +] diff --git a/py-polars/polars/io/cloud/_utils.py b/py-polars/polars/io/cloud/_utils.py new file mode 100644 index 000000000000..6b7e69c7d12f --- /dev/null +++ b/py-polars/polars/io/cloud/_utils.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +from polars._utils.various import is_path_or_str_sequence + +if TYPE_CHECKING: + from polars._typing import ScanSource + + +def _first_scan_path( + source: ScanSource, +) -> str | Path | None: + if isinstance(source, (str, Path)): + return source + elif is_path_or_str_sequence(source) and source: + return source[0] + + return None + + +def _infer_cloud_type( + source: ScanSource, +) -> Literal["aws", "azure", "gcp", "file", "http", "hf"] | None: + if (path := _first_scan_path(source)) is None: + return None + + splitted = str(path).split("://", maxsplit=1) + + # Fast path - local file + if not splitted: + return None + + scheme = splitted[0] + + if scheme == "file": + return "file" + + if any(scheme == x for x in ["s3", "s3a"]): + return "aws" + + if any(scheme == x for x in ["az", "azure", "adl", "abfs", "abfss"]): + return "azure" + + if any(scheme == x for x in ["gs", "gcp", "gcs"]): + return "gcp" + + if any(scheme == x for x in ["http", "https"]): + return "http" + + if scheme == "hf": + return "hf" + + return None diff --git a/py-polars/polars/io/cloud/credential_provider.py b/py-polars/polars/io/cloud/credential_provider.py new file mode 100644 index 000000000000..4bdde4b3b18a --- /dev/null +++ b/py-polars/polars/io/cloud/credential_provider.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import abc +import importlib.util +import os +import sys +import zoneinfo +from typing import TYPE_CHECKING, Callable, Optional, Union + +if TYPE_CHECKING: + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + +from polars._utils.unstable import issue_unstable_warning + +if TYPE_CHECKING: + from polars._typing import ScanSource + + +# These typedefs are here to avoid circular import issues, as +# `CredentialProviderFunction` specifies "CredentialProvider" +CredentialProviderFunctionReturn: TypeAlias = tuple[ + dict[str, Optional[str]], Optional[int] +] + +CredentialProviderFunction: TypeAlias = Union[ + Callable[[], CredentialProviderFunctionReturn], "CredentialProvider" +] + + +class CredentialProvider(abc.ABC): + """ + Base class for credential providers. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + """ + + @abc.abstractmethod + def __call__(self) -> CredentialProviderFunctionReturn: + """Fetches the credentials.""" + + +class CredentialProviderAWS(CredentialProvider): + """ + AWS Credential Provider. + + Using this requires the `boto3` Python package to be installed. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + """ + + def __init__(self, *, profile_name: str | None = None) -> None: + """ + Initialize a credential provider for AWS. + + Parameters + ---------- + profile_name : str + Profile name to use from credentials file. + """ + msg = "`CredentialProviderAWS` functionality is considered unstable" + issue_unstable_warning(msg) + + self._check_module_availability() + self.profile_name = profile_name + + def __call__(self) -> CredentialProviderFunctionReturn: + """Fetch the credentials for the configured profile name.""" + import boto3 + + session = boto3.Session(profile_name=self.profile_name) + creds = session.get_credentials() + + if creds is None: + msg = "unexpected None value returned from boto3.Session.get_credentials()" + raise ValueError(msg) + + return { + "aws_access_key_id": creds.access_key, + "aws_secret_access_key": creds.secret_key, + "aws_session_token": creds.token, + }, None + + @classmethod + def _check_module_availability(cls) -> None: + if importlib.util.find_spec("boto3") is None: + msg = "boto3 must be installed to use `CredentialProviderAWS`" + raise ImportError(msg) + + +class CredentialProviderGCP(CredentialProvider): + """ + GCP Credential Provider. + + Using this requires the `google-auth` Python package to be installed. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + """ + + def __init__(self) -> None: + """Initialize a credential provider for Google Cloud (GCP).""" + msg = "`CredentialProviderAWS` functionality is considered unstable" + issue_unstable_warning(msg) + + self._check_module_availability() + + import google.auth + import google.auth.credentials + + # CI runs with both `mypy` and `mypy --allow-untyped-calls` depending on + # Python version. If we add a `type: ignore[no-untyped-call]`, then the + # check that runs with `--allow-untyped-calls` will complain about an + # unused "type: ignore" comment. And if we don't add the ignore, then + # he check that runs `mypy` will complain. + # + # So we just bypass it with a __dict__[] (because ruff complains about + # getattr) :| + creds, _ = google.auth.__dict__["default"]() + self.creds = creds + + def __call__(self) -> CredentialProviderFunctionReturn: + """Fetch the credentials for the configured profile name.""" + import google.auth.transport.requests + + self.creds.refresh(google.auth.transport.requests.__dict__["Request"]()) + + return {"bearer_token": self.creds.token}, ( + int( + expiry.replace( + # Google auth does not set this properly + tzinfo=zoneinfo.ZoneInfo("UTC") + ).timestamp() + ) + if (expiry := self.creds.expiry) is not None + else None + ) + + @classmethod + def _check_module_availability(cls) -> None: + if importlib.util.find_spec("google.auth") is None: + msg = "google-auth must be installed to use `CredentialProviderGCP`" + raise ImportError(msg) + + +def _auto_select_credential_provider( + source: ScanSource, +) -> CredentialProvider | None: + from polars.io.cloud._utils import _infer_cloud_type + + verbose = os.getenv("POLARS_VERBOSE") == "1" + cloud_type = _infer_cloud_type(source) + + provider = None + + try: + provider = ( + None + if cloud_type is None + else CredentialProviderAWS() + if cloud_type == "aws" + else CredentialProviderGCP() + if cloud_type == "gcp" + else None + ) + except ImportError as e: + if verbose: + msg = f"Unable to auto-select credential provider: {e}" + print(msg, file=sys.stderr) + + if provider is not None and verbose: + msg = f"Auto-selected credential provider: {type(provider).__name__}" + print(msg, file=sys.stderr) + + return provider diff --git a/py-polars/polars/io/parquet/functions.py b/py-polars/polars/io/parquet/functions.py index 80f5e17e9849..1fd61aa9bac0 100644 --- a/py-polars/polars/io/parquet/functions.py +++ b/py-polars/polars/io/parquet/functions.py @@ -21,27 +21,24 @@ parse_row_index_args, prepare_file_arg, ) +from polars.io.cloud.credential_provider import _auto_select_credential_provider with contextlib.suppress(ImportError): from polars.polars import PyLazyFrame from polars.polars import read_parquet_schema as _read_parquet_schema if TYPE_CHECKING: + from typing import Literal + from polars import DataFrame, DataType, LazyFrame - from polars._typing import CredentialProviderFunction, ParallelStrategy, SchemaDict + from polars._typing import ParallelStrategy, ScanSource, SchemaDict + from polars.io.cloud import CredentialProviderFunction @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") def read_parquet( - source: str - | Path - | IO[bytes] - | bytes - | list[str] - | list[Path] - | list[IO[bytes]] - | list[bytes], + source: ScanSource, *, columns: list[int] | list[str] | None = None, n_rows: int | None = None, @@ -203,6 +200,7 @@ def read_parquet( rechunk=rechunk, ) + # TODO: FIXME: Move this to `scan_parquet` # Read file and bytes inputs using `read_parquet` if isinstance(source, bytes): source = io.BytesIO(source) @@ -212,7 +210,7 @@ def read_parquet( # For other inputs, defer to `scan_parquet` lf = scan_parquet( - source, # type: ignore[arg-type] + source, n_rows=n_rows, row_index_name=row_index_name, row_index_offset=row_index_offset, @@ -322,7 +320,7 @@ def read_parquet_schema(source: str | Path | IO[bytes] | bytes) -> dict[str, Dat @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") def scan_parquet( - source: str | Path | IO[bytes] | list[str] | list[Path] | list[IO[bytes]], + source: ScanSource, *, n_rows: int | None = None, row_index_name: str | None = None, @@ -338,7 +336,7 @@ def scan_parquet( low_memory: bool = False, cache: bool = True, storage_options: dict[str, Any] | None = None, - credential_provider: CredentialProviderFunction | None = None, + credential_provider: CredentialProviderFunction | Literal["auto"] | None = None, retries: int = 2, include_file_paths: str | None = None, allow_missing_columns: bool = False, @@ -476,10 +474,6 @@ def scan_parquet( msg = "The `hive_schema` parameter of `scan_parquet` is considered unstable." issue_unstable_warning(msg) - if credential_provider is not None: - msg = "The `credential_provider` parameter of `scan_parquet` is considered unstable." - issue_unstable_warning(msg) - if isinstance(source, (str, Path)): source = normalize_filepath(source, check_not_directory=False) elif is_path_or_str_sequence(source): @@ -487,6 +481,17 @@ def scan_parquet( normalize_filepath(source, check_not_directory=False) for source in source ] + if credential_provider is not None: + msg = "The `credential_provider` parameter of `scan_parquet` is considered unstable." + issue_unstable_warning(msg) + + if credential_provider == "auto": + credential_provider = ( + _auto_select_credential_provider(source) + if storage_options is None + else None + ) + return _scan_parquet_impl( source, # type: ignore[arg-type] n_rows=n_rows, diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index 9d0a2df292cb..e89a8e19c0b6 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -74,3 +74,4 @@ flask-cors # Stub files pandas-stubs boto3-stubs +google-auth-stubs diff --git a/py-polars/tests/unit/io/cloud/test_cloud.py b/py-polars/tests/unit/io/cloud/test_cloud.py index f943ab5e2c26..322d13c44a01 100644 --- a/py-polars/tests/unit/io/cloud/test_cloud.py +++ b/py-polars/tests/unit/io/cloud/test_cloud.py @@ -23,3 +23,32 @@ def test_scan_nonexistent_cloud_path_17444(format: str) -> None: # Upon collection, it should fail with pytest.raises(ComputeError): result.collect() + + +def test_scan_credential_provider(monkeypatch: pytest.MonkeyPatch) -> None: + err_magic = "err_magic_3" + + def raises(*_: None, **__: None) -> None: + raise AssertionError(err_magic) + + monkeypatch.setattr(pl.CredentialProviderAWS, "__init__", raises) + + with pytest.raises(AssertionError, match=err_magic): + pl.scan_parquet("s3://bucket/path", credential_provider="auto") + + # Passing `None` should disable the automatic instantiation of + # `CredentialProviderAWS` + pl.scan_parquet("s3://bucket/path", credential_provider=None) + # Passing `storage_options` should disable the automatic instantiation of + # `CredentialProviderAWS` + pl.scan_parquet("s3://bucket/path", credential_provider="auto", storage_options={}) + + err_magic = "err_magic_7" + + def raises_2() -> pl.CredentialProviderFunctionReturn: + raise AssertionError(err_magic) + + # Note to reader: It is converted to a ComputeError as it is being called + # from Rust. + with pytest.raises(ComputeError, match=err_magic): + pl.scan_parquet("s3://bucket/path", credential_provider=raises_2).collect()