From d2daa30fafcedc6f22056a2e1e0bab28c13018e5 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Wed, 6 Dec 2023 17:42:38 +0900 Subject: [PATCH] fix: change counter to increment only one for safety --- proxy-lib/src/proxy/counter.rs | 136 ++++++++++++++++++++---------- proxy-lib/src/proxy/proxy_main.rs | 4 +- proxy-lib/src/proxy/proxy_tcp.rs | 2 +- proxy-lib/src/proxy/proxy_udp.rs | 6 +- 4 files changed, 98 insertions(+), 50 deletions(-) diff --git a/proxy-lib/src/proxy/counter.rs b/proxy-lib/src/proxy/counter.rs index df69554..5b5f006 100644 --- a/proxy-lib/src/proxy/counter.rs +++ b/proxy-lib/src/proxy/counter.rs @@ -1,6 +1,5 @@ use crate::log::*; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; #[derive(Debug, Clone)] pub enum CounterType { @@ -16,31 +15,55 @@ impl CounterType { } } -#[derive(Debug, Clone, Default)] +#[derive(Debug, Default)] +/// Connection counter inner that is an increment-only counter +pub struct CounterInner { + /// total number of incoming connections + cnt_in: AtomicUsize, + /// total number of served connections + cnt_out: AtomicUsize, +} + +impl CounterInner { + /// output difference between cnt_in and cnt_out as current in-flight connection count + pub fn get_current(&self) -> isize { + self.cnt_in.load(Ordering::Relaxed) as isize - self.cnt_out.load(Ordering::Relaxed) as isize + } + /// increment cnt_in and output current in-flight connection count + pub fn increment(&self) -> isize { + let total_in = self.cnt_in.fetch_add(1, Ordering::Relaxed) as isize; + total_in + 1 - self.cnt_out.load(Ordering::Relaxed) as isize + } + /// increment cnt_out and output current in-flight connection count + pub fn decrement(&self) -> isize { + let total_out = self.cnt_out.fetch_add(1, Ordering::Relaxed) as isize; + self.cnt_in.load(Ordering::Relaxed) as isize - total_out - 1 + } +} + +#[derive(Debug, Default)] /// Connection counter pub struct ConnCounter { - pub cnt_total: Arc, - pub cnt_udp: Arc, - pub cnt_tcp: Arc, + pub cnt_udp: CounterInner, + pub cnt_tcp: CounterInner, } impl ConnCounter { - pub fn get_current_total(&self) -> usize { - self.cnt_total.load(Ordering::Relaxed) + pub fn get_current_total(&self) -> isize { + self.cnt_tcp.get_current() + self.cnt_udp.get_current() } - pub fn get_current(&self, ctype: CounterType) -> usize { + pub fn get_current(&self, ctype: CounterType) -> isize { match ctype { - CounterType::Tcp => self.cnt_tcp.load(Ordering::Relaxed), - CounterType::Udp => self.cnt_udp.load(Ordering::Relaxed), + CounterType::Tcp => self.cnt_tcp.get_current(), + CounterType::Udp => self.cnt_udp.get_current(), } } - pub fn increment(&self, ctype: CounterType) -> usize { - self.cnt_total.fetch_add(1, Ordering::Relaxed); + pub fn increment(&self, ctype: CounterType) -> isize { let c = match ctype { - CounterType::Tcp => self.cnt_tcp.fetch_add(1, Ordering::Relaxed), - CounterType::Udp => self.cnt_udp.fetch_add(1, Ordering::Relaxed), + CounterType::Tcp => self.cnt_tcp.increment(), + CounterType::Udp => self.cnt_udp.increment(), }; debug!( @@ -52,36 +75,11 @@ impl ConnCounter { c } - pub fn decrement(&self, ctype: CounterType) { - let cnt; - match ctype { - CounterType::Tcp => { - let res = { - cnt = self.cnt_tcp.load(Ordering::Relaxed); - cnt > 0 - && self - .cnt_tcp - .compare_exchange(cnt, cnt - 1, Ordering::Relaxed, Ordering::Relaxed) - != Ok(cnt) - }; - if res {} - } - CounterType::Udp => { - let res = { - cnt = self.cnt_udp.load(Ordering::Relaxed); - cnt > 0 - && self - .cnt_udp - .compare_exchange(cnt, cnt - 1, Ordering::Relaxed, Ordering::Relaxed) - != Ok(cnt) - }; - if res {} - } + pub fn decrement(&self, ctype: CounterType) -> isize { + let c = match ctype { + CounterType::Tcp => self.cnt_tcp.decrement(), + CounterType::Udp => self.cnt_udp.decrement(), }; - self.cnt_total.store( - self.cnt_udp.load(Ordering::Relaxed) + self.cnt_tcp.load(Ordering::Relaxed), - Ordering::Relaxed, - ); debug!( "{} connection count--: {} (total = {})", @@ -89,5 +87,57 @@ impl ConnCounter { self.get_current(ctype), self.get_current_total() ); + c + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_counter_inner() { + let counter = CounterInner::default(); + assert_eq!(counter.get_current(), 0); + assert_eq!(counter.increment(), 1); + assert_eq!(counter.get_current(), 1); + assert_eq!(counter.increment(), 2); + assert_eq!(counter.get_current(), 2); + assert_eq!(counter.decrement(), 1); + assert_eq!(counter.get_current(), 1); + assert_eq!(counter.decrement(), 0); + assert_eq!(counter.get_current(), 0); + } + + #[test] + fn test_conn_counter() { + let counter = ConnCounter::default(); + assert_eq!(counter.get_current_total(), 0); + assert_eq!(counter.get_current(CounterType::Tcp), 0); + assert_eq!(counter.get_current(CounterType::Udp), 0); + assert_eq!(counter.increment(CounterType::Tcp), 1); + assert_eq!(counter.get_current_total(), 1); + assert_eq!(counter.get_current(CounterType::Tcp), 1); + assert_eq!(counter.get_current(CounterType::Udp), 0); + assert_eq!(counter.increment(CounterType::Tcp), 2); + assert_eq!(counter.get_current_total(), 2); + assert_eq!(counter.get_current(CounterType::Tcp), 2); + assert_eq!(counter.get_current(CounterType::Udp), 0); + assert_eq!(counter.increment(CounterType::Udp), 1); + assert_eq!(counter.get_current_total(), 3); + assert_eq!(counter.get_current(CounterType::Tcp), 2); + assert_eq!(counter.get_current(CounterType::Udp), 1); + assert_eq!(counter.decrement(CounterType::Tcp), 1); + assert_eq!(counter.get_current_total(), 2); + assert_eq!(counter.get_current(CounterType::Tcp), 1); + assert_eq!(counter.get_current(CounterType::Udp), 1); + assert_eq!(counter.decrement(CounterType::Tcp), 0); + assert_eq!(counter.get_current_total(), 1); + assert_eq!(counter.get_current(CounterType::Tcp), 0); + assert_eq!(counter.get_current(CounterType::Udp), 1); + assert_eq!(counter.decrement(CounterType::Udp), 0); + assert_eq!(counter.get_current_total(), 0); + assert_eq!(counter.get_current(CounterType::Tcp), 0); + assert_eq!(counter.get_current(CounterType::Udp), 0); } } diff --git a/proxy-lib/src/proxy/proxy_main.rs b/proxy-lib/src/proxy/proxy_main.rs index f87dd9f..b55d8ad 100644 --- a/proxy-lib/src/proxy/proxy_main.rs +++ b/proxy-lib/src/proxy/proxy_main.rs @@ -7,7 +7,7 @@ use std::{net::SocketAddr, sync::Arc}; #[derive(Clone)] pub struct Proxy { pub(super) globals: Arc, - pub(super) counter: ConnCounter, + pub(super) counter: Arc, pub(super) doh_client: Arc, pub(super) listening_on: SocketAddr, } @@ -17,7 +17,7 @@ impl Proxy { pub fn new(globals: Arc, listening_on: &SocketAddr, doh_client: &Arc) -> Self { Self { globals, - counter: ConnCounter::default(), + counter: Arc::new(ConnCounter::default()), doh_client: doh_client.clone(), listening_on: *listening_on, } diff --git a/proxy-lib/src/proxy/proxy_tcp.rs b/proxy-lib/src/proxy/proxy_tcp.rs index 0557cff..60845f9 100644 --- a/proxy-lib/src/proxy/proxy_tcp.rs +++ b/proxy-lib/src/proxy/proxy_tcp.rs @@ -40,7 +40,7 @@ impl Proxy { pub async fn serve_tcp_query(self, mut stream: TcpStream, src_addr: SocketAddr) -> Result<()> { debug!("handle tcp query from {:?}", src_addr); let counter = self.counter.clone(); - if counter.increment(CounterType::Tcp) >= self.globals.proxy_config.max_connections { + if counter.increment(CounterType::Tcp) >= self.globals.proxy_config.max_connections as isize { error!( "Too many connections: max = {} (udp+tcp)", self.globals.proxy_config.max_connections diff --git a/proxy-lib/src/proxy/proxy_udp.rs b/proxy-lib/src/proxy/proxy_udp.rs index 9a98e39..931e890 100644 --- a/proxy-lib/src/proxy/proxy_udp.rs +++ b/proxy-lib/src/proxy/proxy_udp.rs @@ -113,7 +113,7 @@ impl Proxy { ) -> Result<()> { debug!("handle udp query from {:?}", src_addr); let counter = self.counter.clone(); - if counter.increment(CounterType::Udp) >= self.globals.proxy_config.max_connections { + if counter.increment(CounterType::Udp) >= self.globals.proxy_config.max_connections as isize { error!( "Too many connections: max = {} (udp+tcp)", self.globals.proxy_config.max_connections @@ -122,7 +122,6 @@ impl Proxy { return Err(DapError::TooManyConnections); } - // self.globals.runtime_handle.clone().spawn(async move { let res = tokio::time::timeout( self.globals.proxy_config.http_timeout_sec + Duration::from_secs(1), // serve udp dns message here @@ -130,10 +129,9 @@ impl Proxy { ) .await .ok(); - // debug!("response from DoH server: {:?}", res); // send response via channel to the dispatch socket - counter.decrement(CounterType::Udp); + counter.decrement(CounterType::Udp); // decrement counter anyways let Some(Ok(r)) = res else { return Err(DapError::FailedToMakeDohQuery);