From 9c56f9f17316dc3a0af7bca338c23684b6753546 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 20 Oct 2023 23:33:36 +0900 Subject: [PATCH] feat: implementing authenticator from scratch --- dap-bin/src/config/toml.rs | 42 +++-- dap-bin/src/config/utils_verifier.rs | 9 +- dap-lib/Cargo.toml | 24 +++ dap-lib/src/auth/auth.rs | 228 +++++++++++++++++++++++++++ dap-lib/src/auth/mod.rs | 4 + dap-lib/src/auth/token.rs | 96 +++++++++++ dap-lib/src/bootstrap.rs | 97 ++++++++++++ dap-lib/src/constants.rs | 11 +- dap-lib/src/error.rs | 29 +++- dap-lib/src/globals.rs | 40 +++-- dap-lib/src/http.rs | 64 ++++++++ dap-lib/src/lib.rs | 60 ++++++- dap-lib/src/proxy/mod.rs | 0 doh-auth-proxy.toml | 2 +- 14 files changed, 663 insertions(+), 43 deletions(-) create mode 100644 dap-lib/src/auth/auth.rs create mode 100644 dap-lib/src/auth/mod.rs create mode 100644 dap-lib/src/auth/token.rs create mode 100644 dap-lib/src/bootstrap.rs create mode 100644 dap-lib/src/http.rs create mode 100644 dap-lib/src/proxy/mod.rs diff --git a/dap-bin/src/config/toml.rs b/dap-bin/src/config/toml.rs index 9f16b9f..375b977 100644 --- a/dap-bin/src/config/toml.rs +++ b/dap-bin/src/config/toml.rs @@ -66,18 +66,18 @@ impl TryInto for &ConfigToml { ///////////////////////////// // bootstrap dns if let Some(val) = &self.bootstrap_dns { - if !val.iter().all(|v| verify_sock_addr(v).is_ok()) { + if !val.iter().all(|v| verify_ip_addr(v).is_ok()) { bail!("Invalid bootstrap DNS address"); } - proxy_config.bootstrap_dns = val.iter().map(|x| x.parse().unwrap()).collect() + proxy_config.bootstrap_dns.ips = val.iter().map(|x| x.parse().unwrap()).collect() }; - info!("Bootstrap DNS: {:?}", proxy_config.bootstrap_dns); + info!("Bootstrap DNS: {:?}", proxy_config.bootstrap_dns.ips); if let Some(val) = self.reboot_period { - proxy_config.rebootstrap_period_sec = Duration::from_secs((val as u64) * 60); + proxy_config.bootstrap_dns.rebootstrap_period_sec = Duration::from_secs((val as u64) * 60); } info!( "Target DoH Address is re-fetched every {:?} min via Bootsrap DNS", - proxy_config.rebootstrap_period_sec.as_secs() / 60 + proxy_config.bootstrap_dns.rebootstrap_period_sec.as_secs() / 60 ); ///////////////////////////// @@ -93,11 +93,16 @@ impl TryInto for &ConfigToml { if !val.iter().all(|x| verify_target_url(x).is_ok()) { bail!("Invalid target urls"); } - proxy_config.target_config.doh_target_urls = val.to_owned(); + proxy_config.target_config.doh_target_urls = val.iter().map(|v| url::Url::parse(v).unwrap()).collect(); } info!( "Target (O)DoH resolvers: {:?}", - proxy_config.target_config.doh_target_urls + proxy_config + .target_config + .doh_target_urls + .iter() + .map(|x| x.as_str()) + .collect::>() ); if let Some(val) = &self.target_randomization { if !val { @@ -122,11 +127,18 @@ impl TryInto for &ConfigToml { bail!("Invalid ODoH relay urls"); } let mut nexthop_relay_config = NextHopRelayConfig { - odoh_relay_urls: odoh_relay_urls.to_owned(), + odoh_relay_urls: odoh_relay_urls.iter().map(|v| url::Url::parse(v).unwrap()).collect(), odoh_relay_randomization: true, }; info!("[ODoH] Oblivious DNS over HTTPS is enabled"); - info!("[ODoH] Nexthop relay URL: {:?}", nexthop_relay_config.odoh_relay_urls); + info!( + "[ODoH] Nexthop relay URL: {:?}", + nexthop_relay_config + .odoh_relay_urls + .iter() + .map(|x| x.as_str()) + .collect::>() + ); if let Some(val) = anon.odoh_relay_randomization { nexthop_relay_config.odoh_relay_randomization = val; @@ -149,14 +161,18 @@ impl TryInto for &ConfigToml { bail!("max_mid_relays must be equal to or less than # of mid_relay_urls."); } let subseq_relay_config = SubseqRelayConfig { - mid_relay_urls: val.to_owned(), + mid_relay_urls: val.iter().map(|v| url::Url::parse(v).unwrap()).collect(), max_mid_relays: anon.max_mid_relays.unwrap_or(1), }; info!("[m-ODoH] Multiple-relay-based Oblivious DNS over HTTPS is enabled"); info!( "[m-ODoH] Intermediate relay URLs employed after the next hop: {:?}", - subseq_relay_config.mid_relay_urls + subseq_relay_config + .mid_relay_urls + .iter() + .map(|x| x.as_str()) + .collect::>() ); info!( "[m-ODoH] Maximum number of intermediate relays after the nexthop: {}", @@ -193,7 +209,7 @@ impl TryInto for &ConfigToml { username, password, client_id, - token_api: token_api.to_owned(), + token_api: token_api.parse().unwrap(), }; proxy_config.authentication_config = Some(authentication_config); } @@ -217,7 +233,7 @@ impl TryInto for &ConfigToml { //////////////////////// - // TODO: plugin関係は既存のコンフィグ何も読んでないので注意。rpxyのcryptosourcereloaderと同じように処理しなければいけない + // TODO: plugin関係は既存のコンフィグ何も読んでないので注意。rpxyのcrypto sourcere loaderと同じように処理しなければいけない Ok(proxy_config) } diff --git a/dap-bin/src/config/utils_verifier.rs b/dap-bin/src/config/utils_verifier.rs index fb8d47b..b27c246 100644 --- a/dap-bin/src/config/utils_verifier.rs +++ b/dap-bin/src/config/utils_verifier.rs @@ -1,5 +1,5 @@ // functions to verify the startup arguments as correct -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use url::Url; pub(crate) fn verify_sock_addr(arg_val: &str) -> Result<(), String> { @@ -12,6 +12,13 @@ pub(crate) fn verify_sock_addr(arg_val: &str) -> Result<(), String> { } } +pub(crate) fn verify_ip_addr(arg_val: &str) -> Result<(), String> { + match arg_val.parse::() { + Ok(_addr) => Ok(()), + Err(_) => Err(format!("Could not parse \"{}\" as a valid ip address.", arg_val)), + } +} + pub(crate) fn verify_target_url(arg_val: &str) -> Result<(), String> { let url = match Url::parse(arg_val) { Ok(addr) => addr, diff --git a/dap-lib/Cargo.toml b/dap-lib/Cargo.toml index a82e33a..1f1f611 100644 --- a/dap-lib/Cargo.toml +++ b/dap-lib/Cargo.toml @@ -41,3 +41,27 @@ tokio = { version = "1.32.0", features = [ futures = { version = "0.3.28", default-features = false } anyhow = "1.0.75" tracing = "0.1.37" +thiserror = "1.0.50" +async-trait = "0.1.74" + +# http client +reqwest = { version = "0.11.22", default-features = false, features = [ + "json", + "trust-dns", + "default", +] } +url = "2.4.1" + + +# for bootstrap dns resolver +trust-dns-resolver = { version = "0.23.1", default-features = false, features = [ + "tokio-runtime", +] } + + +# authentication +jwt-simple = "0.11.7" +chrono = "0.4.31" +serde = { version = "1.0.189", features = ["derive"] } +serde_json = "1.0.107" +p256 = { version = "0.13.2", features = ["jwk", "pem"] } diff --git a/dap-lib/src/auth/auth.rs b/dap-lib/src/auth/auth.rs new file mode 100644 index 0000000..340d252 --- /dev/null +++ b/dap-lib/src/auth/auth.rs @@ -0,0 +1,228 @@ +use super::token::{Algorithm, TokenInner, TokenMeta}; +use crate::{ + constants::{ENDPOINT_JWKS_PATH, ENDPOINT_LOGIN_PATH}, + error::*, + globals::AuthenticationConfig, + http::HttpClient, + log::*, +}; +use jwt_simple::prelude::{JWTClaims, NoCustomClaims}; +use serde::{Deserialize, Serialize}; +use std::{str::FromStr, sync::Arc}; +use tokio::sync::RwLock; + +/// Authentication request +#[derive(Serialize)] +pub struct AuthenticationRequest { + auth: AuthenticationReqInner, + client_id: String, +} +#[derive(Serialize)] +/// Auth req inner +pub struct AuthenticationReqInner { + username: String, + password: String, +} + +#[derive(Deserialize, Debug)] +/// Auth response +pub struct AuthenticationResponse { + pub token: TokenInner, + pub metadata: TokenMeta, + pub message: String, +} + +#[derive(Deserialize, Debug)] +pub struct Jwks { + pub keys: Vec, +} + +/// Authenticator client +pub struct Authenticator { + config: AuthenticationConfig, + http_client: Arc>, + id_token: Arc>>, + refresh_token: Arc>>, + validation_key: Arc>>, +} + +impl Authenticator { + /// Build authenticator + pub async fn new(auth_config: &AuthenticationConfig, http_client: Arc>) -> Result { + Ok(Self { + config: auth_config.clone(), + http_client, + id_token: Arc::new(RwLock::new(None)), + refresh_token: Arc::new(RwLock::new(None)), + validation_key: Arc::new(RwLock::new(None)), + }) + } + + /// Update jwks key + async fn update_validation_key(&self) -> Result<()> { + let id_token_lock = self.id_token.read().await; + let Some(id_token) = id_token_lock.as_ref() else { + return Err(DapError::AuthenticationError("No id token".to_string())); + }; + let meta = id_token.decode_id_token().await?; + drop(id_token_lock); + + let mut jwks_endpoint = self.config.token_api.clone(); + jwks_endpoint + .path_segments_mut() + .map_err(|_| DapError::Other(anyhow!("Failed to parse token api url".to_string())))? + .push(ENDPOINT_JWKS_PATH); + + let client_lock = self.http_client.read().await; + let res = client_lock + .get(jwks_endpoint) + .await + .send() + .await + .map_err(|e| DapError::AuthenticationError(e.to_string()))?; + drop(client_lock); + + if !res.status().is_success() { + error!("Jwks retrieval error!: {:?}", res.status()); + return Err(DapError::AuthenticationError(format!( + "Jwks retrieval error!: {:?}", + res.status() + ))); + } + + let jwks = res + .json::() + .await + .map_err(|_e| DapError::AuthenticationError("Failed to parse jwks response".to_string()))?; + + let key_id = meta + .key_id() + .ok_or_else(|| DapError::AuthenticationError("No key id".to_string()))?; + + let matched_key = jwks.keys.iter().find(|x| { + let kid = x["kid"].as_str().unwrap_or(""); + kid == key_id + }); + if matched_key.is_none() { + return Err(DapError::AuthenticationError(format!( + "No JWK matched to Id token is given at jwks endpoint! key_id: {}", + key_id + ))); + } + + let mut matched = matched_key.unwrap().clone(); + let Some(matched_jwk) = matched.as_object_mut() else { + return Err(DapError::AuthenticationError( + "Invalid jwk retrieved from jwks endpoint".to_string(), + )); + }; + matched_jwk.remove_entry("kid"); + let Ok(jwk_string) = serde_json::to_string(matched_jwk) else { + return Err(DapError::AuthenticationError("Failed to serialize jwk".to_string())); + }; + debug!("Matched JWK given at jwks endpoint is {}", &jwk_string); + + let pem = match Algorithm::from_str(meta.algorithm())? { + Algorithm::ES256 => { + let pk = + p256::PublicKey::from_jwk_str(&jwk_string).map_err(|e| DapError::AuthenticationError(e.to_string()))?; + pk.to_string() + } + }; + + let mut validation_key_lock = self.validation_key.write().await; + validation_key_lock.replace(pem.clone()); + drop(validation_key_lock); + + info!("validation key updated"); + + Ok(()) + } + + /// Verify id token + async fn verify_id_token(&self) -> Result> { + let pk_str_lock = self.validation_key.read().await; + let Some(pk_str) = pk_str_lock.as_ref() else { + return Err(DapError::AuthenticationError("No validation key".to_string())); + }; + let pk_str = pk_str.clone(); + drop(pk_str_lock); + + let token_lock = self.id_token.read().await; + let Some(token_inner) = token_lock.as_ref() else { + return Err(DapError::AuthenticationError("No id token".to_string())); + }; + let token = token_inner.clone(); + drop(token_lock); + + token.verify_id_token(&pk_str, &self.config).await + } + + /// Login to the authentication server + pub async fn login(&self) -> Result<()> { + let mut login_endpoint = self.config.token_api.clone(); + login_endpoint + .path_segments_mut() + .map_err(|_| DapError::Other(anyhow!("Failed to parse token api url".to_string())))? + .push(ENDPOINT_LOGIN_PATH); + + let json_request = AuthenticationRequest { + auth: AuthenticationReqInner { + username: self.config.username.clone(), + password: self.config.password.clone(), + }, + client_id: self.config.client_id.clone(), + }; + + let client_lock = self.http_client.read().await; + let res = client_lock + .post(login_endpoint) + .await + .json(&json_request) + .send() + .await + .map_err(|e| DapError::AuthenticationError(e.to_string()))?; + drop(client_lock); + + if !res.status().is_success() { + error!("Login error!: {:?}", res.status()); + return Err(DapError::AuthenticationError(format!( + "Login error!: {:?}", + res.status() + ))); + } + + // parse token + let token = res + .json::() + .await + .map_err(|_e| DapError::AuthenticationError("Failed to parse token response".to_string()))?; + + if let Some(refresh) = &token.token.refresh { + let mut refresh_token_lock = self.refresh_token.write().await; + refresh_token_lock.replace(refresh.clone()); + drop(refresh_token_lock); + } + + let mut id_token_lock = self.id_token.write().await; + id_token_lock.replace(token.token); + drop(id_token_lock); + + info!("Token retrieved"); + + // update validation key + self.update_validation_key().await?; + + // verify id token with validation key + let Ok(_clm) = self.verify_id_token().await else { + return Err(DapError::AuthenticationError( + "Invalid Id token! Carefully check if target DNS or Token API is compromised!".to_string(), + )); + }; + + info!("Login success!"); + Ok(()) + } + + // TODO: refresh by checking the expiration time +} diff --git a/dap-lib/src/auth/mod.rs b/dap-lib/src/auth/mod.rs new file mode 100644 index 0000000..4a29225 --- /dev/null +++ b/dap-lib/src/auth/mod.rs @@ -0,0 +1,4 @@ +mod auth; +mod token; + +pub use auth::Authenticator; diff --git a/dap-lib/src/auth/token.rs b/dap-lib/src/auth/token.rs new file mode 100644 index 0000000..bc2e56e --- /dev/null +++ b/dap-lib/src/auth/token.rs @@ -0,0 +1,96 @@ +use crate::{error::*, AuthenticationConfig}; +use jwt_simple::{ + prelude::*, + token::{Token, TokenMetadata}, +}; +use p256::elliptic_curve::sec1::ToEncodedPoint; +use serde::Deserialize; +use std::str::FromStr; + +#[derive(Debug)] +pub(crate) enum Algorithm { + ES256, +} +impl FromStr for Algorithm { + type Err = DapError; + fn from_str(s: &str) -> Result { + match s { + "ES256" => Ok(Algorithm::ES256), + _ => Err(DapError::Other(anyhow!("Invalid Algorithm Name"))), + } + } +} + +#[derive(Deserialize, Debug, Clone)] +pub struct TokenInner { + /// id_token jwt itself is given here as string + pub id: String, + /// refresh token if required + pub refresh: Option, + /// issued at in unix time + pub issued_at: String, + /// expires in unix time + pub expires: String, + /// allowed apps, i.e, client_ids + pub allowed_apps: Vec, + /// issuer specified by url like 'https://example.com/' for IdToken + pub issuer: String, + /// subscriber id generated by the token server + pub subscriber_id: String, +} + +#[derive(Deserialize, Debug, Clone)] +pub struct TokenMeta { + pub username: String, + pub is_admin: bool, +} + +impl TokenInner { + /// Decode id token and retrieve metadata + pub async fn decode_id_token(&self) -> Result { + Token::decode_metadata(&self.id).map_err(|e| DapError::TokenError(e.to_string())) + } + + /// Verify id token with key string + pub async fn verify_id_token( + &self, + validation_key: &str, + config: &AuthenticationConfig, + ) -> Result> { + let meta = self.decode_id_token().await?; + + let options = VerificationOptions { + accept_future: true, // accept future + allowed_audiences: Some(HashSet::from_strings(&[&config.client_id])), + allowed_issuers: Some(HashSet::from_strings(&[&config.token_api])), + ..Default::default() + }; + + let clm: JWTClaims = match Algorithm::from_str(meta.algorithm())? { + Algorithm::ES256 => { + let public_key = validation_key + .parse::() + .map_err(|e| DapError::Other(anyhow!(e)))?; + let sec1key = public_key.to_encoded_point(false); + let key = ES256PublicKey::from_bytes(sec1key.as_bytes())?; + key.verify_token::(&self.id, Some(options))? + } + }; + + Ok(clm) + } + // pub async fn id_token_expires_in_secs(&self) -> Result { + // // This returns unix time in secs + // let clm = self.verify_id_token().await?; + // let expires_at: i64 = clm.expires_at.unwrap().as_secs() as i64; + // let dt: DateTime = Local::now(); + // let timestamp = dt.timestamp(); + // let expires_in_secs = expires_at - timestamp; + // if expires_in_secs < CREDENTIAL_REFRESH_MARGIN { + // // try to refresh immediately + // return Ok(0); + // } + + // Ok(expires_in_secs) + // } +} diff --git a/dap-lib/src/bootstrap.rs b/dap-lib/src/bootstrap.rs new file mode 100644 index 0000000..268eba0 --- /dev/null +++ b/dap-lib/src/bootstrap.rs @@ -0,0 +1,97 @@ +use crate::{error::*, globals::BootstrapDns, log::*, ResolveIpResponse, ResolveIps}; +use async_trait::async_trait; +use reqwest::Url; +use std::net::SocketAddr; +use trust_dns_resolver::{ + config::{NameServerConfigGroup, ResolverConfig, ResolverOpts}, + name_server::{GenericConnector, TokioRuntimeProvider}, + AsyncResolver, TokioAsyncResolver, +}; + +/// stub resolver using bootstrap DNS resolver +pub struct BootstrapDnsResolver { + pub inner: AsyncResolver>, +} + +impl BootstrapDnsResolver { + /// Build stub resolver using bootstrap dns resolver + pub async fn try_new(bootstrap_dns: &BootstrapDns, runtime_handle: tokio::runtime::Handle) -> Result { + let ips = &bootstrap_dns.ips; + let port = &bootstrap_dns.port; + let name_servers = NameServerConfigGroup::from_ips_clear(ips, *port, true); + let resolver_config = ResolverConfig::from_parts(None, vec![], name_servers); + + let resolver = runtime_handle + .spawn(async { TokioAsyncResolver::tokio(resolver_config, ResolverOpts::default()) }) + .await + .map_err(|e| DapError::Other(anyhow!(e)))?; + + Ok(Self { inner: resolver }) + } +} + +#[async_trait] +impl ResolveIps for BootstrapDnsResolver { + /// Lookup the IP addresses associated with a name using the bootstrap resolver + async fn resolve_ips(&self, target_url: &Url) -> Result { + // The final dot forces this to be an FQDN, otherwise the search rules as specified + // in `ResolverOpts` will take effect. FQDN's are generally cheaper queries. + let host_str = target_url + .host_str() + .ok_or_else(|| DapError::Other(anyhow!("Unable to parse target host name")))?; + let port = target_url + .port() + .unwrap_or_else(|| if target_url.scheme() == "https" { 443 } else { 80 }); + let response = self + .inner + .lookup_ip(format!("{}.", host_str)) + .await + .map_err(DapError::BootstrapResolverError)?; + + // There can be many addresses associated with the name, + // this can return IPv4 and/or IPv6 addresses + let target_addrs = response + .iter() + .filter_map(|addr| format!("{}:{}", addr, port).parse::().ok()) + .collect::>(); + + if target_addrs.is_empty() { + return Err(DapError::Other(anyhow!( + "Invalid target url: {target_url}, cannot resolve ip address" + ))); + } + debug!( + "Updated target url {} ip addresses by using bootstrap dns: {:?}", + host_str, target_addrs + ); + + Ok(ResolveIpResponse { + hostname: host_str.to_string(), + addresses: target_addrs, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::IpAddr; + + #[tokio::test] + async fn test_bootstrap_dns_resolver() { + let bootstrap_dns = BootstrapDns { + ips: vec![IpAddr::from([8, 8, 8, 8])], + port: 53, + rebootstrap_period_sec: tokio::time::Duration::from_secs(1), + }; + let resolver = BootstrapDnsResolver::try_new(&bootstrap_dns, tokio::runtime::Handle::current()) + .await + .unwrap(); + let target_url = Url::parse("https://dns.google").unwrap(); + let response = resolver.resolve_ips(&target_url).await.unwrap(); + + assert_eq!(response.hostname.as_str(), "dns.google"); + assert!(response.addresses.contains(&SocketAddr::from(([8, 8, 8, 8], 443)))); + assert!(response.addresses.contains(&SocketAddr::from(([8, 8, 4, 4], 443)))); + } +} diff --git a/dap-lib/src/constants.rs b/dap-lib/src/constants.rs index d03b18a..b98fc7a 100644 --- a/dap-lib/src/constants.rs +++ b/dap-lib/src/constants.rs @@ -15,7 +15,8 @@ pub const MIN_TTL: u32 = 10; // TTL for overridden records (plugin) // Can override by specifying values in config.toml pub const LISTEN_ADDRESSES: &[&str] = &["127.0.0.1:50053", "[::1]:50053"]; -pub const BOOTSTRAP_DNS: &[&str] = &["1.1.1.1:53"]; +pub const BOOTSTRAP_DNS_IPS: &[&str] = &["1.1.1.1"]; +pub const BOOTSTRAP_DNS_PORT: u16 = 53; pub const REBOOTSTRAP_PERIOD_MIN: u64 = 60; pub const DOH_TARGET_URL: &[&str] = &["https://dns.google/dns-query"]; @@ -25,10 +26,10 @@ pub const MAX_CACHE_SIZE: usize = 16384; // Constant Values for Proxy // /////////////////////////////// // Cannot override below by config.toml -pub const ODOH_CONFIG_PATH: &str = "/.well-known/odohconfigs"; // client -pub const ENDPOINT_LOGIN_PATH: &str = "/tokens"; // client::credential -pub const ENDPOINT_REFRESH_PATH: &str = "/refresh"; // client::credential -pub const ENDPOINT_JWKS_PATH: &str = "/jwks"; // client::credential +pub const ODOH_CONFIG_PATH: &str = ".well-known/odohconfigs"; // client +pub const ENDPOINT_LOGIN_PATH: &str = "tokens"; // client::credential +pub const ENDPOINT_REFRESH_PATH: &str = "refresh"; // client::credential +pub const ENDPOINT_JWKS_PATH: &str = "jwks"; // client::credential // pub const CREDENTIAL_REFRESH_BEFORE_EXPIRATION_IN_SECS: i64 = 600; // refresh 10 minutes before expiration // proxy // pub const CREDENTIAL_REFRESH_MARGIN: i64 = 10; // at least 10 secs must be left to refresh // client::credential diff --git a/dap-lib/src/error.rs b/dap-lib/src/error.rs index 68a6fd5..f2d3aac 100644 --- a/dap-lib/src/error.rs +++ b/dap-lib/src/error.rs @@ -1,5 +1,26 @@ -pub use anyhow::{anyhow, bail, ensure, Context, Result}; -// use std::io; -// use thiserror::Error; +pub use anyhow::{anyhow, bail, ensure, Context}; +use thiserror::Error; -// pub type Result = std::result::Result; +pub type Result = std::result::Result; + +/// Describes things that can go wrong in the Rpxy +#[derive(Debug, Error)] +pub enum DapError { + #[error("Bootstrap resolver error: {0}")] + BootstrapResolverError(#[from] trust_dns_resolver::error::ResolveError), + + #[error("Http client build error: {0}")] + HttpClientError(String), + + #[error("Url error: {0}")] + UrlError(#[from] url::ParseError), + + #[error("Authentication error: {0}")] + AuthenticationError(String), + + #[error("Token error: {0}")] + TokenError(String), + + #[error(transparent)] + Other(#[from] anyhow::Error), +} diff --git a/dap-lib/src/globals.rs b/dap-lib/src/globals.rs index a30aff0..3e641ec 100644 --- a/dap-lib/src/globals.rs +++ b/dap-lib/src/globals.rs @@ -1,16 +1,22 @@ -use crate::{client::DoHMethod, constants::*}; +use crate::{client::DoHMethod, constants::*, http::HttpClient}; // use futures::future; // use rand::Rng; -use std::net::SocketAddr; -use tokio::time::Duration; +use std::{ + net::{IpAddr, SocketAddr}, + sync::{Arc, RwLock}, +}; +use tokio::{sync::Notify, time::Duration}; +use url::Url; #[derive(Debug, Clone)] pub struct Globals { // pub cache: Arc, // pub counter: ConnCounter, - // pub doh_clients: Arc>>>, + pub http_client: Arc>, + pub proxy_config: ProxyConfig, pub runtime_handle: tokio::runtime::Handle, + pub term_notify: Option>, } #[derive(PartialEq, Eq, Debug, Clone)] @@ -20,8 +26,7 @@ pub struct ProxyConfig { pub max_cache_size: usize, /// bootstrap DNS - pub bootstrap_dns: Vec, - pub rebootstrap_period_sec: Duration, + pub bootstrap_dns: BootstrapDns, // udp proxy setting pub udp_buffer_size: usize, @@ -44,25 +49,33 @@ pub struct ProxyConfig { // pub credential: Arc>>, } +#[derive(PartialEq, Eq, Debug, Clone)] +/// Bootstrap DNS Addresses +pub struct BootstrapDns { + pub ips: Vec, + pub port: u16, + pub rebootstrap_period_sec: Duration, +} + #[derive(PartialEq, Eq, Debug, Clone)] /// doh, odoh, modoh target settings pub struct TargetConfig { pub doh_method: DoHMethod, - pub doh_target_urls: Vec, + pub doh_target_urls: Vec, pub target_randomization: bool, } #[derive(PartialEq, Eq, Debug, Clone)] /// odoh and modoh nexthop pub struct NextHopRelayConfig { - pub odoh_relay_urls: Vec, + pub odoh_relay_urls: Vec, pub odoh_relay_randomization: bool, } #[derive(PartialEq, Eq, Debug, Clone)] /// modoh pub struct SubseqRelayConfig { - pub mid_relay_urls: Vec, + pub mid_relay_urls: Vec, pub max_mid_relays: usize, } @@ -71,7 +84,7 @@ pub struct AuthenticationConfig { pub username: String, pub password: String, pub client_id: String, - pub token_api: String, + pub token_api: Url, } impl Default for TargetConfig { @@ -91,8 +104,11 @@ impl Default for ProxyConfig { max_connections: MAX_CONNECTIONS, max_cache_size: MAX_CACHE_SIZE, - bootstrap_dns: BOOTSTRAP_DNS.iter().map(|v| v.parse().unwrap()).collect(), - rebootstrap_period_sec: Duration::from_secs(REBOOTSTRAP_PERIOD_MIN * 60), + bootstrap_dns: BootstrapDns { + ips: BOOTSTRAP_DNS_IPS.iter().map(|v| v.parse().unwrap()).collect(), + port: BOOTSTRAP_DNS_PORT, + rebootstrap_period_sec: Duration::from_secs(REBOOTSTRAP_PERIOD_MIN * 60), + }, udp_buffer_size: UDP_BUFFER_SIZE, udp_channel_capacity: UDP_CHANNEL_CAPACITY, diff --git a/dap-lib/src/http.rs b/dap-lib/src/http.rs new file mode 100644 index 0000000..c59fda5 --- /dev/null +++ b/dap-lib/src/http.rs @@ -0,0 +1,64 @@ +use crate::{error::*, ResolveIps}; +use futures::future::join_all; +use reqwest::{header::HeaderMap, Client, IntoUrl, RequestBuilder, Url}; +use tokio::time::Duration; + +#[derive(Debug)] +/// HttpClient that is a wrapper of reqwest::Client +pub struct HttpClient { + /// client: reqwest::Client, + client: Client, + + /// domain: endpoint candidates that the client will connect to, where these ip addresses are resolved when instantiated by a given resolver implementing ResolveIps. + /// This would be targets for DoH, nexthop relay for ODoH (path including target, not mid-relays for dynamic randomization) + endpoints: Vec, + + /// timeout for http request + timeout_sec: Duration, +} + +impl HttpClient { + /// Build HttpClient + pub async fn new( + endpoints: &[Url], + timeout_sec: Duration, + default_headers: Option<&HeaderMap>, + resolver_ips: impl ResolveIps, + ) -> Result { + let resolve_ips_fut = endpoints.iter().map(|endpoint| resolver_ips.resolve_ips(endpoint)); + let resolve_ips = join_all(resolve_ips_fut).await; + if resolve_ips.iter().any(|resolve_ip| resolve_ip.is_err()) { + return Err(DapError::HttpClientError("Failed to resolve ip addresses".to_string())); + } + let resolve_ips_iter = resolve_ips.into_iter().map(|resolve_ip| resolve_ip.unwrap()); + + let mut client = Client::builder() + .user_agent(format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"))) + .timeout(timeout_sec) + .trust_dns(true); + + // Override pre-resolved ip addresses + client = resolve_ips_iter.fold(client, |client, resolve_ip| { + client.resolve_to_addrs(&resolve_ip.hostname, &resolve_ip.addresses) + }); + + // Set default headers + if let Some(headers) = default_headers { + client = client.default_headers(headers.clone()); + } + + Ok(Self { + client: client.build().map_err(|e| DapError::HttpClientError(e.to_string()))?, + timeout_sec, + endpoints: endpoints.to_vec(), + }) + } + + pub async fn post(&self, url: impl IntoUrl) -> RequestBuilder { + self.client.post(url) + } + + pub async fn get(&self, url: impl IntoUrl) -> RequestBuilder { + self.client.get(url) + } +} diff --git a/dap-lib/src/lib.rs b/dap-lib/src/lib.rs index b15b3c9..a86da13 100644 --- a/dap-lib/src/lib.rs +++ b/dap-lib/src/lib.rs @@ -1,27 +1,73 @@ +mod auth; +mod bootstrap; mod client; mod constants; mod error; mod globals; +mod http; mod log; +mod proxy; -use crate::{error::*, globals::Globals, log::info}; -use std::sync::Arc; +use crate::{error::*, globals::Globals, http::HttpClient, log::info}; +use async_trait::async_trait; +use std::{net::SocketAddr, sync::Arc}; +use tokio::sync::RwLock; +use url::Url; pub use client::DoHMethod; pub use globals::{AuthenticationConfig, NextHopRelayConfig, ProxyConfig, SubseqRelayConfig, TargetConfig}; +#[async_trait] +pub trait ResolveIps { + async fn resolve_ips(&self, target_url: &Url) -> Result; +} +pub struct ResolveIpResponse { + pub hostname: String, + pub addresses: Vec, +} + pub async fn entrypoint( proxy_config: &ProxyConfig, runtime_handle: &tokio::runtime::Handle, term_notify: Option>, ) -> Result<()> { - info!("Hello, world!"); + info!("Start DoH w/ Auth Proxy"); + + // build bootstrap DNS resolver + let bootstrap_dns_resolver = + bootstrap::BootstrapDnsResolver::try_new(&proxy_config.bootstrap_dns, runtime_handle.clone()).await?; + + // build http client that is used commonly by DoH client and authentication client + let mut endpoint_candidates = vec![]; + if let Some(nexthop_relay_config) = &proxy_config.nexthop_relay_config { + endpoint_candidates.extend(nexthop_relay_config.odoh_relay_urls.clone()); + } else { + endpoint_candidates.extend(proxy_config.target_config.doh_target_urls.clone()); + } + if let Some(auth) = &proxy_config.authentication_config { + endpoint_candidates.push(auth.token_api.clone()); + } + let http_client = HttpClient::new( + &endpoint_candidates, + proxy_config.timeout_sec, + None, + bootstrap_dns_resolver, + ) + .await?; + + let http_client = Arc::new(RwLock::new(http_client)); + + if let Some(auth_config) = &proxy_config.authentication_config { + let authenticator = auth::Authenticator::new(auth_config, http_client).await?; + authenticator.login().await?; + } // build global - let globals = Arc::new(Globals { - proxy_config: proxy_config.clone(), - runtime_handle: runtime_handle.clone(), - }); + // let globals = Arc::new(Globals { + // proxy_config: proxy_config.clone(), + // runtime_handle: runtime_handle.clone(), + // term_notify: term_notify.clone(), + // }); Ok(()) } diff --git a/dap-lib/src/proxy/mod.rs b/dap-lib/src/proxy/mod.rs new file mode 100644 index 0000000..e69de29 diff --git a/doh-auth-proxy.toml b/doh-auth-proxy.toml index d46946b..ef7420b 100644 --- a/doh-auth-proxy.toml +++ b/doh-auth-proxy.toml @@ -12,7 +12,7 @@ listen_addresses = ['127.0.0.1:50053', '[::1]:50053'] ## DNS (Do53) resolver addresses for bootstrap -bootstrap_dns = ["8.8.8.8:53", "1.1.1.1:53"] +bootstrap_dns = ["8.8.8.8", "1.1.1.1"] ## Minutes to re-fetch the IP addr of the target url host via the bootstrap DNS reboot_period = 3