Skip to content

Commit

Permalink
fix: change counter to increment only one for safety
Browse files Browse the repository at this point in the history
  • Loading branch information
junkurihara committed Dec 6, 2023
1 parent 9596916 commit d2daa30
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 50 deletions.
136 changes: 93 additions & 43 deletions proxy-lib/src/proxy/counter.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::log::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

#[derive(Debug, Clone)]
pub enum CounterType {
Expand All @@ -16,31 +15,55 @@ impl CounterType {
}
}

#[derive(Debug, Clone, Default)]
#[derive(Debug, Default)]
/// Connection counter inner that is an increment-only counter
pub struct CounterInner {
/// total number of incoming connections
cnt_in: AtomicUsize,
/// total number of served connections
cnt_out: AtomicUsize,
}

impl CounterInner {
/// output difference between cnt_in and cnt_out as current in-flight connection count
pub fn get_current(&self) -> isize {
self.cnt_in.load(Ordering::Relaxed) as isize - self.cnt_out.load(Ordering::Relaxed) as isize
}
/// increment cnt_in and output current in-flight connection count
pub fn increment(&self) -> isize {
let total_in = self.cnt_in.fetch_add(1, Ordering::Relaxed) as isize;
total_in + 1 - self.cnt_out.load(Ordering::Relaxed) as isize
}
/// increment cnt_out and output current in-flight connection count
pub fn decrement(&self) -> isize {
let total_out = self.cnt_out.fetch_add(1, Ordering::Relaxed) as isize;
self.cnt_in.load(Ordering::Relaxed) as isize - total_out - 1
}
}

#[derive(Debug, Default)]
/// Connection counter
pub struct ConnCounter {
pub cnt_total: Arc<AtomicUsize>,
pub cnt_udp: Arc<AtomicUsize>,
pub cnt_tcp: Arc<AtomicUsize>,
pub cnt_udp: CounterInner,
pub cnt_tcp: CounterInner,
}

