Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in check_auth #1331

Merged
merged 3 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 168 additions & 80 deletions josh-proxy/src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,53 @@
use std::sync::Arc;

// Import the base64 crate Engine trait anonymously so we can
// call its methods without adding to the namespace.
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::engine::Engine as _;
use tracing::Instrument;

// Auths in those groups are independent of each other.
// This lets us reduce mutex contention
#[derive(Hash, Eq, PartialEq, Clone)]
struct AuthTimersGroupKey {
url: String,
username: String,
}

lazy_static! {
static ref AUTH: std::sync::Mutex<std::collections::HashMap<Handle, Header>> =
std::sync::Mutex::new(std::collections::HashMap::new());
static ref AUTH_TIMERS: std::sync::Mutex<AuthTimers> =
std::sync::Mutex::new(std::collections::HashMap::new());
impl AuthTimersGroupKey {
fn new(url: &str, handle: &Handle) -> Self {
let (username, _) = handle.parse().unwrap_or_default();

Self {
url: url.to_string(),
username,
}
}
}

type AuthTimers = std::collections::HashMap<(String, Handle), std::time::Instant>;
// Within a group, we can hold the lock for longer to verify the auth with upstream
type AuthTimersGroup = std::collections::HashMap<Handle, std::time::Instant>;
type AuthTimers =
std::collections::HashMap<AuthTimersGroupKey, Arc<tokio::sync::Mutex<AuthTimersGroup>>>;

lazy_static! {
// Note the use of std::sync::Mutex: access to those structures should only be performed
// shortly, without blocking the async runtime for long time and without holding the
// lock across an await point.
static ref AUTH: std::sync::Mutex<std::collections::HashMap<Handle, Header>> = Default::default();
static ref AUTH_TIMERS: std::sync::Mutex<AuthTimers> = Default::default();
}

// Wrapper struct for storing passwords to avoid having
// them output to traces by accident
#[derive(Clone)]
#[derive(Clone, Default)]
struct Header {
pub header: Option<hyper::header::HeaderValue>,
}

#[derive(Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
pub struct Handle {
pub hash: String,
pub hash: Option<String>,
}

impl std::fmt::Debug for Handle {
Expand All @@ -32,39 +58,50 @@ impl std::fmt::Debug for Handle {

impl Handle {
// Returns a pair: (username, password)
pub fn parse(&self) -> josh::JoshResult<(String, String)> {
let line = josh::some_or!(
AUTH.lock()
pub fn parse(&self) -> Option<(String, String)> {
let get_result = || -> josh::JoshResult<(String, String)> {
let line = AUTH
.lock()
.unwrap()
.get(self)
.and_then(|h| h.header.as_ref())
.map(|h| h.as_bytes().to_owned()),
{
return Ok(("".to_string(), "".to_string()));
}
);
.map(|h| h.as_bytes().to_owned())
.ok_or_else(|| josh::josh_error("no auth found"))?;

let u = josh::ok_or!(String::from_utf8(line[6..].to_vec()), {
return Ok(("".to_string(), "".to_string()));
});
let decoded = josh::ok_or!(BASE64.decode(u), {
return Ok(("".to_string(), "".to_string()));
});
let s = josh::ok_or!(String::from_utf8(decoded), {
return Ok(("".to_string(), "".to_string()));
});
let (username, password) = s.as_str().split_once(':').unwrap_or(("", ""));
Ok((username.to_string(), password.to_string()))
let line = String::from_utf8(line)?;
let (_, token) = line
.split_once(' ')
.ok_or_else(|| josh::josh_error("Unsupported auth type"))?;

let decoded = BASE64.decode(token)?;
let decoded = String::from_utf8(decoded)?;

let (username, password) = decoded
.split_once(':')
.ok_or_else(|| josh::josh_error("No password found"))?;

Ok((username.to_string(), password.to_string()))
};

match get_result() {
Ok(pair) => Some(pair),
Err(e) => {
tracing::trace!(
handle = ?self,
"Falling back to default auth: {:?}",
e
);

None
}
}
}
}

pub fn add_auth(token: &str) -> josh::JoshResult<Handle> {
let header = hyper::header::HeaderValue::from_str(&format!("Basic {}", BASE64.encode(token)))?;
let hp = Handle {
hash: format!(
"{:?}",
git2::Oid::hash_object(git2::ObjectType::Blob, header.as_bytes())?
),
hash: Some(git2::Oid::hash_object(git2::ObjectType::Blob, header.as_bytes())?.to_string()),
};
let p = Header {
header: Some(header),
Expand All @@ -73,65 +110,122 @@ pub fn add_auth(token: &str) -> josh::JoshResult<Handle> {
Ok(hp)
}

pub async fn check_auth(url: &str, auth: &Handle, required: bool) -> josh::JoshResult<bool> {
if required && auth.hash.is_empty() {
return Ok(false);
}
#[tracing::instrument()]
pub async fn check_http_auth(url: &str, auth: &Handle, required: bool) -> josh::JoshResult<bool> {
use opentelemetry_semantic_conventions::trace::HTTP_RESPONSE_STATUS_CODE;

if let Some(last) = AUTH_TIMERS.lock()?.get(&(url.to_string(), auth.clone())) {
let since = std::time::Instant::now().duration_since(*last);
tracing::trace!("last: {:?}, since: {:?}", last, since);
if since < std::time::Duration::from_secs(60 * 30) {
tracing::trace!("cached auth");
return Ok(true);
}
if required && auth.hash.is_none() {
return Ok(false);
}

tracing::trace!("no cached auth {:?}", *AUTH_TIMERS.lock()?);
let group_key = AuthTimersGroupKey::new(url, &auth);
let auth_timers = AUTH_TIMERS
.lock()
.unwrap()
.entry(group_key.clone())
.or_default()
.clone();

let https = hyper_tls::HttpsConnector::new();
let client = hyper::Client::builder().build::<_, hyper::Body>(https);
let auth_header = AUTH.lock().unwrap().get(auth).cloned().unwrap_or_default();

let password = AUTH
.lock()?
.get(auth)
.unwrap_or(&Header { header: None })
.to_owned();
let refs_url = format!("{}/info/refs?service=git-upload-pack", url);
let do_request = || {
let refs_url = refs_url.clone();
let do_request_span = tracing::info_span!("check_http_auth: make request");

let builder = hyper::Request::builder()
.method(hyper::Method::GET)
.uri(&refs_url);
async move {
let https = hyper_tls::HttpsConnector::new();
let client = hyper::Client::builder().build::<_, hyper::Body>(https);

let builder = if let Some(value) = password.header {
builder.header(hyper::header::AUTHORIZATION, value)
} else {
builder
let builder = hyper::Request::builder()
.method(hyper::Method::GET)
.uri(&refs_url);

let builder = if let Some(value) = auth_header.header {
builder.header(hyper::header::AUTHORIZATION, value)
} else {
builder
};

let request = builder.body(hyper::Body::empty())?;
let resp = client.request(request).await?;

Ok::<_, josh::JoshError>(resp)
}
.instrument(do_request_span)
};

let request = builder.body(hyper::Body::empty())?;
let resp = client.request(request).await?;
// Only lock the mutex if auth handle is not empty, because otherwise
// for remotes that require auth, we could run into situation where
// multiple requests are executed essentially sequentially because
// remote always returns 401 for authenticated requests and we never
// populate the auth_timers map
let resp = if auth.hash.is_some() {
let mut auth_timers = auth_timers.lock().await;

if let Some(last) = auth_timers.get(auth) {
let since = std::time::Instant::now().duration_since(*last);
let expired = since > std::time::Duration::from_secs(60 * 30);

tracing::info!(
last = ?last,
since = ?since,
expired = %expired,
"check_http_auth: found auth entry"
);

if !expired {
return Ok(true);
}
}

let status = resp.status();
tracing::info!(
auth_timers = ?auth_timers,
"check_http_auth: no valid cached auth"
);

tracing::trace!("http resp.status {:?}", resp.status());
let resp = do_request().await?;
if resp.status().is_success() {
auth_timers.insert(auth.clone(), std::time::Instant::now());
}

resp
} else {
do_request().await?
};

let status = resp.status();

let err_msg = format!("got http response: {} {:?}", refs_url, resp);
tracing::event!(
tracing::Level::INFO,
{ HTTP_RESPONSE_STATUS_CODE } = status.as_u16(),
"check_http_auth: response"
);

if status == hyper::StatusCode::OK {
AUTH_TIMERS
.lock()?
.insert((url.to_string(), auth.clone()), std::time::Instant::now());
Ok(true)
} else if status == hyper::StatusCode::UNAUTHORIZED {
tracing::warn!("resp.status == 401: {:?}", &err_msg);
tracing::trace!(
"body: {:?}",
std::str::from_utf8(&hyper::body::to_bytes(resp.into_body()).await?)
tracing::event!(
tracing::Level::WARN,
{ HTTP_RESPONSE_STATUS_CODE } = status.as_u16(),
"check_http_auth: unauthorized"
);

let response = hyper::body::to_bytes(resp.into_body()).await?;
let response = String::from_utf8_lossy(&response);

tracing::event!(
tracing::Level::TRACE,
"http.response.body" = %response,
"check_http_auth: unauthorized",
);

Ok(false)
} else {
return Err(josh::josh_error(&err_msg));
return Err(josh::josh_error(&format!(
"check_http_auth: got http response: {} {:?}",
refs_url, resp
)));
}
}

Expand All @@ -144,9 +238,8 @@ pub fn strip_auth(

if let Some(header) = header {
let hp = Handle {
hash: format!(
"{:?}",
git2::Oid::hash_object(git2::ObjectType::Blob, header.as_bytes())?
hash: Some(
git2::Oid::hash_object(git2::ObjectType::Blob, header.as_bytes())?.to_string(),
),
};
let p = Header {
Expand All @@ -156,10 +249,5 @@ pub fn strip_auth(
return Ok((hp, req));
}

Ok((
Handle {
hash: "".to_owned(),
},
req,
))
Ok((Handle { hash: None }, req))
}
19 changes: 8 additions & 11 deletions josh-proxy/src/bin/josh-proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,13 @@ async fn fetch_upstream(

match (fetch_result, remote_auth) {
(Ok(_), RemoteAuth::Http { auth }) => {
let (auth_user, _) = auth.parse().map_err(FetchError::from_josh_error)?;

if matches!(&ARGS.poll_user, Some(user) if auth_user == user.as_str()) {
service
.poll
.lock()?
.insert((upstream_repo, auth.clone(), remote_url));
if let Some((auth_user, _)) = auth.parse() {
if matches!(&ARGS.poll_user, Some(user) if auth_user == user.as_str()) {
service
.poll
.lock()?
.insert((upstream_repo, auth.clone(), remote_url));
}
}

Ok(())
Expand Down Expand Up @@ -1275,10 +1275,7 @@ async fn call_service(

let http_auth_required = ARGS.require_auth && parsed_url.pathinfo == "/git-receive-pack";

if !josh_proxy::auth::check_auth(&remote_url, &auth, http_auth_required)
.in_current_span()
.await?
{
if !josh_proxy::auth::check_http_auth(&remote_url, &auth, http_auth_required).await? {
tracing::trace!("require-auth");
let builder = Response::builder()
.header(
Expand Down
2 changes: 1 addition & 1 deletion josh-proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ pub fn run_git_with_auth(
Ok(shell.command_env(cmd, &env, &env_notrace))
}
RemoteAuth::Http { auth } => {
let (username, password) = auth.parse()?;
let (username, password) = auth.parse().unwrap_or_default();
let env_notrace = [
[
("GIT_PASSWORD", password.as_str()),
Expand Down
Loading