Skip to content

Commit

Permalink
Fix race condition in check_auth (#1331)
Browse files Browse the repository at this point in the history
When a lot of requests arrive roughly at the same time, several requests
can enter the critical section where an HTTP request to upstream is made
to check the auth provided by the client. This means that potentially
thousands of requests can get through to the remote, leading to rate
limits and network errors with some remotes.

* Introduce a much more granular lock scope
* Fix the issue by extending the lock region
* Switch mutex to async to avoid blocking runtime
* Improve tracing

commit-id:7d950008
  • Loading branch information
vlad-ivanov-name authored May 24, 2024
1 parent 012f8dc commit 3043aeb
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 92 deletions.
248 changes: 168 additions & 80 deletions josh-proxy/src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,53 @@
use std::sync::Arc;

// Import the base64 crate Engine trait anonymously so we can
// call its methods without adding to the namespace.
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::engine::Engine as _;
use tracing::Instrument;

// Auths in those groups are independent of each other.
// This lets us reduce mutex contention
#[derive(Hash, Eq, PartialEq, Clone)]
struct AuthTimersGroupKey {
url: String,
username: String,
}

lazy_static! {
static ref AUTH: std::sync::Mutex<std::collections::HashMap<Handle, Header>> =
std::sync::Mutex::new(std::collections::HashMap::new());
static ref AUTH_TIMERS: std::sync::Mutex<AuthTimers> =
std::sync::Mutex::new(std::collections::HashMap::new());
impl AuthTimersGroupKey {
fn new(url: &str, handle: &Handle) -> Self {
let (username, _) = handle.parse().unwrap_or_default();

Self {
url: url.to_string(),
username,
}
}
}

type AuthTimers = std::collections::HashMap<(String, Handle), std::time::Instant>;
// Within a group, we can hold the lock for longer to verify the auth with upstream
type AuthTimersGroup = std::collections::HashMap<Handle, std::time::Instant>;
type AuthTimers =
std::collections::HashMap<AuthTimersGroupKey, Arc<tokio::sync::Mutex<AuthTimersGroup>>>;

lazy_static! {
// Note the use of std::sync::Mutex: access to those structures should only be performed
// shortly, without blocking the async runtime for long time and without holding the
// lock across an await point.
static ref AUTH: std::sync::Mutex<std::collections::HashMap<Handle, Header>> = Default::default();
static ref AUTH_TIMERS: std::sync::Mutex<AuthTimers> = Default::default();
}

// Wrapper struct for storing passwords to avoid having
// them output to traces by accident
#[derive(Clone)]
#[derive(Clone, Default)]
struct Header {
pub header: Option<hyper::header::HeaderValue>,
}

#[derive(Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
pub struct Handle {
pub hash: String,
pub hash: Option<String>,
}

impl std::fmt::Debug for Handle {
Expand All @@ -32,39 +58,50 @@ impl std::fmt::Debug for Handle {

impl Handle {
// Returns a pair: (username, password)
pub fn parse(&self) -> josh::JoshResult<(String, String)> {
let line = josh::some_or!(
AUTH.lock()
pub fn parse(&self) -> Option<(String, String)> {
let get_result = || -> josh::JoshResult<(String, String)> {
let line = AUTH
.lock()
.unwrap()
.get(self)
.and_then(|h| h.header.as_ref())
.map(|h| h.as_bytes().to_owned()),
{
return Ok(("".to_string(), "".to_string()));
}
);
.map(|h| h.as_bytes().to_owned())
.ok_or_else(|| josh::josh_error("no auth found"))?;

let u = josh::ok_or!(String::from_utf8(line[6..].to_vec()), {
return Ok(("".to_string(), "".to_string()));
});
let decoded = josh::ok_or!(BASE64.decode(u), {
return Ok(("".to_string(), "".to_string()));
});
let s = josh::ok_or!(String::from_utf8(decoded), {
return Ok(("".to_string(), "".to_string()));
});
let (username, password) = s.as_str().split_once(':').unwrap_or(("", ""));
Ok((username.to_string(), password.to_string()))
let line = String::from_utf8(line)?;
let (_, token) = line
.split_once(' ')
.ok_or_else(|| josh::josh_error("Unsupported auth type"))?;

let decoded = BASE64.decode(token)?;
let decoded = String::from_utf8(decoded)?;

let (username, password) = decoded
.split_once(':')
.ok_or_else(|| josh::josh_error("No password found"))?;

Ok((username.to_string(), password.to_string()))
};

match get_result() {
Ok(pair) => Some(pair),
Err(e) => {
tracing::trace!(
handle = ?self,
"Falling back to default auth: {:?}",
e
);

None
}
}
}
}

pub fn add_auth(token: &str) -> josh::JoshResult<Handle> {
let header = hyper::header::HeaderValue::from_str(&format!("Basic {}", BASE64.encode(token)))?;
let hp = Handle {
hash: format!(
"{:?}",
git2::Oid::hash_object(git2::ObjectType::Blob, header.as_bytes())?
),
hash: Some(git2::Oid::hash_object(git2::ObjectType::Blob, header.as_bytes())?.to_string()),
};
let p = Header {
header: Some(header),
Expand All @@ -73,65 +110,122 @@ pub fn add_auth(token: &str) -> josh::JoshResult<Handle> {
Ok(hp)
}

pub async fn check_auth(url: &str, auth: &Handle, required: bool) -> josh::JoshResult<bool> {
if required && auth.hash.is_empty() {
return Ok(false);
}
#[tracing::instrument()]
pub async fn check_http_auth(url: &str, auth: &Handle, required: bool) -> josh::JoshResult<bool> {
use opentelemetry_semantic_conventions::trace::HTTP_RESPONSE_STATUS_CODE;

if let Some(last) = AUTH_TIMERS.lock()?.get(&(url.to_string(), auth.clone())) {
let since = std::time::Instant::now().duration_since(*last);
tracing::trace!("last: {:?}, since: {:?}", last, since);
if since < std::time::Duration::from_secs(60 * 30) {
tracing::trace!("cached auth");
return Ok(true);
}
if required && auth.hash.is_none() {
return Ok(false);
}

tracing::trace!("no cached auth {:?}", *AUTH_TIMERS.lock()?);
let group_key = AuthTimersGroupKey::new(url, &auth);
let auth_timers = AUTH_TIMERS
.lock()
.unwrap()
.entry(group_key.clone())
.or_default()
.clone();

let https = hyper_tls::HttpsConnector::new();
let client = hyper::Client::builder().build::<_, hyper::Body>(https);
let auth_header = AUTH.lock().unwrap().get(auth).cloned().unwrap_or_default();

let password = AUTH
.lock()?
.get(auth)
.unwrap_or(&Header { header: None })
.to_owned();
let refs_url = format!("{}/info/refs?service=git-upload-pack", url);
let do_request = || {
let refs_url = refs_url.clone();
let do_request_span = tracing::info_span!("check_http_auth: make request");

let builder = hyper::Request::builder()
.method(hyper::Method::GET)
.uri(&refs_url);
async move {
let https = hyper_tls::HttpsConnector::new();
let client = hyper::Client::builder().build::<_, hyper::Body>(https);

let builder = if let Some(value) = password.header {
builder.header(hyper::header::AUTHORIZATION, value)
} else {
builder
let builder = hyper::Request::builder()
.method(hyper::Method::GET)
.uri(&refs_url);

let builder = if let Some(value) = auth_header.header {
builder.header(hyper::header::AUTHORIZATION, value)
} else {
builder
};

let request = builder.body(hyper::Body::empty())?;
let resp = client.request(request).await?;

Ok::<_, josh::JoshError>(resp)
}
.instrument(do_request_span)
};

let request = builder.body(hyper::Body::empty())?;
let resp = client.request(request).await?;
// Only lock the mutex if auth handle is not empty, because otherwise
// for remotes that require auth, we could run into situation where
// multiple requests are executed essentially sequentially because
// remote always returns 401 for authenticated requests and we never
// populate the auth_timers map
let resp = if auth.hash.is_some() {
let mut auth_timers = auth_timers.lock().await;

if let Some(last) = auth_timers.get(auth) {
let since = std::time::Instant::now().duration_since(*last);
let expired = since > std::time::Duration::from_secs(60 * 30);

tracing::info!(
last = ?last,
since = ?since,
expired = %expired,
"check_http_auth: found auth entry"
);

if !expired {
return Ok(true);
}
}

let status = resp.status();
tracing::info!(
auth_timers = ?auth_timers,
"check_http_auth: no valid cached auth"
);

tracing::trace!("http resp.status {:?}", resp.status());
let resp = do_request().await?;
if resp.status().is_success() {
auth_timers.insert(auth.clone(), std::time::Instant::now());
}

resp
} else {
do_request().await?
};

let status = resp.status();

let err_msg = format!("got http response: {} {:?}", refs_url, resp);
tracing::event!(
tracing::Level::INFO,
{ HTTP_RESPONSE_STATUS_CODE } = status.as_u16(),
"check_http_auth: response"
);

if status == hyper::StatusCode::OK {
AUTH_TIMERS
.lock()?
.insert((url.to_string(), auth.clone()), std::time::Instant::now());
Ok(true)
} else if status == hyper::StatusCode::UNAUTHORIZED {
tracing::warn!("resp.status == 401: {:?}", &err_msg);
tracing::trace!(
"body: {:?}",
std::str::from_utf8(&hyper::body::to_bytes(resp.into_body()).await?)
tracing::event!(
tracing::Level::WARN,
{ HTTP_RESPONSE_STATUS_CODE } = status.as_u16(),
"check_http_auth: unauthorized"
);

let response = hyper::body::to_bytes(resp.into_body()).await?;
let response = String::from_utf8_lossy(&response);

tracing::event!(
tracing::Level::TRACE,
"http.response.body" = %response,
"check_http_auth: unauthorized",
);

Ok(false)
} else {
return Err(josh::josh_error(&err_msg));
return Err(josh::josh_error(&format!(
"check_http_auth: got http response: {} {:?}",
refs_url, resp
)));
}
}

Expand All @@ -144,9 +238,8 @@ pub fn strip_auth(

if let Some(header) = header {
let hp = Handle {
hash: format!(
"{:?}",
git2::Oid::hash_object(git2::ObjectType::Blob, header.as_bytes())?
hash: Some(
git2::Oid::hash_object(git2::ObjectType::Blob, header.as_bytes())?.to_string(),
),
};
let p = Header {
Expand All @@ -156,10 +249,5 @@ pub fn strip_auth(
return Ok((hp, req));
}

Ok((
Handle {
hash: "".to_owned(),
},
req,
))
Ok((Handle { hash: None }, req))
}
19 changes: 8 additions & 11 deletions josh-proxy/src/bin/josh-proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,13 @@ async fn fetch_upstream(

match (fetch_result, remote_auth) {
(Ok(_), RemoteAuth::Http { auth }) => {
let (auth_user, _) = auth.parse().map_err(FetchError::from_josh_error)?;

if matches!(&ARGS.poll_user, Some(user) if auth_user == user.as_str()) {
service
.poll
.lock()?
.insert((upstream_repo, auth.clone(), remote_url));
if let Some((auth_user, _)) = auth.parse() {
if matches!(&ARGS.poll_user, Some(user) if auth_user == user.as_str()) {
service
.poll
.lock()?
.insert((upstream_repo, auth.clone(), remote_url));
}
}

Ok(())
Expand Down Expand Up @@ -1275,10 +1275,7 @@ async fn call_service(

let http_auth_required = ARGS.require_auth && parsed_url.pathinfo == "/git-receive-pack";

if !josh_proxy::auth::check_auth(&remote_url, &auth, http_auth_required)
.in_current_span()
.await?
{
if !josh_proxy::auth::check_http_auth(&remote_url, &auth, http_auth_required).await? {
tracing::trace!("require-auth");
let builder = Response::builder()
.header(
Expand Down
2 changes: 1 addition & 1 deletion josh-proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ pub fn run_git_with_auth(
Ok(shell.command_env(cmd, &env, &env_notrace))
}
RemoteAuth::Http { auth } => {
let (username, password) = auth.parse()?;
let (username, password) = auth.parse().unwrap_or_default();
let env_notrace = [
[
("GIT_PASSWORD", password.as_str()),
Expand Down

0 comments on commit 3043aeb

Please sign in to comment.