impl ConnCounter {
pub fn get_current_total(&self) -> usize {
self.cnt_total.load(Ordering::Relaxed)
pub fn get_current_total(&self) -> isize {
self.cnt_tcp.get_current() + self.cnt_udp.get_current()
}

pub fn get_current(&self, ctype: CounterType) -> usize {
pub fn get_current(&self, ctype: CounterType) -> isize {
match ctype {
CounterType::Tcp => self.cnt_tcp.load(Ordering::Relaxed),
CounterType::Udp => self.cnt_udp.load(Ordering::Relaxed),
CounterType::Tcp => self.cnt_tcp.get_current(),
CounterType::Udp => self.cnt_udp.get_current(),
}
}

pub fn increment(&self, ctype: CounterType) -> usize {
self.cnt_total.fetch_add(1, Ordering::Relaxed);
pub fn increment(&self, ctype: CounterType) -> isize {
let c = match ctype {
CounterType::Tcp => self.cnt_tcp.fetch_add(1, Ordering::Relaxed),
CounterType::Udp => self.cnt_udp.fetch_add(1, Ordering::Relaxed),
CounterType::Tcp => self.cnt_tcp.increment(),
CounterType::Udp => self.cnt_udp.increment(),
};

debug!(
Expand All @@ -52,42 +75,69 @@ impl ConnCounter {
c
}

pub fn decrement(&self, ctype: CounterType) {
let cnt;
match ctype {
CounterType::Tcp => {
let res = {
cnt = self.cnt_tcp.load(Ordering::Relaxed);
cnt > 0
&& self
.cnt_tcp
.compare_exchange(cnt, cnt - 1, Ordering::Relaxed, Ordering::Relaxed)
!= Ok(cnt)
};
if res {}
}
CounterType::Udp => {
let res = {
cnt = self.cnt_udp.load(Ordering::Relaxed);
cnt > 0
&& self
.cnt_udp
.compare_exchange(cnt, cnt - 1, Ordering::Relaxed, Ordering::Relaxed)
!= Ok(cnt)
};
if res {}
}
pub fn decrement(&self, ctype: CounterType) -> isize {
let c = match ctype {
CounterType::Tcp => self.cnt_tcp.decrement(),
CounterType::Udp => self.cnt_udp.decrement(),
};
self.cnt_total.store(
self.cnt_udp.load(Ordering::Relaxed) + self.cnt_tcp.load(Ordering::Relaxed),
Ordering::Relaxed,
);

debug!(
"{} connection count--: {} (total = {})",
&ctype.as_str(),
self.get_current(ctype),
self.get_current_total()
);
c
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_counter_inner() {
let counter = CounterInner::default();
assert_eq!(counter.get_current(), 0);
assert_eq!(counter.increment(), 1);
assert_eq!(counter.get_current(), 1);
assert_eq!(counter.increment(), 2);
assert_eq!(counter.get_current(), 2);
assert_eq!(counter.decrement(), 1);
assert_eq!(counter.get_current(), 1);
assert_eq!(counter.decrement(), 0);
assert_eq!(counter.get_current(), 0);
}

#[test]
fn test_conn_counter() {
let counter = ConnCounter::default();
assert_eq!(counter.get_current_total(), 0);
assert_eq!(counter.get_current(CounterType::Tcp), 0);
assert_eq!(counter.get_current(CounterType::Udp), 0);
assert_eq!(counter.increment(CounterType::Tcp), 1);
assert_eq!(counter.get_current_total(), 1);
assert_eq!(counter.get_current(CounterType::Tcp), 1);
assert_eq!(counter.get_current(CounterType::Udp), 0);
assert_eq!(counter.increment(CounterType::Tcp), 2);
assert_eq!(counter.get_current_total(), 2);
assert_eq!(counter.get_current(CounterType::Tcp), 2);
assert_eq!(counter.get_current(CounterType::Udp), 0);
assert_eq!(counter.increment(CounterType::Udp), 1);
assert_eq!(counter.get_current_total(), 3);
assert_eq!(counter.get_current(CounterType::Tcp), 2);
assert_eq!(counter.get_current(CounterType::Udp), 1);
assert_eq!(counter.decrement(CounterType::Tcp), 1);
assert_eq!(counter.get_current_total(), 2);
assert_eq!(counter.get_current(CounterType::Tcp), 1);
assert_eq!(counter.get_current(CounterType::Udp), 1);
assert_eq!(counter.decrement(CounterType::Tcp), 0);
assert_eq!(counter.get_current_total(), 1);
assert_eq!(counter.get_current(CounterType::Tcp), 0);
assert_eq!(counter.get_current(CounterType::Udp), 1);
assert_eq!(counter.decrement(CounterType::Udp), 0);
assert_eq!(counter.get_current_total(), 0);
assert_eq!(counter.get_current(CounterType::Tcp), 0);
assert_eq!(counter.get_current(CounterType::Udp), 0);
}
}
4 changes: 2 additions & 2 deletions proxy-lib/src/proxy/proxy_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{net::SocketAddr, sync::Arc};
#[derive(Clone)]
pub struct Proxy {
pub(super) globals: Arc<Globals>,
pub(super) counter: ConnCounter,
pub(super) counter: Arc<ConnCounter>,
pub(super) doh_client: Arc<DoHClient>,
pub(super) listening_on: SocketAddr,
}
Expand All @@ -17,7 +17,7 @@ impl Proxy {
pub fn new(globals: Arc<Globals>, listening_on: &SocketAddr, doh_client: &Arc<DoHClient>) -> Self {
Self {
globals,
counter: ConnCounter::default(),
counter: Arc::new(ConnCounter::default()),
doh_client: doh_client.clone(),
listening_on: *listening_on,
}
Expand Down
2 changes: 1 addition & 1 deletion proxy-lib/src/proxy/proxy_tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl Proxy {
pub async fn serve_tcp_query(self, mut stream: TcpStream, src_addr: SocketAddr) -> Result<()> {
debug!("handle tcp query from {:?}", src_addr);
let counter = self.counter.clone();
if counter.increment(CounterType::Tcp) >= self.globals.proxy_config.max_connections {
if counter.increment(CounterType::Tcp) >= self.globals.proxy_config.max_connections as isize {
error!(
"Too many connections: max = {} (udp+tcp)",
self.globals.proxy_config.max_connections
Expand Down
6 changes: 2 additions & 4 deletions proxy-lib/src/proxy/proxy_udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl Proxy {
) -> Result<()> {
debug!("handle udp query from {:?}", src_addr);
let counter = self.counter.clone();
if counter.increment(CounterType::Udp) >= self.globals.proxy_config.max_connections {
if counter.increment(CounterType::Udp) >= self.globals.proxy_config.max_connections as isize {
error!(
"Too many connections: max = {} (udp+tcp)",
self.globals.proxy_config.max_connections
Expand All @@ -122,18 +122,16 @@ impl Proxy {
return Err(DapError::TooManyConnections);
}

// self.globals.runtime_handle.clone().spawn(async move {
let res = tokio::time::timeout(
self.globals.proxy_config.http_timeout_sec + Duration::from_secs(1),
// serve udp dns message here
self.doh_client.make_doh_query(&packet_buf),
)
.await
.ok();
// debug!("response from DoH server: {:?}", res);

// send response via channel to the dispatch socket
counter.decrement(CounterType::Udp);
counter.decrement(CounterType::Udp); // decrement counter anyways

let Some(Ok(r)) = res else {
return Err(DapError::FailedToMakeDohQuery);
Expand Down

0 comments on commit d2daa30

Please sign in to comment.