Skip to content

Commit

Permalink
feat(cloneable-clients): make service account interceptor cloneable
Browse files Browse the repository at this point in the history
  • Loading branch information
sprudel committed Aug 26, 2024
1 parent 2cabaff commit d8c00e1
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 29 deletions.
6 changes: 4 additions & 2 deletions src/api/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ pub struct ClientBuilderWithInterceptor<T: Interceptor> {
interceptor: T,
}


impl ClientBuilder {
/// Create a new client builder with the the provided endpoint.
pub fn new(api_endpoint: &str) -> ClientBuilder {
Expand All @@ -75,7 +74,10 @@ impl ClientBuilder {
/// Clients with this authentication method will have the [`AccessTokenInterceptor`]
/// attached.
#[cfg(feature = "interceptors")]
pub fn with_access_token(self, access_token: &str) -> ClientBuilderWithInterceptor<AccessTokenInterceptor> {
pub fn with_access_token(
self,
access_token: &str,
) -> ClientBuilderWithInterceptor<AccessTokenInterceptor> {
ClientBuilderWithInterceptor {
api_endpoint: self.api_endpoint,
interceptor: AccessTokenInterceptor::new(access_token),
Expand Down
156 changes: 129 additions & 27 deletions src/api/interceptors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
//! interceptors is to authenticate the clients to ZITADEL with
//! provided credentials.
use std::ops::Deref;
use std::sync::{Arc, RwLock};
use std::thread;

use tokio::runtime::Builder;
Expand Down Expand Up @@ -41,6 +43,7 @@ use crate::credentials::{AuthenticationOptions, ServiceAccount};
/// # Ok(())
/// # }
/// ```
#[derive(Clone)]
pub struct AccessTokenInterceptor {
access_token: String,
}
Expand Down Expand Up @@ -125,12 +128,21 @@ impl Interceptor for AccessTokenInterceptor {
/// # Ok(())
/// # }
/// ```
#[derive(Clone)]
pub struct ServiceAccountInterceptor {
inner: Arc<ServiceAccountInterceptorInner>,
}

struct ServiceAccountInterceptorInner {
audience: String,
service_account: ServiceAccount,
auth_options: AuthenticationOptions,
token: Option<String>,
token_expiry: Option<time::OffsetDateTime>,
state: RwLock<Option<ServiceAccountInterceptorState>>,
}

struct ServiceAccountInterceptorState {
token: String,
token_expiry: time::OffsetDateTime,
}

impl ServiceAccountInterceptor {
Expand All @@ -144,11 +156,12 @@ impl ServiceAccountInterceptor {
auth_options: Option<AuthenticationOptions>,
) -> Self {
Self {
audience: audience.to_string(),
service_account: service_account.clone(),
auth_options: auth_options.unwrap_or_default(),
token: None,
token_expiry: None,
inner: Arc::new(ServiceAccountInterceptorInner {
audience: audience.to_string(),
service_account: service_account.clone(),
auth_options: auth_options.unwrap_or_default(),
state: RwLock::new(None),
}),
}
}
}
Expand All @@ -157,25 +170,32 @@ impl Interceptor for ServiceAccountInterceptor {
fn call(&mut self, mut request: tonic::Request<()>) -> Result<tonic::Request<()>, Status> {
let meta = request.metadata_mut();
if !meta.contains_key("authorization") {
if let Some(token) = &self.token {
if let Some(expiry) = self.token_expiry {
if expiry > time::OffsetDateTime::now_utc() {
meta.insert(
"authorization",
format!("Bearer {}", token).parse().unwrap(),
);

return Ok(request);
}
// We unwrap the RWLock to propagate the error if any
// thread panics and the lock is poisoned
let state_read_guard = self.inner.state.read().unwrap();

if let Some(ServiceAccountInterceptorState {
token,
token_expiry,
}) = state_read_guard.deref()
{
if token_expiry > &time::OffsetDateTime::now_utc() {
meta.insert(
"authorization",
format!("Bearer {}", token).parse().unwrap(),
);

return Ok(request);
}
}
drop(state_read_guard);

let aud = self.audience.clone();
let auth = self.auth_options.clone();
let sa = self.service_account.clone();
let aud = self.inner.audience.clone();
let auth = self.inner.auth_options.clone();
let sa = self.inner.service_account.clone();

let token = thread::spawn(move || {
let rt = Builder::new_multi_thread().enable_all().build().unwrap();
let rt = Builder::new_current_thread().enable_all().build().unwrap();
rt.block_on(async {
sa.authenticate_with_options(&aud, &auth)
.await
Expand All @@ -187,8 +207,14 @@ impl Interceptor for ServiceAccountInterceptor {
.join()
.map_err(|_| Status::internal("could not fetch token"))??;

self.token = Some(token.clone());
self.token_expiry = Some(time::OffsetDateTime::now_utc() + time::Duration::minutes(59));
// We unwrap the RWLock to propagate the error if any
// thread panics and the lock is poisoned
let mut state_write_guard = self.inner.state.write().unwrap();

*state_write_guard = Some(ServiceAccountInterceptorState {
token: token.clone(),
token_expiry: time::OffsetDateTime::now_utc() + time::Duration::minutes(59),
});

meta.insert(
"authorization",
Expand Down Expand Up @@ -288,6 +314,46 @@ mod tests {
.is_empty());
}

#[test]
fn service_account_interceptor_can_be_cloned_and_shares_token_sync_context() {
let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap();
let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None);
let mut second_interceptor = interceptor.clone();
let request = Request::new(());
let second_request = Request::new(());

assert!(request.metadata().is_empty());
assert!(second_request.metadata().is_empty());

let request = interceptor.call(request).unwrap();
let second_request = second_interceptor.call(second_request).unwrap();

assert_eq!(
request.metadata().get("authorization"),
second_request.metadata().get("authorization")
);
}

#[tokio::test]
async fn service_account_interceptor_can_be_cloned_and_shares_token_async_context() {
let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap();
let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None);
let mut second_interceptor = interceptor.clone();
let request = Request::new(());
let second_request = Request::new(());

assert!(request.metadata().is_empty());
assert!(second_request.metadata().is_empty());

let request = interceptor.call(request).unwrap();
let second_request = second_interceptor.call(second_request).unwrap();

assert_eq!(
request.metadata().get("authorization"),
second_request.metadata().get("authorization")
);
}

#[test]
fn service_account_interceptor_ignore_existing_auth_metadata_sync_context() {
let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap();
Expand Down Expand Up @@ -333,20 +399,56 @@ mod tests {
let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap();
let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None);
interceptor.call(Request::new(())).unwrap();
let token = interceptor.token.clone().unwrap();
let token = interceptor
.inner
.state
.read()
.unwrap()
.as_ref()
.unwrap()
.token
.clone();
interceptor.call(Request::new(())).unwrap();

assert_eq!(token, interceptor.token.unwrap());
assert_eq!(
token,
interceptor
.inner
.state
.read()
.unwrap()
.as_ref()
.unwrap()
.token
);
}

#[tokio::test]
async fn service_account_interceptor_should_respect_token_lifetime_async() {
let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap();
let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None);
interceptor.call(Request::new(())).unwrap();
let token = interceptor.token.clone().unwrap();
let token = interceptor
.inner
.state
.read()
.unwrap()
.as_ref()
.unwrap()
.token
.clone();
interceptor.call(Request::new(())).unwrap();

assert_eq!(token, interceptor.token.unwrap());
assert_eq!(
token,
interceptor
.inner
.state
.read()
.unwrap()
.as_ref()
.unwrap()
.token
);
}
}

0 comments on commit d8c00e1

Please sign in to comment.