From 919ca81dc59b239b99b7545175e71a1e3b8feaa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9odore=20Pr=C3=A9vot?= Date: Fri, 24 May 2024 15:38:53 +0000 Subject: [PATCH] feat(adapter): remove adapter generic param and use boxed --- socketioxide/src/ack.rs | 10 +- socketioxide/src/adapter.rs | 219 ++++++++++++++----------- socketioxide/src/client.rs | 19 +-- socketioxide/src/errors.rs | 1 - socketioxide/src/extract/data.rs | 24 ++- socketioxide/src/extract/extensions.rs | 58 +++---- socketioxide/src/extract/mod.rs | 8 +- socketioxide/src/extract/socket.rs | 69 ++++---- socketioxide/src/extract/state.rs | 15 +- socketioxide/src/handler/connect.rs | 107 ++++++------ socketioxide/src/handler/disconnect.rs | 40 +++-- socketioxide/src/handler/message.rs | 67 ++++---- socketioxide/src/handler/mod.rs | 6 +- socketioxide/src/io.rs | 75 +++++---- socketioxide/src/layer.rs | 15 +- socketioxide/src/ns.rs | 51 +++--- socketioxide/src/operators.rs | 64 ++++---- socketioxide/src/service.rs | 37 ++--- socketioxide/src/socket.rs | 58 ++++--- socketioxide/tests/fixture.rs | 2 +- 20 files changed, 473 insertions(+), 472 deletions(-) diff --git a/socketioxide/src/ack.rs b/socketioxide/src/ack.rs index 699f8f74..14f8ff8a 100644 --- a/socketioxide/src/ack.rs +++ b/socketioxide/src/ack.rs @@ -145,9 +145,9 @@ impl AckInnerStream { /// /// The [`AckInnerStream`] will wait for the default timeout specified in the config /// (5s by default) if no custom timeout is specified. - pub fn broadcast( + pub fn broadcast( packet: Packet<'static>, - sockets: Vec>, + sockets: Vec, duration: Option, ) -> Self { let rxs = FuturesUnordered::new(); @@ -311,13 +311,13 @@ mod test { use engineioxide::sid::Sid; use futures_util::StreamExt; - use crate::{adapter::LocalAdapter, ns::Namespace, socket::Socket}; + use crate::{ns::Namespace, socket::Socket}; use super::*; - fn create_socket() -> Arc> { + fn create_socket() -> Arc { let sid = Sid::new(); - let ns = Namespace::::new_dummy([sid]).into(); + let ns = Namespace::new_dummy([sid]).into(); let socket = Socket::new_dummy(sid, ns); socket.into() } diff --git a/socketioxide/src/adapter.rs b/socketioxide/src/adapter.rs index 0e875b19..2ed1c5a8 100644 --- a/socketioxide/src/adapter.rs +++ b/socketioxide/src/adapter.rs @@ -6,7 +6,6 @@ use std::{ borrow::Cow, collections::{HashMap, HashSet}, - convert::Infallible, sync::{RwLock, Weak}, time::Duration, }; @@ -48,32 +47,28 @@ pub struct BroadcastOptions { pub sid: Option, } //TODO: Make an AsyncAdapter trait -/// An adapter is responsible for managing the state of the server. +/// An adapter is responsible for managing the state of the server. There is one adapter per namespace. /// This adapter can be implemented to share the state between multiple servers. /// The default adapter is the [`LocalAdapter`], which stores the state in memory. pub trait Adapter: std::fmt::Debug + Send + Sync + 'static { - /// An error that can occur when using the adapter. The default [`LocalAdapter`] has an [`Infallible`] error. - type Error: std::error::Error + Into + Send + Sync + 'static; - - /// Create a new adapter and give the namespace ref to retrieve sockets. - fn new(ns: Weak>) -> Self - where - Self: Sized; + /// Returns a boxed clone of the adapter. + /// It is used to create a new empty instance of the adapter for a new namespace. + fn boxed_clone(&self) -> Box; /// Initializes the adapter. - fn init(&self) -> Result<(), Self::Error>; + fn init(&mut self, ns: Weak) -> Result<(), AdapterError>; /// Closes the adapter. - fn close(&self) -> Result<(), Self::Error>; + fn close(&self) -> Result<(), AdapterError>; /// Returns the number of servers. - fn server_count(&self) -> Result; + fn server_count(&self) -> Result; /// Adds the socket to all the rooms. - fn add_all(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Self::Error>; + fn add_all(&self, sid: Sid, rooms: Vec) -> Result<(), AdapterError>; /// Removes the socket from the rooms. - fn del(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Self::Error>; + fn del(&self, sid: Sid, rooms: Vec) -> Result<(), AdapterError>; /// Removes the socket from all the rooms. - fn del_all(&self, sid: Sid) -> Result<(), Self::Error>; + fn del_all(&self, sid: Sid) -> Result<(), AdapterError>; /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`]. fn broadcast(&self, packet: Packet<'_>, opts: BroadcastOptions) -> Result<(), BroadcastError>; @@ -87,28 +82,24 @@ pub trait Adapter: std::fmt::Debug + Send + Sync + 'static { ) -> AckInnerStream; /// Returns the sockets ids that match the [`BroadcastOptions`]. - fn sockets(&self, rooms: impl RoomParam) -> Result, Self::Error>; + fn sockets(&self, rooms: Vec) -> Result, AdapterError>; /// Returns the rooms of the socket. - fn socket_rooms(&self, sid: Sid) -> Result, Self::Error>; + fn socket_rooms(&self, sid: Sid) -> Result, AdapterError>; /// Returns the sockets that match the [`BroadcastOptions`]. - fn fetch_sockets(&self, opts: BroadcastOptions) -> Result>, Self::Error> - where - Self: Sized; + fn fetch_sockets(&self, opts: BroadcastOptions) -> Result, AdapterError>; /// Adds the sockets that match the [`BroadcastOptions`] to the rooms. - fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) - -> Result<(), Self::Error>; + fn add_sockets(&self, opts: BroadcastOptions, rooms: Vec) -> Result<(), AdapterError>; /// Removes the sockets that match the [`BroadcastOptions`] from the rooms. - fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) - -> Result<(), Self::Error>; + fn del_sockets(&self, opts: BroadcastOptions, rooms: Vec) -> Result<(), AdapterError>; /// Disconnects the sockets that match the [`BroadcastOptions`]. fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), Vec>; /// Returns all the rooms for this adapter. - fn rooms(&self) -> Result, Self::Error>; + fn rooms(&self) -> Result, AdapterError>; //TODO: implement // fn server_side_emit(&self, packet: Packet, opts: BroadcastOptions) -> Result; @@ -117,33 +108,23 @@ pub trait Adapter: std::fmt::Debug + Send + Sync + 'static { } /// The default adapter. Store the state in memory. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct LocalAdapter { rooms: RwLock>>, - ns: Weak>, -} - -impl From for AdapterError { - fn from(_: Infallible) -> AdapterError { - unreachable!() - } + ns: Weak, } impl Adapter for LocalAdapter { - type Error = Infallible; - - fn new(ns: Weak>) -> Self { - Self { - rooms: HashMap::new().into(), - ns, - } + fn boxed_clone(&self) -> Box { + Box::new(Self::default()) } - fn init(&self) -> Result<(), Infallible> { + fn init(&mut self, ns: Weak) -> Result<(), AdapterError> { + self.ns = ns; Ok(()) } - fn close(&self) -> Result<(), Infallible> { + fn close(&self) -> Result<(), AdapterError> { #[cfg(feature = "tracing")] tracing::debug!("closing local adapter: {}", self.ns.upgrade().unwrap().path); let mut rooms = self.rooms.write().unwrap(); @@ -152,21 +133,21 @@ impl Adapter for LocalAdapter { Ok(()) } - fn server_count(&self) -> Result { + fn server_count(&self) -> Result { Ok(1) } - fn add_all(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Infallible> { + fn add_all(&self, sid: Sid, rooms: Vec) -> Result<(), AdapterError> { let mut rooms_map = self.rooms.write().unwrap(); - for room in rooms.into_room_iter() { + for room in rooms { rooms_map.entry(room).or_default().insert(sid); } Ok(()) } - fn del(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Infallible> { + fn del(&self, sid: Sid, rooms: Vec) -> Result<(), AdapterError> { let mut rooms_map = self.rooms.write().unwrap(); - for room in rooms.into_room_iter() { + for room in rooms { if let Some(room) = rooms_map.get_mut(&room) { room.remove(&sid); } @@ -174,7 +155,7 @@ impl Adapter for LocalAdapter { Ok(()) } - fn del_all(&self, sid: Sid) -> Result<(), Infallible> { + fn del_all(&self, sid: Sid) -> Result<(), AdapterError> { let mut rooms_map = self.rooms.write().unwrap(); for room in rooms_map.values_mut() { room.remove(&sid); @@ -214,7 +195,7 @@ impl Adapter for LocalAdapter { AckInnerStream::broadcast(packet, sockets, timeout) } - fn sockets(&self, rooms: impl RoomParam) -> Result, Infallible> { + fn sockets(&self, rooms: Vec) -> Result, AdapterError> { let mut opts = BroadcastOptions::default(); opts.rooms.extend(rooms.into_room_iter()); Ok(self @@ -225,7 +206,7 @@ impl Adapter for LocalAdapter { } //TODO: make this operation O(1) - fn socket_rooms(&self, sid: Sid) -> Result>, Infallible> { + fn socket_rooms(&self, sid: Sid) -> Result>, AdapterError> { let rooms_map = self.rooms.read().unwrap(); Ok(rooms_map .iter() @@ -234,11 +215,11 @@ impl Adapter for LocalAdapter { .collect()) } - fn fetch_sockets(&self, opts: BroadcastOptions) -> Result>, Infallible> { + fn fetch_sockets(&self, opts: BroadcastOptions) -> Result, AdapterError> { Ok(self.apply_opts(opts)) } - fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) -> Result<(), Infallible> { + fn add_sockets(&self, opts: BroadcastOptions, rooms: Vec) -> Result<(), AdapterError> { let rooms: Vec = rooms.into_room_iter().collect(); for socket in self.apply_opts(opts) { self.add_all(socket.id, rooms.clone()).unwrap(); @@ -246,7 +227,7 @@ impl Adapter for LocalAdapter { Ok(()) } - fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) -> Result<(), Infallible> { + fn del_sockets(&self, opts: BroadcastOptions, rooms: Vec) -> Result<(), AdapterError> { let rooms: Vec = rooms.into_room_iter().collect(); for socket in self.apply_opts(opts) { self.del(socket.id, rooms.clone()).unwrap(); @@ -270,14 +251,17 @@ impl Adapter for LocalAdapter { } } - fn rooms(&self) -> Result, Self::Error> { + fn rooms(&self) -> Result, AdapterError> { Ok(self.rooms.read().unwrap().keys().cloned().collect()) } } impl LocalAdapter { + fn new() -> Self { + Self::default() + } /// Applies the given `opts` and return the sockets that match. - fn apply_opts(&self, opts: BroadcastOptions) -> Vec> { + fn apply_opts(&self, opts: BroadcastOptions) -> Vec { let rooms = opts.rooms; let except = self.get_except_sids(&opts.except); @@ -335,19 +319,26 @@ mod test { }; } + fn local_adapter(ns: &Arc) -> LocalAdapter { + let mut adapter = LocalAdapter::new(); + adapter.init(Arc::downgrade(ns)).unwrap(); + adapter + } #[tokio::test] async fn test_server_count() { - let ns = Namespace::new_dummy([]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); + let ns = Namespace::new_dummy([]).into(); + let adapter = local_adapter(&ns); assert_eq!(adapter.server_count().unwrap(), 1); } #[tokio::test] async fn test_add_all() { let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1", "room2"]).unwrap(); + let ns = Namespace::new_dummy([socket]).into(); + let adapter = local_adapter(&ns); + adapter + .add_all(socket, vec!["room1".into(), "room2".into()]) + .unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); assert_eq!(rooms_map.get("room1").unwrap().len(), 1); @@ -357,10 +348,12 @@ mod test { #[tokio::test] async fn test_del() { let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1", "room2"]).unwrap(); - adapter.del(socket, "room1").unwrap(); + let ns = Namespace::new_dummy([socket]).into(); + let adapter = local_adapter(&ns); + adapter + .add_all(socket, vec!["room1".into(), "room2".into()]) + .unwrap(); + adapter.del(socket, vec!["room1".into()]).unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); assert_eq!(rooms_map.get("room1").unwrap().len(), 0); @@ -370,9 +363,11 @@ mod test { #[tokio::test] async fn test_del_all() { let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1", "room2"]).unwrap(); + let ns = Namespace::new_dummy([socket]).into(); + let adapter = local_adapter(&ns); + adapter + .add_all(socket, vec!["room1".into(), "room2".into()]) + .unwrap(); adapter.del_all(socket).unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); @@ -385,11 +380,13 @@ mod test { let sid1 = Sid::new(); let sid2 = Sid::new(); let sid3 = Sid::new(); - let ns = Namespace::new_dummy([sid1, sid2, sid3]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(sid1, ["room1", "room2"]).unwrap(); - adapter.add_all(sid2, ["room1"]).unwrap(); - adapter.add_all(sid3, ["room2"]).unwrap(); + let ns = Namespace::new_dummy([sid1, sid2, sid3]).into(); + let adapter = local_adapter(&ns); + adapter + .add_all(sid1, vec!["room1".into(), "room2".into()]) + .unwrap(); + adapter.add_all(sid2, vec!["room1".into()]).unwrap(); + adapter.add_all(sid3, vec!["room2".into()]).unwrap(); assert!(adapter .socket_rooms(sid1) .unwrap() @@ -405,16 +402,16 @@ mod test { #[tokio::test] async fn test_add_socket() { let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1"]).unwrap(); + let ns = Namespace::new_dummy([socket]).into(); + let adapter = local_adapter(&ns); + adapter.add_all(socket, vec!["room1".into()]).unwrap(); let mut opts = BroadcastOptions { sid: Some(socket), ..Default::default() }; opts.rooms = hash_set!["room1".into()]; - adapter.add_sockets(opts, "room2").unwrap(); + adapter.add_sockets(opts, vec!["room2".into()]).unwrap(); let rooms_map = adapter.rooms.read().unwrap(); assert_eq!(rooms_map.len(), 2); @@ -425,16 +422,16 @@ mod test { #[tokio::test] async fn test_del_socket() { let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1"]).unwrap(); + let ns = Namespace::new_dummy([socket]).into(); + let adapter = local_adapter(&ns); + adapter.add_all(socket, vec!["room1".into()]).unwrap(); let mut opts = BroadcastOptions { sid: Some(socket), ..Default::default() }; opts.rooms = hash_set!["room1".into()]; - adapter.add_sockets(opts, "room2").unwrap(); + adapter.add_sockets(opts, vec!["room2".into()]).unwrap(); { let rooms_map = adapter.rooms.read().unwrap(); @@ -449,7 +446,7 @@ mod test { ..Default::default() }; opts.rooms = hash_set!["room1".into()]; - adapter.del_sockets(opts, "room2").unwrap(); + adapter.del_sockets(opts, vec!["room2".into()]).unwrap(); { let rooms_map = adapter.rooms.read().unwrap(); @@ -465,23 +462,29 @@ mod test { let socket0 = Sid::new(); let socket1 = Sid::new(); let socket2 = Sid::new(); - let ns = Namespace::new_dummy([socket0, socket1, socket2]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket0, ["room1", "room2"]).unwrap(); - adapter.add_all(socket1, ["room1", "room3"]).unwrap(); - adapter.add_all(socket2, ["room2", "room3"]).unwrap(); + let ns = Namespace::new_dummy([socket0, socket1, socket2]).into(); + let adapter = local_adapter(&ns); + adapter + .add_all(socket0, vec!["room1".into(), "room2".into()]) + .unwrap(); + adapter + .add_all(socket1, vec!["room1".into(), "room3".into()]) + .unwrap(); + adapter + .add_all(socket2, vec!["room2".into(), "room3".into()]) + .unwrap(); - let sockets = adapter.sockets("room1").unwrap(); + let sockets = adapter.sockets(vec!["room1".into()]).unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket0)); assert!(sockets.contains(&socket1)); - let sockets = adapter.sockets("room2").unwrap(); + let sockets = adapter.sockets(vec!["room2".into()]).unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket0)); assert!(sockets.contains(&socket2)); - let sockets = adapter.sockets("room3").unwrap(); + let sockets = adapter.sockets(vec!["room3".into()]).unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket1)); assert!(sockets.contains(&socket2)); @@ -492,16 +495,25 @@ mod test { let socket0 = Sid::new(); let socket1 = Sid::new(); let socket2 = Sid::new(); - let ns = Namespace::new_dummy([socket0, socket1, socket2]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); + let ns = Namespace::new_dummy([socket0, socket1, socket2]).into(); + let adapter = local_adapter(&ns); adapter - .add_all(socket0, ["room1", "room2", "room4"]) + .add_all( + socket0, + vec!["room1".into(), "room2".into(), "room4".into()], + ) .unwrap(); adapter - .add_all(socket1, ["room1", "room3", "room5"]) + .add_all( + socket1, + vec!["room1".into(), "room3".into(), "room5".into()], + ) .unwrap(); adapter - .add_all(socket2, ["room2", "room3", "room6"]) + .add_all( + socket2, + vec!["room2".into(), "room3".into(), "room6".into()], + ) .unwrap(); let mut opts = BroadcastOptions { @@ -511,7 +523,7 @@ mod test { opts.rooms = hash_set!["room5".into()]; adapter.disconnect_socket(opts).unwrap(); - let sockets = adapter.sockets("room2").unwrap(); + let sockets = adapter.sockets(vec!["room2".into()]).unwrap(); assert_eq!(sockets.len(), 2); assert!(sockets.contains(&socket2)); assert!(sockets.contains(&socket0)); @@ -521,15 +533,22 @@ mod test { let socket0 = Sid::new(); let socket1 = Sid::new(); let socket2 = Sid::new(); - let ns = Namespace::new_dummy([socket0, socket1, socket2]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); + let ns = Namespace::new_dummy([socket0, socket1, socket2]).into(); + let adapter = local_adapter(&ns); // Add socket 0 to room1 and room2 - adapter.add_all(socket0, ["room1", "room2"]).unwrap(); + adapter + .add_all(socket0, vec!["room1".into(), "room2".into()]) + .unwrap(); // Add socket 1 to room1 and room3 - adapter.add_all(socket1, ["room1", "room3"]).unwrap(); + adapter + .add_all(socket1, vec!["room1".into(), "room3".into()]) + .unwrap(); // Add socket 2 to room2 and room3 adapter - .add_all(socket2, ["room1", "room2", "room3"]) + .add_all( + socket2, + vec!["room1".into(), "room2".into(), "room3".into()], + ) .unwrap(); // socket 2 is the sender diff --git a/socketioxide/src/client.rs b/socketioxide/src/client.rs index ef60cf7e..d5a8da8a 100644 --- a/socketioxide/src/client.rs +++ b/socketioxide/src/client.rs @@ -11,7 +11,6 @@ use futures_util::{FutureExt, TryFutureExt}; use engineioxide::sid::Sid; use tokio::sync::oneshot; -use crate::adapter::Adapter; use crate::handler::ConnectHandler; use crate::socket::DisconnectReason; use crate::ProtocolVersion; @@ -23,12 +22,12 @@ use crate::{ }; #[derive(Debug)] -pub struct Client { +pub struct Client { pub(crate) config: Arc, - ns: RwLock, Arc>>>, + ns: RwLock, Arc>>, } -impl Client { +impl Client { pub fn new(config: Arc) -> Self { #[cfg(feature = "state")] crate::state::freeze_state(); @@ -113,12 +112,12 @@ impl Client { /// Adds a new namespace handler pub fn add_ns(&self, path: Cow<'static, str>, callback: C) where - C: ConnectHandler, + C: ConnectHandler, T: Send + Sync + 'static, { #[cfg(feature = "tracing")] tracing::debug!("adding namespace {}", path); - let ns = Namespace::new(path.clone(), callback); + let ns = Namespace::new(path.clone(), callback, self.config.adapter.boxed_clone()); self.ns.write().unwrap().insert(path, ns); } @@ -137,7 +136,7 @@ impl Client { } } - pub fn get_ns(&self, path: &str) -> Option>> { + pub fn get_ns(&self, path: &str) -> Option> { self.ns.read().unwrap().get(path).cloned() } @@ -215,7 +214,7 @@ pub struct SocketData { pub connect_recv_tx: Mutex>>, } -impl EngineIoHandler for Client { +impl EngineIoHandler for Client { type Data = SocketData; #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, socket), fields(sid = socket.id.to_string())))] @@ -363,12 +362,12 @@ mod test { use crate::adapter::LocalAdapter; const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(10); - fn create_client() -> super::Client { + fn create_client() -> super::Client { let config = crate::SocketIoConfig { connect_timeout: CONNECT_TIMEOUT, ..Default::default() }; - let client = Client::::new(std::sync::Arc::new(config)); + let client = Client::new(std::sync::Arc::new(config)); client.add_ns("/".into(), || {}); client } diff --git a/socketioxide/src/errors.rs b/socketioxide/src/errors.rs index 63e446f8..ead60030 100644 --- a/socketioxide/src/errors.rs +++ b/socketioxide/src/errors.rs @@ -1,7 +1,6 @@ use engineioxide::{sid::Sid, socket::DisconnectReason as EIoDisconnectReason}; use std::fmt::{Debug, Display}; use tokio::{sync::mpsc::error::TrySendError, time::error::Elapsed}; - /// Error type for socketio #[derive(thiserror::Error, Debug)] pub enum Error { diff --git a/socketioxide/src/extract/data.rs b/socketioxide/src/extract/data.rs index 51ba5292..691b283f 100644 --- a/socketioxide/src/extract/data.rs +++ b/socketioxide/src/extract/data.rs @@ -21,27 +21,25 @@ fn upwrap_array(v: &mut Value) { /// If a deserialization error occurs, the [`ConnectHandler`](crate::handler::ConnectHandler) won't be called /// and an error log will be print if the `tracing` feature is enabled. pub struct Data(pub T); -impl FromConnectParts for Data +impl FromConnectParts for Data where T: DeserializeOwned, - A: Adapter, { type Error = serde_json::Error; - fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { + fn from_connect_parts(_: &Arc, auth: &Option) -> Result { auth.as_ref() .map(|a| serde_json::from_str::(a)) .unwrap_or(serde_json::from_str::("{}")) .map(Data) } } -impl FromMessageParts for Data +impl FromMessageParts for Data where T: DeserializeOwned, - A: Adapter, { type Error = serde_json::Error; fn from_message_parts( - _: &Arc>, + _: &Arc, v: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -54,13 +52,12 @@ where /// An Extractor that returns the deserialized data related to the event. pub struct TryData(pub Result); -impl FromConnectParts for TryData +impl FromConnectParts for TryData where T: DeserializeOwned, - A: Adapter, { type Error = Infallible; - fn from_connect_parts(_: &Arc>, auth: &Option) -> Result { + fn from_connect_parts(_: &Arc, auth: &Option) -> Result { let v: Result = auth .as_ref() .map(|a| serde_json::from_str(a)) @@ -68,14 +65,13 @@ where Ok(TryData(v)) } } -impl FromMessageParts for TryData +impl FromMessageParts for TryData where T: DeserializeOwned, - A: Adapter, { type Error = Infallible; fn from_message_parts( - _: &Arc>, + _: &Arc, v: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -88,10 +84,10 @@ where /// An Extractor that returns the binary data of the message. /// If there is no binary data, it will contain an empty vec. pub struct Bin(pub Vec); -impl FromMessage for Bin { +impl FromMessage for Bin { type Error = Infallible; fn from_message( - _: Arc>, + _: Arc, _: serde_json::Value, bin: Vec, _: Option, diff --git a/socketioxide/src/extract/extensions.rs b/socketioxide/src/extract/extensions.rs index 9a898503..c8d81e3a 100644 --- a/socketioxide/src/extract/extensions.rs +++ b/socketioxide/src/extract/extensions.rs @@ -1,7 +1,6 @@ use std::convert::Infallible; use std::sync::Arc; -use crate::adapter::Adapter; use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts}; use crate::socket::{DisconnectReason, Socket}; use bytes::Bytes; @@ -30,7 +29,7 @@ impl std::fmt::Debug for ExtensionNotFound { impl std::error::Error for ExtensionNotFound {} fn extract_http_extension( - s: &Arc>, + s: &Arc, ) -> Result> { s.req_parts() .extensions @@ -44,45 +43,43 @@ pub struct HttpExtension(pub T); /// An Extractor that returns a clone extension from the request parts if it exists. pub struct MaybeHttpExtension(pub Option); -impl FromConnectParts for HttpExtension { +impl FromConnectParts for HttpExtension { type Error = ExtensionNotFound; fn from_connect_parts( - s: &Arc>, + s: &Arc, _: &Option, ) -> Result> { extract_http_extension(s).map(HttpExtension) } } -impl FromConnectParts for MaybeHttpExtension { +impl FromConnectParts for MaybeHttpExtension { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc, _: &Option) -> Result { Ok(MaybeHttpExtension(extract_http_extension(s).ok())) } } -impl FromDisconnectParts for HttpExtension { +impl FromDisconnectParts for HttpExtension { type Error = ExtensionNotFound; fn from_disconnect_parts( - s: &Arc>, + s: &Arc, _: DisconnectReason, ) -> Result> { extract_http_extension(s).map(HttpExtension) } } -impl FromDisconnectParts - for MaybeHttpExtension -{ +impl FromDisconnectParts for MaybeHttpExtension { type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(MaybeHttpExtension(extract_http_extension(s).ok())) } } -impl FromMessageParts for HttpExtension { +impl FromMessageParts for HttpExtension { type Error = ExtensionNotFound; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -90,10 +87,10 @@ impl FromMessageParts for HttpE extract_http_extension(s).map(HttpExtension) } } -impl FromMessageParts for MaybeHttpExtension { +impl FromMessageParts for MaybeHttpExtension { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -110,7 +107,7 @@ mod extensions_extract { use super::*; fn extract_extension( - s: &Arc>, + s: &Arc, ) -> Result> { s.extensions .get::() @@ -125,43 +122,40 @@ mod extensions_extract { /// An Extractor that returns the extension of the given type if it exists or `None` otherwise. pub struct MaybeExtension(pub Option); - impl FromConnectParts for Extension { + impl FromConnectParts for Extension { type Error = ExtensionNotFound; fn from_connect_parts( - s: &Arc>, + s: &Arc, _: &Option, ) -> Result> { extract_extension(s).map(Extension) } } - impl FromConnectParts for MaybeExtension { + impl FromConnectParts for MaybeExtension { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc, _: &Option) -> Result { Ok(MaybeExtension(extract_extension(s).ok())) } } - impl FromDisconnectParts for Extension { + impl FromDisconnectParts for Extension { type Error = ExtensionNotFound; fn from_disconnect_parts( - s: &Arc>, + s: &Arc, _: DisconnectReason, ) -> Result> { extract_extension(s).map(Extension) } } - impl FromDisconnectParts for MaybeExtension { + impl FromDisconnectParts for MaybeExtension { type Error = Infallible; - fn from_disconnect_parts( - s: &Arc>, - _: DisconnectReason, - ) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(MaybeExtension(extract_extension(s).ok())) } } - impl FromMessageParts for Extension { + impl FromMessageParts for Extension { type Error = ExtensionNotFound; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -169,10 +163,10 @@ mod extensions_extract { extract_extension(s).map(Extension) } } - impl FromMessageParts for MaybeExtension { + impl FromMessageParts for MaybeExtension { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, diff --git a/socketioxide/src/extract/mod.rs b/socketioxide/src/extract/mod.rs index b7adcb3e..9f5ac96f 100644 --- a/socketioxide/src/extract/mod.rs +++ b/socketioxide/src/extract/mod.rs @@ -56,9 +56,9 @@ //! } //! impl std::error::Error for UserIdNotFound {} //! -//! impl FromConnectParts for UserId { +//! impl FromConnectParts for UserId { //! type Error = Infallible; -//! fn from_connect_parts(s: &Arc>, _: &Option) -> Result { +//! fn from_connect_parts(s: &Arc, _: &Option) -> Result { //! // In a real app it would be better to parse the query params with a crate like `url` //! let uri = &s.req_parts().uri; //! let uid = uri @@ -72,11 +72,11 @@ //! //! // Here, if the user id is not found, the handler won't be called //! // and a tracing `error` log will be printed (if the `tracing` feature is enabled) -//! impl FromMessageParts for UserId { +//! impl FromMessageParts for UserId { //! type Error = UserIdNotFound; //! //! fn from_message_parts( -//! s: &Arc>, +//! s: &Arc, //! _: &mut serde_json::Value, //! _: &mut Vec, //! _: &Option, diff --git a/socketioxide/src/extract/socket.rs b/socketioxide/src/extract/socket.rs index 0a08ba27..aeefff4f 100644 --- a/socketioxide/src/extract/socket.rs +++ b/socketioxide/src/extract/socket.rs @@ -5,7 +5,6 @@ use crate::errors::{DisconnectError, SendError}; use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts}; use crate::socket::DisconnectReason; use crate::{ - adapter::{Adapter, LocalAdapter}, packet::Packet, socket::Socket, }; @@ -14,18 +13,18 @@ use serde::Serialize; /// An Extractor that returns a reference to a [`Socket`]. #[derive(Debug)] -pub struct SocketRef(Arc>); +pub struct SocketRef(Arc); -impl FromConnectParts for SocketRef { +impl FromConnectParts for SocketRef { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc, _: &Option) -> Result { Ok(SocketRef(s.clone())) } } -impl FromMessageParts for SocketRef { +impl FromMessageParts for SocketRef { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -33,40 +32,40 @@ impl FromMessageParts for SocketRef { Ok(SocketRef(s.clone())) } } -impl FromDisconnectParts for SocketRef { +impl FromDisconnectParts for SocketRef { type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(SocketRef(s.clone())) } } -impl std::ops::Deref for SocketRef { - type Target = Socket; +impl std::ops::Deref for SocketRef { + type Target = Socket; #[inline(always)] fn deref(&self) -> &Self::Target { &self.0 } } -impl PartialEq for SocketRef { +impl PartialEq for SocketRef { #[inline(always)] fn eq(&self, other: &Self) -> bool { self.0.id == other.0.id } } -impl From>> for SocketRef { +impl From> for SocketRef { #[inline(always)] - fn from(socket: Arc>) -> Self { + fn from(socket: Arc) -> Self { Self(socket) } } -impl Clone for SocketRef { +impl Clone for SocketRef { fn clone(&self) -> Self { Self(self.0.clone()) } } -impl SocketRef { +impl SocketRef { /// Disconnect the socket from the current namespace, /// /// It will also call the disconnect handler if it is set. @@ -79,15 +78,15 @@ impl SocketRef { /// An Extractor to send an ack response corresponding to the current event. /// If the client sent a normal message without expecting an ack, the ack callback will do nothing. #[derive(Debug)] -pub struct AckSender { +pub struct AckSender { binary: Vec, - socket: Arc>, + socket: Arc, ack_id: Option, } -impl FromMessageParts for AckSender { +impl FromMessageParts for AckSender { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, ack_id: &Option, @@ -95,8 +94,8 @@ impl FromMessageParts for AckSender { Ok(Self::new(s.clone(), *ack_id)) } } -impl AckSender { - pub(crate) fn new(socket: Arc>, ack_id: Option) -> Self { +impl AckSender { + pub(crate) fn new(socket: Arc, ack_id: Option) -> Self { Self { binary: vec![], socket, @@ -137,16 +136,16 @@ impl AckSender { } } -impl FromConnectParts for crate::ProtocolVersion { +impl FromConnectParts for crate::ProtocolVersion { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc, _: &Option) -> Result { Ok(s.protocol()) } } -impl FromMessageParts for crate::ProtocolVersion { +impl FromMessageParts for crate::ProtocolVersion { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -154,23 +153,23 @@ impl FromMessageParts for crate::ProtocolVersion { Ok(s.protocol()) } } -impl FromDisconnectParts for crate::ProtocolVersion { +impl FromDisconnectParts for crate::ProtocolVersion { type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(s.protocol()) } } -impl FromConnectParts for crate::TransportType { +impl FromConnectParts for crate::TransportType { type Error = Infallible; - fn from_connect_parts(s: &Arc>, _: &Option) -> Result { + fn from_connect_parts(s: &Arc, _: &Option) -> Result { Ok(s.transport_type()) } } -impl FromMessageParts for crate::TransportType { +impl FromMessageParts for crate::TransportType { type Error = Infallible; fn from_message_parts( - s: &Arc>, + s: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, @@ -178,17 +177,17 @@ impl FromMessageParts for crate::TransportType { Ok(s.transport_type()) } } -impl FromDisconnectParts for crate::TransportType { +impl FromDisconnectParts for crate::TransportType { type Error = Infallible; - fn from_disconnect_parts(s: &Arc>, _: DisconnectReason) -> Result { + fn from_disconnect_parts(s: &Arc, _: DisconnectReason) -> Result { Ok(s.transport_type()) } } -impl FromDisconnectParts for DisconnectReason { +impl FromDisconnectParts for DisconnectReason { type Error = Infallible; fn from_disconnect_parts( - _: &Arc>, + _: &Arc, reason: DisconnectReason, ) -> Result { Ok(reason) diff --git a/socketioxide/src/extract/state.rs b/socketioxide/src/extract/state.rs index 4c7f67ad..2b8485fd 100644 --- a/socketioxide/src/extract/state.rs +++ b/socketioxide/src/extract/state.rs @@ -60,21 +60,18 @@ impl std::fmt::Debug for StateNotFound { } impl std::error::Error for StateNotFound {} -impl FromConnectParts for State { +impl FromConnectParts for State { type Error = StateNotFound; - fn from_connect_parts( - _: &Arc>, - _: &Option, - ) -> Result> { + fn from_connect_parts(_: &Arc, _: &Option) -> Result> { get_state::() .map(State) .ok_or(StateNotFound(std::marker::PhantomData)) } } -impl FromDisconnectParts for State { +impl FromDisconnectParts for State { type Error = StateNotFound; fn from_disconnect_parts( - _: &Arc>, + _: &Arc, _: DisconnectReason, ) -> Result> { get_state::() @@ -82,10 +79,10 @@ impl FromDisconnectParts for State { .ok_or(StateNotFound(std::marker::PhantomData)) } } -impl FromMessageParts for State { +impl FromMessageParts for State { type Error = StateNotFound; fn from_message_parts( - _: &Arc>, + _: &Arc, _: &mut serde_json::Value, _: &mut Vec, _: &Option, diff --git a/socketioxide/src/handler/connect.rs b/socketioxide/src/handler/connect.rs index 888a3b30..6bbf3206 100644 --- a/socketioxide/src/handler/connect.rs +++ b/socketioxide/src/handler/connect.rs @@ -122,16 +122,16 @@ use crate::{adapter::Adapter, socket::Socket}; use super::MakeErasedHandler; /// A Type Erased [`ConnectHandler`] so it can be stored in a HashMap -pub(crate) type BoxedConnectHandler = Box>; +pub(crate) type BoxedConnectHandler = Box; type MiddlewareRes = Result<(), Box>; type MiddlewareResFut<'a> = Pin + Send + 'a>>; -pub(crate) trait ErasedConnectHandler: Send + Sync + 'static { - fn call(&self, s: Arc>, auth: Option); +pub(crate) trait ErasedConnectHandler: Send + Sync + 'static { + fn call(&self, s: Arc, auth: Option); fn call_middleware<'a>( &'a self, - s: Arc>, + s: Arc, auth: &'a Option, ) -> MiddlewareResFut<'a>; } @@ -142,13 +142,13 @@ pub(crate) trait ErasedConnectHandler: Send + Sync + 'static { /// /// * See the [`connect`](super::connect) module doc for more details on connect handler. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. -pub trait FromConnectParts: Sized { +pub trait FromConnectParts: Sized { /// The error type returned by the extractor type Error: std::error::Error + Send + 'static; /// Extract the arguments from the connect event. /// If it fails, the handler is not called - fn from_connect_parts(s: &Arc>, auth: &Option) -> Result; + fn from_connect_parts(s: &Arc, auth: &Option) -> Result; } /// Define a middleware for the connect event. @@ -157,16 +157,16 @@ pub trait FromConnectParts: Sized { /// /// * See the [`connect`](super::connect) module doc for more details on connect middlewares. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. -pub trait ConnectMiddleware: Send + Sync + 'static { +pub trait ConnectMiddleware: Send + Sync + 'static { /// Call the middleware with the given arguments. fn call<'a>( &'a self, - s: Arc>, + s: Arc, auth: &'a Option, ) -> impl Future + Send; #[doc(hidden)] - fn phantom(&self) -> std::marker::PhantomData<(A, T)> { + fn phantom(&self) -> std::marker::PhantomData { std::marker::PhantomData } } @@ -176,14 +176,14 @@ pub trait ConnectMiddleware: Send + Sync + 'static { /// /// * See the [`connect`](super::connect) module doc for more details on connect handler. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. -pub trait ConnectHandler: Send + Sync + 'static { +pub trait ConnectHandler: Send + Sync + 'static { /// Call the handler with the given arguments. - fn call(&self, s: Arc>, auth: Option); + fn call(&self, s: Arc, auth: Option); /// Call the middleware with the given arguments. fn call_middleware<'a>( &'a self, - _: Arc>, + _: Arc, _: &'a Option, ) -> MiddlewareResFut<'a> { Box::pin(async move { Ok(()) }) @@ -233,10 +233,10 @@ pub trait ConnectHandler: Send + Sync + 'static { /// let (_, io) = SocketIo::new_layer(); /// io.ns("/", handler.with(middleware).with(other_middleware)); /// ``` - fn with(self, middleware: M) -> impl ConnectHandler + fn with(self, middleware: M) -> impl ConnectHandler where Self: Sized, - M: ConnectMiddleware + Send + Sync + 'static, + M: ConnectMiddleware + Send + Sync + 'static, T: Send + Sync + 'static, T1: Send + Sync + 'static, { @@ -252,10 +252,10 @@ pub trait ConnectHandler: Send + Sync + 'static { std::marker::PhantomData } } -struct LayeredConnectHandler { +struct LayeredConnectHandler { handler: H, middleware: M, - phantom: std::marker::PhantomData<(A, T, T1)>, + phantom: std::marker::PhantomData<(T, T1)>, } struct ConnectMiddlewareLayer { middleware: M, @@ -263,57 +263,56 @@ struct ConnectMiddlewareLayer { phantom: std::marker::PhantomData<(T, T1)>, } -impl MakeErasedHandler +impl MakeErasedHandler where - H: ConnectHandler + Send + Sync + 'static, + H: ConnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { - pub fn new_ns_boxed(inner: H) -> Box> { + pub fn new_ns_boxed(inner: H) -> Box { Box::new(MakeErasedHandler::new(inner)) } } -impl ErasedConnectHandler for MakeErasedHandler +impl ErasedConnectHandler for MakeErasedHandler where - H: ConnectHandler + Send + Sync + 'static, + H: ConnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { - fn call(&self, s: Arc>, auth: Option) { + fn call(&self, s: Arc, auth: Option) { self.handler.call(s, auth); } fn call_middleware<'a>( &'a self, - s: Arc>, + s: Arc, auth: &'a Option, ) -> MiddlewareResFut<'a> { self.handler.call_middleware(s, auth) } } -impl ConnectHandler for LayeredConnectHandler +impl ConnectHandler for LayeredConnectHandler where - A: Adapter, - H: ConnectHandler + Send + Sync + 'static, - M: ConnectMiddleware + Send + Sync + 'static, + H: ConnectHandler + Send + Sync + 'static, + M: ConnectMiddleware + Send + Sync + 'static, T: Send + Sync + 'static, T1: Send + Sync + 'static, { - fn call(&self, s: Arc>, auth: Option) { + fn call(&self, s: Arc, auth: Option) { self.handler.call(s, auth); } fn call_middleware<'a>( &'a self, - s: Arc>, + s: Arc, auth: &'a Option, ) -> MiddlewareResFut<'a> { Box::pin(async move { self.middleware.call(s, auth).await }) } - fn with(self, next: M2) -> impl ConnectHandler + fn with(self, next: M2) -> impl ConnectHandler where - M2: ConnectMiddleware + Send + Sync + 'static, + M2: ConnectMiddleware + Send + Sync + 'static, T2: Send + Sync + 'static, { LayeredConnectHandler { @@ -327,28 +326,26 @@ where } } } -impl ConnectMiddleware for LayeredConnectHandler +impl ConnectMiddleware for LayeredConnectHandler where - A: Adapter, - H: ConnectHandler + Send + Sync + 'static, - N: ConnectMiddleware + Send + Sync + 'static, + H: ConnectHandler + Send + Sync + 'static, + N: ConnectMiddleware + Send + Sync + 'static, T: Send + Sync + 'static, T1: Send + Sync + 'static, { - async fn call<'a>(&'a self, s: Arc>, auth: &'a Option) -> MiddlewareRes { + async fn call<'a>(&'a self, s: Arc, auth: &'a Option) -> MiddlewareRes { self.middleware.call(s, auth).await } } -impl ConnectMiddleware for ConnectMiddlewareLayer +impl ConnectMiddleware for ConnectMiddlewareLayer where - A: Adapter, - M: ConnectMiddleware + Send + Sync + 'static, - N: ConnectMiddleware + Send + Sync + 'static, + M: ConnectMiddleware + Send + Sync + 'static, + N: ConnectMiddleware + Send + Sync + 'static, T: Send + Sync + 'static, T1: Send + Sync + 'static, { - async fn call<'a>(&'a self, s: Arc>, auth: &'a Option) -> MiddlewareRes { + async fn call<'a>(&'a self, s: Arc, auth: &'a Option) -> MiddlewareRes { self.middleware.call(s.clone(), auth).await?; self.next.call(s, auth).await } @@ -366,14 +363,13 @@ macro_rules! impl_handler_async { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl ConnectHandler for F + impl ConnectHandler<(private::Async, $($ty,)*)> for F where F: FnOnce($($ty,)*) -> Fut + Send + Sync + Clone + 'static, Fut: Future + Send + 'static, - A: Adapter, - $( $ty: FromConnectParts + Send, )* + $( $ty: FromConnectParts + Send, )* { - fn call(&self, s: Arc>, auth: Option) { + fn call(&self, s: Arc, auth: Option) { $( let $ty = match $ty::from_connect_parts(&s, &auth) { Ok(v) => v, @@ -398,13 +394,12 @@ macro_rules! impl_handler { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl ConnectHandler for F + impl ConnectHandler<(private::Sync, $($ty,)*)> for F where F: FnOnce($($ty,)*) + Send + Sync + Clone + 'static, - A: Adapter, - $( $ty: FromConnectParts + Send, )* + $( $ty: FromConnectParts + Send, )* { - fn call(&self, s: Arc>, auth: Option) { + fn call(&self, s: Arc, auth: Option) { $( let $ty = match $ty::from_connect_parts(&s, &auth) { Ok(v) => v, @@ -427,15 +422,14 @@ macro_rules! impl_middleware_async { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl ConnectMiddleware for F + impl ConnectMiddleware<(private::Async, $($ty,)*)> for F where F: FnOnce($($ty,)*) -> Fut + Send + Sync + Clone + 'static, Fut: Future> + Send + 'static, - A: Adapter, E: std::fmt::Display + Send + 'static, - $( $ty: FromConnectParts + Send, )* + $( $ty: FromConnectParts + Send, )* { - async fn call<'a>(&'a self, s: Arc>, auth: &'a Option) -> MiddlewareRes { + async fn call<'a>(&'a self, s: Arc, auth: &'a Option) -> MiddlewareRes { $( let $ty = match $ty::from_connect_parts(&s, &auth) { Ok(v) => v, @@ -465,14 +459,13 @@ macro_rules! impl_middleware { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl ConnectMiddleware for F + impl ConnectMiddleware<(private::Sync, $($ty,)*)> for F where F: FnOnce($($ty,)*) -> Result<(), E> + Send + Sync + Clone + 'static, - A: Adapter, E: std::fmt::Display + Send + 'static, - $( $ty: FromConnectParts + Send, )* + $( $ty: FromConnectParts + Send, )* { - async fn call<'a>(&'a self, s: Arc>, auth: &'a Option) -> MiddlewareRes { + async fn call<'a>(&'a self, s: Arc, auth: &'a Option) -> MiddlewareRes { $( let $ty = match $ty::from_connect_parts(&s, &auth) { Ok(v) => v, diff --git a/socketioxide/src/handler/disconnect.rs b/socketioxide/src/handler/disconnect.rs index b63adb15..442d8d15 100644 --- a/socketioxide/src/handler/disconnect.rs +++ b/socketioxide/src/handler/disconnect.rs @@ -66,28 +66,28 @@ use crate::{ use super::MakeErasedHandler; /// A Type Erased [`DisconnectHandler`] so it can be stored in a HashMap -pub(crate) type BoxedDisconnectHandler = Box>; -pub(crate) trait ErasedDisconnectHandler: Send + Sync + 'static { - fn call(&self, s: Arc>, reason: DisconnectReason); +pub(crate) type BoxedDisconnectHandler = Box; +pub(crate) trait ErasedDisconnectHandler: Send + Sync + 'static { + fn call(&self, s: Arc, reason: DisconnectReason); } -impl MakeErasedHandler +impl MakeErasedHandler where T: Send + Sync + 'static, - H: DisconnectHandler + Send + Sync + 'static, + H: DisconnectHandler + Send + Sync + 'static, { - pub fn new_disconnect_boxed(inner: H) -> Box> { + pub fn new_disconnect_boxed(inner: H) -> Box { Box::new(MakeErasedHandler::new(inner)) } } -impl ErasedDisconnectHandler for MakeErasedHandler +impl ErasedDisconnectHandler for MakeErasedHandler where - H: DisconnectHandler + Send + Sync + 'static, + H: DisconnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { #[inline(always)] - fn call(&self, s: Arc>, reason: DisconnectReason) { + fn call(&self, s: Arc, reason: DisconnectReason) { self.handler.call(s, reason); } } @@ -98,14 +98,14 @@ where /// /// * See the [`disconnect`](super::disconnect) module doc for more details on disconnect handler. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. -pub trait FromDisconnectParts: Sized { +pub trait FromDisconnectParts: Sized { /// The error type returned by the extractor type Error: std::error::Error + 'static; /// Extract the arguments from the disconnect event. /// If it fails, the handler is not called fn from_disconnect_parts( - s: &Arc>, + s: &Arc, reason: DisconnectReason, ) -> Result; } @@ -115,9 +115,9 @@ pub trait FromDisconnectParts: Sized { /// /// * See the [`disconnect`](super::disconnect) module doc for more details on disconnect handler. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. -pub trait DisconnectHandler: Send + Sync + 'static { +pub trait DisconnectHandler: Send + Sync + 'static { /// Call the handler with the given arguments. - fn call(&self, s: Arc>, reason: DisconnectReason); + fn call(&self, s: Arc, reason: DisconnectReason); #[doc(hidden)] fn phantom(&self) -> std::marker::PhantomData { @@ -137,14 +137,13 @@ macro_rules! impl_handler_async { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl DisconnectHandler for F + impl DisconnectHandler<(private::Async, $($ty,)*)> for F where F: FnOnce($($ty,)*) -> Fut + Send + Sync + Clone + 'static, Fut: Future + Send + 'static, - A: Adapter, - $( $ty: FromDisconnectParts + Send, )* + $( $ty: FromDisconnectParts + Send, )* { - fn call(&self, s: Arc>, reason: DisconnectReason) { + fn call(&self, s: Arc, reason: DisconnectReason) { $( let $ty = match $ty::from_disconnect_parts(&s, reason) { Ok(v) => v, @@ -169,13 +168,12 @@ macro_rules! impl_handler { [$($ty:ident),*] ) => { #[allow(non_snake_case, unused)] - impl DisconnectHandler for F + impl DisconnectHandler<(private::Sync, $($ty,)*)> for F where F: FnOnce($($ty,)*) + Send + Sync + Clone + 'static, - A: Adapter, - $( $ty: FromDisconnectParts + Send, )* + $( $ty: FromDisconnectParts + Send, )* { - fn call(&self, s: Arc>, reason: DisconnectReason) { + fn call(&self, s: Arc, reason: DisconnectReason) { $( let $ty = match $ty::from_disconnect_parts(&s, reason) { Ok(v) => v, diff --git a/socketioxide/src/handler/message.rs b/socketioxide/src/handler/message.rs index 5b4ab255..4d7e715b 100644 --- a/socketioxide/src/handler/message.rs +++ b/socketioxide/src/handler/message.rs @@ -83,10 +83,10 @@ use crate::socket::Socket; use super::MakeErasedHandler; /// A Type Erased [`MessageHandler`] so it can be stored in a HashMap -pub(crate) type BoxedMessageHandler = Box>; +pub(crate) type BoxedMessageHandler = Box; -pub(crate) trait ErasedMessageHandler: Send + Sync + 'static { - fn call(&self, s: Arc>, v: Value, p: Vec, ack_id: Option); +pub(crate) trait ErasedMessageHandler: Send + Sync + 'static { + fn call(&self, s: Arc, v: Value, p: Vec, ack_id: Option); } /// Define a handler for the connect event. @@ -100,9 +100,9 @@ pub(crate) trait ErasedMessageHandler: Send + Sync + 'static { note = "Function argument is not a valid socketio extractor. \nSee `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details", ) )] -pub trait MessageHandler: Send + Sync + 'static { +pub trait MessageHandler: Send + Sync + 'static { /// Call the handler with the given arguments - fn call(&self, s: Arc>, v: Value, p: Vec, ack_id: Option); + fn call(&self, s: Arc, v: Value, p: Vec, ack_id: Option); #[doc(hidden)] fn phantom(&self) -> std::marker::PhantomData { @@ -110,25 +110,23 @@ pub trait MessageHandler: Send + Sync + 'static { } } -impl MakeErasedHandler +impl MakeErasedHandler where T: Send + Sync + 'static, - H: MessageHandler, - A: Adapter, + H: MessageHandler, { - pub fn new_message_boxed(inner: H) -> Box> { + pub fn new_message_boxed(inner: H) -> Box { Box::new(MakeErasedHandler::new(inner)) } } -impl ErasedMessageHandler for MakeErasedHandler +impl ErasedMessageHandler for MakeErasedHandler where T: Send + Sync + 'static, - H: MessageHandler, - A: Adapter, + H: MessageHandler, { #[inline(always)] - fn call(&self, s: Arc>, v: Value, p: Vec, ack_id: Option) { + fn call(&self, s: Arc, v: Value, p: Vec, ack_id: Option) { self.handler.call(s, v, p, ack_id); } } @@ -157,14 +155,14 @@ mod private { note = "Function argument is not a valid socketio extractor. \nSee `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details", ) )] -pub trait FromMessageParts: Sized { +pub trait FromMessageParts: Sized { /// The error type returned by the extractor type Error: std::error::Error + 'static; /// Extract the arguments from the message event. /// If it fails, the handler is not called. fn from_message_parts( - s: &Arc>, + s: &Arc, v: &mut Value, p: &mut Vec, ack_id: &Option, @@ -182,14 +180,14 @@ pub trait FromMessageParts: Sized { note = "Function argument is not a valid socketio extractor. \nSee `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details", ) )] -pub trait FromMessage: Sized { +pub trait FromMessage: Sized { /// The error type returned by the extractor type Error: std::error::Error + 'static; /// Extract the arguments from the message event. /// If it fails, the handler is not called fn from_message( - s: Arc>, + s: Arc, v: Value, p: Vec, ack_id: Option, @@ -197,14 +195,13 @@ pub trait FromMessage: Sized { } /// All the types that implement [`FromMessageParts`] also implement [`FromMessage`] -impl FromMessage for T +impl FromMessage for T where - T: FromMessageParts, - A: Adapter, + T: FromMessageParts, { type Error = T::Error; fn from_message( - s: Arc>, + s: Arc, mut v: Value, mut p: Vec, ack_id: Option, @@ -214,25 +211,23 @@ where } /// Empty Async handler -impl MessageHandler for F +impl MessageHandler<(private::Async,)> for F where F: FnOnce() -> Fut + Send + Sync + Clone + 'static, Fut: Future + Send + 'static, - A: Adapter, { - fn call(&self, _: Arc>, _: Value, _: Vec, _: Option) { + fn call(&self, _: Arc, _: Value, _: Vec, _: Option) { let fut = (self.clone())(); tokio::spawn(fut); } } /// Empty Sync handler -impl MessageHandler for F +impl MessageHandler<(private::Sync,)> for F where F: FnOnce() + Send + Sync + Clone + 'static, - A: Adapter, { - fn call(&self, _: Arc>, _: Value, _: Vec, _: Option) { + fn call(&self, _: Arc, _: Value, _: Vec, _: Option) { (self.clone())(); } } @@ -242,15 +237,14 @@ macro_rules! impl_async_handler { [$($ty:ident),*], $last:ident ) => { #[allow(non_snake_case, unused)] - impl MessageHandler for F + impl MessageHandler<(private::Async, M, $($ty,)* $last,)> for F where F: FnOnce($($ty,)* $last,) -> Fut + Send + Sync + Clone + 'static, Fut: Future + Send + 'static, - A: Adapter, - $( $ty: FromMessageParts + Send, )* - $last: FromMessage + Send, + $( $ty: FromMessageParts + Send, )* + $last: FromMessage + Send, { - fn call(&self, s: Arc>, mut v: Value, mut p: Vec, ack_id: Option) { + fn call(&self, s: Arc, mut v: Value, mut p: Vec, ack_id: Option) { $( let $ty = match $ty::from_message_parts(&s, &mut v, &mut p, &ack_id) { Ok(v) => v, @@ -281,14 +275,13 @@ macro_rules! impl_handler { [$($ty:ident),*], $last:ident ) => { #[allow(non_snake_case, unused)] - impl MessageHandler for F + impl MessageHandler<(private::Sync, M, $($ty,)* $last,)> for F where F: FnOnce($($ty,)* $last,) + Send + Sync + Clone + 'static, - A: Adapter, - $( $ty: FromMessageParts + Send, )* - $last: FromMessage + Send, + $( $ty: FromMessageParts + Send, )* + $last: FromMessage + Send, { - fn call(&self, s: Arc>, mut v: Value, mut p: Vec, ack_id: Option) { + fn call(&self, s: Arc, mut v: Value, mut p: Vec, ack_id: Option) { $( let $ty = match $ty::from_message_parts(&s, &mut v, &mut p, &ack_id) { Ok(v) => v, diff --git a/socketioxide/src/handler/mod.rs b/socketioxide/src/handler/mod.rs index 9f51a6dd..2fbdc4fc 100644 --- a/socketioxide/src/handler/mod.rs +++ b/socketioxide/src/handler/mod.rs @@ -12,16 +12,14 @@ pub use disconnect::{DisconnectHandler, FromDisconnectParts}; pub(crate) use message::BoxedMessageHandler; pub use message::{FromMessage, FromMessageParts, MessageHandler}; /// A struct used to erase the type of a [`ConnectHandler`] or [`MessageHandler`] so it can be stored in a map -pub(crate) struct MakeErasedHandler { +pub(crate) struct MakeErasedHandler { handler: H, - adapter: std::marker::PhantomData, type_: std::marker::PhantomData, } -impl MakeErasedHandler { +impl MakeErasedHandler { pub fn new(handler: H) -> Self { Self { handler, - adapter: std::marker::PhantomData, type_: std::marker::PhantomData, } } diff --git a/socketioxide/src/io.rs b/socketioxide/src/io.rs index 3f4ef4c4..260e486c 100644 --- a/socketioxide/src/io.rs +++ b/socketioxide/src/io.rs @@ -17,11 +17,11 @@ use crate::{ layer::SocketIoLayer, operators::{BroadcastOperators, RoomParam}, service::SocketIoService, - BroadcastError, DisconnectError, + AdapterError, BroadcastError, DisconnectError, }; /// Configuration for Socket.IO & Engine.IO -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct SocketIoConfig { /// The inner Engine.IO config pub engine_config: EngineIoConfig, @@ -35,6 +35,9 @@ pub struct SocketIoConfig { /// /// Defaults to 45 seconds. pub connect_timeout: Duration, + + /// The adapter to use for this server. + pub adapter: Box, } impl Default for SocketIoConfig { @@ -46,6 +49,17 @@ impl Default for SocketIoConfig { }, ack_timeout: Duration::from_secs(5), connect_timeout: Duration::from_secs(45), + adapter: Box::new(LocalAdapter::default()), + } + } +} +impl Clone for SocketIoConfig { + fn clone(&self) -> Self { + Self { + engine_config: self.engine_config.clone(), + ack_timeout: self.ack_timeout, + connect_timeout: self.connect_timeout, + adapter: self.adapter.boxed_clone(), } } } @@ -53,19 +67,17 @@ impl Default for SocketIoConfig { /// A builder to create a [`SocketIo`] instance. /// It contains everything to configure the socket.io server with a [`SocketIoConfig`]. /// It can be used to build either a Tower [`Layer`](tower::layer::Layer) or a [`Service`](tower::Service). -pub struct SocketIoBuilder { +pub struct SocketIoBuilder { config: SocketIoConfig, engine_config_builder: EngineIoConfigBuilder, - adapter: std::marker::PhantomData, } -impl SocketIoBuilder { +impl SocketIoBuilder { /// Creates a new [`SocketIoBuilder`] with default config pub fn new() -> Self { Self { config: SocketIoConfig::default(), engine_config_builder: EngineIoConfigBuilder::new().req_path("/socket.io".to_string()), - adapter: std::marker::PhantomData, } } @@ -154,12 +166,9 @@ impl SocketIoBuilder { } /// Sets a custom [`Adapter`] for this [`SocketIoBuilder`] - pub fn with_adapter(self) -> SocketIoBuilder { - SocketIoBuilder { - config: self.config, - engine_config_builder: self.engine_config_builder, - adapter: std::marker::PhantomData, - } + pub fn with_adapter(mut self, adapter: impl Adapter) -> SocketIoBuilder { + self.config.adapter = Box::new(adapter); + self } /// Add a custom global state for the [`SocketIo`] instance. @@ -176,7 +185,7 @@ impl SocketIoBuilder { /// Builds a [`SocketIoLayer`] and a [`SocketIo`] instance /// /// The layer can be used as a tower layer - pub fn build_layer(mut self) -> (SocketIoLayer, SocketIo) { + pub fn build_layer(mut self) -> (SocketIoLayer, SocketIo) { self.config.engine_config = self.engine_config_builder.build(); let (layer, client) = SocketIoLayer::from_config(Arc::new(self.config)); @@ -215,9 +224,9 @@ impl Default for SocketIoBuilder { /// The [`SocketIo`] instance can be cheaply cloned and moved around everywhere in your program. /// It can be used as the main handle to access the whole socket.io context. #[derive(Debug)] -pub struct SocketIo(Arc>); +pub struct SocketIo(Arc); -impl SocketIo { +impl SocketIo { /// Creates a new [`SocketIoBuilder`] with a default config #[inline(always)] pub fn builder() -> SocketIoBuilder { @@ -247,7 +256,7 @@ impl SocketIo { } } -impl SocketIo { +impl SocketIo { /// Returns a reference to the [`SocketIoConfig`] used by this [`SocketIo`] instance #[inline] pub fn config(&self) -> &SocketIoConfig { @@ -336,7 +345,7 @@ impl SocketIo { #[inline] pub fn ns(&self, path: impl Into>, callback: C) where - C: ConnectHandler, + C: ConnectHandler, T: Send + Sync + 'static, { self.0.add_ns(path.into(), callback); @@ -382,7 +391,7 @@ impl SocketIo { /// println!("found socket on /custom_ns namespace with id: {}", socket.id); /// } #[inline] - pub fn of<'a>(&self, path: impl Into<&'a str>) -> Option> { + pub fn of<'a>(&self, path: impl Into<&'a str>) -> Option { self.get_op(path.into()) } @@ -408,7 +417,7 @@ impl SocketIo { /// println!("found socket on / ns in room1 with id: {}", socket.id); /// } #[inline] - pub fn to(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn to(&self, rooms: impl RoomParam) -> BroadcastOperators { self.get_default_op().to(rooms) } @@ -436,7 +445,7 @@ impl SocketIo { /// println!("found socket on / ns in room1 with id: {}", socket.id); /// } #[inline] - pub fn within(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn within(&self, rooms: impl RoomParam) -> BroadcastOperators { self.get_default_op().within(rooms) } @@ -469,7 +478,7 @@ impl SocketIo { /// println!("found socket on / ns in room1 with id: {}", socket.id); /// } #[inline] - pub fn except(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn except(&self, rooms: impl RoomParam) -> BroadcastOperators { self.get_default_op().except(rooms) } @@ -496,7 +505,7 @@ impl SocketIo { /// println!("found socket on / ns in room1 with id: {}", socket.id); /// } #[inline] - pub fn local(&self) -> BroadcastOperators { + pub fn local(&self) -> BroadcastOperators { self.get_default_op().local() } @@ -539,7 +548,7 @@ impl SocketIo { /// } /// }); #[inline] - pub fn timeout(&self, timeout: Duration) -> BroadcastOperators { + pub fn timeout(&self, timeout: Duration) -> BroadcastOperators { self.get_default_op().timeout(timeout) } @@ -568,7 +577,7 @@ impl SocketIo { /// .bin(vec![Bytes::from_static(&[1, 2, 3, 4])]) /// .emit("test", ()); #[inline] - pub fn bin(&self, binary: impl IntoIterator>) -> BroadcastOperators { + pub fn bin(&self, binary: impl IntoIterator>) -> BroadcastOperators { self.get_default_op().bin(binary) } @@ -695,7 +704,7 @@ impl SocketIo { /// println!("found socket on / ns in room1 with id: {}", socket.id); /// } #[inline] - pub fn sockets(&self) -> Result>, A::Error> { + pub fn sockets(&self) -> Result, AdapterError> { self.get_default_op().sockets() } @@ -739,7 +748,7 @@ impl SocketIo { /// // Later in your code you can for example add all sockets on the root namespace to the room1 and room3 /// io.join(["room1", "room3"]).unwrap(); #[inline] - pub fn join(self, rooms: impl RoomParam) -> Result<(), A::Error> { + pub fn join(self, rooms: impl RoomParam) -> Result<(), AdapterError> { self.get_default_op().join(rooms) } @@ -760,7 +769,7 @@ impl SocketIo { /// let rooms = io2.rooms().unwrap(); /// println!("All rooms on / namespace: {:?}", rooms); /// }); - pub fn rooms(&self) -> Result, A::Error> { + pub fn rooms(&self) -> Result, AdapterError> { self.get_default_op().rooms() } @@ -782,19 +791,19 @@ impl SocketIo { /// // Later in your code you can for example remove all sockets on the root namespace from the room1 and room3 /// io.leave(["room1", "room3"]).unwrap(); #[inline] - pub fn leave(self, rooms: impl RoomParam) -> Result<(), A::Error> { + pub fn leave(self, rooms: impl RoomParam) -> Result<(), AdapterError> { self.get_default_op().leave(rooms) } /// Gets a [`SocketRef`] by the specified [`Sid`]. #[inline] - pub fn get_socket(&self, sid: Sid) -> Option> { + pub fn get_socket(&self, sid: Sid) -> Option { self.get_default_op().get_socket(sid) } /// Returns a new operator on the given namespace #[inline(always)] - fn get_op(&self, path: &str) -> Option> { + fn get_op(&self, path: &str) -> Option { self.0 .get_ns(path) .map(|ns| BroadcastOperators::new(ns).broadcast()) @@ -806,19 +815,19 @@ impl SocketIo { /// /// If the **default namespace "/" is not found** this fn will panic! #[inline(always)] - fn get_default_op(&self) -> BroadcastOperators { + fn get_default_op(&self) -> BroadcastOperators { self.get_op("/").expect("default namespace not found") } } -impl Clone for SocketIo { +impl Clone for SocketIo { fn clone(&self) -> Self { Self(self.0.clone()) } } #[cfg(any(test, socketioxide_test))] -impl SocketIo { +impl SocketIo { /// Create a dummy socket for testing purpose with a /// receiver to get the packets sent to the client pub async fn new_dummy_sock( diff --git a/socketioxide/src/layer.rs b/socketioxide/src/layer.rs index 1a656766..581d9fe6 100644 --- a/socketioxide/src/layer.rs +++ b/socketioxide/src/layer.rs @@ -21,18 +21,17 @@ use std::sync::Arc; use tower::Layer; use crate::{ - adapter::{Adapter, LocalAdapter}, client::Client, service::SocketIoService, SocketIoConfig, }; /// A [`Layer`] for [`SocketIoService`], acting as a middleware. -pub struct SocketIoLayer { - client: Arc>, +pub struct SocketIoLayer { + client: Arc, } -impl Clone for SocketIoLayer { +impl Clone for SocketIoLayer { fn clone(&self) -> Self { Self { client: self.client.clone(), @@ -40,8 +39,8 @@ impl Clone for SocketIoLayer { } } -impl SocketIoLayer { - pub(crate) fn from_config(config: Arc) -> (Self, Arc>) { +impl SocketIoLayer { + pub(crate) fn from_config(config: Arc) -> (Self, Arc) { let client = Arc::new(Client::new(config.clone())); let layer = Self { client: client.clone(), @@ -50,8 +49,8 @@ impl SocketIoLayer { } } -impl Layer for SocketIoLayer { - type Service = SocketIoService; +impl Layer for SocketIoLayer { + type Service = SocketIoService; fn layer(&self, inner: S) -> Self::Service { SocketIoService::with_client(inner, self.client.clone()) diff --git a/socketioxide/src/ns.rs b/socketioxide/src/ns.rs index 7f810b8a..8744d951 100644 --- a/socketioxide/src/ns.rs +++ b/socketioxide/src/ns.rs @@ -5,34 +5,41 @@ use std::{ }; use crate::{ - adapter::Adapter, - errors::{ConnectFail, Error}, + adapter::{Adapter, LocalAdapter}, + client::SocketData, + errors::{AdapterError, ConnectFail, Error}, handler::{BoxedConnectHandler, ConnectHandler, MakeErasedHandler}, packet::{Packet, PacketData}, socket::{DisconnectReason, Socket}, SocketIoConfig, }; -use crate::{client::SocketData, errors::AdapterError}; use engineioxide::sid::Sid; -pub struct Namespace { +pub struct Namespace { pub path: Cow<'static, str>, - pub(crate) adapter: A, - handler: BoxedConnectHandler, - sockets: RwLock>>>, + pub(crate) adapter: Box, + handler: BoxedConnectHandler, + sockets: RwLock>>, } -impl Namespace { - pub fn new(path: Cow<'static, str>, handler: C) -> Arc +impl Namespace { + pub fn new( + path: Cow<'static, str>, + handler: C, + mut adapter: Box, + ) -> Arc where - C: ConnectHandler + Send + Sync + 'static, + C: ConnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { - Arc::new_cyclic(|ns| Self { - path, - handler: MakeErasedHandler::new_ns_boxed(handler), - sockets: HashMap::new().into(), - adapter: A::new(ns.clone()), + Arc::new_cyclic(move |ns: &std::sync::Weak<_>| { + adapter.init(ns.clone()); + Self { + path, + handler: MakeErasedHandler::new_ns_boxed(handler), + sockets: HashMap::new().into(), + adapter, + } }) } @@ -49,7 +56,7 @@ impl Namespace { auth: Option, config: Arc, ) -> Result<(), ConnectFail> { - let socket: Arc> = Socket::new(sid, self.clone(), esocket.clone(), config).into(); + let socket: Arc = Socket::new(sid, self.clone(), esocket.clone(), config).into(); if let Err(e) = self.handler.call_middleware(socket.clone(), &auth).await { #[cfg(feature = "tracing")] @@ -106,7 +113,7 @@ impl Namespace { } } - pub fn get_socket(&self, sid: Sid) -> Result>, Error> { + pub fn get_socket(&self, sid: Sid) -> Result, Error> { self.sockets .read() .unwrap() @@ -115,7 +122,7 @@ impl Namespace { .ok_or(Error::SocketGone(sid)) } - pub fn get_sockets(&self) -> Vec>> { + pub fn get_sockets(&self) -> Vec> { self.sockets.read().unwrap().values().cloned().collect() } @@ -159,9 +166,9 @@ impl Namespace { } #[cfg(any(test, socketioxide_test))] -impl Namespace { +impl Namespace { pub fn new_dummy(sockets: [Sid; S]) -> Arc { - let ns = Namespace::new(Cow::Borrowed("/"), || {}); + let ns = Namespace::new(Cow::Borrowed("/"), || {}, Box::new(LocalAdapter::default())); for sid in sockets { ns.sockets .write() @@ -176,7 +183,7 @@ impl Namespace { } } -impl std::fmt::Debug for Namespace { +impl std::fmt::Debug for Namespace { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Namespace") .field("path", &self.path) @@ -187,7 +194,7 @@ impl std::fmt::Debug for Namespace { } #[cfg(feature = "tracing")] -impl Drop for Namespace { +impl Drop for Namespace { fn drop(&mut self) { #[cfg(feature = "tracing")] tracing::debug!("dropping namespace {}", self.path); diff --git a/socketioxide/src/operators.rs b/socketioxide/src/operators.rs index 46a9754f..cc43a6e7 100644 --- a/socketioxide/src/operators.rs +++ b/socketioxide/src/operators.rs @@ -13,11 +13,9 @@ use bytes::Bytes; use engineioxide::sid::Sid; use crate::ack::{AckInnerStream, AckStream}; -use crate::adapter::LocalAdapter; -use crate::errors::{BroadcastError, DisconnectError}; +use crate::errors::{AdapterError, BroadcastError, DisconnectError, SendError}; use crate::extract::SocketRef; use crate::socket::Socket; -use crate::SendError; use crate::{ adapter::{Adapter, BroadcastFlags, BroadcastOptions, Room}, ns::Namespace, @@ -103,21 +101,21 @@ impl RoomParam for Sid { } /// Chainable operators to configure the message to be sent. -pub struct ConfOperators<'a, A: Adapter = LocalAdapter> { +pub struct ConfOperators<'a> { binary: Vec, timeout: Option, - socket: &'a Socket, + socket: &'a Socket, } /// Chainable operators to select sockets to send a message to and to configure the message to be sent. -pub struct BroadcastOperators { +pub struct BroadcastOperators { binary: Vec, timeout: Option, - ns: Arc>, + ns: Arc, opts: BroadcastOptions, } -impl From> for BroadcastOperators { - fn from(conf: ConfOperators<'_, A>) -> Self { +impl From> for BroadcastOperators { + fn from(conf: ConfOperators<'_>) -> Self { let opts = BroadcastOptions { sid: Some(conf.socket.id), ..Default::default() @@ -132,8 +130,8 @@ impl From> for BroadcastOperators { } // ==== impl ConfOperators operations ==== -impl<'a, A: Adapter> ConfOperators<'a, A> { - pub(crate) fn new(sender: &'a Socket) -> Self { +impl<'a> ConfOperators<'a> { + pub(crate) fn new(sender: &'a Socket) -> Self { Self { binary: vec![], timeout: None, @@ -161,7 +159,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { /// .emit("test", data); /// }); /// }); - pub fn to(self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn to(self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from(self).to(rooms) } @@ -185,7 +183,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { /// .emit("test", data); /// }); /// }); - pub fn within(self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn within(self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from(self).within(rooms) } @@ -208,7 +206,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { /// socket.broadcast().except("room1").emit("test", data); /// }); /// }); - pub fn except(self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn except(self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from(self).except(rooms) } @@ -225,7 +223,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { /// socket.local().emit("test", data); /// }); /// }); - pub fn local(self) -> BroadcastOperators { + pub fn local(self) -> BroadcastOperators { BroadcastOperators::from(self).local() } @@ -241,7 +239,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { /// socket.broadcast().emit("test", data); /// }); /// }); - pub fn broadcast(self) -> BroadcastOperators { + pub fn broadcast(self) -> BroadcastOperators { BroadcastOperators::from(self).broadcast() } @@ -304,7 +302,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { } // ==== impl ConfOperators consume fns ==== -impl ConfOperators<'_, A> { +impl ConfOperators<'_> { /// Emits a message to the client and apply the previous operators on the message. /// /// If you provide array-like data (tuple, vec, arrays), it will be considered as multiple arguments. @@ -452,7 +450,7 @@ impl ConfOperators<'_, A> { /// socket.within("room1").within("room3").join(["room4", "room5"]).unwrap(); /// }); /// }); - pub fn join(self, rooms: impl RoomParam) -> Result<(), A::Error> { + pub fn join(self, rooms: impl RoomParam) -> Result<(), AdapterError> { self.socket.join(rooms) } @@ -468,12 +466,12 @@ impl ConfOperators<'_, A> { /// socket.within("room1").within("room3").leave(["room4", "room5"]).unwrap(); /// }); /// }); - pub fn leave(self, rooms: impl RoomParam) -> Result<(), A::Error> { + pub fn leave(self, rooms: impl RoomParam) -> Result<(), AdapterError> { self.socket.leave(rooms) } /// Gets all room names for a given namespace - pub fn rooms(self) -> Result, A::Error> { + pub fn rooms(self) -> Result, AdapterError> { self.socket.rooms() } @@ -495,8 +493,8 @@ impl ConfOperators<'_, A> { } } -impl BroadcastOperators { - pub(crate) fn new(ns: Arc>) -> Self { +impl BroadcastOperators { + pub(crate) fn new(ns: Arc) -> Self { Self { binary: vec![], timeout: None, @@ -504,7 +502,7 @@ impl BroadcastOperators { opts: BroadcastOptions::default(), } } - pub(crate) fn from_sock(ns: Arc>, sid: Sid) -> Self { + pub(crate) fn from_sock(ns: Arc, sid: Sid) -> Self { Self { binary: vec![], timeout: None, @@ -684,7 +682,7 @@ impl BroadcastOperators { } // ==== impl BroadcastOperators consume fns ==== -impl BroadcastOperators { +impl BroadcastOperators { /// Emits a message to all sockets selected with the previous operators. /// /// If you provide array-like data (tuple, vec, arrays), it will be considered as multiple arguments. @@ -826,7 +824,7 @@ impl BroadcastOperators { /// } /// }); /// }); - pub fn sockets(self) -> Result>, A::Error> { + pub fn sockets(self) -> Result, AdapterError> { self.ns.adapter.fetch_sockets(self.opts) } @@ -858,8 +856,10 @@ impl BroadcastOperators { /// socket.within("room1").within("room3").join(["room4", "room5"]).unwrap(); /// }); /// }); - pub fn join(self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.add_sockets(self.opts, rooms) + pub fn join(self, rooms: impl RoomParam) -> Result<(), AdapterError> { + self.ns + .adapter + .add_sockets(self.opts, rooms.into_room_iter().collect()) } /// Makes all sockets selected with the previous operators leave the given room(s). @@ -874,17 +874,19 @@ impl BroadcastOperators { /// socket.within("room1").within("room3").leave(["room4", "room5"]).unwrap(); /// }); /// }); - pub fn leave(self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.del_sockets(self.opts, rooms) + pub fn leave(self, rooms: impl RoomParam) -> Result<(), AdapterError> { + self.ns + .adapter + .del_sockets(self.opts, rooms.into_room_iter().collect()) } /// Gets all room names for a given namespace - pub fn rooms(self) -> Result, A::Error> { + pub fn rooms(self) -> Result, AdapterError> { self.ns.adapter.rooms() } /// Gets a [`SocketRef`] by the specified [`Sid`]. - pub fn get_socket(&self, sid: Sid) -> Option> { + pub fn get_socket(&self, sid: Sid) -> Option { self.ns.get_socket(sid).map(SocketRef::from).ok() } diff --git a/socketioxide/src/service.rs b/socketioxide/src/service.rs index 91bfda74..51d861c7 100644 --- a/socketioxide/src/service.rs +++ b/socketioxide/src/service.rs @@ -54,30 +54,29 @@ use std::{ use tower::Service as TowerSvc; use crate::{ - adapter::{Adapter, LocalAdapter}, + adapter::{Adapter}, client::Client, SocketIoConfig, }; /// A [`Tower`](TowerSvc)/[`Hyper`](HyperSvc) Service that wraps [`EngineIoService`] and /// redirect every request to it -pub struct SocketIoService { - engine_svc: EngineIoService>, S>, +pub struct SocketIoService { + engine_svc: EngineIoService, S>, } /// Tower Service implementation. -impl TowerSvc> for SocketIoService +impl TowerSvc> for SocketIoService where ReqBody: Body + Send + Unpin + std::fmt::Debug + 'static, ::Error: std::fmt::Debug, ::Data: Send, ResBody: Body + Send + 'static, S: TowerSvc, Response = Response> + Clone, - A: Adapter, { - type Response = >, S> as TowerSvc>>::Response; - type Error = >, S> as TowerSvc>>::Error; - type Future = >, S> as TowerSvc>>::Future; + type Response = , S> as TowerSvc>>::Response; + type Error = , S> as TowerSvc>>::Error; + type Future = , S> as TowerSvc>>::Future; #[inline(always)] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { @@ -90,18 +89,17 @@ where } /// Hyper 1.0 Service implementation. -impl HyperSvc> for SocketIoService +impl HyperSvc> for SocketIoService where ReqBody: Body + Send + Unpin + std::fmt::Debug + 'static, ::Error: std::fmt::Debug, ::Data: Send, ResBody: Body + Send + 'static, S: HyperSvc, Response = Response> + Clone, - A: Adapter, { - type Response = >, S> as HyperSvc>>::Response; - type Error = >, S> as HyperSvc>>::Error; - type Future = >, S> as HyperSvc>>::Future; + type Response = , S> as HyperSvc>>::Response; + type Error = , S> as HyperSvc>>::Error; + type Future = , S> as HyperSvc>>::Future; #[inline(always)] fn call(&self, req: Request) -> Self::Future { @@ -109,18 +107,15 @@ where } } -impl SocketIoService { +impl SocketIoService { /// Creates a MakeService which can be used as a hyper service #[inline(always)] - pub fn into_make_service(self) -> MakeEngineIoService>, S> { + pub fn into_make_service(self) -> MakeEngineIoService, S> { self.engine_svc.into_make_service() } /// Creates a new [`EngineIoService`] with a custom inner service and a custom config. - pub(crate) fn with_config_inner( - inner: S, - config: Arc, - ) -> (Self, Arc>) { + pub(crate) fn with_config_inner(inner: S, config: Arc) -> (Self, Arc) { let engine_config = config.engine_config.clone(); let client = Arc::new(Client::new(config)); let svc = EngineIoService::with_config_inner(inner, client.clone(), engine_config); @@ -129,14 +124,14 @@ impl SocketIoService { /// Creates a new [`EngineIoService`] with a custom inner service and an existing client /// It is mainly used with a [`SocketIoLayer`](crate::layer::SocketIoLayer) that owns the client - pub(crate) fn with_client(inner: S, client: Arc>) -> Self { + pub(crate) fn with_client(inner: S, client: Arc) -> Self { let engine_config = client.config.engine_config.clone(); let svc = EngineIoService::with_config_inner(inner, client, engine_config); Self { engine_svc: svc } } } -impl Clone for SocketIoService { +impl Clone for SocketIoService { fn clone(&self) -> Self { Self { engine_svc: self.engine_svc.clone(), diff --git a/socketioxide/src/socket.rs b/socketioxide/src/socket.rs index b082d8c0..2ac37eff 100644 --- a/socketioxide/src/socket.rs +++ b/socketioxide/src/socket.rs @@ -23,7 +23,7 @@ use crate::extensions::Extensions; use crate::{ ack::{AckInnerStream, AckResponse, AckResult, AckStream}, - adapter::{Adapter, LocalAdapter, Room}, + adapter::{Adapter, Room}, errors::{DisconnectError, Error, SendError}, handler::{ BoxedDisconnectHandler, BoxedMessageHandler, DisconnectHandler, MakeErasedHandler, @@ -127,11 +127,11 @@ impl<'a> PermitExt<'a> for Permit<'a> { /// A Socket represents a client connected to a namespace. /// It is used to send and receive messages from the client, join and leave rooms, etc. /// The socket struct itself should not be used directly, but through a [`SocketRef`](crate::extract::SocketRef). -pub struct Socket { +pub struct Socket { pub(crate) config: Arc, - pub(crate) ns: Arc>, - message_handlers: RwLock, BoxedMessageHandler>>, - disconnect_handler: Mutex>>, + pub(crate) ns: Arc, + message_handlers: RwLock, BoxedMessageHandler>>, + disconnect_handler: Mutex>, ack_message: Mutex>>>, ack_counter: AtomicI64, connected: AtomicBool, @@ -149,10 +149,10 @@ pub struct Socket { esocket: Arc>, } -impl Socket { +impl Socket { pub(crate) fn new( sid: Sid, - ns: Arc>, + ns: Arc, esocket: Arc>, config: Arc, ) -> Self { @@ -225,7 +225,7 @@ impl Socket { /// ``` pub fn on(&self, event: impl Into>, handler: H) where - H: MessageHandler, + H: MessageHandler, T: Send + Sync + 'static, { self.message_handlers @@ -259,7 +259,7 @@ impl Socket { /// }); pub fn on_disconnect(&self, callback: C) where - C: DisconnectHandler + Send + Sync + 'static, + C: DisconnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { let handler = MakeErasedHandler::new_disconnect_boxed(callback); @@ -411,8 +411,10 @@ impl Socket { /// ## Errors /// When using a distributed adapter, it can return an [`Adapter::Error`] which is mostly related to network errors. /// For the default [`LocalAdapter`] it is always an [`Infallible`](std::convert::Infallible) error - pub fn join(&self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.add_all(self.id, rooms) + pub fn join(&self, rooms: impl RoomParam) -> Result<(), AdapterError> { + self.ns + .adapter + .add_all(self.id, rooms.into_room_iter().collect()) } /// Leaves the given rooms. @@ -421,15 +423,17 @@ impl Socket { /// ## Errors /// When using a distributed adapter, it can return an [`Adapter::Error`] which is mostly related to network errors. /// For the default [`LocalAdapter`] it is always an [`Infallible`](std::convert::Infallible) error - pub fn leave(&self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.del(self.id, rooms) + pub fn leave(&self, rooms: impl RoomParam) -> Result<(), AdapterError> { + self.ns + .adapter + .del(self.id, rooms.into_room_iter().collect()) } /// Leaves all rooms where the socket is connected. /// ## Errors /// When using a distributed adapter, it can return an [`Adapter::Error`] which is mostly related to network errors. /// For the default [`LocalAdapter`] it is always an [`Infallible`](std::convert::Infallible) error - pub fn leave_all(&self) -> Result<(), A::Error> { + pub fn leave_all(&self) -> Result<(), AdapterError> { self.ns.adapter.del_all(self.id) } @@ -437,7 +441,7 @@ impl Socket { /// ## Errors /// When using a distributed adapter, it can return an [`Adapter::Error`] which is mostly related to network errors. /// For the default [`LocalAdapter`] it is always an [`Infallible`](std::convert::Infallible) error - pub fn rooms(&self) -> Result, A::Error> { + pub fn rooms(&self) -> Result, AdapterError> { self.ns.adapter.socket_rooms(self.id) } @@ -471,7 +475,7 @@ impl Socket { /// .emit("test", data); /// }); /// }); - pub fn to(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn to(&self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from_sock(self.ns.clone(), self.id).to(rooms) } @@ -495,7 +499,7 @@ impl Socket { /// .emit("test", data); /// }); /// }); - pub fn within(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn within(&self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from_sock(self.ns.clone(), self.id).within(rooms) } @@ -519,7 +523,7 @@ impl Socket { /// socket.broadcast().except("room1").emit("test", data); /// }); /// }); - pub fn except(&self, rooms: impl RoomParam) -> BroadcastOperators { + pub fn except(&self, rooms: impl RoomParam) -> BroadcastOperators { BroadcastOperators::from_sock(self.ns.clone(), self.id).except(rooms) } @@ -537,7 +541,7 @@ impl Socket { /// socket.local().emit("test", data); /// }); /// }); - pub fn local(&self) -> BroadcastOperators { + pub fn local(&self) -> BroadcastOperators { BroadcastOperators::from_sock(self.ns.clone(), self.id).local() } @@ -576,7 +580,7 @@ impl Socket { /// }); /// }); /// - pub fn timeout(&self, timeout: Duration) -> ConfOperators<'_, A> { + pub fn timeout(&self, timeout: Duration) -> ConfOperators<'_> { ConfOperators::new(self).timeout(timeout) } @@ -593,7 +597,7 @@ impl Socket { /// socket.bin(bin).emit("test", data); /// }); /// }); - pub fn bin(&self, binary: impl IntoIterator>) -> ConfOperators<'_, A> { + pub fn bin(&self, binary: impl IntoIterator>) -> ConfOperators<'_> { ConfOperators::new(self).bin(binary) } @@ -610,7 +614,7 @@ impl Socket { /// socket.broadcast().emit("test", data); /// }); /// }); - pub fn broadcast(&self) -> BroadcastOperators { + pub fn broadcast(&self) -> BroadcastOperators { BroadcastOperators::from_sock(self.ns.clone(), self.id).broadcast() } @@ -798,7 +802,7 @@ impl Socket { } } -impl Debug for Socket { +impl Debug for Socket { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Socket") .field("ns", &self.ns()) @@ -808,16 +812,16 @@ impl Debug for Socket { .finish() } } -impl PartialEq for Socket { +impl PartialEq for Socket { fn eq(&self, other: &Self) -> bool { self.id == other.id } } #[cfg(any(test, socketioxide_test))] -impl Socket { +impl Socket { /// Creates a dummy socket for testing purposes - pub fn new_dummy(sid: Sid, ns: Arc>) -> Socket { + pub fn new_dummy(sid: Sid, ns: Arc) -> Socket { let close_fn = Box::new(move |_, _| ()); let s = Socket::new( sid, @@ -837,7 +841,7 @@ mod test { #[tokio::test] async fn send_with_ack_error() { let sid = Sid::new(); - let ns = Namespace::::new_dummy([sid]).into(); + let ns = Namespace::new_dummy([sid]).into(); let socket: Arc = Socket::new_dummy(sid, ns).into(); // Saturate the channel for _ in 0..1024 { diff --git a/socketioxide/tests/fixture.rs b/socketioxide/tests/fixture.rs index 04d23dc0..7d127a12 100644 --- a/socketioxide/tests/fixture.rs +++ b/socketioxide/tests/fixture.rs @@ -111,7 +111,7 @@ pub async fn create_server(port: u16) -> SocketIo { io } -async fn spawn_server(port: u16, svc: SocketIoService) { +async fn spawn_server(port: u16, svc: SocketIoService) { let addr = &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port); let listener = TcpListener::bind(&addr).await.unwrap(); tokio::spawn(async move {