Skip to content

Commit

Permalink
Fix race condition in check_auth
Browse files Browse the repository at this point in the history
When a lot of requests arrive roughly at the same time, several requests
can enter the critical section where an HTTP request to upstream is made
to check the auth provided by the client. This means that potentially
thousands of requests can get through to the remote, leading to rate
limits and network errors with some remotes.

* Fix the issue by extending the lock region
* Switch mutex to async to avoid blocking runtime
* Improve tracing

commit-id:d44447ef
  • Loading branch information
vlad-ivanov-name committed May 14, 2024
1 parent b28d490 commit 90d1e4a
Showing 1 changed file with 45 additions and 18 deletions.
63 changes: 45 additions & 18 deletions josh-proxy/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use base64::engine::Engine as _;
lazy_static! {
static ref AUTH: std::sync::Mutex<std::collections::HashMap<Handle, Header>> =
std::sync::Mutex::new(std::collections::HashMap::new());
static ref AUTH_TIMERS: std::sync::Mutex<AuthTimers> =
std::sync::Mutex::new(std::collections::HashMap::new());
static ref AUTH_TIMERS: tokio::sync::Mutex<AuthTimers> =
tokio::sync::Mutex::new(std::collections::HashMap::new());
}

type AuthTimers = std::collections::HashMap<(String, Handle), std::time::Instant>;
Expand Down Expand Up @@ -74,20 +74,34 @@ pub fn add_auth(token: &str) -> josh::JoshResult<Handle> {
}

pub async fn check_auth(url: &str, auth: &Handle, required: bool) -> josh::JoshResult<bool> {
use opentelemetry_semantic_conventions::trace::HTTP_RESPONSE_STATUS_CODE;

if required && auth.hash.is_empty() {
return Ok(false);
}

if let Some(last) = AUTH_TIMERS.lock()?.get(&(url.to_string(), auth.clone())) {
let mut auth_timers = AUTH_TIMERS.lock().await;

if let Some(last) = auth_timers.get(&(url.to_string(), auth.clone())) {
let since = std::time::Instant::now().duration_since(*last);
tracing::trace!("last: {:?}, since: {:?}", last, since);
if since < std::time::Duration::from_secs(60 * 30) {
tracing::trace!("cached auth");
let expired = since > std::time::Duration::from_secs(60 * 30);

tracing::trace!(
last = ?last,
since = ?since,
expired = %expired,
"check_auth: found auth entry"
);

if !expired {
return Ok(true);
}
}

tracing::trace!("no cached auth {:?}", *AUTH_TIMERS.lock()?);
tracing::trace!(
auth_timers = ?auth_timers,
"check_auth: no valid cached auth"
);

let https = hyper_tls::HttpsConnector::new();
let client = hyper::Client::builder().build::<_, hyper::Body>(https);
Expand All @@ -114,24 +128,37 @@ pub async fn check_auth(url: &str, auth: &Handle, required: bool) -> josh::JoshR

let status = resp.status();

tracing::trace!("http resp.status {:?}", resp.status());

let err_msg = format!("got http response: {} {:?}", refs_url, resp);
tracing::event!(
tracing::Level::WARN,
{ HTTP_RESPONSE_STATUS_CODE } = status.as_u16(),
"check_auth: response"
);

if status == hyper::StatusCode::OK {
AUTH_TIMERS
.lock()?
.insert((url.to_string(), auth.clone()), std::time::Instant::now());
auth_timers.insert((url.to_string(), auth.clone()), std::time::Instant::now());
Ok(true)
} else if status == hyper::StatusCode::UNAUTHORIZED {
tracing::warn!("resp.status == 401: {:?}", &err_msg);
tracing::trace!(
"body: {:?}",
std::str::from_utf8(&hyper::body::to_bytes(resp.into_body()).await?)
tracing::event!(
tracing::Level::WARN,
{ HTTP_RESPONSE_STATUS_CODE } = status.as_u16(),
"check_auth: unauthorized"
);

let response = hyper::body::to_bytes(resp.into_body()).await?;
let response = String::from_utf8_lossy(&response);

tracing::event!(
tracing::Level::TRACE,
"http.response.body" = %response,
"check_auth: unauthorized",
);

Ok(false)
} else {
return Err(josh::josh_error(&err_msg));
return Err(josh::josh_error(&format!(
"check_auth: got http response: {} {:?}",
refs_url, resp
)));
}
}

Expand Down

0 comments on commit 90d1e4a

Please sign in to comment.