From b3219e84e97258e8ecfe4fcd1c9450fc1f23838a Mon Sep 17 00:00:00 2001 From: yanghua Date: Mon, 23 Dec 2024 17:05:23 +0800 Subject: [PATCH] feat(python): support customize credential provider for write_dataset api in PyLance --- python/Cargo.lock | 138 +++++++++++++++++++++++++++++++++++++ python/Cargo.toml | 1 + python/src/dataset.rs | 19 ++++- python/src/lib.rs | 1 + python/src/object_store.rs | 129 ++++++++++++++++++++++++++++++++++ 5 files changed, 285 insertions(+), 3 deletions(-) create mode 100644 python/src/object_store.rs diff --git a/python/Cargo.lock b/python/Cargo.lock index fbf557e426..ed36127d8c 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -1814,6 +1814,15 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + [[package]] name = "env_logger" version = "0.10.2" @@ -1943,6 +1952,21 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -2411,6 +2435,22 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.5.1", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.10" @@ -3394,6 +3434,23 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2195bf6aa996a481483b29d62a7663eed3fe39600c460e323f8ff41e90bdd89b" +[[package]] +name = "native-tls" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework 2.11.1", + "security-framework-sys", + "tempfile", +] + [[package]] name = "noisy_float" version = "0.2.0" @@ -3586,12 +3643,50 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e296cf87e61c9cfc1a61c3c63a0f7f286ed4554e0e22be84e8a38e1d264a2a29" +[[package]] +name = "openssl" +version = "0.10.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "openssl-probe" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "openssl-sys" +version = "0.9.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "option-ext" version = "0.2.0" @@ -4122,6 +4217,7 @@ dependencies = [ "prost 0.13.3", "prost-build 0.11.9", "pyo3", + "reqwest", "serde", "serde_json", "serde_yaml", @@ -4446,6 +4542,7 @@ checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" dependencies = [ "base64 0.22.1", "bytes", + "encoding_rs", "futures-core", "futures-util", "h2 0.4.7", @@ -4454,11 +4551,13 @@ dependencies = [ "http-body-util", "hyper 1.5.1", "hyper-rustls 0.27.3", + "hyper-tls", "hyper-util", "ipnet", "js-sys", "log", "mime", + "native-tls", "once_cell", "percent-encoding", "pin-project-lite", @@ -4471,7 +4570,9 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper", + "system-configuration", "tokio", + "tokio-native-tls", "tokio-rustls 0.26.0", "tokio-util", "tower-service", @@ -5136,6 +5237,27 @@ dependencies = [ "syn 2.0.90", ] +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags 2.6.0", + "core-foundation 0.9.4", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tagptr" version = "0.2.0" @@ -5511,6 +5633,16 @@ dependencies = [ "syn 2.0.90", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.24.1" @@ -5814,6 +5946,12 @@ version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3ef4c4aa54d5d05a279399bfa921ec387b7aba77caf7a682ae8d86785b8fdad2" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" diff --git a/python/Cargo.toml b/python/Cargo.toml index e9e9f867c4..85614aa326 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -60,6 +60,7 @@ tracing-subscriber = "0.3.17" tracing = "0.1.37" url = "2.5.0" bytes = "1.4" +reqwest = { version = "0.12.9", features = [] } [features] datagen = ["lance-datagen"] diff --git a/python/src/dataset.rs b/python/src/dataset.rs index c4fd94aa12..04e9760d45 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -59,7 +59,9 @@ use lance_io::object_store::ObjectStoreParams; use lance_linalg::distance::MetricType; use lance_table::format::Fragment; use lance_table::io::commit::CommitHandler; +use object_store::aws::AwsCredential; use object_store::path::Path; +use object_store::CredentialProvider; use pyo3::exceptions::{PyStopIteration, PyTypeError}; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyInt, PyList, PySet, PyString}; @@ -71,17 +73,18 @@ use pyo3::{ }; use snafu::{location, Location}; +use self::cleanup::CleanupStats; +use self::commit::PyCommitLock; use crate::error::PythonErrorExt; use crate::file::object_store_from_uri_or_path; use crate::fragment::FileFragment; +use crate::object_store::UrlBasedCredentialProvider; use crate::schema::LanceSchema; use crate::session::Session; use crate::utils::PyLance; use crate::RT; use crate::{LanceReader, Scanner}; - -use self::cleanup::CleanupStats; -use self::commit::PyCommitLock; +use pyo3::types::PyAny; pub mod blob; pub mod cleanup; @@ -1590,8 +1593,18 @@ pub fn get_write_params(options: &PyDict) -> PyResult> { if let Some(storage_options) = get_dict_opt::>(options, "storage_options")? { + let credential_provider: Option< + Arc>, + > = if let Some(url) = storage_options.get("assume_role_credential_url") { + Some(Arc::new(UrlBasedCredentialProvider::new(url.parse()?)) + as Arc>) + } else { + None + }; + p.store_params = Some(ObjectStoreParams { storage_options: Some(storage_options), + aws_credentials: credential_provider, ..Default::default() }); } diff --git a/python/src/lib.rs b/python/src/lib.rs index c0b83dee9f..45acd761f3 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -59,6 +59,7 @@ pub(crate) mod executor; pub(crate) mod file; pub(crate) mod fragment; pub(crate) mod indices; +pub(crate) mod object_store; pub(crate) mod reader; pub(crate) mod scanner; pub(crate) mod schema; diff --git a/python/src/object_store.rs b/python/src/object_store.rs new file mode 100644 index 0000000000..903796f48d --- /dev/null +++ b/python/src/object_store.rs @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use object_store::aws::AwsCredential; +use object_store::CredentialProvider; +use reqwest::Client; +use serde::Deserialize; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; +use std::time::{Duration, SystemTime}; +use tokio::sync::Mutex; + +pub struct UrlBasedCredentialProvider { + url: String, + client: Client, + lock: Mutex<()>, + state: Mutex, +} + +#[derive(Deserialize, Debug)] +struct CredentialState { + #[serde(rename = "AccessKeyId")] + access_key_id: Option, + + #[serde(rename = "SecretAccessKey")] + secret_access_key: Option, + + #[serde(rename = "SessionToken")] + session_token: Option, + + #[serde(rename = "ExpiredTime")] + expired_time: Option, +} + +impl UrlBasedCredentialProvider { + pub fn new(url: String) -> Self { + Self { + url, + client: Client::new(), + lock: Mutex::new(()), + state: Mutex::new(CredentialState { + access_key_id: None, + secret_access_key: None, + session_token: None, + expired_time: None, + }), + } + } + + async fn try_get_credentials(&self) -> Option { + let state = self.state.lock().await; + if let Some(expiration) = state.expired_time { + if SystemTime::now() < expiration - Duration::from_secs(600) { + Some(AwsCredential { + key_id: state.access_key_id.clone().unwrap_or_default(), + secret_key: state.secret_access_key.clone().unwrap_or_default(), + token: state.session_token.clone(), + }); + } + } + None + } +} + +impl Debug for UrlBasedCredentialProvider { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "UrlBasedCredentialProvider {{ url: {} }}", self.url) + } +} + +#[async_trait] +impl CredentialProvider for UrlBasedCredentialProvider { + type Credential = AwsCredential; + + async fn get_credential(&self) -> object_store::Result> { + if let Some(credentials) = self.try_get_credentials().await { + return Ok(Arc::from(credentials)); + } + + let _guard = self.lock.lock().await; + if let Some(credentials) = self.try_get_credentials().await { + return Ok(Arc::from(credentials)); + } + + let response = match self.client.get(&self.url).send().await { + Ok(resp) => resp, + Err(_) => { + return Err(object_store::Error::Generic { + store: "Request credential error.", + source: Box::from(""), + }) + } + }; + + let credential_state: CredentialState = match response.json().await { + Ok(state) => state, + Err(_) => { + return Err(object_store::Error::Generic { + store: "Parse response JSON error.", + source: Box::from(""), + }) + } + }; + + let expiration_time: DateTime = match credential_state.expired_time { + Some(exp_time) => exp_time.into(), + None => { + return Err(object_store::Error::Generic { + store: "Parse expire time error.", + source: Box::from(""), + }) + } + }; + + let mut state = self.state.lock().await; + state.expired_time = Some(expiration_time.into()); + state.access_key_id = credential_state.access_key_id.clone(); + state.secret_access_key = credential_state.secret_access_key.clone(); + state.session_token = credential_state.session_token.clone(); + + Ok(Arc::from(AwsCredential { + key_id: state.access_key_id.clone().unwrap_or_default(), + secret_key: state.secret_access_key.clone().unwrap_or_default(), + token: state.session_token.clone(), + })) + } +}