diff --git a/Cargo.toml b/Cargo.toml index a8a5a24..ef85636 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "cueball" -version = "0.2.2" +version = "0.3.0" authors = [ "Kelly McLaughlin ", "Jon Anderson ", @@ -9,6 +9,7 @@ authors = [ edition = "2018" [dependencies] +backoff = "0.1.5" base64 = "0.10.1" chrono = "0.4.9" derive_more = "0.14.0" diff --git a/examples/basic.rs b/examples/basic.rs index 7ca670a..94f1bb8 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -3,8 +3,8 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::mpsc::Sender; use std::sync::{Arc, Barrier, Mutex}; -use std::{thread, time}; use std::time::Duration; +use std::{thread, time}; use slog::{info, o, Drain, Logger}; @@ -66,7 +66,7 @@ impl FakeResolver { } impl Resolver for FakeResolver { - fn run(&mut self, s: Sender) { + fn run(&mut self, s: Sender) { if self.running { return; } @@ -83,8 +83,13 @@ impl Resolver for FakeResolver { self.pool_tx = Some(s); loop { - if self.pool_tx.as_ref().unwrap().send(BackendMsg::HeartbeatMsg). - is_err() { + if self + .pool_tx + .as_ref() + .unwrap() + .send(BackendMsg::HeartbeatMsg) + .is_err() + { break; } thread::sleep(HEARTBEAT_INTERVAL); @@ -115,6 +120,7 @@ fn main() { log: Some(log), rebalancer_action_delay: None, decoherence_interval: None, + connection_check_interval: None, }; let pool = ConnectionPool::new(pool_opts, resolver, DummyConnection::new); diff --git a/src/connection.rs b/src/connection.rs index 4010ea3..f2b9d93 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -28,6 +28,16 @@ pub trait Connection: Send + Sized + 'static { /// input parameters to `ConnectionPool::new`. Returns an [`error`]( /// ../error/enum.Error.html) if the connection attempt fails. fn connect(&mut self) -> Result<(), Self::Error>; + /// check the to see if connection is still up and working. The connection pool runs this + /// function as the connection is being replaced and triggers a rebalance if the + /// connection is unhealthy. + fn is_valid(&mut self) -> bool { + true + } + // Check to see if the connection has closed or is not operational. + fn has_broken(&self) -> bool { + false + } /// Close the connection to the backend fn close(&mut self) -> Result<(), Self::Error>; } diff --git a/src/connection_pool.rs b/src/connection_pool.rs index feb59d7..01387b9 100644 --- a/src/connection_pool.rs +++ b/src/connection_pool.rs @@ -29,6 +29,7 @@ use crate::error::Error; use crate::resolver::{ BackendAction, BackendAddedMsg, BackendMsg, BackendRemovedMsg, Resolver, }; +use backoff::{ExponentialBackoff, Operation}; // Default number of maximum pool connections const DEFAULT_MAX_CONNECTIONS: u32 = 10; @@ -36,6 +37,8 @@ const DEFAULT_MAX_CONNECTIONS: u32 = 10; const DEFAULT_REBALANCE_ACTION_DELAY: u64 = 100; // Decoherence interval in seconds const DEFAULT_DECOHERENCE_INTERVAL: u64 = 300; +// Connection health check interval in seconds +const DEFAULT_CONNECTION_CHECK_INTERVAL: u64 = 30; /// A pool of connections to a multi-node service pub struct ConnectionPool { @@ -52,7 +55,9 @@ pub struct ConnectionPool { log: Logger, state: ConnectionPoolState, decoherence_timer: Option, - decoherence_timer_guard: Guard, + _decoherence_timer_guard: Option, + connection_check_timer: Option, + _connection_check_timer_guard: Option, _resolver: PhantomData, _connection_function: PhantomData, } @@ -130,7 +135,9 @@ where log: self.log.clone(), state: self.state, decoherence_timer: None, - decoherence_timer_guard: self.decoherence_timer_guard.clone(), + connection_check_timer: None, + _connection_check_timer_guard: None, + _decoherence_timer_guard: None, _resolver: PhantomData, _connection_function: PhantomData, } @@ -220,15 +227,29 @@ where .decoherence_interval .unwrap_or(DEFAULT_DECOHERENCE_INTERVAL); - let timer = timer::Timer::new(); + let decoherence_timer = timer::Timer::new(); let decoherence_timer_guard = start_decoherence( - &timer, + &decoherence_timer, decoherence_interval, protected_data.clone(), logger.clone(), ); + let connection_check_interval = cpo + .connection_check_interval + .unwrap_or(DEFAULT_CONNECTION_CHECK_INTERVAL); + + let connection_check_timer = timer::Timer::new(); + + let connection_check_timer_guard = start_connection_check( + &connection_check_timer, + connection_check_interval, + protected_data.clone(), + rebalancer_check.clone(), + logger.clone(), + ); + let pool = ConnectionPool { protected_data, resolver_thread: Some(resolver_thread), @@ -242,8 +263,10 @@ where decoherence_interval: Some(decoherence_interval), log: logger, state: ConnectionPoolState::Running, - decoherence_timer: Some(timer), - decoherence_timer_guard, + decoherence_timer: Some(decoherence_timer), + _decoherence_timer_guard: Some(decoherence_timer_guard), + connection_check_timer: Some(connection_check_timer), + _connection_check_timer_guard: Some(connection_check_timer_guard), _resolver: PhantomData, _connection_function: PhantomData, }; @@ -336,6 +359,9 @@ where let _timer = self.decoherence_timer.take(); } + if self.connection_check_timer.is_some() { + let _timer = self.connection_check_timer.take(); + } // Wait for all outstanding threads to be returned to the pool and // close those while connections_remaining > 0.into() { @@ -602,8 +628,17 @@ where { let mut connection_data = self.protected_data.connection_data_lock(); let (key, m_conn) = connection_key_pair.into(); - connection_data.connections.push_back((key, m_conn).into()); - connection_data.stats.idle_connections += 1.into(); + match m_conn { + Some(conn) => { + if conn.has_broken() { + warn!(self.log, "Found an invalid connection, not returning to the pool"); + } else { + connection_data.connections.push_back((key, conn).into()); + connection_data.stats.idle_connections += 1.into(); + } + } + None => warn!(self.log, "Connection not found"), + } self.protected_data.condvar_notify(); } } @@ -737,6 +772,11 @@ where C: Connection, { let mut connection_data = protected_data.connection_data_lock(); + debug!( + log, + "Running rebalancer on {} connections...", + connection_data.connections.len() + ); // Calculate a new connection distribution over the set of available // backends and determine what additional connections need to be created and @@ -882,10 +922,27 @@ fn add_connections( if net_total_connections < max_connections.into() { // Try to establish connection + debug!( + log, + "Trying to add more connections: {}", net_total_connections + ); let m_backend = connection_data.backends.get(b_key); if let Some(backend) = m_backend { let mut conn = create_connection(backend); - conn.connect() + let mut backoff = ExponentialBackoff::default(); + let mut op = || { + debug!(log, "attempting to connect with retry..."); + conn.connect().map_err(|e| { + error!( + log, + "Retrying connection \ + : {}", + e + ); + })?; + Ok(()) + }; + op.retry(&mut backoff) .and_then(|_| { // Update connection info and stats let connection_key_pair = @@ -905,12 +962,10 @@ fn add_connections( protected_data.condvar_notify(); Ok(()) }) - .unwrap_or_else(|e| { + .unwrap_or_else(|_| { error!( log, - "Error occurred trying to establish connection \ - : {}", - e + "Giving up trying to establish connection" ); }); } else { @@ -927,7 +982,7 @@ fn add_connections( debug!(log, "{}", msg); } } - }) + }); } fn resolver_recv_loop( @@ -999,7 +1054,7 @@ fn rebalancer_loop( time::Duration::from_millis(rebalance_action_delay); thread::sleep(sleep_time); - debug!(log, "Performing connection rebalance"); + debug!(log, "rebalance var true"); let rebalance_result = rebalance_connections( max_connections, @@ -1007,7 +1062,10 @@ fn rebalancer_loop( protected_data.clone(), ); - debug!(log, "Connection rebalance completed"); + debug!( + log, + "Connection rebalance completed: {:#?}", rebalance_result + ); if let Ok(Some(added_connection_count)) = rebalance_result { debug!(log, "Adding new connections"); @@ -1077,3 +1135,112 @@ where connections.swap(i, new_idx); } } + +/// Start a thread to run periodic health checks on the connection pool +fn start_connection_check( + timer: &timer::Timer, + conn_check_interval: u64, + protected_data: ProtectedData, + rebalance_check: RebalanceCheck, + log: Logger, +) -> Guard +where + C: Connection, +{ + debug!( + log, + "starting connection health task, interval {} seconds", + conn_check_interval + ); + timer.schedule_repeating( + Duration::seconds(conn_check_interval as i64), + move || { + check_pool_connections( + protected_data.clone(), + rebalance_check.clone(), + log.clone(), + ) + }, + ) +} + +fn check_pool_connections( + protected_data: ProtectedData, + rebalance_check: RebalanceCheck, + log: Logger, +) where + C: Connection, +{ + let mut connection_data = protected_data.connection_data_lock(); + let len = connection_data.connections.len(); + + if len == 0 { + debug!(log, "No connections to check, signaling rebalance check"); + let mut rebalance = rebalance_check.get_lock(); + *rebalance = true; + rebalance_check.condvar_notify(); + return; + } + + debug!(log, "Performing connection check on {} connections", len); + + let backend_count = connection_data.backends.len(); + let mut remove_count = HashMap::with_capacity(backend_count); + let mut removed = 0; + connection_data.connections.retain(|pair| match pair { + ConnectionKeyPair((key, Some(conn))) => { + if conn.has_broken() { + removed += 1; + *remove_count.entry(key.clone()).or_insert(0) += 1; + false + } else { + true + } + } + ConnectionKeyPair((key, None)) => { + warn!(log, "found malformed connection"); + removed += 1; + *remove_count.entry(key.clone()).or_insert(0) += 1; + false + } + }); + debug!(log, "Removed {} from connection pool", removed); + + if removed > 0 { + for (key, count) in remove_count.iter() { + connection_data + .connection_distribution + .entry(key.clone()) + .and_modify(|e| { + *e -= ConnectionCount::from(*count); + debug!( + log, + "Connection count for {} now: {}", + key.clone(), + *e + ); + }); + connection_data + .unwanted_connection_counts + .entry(key.clone()) + .and_modify(|e| { + *e += ConnectionCount::from(*count); + debug!( + log, + "Unwanted onnection count for {} now: {}", + key.clone(), + *e + ); + }) + .or_insert_with(|| ConnectionCount::from(*count)); + } + connection_data.stats.idle_connections -= removed.into(); + + let mut rebalance = rebalance_check.get_lock(); + if !*rebalance { + debug!(log, "attempting to signal rebalance check"); + *rebalance = true; + rebalance_check.condvar_notify(); + } + } +} diff --git a/src/connection_pool/types.rs b/src/connection_pool/types.rs index 18778d3..b251ac3 100644 --- a/src/connection_pool/types.rs +++ b/src/connection_pool/types.rs @@ -63,6 +63,10 @@ pub struct ConnectionPoolOptions { /// the period of the decoherence shuffle. If not specified the default is /// 300 seconds. pub decoherence_interval: Option, + /// Optional connection check interval in seconds. This represents the length of + /// the period of the pool connection check task. If not specified the default is + /// 30 seconds. + pub connection_check_interval: Option, } // This type wraps a pair that associates a `BackendKey` with a connection of diff --git a/src/resolver.rs b/src/resolver.rs index 0e6d400..6e4100b 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -57,25 +57,19 @@ pub enum BackendMsg { // For internal pool use only. Resolver implementations can send this // message to test whether or not the channel has been closed. #[doc(hidden)] - HeartbeatMsg + HeartbeatMsg, } impl PartialEq for BackendMsg { fn eq(&self, other: &Self) -> bool { match (self, other) { - (BackendMsg::AddedMsg(a), BackendMsg::AddedMsg(b)) => { - a == b - }, + (BackendMsg::AddedMsg(a), BackendMsg::AddedMsg(b)) => a == b, (BackendMsg::RemovedMsg(a), BackendMsg::RemovedMsg(b)) => { a.0 == b.0 - }, - (BackendMsg::StopMsg, BackendMsg::StopMsg) => { - true - }, - (BackendMsg::HeartbeatMsg, BackendMsg::HeartbeatMsg) => { - true - }, - _ => false + } + (BackendMsg::StopMsg, BackendMsg::StopMsg) => true, + (BackendMsg::HeartbeatMsg, BackendMsg::HeartbeatMsg) => true, + _ => false, } } } diff --git a/tests/basic_test.rs b/tests/basic_test.rs index 869d8b1..89c4c04 100644 --- a/tests/basic_test.rs +++ b/tests/basic_test.rs @@ -3,8 +3,8 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::mpsc::Sender; use std::sync::{Arc, Barrier}; -use std::{thread, time}; use std::time::Duration; +use std::{thread, time}; use cueball::backend; use cueball::backend::{Backend, BackendAddress, BackendPort}; @@ -81,8 +81,13 @@ impl Resolver for FakeResolver { self.pool_tx = Some(s); loop { - if self.pool_tx.as_ref().unwrap().send(BackendMsg::HeartbeatMsg). - is_err() { + if self + .pool_tx + .as_ref() + .unwrap() + .send(BackendMsg::HeartbeatMsg) + .is_err() + { break; } thread::sleep(HEARTBEAT_INTERVAL); @@ -107,6 +112,7 @@ fn connection_pool_claim() { log: None, rebalancer_action_delay: None, decoherence_interval: None, + connection_check_interval: None, }; let max_connections = pool_opts.max_connections.unwrap().clone(); @@ -192,6 +198,7 @@ fn connection_pool_stop() { log: None, rebalancer_action_delay: None, decoherence_interval: None, + connection_check_interval: None, }; let max_connections = pool_opts.max_connections.unwrap().clone(); @@ -231,6 +238,7 @@ fn connection_pool_accounting() { log: None, rebalancer_action_delay: None, decoherence_interval: None, + connection_check_interval: None, }; let max_connections: ConnectionCount = @@ -342,6 +350,7 @@ fn connection_pool_decoherence() { log: None, rebalancer_action_delay: Some(10000), decoherence_interval: Some(5), + connection_check_interval: None, }; let max_connections: ConnectionCount =