From c7752f39fea6e1b20ab508ce0908a8abaa13898d Mon Sep 17 00:00:00 2001 From: Russell Cohen Date: Tue, 6 Aug 2024 08:43:24 -0400 Subject: [PATCH] feat: Add API for poisoning connections (#121) This is a port of https://github.com/hyperium/hyper/pull/3145 from hyper v0.14.x. It introduces a PoisonPill atomic onto connection info. When set to true, this prevents the connection from being returned to the pool. --- src/client/legacy/client.rs | 6 ++- src/client/legacy/connect/mod.rs | 53 ++++++++++++++++++++++- tests/legacy_client.rs | 73 +++++++++++++++++++++++++++++++- 3 files changed, 129 insertions(+), 3 deletions(-) diff --git a/src/client/legacy/client.rs b/src/client/legacy/client.rs index 1508666..8562584 100644 --- a/src/client/legacy/client.rs +++ b/src/client/legacy/client.rs @@ -750,6 +750,10 @@ impl PoolClient { } } + fn is_poisoned(&self) -> bool { + self.conn_info.poisoned.poisoned() + } + fn is_ready(&self) -> bool { match self.tx { #[cfg(feature = "http1")] @@ -826,7 +830,7 @@ where B: Send + 'static, { fn is_open(&self) -> bool { - self.is_ready() + !self.is_poisoned() && self.is_ready() } fn reserve(self) -> pool::Reservation { diff --git a/src/client/legacy/connect/mod.rs b/src/client/legacy/connect/mod.rs index bd00baa..e3369b5 100644 --- a/src/client/legacy/connect/mod.rs +++ b/src/client/legacy/connect/mod.rs @@ -62,7 +62,13 @@ //! [`Read`]: hyper::rt::Read //! [`Write`]: hyper::rt::Write //! [`Connection`]: Connection -use std::fmt; +use std::{ + fmt::{self, Formatter}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; use ::http::Extensions; @@ -94,6 +100,39 @@ pub struct Connected { pub(super) alpn: Alpn, pub(super) is_proxied: bool, pub(super) extra: Option, + pub(super) poisoned: PoisonPill, +} + +#[derive(Clone)] +pub(crate) struct PoisonPill { + poisoned: Arc, +} + +impl fmt::Debug for PoisonPill { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + // print the address of the pill—this makes debugging issues much easier + write!( + f, + "PoisonPill@{:p} {{ poisoned: {} }}", + self.poisoned, + self.poisoned.load(Ordering::Relaxed) + ) + } +} + +impl PoisonPill { + pub(crate) fn healthy() -> Self { + Self { + poisoned: Arc::new(AtomicBool::new(false)), + } + } + pub(crate) fn poison(&self) { + self.poisoned.store(true, Ordering::Relaxed) + } + + pub(crate) fn poisoned(&self) -> bool { + self.poisoned.load(Ordering::Relaxed) + } } pub(super) struct Extra(Box); @@ -111,6 +150,7 @@ impl Connected { alpn: Alpn::None, is_proxied: false, extra: None, + poisoned: PoisonPill::healthy(), } } @@ -170,6 +210,16 @@ impl Connected { self.alpn == Alpn::H2 } + /// Poison this connection + /// + /// A poisoned connection will not be reused for subsequent requests by the pool + pub fn poison(&self) { + self.poisoned.poison(); + tracing::debug!( + poison_pill = ?self.poisoned, "connection was poisoned. this connection will not be reused for subsequent requests" + ); + } + // Don't public expose that `Connected` is `Clone`, unsure if we want to // keep that contract... pub(super) fn clone(&self) -> Connected { @@ -177,6 +227,7 @@ impl Connected { alpn: self.alpn, is_proxied: self.is_proxied, extra: self.extra.clone(), + poisoned: self.poisoned.clone(), } } } diff --git a/tests/legacy_client.rs b/tests/legacy_client.rs index 28babd7..f2fd8b3 100644 --- a/tests/legacy_client.rs +++ b/tests/legacy_client.rs @@ -4,6 +4,7 @@ use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener}; use std::pin::Pin; use std::sync::atomic::Ordering; +use std::sync::Arc; use std::task::Poll; use std::thread; use std::time::Duration; @@ -891,7 +892,6 @@ fn capture_connection_on_client() { let addr = server.local_addr().unwrap(); thread::spawn(move || { let mut sock = server.accept().unwrap().0; - //drop(server); sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); sock.set_write_timeout(Some(Duration::from_secs(5))) .unwrap(); @@ -908,3 +908,74 @@ fn capture_connection_on_client() { rt.block_on(client.request(req)).expect("200 OK"); assert!(captured_conn.connection_metadata().is_some()); } + +#[cfg(not(miri))] +#[test] +fn connection_poisoning() { + use std::sync::atomic::AtomicUsize; + + let _ = pretty_env_logger::try_init(); + + let rt = runtime(); + let connector = DebugConnector::new(); + + let client = Client::builder(TokioExecutor::new()).build(connector); + + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let num_conns: Arc = Default::default(); + let num_requests: Arc = Default::default(); + let num_requests_tracker = num_requests.clone(); + let num_conns_tracker = num_conns.clone(); + thread::spawn(move || loop { + let mut sock = server.accept().unwrap().0; + num_conns_tracker.fetch_add(1, Ordering::Relaxed); + let num_requests_tracker = num_requests_tracker.clone(); + thread::spawn(move || { + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))) + .unwrap(); + let mut buf = [0; 4096]; + loop { + if sock.read(&mut buf).expect("read 1") > 0 { + num_requests_tracker.fetch_add(1, Ordering::Relaxed); + sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") + .expect("write 1"); + } + } + }); + }); + let make_request = || { + Request::builder() + .uri(&*format!("http://{}/a", addr)) + .body(Empty::::new()) + .unwrap() + }; + let mut req = make_request(); + let captured_conn = capture_connection(&mut req); + rt.block_on(client.request(req)).expect("200 OK"); + assert_eq!(num_conns.load(Ordering::SeqCst), 1); + assert_eq!(num_requests.load(Ordering::SeqCst), 1); + + rt.block_on(client.request(make_request())).expect("200 OK"); + rt.block_on(client.request(make_request())).expect("200 OK"); + // Before poisoning the connection is reused + assert_eq!(num_conns.load(Ordering::SeqCst), 1); + assert_eq!(num_requests.load(Ordering::SeqCst), 3); + captured_conn + .connection_metadata() + .as_ref() + .unwrap() + .poison(); + + rt.block_on(client.request(make_request())).expect("200 OK"); + + // After poisoning, a new connection is established + assert_eq!(num_conns.load(Ordering::SeqCst), 2); + assert_eq!(num_requests.load(Ordering::SeqCst), 4); + + rt.block_on(client.request(make_request())).expect("200 OK"); + // another request can still reuse: + assert_eq!(num_conns.load(Ordering::SeqCst), 2); + assert_eq!(num_requests.load(Ordering::SeqCst), 5); +}