From 6879692b49b73a194d05ae3dfc26a56a4e410df5 Mon Sep 17 00:00:00 2001 From: Alexei Kornienko Date: Wed, 9 Dec 2020 19:53:22 +0200 Subject: [PATCH] Initial version of socket monitor impl. Related to #103 --- src/backend.rs | 9 ++++++++- src/dealer.rs | 12 +++++++++++- src/lib.rs | 21 +++++++++++++++++++++ src/pub.rs | 19 +++++++++++++++++-- src/pull.rs | 12 +++++++++++- src/push.rs | 11 +++++++++-- src/rep.rs | 15 +++++++++++++++ src/req.rs | 21 ++++++++++++--------- src/router.rs | 9 ++++++++- src/sub.rs | 9 ++++++++- src/util.rs | 41 +++++++++++++++++++++++++++++++++-------- tests/req_rep.rs | 6 ++++++ 12 files changed, 159 insertions(+), 26 deletions(-) diff --git a/src/backend.rs b/src/backend.rs index ba880d4..b59a234 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,9 +1,10 @@ use crate::codec::{FramedIo, Message, ZmqFramedRead, ZmqFramedWrite}; use crate::fair_queue::QueueInner; use crate::util::PeerIdentity; -use crate::{MultiPeerBackend, SocketBackend, SocketType, ZmqError, ZmqResult}; +use crate::{MultiPeerBackend, SocketBackend, SocketEvent, SocketType, ZmqError, ZmqResult}; use crossbeam::queue::SegQueue; use dashmap::DashMap; +use futures::channel::mpsc; use futures::SinkExt; use parking_lot::Mutex; use std::sync::Arc; @@ -17,6 +18,7 @@ pub(crate) struct GenericSocketBackend { fair_queue_inner: Option>>>, pub(crate) round_robin: SegQueue, socket_type: SocketType, + pub(crate) socket_monitor: Mutex>>, } impl GenericSocketBackend { @@ -29,6 +31,7 @@ impl GenericSocketBackend { fair_queue_inner, round_robin: SegQueue::new(), socket_type, + socket_monitor: Mutex::new(None), } } @@ -85,6 +88,10 @@ impl SocketBackend for GenericSocketBackend { fn shutdown(&self) { self.peers.clear(); } + + fn monitor(&self) -> &Mutex>> { + &self.socket_monitor + } } impl MultiPeerBackend for GenericSocketBackend { diff --git a/src/dealer.rs b/src/dealer.rs index c7a5c7b..7b3c7ea 100644 --- a/src/dealer.rs +++ b/src/dealer.rs @@ -3,8 +3,12 @@ use crate::codec::{Message, ZmqFramedRead}; use crate::fair_queue::FairQueue; use crate::transport::AcceptStopHandle; use crate::util::PeerIdentity; -use crate::{Endpoint, MultiPeerBackend, Socket, SocketBackend, SocketType, ZmqMessage, ZmqResult}; +use crate::{ + Endpoint, MultiPeerBackend, Socket, SocketBackend, SocketEvent, SocketType, ZmqMessage, + ZmqResult, +}; use async_trait::async_trait; +use futures::channel::mpsc; use futures::StreamExt; use std::collections::hash_map::RandomState; use std::collections::HashMap; @@ -42,6 +46,12 @@ impl Socket for DealerSocket { fn binds(&mut self) -> &mut HashMap { &mut self.binds } + + fn monitor(&mut self) -> mpsc::Receiver { + let (sender, receiver) = mpsc::channel(1024); + self.backend.socket_monitor.lock().replace(sender); + receiver + } } impl DealerSocket { diff --git a/src/lib.rs b/src/lib.rs index 6933f69..ea7f009 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,8 +38,10 @@ use util::PeerIdentity; extern crate enum_primitive_derive; use async_trait::async_trait; +use futures::channel::mpsc; use futures_codec::FramedWrite; use num_traits::ToPrimitive; +use parking_lot::Mutex; use std::collections::HashMap; use std::convert::TryFrom; use std::fmt::{Debug, Display}; @@ -102,6 +104,18 @@ impl Display for SocketType { } } +#[derive(Debug)] +pub enum SocketEvent { + Connected, + ConnectDelayed, + ConnectRetried, + Accepted(Endpoint, PeerIdentity), + AcceptFailed(ZmqError), + Closed, + CloseFailed, + Disconnected(PeerIdentity), +} + pub trait MultiPeerBackend: SocketBackend { /// This should not be public.. /// Find a better way of doing this @@ -113,6 +127,7 @@ pub trait MultiPeerBackend: SocketBackend { pub trait SocketBackend: Send + Sync { fn socket_type(&self) -> SocketType; fn shutdown(&self); + fn monitor(&self) -> &Mutex>>; } #[async_trait] @@ -195,6 +210,12 @@ pub trait Socket: Sized + Send { Ok(()) } + /// Creates and setups new socket monitor + /// + /// Subsequent calls to this method each create a new monitor channel. + /// Sender side of previous one is dropped. + fn monitor(&mut self) -> mpsc::Receiver; + // TODO: async fn connections(&self) -> ? /// Disconnects from the given endpoint, blocking until finished. diff --git a/src/pub.rs b/src/pub.rs index ce03df6..a5ee0f4 100644 --- a/src/pub.rs +++ b/src/pub.rs @@ -4,11 +4,14 @@ use crate::error::ZmqResult; use crate::message::*; use crate::transport::AcceptStopHandle; use crate::util::PeerIdentity; -use crate::{BlockingSend, MultiPeerBackend, Socket, SocketBackend, SocketType, ZmqError}; -use futures::channel::oneshot; +use crate::{ + BlockingSend, MultiPeerBackend, Socket, SocketBackend, SocketEvent, SocketType, ZmqError, +}; +use futures::channel::{mpsc, oneshot}; use async_trait::async_trait; use dashmap::DashMap; +use parking_lot::Mutex; use std::collections::HashMap; use std::io::ErrorKind; use std::pin::Pin; @@ -22,6 +25,7 @@ pub(crate) struct Subscriber { pub(crate) struct PubSocketBackend { subscribers: DashMap, + socket_monitor: Mutex>>, } impl PubSocketBackend { @@ -81,6 +85,10 @@ impl SocketBackend for PubSocketBackend { fn shutdown(&self) { self.subscribers.clear(); } + + fn monitor(&self) -> &Mutex>> { + &self.socket_monitor + } } impl MultiPeerBackend for PubSocketBackend { @@ -184,6 +192,7 @@ impl Socket for PubSocket { Self { backend: Arc::new(PubSocketBackend { subscribers: DashMap::new(), + socket_monitor: Mutex::new(None), }), binds: HashMap::new(), } @@ -196,6 +205,12 @@ impl Socket for PubSocket { fn binds(&mut self) -> &mut HashMap { &mut self.binds } + + fn monitor(&mut self) -> mpsc::Receiver { + let (sender, receiver) = mpsc::channel(1024); + self.backend.socket_monitor.lock().replace(sender); + receiver + } } #[cfg(test)] diff --git a/src/pull.rs b/src/pull.rs index 476f137..fbe9d88 100644 --- a/src/pull.rs +++ b/src/pull.rs @@ -3,8 +3,12 @@ use crate::codec::{Message, ZmqFramedRead}; use crate::fair_queue::FairQueue; use crate::transport::AcceptStopHandle; use crate::util::PeerIdentity; -use crate::{BlockingRecv, Endpoint, MultiPeerBackend, Socket, SocketType, ZmqMessage, ZmqResult}; +use crate::{ + BlockingRecv, Endpoint, MultiPeerBackend, Socket, SocketEvent, SocketType, ZmqMessage, + ZmqResult, +}; use async_trait::async_trait; +use futures::channel::mpsc; use futures::StreamExt; use std::collections::hash_map::RandomState; use std::collections::HashMap; @@ -37,6 +41,12 @@ impl Socket for PullSocket { fn binds(&mut self) -> &mut HashMap { &mut self.binds } + + fn monitor(&mut self) -> mpsc::Receiver { + let (sender, receiver) = mpsc::channel(1024); + self.backend.socket_monitor.lock().replace(sender); + receiver + } } #[async_trait] diff --git a/src/push.rs b/src/push.rs index 4e42f40..8d2de96 100644 --- a/src/push.rs +++ b/src/push.rs @@ -2,10 +2,11 @@ use crate::backend::GenericSocketBackend; use crate::codec::Message; use crate::transport::AcceptStopHandle; use crate::{ - BlockingSend, Endpoint, MultiPeerBackend, Socket, SocketBackend, SocketType, ZmqMessage, - ZmqResult, + BlockingSend, Endpoint, MultiPeerBackend, Socket, SocketBackend, SocketEvent, SocketType, + ZmqMessage, ZmqResult, }; use async_trait::async_trait; +use futures::channel::mpsc; use std::collections::hash_map::RandomState; use std::collections::HashMap; use std::sync::Arc; @@ -36,6 +37,12 @@ impl Socket for PushSocket { fn binds(&mut self) -> &mut HashMap { &mut self.binds } + + fn monitor(&mut self) -> mpsc::Receiver { + let (sender, receiver) = mpsc::channel(1024); + self.backend.socket_monitor.lock().replace(sender); + receiver + } } #[async_trait] diff --git a/src/rep.rs b/src/rep.rs index 5c0fd6e..61c9e4d 100644 --- a/src/rep.rs +++ b/src/rep.rs @@ -21,6 +21,7 @@ struct RepPeer { struct RepSocketBackend { pub(crate) peers: DashMap, fair_queue_inner: Arc>>, + socket_monitor: Mutex>>, } pub struct RepSocket { @@ -44,6 +45,7 @@ impl Socket for RepSocket { backend: Arc::new(RepSocketBackend { peers: DashMap::new(), fair_queue_inner: fair_queue.inner(), + socket_monitor: Mutex::new(None), }), current_request: None, fair_queue, @@ -58,6 +60,12 @@ impl Socket for RepSocket { fn binds(&mut self) -> &mut HashMap { &mut self.binds } + + fn monitor(&mut self) -> mpsc::Receiver { + let (sender, receiver) = mpsc::channel(1024); + self.backend.socket_monitor.lock().replace(sender); + receiver + } } impl MultiPeerBackend for RepSocketBackend { @@ -77,6 +85,9 @@ impl MultiPeerBackend for RepSocketBackend { } fn peer_disconnected(&self, peer_id: &PeerIdentity) { + if let Some(monitor) = self.monitor().lock().as_mut() { + let _ = monitor.try_send(SocketEvent::Disconnected(peer_id.clone())); + } self.peers.remove(peer_id); } } @@ -89,6 +100,10 @@ impl SocketBackend for RepSocketBackend { fn shutdown(&self) { self.peers.clear(); } + + fn monitor(&self) -> &Mutex>> { + &self.socket_monitor + } } #[async_trait] diff --git a/src/req.rs b/src/req.rs index 51ac00f..6f6eea5 100644 --- a/src/req.rs +++ b/src/req.rs @@ -9,7 +9,6 @@ use crate::{SocketType, ZmqResult}; use async_trait::async_trait; use crossbeam::queue::SegQueue; use dashmap::DashMap; -use futures::lock::Mutex; use futures::{SinkExt, StreamExt}; use std::collections::HashMap; use std::sync::Arc; @@ -17,7 +16,7 @@ use std::sync::Arc; struct ReqSocketBackend { pub(crate) peers: DashMap, pub(crate) round_robin: SegQueue, - pub(crate) current_request_peer_id: Mutex>, + socket_monitor: Mutex>>, } pub struct ReqSocket { @@ -64,11 +63,6 @@ impl BlockingSend for ReqSocket { message, ]; peer.send_queue.send(Message::Multipart(frames)).await?; - self.backend - .current_request_peer_id - .lock() - .await - .replace(next_peer_id.clone()); self.current_request = Some(next_peer_id); return Ok(()); } @@ -110,7 +104,7 @@ impl Socket for ReqSocket { backend: Arc::new(ReqSocketBackend { peers: DashMap::new(), round_robin: SegQueue::new(), - current_request_peer_id: Mutex::new(None), + socket_monitor: Mutex::new(None), }), current_request: None, binds: HashMap::new(), @@ -124,6 +118,12 @@ impl Socket for ReqSocket { fn binds(&mut self) -> &mut HashMap { &mut self.binds } + + fn monitor(&mut self) -> mpsc::Receiver { + let (sender, receiver) = mpsc::channel(1024); + self.backend.socket_monitor.lock().replace(sender); + receiver + } } impl MultiPeerBackend for ReqSocketBackend { @@ -151,7 +151,10 @@ impl SocketBackend for ReqSocketBackend { } fn shutdown(&self) { - println!("Shutting down req backend"); self.peers.clear(); } + + fn monitor(&self) -> &Mutex>> { + &self.socket_monitor + } } diff --git a/src/router.rs b/src/router.rs index 302491c..37d11ef 100644 --- a/src/router.rs +++ b/src/router.rs @@ -12,8 +12,9 @@ use crate::fair_queue::FairQueue; use crate::message::*; use crate::transport::AcceptStopHandle; use crate::util::PeerIdentity; -use crate::{MultiPeerBackend, SocketType}; +use crate::{MultiPeerBackend, SocketEvent, SocketType}; use crate::{Socket, SocketBackend}; +use futures::channel::mpsc; use futures::SinkExt; pub struct RouterSocket { @@ -49,6 +50,12 @@ impl Socket for RouterSocket { fn binds(&mut self) -> &mut HashMap { &mut self.binds } + + fn monitor(&mut self) -> mpsc::Receiver { + let (sender, receiver) = mpsc::channel(1024); + self.backend.socket_monitor.lock().replace(sender); + receiver + } } impl RouterSocket { diff --git a/src/sub.rs b/src/sub.rs index 4b8e413..b150fdc 100644 --- a/src/sub.rs +++ b/src/sub.rs @@ -4,12 +4,13 @@ use crate::error::ZmqResult; use crate::message::*; use crate::transport::AcceptStopHandle; use crate::util::PeerIdentity; -use crate::{BlockingRecv, MultiPeerBackend, Socket, SocketBackend, SocketType}; +use crate::{BlockingRecv, MultiPeerBackend, Socket, SocketBackend, SocketEvent, SocketType}; use crate::backend::GenericSocketBackend; use crate::fair_queue::FairQueue; use async_trait::async_trait; use bytes::{BufMut, BytesMut}; +use futures::channel::mpsc; use futures::{SinkExt, StreamExt}; use std::collections::HashMap; use std::sync::Arc; @@ -74,6 +75,12 @@ impl Socket for SubSocket { fn binds(&mut self) -> &mut HashMap { &mut self.binds } + + fn monitor(&mut self) -> mpsc::Receiver { + let (sender, receiver) = mpsc::channel(1024); + self.backend.socket_monitor.lock().replace(sender); + receiver + } } #[async_trait] diff --git a/src/util.rs b/src/util.rs index 9895c4f..02ca3df 100644 --- a/src/util.rs +++ b/src/util.rs @@ -172,14 +172,39 @@ pub(crate) async fn peer_connected( accept_result: ZmqResult<(FramedIo, Endpoint)>, backend: Arc, ) { - let (mut raw_socket, _remote_endpoint) = accept_result.expect("Failed to accept"); - greet_exchange(&mut raw_socket) - .await - .expect("Failed to exchange greetings"); - let peer_id = ready_exchange(&mut raw_socket, backend.socket_type()) - .await - .expect("Failed to exchange ready messages"); - + // TODO find a better way of writing this + let (mut raw_socket, remote_endpoint) = match accept_result { + Ok((socket, remote_endpoint)) => (socket, remote_endpoint), + Err(e) => { + if let Some(monitor) = backend.monitor().lock().as_mut() { + let _ = monitor.try_send(SocketEvent::AcceptFailed(e)); + } + return; + } + }; + match greet_exchange(&mut raw_socket).await { + Ok(_) => (), + Err(e) => { + if let Some(monitor) = backend.monitor().lock().as_mut() { + let _ = monitor.try_send(SocketEvent::AcceptFailed(e)); + } + return; + } + }; + let peer_id = match ready_exchange(&mut raw_socket, backend.socket_type()).await { + Ok(peer_id) => { + if let Some(monitor) = backend.monitor().lock().as_mut() { + let _ = monitor.try_send(SocketEvent::Accepted(remote_endpoint, peer_id.clone())); + }; + peer_id + } + Err(e) => { + if let Some(monitor) = backend.monitor().lock().as_mut() { + let _ = monitor.try_send(SocketEvent::AcceptFailed(e)); + } + return; + } + }; backend.peer_connected(&peer_id, raw_socket); } diff --git a/tests/req_rep.rs b/tests/req_rep.rs index d64bad2..83ca194 100644 --- a/tests/req_rep.rs +++ b/tests/req_rep.rs @@ -1,6 +1,7 @@ use zeromq::prelude::*; use zeromq::RepSocket; +use futures::StreamExt; use std::convert::TryInto; use std::error::Error; use std::time::Duration; @@ -27,6 +28,7 @@ async fn test_req_rep_sockets() -> Result<(), Box> { pretty_env_logger::try_init().ok(); let mut rep_socket = zeromq::RepSocket::new(); + let monitor = rep_socket.monitor(); let endpoint = rep_socket.bind("tcp://localhost:0").await?; println!("Started rep server on {}", endpoint); @@ -42,6 +44,10 @@ async fn test_req_rep_sockets() -> Result<(), Box> { let repl: String = req_socket.recv().await?.try_into()?; assert_eq!(format!("Req - {} Rep - {}", i, i), repl) } + req_socket.close().await; + let events: Vec<_> = monitor.collect().await; + // Currently it only contains Accepted event + assert_eq!(1, events.len()); Ok(()) }