diff --git a/socketioxide/src/ack.rs b/socketioxide/src/ack.rs index 75588037..373bf25c 100644 --- a/socketioxide/src/ack.rs +++ b/socketioxide/src/ack.rs @@ -19,7 +19,7 @@ use serde::de::DeserializeOwned; use serde_json::Value; use tokio::{sync::oneshot::Receiver, time::Timeout}; -use crate::{adapter::Adapter, errors::AckError, extract::SocketRef, packet::Packet, SocketError}; +use crate::{errors::AckError, extract::SocketRef, packet::Packet, SocketError}; /// An acknowledgement sent by the client. /// It contains the data sent by the client and the binary payloads if there are any. @@ -145,9 +145,9 @@ impl AckInnerStream { /// /// The [`AckInnerStream`] will wait for the default timeout specified in the config /// (5s by default) if no custom timeout is specified. - pub fn broadcast( + pub fn broadcast( packet: Packet<'static>, - sockets: Vec>, + sockets: Vec, duration: Option, ) -> Self { let rxs = FuturesUnordered::new(); @@ -312,13 +312,13 @@ mod test { use engineioxide::sid::Sid; use futures_util::StreamExt; - use crate::{adapter::LocalAdapter, ns::Namespace, socket::Socket}; + use crate::{ns::Namespace, socket::Socket}; use super::*; - fn create_socket() -> Arc> { + fn create_socket() -> Arc { let sid = Sid::new(); - let ns = Namespace::::new_dummy([sid]).into(); + let ns = Namespace::new_dummy([sid]).into(); let socket = Socket::new_dummy(sid, ns); socket.into() } diff --git a/socketioxide/src/adapter.rs b/socketioxide/src/adapter.rs index 58850c7d..eae5a41b 100644 --- a/socketioxide/src/adapter.rs +++ b/socketioxide/src/adapter.rs @@ -6,7 +6,6 @@ use std::{ borrow::Cow, collections::{HashMap, HashSet}, - convert::Infallible, sync::{RwLock, Weak}, time::Duration, }; @@ -48,32 +47,28 @@ pub struct BroadcastOptions { pub sid: Option, } //TODO: Make an AsyncAdapter trait -/// An adapter is responsible for managing the state of the server. +/// An adapter is responsible for managing the state of the server. There is one adapter per namespace. /// This adapter can be implemented to share the state between multiple servers. /// The default adapter is the [`LocalAdapter`], which stores the state in memory. pub trait Adapter: std::fmt::Debug + Send + Sync + 'static { - /// An error that can occur when using the adapter. The default [`LocalAdapter`] has an [`Infallible`] error. - type Error: std::error::Error + Into + Send + Sync + 'static; - - /// Create a new adapter and give the namespace ref to retrieve sockets. - fn new(ns: Weak>) -> Self - where - Self: Sized; + /// Returns a boxed clone of the adapter. + /// It is used to create a new empty instance of the adapter for a new namespace. + fn boxed_clone(&self) -> Box; /// Initializes the adapter. - fn init(&self) -> Result<(), Self::Error>; + fn init(&mut self, ns: Weak) -> Result<(), AdapterError>; /// Closes the adapter. - fn close(&self) -> Result<(), Self::Error>; + fn close(&self) -> Result<(), AdapterError>; /// Returns the number of servers. - fn server_count(&self) -> Result; + fn server_count(&self) -> Result; /// Adds the socket to all the rooms. - fn add_all(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Self::Error>; + fn add_all(&self, sid: Sid, rooms: Vec) -> Result<(), AdapterError>; /// Removes the socket from the rooms. - fn del(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Self::Error>; + fn del(&self, sid: Sid, rooms: Vec) -> Result<(), AdapterError>; /// Removes the socket from all the rooms. - fn del_all(&self, sid: Sid) -> Result<(), Self::Error>; + fn del_all(&self, sid: Sid) -> Result<(), AdapterError>; /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`]. fn broadcast(&self, packet: Packet<'_>, opts: BroadcastOptions) -> Result<(), BroadcastError>; @@ -87,28 +82,24 @@ pub trait Adapter: std::fmt::Debug + Send + Sync + 'static { ) -> AckInnerStream; /// Returns the sockets ids that match the [`BroadcastOptions`]. - fn sockets(&self, rooms: impl RoomParam) -> Result, Self::Error>; + fn sockets(&self, rooms: Vec) -> Result, AdapterError>; /// Returns the rooms of the socket. - fn socket_rooms(&self, sid: Sid) -> Result, Self::Error>; + fn socket_rooms(&self, sid: Sid) -> Result, AdapterError>; /// Returns the sockets that match the [`BroadcastOptions`]. - fn fetch_sockets(&self, opts: BroadcastOptions) -> Result>, Self::Error> - where - Self: Sized; + fn fetch_sockets(&self, opts: BroadcastOptions) -> Result, AdapterError>; /// Adds the sockets that match the [`BroadcastOptions`] to the rooms. - fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) - -> Result<(), Self::Error>; + fn add_sockets(&self, opts: BroadcastOptions, rooms: Vec) -> Result<(), AdapterError>; /// Removes the sockets that match the [`BroadcastOptions`] from the rooms. - fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) - -> Result<(), Self::Error>; + fn del_sockets(&self, opts: BroadcastOptions, rooms: Vec) -> Result<(), AdapterError>; /// Disconnects the sockets that match the [`BroadcastOptions`]. fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), Vec>; /// Returns all the rooms for this adapter. - fn rooms(&self) -> Result, Self::Error>; + fn rooms(&self) -> Result, AdapterError>; //TODO: implement // fn server_side_emit(&self, packet: Packet, opts: BroadcastOptions) -> Result; @@ -117,33 +108,23 @@ pub trait Adapter: std::fmt::Debug + Send + Sync + 'static { } /// The default adapter. Store the state in memory. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct LocalAdapter { rooms: RwLock>>, - ns: Weak>, -} - -impl From for AdapterError { - fn from(_: Infallible) -> AdapterError { - unreachable!() - } + ns: Weak, } impl Adapter for LocalAdapter { - type Error = Infallible; - - fn new(ns: Weak>) -> Self { - Self { - rooms: HashMap::new().into(), - ns, - } + fn boxed_clone(&self) -> Box { + Box::new(Self::new()) } - fn init(&self) -> Result<(), Infallible> { + fn init(&mut self, ns: Weak) -> Result<(), AdapterError> { + self.ns = ns; Ok(()) } - fn close(&self) -> Result<(), Infallible> { + fn close(&self) -> Result<(), AdapterError> { #[cfg(feature = "tracing")] tracing::debug!("closing local adapter: {}", self.ns.upgrade().unwrap().path); let mut rooms = self.rooms.write().unwrap(); @@ -152,21 +133,21 @@ impl Adapter for LocalAdapter { Ok(()) } - fn server_count(&self) -> Result { + fn server_count(&self) -> Result { Ok(1) } - fn add_all(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Infallible> { + fn add_all(&self, sid: Sid, rooms: Vec) -> Result<(), AdapterError> { let mut rooms_map = self.rooms.write().unwrap(); - for room in rooms.into_room_iter() { + for room in rooms { rooms_map.entry(room).or_default().insert(sid); } Ok(()) } - fn del(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Infallible> { + fn del(&self, sid: Sid, rooms: Vec) -> Result<(), AdapterError> { let mut rooms_map = self.rooms.write().unwrap(); - for room in rooms.into_room_iter() { + for room in rooms { if let Some(room) = rooms_map.get_mut(&room) { room.remove(&sid); } @@ -174,7 +155,7 @@ impl Adapter for LocalAdapter { Ok(()) } - fn del_all(&self, sid: Sid) -> Result<(), Infallible> { + fn del_all(&self, sid: Sid) -> Result<(), AdapterError> { let mut rooms_map = self.rooms.write().unwrap(); for room in rooms_map.values_mut() { room.remove(&sid); @@ -214,7 +195,7 @@ impl Adapter for LocalAdapter { AckInnerStream::broadcast(packet, sockets, timeout) } - fn sockets(&self, rooms: impl RoomParam) -> Result, Infallible> { + fn sockets(&self, rooms: Vec) -> Result, AdapterError> { let mut opts = BroadcastOptions::default(); opts.rooms.extend(rooms.into_room_iter()); Ok(self @@ -225,7 +206,7 @@ impl Adapter for LocalAdapter { } //TODO: make this operation O(1) - fn socket_rooms(&self, sid: Sid) -> Result>, Infallible> { + fn socket_rooms(&self, sid: Sid) -> Result>, AdapterError> { let rooms_map = self.rooms.read().unwrap(); Ok(rooms_map .iter() @@ -234,11 +215,11 @@ impl Adapter for LocalAdapter { .collect()) } - fn fetch_sockets(&self, opts: BroadcastOptions) -> Result>, Infallible> { + fn fetch_sockets(&self, opts: BroadcastOptions) -> Result, AdapterError> { Ok(self.apply_opts(opts)) } - fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) -> Result<(), Infallible> { + fn add_sockets(&self, opts: BroadcastOptions, rooms: Vec) -> Result<(), AdapterError> { let rooms: Vec = rooms.into_room_iter().collect(); for socket in self.apply_opts(opts) { self.add_all(socket.id, rooms.clone()).unwrap(); @@ -246,7 +227,7 @@ impl Adapter for LocalAdapter { Ok(()) } - fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) -> Result<(), Infallible> { + fn del_sockets(&self, opts: BroadcastOptions, rooms: Vec) -> Result<(), AdapterError> { let rooms: Vec = rooms.into_room_iter().collect(); for socket in self.apply_opts(opts) { self.del(socket.id, rooms.clone()).unwrap(); @@ -270,14 +251,18 @@ impl Adapter for LocalAdapter { } } - fn rooms(&self) -> Result, Self::Error> { + fn rooms(&self) -> Result, AdapterError> { Ok(self.rooms.read().unwrap().keys().cloned().collect()) } } impl LocalAdapter { + /// Creates a new [LocalAdapter]. + pub fn new() -> Self { + Self::default() + } /// Applies the given `opts` and return the sockets that match. - fn apply_opts(&self, opts: BroadcastOptions) -> Vec> { + fn apply_opts(&self, opts: BroadcastOptions) -> Vec { let rooms = opts.rooms; let except = self.get_except_sids(&opts.except); @@ -335,19 +320,26 @@ mod test { }; } + fn local_adapter(ns: &Arc) -> LocalAdapter { + let mut adapter = LocalAdapter::new(); + adapter.init(Arc::downgrade(ns)).unwrap(); + adapter + } #[tokio::test] async fn test_server_count() { - let ns = Namespace::new_dummy([]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); + let ns = Namespace::new_dummy([]).into(); + let adapter = local_adapter(&ns); assert_eq!(adapter.server_count().unwrap(), 1); } #[tokio::test] async fn test_add_all() { let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1", "room2"]).unwrap(); + let ns = Namespace::new_dummy([socket]).into(); + let adapter = local_adapter(&ns); + adapter + .add_all(socket, vec!["room1".into(), "room2".into()]) + .unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); assert_eq!(rooms_map.get("room1").unwrap().len(), 1); @@ -357,10 +349,12 @@ mod test { #[tokio::test] async fn test_del() { let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1", "room2"]).unwrap(); - adapter.del(socket, "room1").unwrap(); + let ns = Namespace::new_dummy([socket]).into(); + let adapter = local_adapter(&ns); + adapter + .add_all(socket, vec!["room1".into(), "room2".into()]) + .unwrap(); + adapter.del(socket, vec!["room1".into()]).unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); assert_eq!(rooms_map.get("room1").unwrap().len(), 0); @@ -370,9 +364,11 @@ mod test { #[tokio::test] async fn test_del_all() { let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1", "room2"]).unwrap(); + let ns = Namespace::new_dummy([socket]).into(); + let adapter = local_adapter(&ns); + adapter + .add_all(socket, vec!["room1".into(), "room2".into()]) + .unwrap(); adapter.del_all(socket).unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); @@ -385,11 +381,13 @@ mod test { let sid1 = Sid::new(); let sid2 = Sid::new(); let sid3 = Sid::new(); - let ns = Namespace::new_dummy([sid1, sid2, sid3]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(sid1, ["room1", "room2"]).unwrap(); - adapter.add_all(sid2, ["room1"]).unwrap(); - adapter.add_all(sid3, ["room2"]).unwrap(); + let ns = Namespace::new_dummy([sid1, sid2, sid3]).into(); + let adapter = local_adapter(&ns); + adapter + .add_all(sid1, vec!["room1".into(), "room2".into()]) + .unwrap(); + adapter.add_all(sid2, vec!["room1".into()]).unwrap(); + adapter.add_all(sid3, vec!["room2".into()]).unwrap(); assert!(adapter .socket_rooms(sid1) .unwrap() @@ -405,16 +403,16 @@ mod test { #[tokio::test] async fn test_add_socket() { let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1"]).unwrap(); + let ns = Namespace::new_dummy([socket]).into(); + let adapter = local_adapter(&ns); + adapter.add_all(socket, vec!["room1".into()]).unwrap(); let mut opts = BroadcastOptions { sid: Some(socket), ..Default::default() }; opts.rooms = hash_set!["room1".into()]; - adapter.add_sockets(opts, "room2").unwrap(); + adapter.add_sockets(opts, vec!["room2".into()]).unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); @@ -425,16 +423,16 @@ mod test { #[tokio::test] async fn test_del_socket() { let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1"]).unwrap(); + let ns = Namespace::new_dummy([socket]).into(); + let adapter = local_adapter(&ns); + adapter.add_all(socket, vec!["room1".into()]).unwrap(); let mut opts = BroadcastOptions { sid: Some(socket), ..Default::default() }; opts.rooms = hash_set!["room1".into()]; - adapter.add_sockets(opts, "room2").unwrap(); + adapter.add_sockets(opts, vec!["room2".into()]).unwrap(); { let rooms_map = adapter.rooms.read().unwrap(); @@ -449,7 +447,7 @@ mod test { ..Default::default() }; opts.rooms = hash_set!["room1".into()]; - adapter.del_sockets(opts, "room2").unwrap(); + adapter.del_sockets(opts, vec!["room2".into()]).unwrap(); { let rooms_map = adapter.rooms.read().unwrap(); @@ -465,23 +463,29 @@ mod test { let socket0 = Sid::new(); let socket1 = Sid::new(); let socket2 = Sid::new(); - let ns = Namespace::new_dummy([socket0, socket1, socket2]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket0, ["room1", "room2"]).unwrap(); - adapter.add_all(socket1, ["room1", "room3"]).unwrap(); - adapter.add_all(socket2, ["room2", "room3"]).unwrap(); + let ns = Namespace::new_dummy([socket0, socket1, socket2]).into(); + let adapter = local_adapter(&ns); + adapter + .add_all(socket0, vec!["room1".into(), "room2".into()]) + .unwrap(); + adapter + .add_all(socket1, vec!["room1".into(), "room3".into()]) + .unwrap(); + adapter + .add_all(socket2, vec!["room2".into(), "room3".into()]) + .unwrap(); - let sockets = adapter.sockets("room1").unwrap(); + let sockets = adapter.sockets(vec!["room1".into()]).unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket0)); assert!(sockets.contains(&socket1)); - let sockets = adapter.sockets("room2").unwrap(); + let sockets = adapter.sockets(vec!["room2".into()]).unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket0)); assert!(sockets.contains(&socket2)); - let sockets = adapter.sockets("room3").unwrap(); + let sockets = adapter.sockets(vec!["room3".into()]).unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket1)); assert!(sockets.contains(&socket2)); @@ -492,16 +496,25 @@ mod test { let socket0 = Sid::new(); let socket1 = Sid::new(); let socket2 = Sid::new(); - let ns = Namespace::new_dummy([socket0, socket1, socket2]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); + let ns = Namespace::new_dummy([socket0, socket1, socket2]).into(); + let adapter = local_adapter(&ns); adapter - .add_all(socket0, ["room1", "room2", "room4"]) + .add_all( + socket0, + vec!["room1".into(), "room2".into(), "room4".into()], + ) .unwrap(); adapter - .add_all(socket1, ["room1", "room3", "room5"]) + .add_all( + socket1, + vec!["room1".into(), "room3".into(), "room5".into()], + ) .unwrap(); adapter - .add_all(socket2, ["room2", "room3", "room6"]) + .add_all( + socket2, + vec!["room2".into(), "room3".into(), "room6".into()], + ) .unwrap(); let mut opts = BroadcastOptions { @@ -511,7 +524,7 @@ mod test { opts.rooms = hash_set!["room5".into()]; adapter.disconnect_socket(opts).unwrap(); - let sockets = adapter.sockets("room2").unwrap(); + let sockets = adapter.sockets(vec!["room2".into()]).unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket2)); assert!(sockets.contains(&socket0)); @@ -521,15 +534,22 @@ mod test { let socket0 = Sid::new(); let socket1 = Sid::new(); let socket2 = Sid::new(); - let ns = Namespace::new_dummy([socket0, socket1, socket2]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); + let ns = Namespace::new_dummy([socket0, socket1, socket2]).into(); + let adapter = local_adapter(&ns); // Add socket 0 to room1 and room2 - adapter.add_all(socket0, ["room1", "room2"]).unwrap(); + adapter + .add_all(socket0, vec!["room1".into(), "room2".into()]) + .unwrap(); // Add socket 1 to room1 and room3 - adapter.add_all(socket1, ["room1", "room3"]).unwrap(); + adapter + .add_all(socket1, vec!["room1".into(), "room3".into()]) + .unwrap(); // Add socket 2 to room2 and room3 adapter - .add_all(socket2, ["room1", "room2", "room3"]) + .add_all( + socket2, + vec!["room1".into(), "room2".into(), "room3".into()], + ) .unwrap(); // socket 2 is the sender diff --git a/socketioxide/src/client.rs b/socketioxide/src/client.rs index 0cd7c475..2ad47d40 100644 --- a/socketioxide/src/client.rs +++ b/socketioxide/src/client.rs @@ -12,7 +12,6 @@ use engineioxide::sid::Sid; use matchit::{Match, Router}; use tokio::sync::oneshot; -use crate::adapter::Adapter; use crate::handler::ConnectHandler; use crate::ns::NamespaceCtr; use crate::socket::DisconnectReason; @@ -24,10 +23,10 @@ use crate::{ }; use crate::{ProtocolVersion, SocketIo}; -pub struct Client { +pub struct Client { pub(crate) config: SocketIoConfig, - ns: RwLock, Arc>>>, - router: RwLock>>, + ns: RwLock, Arc>>, + router: RwLock>, #[cfg(feature = "state")] pub(crate) state: state::TypeMap![Send + Sync], @@ -35,7 +34,7 @@ pub struct Client { /// ==== impl Client ==== -impl Client { +impl Client { pub fn new( config: SocketIoConfig, #[cfg(feature = "state")] mut state: state::TypeMap![Send + Sync], @@ -57,26 +56,25 @@ impl Client { &self, auth: Option, ns_path: Str, - esocket: &Arc>>, + esocket: &Arc>, ) { #[cfg(feature = "tracing")] tracing::debug!("auth: {:?}", auth); let protocol: ProtocolVersion = esocket.protocol.into(); - let connect = - move |ns: Arc>, esocket: Arc>>| async move { - if ns.connect(esocket.id, esocket.clone(), auth).await.is_ok() { - // cancel the connect timeout task for v5 - if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() { - tx.send(()).ok(); - } + let connect = move |ns: Arc, esocket: Arc>| async move { + if ns.connect(esocket.id, esocket.clone(), auth).await.is_ok() { + // cancel the connect timeout task for v5 + if let Some(tx) = esocket.data.connect_recv_tx.lock().unwrap().take() { + tx.send(()).ok(); } - }; + } + }; if let Some(ns) = self.get_ns(&ns_path) { tokio::spawn(connect(ns, esocket.clone())); } else if let Ok(Match { value: ns_ctr, .. }) = self.router.read().unwrap().at(&ns_path) { let path: Cow<'static, str> = Cow::Owned(ns_path.clone().into()); - let ns = ns_ctr.get_new_ns(ns_path); //TODO: check memory leak here + let ns = ns_ctr.get_new_ns(ns_path, self.config.adapter.boxed_clone()); //TODO: check memory leak here self.ns.write().unwrap().insert(path, ns.clone()); tokio::spawn(connect(ns, esocket.clone())); } else if protocol == ProtocolVersion::V4 && ns_path == "/" { @@ -107,7 +105,7 @@ impl Client { /// Spawn a task that will close the socket if it is not connected to a namespace /// after the [`SocketIoConfig::connect_timeout`] duration - fn spawn_connect_timeout_task(&self, socket: Arc>>) { + fn spawn_connect_timeout_task(&self, socket: Arc>) { #[cfg(feature = "tracing")] tracing::debug!("spawning connect timeout task"); let (tx, rx) = oneshot::channel(); @@ -125,18 +123,22 @@ impl Client { /// Adds a new namespace handler pub fn add_ns(&self, path: Cow<'static, str>, callback: C) where - C: ConnectHandler, + C: ConnectHandler, T: Send + Sync + 'static, { #[cfg(feature = "tracing")] tracing::debug!("adding namespace {}", path); - let ns = Namespace::new(Str::from(&path), callback); + let ns = Namespace::new( + Str::from(&path), + callback, + self.config.adapter.boxed_clone(), + ); self.ns.write().unwrap().insert(path, ns); } pub fn add_dyn_ns(&self, path: String, callback: C) -> Result<(), matchit::InsertError> where - C: ConnectHandler, + C: ConnectHandler, T: Send + Sync + 'static, { #[cfg(feature = "tracing")] @@ -161,7 +163,7 @@ impl Client { } } - pub fn get_ns(&self, path: &str) -> Option>> { + pub fn get_ns(&self, path: &str) -> Option> { self.ns.read().unwrap().get(path).cloned() } @@ -182,7 +184,7 @@ impl Client { } #[derive(Debug)] -pub struct SocketData { +pub struct SocketData { /// Partial binary packet that is being received /// Stored here until all the binary payloads are received pub partial_bin_packet: Mutex>>, @@ -191,9 +193,9 @@ pub struct SocketData { pub connect_recv_tx: Mutex>>, /// Used to store the [`SocketIo`] instance so it can be accessed by any sockets - pub io: OnceLock>, + pub io: OnceLock, } -impl Default for SocketData { +impl Default for SocketData { fn default() -> Self { Self { partial_bin_packet: Default::default(), @@ -203,11 +205,11 @@ impl Default for SocketData { } } -impl EngineIoHandler for Client { - type Data = SocketData; +impl EngineIoHandler for Client { + type Data = SocketData; #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, socket), fields(sid = socket.id.to_string())))] - fn on_connect(self: Arc, socket: Arc>>) { + fn on_connect(self: Arc, socket: Arc>) { socket.data.io.set(SocketIo::from(self.clone())).ok(); #[cfg(feature = "tracing")] @@ -228,7 +230,7 @@ impl EngineIoHandler for Client { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, socket), fields(sid = socket.id.to_string())))] - fn on_disconnect(&self, socket: Arc>>, reason: EIoDisconnectReason) { + fn on_disconnect(&self, socket: Arc>, reason: EIoDisconnectReason) { #[cfg(feature = "tracing")] tracing::debug!("eio socket disconnected"); let socks: Vec<_> = self @@ -255,7 +257,7 @@ impl EngineIoHandler for Client { } } - fn on_message(&self, msg: Str, socket: Arc>>) { + fn on_message(&self, msg: Str, socket: Arc>) { #[cfg(feature = "tracing")] tracing::debug!("Received message: {:?}", msg); let packet = match Packet::try_from(msg) { @@ -303,7 +305,7 @@ impl EngineIoHandler for Client { /// When a binary payload is received from a socket, it is applied to the partial binary packet /// /// If the packet is complete, it is propagated to the namespace - fn on_binary(&self, data: Bytes, socket: Arc>>) { + fn on_binary(&self, data: Bytes, socket: Arc>) { if apply_payload_on_packet(data, &socket) { if let Some(packet) = socket.data.partial_bin_packet.lock().unwrap().take() { if let Err(ref err) = self.sock_propagate_packet(packet, socket.id) { @@ -321,7 +323,7 @@ impl EngineIoHandler for Client { } } } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut f = f.debug_struct("Client"); f.field("config", &self.config).field("ns", &self.ns); @@ -335,7 +337,7 @@ impl std::fmt::Debug for Client { /// waiting to be filled with all the payloads /// /// Returns true if the packet is complete and should be processed -fn apply_payload_on_packet(data: Bytes, socket: &EIoSocket>) -> bool { +fn apply_payload_on_packet(data: Bytes, socket: &EIoSocket) -> bool { #[cfg(feature = "tracing")] tracing::debug!("[sid={}] applying payload on packet", socket.id); if let Some(ref mut packet) = *socket.data.partial_bin_packet.lock().unwrap() { @@ -354,7 +356,7 @@ fn apply_payload_on_packet(data: Bytes, socket: &EIoSocket Client { +impl Client { pub async fn new_dummy_sock( self: Arc, ns: &'static str, @@ -366,7 +368,7 @@ impl Client { let buffer_size = self.config.engine_config.max_buffer_size; let sid = Sid::new(); let (esock, rx) = - EIoSocket::>::new_dummy_piped(sid, Box::new(|_, _| {}), buffer_size); + EIoSocket::::new_dummy_piped(sid, Box::new(|_, _| {}), buffer_size); esock.data.io.set(SocketIo::from(self.clone())).ok(); let (tx1, mut rx1) = tokio::sync::mpsc::channel(buffer_size); tokio::spawn({ @@ -408,15 +410,14 @@ mod test { use super::*; use tokio::sync::mpsc; - use crate::adapter::LocalAdapter; const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(10); - fn create_client() -> Arc> { + fn create_client() -> Arc { let config = crate::SocketIoConfig { connect_timeout: CONNECT_TIMEOUT, ..Default::default() }; - let client = Client::::new( + let client = Client::new( config, #[cfg(feature = "state")] Default::default(), diff --git a/socketioxide/src/extract/data.rs b/socketioxide/src/extract/data.rs index acc86f87..a1ca4cdb 100644 --- a/socketioxide/src/extract/data.rs +++ b/socketioxide/src/extract/data.rs @@ -2,7 +2,7 @@ use std::convert::Infallible; use std::sync::Arc; use crate::handler::{FromConnectParts, FromMessage, FromMessageParts}; -use crate::{adapter::Adapter, socket::Socket}; +use crate::socket::Socket; use bytes::Bytes; use serde::de::DeserializeOwned; use serde_json::Value; @@ -21,27 +21,25 @@ fn upwrap_array(v: &mut Value) { /// If a deserialization error occurs, the handler won't be called /// and an error log will be print if the `tracing` feature is enabled. pub struct Data(pub T); -impl FromConnectParts for Data +impl FromConnectParts for Data where T: DeserializeOwned, - A: Adapter, { type Error = serde_json::Error; - fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { + fn from_connect_parts(_: &Arc, auth: &Option) -> Result { auth.as_ref() .map(|a| serde_json::from_str::(a)) .unwrap_or(serde_json::from_str::("{}")) .map(Data) } } -impl FromMessageParts for Data +impl FromMessageParts for Data where T: DeserializeOwned, - A: Adapter, { type Error = serde_json::Error; fn from_message_parts( - _: &Arc>, + _: &Arc, v: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -54,13 +52,12 @@ where /// An Extractor that returns the deserialized data related to the event. pub struct TryData(pub Result); -impl FromConnectParts for TryData +impl FromConnectParts for TryData where T: DeserializeOwned, - A: Adapter, { type Error = Infallible; - fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { + fn from_connect_parts(_: &Arc, auth: &Option) -> Result { let v: Result = auth .as_ref() .map(|a| serde_json::from_str(a)) @@ -68,14 +65,13 @@ where Ok(TryData(v)) } } -impl FromMessageParts for TryData +impl FromMessageParts for TryData where T: DeserializeOwned, - A: Adapter, { type Error = Infallible; fn from_message_parts( - _: &Arc>, + _: &Arc, v: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -88,10 +84,10 @@ where /// An Extractor that returns the binary data of the message. /// If there is no binary data, it will contain an empty vec. pub struct Bin(pub Vec); -impl FromMessage for Bin { +impl FromMessage for Bin { type Error = Infallible; fn from_message( - _: Arc>, + _: Arc, _: serde_json::Value, bin: Vec, _: Option, diff --git a/socketioxide/src/extract/extensions.rs b/socketioxide/src/extract/extensions.rs index 5d5624a1..3294b3ac 100644 --- a/socketioxide/src/extract/extensions.rs +++ b/socketioxide/src/extract/extensions.rs @@ -1,7 +1,6 @@ use std::convert::Infallible; use std::sync::Arc; -use crate::adapter::Adapter; use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts}; use crate::socket::{DisconnectReason, Socket}; use bytes::Bytes; @@ -30,7 +29,7 @@ impl std::fmt::Debug for ExtensionNotFound { impl std::error::Error for ExtensionNotFound {} fn extract_http_extension( - s: &Arc>, + s: &Arc, ) -> Result> { s.req_parts() .extensions @@ -44,45 +43,43 @@ pub struct HttpExtension(pub T); /// An Extractor that returns a clone extension from the request parts if it exists. pub struct MaybeHttpExtension(pub Option); -impl FromConnectParts for HttpExtension { +impl FromConnectParts for HttpExtension { type Error = ExtensionNotFound; fn from_connect_parts( - s: &Arc>, + s: &Arc, _: &Option, ) -> Result> { extract_http_extension(s).map(HttpExtension) } } -impl FromConnectParts for MaybeHttpExtension { +impl FromConnectParts for MaybeHttpExtension { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc, _: &Option) -> Result { Ok(MaybeHttpExtension(extract_http_extension(s).ok())) } } -impl FromDisconnectParts for HttpExtension { +impl FromDisconnectParts for HttpExtension { type Error = ExtensionNotFound; fn from_disconnect_parts( - s: &Arc>, + s: &Arc, _: DisconnectReason, ) -> Result> { extract_http_extension(s).map(HttpExtension) } } -impl FromDisconnectParts - for MaybeHttpExtension -{ +impl FromDisconnectParts for MaybeHttpExtension { type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(MaybeHttpExtension(extract_http_extension(s).ok())) } } -impl FromMessageParts for HttpExtension { +impl FromMessageParts for HttpExtension { type Error = ExtensionNotFound; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -90,10 +87,10 @@ impl FromMessageParts for HttpE extract_http_extension(s).map(HttpExtension) } } -impl FromMessageParts for MaybeHttpExtension { +impl FromMessageParts for MaybeHttpExtension { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -110,7 +107,7 @@ mod extensions_extract { use super::*; fn extract_extension( - s: &Arc>, + s: &Arc, ) -> Result> { s.extensions .get::() @@ -127,43 +124,40 @@ mod extensions_extract { /// An Extractor that returns the extension of the given type T if it exists or [`None`] otherwise. pub struct MaybeExtension(pub Option); - impl FromConnectParts for Extension { + impl FromConnectParts for Extension { type Error = ExtensionNotFound; fn from_connect_parts( - s: &Arc>, + s: &Arc, _: &Option, ) -> Result> { extract_extension(s).map(Extension) } } - impl FromConnectParts for MaybeExtension { + impl FromConnectParts for MaybeExtension { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc, _: &Option) -> Result { Ok(MaybeExtension(extract_extension(s).ok())) } } - impl FromDisconnectParts for Extension { + impl FromDisconnectParts for Extension { type Error = ExtensionNotFound; fn from_disconnect_parts( - s: &Arc>, + s: &Arc, _: DisconnectReason, ) -> Result> { extract_extension(s).map(Extension) } } - impl FromDisconnectParts for MaybeExtension { + impl FromDisconnectParts for MaybeExtension { type Error = Infallible; - fn from_disconnect_parts( - s: &Arc>, - _: DisconnectReason, - ) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(MaybeExtension(extract_extension(s).ok())) } } - impl FromMessageParts for Extension { + impl FromMessageParts for Extension { type Error = ExtensionNotFound; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -171,10 +165,10 @@ mod extensions_extract { extract_extension(s).map(Extension) } } - impl FromMessageParts for MaybeExtension { + impl FromMessageParts for MaybeExtension { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, diff --git a/socketioxide/src/extract/mod.rs b/socketioxide/src/extract/mod.rs index 6a90e6da..a3d1889b 100644 --- a/socketioxide/src/extract/mod.rs +++ b/socketioxide/src/extract/mod.rs @@ -56,9 +56,9 @@ //! } //! impl std::error::Error for UserIdNotFound {} //! -//! impl FromConnectParts for UserId { +//! impl FromConnectParts for UserId { //! type Error = Infallible; -//! fn from_connect_parts(s: &Arc>, _: &Option) -> Result { +//! fn from_connect_parts(s: &Arc, _: &Option) -> Result { //! // In a real app it would be better to parse the query params with a crate like `url` //! let uri = &s.req_parts().uri; //! let uid = uri @@ -72,11 +72,11 @@ //! //! // Here, if the user id is not found, the handler won't be called //! // and a tracing `error` log will be emitted (if the `tracing` feature is enabled) -//! impl FromMessageParts for UserId { +//! impl FromMessageParts for UserId { //! type Error = UserIdNotFound; //! //! fn from_message_parts( -//! s: &Arc>, +//! s: &Arc, //! _: &mut serde_json::Value, //! _: &mut Vec, //! _: &Option, diff --git a/socketioxide/src/extract/socket.rs b/socketioxide/src/extract/socket.rs index 914cd361..be4c99f6 100644 --- a/socketioxide/src/extract/socket.rs +++ b/socketioxide/src/extract/socket.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts}; use crate::{ - adapter::{Adapter, LocalAdapter}, errors::{DisconnectError, SendError}, packet::Packet, socket::{DisconnectReason, Socket}, @@ -14,18 +13,18 @@ use serde::Serialize; /// An Extractor that returns a reference to a [`Socket`]. #[derive(Debug)] -pub struct SocketRef(Arc>); +pub struct SocketRef(Arc); -impl FromConnectParts for SocketRef { +impl FromConnectParts for SocketRef { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc, _: &Option) -> Result { Ok(SocketRef(s.clone())) } } -impl FromMessageParts for SocketRef { +impl FromMessageParts for SocketRef { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -33,40 +32,40 @@ impl FromMessageParts for SocketRef { Ok(SocketRef(s.clone())) } } -impl FromDisconnectParts for SocketRef { +impl FromDisconnectParts for SocketRef { type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(SocketRef(s.clone())) } } -impl std::ops::Deref for SocketRef { - type Target = Socket; +impl std::ops::Deref for SocketRef { + type Target = Socket; #[inline(always)] fn deref(&self) -> &Self::Target { &self.0 } } -impl PartialEq for SocketRef { +impl PartialEq for SocketRef { #[inline(always)] fn eq(&self, other: &Self) -> bool { self.0.id == other.0.id } } -impl From>> for SocketRef { +impl From> for SocketRef { #[inline(always)] - fn from(socket: Arc>) -> Self { + fn from(socket: Arc) -> Self { Self(socket) } } -impl Clone for SocketRef { +impl Clone for SocketRef { fn clone(&self) -> Self { Self(self.0.clone()) } } -impl SocketRef { +impl SocketRef { /// Disconnect the socket from the current namespace, /// /// It will also call the disconnect handler if it is set. @@ -79,15 +78,15 @@ impl SocketRef { /// An Extractor to send an ack response corresponding to the current event. /// If the client sent a normal message without expecting an ack, the ack callback will do nothing. #[derive(Debug)] -pub struct AckSender { +pub struct AckSender { binary: Vec, - socket: Arc>, + socket: Arc, ack_id: Option, } -impl FromMessageParts for AckSender { +impl FromMessageParts for AckSender { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, ack_id: &Option, @@ -95,8 +94,8 @@ impl FromMessageParts for AckSender { Ok(Self::new(s.clone(), *ack_id)) } } -impl AckSender { - pub(crate) fn new(socket: Arc>, ack_id: Option) -> Self { +impl AckSender { + pub(crate) fn new(socket: Arc, ack_id: Option) -> Self { Self { binary: vec![], socket, @@ -137,16 +136,16 @@ impl AckSender { } } -impl FromConnectParts for crate::ProtocolVersion { +impl FromConnectParts for crate::ProtocolVersion { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc, _: &Option) -> Result { Ok(s.protocol()) } } -impl FromMessageParts for crate::ProtocolVersion { +impl FromMessageParts for crate::ProtocolVersion { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -154,23 +153,23 @@ impl FromMessageParts for crate::ProtocolVersion { Ok(s.protocol()) } } -impl FromDisconnectParts for crate::ProtocolVersion { +impl FromDisconnectParts for crate::ProtocolVersion { type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(s.protocol()) } } -impl FromConnectParts for crate::TransportType { +impl FromConnectParts for crate::TransportType { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc, _: &Option) -> Result { Ok(s.transport_type()) } } -impl FromMessageParts for crate::TransportType { +impl FromMessageParts for crate::TransportType { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -178,35 +177,35 @@ impl FromMessageParts for crate::TransportType { Ok(s.transport_type()) } } -impl FromDisconnectParts for crate::TransportType { +impl FromDisconnectParts for crate::TransportType { type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(s.transport_type()) } } -impl FromDisconnectParts for DisconnectReason { +impl FromDisconnectParts for DisconnectReason { type Error = Infallible; fn from_disconnect_parts( - _: &Arc>, + _: &Arc, reason: DisconnectReason, ) -> Result { Ok(reason) } } -impl FromConnectParts for SocketIo { +impl FromConnectParts for SocketIo { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc, _: &Option) -> Result { Ok(s.get_io().clone()) } } -impl FromMessageParts for SocketIo { +impl FromMessageParts for SocketIo { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -214,10 +213,10 @@ impl FromMessageParts for SocketIo { Ok(s.get_io().clone()) } } -impl FromDisconnectParts for SocketIo { +impl FromDisconnectParts for SocketIo { type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(s.get_io().clone()) } } diff --git a/socketioxide/src/extract/state.rs b/socketioxide/src/extract/state.rs index 50886aec..db354036 100644 --- a/socketioxide/src/extract/state.rs +++ b/socketioxide/src/extract/state.rs @@ -2,7 +2,6 @@ use bytes::Bytes; use std::sync::Arc; -use crate::adapter::Adapter; use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts}; use crate::socket::{DisconnectReason, Socket}; @@ -57,22 +56,19 @@ impl std::fmt::Debug for StateNotFound { } impl std::error::Error for StateNotFound {} -impl FromConnectParts for State { +impl FromConnectParts for State { type Error = StateNotFound; - fn from_connect_parts( - s: &Arc>, - _: &Option, - ) -> Result> { + fn from_connect_parts(s: &Arc, _: &Option) -> Result> { s.get_io() .get_state::() .map(State) .ok_or(StateNotFound(std::marker::PhantomData)) } } -impl FromDisconnectParts for State { +impl FromDisconnectParts for State { type Error = StateNotFound; fn from_disconnect_parts( - s: &Arc>, + s: &Arc, _: DisconnectReason, ) -> Result> { s.get_io() @@ -81,10 +77,10 @@ impl FromDisconnectParts for St .ok_or(StateNotFound(std::marker::PhantomData)) } } -impl FromMessageParts for State { +impl FromMessageParts for State { type Error = StateNotFound; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, diff --git a/socketioxide/src/handler/connect.rs b/socketioxide/src/handler/connect.rs index 634ac239..6eb077dd 100644 --- a/socketioxide/src/handler/connect.rs +++ b/socketioxide/src/handler/connect.rs @@ -115,26 +115,26 @@ use std::pin::Pin; use std::sync::Arc; -use crate::{adapter::Adapter, socket::Socket}; +use crate::socket::Socket; use futures_core::Future; use super::MakeErasedHandler; /// A Type Erased [`ConnectHandler`] so it can be stored in a HashMap -pub(crate) type BoxedConnectHandler = Box>; +pub(crate) type BoxedConnectHandler = Box; type MiddlewareRes = Result<(), Box>; type MiddlewareResFut<'a> = Pin + Send + 'a>>; -pub(crate) trait ErasedConnectHandler: Send + Sync + 'static { - fn call(&self, s: Arc>, auth: Option); +pub(crate) trait ErasedConnectHandler: Send + Sync + 'static { + fn call(&self, s: Arc, auth: Option); fn call_middleware<'a>( &'a self, - s: Arc>, + s: Arc, auth: &'a Option, ) -> MiddlewareResFut<'a>; - fn boxed_clone(&self) -> BoxedConnectHandler; + fn boxed_clone(&self) -> BoxedConnectHandler; } /// A trait used to extract the arguments from the connect event. @@ -143,13 +143,13 @@ pub(crate) trait ErasedConnectHandler: Send + Sync + 'static { /// /// * See the [`connect`](super::connect) module doc for more details on connect handler. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. -pub trait FromConnectParts: Sized { +pub trait FromConnectParts: Sized { /// The error type returned by the extractor type Error: std::error::Error + Send + 'static; /// Extract the arguments from the connect event. /// If it fails, the handler is not called - fn from_connect_parts(s: &Arc>, auth: &Option) -> Result; + fn from_connect_parts(s: &Arc, auth: &Option) -> Result; } /// Define a middleware for the connect event. @@ -158,16 +158,16 @@ pub trait FromConnectParts: Sized { /// /// * See the [`connect`](super::connect) module doc for more details on connect middlewares. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. -pub trait ConnectMiddleware: Sized + Clone + Send + Sync + 'static { +pub trait ConnectMiddleware: Sized + Clone + Send + Sync + 'static { /// Call the middleware with the given arguments. fn call<'a>( &'a self, - s: Arc>, + s: Arc, auth: &'a Option, ) -> impl Future + Send; #[doc(hidden)] - fn phantom(&self) -> std::marker::PhantomData<(A, T)> { + fn phantom(&self) -> std::marker::PhantomData { std::marker::PhantomData } } @@ -177,14 +177,14 @@ pub trait ConnectMiddleware: Sized + Clone + Send + Sync + 'stati /// /// * See the [`connect`](super::connect) module doc for more details on connect handler. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. -pub trait ConnectHandler: Sized + Clone + Send + Sync + 'static { +pub trait ConnectHandler: Sized + Clone + Send + Sync + 'static { /// Call the handler with the given arguments. - fn call(&self, s: Arc>, auth: Option); + fn call(&self, s: Arc, auth: Option); /// Call the middleware with the given arguments. fn call_middleware<'a>( &'a self, - _: Arc>, + _: Arc, _: &'a Option, ) -> MiddlewareResFut<'a> { Box::pin(async move { Ok(()) }) @@ -234,9 +234,9 @@ pub trait ConnectHandler: Sized + Clone + Send + Sync + 'static { /// let (_, io) = SocketIo::new_layer(); /// io.ns("/", handler.with(middleware).with(other_middleware)); /// ``` - fn with(self, middleware: M) -> impl ConnectHandler + fn with(self, middleware: M) -> impl ConnectHandler where - M: ConnectMiddleware + Send + Sync + 'static, + M: ConnectMiddleware + Send + Sync + 'static, T: Send + Sync + 'static, T1: Send + Sync + 'static, { @@ -252,10 +252,10 @@ pub trait ConnectHandler: Sized + Clone + Send + Sync + 'static { std::marker::PhantomData } } -struct LayeredConnectHandler { +struct LayeredConnectHandler { handler: H, middleware: M, - phantom: std::marker::PhantomData<(A, T, T1)>, + phantom: std::marker::PhantomData<(T, T1)>, } struct ConnectMiddlewareLayer { middleware: M, @@ -263,61 +263,60 @@ struct ConnectMiddlewareLayer { phantom: std::marker::PhantomData<(T, T1)>, } -impl MakeErasedHandler +impl MakeErasedHandler where - H: ConnectHandler + Send + Sync + 'static, + H: ConnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { - pub fn new_ns_boxed(inner: H) -> Box> { + pub fn new_ns_boxed(inner: H) -> Box { Box::new(MakeErasedHandler::new(inner)) } } -impl ErasedConnectHandler for MakeErasedHandler +impl ErasedConnectHandler for MakeErasedHandler where - H: ConnectHandler + Send + Sync + 'static, + H: ConnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { - fn call(&self, s: Arc>, auth: Option) { + fn call(&self, s: Arc, auth: Option) { self.handler.call(s, auth); } fn call_middleware<'a>( &'a self, - s: Arc>, + s: Arc, auth: &'a Option, ) -> MiddlewareResFut<'a> { self.handler.call_middleware(s, auth) } - fn boxed_clone(&self) -> BoxedConnectHandler { + fn boxed_clone(&self) -> BoxedConnectHandler { Box::new(self.clone()) } } -impl ConnectHandler for LayeredConnectHandler +impl ConnectHandler for LayeredConnectHandler where - A: Adapter, - H: ConnectHandler + Send + Sync + 'static, - M: ConnectMiddleware + Send + Sync + 'static, + H: ConnectHandler + Send + Sync + 'static, + M: ConnectMiddleware + Send + Sync + 'static, T: Send + Sync + 'static, T1: Send + Sync + 'static, { - fn call(&self, s: Arc>, auth: Option) { + fn call(&self, s: Arc, auth: Option) { self.handler.call(s, auth); } fn call_middleware<'a>( &'a self, - s: Arc>, + s: Arc, auth: &'a Option, ) -> MiddlewareResFut<'a> { Box::pin(async move { self.middleware.call(s, auth).await }) } - fn with(self, next: M2) -> impl ConnectHandler + fn with(self, next: M2) -> impl ConnectHandler where - M2: ConnectMiddleware + Send + Sync + 'static, + M2: ConnectMiddleware + Send + Sync + 'static, T2: Send + Sync + 'static, { LayeredConnectHandler { @@ -331,19 +330,18 @@ where } } } -impl ConnectMiddleware for LayeredConnectHandler +impl ConnectMiddleware for LayeredConnectHandler where - A: Adapter, - H: ConnectHandler + Send + Sync + 'static, - N: ConnectMiddleware + Send + Sync + 'static, + H: ConnectHandler + Send + Sync + 'static, + N: ConnectMiddleware + Send + Sync + 'static, T: Send + Sync + 'static, T1: Send + Sync + 'static, { - async fn call<'a>(&'a self, s: Arc>, auth: &'a Option) -> MiddlewareRes { + async fn call<'a>(&'a self, s: Arc, auth: &'a Option) -> MiddlewareRes { self.middleware.call(s, auth).await } } -impl Clone for LayeredConnectHandler +impl Clone for LayeredConnectHandler where H: Clone, N: Clone, @@ -370,15 +368,14 @@ where } } -impl ConnectMiddleware for ConnectMiddlewareLayer +impl ConnectMiddleware for ConnectMiddlewareLayer where - A: Adapter, - M: ConnectMiddleware + Send + Sync + 'static, - N: ConnectMiddleware + Send + Sync + 'static, + M: ConnectMiddleware + Send + Sync + 'static, + N: ConnectMiddleware + Send + Sync + 'static, T: Send + Sync + 'static, T1: Send + Sync + 'static, { - async fn call<'a>(&'a self, s: Arc>, auth: &'a Option) -> MiddlewareRes { + async fn call<'a>(&'a self, s: Arc, auth: &'a Option) -> MiddlewareRes { self.middleware.call(s.clone(), auth).await?; self.next.call(s, auth).await } @@ -396,14 +393,17 @@ macro_rules! impl_handler_async { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl ConnectHandler for F + impl ConnectHandler<(private::Async, $($ty,)*)> for F where F: FnOnce($($ty,)*) -> Fut + Send + Sync + Clone + 'static, Fut: Future + Send + 'static, - A: Adapter, - $( $ty: FromConnectParts + Send, )* + $( $ty: FromConnectParts + Send, )* { - fn call(&self, s: Arc>, auth: Option) { + fn call( + &self, + s: Arc, + auth: Option, + ) { $( let $ty = match $ty::from_connect_parts(&s, &auth) { Ok(v) => v, @@ -428,13 +428,16 @@ macro_rules! impl_handler { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl ConnectHandler for F + impl ConnectHandler<(private::Sync, $($ty,)*)> for F where F: FnOnce($($ty,)*) + Send + Sync + Clone + 'static, - A: Adapter, - $( $ty: FromConnectParts + Send, )* + $( $ty: FromConnectParts + Send, )* { - fn call(&self, s: Arc>, auth: Option) { + fn call( + &self, + s: Arc, + auth: Option, + ) { $( let $ty = match $ty::from_connect_parts(&s, &auth) { Ok(v) => v, @@ -457,17 +460,16 @@ macro_rules! impl_middleware_async { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl ConnectMiddleware for F + impl ConnectMiddleware<(private::Async, $($ty,)*)> for F where F: FnOnce($($ty,)*) -> Fut + Send + Sync + Clone + 'static, Fut: Future> + Send + 'static, - A: Adapter, E: std::fmt::Display + Send + 'static, - $( $ty: FromConnectParts + Send, )* + $( $ty: FromConnectParts + Send, )* { async fn call<'a>( &'a self, - s: Arc>, + s: Arc, auth: &'a Option, ) -> MiddlewareRes { $( @@ -499,16 +501,15 @@ macro_rules! impl_middleware { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl ConnectMiddleware for F + impl ConnectMiddleware<(private::Sync, $($ty,)*)> for F where F: FnOnce($($ty,)*) -> Result<(), E> + Send + Sync + Clone + 'static, - A: Adapter, E: std::fmt::Display + Send + 'static, - $( $ty: FromConnectParts + Send, )* + $( $ty: FromConnectParts + Send, )* { async fn call<'a>( &'a self, - s: Arc>, + s: Arc, auth: &'a Option, ) -> MiddlewareRes { $( diff --git a/socketioxide/src/handler/disconnect.rs b/socketioxide/src/handler/disconnect.rs index b63adb15..86a1df94 100644 --- a/socketioxide/src/handler/disconnect.rs +++ b/socketioxide/src/handler/disconnect.rs @@ -58,36 +58,33 @@ use std::sync::Arc; use futures_core::Future; -use crate::{ - adapter::Adapter, - socket::{DisconnectReason, Socket}, -}; +use crate::socket::{DisconnectReason, Socket}; use super::MakeErasedHandler; /// A Type Erased [`DisconnectHandler`] so it can be stored in a HashMap -pub(crate) type BoxedDisconnectHandler = Box>; -pub(crate) trait ErasedDisconnectHandler: Send + Sync + 'static { - fn call(&self, s: Arc>, reason: DisconnectReason); +pub(crate) type BoxedDisconnectHandler = Box; +pub(crate) trait ErasedDisconnectHandler: Send + Sync + 'static { + fn call(&self, s: Arc, reason: DisconnectReason); } -impl MakeErasedHandler +impl MakeErasedHandler where T: Send + Sync + 'static, - H: DisconnectHandler + Send + Sync + 'static, + H: DisconnectHandler + Send + Sync + 'static, { - pub fn new_disconnect_boxed(inner: H) -> Box> { + pub fn new_disconnect_boxed(inner: H) -> Box { Box::new(MakeErasedHandler::new(inner)) } } -impl ErasedDisconnectHandler for MakeErasedHandler +impl ErasedDisconnectHandler for MakeErasedHandler where - H: DisconnectHandler + Send + Sync + 'static, + H: DisconnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { #[inline(always)] - fn call(&self, s: Arc>, reason: DisconnectReason) { + fn call(&self, s: Arc, reason: DisconnectReason) { self.handler.call(s, reason); } } @@ -98,14 +95,14 @@ where /// /// * See the [`disconnect`](super::disconnect) module doc for more details on disconnect handler. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. -pub trait FromDisconnectParts: Sized { +pub trait FromDisconnectParts: Sized { /// The error type returned by the extractor type Error: std::error::Error + 'static; /// Extract the arguments from the disconnect event. /// If it fails, the handler is not called fn from_disconnect_parts( - s: &Arc>, + s: &Arc, reason: DisconnectReason, ) -> Result; } @@ -115,9 +112,9 @@ pub trait FromDisconnectParts: Sized { /// /// * See the [`disconnect`](super::disconnect) module doc for more details on disconnect handler. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. -pub trait DisconnectHandler: Send + Sync + 'static { +pub trait DisconnectHandler: Send + Sync + 'static { /// Call the handler with the given arguments. - fn call(&self, s: Arc>, reason: DisconnectReason); + fn call(&self, s: Arc, reason: DisconnectReason); #[doc(hidden)] fn phantom(&self) -> std::marker::PhantomData { @@ -137,14 +134,13 @@ macro_rules! impl_handler_async { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl DisconnectHandler for F + impl DisconnectHandler<(private::Async, $($ty,)*)> for F where F: FnOnce($($ty,)*) -> Fut + Send + Sync + Clone + 'static, Fut: Future + Send + 'static, - A: Adapter, - $( $ty: FromDisconnectParts + Send, )* + $( $ty: FromDisconnectParts + Send, )* { - fn call(&self, s: Arc>, reason: DisconnectReason) { + fn call(&self, s: Arc, reason: DisconnectReason) { $( let $ty = match $ty::from_disconnect_parts(&s, reason) { Ok(v) => v, @@ -169,13 +165,12 @@ macro_rules! impl_handler { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl DisconnectHandler for F + impl DisconnectHandler<(private::Sync, $($ty,)*)> for F where F: FnOnce($($ty,)*) + Send + Sync + Clone + 'static, - A: Adapter, - $( $ty: FromDisconnectParts + Send, )* + $( $ty: FromDisconnectParts + Send, )* { - fn call(&self, s: Arc>, reason: DisconnectReason) { + fn call(&self, s: Arc, reason: DisconnectReason) { $( let $ty = match $ty::from_disconnect_parts(&s, reason) { Ok(v) => v, diff --git a/socketioxide/src/handler/message.rs b/socketioxide/src/handler/message.rs index 5b4ab255..84bf3f80 100644 --- a/socketioxide/src/handler/message.rs +++ b/socketioxide/src/handler/message.rs @@ -77,16 +77,15 @@ use bytes::Bytes; use futures_core::Future; use serde_json::Value; -use crate::adapter::Adapter; use crate::socket::Socket; use super::MakeErasedHandler; /// A Type Erased [`MessageHandler`] so it can be stored in a HashMap -pub(crate) type BoxedMessageHandler = Box>; +pub(crate) type BoxedMessageHandler = Box; -pub(crate) trait ErasedMessageHandler: Send + Sync + 'static { - fn call(&self, s: Arc>, v: Value, p: Vec, ack_id: Option); +pub(crate) trait ErasedMessageHandler: Send + Sync + 'static { + fn call(&self, s: Arc, v: Value, p: Vec, ack_id: Option); } /// Define a handler for the connect event. @@ -100,9 +99,9 @@ pub(crate) trait ErasedMessageHandler: Send + Sync + 'static { note = "Function argument is not a valid socketio extractor. \nSee `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details", ) )] -pub trait MessageHandler: Send + Sync + 'static { +pub trait MessageHandler: Send + Sync + 'static { /// Call the handler with the given arguments - fn call(&self, s: Arc>, v: Value, p: Vec, ack_id: Option); + fn call(&self, s: Arc, v: Value, p: Vec, ack_id: Option); #[doc(hidden)] fn phantom(&self) -> std::marker::PhantomData { @@ -110,25 +109,23 @@ pub trait MessageHandler: Send + Sync + 'static { } } -impl MakeErasedHandler +impl MakeErasedHandler where T: Send + Sync + 'static, - H: MessageHandler, - A: Adapter, + H: MessageHandler, { - pub fn new_message_boxed(inner: H) -> Box> { + pub fn new_message_boxed(inner: H) -> Box { Box::new(MakeErasedHandler::new(inner)) } } -impl ErasedMessageHandler for MakeErasedHandler +impl ErasedMessageHandler for MakeErasedHandler where T: Send + Sync + 'static, - H: MessageHandler, - A: Adapter, + H: MessageHandler, { #[inline(always)] - fn call(&self, s: Arc>, v: Value, p: Vec, ack_id: Option) { + fn call(&self, s: Arc, v: Value, p: Vec, ack_id: Option) { self.handler.call(s, v, p, ack_id); } } @@ -157,14 +154,14 @@ mod private { note = "Function argument is not a valid socketio extractor. \nSee `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details", ) )] -pub trait FromMessageParts: Sized { +pub trait FromMessageParts: Sized { /// The error type returned by the extractor type Error: std::error::Error + 'static; /// Extract the arguments from the message event. /// If it fails, the handler is not called. fn from_message_parts( - s: &Arc>, + s: &Arc, v: &mut Value, p: &mut Vec, ack_id: &Option, @@ -182,14 +179,14 @@ pub trait FromMessageParts: Sized { note = "Function argument is not a valid socketio extractor. \nSee `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details", ) )] -pub trait FromMessage: Sized { +pub trait FromMessage: Sized { /// The error type returned by the extractor type Error: std::error::Error + 'static; /// Extract the arguments from the message event. /// If it fails, the handler is not called fn from_message( - s: Arc>, + s: Arc, v: Value, p: Vec, ack_id: Option, @@ -197,14 +194,13 @@ pub trait FromMessage: Sized { } /// All the types that implement [`FromMessageParts`] also implement [`FromMessage`] -impl FromMessage for T +impl FromMessage for T where - T: FromMessageParts, - A: Adapter, + T: FromMessageParts, { type Error = T::Error; fn from_message( - s: Arc>, + s: Arc, mut v: Value, mut p: Vec, ack_id: Option, @@ -214,25 +210,23 @@ where } /// Empty Async handler -impl MessageHandler for F +impl MessageHandler<(private::Async,)> for F where F: FnOnce() -> Fut + Send + Sync + Clone + 'static, Fut: Future + Send + 'static, - A: Adapter, { - fn call(&self, _: Arc>, _: Value, _: Vec, _: Option) { + fn call(&self, _: Arc, _: Value, _: Vec, _: Option) { let fut = (self.clone())(); tokio::spawn(fut); } } /// Empty Sync handler -impl MessageHandler for F +impl MessageHandler<(private::Sync,)> for F where F: FnOnce() + Send + Sync + Clone + 'static, - A: Adapter, { - fn call(&self, _: Arc>, _: Value, _: Vec, _: Option) { + fn call(&self, _: Arc, _: Value, _: Vec, _: Option) { (self.clone())(); } } @@ -242,15 +236,14 @@ macro_rules! impl_async_handler { [$($ty:ident),*], $last:ident ) => { #[allow(non_snake_case, unused)] - impl MessageHandler for F + impl MessageHandler<(private::Async, M, $($ty,)* $last,)> for F where F: FnOnce($($ty,)* $last,) -> Fut + Send + Sync + Clone + 'static, Fut: Future + Send + 'static, - A: Adapter, - $( $ty: FromMessageParts + Send, )* - $last: FromMessage + Send, + $( $ty: FromMessageParts + Send, )* + $last: FromMessage + Send, { - fn call(&self, s: Arc>, mut v: Value, mut p: Vec, ack_id: Option) { + fn call(&self, s: Arc, mut v: Value, mut p: Vec, ack_id: Option) { $( let $ty = match $ty::from_message_parts(&s, &mut v, &mut p, &ack_id) { Ok(v) => v, @@ -281,14 +274,13 @@ macro_rules! impl_handler { [$($ty:ident),*], $last:ident ) => { #[allow(non_snake_case, unused)] - impl MessageHandler for F + impl MessageHandler<(private::Sync, M, $($ty,)* $last,)> for F where F: FnOnce($($ty,)* $last,) + Send + Sync + Clone + 'static, - A: Adapter, - $( $ty: FromMessageParts + Send, )* - $last: FromMessage + Send, + $( $ty: FromMessageParts + Send, )* + $last: FromMessage + Send, { - fn call(&self, s: Arc>, mut v: Value, mut p: Vec, ack_id: Option) { + fn call(&self, s: Arc, mut v: Value, mut p: Vec, ack_id: Option) { $( let $ty = match $ty::from_message_parts(&s, &mut v, &mut p, &ack_id) { Ok(v) => v, diff --git a/socketioxide/src/handler/mod.rs b/socketioxide/src/handler/mod.rs index 9850fcc5..693cf992 100644 --- a/socketioxide/src/handler/mod.rs +++ b/socketioxide/src/handler/mod.rs @@ -12,25 +12,22 @@ pub use disconnect::{DisconnectHandler, FromDisconnectParts}; pub(crate) use message::BoxedMessageHandler; pub use message::{FromMessage, FromMessageParts, MessageHandler}; /// A struct used to erase the type of a [`ConnectHandler`] or [`MessageHandler`] so it can be stored in a map -pub(crate) struct MakeErasedHandler { +pub(crate) struct MakeErasedHandler { handler: H, - adapter: std::marker::PhantomData, type_: std::marker::PhantomData, } -impl MakeErasedHandler { +impl MakeErasedHandler { pub fn new(handler: H) -> Self { Self { handler, - adapter: std::marker::PhantomData, type_: std::marker::PhantomData, } } } -impl Clone for MakeErasedHandler { +impl Clone for MakeErasedHandler { fn clone(&self) -> Self { Self { handler: self.handler.clone(), - adapter: std::marker::PhantomData, type_: std::marker::PhantomData, } } diff --git a/socketioxide/src/io.rs b/socketioxide/src/io.rs index 115459e7..df2fac54 100644 --- a/socketioxide/src/io.rs +++ b/socketioxide/src/io.rs @@ -17,11 +17,11 @@ use crate::{ layer::SocketIoLayer, operators::{BroadcastOperators, RoomParam}, service::SocketIoService, - BroadcastError, DisconnectError, + AdapterError, BroadcastError, DisconnectError, }; /// Configuration for Socket.IO & Engine.IO -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct SocketIoConfig { /// The inner Engine.IO config pub engine_config: EngineIoConfig, @@ -35,6 +35,9 @@ pub struct SocketIoConfig { /// /// Defaults to 45 seconds. pub connect_timeout: Duration, + + /// The adapter to use for this server. + pub adapter: Box, } impl Default for SocketIoConfig { @@ -46,6 +49,17 @@ impl Default for SocketIoConfig { }, ack_timeout: Duration::from_secs(5), connect_timeout: Duration::from_secs(45), + adapter: Box::new(LocalAdapter::new()), + } + } +} +impl Clone for SocketIoConfig { + fn clone(&self) -> Self { + Self { + engine_config: self.engine_config.clone(), + ack_timeout: self.ack_timeout, + connect_timeout: self.connect_timeout, + adapter: self.adapter.boxed_clone(), } } } @@ -53,21 +67,19 @@ impl Default for SocketIoConfig { /// A builder to create a [`SocketIo`] instance. /// It contains everything to configure the socket.io server with a [`SocketIoConfig`]. /// It can be used to build either a Tower [`Layer`](tower::layer::Layer) or a [`Service`](tower::Service). -pub struct SocketIoBuilder { +pub struct SocketIoBuilder { config: SocketIoConfig, engine_config_builder: EngineIoConfigBuilder, - adapter: std::marker::PhantomData, #[cfg(feature = "state")] state: state::TypeMap![Send + Sync], } -impl SocketIoBuilder { +impl SocketIoBuilder { /// Creates a new [`SocketIoBuilder`] with default config pub fn new() -> Self { Self { config: SocketIoConfig::default(), engine_config_builder: EngineIoConfigBuilder::new().req_path("/socket.io".to_string()), - adapter: std::marker::PhantomData, #[cfg(feature = "state")] state: std::default::Default::default(), } @@ -158,14 +170,9 @@ impl SocketIoBuilder { } /// Sets a custom [`Adapter`] for this [`SocketIoBuilder`] - pub fn with_adapter(self) -> SocketIoBuilder { - SocketIoBuilder { - config: self.config, - engine_config_builder: self.engine_config_builder, - adapter: std::marker::PhantomData, - #[cfg(feature = "state")] - state: self.state, - } + pub fn with_adapter(mut self, adapter: impl Adapter) -> SocketIoBuilder { + self.config.adapter = Box::new(adapter); + self } /// Add a custom global state for the [`SocketIo`] instance. @@ -183,7 +190,7 @@ impl SocketIoBuilder { /// Builds a [`SocketIoLayer`] and a [`SocketIo`] instance /// /// The layer can be used as a tower layer - pub fn build_layer(mut self) -> (SocketIoLayer, SocketIo) { + pub fn build_layer(mut self) -> (SocketIoLayer, SocketIo) { self.config.engine_config = self.engine_config_builder.build(); let (layer, client) = SocketIoLayer::from_config( @@ -235,9 +242,9 @@ impl Default for SocketIoBuilder { /// The [`SocketIo`] instance can be cheaply cloned and moved around everywhere in your program. /// It can be used as the main handle to access the whole socket.io context. #[derive(Debug)] -pub struct SocketIo(Arc>); +pub struct SocketIo(Arc); -impl SocketIo { +impl SocketIo { /// Creates a new [`SocketIoBuilder`] with a default config #[inline(always)] pub fn builder() -> SocketIoBuilder { @@ -267,7 +274,7 @@ impl SocketIo { } } -impl SocketIo { +impl SocketIo { /// Returns a reference to the [`SocketIoConfig`] used by this [`SocketIo`] instance #[inline] pub fn config(&self) -> &SocketIoConfig { @@ -358,7 +365,7 @@ impl SocketIo { #[inline] pub fn ns(&self, path: impl Into>, callback: C) where - C: ConnectHandler, + C: ConnectHandler, T: Send + Sync + 'static, { self.0.add_ns(path.into(), callback); @@ -403,7 +410,7 @@ impl SocketIo { callback: C, ) -> Result<(), crate::NsInsertError> where - C: ConnectHandler, + C: ConnectHandler, T: Send + Sync + 'static, { self.0.add_dyn_ns(path.into(), callback) @@ -451,7 +458,7 @@ impl SocketIo { /// } /// ``` #[inline] - pub fn of<'a>(&self, path: impl Into<&'a str>) -> Option> { + pub fn of<'a>(&self, path: impl Into<&'a str>) -> Option { self.get_op(path.into()) } @@ -477,7 +484,7 @@ impl SocketIo { /// println!("found socket on / ns in room1 with id: {}", socket.id); /// } #[inline] - pub fn to(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn to(&self, rooms: impl RoomParam) -> BroadcastOperators { self.get_default_op().to(rooms) } @@ -505,7 +512,7 @@ impl SocketIo { /// println!("found socket on / ns in room1 with id: {}", socket.id); /// } #[inline] - pub fn within(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn within(&self, rooms: impl RoomParam) -> BroadcastOperators { self.get_default_op().within(rooms) } @@ -538,7 +545,7 @@ impl SocketIo { /// println!("found socket on / ns in room1 with id: {}", socket.id); /// } #[inline] - pub fn except(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn except(&self, rooms: impl RoomParam) -> BroadcastOperators { self.get_default_op().except(rooms) } @@ -565,7 +572,7 @@ impl SocketIo { /// println!("found socket on / ns in room1 with id: {}", socket.id); /// } #[inline] - pub fn local(&self) -> BroadcastOperators { + pub fn local(&self) -> BroadcastOperators { self.get_default_op().local() } @@ -608,7 +615,7 @@ impl SocketIo { /// } /// }); #[inline] - pub fn timeout(&self, timeout: Duration) -> BroadcastOperators { + pub fn timeout(&self, timeout: Duration) -> BroadcastOperators { self.get_default_op().timeout(timeout) } @@ -637,7 +644,7 @@ impl SocketIo { /// .bin(vec![Bytes::from_static(&[1, 2, 3, 4])]) /// .emit("test", ()); #[inline] - pub fn bin(&self, binary: impl IntoIterator>) -> BroadcastOperators { + pub fn bin(&self, binary: impl IntoIterator>) -> BroadcastOperators { self.get_default_op().bin(binary) } @@ -764,7 +771,7 @@ impl SocketIo { /// println!("found socket on / ns in room1 with id: {}", socket.id); /// } #[inline] - pub fn sockets(&self) -> Result>, A::Error> { + pub fn sockets(&self) -> Result, AdapterError> { self.get_default_op().sockets() } @@ -808,7 +815,7 @@ impl SocketIo { /// // Later in your code you can for example add all sockets on the root namespace to the room1 and room3 /// io.join(["room1", "room3"]).unwrap(); #[inline] - pub fn join(self, rooms: impl RoomParam) -> Result<(), A::Error> { + pub fn join(self, rooms: impl RoomParam) -> Result<(), AdapterError> { self.get_default_op().join(rooms) } @@ -828,7 +835,7 @@ impl SocketIo { /// let rooms = io.rooms().unwrap(); /// println!("All rooms on / namespace: {:?}", rooms); /// }); - pub fn rooms(&self) -> Result, A::Error> { + pub fn rooms(&self) -> Result, AdapterError> { self.get_default_op().rooms() } @@ -850,13 +857,13 @@ impl SocketIo { /// // Later in your code you can for example remove all sockets on the root namespace from the room1 and room3 /// io.leave(["room1", "room3"]).unwrap(); #[inline] - pub fn leave(self, rooms: impl RoomParam) -> Result<(), A::Error> { + pub fn leave(self, rooms: impl RoomParam) -> Result<(), AdapterError> { self.get_default_op().leave(rooms) } /// Gets a [`SocketRef`] by the specified [`Sid`]. #[inline] - pub fn get_socket(&self, sid: Sid) -> Option> { + pub fn get_socket(&self, sid: Sid) -> Option { self.get_default_op().get_socket(sid) } @@ -867,7 +874,7 @@ impl SocketIo { /// Returns a new operator on the given namespace #[inline(always)] - fn get_op(&self, path: &str) -> Option> { + fn get_op(&self, path: &str) -> Option { self.0 .get_ns(path) .map(|ns| BroadcastOperators::new(ns).broadcast()) @@ -879,24 +886,24 @@ impl SocketIo { /// /// If the **default namespace "/" is not found** this fn will panic! #[inline(always)] - fn get_default_op(&self) -> BroadcastOperators { + fn get_default_op(&self) -> BroadcastOperators { self.get_op("/").expect("default namespace not found") } } -impl Clone for SocketIo { +impl Clone for SocketIo { fn clone(&self) -> Self { Self(self.0.clone()) } } -impl From>> for SocketIo { - fn from(client: Arc>) -> Self { +impl From> for SocketIo { + fn from(client: Arc) -> Self { SocketIo(client) } } #[cfg(any(test, socketioxide_test))] -impl SocketIo { +impl SocketIo { /// Create a dummy socket for testing purpose with a /// receiver to get the packets sent to the client pub async fn new_dummy_sock( diff --git a/socketioxide/src/layer.rs b/socketioxide/src/layer.rs index ac7e1371..405f515d 100644 --- a/socketioxide/src/layer.rs +++ b/socketioxide/src/layer.rs @@ -20,19 +20,14 @@ use std::sync::Arc; use tower::Layer; -use crate::{ - adapter::{Adapter, LocalAdapter}, - client::Client, - service::SocketIoService, - SocketIoConfig, -}; +use crate::{client::Client, service::SocketIoService, SocketIoConfig}; /// A [`Layer`] for [`SocketIoService`], acting as a middleware. -pub struct SocketIoLayer { - client: Arc>, +pub struct SocketIoLayer { + client: Arc, } -impl Clone for SocketIoLayer { +impl Clone for SocketIoLayer { fn clone(&self) -> Self { Self { client: self.client.clone(), @@ -40,11 +35,11 @@ impl Clone for SocketIoLayer { } } -impl SocketIoLayer { +impl SocketIoLayer { pub(crate) fn from_config( config: SocketIoConfig, #[cfg(feature = "state")] state: state::TypeMap![Send + Sync], - ) -> (Self, Arc>) { + ) -> (Self, Arc) { let client = Arc::new(Client::new( config, #[cfg(feature = "state")] @@ -57,8 +52,8 @@ impl SocketIoLayer { } } -impl Layer for SocketIoLayer { - type Service = SocketIoService; +impl Layer for SocketIoLayer { + type Service = SocketIoService; fn layer(&self, inner: S) -> Self::Service { SocketIoService::with_client(inner, self.client.clone()) diff --git a/socketioxide/src/ns.rs b/socketioxide/src/ns.rs index dde8d81c..fb228388 100644 --- a/socketioxide/src/ns.rs +++ b/socketioxide/src/ns.rs @@ -5,59 +5,65 @@ use std::{ use crate::{ adapter::Adapter, - errors::{ConnectFail, Error}, + client::SocketData, + errors::{AdapterError, ConnectFail, Error}, handler::{BoxedConnectHandler, ConnectHandler, MakeErasedHandler}, packet::{Packet, PacketData}, socket::{DisconnectReason, Socket}, }; -use crate::{client::SocketData, errors::AdapterError}; use engineioxide::{sid::Sid, Str}; /// A [`Namespace`] constructor used for dynamic namespaces /// A namespace constructor only hold a common handler that will be cloned /// to the instantiated namespaces. -pub struct NamespaceCtr { - handler: BoxedConnectHandler, +pub struct NamespaceCtr { + handler: BoxedConnectHandler, } -pub struct Namespace { +pub struct Namespace { pub path: Str, - pub(crate) adapter: A, - handler: BoxedConnectHandler, - sockets: RwLock>>>, + pub(crate) adapter: Box, + handler: BoxedConnectHandler, + sockets: RwLock>>, } /// ===== impl NamespaceCtr ===== -impl NamespaceCtr { +impl NamespaceCtr { pub fn new(handler: C) -> Self where - C: ConnectHandler + Send + Sync + 'static, + C: ConnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { Self { handler: MakeErasedHandler::new_ns_boxed(handler), } } - pub fn get_new_ns(&self, path: Str) -> Arc> { - Arc::new_cyclic(|ns| Namespace { - path, - handler: self.handler.boxed_clone(), - sockets: HashMap::new().into(), - adapter: A::new(ns.clone()), + pub fn get_new_ns(&self, path: Str, mut adapter: Box) -> Arc { + Arc::new_cyclic(|ns| { + adapter.init(ns.clone()).ok(); + Namespace { + path, + handler: self.handler.boxed_clone(), + sockets: HashMap::new().into(), + adapter, + } }) } } -impl Namespace { - pub fn new(path: Str, handler: C) -> Arc +impl Namespace { + pub fn new(path: Str, handler: C, mut adapter: Box) -> Arc where - C: ConnectHandler + Send + Sync + 'static, + C: ConnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { - Arc::new_cyclic(|ns| Self { - path, - handler: MakeErasedHandler::new_ns_boxed(handler), - sockets: HashMap::new().into(), - adapter: A::new(ns.clone()), + Arc::new_cyclic(move |ns: &std::sync::Weak<_>| { + adapter.init(ns.clone()).ok(); + Self { + path, + handler: MakeErasedHandler::new_ns_boxed(handler), + sockets: HashMap::new().into(), + adapter, + } }) } @@ -70,10 +76,10 @@ impl Namespace { pub(crate) async fn connect( self: Arc, sid: Sid, - esocket: Arc>>, + esocket: Arc>, auth: Option, ) -> Result<(), ConnectFail> { - let socket: Arc> = Socket::new(sid, self.clone(), esocket.clone()).into(); + let socket: Arc = Socket::new(sid, self.clone(), esocket.clone()).into(); if let Err(e) = self.handler.call_middleware(socket.clone(), &auth).await { #[cfg(feature = "tracing")] @@ -130,7 +136,7 @@ impl Namespace { } } - pub fn get_socket(&self, sid: Sid) -> Result>, Error> { + pub fn get_socket(&self, sid: Sid) -> Result, Error> { self.sockets .read() .unwrap() @@ -139,7 +145,7 @@ impl Namespace { .ok_or(Error::SocketGone(sid)) } - pub fn get_sockets(&self) -> Vec>> { + pub fn get_sockets(&self) -> Vec> { self.sockets.read().unwrap().values().cloned().collect() } @@ -183,9 +189,10 @@ impl Namespace { } #[cfg(any(test, socketioxide_test))] -impl Namespace { +impl Namespace { pub fn new_dummy(sockets: [Sid; S]) -> Arc { - let ns = Namespace::new("/".into(), || {}); + use crate::adapter::LocalAdapter; + let ns = Namespace::new("/".into(), || {}, Box::new(LocalAdapter::new())); for sid in sockets { ns.sockets .write() @@ -200,7 +207,7 @@ impl Namespace { } } -impl std::fmt::Debug for Namespace { +impl std::fmt::Debug for Namespace { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Namespace") .field("path", &self.path) @@ -211,7 +218,7 @@ impl std::fmt::Debug for Namespace { } #[cfg(feature = "tracing")] -impl Drop for Namespace { +impl Drop for Namespace { fn drop(&mut self) { #[cfg(feature = "tracing")] tracing::debug!("dropping namespace {}", self.path); diff --git a/socketioxide/src/operators.rs b/socketioxide/src/operators.rs index 22a4aeb2..32f6acb8 100644 --- a/socketioxide/src/operators.rs +++ b/socketioxide/src/operators.rs @@ -13,13 +13,11 @@ use bytes::Bytes; use engineioxide::sid::Sid; use crate::ack::{AckInnerStream, AckStream}; -use crate::adapter::LocalAdapter; -use crate::errors::{BroadcastError, DisconnectError}; +use crate::errors::{AdapterError, BroadcastError, DisconnectError, SendError}; use crate::extract::SocketRef; use crate::socket::Socket; -use crate::SendError; use crate::{ - adapter::{Adapter, BroadcastFlags, BroadcastOptions, Room}, + adapter::{BroadcastFlags, BroadcastOptions, Room}, ns::Namespace, packet::Packet, }; @@ -103,21 +101,21 @@ impl RoomParam for Sid { } /// Chainable operators to configure the message to be sent. -pub struct ConfOperators<'a, A: Adapter = LocalAdapter> { +pub struct ConfOperators<'a> { binary: Vec, timeout: Option, - socket: &'a Socket, + socket: &'a Socket, } /// Chainable operators to select sockets to send a message to and to configure the message to be sent. -pub struct BroadcastOperators { +pub struct BroadcastOperators { binary: Vec, timeout: Option, - ns: Arc>, + ns: Arc, opts: BroadcastOptions, } -impl From> for BroadcastOperators { - fn from(conf: ConfOperators<'_, A>) -> Self { +impl From> for BroadcastOperators { + fn from(conf: ConfOperators<'_>) -> Self { let opts = BroadcastOptions { sid: Some(conf.socket.id), ..Default::default() @@ -132,8 +130,8 @@ impl From> for BroadcastOperators { } // ==== impl ConfOperators operations ==== -impl<'a, A: Adapter> ConfOperators<'a, A> { - pub(crate) fn new(sender: &'a Socket) -> Self { +impl<'a> ConfOperators<'a> { + pub(crate) fn new(sender: &'a Socket) -> Self { Self { binary: vec![], timeout: None, @@ -161,7 +159,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { /// .emit("test", data); /// }); /// }); - pub fn to(self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn to(self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from(self).to(rooms) } @@ -185,7 +183,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { /// .emit("test", data); /// }); /// }); - pub fn within(self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn within(self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from(self).within(rooms) } @@ -208,7 +206,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { /// socket.broadcast().except("room1").emit("test", data); /// }); /// }); - pub fn except(self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn except(self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from(self).except(rooms) } @@ -225,7 +223,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { /// socket.local().emit("test", data); /// }); /// }); - pub fn local(self) -> BroadcastOperators { + pub fn local(self) -> BroadcastOperators { BroadcastOperators::from(self).local() } @@ -241,7 +239,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { /// socket.broadcast().emit("test", data); /// }); /// }); - pub fn broadcast(self) -> BroadcastOperators { + pub fn broadcast(self) -> BroadcastOperators { BroadcastOperators::from(self).broadcast() } @@ -304,7 +302,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { } // ==== impl ConfOperators consume fns ==== -impl ConfOperators<'_, A> { +impl ConfOperators<'_> { /// Emits a message to the client and apply the previous operators on the message. /// /// If you provide array-like data (tuple, vec, arrays), it will be considered as multiple arguments. @@ -454,7 +452,7 @@ impl ConfOperators<'_, A> { /// socket.within("room1").within("room3").join(["room4", "room5"]).unwrap(); /// }); /// }); - pub fn join(self, rooms: impl RoomParam) -> Result<(), A::Error> { + pub fn join(self, rooms: impl RoomParam) -> Result<(), AdapterError> { self.socket.join(rooms) } @@ -470,12 +468,12 @@ impl ConfOperators<'_, A> { /// socket.within("room1").within("room3").leave(["room4", "room5"]).unwrap(); /// }); /// }); - pub fn leave(self, rooms: impl RoomParam) -> Result<(), A::Error> { + pub fn leave(self, rooms: impl RoomParam) -> Result<(), AdapterError> { self.socket.leave(rooms) } /// Gets all room names for a given namespace - pub fn rooms(self) -> Result, A::Error> { + pub fn rooms(self) -> Result, AdapterError> { self.socket.rooms() } @@ -497,8 +495,8 @@ impl ConfOperators<'_, A> { } } -impl BroadcastOperators { - pub(crate) fn new(ns: Arc>) -> Self { +impl BroadcastOperators { + pub(crate) fn new(ns: Arc) -> Self { Self { binary: vec![], timeout: None, @@ -506,7 +504,7 @@ impl BroadcastOperators { opts: BroadcastOptions::default(), } } - pub(crate) fn from_sock(ns: Arc>, sid: Sid) -> Self { + pub(crate) fn from_sock(ns: Arc, sid: Sid) -> Self { Self { binary: vec![], timeout: None, @@ -686,7 +684,7 @@ impl BroadcastOperators { } // ==== impl BroadcastOperators consume fns ==== -impl BroadcastOperators { +impl BroadcastOperators { /// Emits a message to all sockets selected with the previous operators. /// /// If you provide array-like data (tuple, vec, arrays), it will be considered as multiple arguments. @@ -828,7 +826,7 @@ impl BroadcastOperators { /// } /// }); /// }); - pub fn sockets(self) -> Result>, A::Error> { + pub fn sockets(self) -> Result, AdapterError> { self.ns.adapter.fetch_sockets(self.opts) } @@ -860,8 +858,10 @@ impl BroadcastOperators { /// socket.within("room1").within("room3").join(["room4", "room5"]).unwrap(); /// }); /// }); - pub fn join(self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.add_sockets(self.opts, rooms) + pub fn join(self, rooms: impl RoomParam) -> Result<(), AdapterError> { + self.ns + .adapter + .add_sockets(self.opts, rooms.into_room_iter().collect()) } /// Makes all sockets selected with the previous operators leave the given room(s). @@ -876,17 +876,19 @@ impl BroadcastOperators { /// socket.within("room1").within("room3").leave(["room4", "room5"]).unwrap(); /// }); /// }); - pub fn leave(self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.del_sockets(self.opts, rooms) + pub fn leave(self, rooms: impl RoomParam) -> Result<(), AdapterError> { + self.ns + .adapter + .del_sockets(self.opts, rooms.into_room_iter().collect()) } /// Gets all room names for a given namespace - pub fn rooms(self) -> Result, A::Error> { + pub fn rooms(self) -> Result, AdapterError> { self.ns.adapter.rooms() } /// Gets a [`SocketRef`] by the specified [`Sid`]. - pub fn get_socket(&self, sid: Sid) -> Option> { + pub fn get_socket(&self, sid: Sid) -> Option { self.ns.get_socket(sid).map(SocketRef::from).ok() } diff --git a/socketioxide/src/service.rs b/socketioxide/src/service.rs index 6ce0ef33..6be8c58e 100644 --- a/socketioxide/src/service.rs +++ b/socketioxide/src/service.rs @@ -53,31 +53,26 @@ use std::{ }; use tower::Service as TowerSvc; -use crate::{ - adapter::{Adapter, LocalAdapter}, - client::Client, - SocketIoConfig, -}; +use crate::{client::Client, SocketIoConfig}; /// A [`Tower`](TowerSvc)/[`Hyper`](HyperSvc) Service that wraps [`EngineIoService`] and /// redirect every request to it -pub struct SocketIoService { - engine_svc: EngineIoService, S>, +pub struct SocketIoService { + engine_svc: EngineIoService, } /// Tower Service implementation. -impl TowerSvc> for SocketIoService +impl TowerSvc> for SocketIoService where ReqBody: Body + Send + Unpin + std::fmt::Debug + 'static, ::Error: std::fmt::Debug, ::Data: Send, ResBody: Body + Send + 'static, S: TowerSvc, Response = Response> + Clone, - A: Adapter, { - type Response = , S> as TowerSvc>>::Response; - type Error = , S> as TowerSvc>>::Error; - type Future = , S> as TowerSvc>>::Future; + type Response = as TowerSvc>>::Response; + type Error = as TowerSvc>>::Error; + type Future = as TowerSvc>>::Future; #[inline(always)] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { @@ -90,18 +85,17 @@ where } /// Hyper 1.0 Service implementation. -impl HyperSvc> for SocketIoService +impl HyperSvc> for SocketIoService where ReqBody: Body + Send + Unpin + std::fmt::Debug + 'static, ::Error: std::fmt::Debug, ::Data: Send, ResBody: Body + Send + 'static, S: HyperSvc, Response = Response> + Clone, - A: Adapter, { - type Response = , S> as HyperSvc>>::Response; - type Error = , S> as HyperSvc>>::Error; - type Future = , S> as HyperSvc>>::Future; + type Response = as HyperSvc>>::Response; + type Error = as HyperSvc>>::Error; + type Future = as HyperSvc>>::Future; #[inline(always)] fn call(&self, req: Request) -> Self::Future { @@ -109,10 +103,10 @@ where } } -impl SocketIoService { +impl SocketIoService { /// Creates a MakeService which can be used as a hyper service #[inline(always)] - pub fn into_make_service(self) -> MakeEngineIoService, S> { + pub fn into_make_service(self) -> MakeEngineIoService { self.engine_svc.into_make_service() } @@ -121,7 +115,7 @@ impl SocketIoService { inner: S, config: SocketIoConfig, #[cfg(feature = "state")] state: state::TypeMap![Send + Sync], - ) -> (Self, Arc>) { + ) -> (Self, Arc) { let engine_config = config.engine_config.clone(); let client = Arc::new(Client::new( config, @@ -134,14 +128,14 @@ impl SocketIoService { /// Creates a new [`EngineIoService`] with a custom inner service and an existing client /// It is mainly used with a [`SocketIoLayer`](crate::layer::SocketIoLayer) that owns the client - pub(crate) fn with_client(inner: S, client: Arc>) -> Self { + pub(crate) fn with_client(inner: S, client: Arc) -> Self { let engine_config = client.config.engine_config.clone(); let svc = EngineIoService::with_config_inner(inner, client, engine_config); Self { engine_svc: svc } } } -impl Clone for SocketIoService { +impl Clone for SocketIoService { fn clone(&self) -> Self { Self { engine_svc: self.engine_svc.clone(), diff --git a/socketioxide/src/socket.rs b/socketioxide/src/socket.rs index 459985ec..918c32b5 100644 --- a/socketioxide/src/socket.rs +++ b/socketioxide/src/socket.rs @@ -23,7 +23,7 @@ use crate::extensions::Extensions; use crate::{ ack::{AckInnerStream, AckResponse, AckResult, AckStream}, - adapter::{Adapter, LocalAdapter, Room}, + adapter::Room, errors::{DisconnectError, Error, SendError}, handler::{ BoxedDisconnectHandler, BoxedMessageHandler, DisconnectHandler, MakeErasedHandler, @@ -127,10 +127,10 @@ impl<'a> PermitExt<'a> for Permit<'a> { /// A Socket represents a client connected to a namespace. /// It is used to send and receive messages from the client, join and leave rooms, etc. /// The socket struct itself should not be used directly, but through a [`SocketRef`](crate::extract::SocketRef). -pub struct Socket { - pub(crate) ns: Arc>, - message_handlers: RwLock, BoxedMessageHandler>>, - disconnect_handler: Mutex>>, +pub struct Socket { + pub(crate) ns: Arc, + message_handlers: RwLock, BoxedMessageHandler>>, + disconnect_handler: Mutex>, ack_message: Mutex>>>, ack_counter: AtomicI64, connected: AtomicBool, @@ -145,14 +145,14 @@ pub struct Socket { #[cfg_attr(docsrs, doc(cfg(feature = "extensions")))] #[cfg(feature = "extensions")] pub extensions: Extensions, - esocket: Arc>>, + esocket: Arc>, } -impl Socket { +impl Socket { pub(crate) fn new( sid: Sid, - ns: Arc>, - esocket: Arc>>, + ns: Arc, + esocket: Arc>, ) -> Self { Self { ns, @@ -222,7 +222,7 @@ impl Socket { /// ``` pub fn on(&self, event: impl Into>, handler: H) where - H: MessageHandler, + H: MessageHandler, T: Send + Sync + 'static, { self.message_handlers @@ -256,7 +256,7 @@ impl Socket { /// }); pub fn on_disconnect(&self, callback: C) where - C: DisconnectHandler + Send + Sync + 'static, + C: DisconnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { let handler = MakeErasedHandler::new_disconnect_boxed(callback); @@ -409,8 +409,10 @@ impl Socket { /// ## Errors /// When using a distributed adapter, it can return an [`Adapter::Error`] which is mostly related to network errors. /// For the default [`LocalAdapter`] it is always an [`Infallible`](std::convert::Infallible) error - pub fn join(&self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.add_all(self.id, rooms) + pub fn join(&self, rooms: impl RoomParam) -> Result<(), AdapterError> { + self.ns + .adapter + .add_all(self.id, rooms.into_room_iter().collect()) } /// Leaves the given rooms. @@ -419,15 +421,17 @@ impl Socket { /// ## Errors /// When using a distributed adapter, it can return an [`Adapter::Error`] which is mostly related to network errors. /// For the default [`LocalAdapter`] it is always an [`Infallible`](std::convert::Infallible) error - pub fn leave(&self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.del(self.id, rooms) + pub fn leave(&self, rooms: impl RoomParam) -> Result<(), AdapterError> { + self.ns + .adapter + .del(self.id, rooms.into_room_iter().collect()) } /// Leaves all rooms where the socket is connected. /// ## Errors /// When using a distributed adapter, it can return an [`Adapter::Error`] which is mostly related to network errors. /// For the default [`LocalAdapter`] it is always an [`Infallible`](std::convert::Infallible) error - pub fn leave_all(&self) -> Result<(), A::Error> { + pub fn leave_all(&self) -> Result<(), AdapterError> { self.ns.adapter.del_all(self.id) } @@ -435,7 +439,7 @@ impl Socket { /// ## Errors /// When using a distributed adapter, it can return an [`Adapter::Error`] which is mostly related to network errors. /// For the default [`LocalAdapter`] it is always an [`Infallible`](std::convert::Infallible) error - pub fn rooms(&self) -> Result, A::Error> { + pub fn rooms(&self) -> Result, AdapterError> { self.ns.adapter.socket_rooms(self.id) } @@ -469,7 +473,7 @@ impl Socket { /// .emit("test", data); /// }); /// }); - pub fn to(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn to(&self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from_sock(self.ns.clone(), self.id).to(rooms) } @@ -493,7 +497,7 @@ impl Socket { /// .emit("test", data); /// }); /// }); - pub fn within(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn within(&self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from_sock(self.ns.clone(), self.id).within(rooms) } @@ -517,7 +521,7 @@ impl Socket { /// socket.broadcast().except("room1").emit("test", data); /// }); /// }); - pub fn except(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn except(&self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from_sock(self.ns.clone(), self.id).except(rooms) } @@ -535,7 +539,7 @@ impl Socket { /// socket.local().emit("test", data); /// }); /// }); - pub fn local(&self) -> BroadcastOperators { + pub fn local(&self) -> BroadcastOperators { BroadcastOperators::from_sock(self.ns.clone(), self.id).local() } @@ -574,7 +578,7 @@ impl Socket { /// }); /// }); /// - pub fn timeout(&self, timeout: Duration) -> ConfOperators<'_, A> { + pub fn timeout(&self, timeout: Duration) -> ConfOperators<'_> { ConfOperators::new(self).timeout(timeout) } @@ -591,7 +595,7 @@ impl Socket { /// socket.bin(bin).emit("test", data); /// }); /// }); - pub fn bin(&self, binary: impl IntoIterator>) -> ConfOperators<'_, A> { + pub fn bin(&self, binary: impl IntoIterator>) -> ConfOperators<'_> { ConfOperators::new(self).bin(binary) } @@ -608,7 +612,7 @@ impl Socket { /// socket.broadcast().emit("test", data); /// }); /// }); - pub fn broadcast(&self) -> BroadcastOperators { + pub fn broadcast(&self) -> BroadcastOperators { BroadcastOperators::from_sock(self.ns.clone(), self.id).broadcast() } @@ -617,7 +621,7 @@ impl Socket { /// # Panics /// Because [`SocketData::io`] should be immediately set at the creation of the socket. /// this should never panic. - pub(crate) fn get_io(&self) -> &SocketIo { + pub(crate) fn get_io(&self) -> &SocketIo { self.esocket.data.io.get().unwrap() } @@ -805,7 +809,7 @@ impl Socket { } } -impl Debug for Socket { +impl Debug for Socket { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Socket") .field("ns", &self.ns()) @@ -815,22 +819,22 @@ impl Debug for Socket { .finish() } } -impl PartialEq for Socket { +impl PartialEq for Socket { fn eq(&self, other: &Self) -> bool { self.id == other.id } } #[cfg(any(test, socketioxide_test))] -impl Socket { +impl Socket { /// Creates a dummy socket for testing purposes - pub fn new_dummy(sid: Sid, ns: Arc>) -> Socket { + pub fn new_dummy(sid: Sid, ns: Arc) -> Socket { use crate::client::Client; use crate::io::SocketIoConfig; let close_fn = Box::new(move |_, _| ()); let config = SocketIoConfig::default(); - let io = SocketIo::from(Arc::new(Client::::new( + let io = SocketIo::from(Arc::new(Client::new( config, std::default::Default::default(), ))); @@ -848,7 +852,7 @@ mod test { #[tokio::test] async fn send_with_ack_error() { let sid = Sid::new(); - let ns = Namespace::::new_dummy([sid]).into(); + let ns = Namespace::new_dummy([sid]).into(); let socket: Arc = Socket::new_dummy(sid, ns).into(); // Saturate the channel for _ in 0..1024 { diff --git a/socketioxide/tests/fixture.rs b/socketioxide/tests/fixture.rs index 04d23dc0..b384663a 100644 --- a/socketioxide/tests/fixture.rs +++ b/socketioxide/tests/fixture.rs @@ -17,7 +17,7 @@ use hyper_util::{ }; use serde::{Deserialize, Serialize}; -use socketioxide::{adapter::LocalAdapter, service::SocketIoService, SocketIo}; +use socketioxide::{service::SocketIoService, SocketIo}; use tokio::net::{TcpListener, TcpStream}; use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream}; @@ -111,7 +111,7 @@ pub async fn create_server(port: u16) -> SocketIo { io } -async fn spawn_server(port: u16, svc: SocketIoService) { +async fn spawn_server(port: u16, svc: SocketIoService) { let addr = &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port); let listener = TcpListener::bind(&addr).await.unwrap(); tokio::spawn(async move {