Skip to content

Commit

Permalink
feat(python): Add credential_provider argument to more read functio…
Browse files Browse the repository at this point in the history
…ns (#19421)
  • Loading branch information
nameexhaustion authored Oct 25, 2024
1 parent 655dd0f commit 425e251
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 44 deletions.
37 changes: 31 additions & 6 deletions crates/polars-python/src/lazyframe/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl PyLazyFrame {
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (
source, sources, infer_schema_length, schema, schema_overrides, batch_size, n_rows, low_memory, rechunk,
row_index, ignore_errors, include_file_paths, cloud_options, retries, file_cache_ttl
row_index, ignore_errors, include_file_paths, cloud_options, credential_provider, retries, file_cache_ttl
))]
fn new_from_ndjson(
source: Option<PyObject>,
Expand All @@ -57,9 +57,11 @@ impl PyLazyFrame {
ignore_errors: bool,
include_file_paths: Option<String>,
cloud_options: Option<Vec<(String, String)>>,
credential_provider: Option<PyObject>,
retries: usize,
file_cache_ttl: Option<u64>,
) -> PyResult<Self> {
use cloud::credential_provider::PlCredentialProvider;
let row_index = row_index.map(|(name, offset)| RowIndex {
name: name.into(),
offset,
Expand All @@ -79,7 +81,11 @@ impl PyLazyFrame {

let mut cloud_options =
parse_cloud_options(&first_path_url, cloud_options.unwrap_or_default())?;
cloud_options = cloud_options.with_max_retries(retries);
cloud_options = cloud_options
.with_max_retries(retries)
.with_credential_provider(
credential_provider.map(PlCredentialProvider::from_python_func_object),
);

if let Some(file_cache_ttl) = file_cache_ttl {
cloud_options.file_cache_ttl = file_cache_ttl;
Expand Down Expand Up @@ -111,7 +117,7 @@ impl PyLazyFrame {
low_memory, comment_prefix, quote_char, null_values, missing_utf8_is_empty_string,
infer_schema_length, with_schema_modify, rechunk, skip_rows_after_header,
encoding, row_index, try_parse_dates, eol_char, raise_if_empty, truncate_ragged_lines, decimal_comma, glob, schema,
cloud_options, retries, file_cache_ttl, include_file_paths
cloud_options, credential_provider, retries, file_cache_ttl, include_file_paths
)
)]
fn new_from_csv(
Expand Down Expand Up @@ -143,10 +149,13 @@ impl PyLazyFrame {
glob: bool,
schema: Option<Wrap<Schema>>,
cloud_options: Option<Vec<(String, String)>>,
credential_provider: Option<PyObject>,
retries: usize,
file_cache_ttl: Option<u64>,
include_file_paths: Option<String>,
) -> PyResult<Self> {
use cloud::credential_provider::PlCredentialProvider;

let null_values = null_values.map(|w| w.0);
let quote_char = quote_char
.map(|s| {
Expand Down Expand Up @@ -198,7 +207,11 @@ impl PyLazyFrame {
if let Some(file_cache_ttl) = file_cache_ttl {
cloud_options.file_cache_ttl = file_cache_ttl;
}
cloud_options = cloud_options.with_max_retries(retries);
cloud_options = cloud_options
.with_max_retries(retries)
.with_credential_provider(
credential_provider.map(PlCredentialProvider::from_python_func_object),
);
r = r.with_cloud_options(Some(cloud_options));
}

Expand Down Expand Up @@ -343,7 +356,11 @@ impl PyLazyFrame {

#[cfg(feature = "ipc")]
#[staticmethod]
#[pyo3(signature = (source, sources, n_rows, cache, rechunk, row_index, cloud_options, hive_partitioning, hive_schema, try_parse_hive_dates, retries, file_cache_ttl, include_file_paths))]
#[pyo3(signature = (
source, sources, n_rows, cache, rechunk, row_index, cloud_options,credential_provider,
hive_partitioning, hive_schema, try_parse_hive_dates, retries, file_cache_ttl,
include_file_paths
))]
fn new_from_ipc(
source: Option<PyObject>,
sources: Wrap<ScanSources>,
Expand All @@ -352,13 +369,15 @@ impl PyLazyFrame {
rechunk: bool,
row_index: Option<(String, IdxSize)>,
cloud_options: Option<Vec<(String, String)>>,
credential_provider: Option<PyObject>,
hive_partitioning: Option<bool>,
hive_schema: Option<Wrap<Schema>>,
try_parse_hive_dates: bool,
retries: usize,
file_cache_ttl: Option<u64>,
include_file_paths: Option<String>,
) -> PyResult<Self> {
use cloud::credential_provider::PlCredentialProvider;
let row_index = row_index.map(|(name, offset)| RowIndex {
name: name.into(),
offset,
Expand Down Expand Up @@ -397,7 +416,13 @@ impl PyLazyFrame {
if let Some(file_cache_ttl) = file_cache_ttl {
cloud_options.file_cache_ttl = file_cache_ttl;
}
args.cloud_options = Some(cloud_options.with_max_retries(retries));
args.cloud_options = Some(
cloud_options
.with_max_retries(retries)
.with_credential_provider(
credential_provider.map(PlCredentialProvider::from_python_func_object),
),
);
}

let lf = LazyFrame::scan_ipc_sources(sources, args).map_err(PyPolarsErr::from)?;
Expand Down
9 changes: 8 additions & 1 deletion py-polars/polars/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,5 +298,12 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any:
EngineType: TypeAlias = Union[Literal["cpu", "gpu"], "GPUEngine"]

ScanSource: TypeAlias = Union[
str, Path, IO[bytes], bytes, list[str], list[Path], list[IO[bytes]], list[bytes]
str,
Path,
IO[bytes],
bytes,
list[str],
list[Path],
list[IO[bytes]],
list[bytes],
]
16 changes: 11 additions & 5 deletions py-polars/polars/io/cloud/_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING
from typing import IO

from polars._utils.various import is_path_or_str_sequence

if TYPE_CHECKING:
from polars._typing import ScanSource


def _first_scan_path(
source: ScanSource,
source: str
| Path
| IO[str]
| IO[bytes]
| bytes
| list[str]
| list[Path]
| list[IO[str]]
| list[IO[bytes]]
| list[bytes],
) -> str | Path | None:
if isinstance(source, (str, Path)):
return source
Expand Down
35 changes: 27 additions & 8 deletions py-polars/polars/io/cloud/credential_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,17 @@
import os
import sys
import zoneinfo
from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict, Union
from typing import IO, TYPE_CHECKING, Any, Callable, Literal, Optional, TypedDict, Union

if TYPE_CHECKING:
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
from pathlib import Path

from polars._utils.unstable import issue_unstable_warning

if TYPE_CHECKING:
from polars._typing import ScanSource


# These typedefs are here to avoid circular import issues, as
# `CredentialProviderFunction` specifies "CredentialProvider"
CredentialProviderFunctionReturn: TypeAlias = tuple[
Expand Down Expand Up @@ -199,16 +196,38 @@ def _check_module_availability(cls) -> None:
raise ImportError(msg)


def _auto_select_credential_provider(
source: ScanSource,
) -> CredentialProvider | None:
def _maybe_init_credential_provider(
credential_provider: CredentialProviderFunction | Literal["auto"] | None,
source: str
| Path
| IO[str]
| IO[bytes]
| bytes
| list[str]
| list[Path]
| list[IO[str]]
| list[IO[bytes]]
| list[bytes],
storage_options: dict[str, Any] | None,
caller_name: str,
) -> CredentialProviderFunction | CredentialProvider | None:
from polars.io.cloud._utils import (
_first_scan_path,
_get_path_scheme,
_is_aws_cloud,
_is_gcp_cloud,
)

if credential_provider is not None:
msg = f"The `credential_provider` parameter of `{caller_name}` is considered unstable."
issue_unstable_warning(msg)

if credential_provider != "auto":
return credential_provider

if storage_options is not None:
return None

verbose = os.getenv("POLARS_VERBOSE") == "1"

if (path := _first_scan_path(source)) is None:
Expand Down
20 changes: 19 additions & 1 deletion py-polars/polars/io/csv/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Sequence
from io import BytesIO, StringIO
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable
from typing import IO, TYPE_CHECKING, Any, Callable, Literal

import polars._reexport as pl
import polars.functions as F
Expand All @@ -24,6 +24,7 @@
parse_row_index_args,
prepare_file_arg,
)
from polars.io.cloud.credential_provider import _maybe_init_credential_provider
from polars.io.csv._utils import _check_arg_is_1byte, _update_columns
from polars.io.csv.batched_reader import BatchedCsvReader

Expand All @@ -35,6 +36,7 @@

from polars import DataFrame, LazyFrame
from polars._typing import CsvEncoding, PolarsDataType, SchemaDict
from polars.io.cloud import CredentialProviderFunction


@deprecate_renamed_parameter("dtypes", "schema_overrides", version="0.20.31")
Expand Down Expand Up @@ -1034,6 +1036,7 @@ def scan_csv(
decimal_comma: bool = False,
glob: bool = True,
storage_options: dict[str, Any] | None = None,
credential_provider: CredentialProviderFunction | Literal["auto"] | None = None,
retries: int = 2,
file_cache_ttl: int | None = None,
include_file_paths: str | None = None,
Expand Down Expand Up @@ -1154,6 +1157,14 @@ def scan_csv(
If `storage_options` is not provided, Polars will try to infer the information
from environment variables.
credential_provider
Provide a function that can be called to provide cloud storage
credentials. The function is expected to return a dictionary of
credential keys along with an optional credential expiry time.
.. warning::
This functionality is considered **unstable**. It may be changed
at any point without it being considered a breaking change.
retries
Number of retries if accessing a cloud instance fails.
file_cache_ttl
Expand Down Expand Up @@ -1259,6 +1270,10 @@ def with_column_names(cols: list[str]) -> list[str]:
if not infer_schema:
infer_schema_length = 0

credential_provider = _maybe_init_credential_provider(
credential_provider, source, storage_options, "scan_csv"
)

return _scan_csv_impl(
source,
has_header=has_header,
Expand Down Expand Up @@ -1289,6 +1304,7 @@ def with_column_names(cols: list[str]) -> list[str]:
glob=glob,
retries=retries,
storage_options=storage_options,
credential_provider=credential_provider,
file_cache_ttl=file_cache_ttl,
include_file_paths=include_file_paths,
)
Expand Down Expand Up @@ -1332,6 +1348,7 @@ def _scan_csv_impl(
decimal_comma: bool = False,
glob: bool = True,
storage_options: dict[str, Any] | None = None,
credential_provider: CredentialProviderFunction | None = None,
retries: int = 2,
file_cache_ttl: int | None = None,
include_file_paths: str | None = None,
Expand Down Expand Up @@ -1384,6 +1401,7 @@ def _scan_csv_impl(
glob=glob,
schema=schema,
cloud_options=storage_options,
credential_provider=credential_provider,
retries=retries,
file_cache_ttl=file_cache_ttl,
include_file_paths=include_file_paths,
Expand Down
25 changes: 24 additions & 1 deletion py-polars/polars/io/ipc/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import contextlib
import os
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any
from typing import IO, TYPE_CHECKING, Any, Literal

import polars._reexport as pl
import polars.functions as F
Expand All @@ -22,6 +22,7 @@
parse_row_index_args,
prepare_file_arg,
)
from polars.io.cloud.credential_provider import _maybe_init_credential_provider

with contextlib.suppress(ImportError): # Module not available when building docs
from polars.polars import PyDataFrame, PyLazyFrame
Expand All @@ -32,6 +33,7 @@

from polars import DataFrame, DataType, LazyFrame
from polars._typing import SchemaDict
from polars.io.cloud import CredentialProviderFunction


@deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4")
Expand Down Expand Up @@ -362,6 +364,7 @@ def scan_ipc(
row_index_name: str | None = None,
row_index_offset: int = 0,
storage_options: dict[str, Any] | None = None,
credential_provider: CredentialProviderFunction | Literal["auto"] | None = None,
memory_map: bool = True,
retries: int = 2,
file_cache_ttl: int | None = None,
Expand Down Expand Up @@ -407,6 +410,15 @@ def scan_ipc(
If `storage_options` is not provided, Polars will try to infer the information
from environment variables.
credential_provider
Provide a function that can be called to provide cloud storage
credentials. The function is expected to return a dictionary of
credential keys along with an optional credential expiry time.
.. warning::
This functionality is considered **unstable**. It may be changed
at any point without it being considered a breaking change.
memory_map
Try to memory map the file. This can greatly improve performance on repeated
queries as the OS may cache pages.
Expand Down Expand Up @@ -451,6 +463,16 @@ def scan_ipc(
# Memory Mapping is now a no-op
_ = memory_map

credential_provider = _maybe_init_credential_provider(
credential_provider, source, storage_options, "scan_parquet"
)

if storage_options:
storage_options = list(storage_options.items()) # type: ignore[assignment]
else:
# Handle empty dict input
storage_options = None

pylf = PyLazyFrame.new_from_ipc(
source,
sources,
Expand All @@ -459,6 +481,7 @@ def scan_ipc(
rechunk,
parse_row_index_args(row_index_name, row_index_offset),
cloud_options=storage_options,
credential_provider=credential_provider,
retries=retries,
file_cache_ttl=file_cache_ttl,
hive_partitioning=hive_partitioning,
Expand Down
Loading

0 comments on commit 425e251

Please sign in to comment.