Skip to content

Commit

Permalink
Merge pull request #111 from aviramha/re-auth
Browse files Browse the repository at this point in the history
Refactor auth - add re-auth mechanism
  • Loading branch information
flavio authored Feb 7, 2024
2 parents 5b0ab7a + f8c23d2 commit 4c80c9c
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 74 deletions.
154 changes: 106 additions & 48 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use std::collections::HashMap;
use std::convert::TryFrom;
use std::sync::Arc;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::RwLock;
use tracing::{debug, trace, warn};

const MIME_TYPES_DISTRIBUTION_MANIFEST: &[&str] = &[
Expand Down Expand Up @@ -205,6 +206,8 @@ impl TryFrom<Config> for ConfigFile {
#[derive(Clone)]
pub struct Client {
config: Arc<ClientConfig>,
// Registry -> RegistryAuth
auth_store: Arc<RwLock<HashMap<String, RegistryAuth>>>,
tokens: TokenCache,
client: reqwest::Client,
push_chunk_size: usize,
Expand All @@ -213,9 +216,10 @@ pub struct Client {
impl Default for Client {
fn default() -> Self {
Self {
config: Arc::new(ClientConfig::default()),
tokens: TokenCache::new(),
client: reqwest::Client::new(),
config: Arc::default(),
auth_store: Arc::default(),
tokens: TokenCache::default(),
client: reqwest::Client::default(),
push_chunk_size: PUSH_CHUNK_MAX_SIZE,
}
}
Expand Down Expand Up @@ -257,9 +261,9 @@ impl TryFrom<ClientConfig> for Client {

Ok(Self {
config: Arc::new(config),
tokens: TokenCache::new(),
client: client_builder.build()?,
push_chunk_size: PUSH_CHUNK_MAX_SIZE,
..Default::default()
})
}
}
Expand All @@ -271,10 +275,8 @@ impl Client {
warn!("Cannot create OCI client from config: {:?}", err);
warn!("Creating client with default configuration");
Self {
config: Arc::new(ClientConfig::default()),
tokens: TokenCache::new(),
client: reqwest::Client::new(),
push_chunk_size: PUSH_CHUNK_MAX_SIZE,
..Default::default()
}
})
}
Expand All @@ -284,6 +286,41 @@ impl Client {
Self::new(config_source.client_config())
}

async fn store_auth(&self, registry: &str, auth: RegistryAuth) {
self.auth_store
.write()
.await
.insert(registry.to_string(), auth);
}

async fn is_stored_auth(&self, registry: &str) -> bool {
self.auth_store.read().await.contains_key(registry)
}

async fn store_auth_if_needed(&self, registry: &str, auth: &RegistryAuth) {
if !self.is_stored_auth(registry).await {
self.store_auth(registry, auth.clone()).await;
}
}

/// Checks if we got a token, if we don't - create it and store it in cache.
async fn get_auth_token(
&self,
reference: &Reference,
op: RegistryOperation,
) -> Option<RegistryTokenType> {
let registry = reference.resolve_registry();
let auth = self.auth_store.read().await.get(registry)?.clone();
match self.tokens.get(reference, op).await {
Some(token) => Some(token),
None => {
let token = self._auth(reference, &auth, op).await.ok()??;
self.tokens.insert(reference, op, token.clone()).await;
Some(token)
}
}
}

/// Fetches the available Tags for the given Reference
///
/// The client will check if it's already been authenticated and if
Expand All @@ -298,9 +335,8 @@ impl Client {
let op = RegistryOperation::Pull;
let url = self.to_list_tags_url(image);

if !self.tokens.contains_key(image, op).await {
self.auth(image, auth, op).await?;
}
self.store_auth_if_needed(image.resolve_registry(), auth)
.await;

let request = self.client.get(&url);
let request = if let Some(num) = n {
Expand Down Expand Up @@ -342,10 +378,8 @@ impl Client {
accepted_media_types: Vec<&str>,
) -> Result<ImageData> {
debug!("Pulling image: {:?}", image);
let op = RegistryOperation::Pull;
if !self.tokens.contains_key(image, op).await {
self.auth(image, auth, op).await?;
}
self.store_auth_if_needed(image.resolve_registry(), auth)
.await;

let (manifest, digest, config) = self._pull_manifest_and_config(image).await?;

Expand Down Expand Up @@ -400,10 +434,8 @@ impl Client {
manifest: Option<OciImageManifest>,
) -> Result<PushResponse> {
debug!("Pushing image: {:?}", image_ref);
let op = RegistryOperation::Push;
if !self.tokens.contains_key(image_ref, op).await {
self.auth(image_ref, auth, op).await?;
}
self.store_auth_if_needed(image_ref.resolve_registry(), auth)
.await;

let manifest: OciImageManifest = match manifest {
Some(m) => m,
Expand Down Expand Up @@ -502,6 +534,38 @@ impl Client {
authentication: &RegistryAuth,
operation: RegistryOperation,
) -> Result<Option<String>> {
self.store_auth_if_needed(image.resolve_registry(), authentication)
.await;
// preserve old caching behavior
match self._auth(image, authentication, operation).await {
Ok(Some(RegistryTokenType::Bearer(token))) => {
self.tokens
.insert(image, operation, RegistryTokenType::Bearer(token.clone()))
.await;
Ok(Some(token.token().to_string()))
}
Ok(Some(RegistryTokenType::Basic(username, password))) => {
self.tokens
.insert(
image,
operation,
RegistryTokenType::Basic(username, password),
)
.await;
Ok(None)
}
Ok(None) => Ok(None),
Err(e) => Err(e),
}
}

/// Internal auth that retrieves token.
async fn _auth(
&self,
image: &Reference,
authentication: &RegistryAuth,
operation: RegistryOperation,
) -> Result<Option<RegistryTokenType>> {
debug!("Authorizing for image: {:?}", image);
// The version request will tell us where to go.
let url = format!(
Expand All @@ -521,13 +585,10 @@ impl Client {
Err(e) => {
debug!(error = ?e, "Falling back to HTTP Basic Auth");
if let RegistryAuth::Basic(username, password) = authentication {
self.tokens
.insert(
image,
operation,
RegistryTokenType::Basic(username.to_string(), password.to_string()),
)
.await;
return Ok(Some(RegistryTokenType::Basic(
username.to_string(),
password.to_string(),
)));
}
return Ok(None);
}
Expand Down Expand Up @@ -566,11 +627,7 @@ impl Client {
let token: RegistryToken = serde_json::from_str(&text)
.map_err(|e| OciDistributionError::RegistryTokenDecodeError(e.to_string()))?;
debug!("Successfully authorized for image '{:?}'", image);
let oauth_token = token.token().to_string();
self.tokens
.insert(image, operation, RegistryTokenType::Bearer(token))
.await;
Ok(Some(oauth_token))
Ok(Some(RegistryTokenType::Bearer(token)))
}
_ => {
let reason = auth_res.text().await?;
Expand All @@ -593,10 +650,8 @@ impl Client {
image: &Reference,
auth: &RegistryAuth,
) -> Result<String> {
let op = RegistryOperation::Pull;
if !self.tokens.contains_key(image, op).await {
self.auth(image, auth, op).await?;
}
self.store_auth_if_needed(image.resolve_registry(), auth)
.await;

let url = self.to_v2_manifest_url(image);
debug!("HEAD image manifest from {}", url);
Expand Down Expand Up @@ -670,10 +725,8 @@ impl Client {
image: &Reference,
auth: &RegistryAuth,
) -> Result<(OciImageManifest, String)> {
let op = RegistryOperation::Pull;
if !self.tokens.contains_key(image, op).await {
self.auth(image, auth, op).await?;
}
self.store_auth_if_needed(image.resolve_registry(), auth)
.await;

self._pull_image_manifest(image).await
}
Expand All @@ -690,10 +743,8 @@ impl Client {
image: &Reference,
auth: &RegistryAuth,
) -> Result<(OciManifest, String)> {
let op = RegistryOperation::Pull;
if !self.tokens.contains_key(image, op).await {
self.auth(image, auth, op).await?;
}
self.store_auth_if_needed(image.resolve_registry(), auth)
.await;

self._pull_manifest(image).await
}
Expand Down Expand Up @@ -811,10 +862,8 @@ impl Client {
image: &Reference,
auth: &RegistryAuth,
) -> Result<(OciImageManifest, String, String)> {
let op = RegistryOperation::Pull;
if !self.tokens.contains_key(image, op).await {
self.auth(image, auth, op).await?;
}
self.store_auth_if_needed(image.resolve_registry(), auth)
.await;

self._pull_manifest_and_config(image)
.await
Expand Down Expand Up @@ -855,7 +904,8 @@ impl Client {
auth: &RegistryAuth,
manifest: OciImageIndex,
) -> Result<String> {
self.auth(reference, auth, RegistryOperation::Push).await?;
self.store_auth_if_needed(reference.resolve_registry(), auth)
.await;
self.push_manifest(reference, &OciManifest::ImageIndex(manifest))
.await
}
Expand Down Expand Up @@ -1418,7 +1468,7 @@ impl<'a> RequestBuilderWrapper<'a> {
) -> Result<RequestBuilderWrapper> {
let mut headers = HeaderMap::new();

if let Some(token) = self.client.tokens.get(image, op).await {
if let Some(token) = self.client.get_auth_token(image, op).await {
match token {
RegistryTokenType::Bearer(token) => {
debug!("Using bearer token authentication.");
Expand Down Expand Up @@ -1816,6 +1866,14 @@ mod test {
.as_str()
.to_string();

// we have to have it in the stored auth so we'll get to the token cache check.
client
.store_auth(
&Reference::try_from(HELLO_IMAGE_TAG)?.resolve_registry(),
RegistryAuth::Anonymous,
)
.await;

client
.tokens
.insert(
Expand Down
59 changes: 33 additions & 26 deletions src/token_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,25 @@ pub enum RegistryOperation {
Pull,
}

type CacheType = BTreeMap<(String, String, RegistryOperation), (RegistryTokenType, u64)>;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct TokenCacheKey {
registry: String,
repository: String,
operation: RegistryOperation,
}

struct TokenCacheValue {
token: RegistryTokenType,
expiration: u64,
}

#[derive(Default, Clone)]
pub(crate) struct TokenCache {
// (registry, repository, scope) -> (token, expiration)
tokens: Arc<RwLock<CacheType>>,
tokens: Arc<RwLock<BTreeMap<TokenCacheKey, TokenCacheValue>>>,
}

impl TokenCache {
pub(crate) fn new() -> Self {
TokenCache {
tokens: Arc::new(RwLock::new(BTreeMap::new())),
}
}

pub(crate) async fn insert(
&self,
reference: &Reference,
Expand Down Expand Up @@ -119,10 +123,14 @@ impl TokenCache {
let registry = reference.resolve_registry().to_string();
let repository = reference.repository().to_string();
debug!(%registry, %repository, ?op, %expiration, "Inserting token");
self.tokens
.write()
.await
.insert((registry, repository, op), (token, expiration));
self.tokens.write().await.insert(
TokenCacheKey {
registry,
repository,
operation: op,
},
TokenCacheValue { token, expiration },
);
}

pub(crate) async fn get(
Expand All @@ -132,34 +140,33 @@ impl TokenCache {
) -> Option<RegistryTokenType> {
let registry = reference.resolve_registry().to_string();
let repository = reference.repository().to_string();
match self
.tokens
.read()
.await
.get(&(registry.clone(), repository.clone(), op))
{
Some((ref token, expiration)) => {
let key = TokenCacheKey {
registry,
repository,
operation: op,
};
match self.tokens.read().await.get(&key) {
Some(TokenCacheValue {
ref token,
expiration,
}) => {
let now = SystemTime::now();
let epoch = now
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs();
if epoch > *expiration {
debug!(%registry, %repository, ?op, %expiration, miss=false, expired=true, "Fetching token");
debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=true, "Fetching token");
None
} else {
debug!(%registry, %repository, ?op, %expiration, miss=false, expired=false, "Fetching token");
debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=false, "Fetching token");
Some(token.clone())
}
}
None => {
debug!(%registry, %repository, ?op, miss=true, "Fetching token");
debug!(%key.registry, %key.repository, ?key.operation, miss = true, "Fetching token");
None
}
}
}

pub(crate) async fn contains_key(&self, reference: &Reference, op: RegistryOperation) -> bool {
self.get(reference, op).await.is_some()
}
}

0 comments on commit 4c80c9c

Please sign in to comment.