Skip to content

Commit

Permalink
feat: implementing authenticator from scratch
Browse files Browse the repository at this point in the history
  • Loading branch information
junkurihara committed Oct 20, 2023
1 parent e78d507 commit 9c56f9f
Show file tree
Hide file tree
Showing 14 changed files with 663 additions and 43 deletions.
42 changes: 29 additions & 13 deletions dap-bin/src/config/toml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,18 @@ impl TryInto<ProxyConfig> for &ConfigToml {
/////////////////////////////
// bootstrap dns
if let Some(val) = &self.bootstrap_dns {
if !val.iter().all(|v| verify_sock_addr(v).is_ok()) {
if !val.iter().all(|v| verify_ip_addr(v).is_ok()) {
bail!("Invalid bootstrap DNS address");
}
proxy_config.bootstrap_dns = val.iter().map(|x| x.parse().unwrap()).collect()
proxy_config.bootstrap_dns.ips = val.iter().map(|x| x.parse().unwrap()).collect()
};
info!("Bootstrap DNS: {:?}", proxy_config.bootstrap_dns);
info!("Bootstrap DNS: {:?}", proxy_config.bootstrap_dns.ips);
if let Some(val) = self.reboot_period {
proxy_config.rebootstrap_period_sec = Duration::from_secs((val as u64) * 60);
proxy_config.bootstrap_dns.rebootstrap_period_sec = Duration::from_secs((val as u64) * 60);
}
info!(
"Target DoH Address is re-fetched every {:?} min via Bootsrap DNS",
proxy_config.rebootstrap_period_sec.as_secs() / 60
proxy_config.bootstrap_dns.rebootstrap_period_sec.as_secs() / 60
);

/////////////////////////////
Expand All @@ -93,11 +93,16 @@ impl TryInto<ProxyConfig> for &ConfigToml {
if !val.iter().all(|x| verify_target_url(x).is_ok()) {
bail!("Invalid target urls");
}
proxy_config.target_config.doh_target_urls = val.to_owned();
proxy_config.target_config.doh_target_urls = val.iter().map(|v| url::Url::parse(v).unwrap()).collect();
}
info!(
"Target (O)DoH resolvers: {:?}",
proxy_config.target_config.doh_target_urls
proxy_config
.target_config
.doh_target_urls
.iter()
.map(|x| x.as_str())
.collect::<Vec<_>>()
);
if let Some(val) = &self.target_randomization {
if !val {
Expand All @@ -122,11 +127,18 @@ impl TryInto<ProxyConfig> for &ConfigToml {
bail!("Invalid ODoH relay urls");
}
let mut nexthop_relay_config = NextHopRelayConfig {
odoh_relay_urls: odoh_relay_urls.to_owned(),
odoh_relay_urls: odoh_relay_urls.iter().map(|v| url::Url::parse(v).unwrap()).collect(),
odoh_relay_randomization: true,
};
info!("[ODoH] Oblivious DNS over HTTPS is enabled");
info!("[ODoH] Nexthop relay URL: {:?}", nexthop_relay_config.odoh_relay_urls);
info!(
"[ODoH] Nexthop relay URL: {:?}",
nexthop_relay_config
.odoh_relay_urls
.iter()
.map(|x| x.as_str())
.collect::<Vec<_>>()
);

if let Some(val) = anon.odoh_relay_randomization {
nexthop_relay_config.odoh_relay_randomization = val;
Expand All @@ -149,14 +161,18 @@ impl TryInto<ProxyConfig> for &ConfigToml {
bail!("max_mid_relays must be equal to or less than # of mid_relay_urls.");
}
let subseq_relay_config = SubseqRelayConfig {
mid_relay_urls: val.to_owned(),
mid_relay_urls: val.iter().map(|v| url::Url::parse(v).unwrap()).collect(),
max_mid_relays: anon.max_mid_relays.unwrap_or(1),
};

info!("[m-ODoH] Multiple-relay-based Oblivious DNS over HTTPS is enabled");
info!(
"[m-ODoH] Intermediate relay URLs employed after the next hop: {:?}",
subseq_relay_config.mid_relay_urls
subseq_relay_config
.mid_relay_urls
.iter()
.map(|x| x.as_str())
.collect::<Vec<_>>()
);
info!(
"[m-ODoH] Maximum number of intermediate relays after the nexthop: {}",
Expand Down Expand Up @@ -193,7 +209,7 @@ impl TryInto<ProxyConfig> for &ConfigToml {
username,
password,
client_id,
token_api: token_api.to_owned(),
token_api: token_api.parse().unwrap(),
};
proxy_config.authentication_config = Some(authentication_config);
}
Expand All @@ -217,7 +233,7 @@ impl TryInto<ProxyConfig> for &ConfigToml {

////////////////////////

// TODO: plugin関係は既存のコンフィグ何も読んでないので注意。rpxyのcryptosourcereloaderと同じように処理しなければいけない
// TODO: plugin関係は既存のコンフィグ何も読んでないので注意。rpxyのcrypto sourcere loaderと同じように処理しなければいけない

Ok(proxy_config)
}
Expand Down
9 changes: 8 additions & 1 deletion dap-bin/src/config/utils_verifier.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// functions to verify the startup arguments as correct
use std::net::SocketAddr;
use std::net::{IpAddr, SocketAddr};
use url::Url;

pub(crate) fn verify_sock_addr(arg_val: &str) -> Result<(), String> {
Expand All @@ -12,6 +12,13 @@ pub(crate) fn verify_sock_addr(arg_val: &str) -> Result<(), String> {
}
}

pub(crate) fn verify_ip_addr(arg_val: &str) -> Result<(), String> {
match arg_val.parse::<IpAddr>() {
Ok(_addr) => Ok(()),
Err(_) => Err(format!("Could not parse \"{}\" as a valid ip address.", arg_val)),
}
}

pub(crate) fn verify_target_url(arg_val: &str) -> Result<(), String> {
let url = match Url::parse(arg_val) {
Ok(addr) => addr,
Expand Down
24 changes: 24 additions & 0 deletions dap-lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,27 @@ tokio = { version = "1.32.0", features = [
futures = { version = "0.3.28", default-features = false }
anyhow = "1.0.75"
tracing = "0.1.37"
thiserror = "1.0.50"
async-trait = "0.1.74"

# http client
reqwest = { version = "0.11.22", default-features = false, features = [
"json",
"trust-dns",
"default",
] }
url = "2.4.1"


# for bootstrap dns resolver
trust-dns-resolver = { version = "0.23.1", default-features = false, features = [
"tokio-runtime",
] }


# authentication
jwt-simple = "0.11.7"
chrono = "0.4.31"
serde = { version = "1.0.189", features = ["derive"] }
serde_json = "1.0.107"
p256 = { version = "0.13.2", features = ["jwk", "pem"] }
228 changes: 228 additions & 0 deletions dap-lib/src/auth/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
use super::token::{Algorithm, TokenInner, TokenMeta};
use crate::{
constants::{ENDPOINT_JWKS_PATH, ENDPOINT_LOGIN_PATH},
error::*,
globals::AuthenticationConfig,
http::HttpClient,
log::*,
};
use jwt_simple::prelude::{JWTClaims, NoCustomClaims};
use serde::{Deserialize, Serialize};
use std::{str::FromStr, sync::Arc};
use tokio::sync::RwLock;

/// Authentication request
#[derive(Serialize)]
pub struct AuthenticationRequest {
auth: AuthenticationReqInner,
client_id: String,
}
#[derive(Serialize)]
/// Auth req inner
pub struct AuthenticationReqInner {
username: String,
password: String,
}

#[derive(Deserialize, Debug)]
/// Auth response
pub struct AuthenticationResponse {
pub token: TokenInner,
pub metadata: TokenMeta,
pub message: String,
}

#[derive(Deserialize, Debug)]
pub struct Jwks {
pub keys: Vec<serde_json::Value>,
}

/// Authenticator client
pub struct Authenticator {
config: AuthenticationConfig,
http_client: Arc<RwLock<HttpClient>>,
id_token: Arc<RwLock<Option<TokenInner>>>,
refresh_token: Arc<RwLock<Option<String>>>,
validation_key: Arc<RwLock<Option<String>>>,
}

impl Authenticator {
/// Build authenticator
pub async fn new(auth_config: &AuthenticationConfig, http_client: Arc<RwLock<HttpClient>>) -> Result<Self> {
Ok(Self {
config: auth_config.clone(),
http_client,
id_token: Arc::new(RwLock::new(None)),
refresh_token: Arc::new(RwLock::new(None)),
validation_key: Arc::new(RwLock::new(None)),
})
}

/// Update jwks key
async fn update_validation_key(&self) -> Result<()> {
let id_token_lock = self.id_token.read().await;
let Some(id_token) = id_token_lock.as_ref() else {
return Err(DapError::AuthenticationError("No id token".to_string()));
};
let meta = id_token.decode_id_token().await?;
drop(id_token_lock);

let mut jwks_endpoint = self.config.token_api.clone();
jwks_endpoint
.path_segments_mut()
.map_err(|_| DapError::Other(anyhow!("Failed to parse token api url".to_string())))?
.push(ENDPOINT_JWKS_PATH);

let client_lock = self.http_client.read().await;
let res = client_lock
.get(jwks_endpoint)
.await
.send()
.await
.map_err(|e| DapError::AuthenticationError(e.to_string()))?;
drop(client_lock);

if !res.status().is_success() {
error!("Jwks retrieval error!: {:?}", res.status());
return Err(DapError::AuthenticationError(format!(
"Jwks retrieval error!: {:?}",
res.status()
)));
}

let jwks = res
.json::<Jwks>()
.await
.map_err(|_e| DapError::AuthenticationError("Failed to parse jwks response".to_string()))?;

let key_id = meta
.key_id()
.ok_or_else(|| DapError::AuthenticationError("No key id".to_string()))?;

let matched_key = jwks.keys.iter().find(|x| {
let kid = x["kid"].as_str().unwrap_or("");
kid == key_id
});
if matched_key.is_none() {
return Err(DapError::AuthenticationError(format!(
"No JWK matched to Id token is given at jwks endpoint! key_id: {}",
key_id
)));
}

let mut matched = matched_key.unwrap().clone();
let Some(matched_jwk) = matched.as_object_mut() else {
return Err(DapError::AuthenticationError(
"Invalid jwk retrieved from jwks endpoint".to_string(),
));
};
matched_jwk.remove_entry("kid");
let Ok(jwk_string) = serde_json::to_string(matched_jwk) else {
return Err(DapError::AuthenticationError("Failed to serialize jwk".to_string()));
};
debug!("Matched JWK given at jwks endpoint is {}", &jwk_string);

let pem = match Algorithm::from_str(meta.algorithm())? {
Algorithm::ES256 => {
let pk =
p256::PublicKey::from_jwk_str(&jwk_string).map_err(|e| DapError::AuthenticationError(e.to_string()))?;
pk.to_string()
}
};

let mut validation_key_lock = self.validation_key.write().await;
validation_key_lock.replace(pem.clone());
drop(validation_key_lock);

info!("validation key updated");

Ok(())
}

/// Verify id token
async fn verify_id_token(&self) -> Result<JWTClaims<NoCustomClaims>> {
let pk_str_lock = self.validation_key.read().await;
let Some(pk_str) = pk_str_lock.as_ref() else {
return Err(DapError::AuthenticationError("No validation key".to_string()));
};
let pk_str = pk_str.clone();
drop(pk_str_lock);

let token_lock = self.id_token.read().await;
let Some(token_inner) = token_lock.as_ref() else {
return Err(DapError::AuthenticationError("No id token".to_string()));
};
let token = token_inner.clone();
drop(token_lock);

token.verify_id_token(&pk_str, &self.config).await
}

/// Login to the authentication server
pub async fn login(&self) -> Result<()> {
let mut login_endpoint = self.config.token_api.clone();
login_endpoint
.path_segments_mut()
.map_err(|_| DapError::Other(anyhow!("Failed to parse token api url".to_string())))?
.push(ENDPOINT_LOGIN_PATH);

let json_request = AuthenticationRequest {
auth: AuthenticationReqInner {
username: self.config.username.clone(),
password: self.config.password.clone(),
},
client_id: self.config.client_id.clone(),
};

let client_lock = self.http_client.read().await;
let res = client_lock
.post(login_endpoint)
.await
.json(&json_request)
.send()
.await
.map_err(|e| DapError::AuthenticationError(e.to_string()))?;
drop(client_lock);

if !res.status().is_success() {
error!("Login error!: {:?}", res.status());
return Err(DapError::AuthenticationError(format!(
"Login error!: {:?}",
res.status()
)));
}

// parse token
let token = res
.json::<AuthenticationResponse>()
.await
.map_err(|_e| DapError::AuthenticationError("Failed to parse token response".to_string()))?;

if let Some(refresh) = &token.token.refresh {
let mut refresh_token_lock = self.refresh_token.write().await;
refresh_token_lock.replace(refresh.clone());
drop(refresh_token_lock);
}

let mut id_token_lock = self.id_token.write().await;
id_token_lock.replace(token.token);
drop(id_token_lock);

info!("Token retrieved");

// update validation key
self.update_validation_key().await?;

// verify id token with validation key
let Ok(_clm) = self.verify_id_token().await else {
return Err(DapError::AuthenticationError(
"Invalid Id token! Carefully check if target DNS or Token API is compromised!".to_string(),
));
};

info!("Login success!");
Ok(())
}

// TODO: refresh by checking the expiration time
}
4 changes: 4 additions & 0 deletions dap-lib/src/auth/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mod auth;
mod token;

pub use auth::Authenticator;
Loading

0 comments on commit 9c56f9f

Please sign in to comment.