Skip to content

Commit

Permalink
Add support for IRSA authentication for S3 (#694)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwylde committed Jul 30, 2024
1 parent a254ff7 commit 67d33c7
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 414 deletions.
378 changes: 96 additions & 282 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ deltalake = { version = "0.17.3" }
cornucopia = { version = "0.9.0" }
cornucopia_async = {version = "0.6.0"}
deadpool-postgres = "0.12"

[profile.release]
debug = 1

Expand Down
4 changes: 3 additions & 1 deletion crates/arroyo-server-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ pub fn init_logging_with_filter(_name: &str, filter: EnvFilter) -> WorkerGuard {
eprintln!("Failed to initialize log tracer {:?}", e);
}

let filter = filter.add_directive("refinery_core=warn".parse().unwrap());
let filter = filter
.add_directive("refinery_core=warn".parse().unwrap())
.add_directive("aws_config::profile::credentials=warn".parse().unwrap());

let (nonblocking, guard) = tracing_appender::non_blocking(std::io::stderr());

Expand Down
27 changes: 26 additions & 1 deletion crates/arroyo-state/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use anyhow::Result;
use anyhow::{Context, Result};
use arrow_array::RecordBatch;
use arroyo_rpc::grpc::rpc::{
CheckpointMetadata, ExpiringKeyedTimeTableConfig, GlobalKeyedTableConfig,
Expand All @@ -9,12 +9,15 @@ use async_trait::async_trait;
use bincode::config::Configuration;
use bincode::{Decode, Encode};

use arroyo_rpc::config::config;
use arroyo_rpc::df::ArroyoSchema;
use arroyo_storage::StorageProvider;
use prost::Message;
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::ops::RangeInclusive;
use std::sync::Arc;
use std::time::{Duration, SystemTime};

pub mod checkpoint_state;
Expand Down Expand Up @@ -160,3 +163,25 @@ pub fn hash_key<K: Hash>(key: &K) -> u64 {
key.hash(&mut hasher);
hasher.finish()
}

static STORAGE_PROVIDER: tokio::sync::OnceCell<Arc<StorageProvider>> =
tokio::sync::OnceCell::const_new();

pub(crate) async fn get_storage_provider() -> Result<&'static Arc<StorageProvider>> {
// TODO: this should be encoded in the config so that the controller doesn't need
// to be synchronized with the workers

STORAGE_PROVIDER
.get_or_try_init(|| async {
let storage_url = &config().checkpoint_url;

StorageProvider::for_url(storage_url)
.await
.context(format!(
"failed to construct checkpoint backend for URL {}",
storage_url
))
.map(Arc::new)
})
.await
}
20 changes: 4 additions & 16 deletions crates/arroyo-state/src/parquet.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use crate::tables::expiring_time_key_map::ExpiringTimeKeyTable;
use crate::tables::global_keyed_map::GlobalKeyedTable;
use crate::tables::{CompactionConfig, ErasedTable};
use crate::BackingStore;
use anyhow::{bail, Context, Result};
use crate::{get_storage_provider, BackingStore};
use anyhow::{bail, Result};
use arroyo_rpc::grpc::rpc::{
CheckpointMetadata, OperatorCheckpointMetadata, TableCheckpointMetadata,
};
use arroyo_storage::StorageProvider;
use futures::stream::FuturesUnordered;
use futures::StreamExt;

Expand All @@ -23,17 +22,6 @@ use tracing::{debug, info};
pub const FULL_KEY_RANGE: RangeInclusive<u64> = 0..=u64::MAX;
pub const GENERATIONS_TO_COMPACT: u32 = 1; // only compact generation 0 files

async fn get_storage_provider() -> anyhow::Result<StorageProvider> {
// TODO: this should be encoded in the config so that the controller doesn't need
// to be synchronized with the workers
let storage_url = &config().checkpoint_url;

StorageProvider::for_url(storage_url).await.context(format!(
"failed to construct checkpoint backend for URL {}",
storage_url
))
}

pub struct ParquetBackend;

fn base_path(job_id: &str, epoch: u32) -> String {
Expand Down Expand Up @@ -178,11 +166,11 @@ impl ParquetBackend {
Self::load_operator_metadata(&job_id, &operator_id, epoch)
.await?
.expect("expect operator metadata to still be present");
let storage_provider = Arc::new(get_storage_provider().await?);
let storage_provider = get_storage_provider().await?;
let compaction_config = CompactionConfig {
storage_provider,
compact_generations: vec![0].into_iter().collect(),
min_compaction_epochs: min_files_to_compact,
storage_provider: Arc::clone(storage_provider),
};
let operator_metadata = operator_checkpoint_metadata.operator_metadata.unwrap();

Expand Down
18 changes: 2 additions & 16 deletions crates/arroyo-state/src/tables/table_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use tokio::sync::{
use arroyo_rpc::config::config;
use tracing::{debug, error, info, warn};

use crate::{tables::global_keyed_map::GlobalKeyedTable, StateMessage};
use crate::{get_storage_provider, tables::global_keyed_map::GlobalKeyedTable, StateMessage};
use crate::{CheckpointMessage, TableData};

use super::expiring_time_key_map::{
Expand Down Expand Up @@ -225,20 +225,6 @@ impl BackendWriter {
}
}

async fn get_storage_provider() -> anyhow::Result<StorageProviderRef> {
// TODO: this should be encoded in the config so that the controller doesn't need
// to be synchronized with the workers

Ok(Arc::new(
StorageProvider::for_url(&config().checkpoint_url)
.await
.context(format!(
"failed to construct checkpoint backend for URL {}",
config().checkpoint_url
))?,
))
}

impl TableManager {
pub async fn new(
task_info: TaskInfoRef,
Expand Down Expand Up @@ -320,7 +306,7 @@ impl TableManager {
tables,
writer,
task_info,
storage,
storage: Arc::clone(storage),
caches: HashMap::new(),
})
}
Expand Down
5 changes: 2 additions & 3 deletions crates/arroyo-storage/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ arroyo-types = { path = "../arroyo-types" }
arroyo-rpc = { path = "../arroyo-rpc" }
bytes = "1.4.0"
tracing = "0.1"
# used only for getting local AWS credentials; can be removed once we have a
# better way to do this
rusoto_core = "0.48.0"

aws-credential-types = "1.2.0"
aws-config = { version = "1.5.4" }
rand = "0.8"
object_store = {workspace = true, features = ["aws", "gcp"]}
regex = "1.9.5"
Expand Down
62 changes: 34 additions & 28 deletions crates/arroyo-storage/src/aws.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
use std::sync::Arc;

use object_store::{aws::AwsCredential, CredentialProvider};
use rusoto_core::credential::{
AutoRefreshingProvider, ChainProvider, ProfileProvider, ProvideAwsCredentials,
};

use crate::StorageError;
use aws_config::BehaviorVersion;
use aws_credential_types::provider::ProvideCredentials;
use object_store::{aws::AwsCredential, CredentialProvider};
use std::sync::Arc;

pub struct ArroyoCredentialProvider {
provider: AutoRefreshingProvider<ChainProvider>,
provider: aws_credential_types::provider::SharedCredentialsProvider,
}

impl std::fmt::Debug for ArroyoCredentialProvider {
Expand All @@ -18,38 +15,47 @@ impl std::fmt::Debug for ArroyoCredentialProvider {
}

impl ArroyoCredentialProvider {
pub fn try_new() -> Result<Self, StorageError> {
let inner: AutoRefreshingProvider<ChainProvider> =
AutoRefreshingProvider::new(ChainProvider::new())
.map_err(|e| StorageError::CredentialsError(e.to_string()))?;

Ok(Self { provider: inner })
pub async fn try_new() -> Result<Self, StorageError> {
let config = aws_config::defaults(BehaviorVersion::latest()).load().await;

let credentials = config
.credentials_provider()
.ok_or_else(|| {
StorageError::CredentialsError(
"Unable to load S3 credentials from environment".to_string(),
)
})?
.clone();

Ok(Self {
provider: credentials,
})
}

pub async fn default_region() -> Option<String> {
ProfileProvider::region().ok()?
aws_config::defaults(BehaviorVersion::latest())
.load()
.await
.region()
.map(|r| r.to_string())
}
}

#[async_trait::async_trait]
impl CredentialProvider for ArroyoCredentialProvider {
#[doc = " The type of credential returned by this provider"]
type Credential = AwsCredential;

/// Return a credential
async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
let credentials =
self.provider
.credentials()
.await
.map_err(|err| object_store::Error::Generic {
store: "s3",
source: Box::new(err),
})?;
let creds = self.provider.provide_credentials().await.map_err(|e| {
object_store::Error::Generic {
store: "S3",
source: Box::new(e),
}
})?;
Ok(Arc::new(AwsCredential {
key_id: credentials.aws_access_key_id().to_string(),
secret_key: credentials.aws_secret_access_key().to_string(),
token: credentials.token().clone(),
key_id: creds.access_key_id().to_string(),
secret_key: creds.secret_access_key().to_string(),
token: creds.session_token().map(ToString::to_string),
}))
}
}
67 changes: 4 additions & 63 deletions crates/arroyo-storage/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@ use object_store::multipart::PartId;
use object_store::path::Path;
use object_store::{aws::AmazonS3Builder, local::LocalFileSystem, ObjectStore};
use object_store::{CredentialProvider, MultipartId};
use once_cell::sync::Lazy;
use regex::{Captures, Regex};
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::sync::RwLock;
use tracing::{debug, error, trace};
use tracing::{debug, error};

mod aws;

Expand Down Expand Up @@ -296,23 +293,11 @@ fn last<I: Sized, const COUNT: usize>(opts: [Option<I>; COUNT]) -> Option<I> {
}

pub async fn get_current_credentials() -> Result<Arc<AwsCredential>, StorageError> {
let provider = ArroyoCredentialProvider::try_new()?;
let provider = ArroyoCredentialProvider::try_new().await?;
let credentials = provider.get_credential().await?;
Ok(credentials)
}

static OBJECT_STORE_CACHE: Lazy<RwLock<HashMap<String, CacheEntry<Arc<dyn ObjectStore>>>>> =
Lazy::new(Default::default);

struct CacheEntry<T> {
value: T,
inserted_at: Instant,
}

// The bearer token should last for 3600 seconds,
// but regenerating it every 5 minutes to avoid token expiry
const GCS_CACHE_TTL: Duration = Duration::from_secs(5 * 60);

impl StorageProvider {
pub async fn for_url(url: &str) -> Result<Self, StorageError> {
Self::for_url_with_options(url, HashMap::new()).await
Expand Down Expand Up @@ -360,11 +345,6 @@ impl StorageProvider {
Ok(key.clone())
}

pub async fn url_exists(url: &str) -> Result<bool, StorageError> {
let provider = Self::for_url(url).await?;
provider.exists("").await
}

async fn construct_s3(
mut config: S3Config,
options: HashMap<String, String>,
Expand All @@ -386,7 +366,7 @@ impl StorageProvider {

if !aws_key_manually_set {
let credentials: Arc<ArroyoCredentialProvider> =
Arc::new(ArroyoCredentialProvider::try_new()?);
Arc::new(ArroyoCredentialProvider::try_new().await?);
builder = builder.with_credentials(credentials);
}

Expand Down Expand Up @@ -444,45 +424,6 @@ impl StorageProvider {
})
}

async fn get_or_create_object_store(
builder: GoogleCloudStorageBuilder,
bucket: &str,
) -> Result<Arc<dyn ObjectStore>, StorageError> {
let mut cache = OBJECT_STORE_CACHE.write().await;

if let Some(entry) = cache.get(bucket) {
if entry.inserted_at.elapsed() < GCS_CACHE_TTL {
trace!(
"Cache hit - using cached object store for bucket {}",
bucket
);
return Ok(entry.value.clone());
} else {
debug!(
"Cache expired - constructing new object store for bucket {}",
bucket
);
}
} else {
debug!(
"Cache miss - constructing new object store for bucket {}",
bucket
);
}

let new_store = Arc::new(builder.build().map_err(Into::<StorageError>::into)?);

cache.insert(
bucket.to_string(),
CacheEntry {
value: new_store.clone(),
inserted_at: Instant::now(),
},
);

Ok(new_store)
}

async fn construct_gcs(config: GCSConfig) -> Result<Self, StorageError> {
let mut builder = GoogleCloudStorageBuilder::from_env().with_bucket_name(&config.bucket);

Expand All @@ -498,7 +439,7 @@ impl StorageProvider {

let object_store_base_url = format!("https://{}.storage.googleapis.com", config.bucket);

let object_store = Self::get_or_create_object_store(builder, &config.bucket).await?;
let object_store = Arc::new(builder.build()?);

Ok(Self {
config: BackendConfig::GCS(config),
Expand Down
4 changes: 0 additions & 4 deletions crates/arroyo-worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,7 @@ parquet = { workspace = true, features = ["async"]}
arrow-array = { workspace = true}
arrow-json = { workspace = true }

aws-sdk-kinesis = { version = "0.21", default-features = false, features = ["rt-tokio", "native-tls"] }
aws-config = { version = "0.51", default-features = false, features = ["rt-tokio", "native-tls"] }
uuid = {version = "1.4.1", features = ["v4"]}
rusoto_core = "0.48.0"
rusoto_s3 = "0.48.0"

tonic = { workspace = true }
prost = "0.12"
Expand Down

0 comments on commit 67d33c7

Please sign in to comment.