diff --git a/dap-bin/src/config/toml.rs b/dap-bin/src/config/toml.rs index 375b977..6a706cb 100644 --- a/dap-bin/src/config/toml.rs +++ b/dap-bin/src/config/toml.rs @@ -72,12 +72,15 @@ impl TryInto for &ConfigToml { proxy_config.bootstrap_dns.ips = val.iter().map(|x| x.parse().unwrap()).collect() }; info!("Bootstrap DNS: {:?}", proxy_config.bootstrap_dns.ips); + + ///////////////////////////// + // reboot period if let Some(val) = self.reboot_period { - proxy_config.bootstrap_dns.rebootstrap_period_sec = Duration::from_secs((val as u64) * 60); + proxy_config.endpoint_resolution_period_sec = Duration::from_secs((val as u64) * 60); } info!( - "Target DoH Address is re-fetched every {:?} min via Bootsrap DNS", - proxy_config.bootstrap_dns.rebootstrap_period_sec.as_secs() / 60 + "Target DoH and auth server addresses are re-fetched every {:?} min via DoH itself or Bootsrap DNS", + proxy_config.endpoint_resolution_period_sec.as_secs() / 60 ); ///////////////////////////// diff --git a/dap-lib/Cargo.toml b/dap-lib/Cargo.toml index 62ee936..87ed20e 100644 --- a/dap-lib/Cargo.toml +++ b/dap-lib/Cargo.toml @@ -38,12 +38,18 @@ tokio = { version = "1.33.0", features = [ "sync", "macros", ] } -futures = { version = "0.3.28", default-features = false } +futures = { version = "0.3.28", default-features = false, features = [ + "std", + "async-await", +] } anyhow = "1.0.75" tracing = "0.1.40" thiserror = "1.0.50" async-trait = "0.1.74" +# network +socket2 = "0.5.5" + # http client reqwest = { version = "0.11.22", default-features = false, features = [ "json", @@ -52,14 +58,11 @@ reqwest = { version = "0.11.22", default-features = false, features = [ ] } url = "2.4.1" - # for bootstrap dns resolver trust-dns-resolver = { version = "0.23.2", default-features = false, features = [ "tokio-runtime", ] } - # authentication auth-client = { git = "https://github.com/junkurihara/rust-token-server", package = "rust-token-server-client", branch = "develop" } serde = { version = "1.0.189", features = ["derive"] } -socket2 = "0.5.5" diff --git a/dap-lib/src/bootstrap.rs b/dap-lib/src/bootstrap.rs index f3aa975..3aed770 100644 --- a/dap-lib/src/bootstrap.rs +++ b/dap-lib/src/bootstrap.rs @@ -6,13 +6,14 @@ use crate::{ }; use async_trait::async_trait; use reqwest::Url; -use std::net::SocketAddr; +use std::{net::SocketAddr, sync::Arc}; use trust_dns_resolver::{ config::{NameServerConfigGroup, ResolverConfig, ResolverOpts}, name_server::{GenericConnector, TokioRuntimeProvider}, AsyncResolver, TokioAsyncResolver, }; +#[derive(Clone)] /// stub resolver using bootstrap DNS resolver pub struct BootstrapDnsResolver { /// wrapper of trust-dns-resolver @@ -37,7 +38,7 @@ impl BootstrapDnsResolver { } #[async_trait] -impl ResolveIps for BootstrapDnsResolver { +impl ResolveIps for Arc { /// Lookup the IP addresses associated with a name using the bootstrap resolver async fn resolve_ips(&self, target_url: &Url) -> Result { // The final dot forces this to be an FQDN, otherwise the search rules as specified @@ -88,11 +89,11 @@ mod tests { let bootstrap_dns = BootstrapDns { ips: vec![IpAddr::from([8, 8, 8, 8])], port: 53, - rebootstrap_period_sec: tokio::time::Duration::from_secs(1), }; let resolver = BootstrapDnsResolver::try_new(&bootstrap_dns, tokio::runtime::Handle::current()) .await .unwrap(); + let resolver = Arc::new(resolver); let target_url = Url::parse("https://dns.google").unwrap(); let response = resolver.resolve_ips(&target_url).await.unwrap(); diff --git a/dap-lib/src/constants.rs b/dap-lib/src/constants.rs index 89fe2c3..86a61f7 100644 --- a/dap-lib/src/constants.rs +++ b/dap-lib/src/constants.rs @@ -16,11 +16,19 @@ pub const MIN_TTL: u32 = 10; // TTL for overridden records (plugin) // Default Values for Config // //////////////////////////////// // Can override by specifying values in config.toml + +/// Default listen address pub const LISTEN_ADDRESSES: &[&str] = &["127.0.0.1:50053", "[::1]:50053"]; +/// Bootstrap DNS address pub const BOOTSTRAP_DNS_IPS: &[&str] = &["1.1.1.1"]; +/// Bootstrap DNS port pub const BOOTSTRAP_DNS_PORT: u16 = 53; -pub const REBOOTSTRAP_PERIOD_MIN: u64 = 60; + +/// Endpoint resolution period in minutes +pub const ENDPOINT_RESOLUTION_PERIOD_MIN: u64 = 60; + +/// Default DoH target server pub const DOH_TARGET_URL: &[&str] = &["https://dns.google/dns-query"]; pub const MAX_CACHE_SIZE: usize = 16384; diff --git a/dap-lib/src/doh_client/doh_client_main.rs b/dap-lib/src/doh_client/doh_client_main.rs index d092e87..392ef07 100644 --- a/dap-lib/src/doh_client/doh_client_main.rs +++ b/dap-lib/src/doh_client/doh_client_main.rs @@ -1,6 +1,13 @@ -use crate::{error::*, globals::Globals, http_client::HttpClientInner}; +use crate::{ + error::*, + globals::Globals, + http_client::HttpClientInner, + trait_resolve_ips::{ResolveIpResponse, ResolveIps}, +}; +use async_trait::async_trait; use std::sync::Arc; use tokio::sync::RwLock; +use url::Url; #[derive(Debug)] /// DoH, ODoH, MODoH client @@ -19,3 +26,12 @@ impl DoHClient { Ok(vec![]) } } + +// TODO: implement ResolveIps for DoHClient +#[async_trait] +impl ResolveIps for Arc { + /// Resolve ip addresses of the given domain name + async fn resolve_ips(&self, domain: &Url) -> Result { + Err(DapError::Other(anyhow!("Not implemented"))) + } +} diff --git a/dap-lib/src/error.rs b/dap-lib/src/error.rs index 2b23f51..81865ca 100644 --- a/dap-lib/src/error.rs +++ b/dap-lib/src/error.rs @@ -21,8 +21,10 @@ pub enum DapError { #[error("HttpClient error")] HttpClientError(#[from] reqwest::Error), - #[error("HttpClient build error")] - HttpClientBuildError, + #[error("Failed to resolve Ips for HTTP client")] + FailedToResolveIpsForHttpClient, + #[error("Too many fails to resolve Ips for HTTP client in periodic task")] + TooManyFailsToResolveIps, #[error("Io Error: {0}")] Io(#[from] std::io::Error), #[error("Null TCP stream")] diff --git a/dap-lib/src/globals.rs b/dap-lib/src/globals.rs index 24ee878..162ecd8 100644 --- a/dap-lib/src/globals.rs +++ b/dap-lib/src/globals.rs @@ -1,17 +1,10 @@ -use crate::{ - constants::*, - doh_client::{DoHClient, DoHMethod}, - http_client::HttpClient, -}; +use crate::{constants::*, doh_client::DoHMethod, http_client::HttpClient}; use auth_client::AuthenticationConfig; use std::{ net::{IpAddr, SocketAddr}, sync::Arc, }; -use tokio::{ - sync::{Notify, RwLock}, - time::Duration, -}; +use tokio::{sync::Notify, time::Duration}; use url::Url; #[derive(Debug)] @@ -39,6 +32,7 @@ pub struct ProxyConfig { /// bootstrap DNS pub bootstrap_dns: BootstrapDns, + pub endpoint_resolution_period_sec: Duration, // udp and tcp proxy setting pub udp_buffer_size: usize, @@ -70,7 +64,6 @@ pub struct ProxyConfig { pub struct BootstrapDns { pub ips: Vec, pub port: u16, - pub rebootstrap_period_sec: Duration, } #[derive(PartialEq, Eq, Debug, Clone)] @@ -115,8 +108,8 @@ impl Default for ProxyConfig { bootstrap_dns: BootstrapDns { ips: BOOTSTRAP_DNS_IPS.iter().map(|v| v.parse().unwrap()).collect(), port: BOOTSTRAP_DNS_PORT, - rebootstrap_period_sec: Duration::from_secs(REBOOTSTRAP_PERIOD_MIN * 60), }, + endpoint_resolution_period_sec: Duration::from_secs(ENDPOINT_RESOLUTION_PERIOD_MIN * 60), udp_buffer_size: UDP_BUFFER_SIZE, udp_channel_capacity: UDP_CHANNEL_CAPACITY, diff --git a/dap-lib/src/http_client/http_client_main.rs b/dap-lib/src/http_client/http_client_main.rs index 9f5865a..6226529 100644 --- a/dap-lib/src/http_client/http_client_main.rs +++ b/dap-lib/src/http_client/http_client_main.rs @@ -1,11 +1,9 @@ -use std::sync::Arc; - use crate::{ error::*, - trait_resolve_ips::{ResolveIpResponse, ResolveIps}, + trait_resolve_ips::{resolve_ips, ResolveIpResponse, ResolveIps}, }; -use futures::future::join_all; use reqwest::{header::HeaderMap, Client, IntoUrl, RequestBuilder, Url}; +use std::sync::Arc; use tokio::{sync::RwLock, time::Duration}; #[derive(Debug)] @@ -18,8 +16,14 @@ pub struct HttpClient { /// This would be targets for DoH, nexthop relay for ODoH (path including target, not mid-relays for dynamic randomization) endpoints: Vec, + /// default headers + default_headers: Option, + /// timeout for http request timeout_sec: Duration, + + /// rebootstrap period for endpoint ip resolution + rebootstrap_period_sec: Duration, } impl HttpClient { @@ -29,14 +33,17 @@ impl HttpClient { timeout_sec: Duration, default_headers: Option<&HeaderMap>, resolver_ips: impl ResolveIps, + rebootstrap_period_sec: Duration, ) -> Result { let resolved_ips = resolve_ips(endpoints, resolver_ips).await?; Ok(Self { inner: Arc::new(RwLock::new( HttpClientInner::new(timeout_sec, default_headers, &resolved_ips).await?, )), + default_headers: default_headers.cloned(), timeout_sec, endpoints: endpoints.to_vec(), + rebootstrap_period_sec, }) } @@ -44,6 +51,26 @@ impl HttpClient { pub fn inner(&self) -> Arc> { self.inner.clone() } + + /// Get endpoints + pub fn endpoints(&self) -> &[Url] { + &self.endpoints + } + + /// Get default headers + pub fn default_headers(&self) -> Option<&HeaderMap> { + self.default_headers.as_ref() + } + + /// Get timeout + pub fn timeout_sec(&self) -> Duration { + self.timeout_sec + } + + /// Get rebootstrap period + pub fn rebootstrap_period_sec(&self) -> Duration { + self.rebootstrap_period_sec + } } #[derive(Debug)] @@ -79,23 +106,12 @@ impl HttpClientInner { } /// Post wrapper - pub async fn post(&self, url: impl IntoUrl) -> RequestBuilder { + pub fn post(&self, url: impl IntoUrl) -> RequestBuilder { self.client.post(url) } /// Get wrapper - pub async fn get(&self, url: impl IntoUrl) -> RequestBuilder { + pub fn get(&self, url: impl IntoUrl) -> RequestBuilder { self.client.get(url) } } - -/// Resolve ip addresses for given endpoints -async fn resolve_ips(endpoints: &[Url], resolver_ips: impl ResolveIps) -> Result> { - let resolve_ips_fut = endpoints.iter().map(|endpoint| resolver_ips.resolve_ips(endpoint)); - let resolve_ips = join_all(resolve_ips_fut).await; - if resolve_ips.iter().any(|resolve_ip| resolve_ip.is_err()) { - return Err(DapError::HttpClientBuildError); - } - let resolve_ips_vec = resolve_ips.into_iter().map(|resolve_ip| resolve_ip.unwrap()).collect(); - Ok(resolve_ips_vec) -} diff --git a/dap-lib/src/http_client/http_client_service.rs b/dap-lib/src/http_client/http_client_service.rs index 96492e4..1e7871f 100644 --- a/dap-lib/src/http_client/http_client_service.rs +++ b/dap-lib/src/http_client/http_client_service.rs @@ -1,44 +1,89 @@ -use super::HttpClient; -use crate::{error::*, log::*, trait_resolve_ips::ResolveIps}; +use tokio::time::sleep; + +use super::{HttpClient, HttpClientInner}; +use crate::{ + error::*, + log::*, + trait_resolve_ips::{resolve_ips, ResolveIpResponse, ResolveIps}, +}; use std::sync::Arc; impl HttpClient { /// Periodically resolves endpoints to ip addresses, and override their ip addresses in the inner client. - pub async fn start_ip_update_service( + pub async fn start_endpoint_ip_update_service( &self, - primary_resolver: impl ResolveIps, - fallback_resolver: impl ResolveIps, + primary_resolver: impl ResolveIps + Clone, + fallback_resolver: impl ResolveIps + Clone, term_notify: Option>, ) -> Result<()> { - info!("start periodic service updating endpoint ip addresses"); - - // match term_notify { - // Some(term) => { - // tokio::select! { - // _ = self.auth_service() => { - // warn!("Auth service got down"); - // } - // _ = term.notified() => { - // info!("Auth service receives term signal"); - // } - // } - // } - // None => { - // self.auth_service().await?; - // warn!("Auth service got down"); - // } - // } + info!("start periodic service for resolution of endpoint ip addresses"); + + match term_notify { + Some(term) => { + tokio::select! { + _ = self.resolve_endpoint_ip_service(primary_resolver, fallback_resolver) => { + warn!("Endpoint ip resolution service got down"); + } + _ = term.notified() => { + info!("Endpoint ip resolution service receives term signal"); + } + } + } + None => { + self + .resolve_endpoint_ip_service(primary_resolver, fallback_resolver) + .await?; + warn!("Endpoint ip resolution service got down"); + } + } Ok(()) } - // /// periodic refresh checker - // async fn auth_service(&self) -> Result<()> { - // loop { - // self - // .refresh_or_login() - // .await - // .with_context(|| "auth service failed to refresh or login")?; - // sleep(Duration::from_secs(TOKEN_REFRESH_WATCH_DELAY as u64)).await; - // } - // } + /// periodic refresh checker + async fn resolve_endpoint_ip_service( + &self, + primary_resolver: impl ResolveIps + Clone, + fallback_resolver: impl ResolveIps + Clone, + ) -> Result<()> { + let mut fail_cnt = 0; + loop { + sleep(self.rebootstrap_period_sec()).await; + let endpoints = self.endpoints(); + + let primary_res = resolve_ips(endpoints, primary_resolver.clone()).await; + if primary_res.is_ok() { + self.update_inner(&primary_res.unwrap()).await?; + fail_cnt = 0; + info!("Resolved endpoint ip addresses by DoH resolver"); + continue; + } + warn!( + "Failed to resolve endpoint ip addresses by doh resolver, trying fallback with bootstrap resolver: {}", + primary_res.err().unwrap() + ); + + let fallback_res = resolve_ips(endpoints, fallback_resolver.clone()).await; + if fallback_res.is_ok() { + self.update_inner(&fallback_res.unwrap()).await?; + fail_cnt = 0; + info!("Resolved endpoint ip addresses by bootstrap resolver"); + continue; + } + warn!("Failed to resolve endpoint ip addresses by both DoH and bootstrap resolvers"); + + fail_cnt += 1; + if fail_cnt > 3 { + return Err(DapError::TooManyFailsToResolveIps); + } + } + } + + /// Update http client inner + async fn update_inner(&self, resolved_ips: &[ResolveIpResponse]) -> Result<()> { + let inner = self.inner(); + let mut inner_lock = inner.write().await; + *inner_lock = HttpClientInner::new(self.timeout_sec(), self.default_headers(), resolved_ips).await?; + drop(inner_lock); + Ok(()) + } } diff --git a/dap-lib/src/lib.rs b/dap-lib/src/lib.rs index fdf94b2..d669f96 100644 --- a/dap-lib/src/lib.rs +++ b/dap-lib/src/lib.rs @@ -10,7 +10,10 @@ mod proxy; mod trait_resolve_ips; use crate::{doh_client::DoHClient, error::*, globals::Globals, http_client::HttpClient, log::*, proxy::Proxy}; -use futures::future::select_all; +use futures::{ + future::{select_all, FutureExt}, + select, +}; use std::sync::Arc; pub use auth_client::AuthenticationConfig; @@ -27,7 +30,7 @@ pub async fn entrypoint( // build bootstrap DNS resolver let bootstrap_dns_resolver = - bootstrap::BootstrapDnsResolver::try_new(&proxy_config.bootstrap_dns, runtime_handle.clone()).await?; + Arc::new(bootstrap::BootstrapDnsResolver::try_new(&proxy_config.bootstrap_dns, runtime_handle.clone()).await?); // build http client that is used commonly by DoH client and authentication client let mut endpoint_candidates = vec![]; @@ -43,9 +46,11 @@ pub async fn entrypoint( &endpoint_candidates, proxy_config.http_timeout_sec, None, - bootstrap_dns_resolver, + bootstrap_dns_resolver.clone(), + proxy_config.endpoint_resolution_period_sec, ) .await?; + let http_client = Arc::new(http_client); // spawn authentication service let term_notify_clone = term_notify.clone(); @@ -61,19 +66,23 @@ pub async fn entrypoint( auth_service = Some(auth_service_inner); } - // TODO: services - // - Authentication refresh/re-login service loop (Done) - // - HTTP client update service loop, changing DNS resolver to the self when it works - // - Health check service checking every path, flag unreachable patterns as unhealthy - // build doh_client let doh_client = Arc::new(DoHClient::new(http_client.inner())); - // TODO: doh_clientにResolveIps traitを実装、http client ip updateサービスをここでspawn + // spawn endpoint ip update service + let doh_client_clone = doh_client.clone(); + let term_notify_clone = term_notify.clone(); + let http_client_clone = http_client.clone(); + let ip_resolution_service = runtime_handle.spawn(async move { + http_client_clone + .start_endpoint_ip_update_service(doh_client_clone, bootstrap_dns_resolver, term_notify_clone) + .await + .with_context(|| "endpoint ip update service got down") + }); // build global let globals = Arc::new(Globals { - http_client: Arc::new(http_client), + http_client, proxy_config: proxy_config.clone(), runtime_handle: runtime_handle.clone(), term_notify, @@ -88,12 +97,32 @@ pub async fn entrypoint( // wait for all future if let Some(auth_service) = auth_service { - futures::future::select(proxy_service, auth_service).await; - warn!("Some proxy services and auth service are down or term notified"); + select! { + _ = proxy_service.fuse() => { + warn!("Proxy services are down, or term notified"); + }, + _ = auth_service.fuse() => { + warn!("Auth services are down, or term notified"); + } + _ = ip_resolution_service.fuse() => { + warn!("Ip resolution service are down, or term notified"); + } + } } else { - let _res = proxy_service.await; - warn!("Some proxy services are down or term notified"); + select! { + _ = proxy_service.fuse() => { + warn!("Proxy services are down, or term notified"); + }, + _ = ip_resolution_service.fuse() => { + warn!("Ip resolution service are down, or term notified"); + } + } } + // TODO: services + // - Authentication refresh/re-login service loop (Done) + // - HTTP client update service loop, changing DNS resolver to the self when it works (Done) + // - Health check service checking every path, flag unreachable patterns as unhealthy (as individual service using doh_client?) + Ok(()) } diff --git a/dap-lib/src/proxy/proxy_main.rs b/dap-lib/src/proxy/proxy_main.rs index 94d877a..f87dd9f 100644 --- a/dap-lib/src/proxy/proxy_main.rs +++ b/dap-lib/src/proxy/proxy_main.rs @@ -1,5 +1,6 @@ use super::counter::ConnCounter; use crate::{doh_client::DoHClient, error::*, globals::Globals, log::*}; +use futures::future::select; use std::{net::SocketAddr, sync::Arc}; /// Proxy object serving UDP and TCP queries @@ -66,7 +67,7 @@ impl Proxy { } }); - futures::future::select(udp_fut, tcp_fut).await; + select(udp_fut, tcp_fut).await; Ok(()) } diff --git a/dap-lib/src/trait_resolve_ips.rs b/dap-lib/src/trait_resolve_ips.rs index ee971ee..92f7b09 100644 --- a/dap-lib/src/trait_resolve_ips.rs +++ b/dap-lib/src/trait_resolve_ips.rs @@ -1,4 +1,4 @@ -use crate::error::Result; +use crate::error::*; use async_trait::async_trait; use std::net::SocketAddr; use url::Url; @@ -16,3 +16,14 @@ pub struct ResolveIpResponse { /// resolved ip addresses pub addresses: Vec, } + +/// Resolve ip addresses for given endpoints +pub async fn resolve_ips(endpoints: &[Url], resolver_ips: impl ResolveIps) -> Result> { + let resolve_ips_fut = endpoints.iter().map(|endpoint| resolver_ips.resolve_ips(endpoint)); + let resolve_ips = futures::future::join_all(resolve_ips_fut).await; + if resolve_ips.iter().any(|resolve_ip| resolve_ip.is_err()) { + return Err(DapError::FailedToResolveIpsForHttpClient); + } + let resolve_ips_vec = resolve_ips.into_iter().map(|resolve_ip| resolve_ip.unwrap()).collect(); + Ok(resolve_ips_vec) +}