From 3748c5b148ae00f82feeb89ea9812cb783f35111 Mon Sep 17 00:00:00 2001 From: mendess Date: Mon, 16 Sep 2024 19:05:38 +0100 Subject: [PATCH 1/2] Add AckId --- socketio/src/asynchronous/client/ack.rs | 4 +-- socketio/src/asynchronous/client/client.rs | 4 +-- socketio/src/client/raw_client.rs | 6 ++-- socketio/src/lib.rs | 4 ++- socketio/src/packet.rs | 39 ++++++++++++---------- socketio/src/payload.rs | 25 ++++++++++++++ 6 files changed, 57 insertions(+), 25 deletions(-) diff --git a/socketio/src/asynchronous/client/ack.rs b/socketio/src/asynchronous/client/ack.rs index ef43a4bc..2250b6db 100644 --- a/socketio/src/asynchronous/client/ack.rs +++ b/socketio/src/asynchronous/client/ack.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use crate::asynchronous::client::callback::Callback; +use crate::{asynchronous::client::callback::Callback, AckId}; use tokio::time::Instant; use super::callback::DynAsyncCallback; @@ -11,7 +11,7 @@ use super::callback::DynAsyncCallback; /// won't contain data. #[derive(Debug)] pub(crate) struct Ack { - pub id: i32, + pub id: AckId, pub timeout: Duration, pub time_started: Instant, pub callback: Callback, diff --git a/socketio/src/asynchronous/client/client.rs b/socketio/src/asynchronous/client/client.rs index 67feb7db..72725263 100644 --- a/socketio/src/asynchronous/client/client.rs +++ b/socketio/src/asynchronous/client/client.rs @@ -19,7 +19,7 @@ use crate::{ asynchronous::socket::Socket as InnerSocket, error::{Error, Result}, packet::{Packet, PacketId}, - Event, Payload, + AckId, Event, Payload, }; #[derive(Default)] @@ -359,7 +359,7 @@ impl Client { E: Into, D: Into, { - let id = thread_rng().gen_range(0..999); + let id = AckId::new(thread_rng().gen_range(0..999)); let socket_packet = Packet::new_from_payload(data.into(), event.into(), &self.nsp, Some(id))?; diff --git a/socketio/src/client/raw_client.rs b/socketio/src/client/raw_client.rs index 0686683f..9fbb5ab3 100644 --- a/socketio/src/client/raw_client.rs +++ b/socketio/src/client/raw_client.rs @@ -1,7 +1,7 @@ use super::callback::Callback; use crate::packet::{Packet, PacketId}; -use crate::Error; pub(crate) use crate::{event::Event, payload::Payload}; +use crate::{AckId, Error}; use rand::{thread_rng, Rng}; use serde_json::Value; @@ -21,7 +21,7 @@ use crate::socket::Socket as InnerSocket; /// won't contain data. #[derive(Debug)] pub struct Ack { - pub id: i32, + pub id: AckId, timeout: Duration, time_started: Instant, callback: Callback, @@ -203,7 +203,7 @@ impl RawClient { E: Into, D: Into, { - let id = thread_rng().gen_range(0..999); + let id = AckId::new(thread_rng().gen_range(0..999)); let socket_packet = Packet::new_from_payload(data.into(), event.into(), &self.nsp, Some(id))?; diff --git a/socketio/src/lib.rs b/socketio/src/lib.rs index b913eb4d..080a7a0e 100644 --- a/socketio/src/lib.rs +++ b/socketio/src/lib.rs @@ -176,7 +176,7 @@ pub(crate) mod packet; /// Defines the types of payload (binary or string), that /// could be sent or received. pub mod payload; -pub(self) mod socket; +mod socket; /// Deprecated import since 0.3.0-alpha-2, use Error in the crate root instead. /// Contains the error type which will be returned with every result in this @@ -195,6 +195,8 @@ pub use {event::Event, payload::Payload}; pub use client::{ClientBuilder, RawClient, TransportType}; +pub use payload::AckId; + // TODO: 0.4.0 remove #[deprecated(since = "0.3.0-alpha-2", note = "Socket renamed to Client")] pub use client::{ClientBuilder as SocketBuilder, RawClient as Socket}; diff --git a/socketio/src/packet.rs b/socketio/src/packet.rs index e74dedb5..48cc6312 100644 --- a/socketio/src/packet.rs +++ b/socketio/src/packet.rs @@ -1,5 +1,5 @@ use crate::error::{Error, Result}; -use crate::{Event, Payload}; +use crate::{AckId, Event, Payload}; use bytes::Bytes; use serde::de::IgnoredAny; @@ -25,7 +25,7 @@ pub struct Packet { pub packet_type: PacketId, pub nsp: String, pub data: Option, - pub id: Option, + pub id: Option, pub attachment_count: u8, pub attachments: Option>, } @@ -38,7 +38,7 @@ impl Packet { payload: Payload, event: Event, nsp: &'a str, - id: Option, + id: Option, ) -> Result { match payload { Payload::Binary(bin_data) => Ok(Packet::new( @@ -132,7 +132,7 @@ impl Packet { packet_type: PacketId, nsp: String, data: Option, - id: Option, + id: Option, attachment_count: u8, attachments: Option>, ) -> Self { @@ -360,7 +360,7 @@ mod test { PacketId::Event, "/admin".to_owned(), Some(String::from("[\"project:delete\",123]")), - Some(456), + Some(AckId::new(10)), 0, None, ), @@ -376,7 +376,7 @@ mod test { PacketId::Ack, "/admin".to_owned(), Some(String::from("[]")), - Some(456), + Some(AckId::new(10)), 0, None, ), @@ -426,7 +426,7 @@ mod test { PacketId::BinaryEvent, "/admin".to_owned(), Some(String::from("\"project:delete\"")), - Some(456), + Some(AckId::new(10)), 1, None, ), @@ -442,7 +442,7 @@ mod test { PacketId::BinaryAck, "/admin".to_owned(), None, - Some(456), + Some(AckId::new(10)), 1, None, ), @@ -511,7 +511,7 @@ mod test { PacketId::Event, "/admin".to_owned(), Some(String::from("[\"project:delete\",123]")), - Some(456), + Some(AckId::new(10)), 0, None, ); @@ -527,7 +527,7 @@ mod test { PacketId::Ack, "/admin".to_owned(), Some(String::from("[]")), - Some(456), + Some(AckId::new(10)), 0, None, ); @@ -573,7 +573,7 @@ mod test { PacketId::BinaryEvent, "/admin".to_owned(), Some(String::from("\"project:delete\"")), - Some(456), + Some(AckId::new(10)), 1, Some(vec![Bytes::from_static(&[1, 2, 3])]), ); @@ -589,7 +589,7 @@ mod test { PacketId::BinaryAck, "/admin".to_owned(), None, - Some(456), + Some(AckId::new(10)), 1, Some(vec![Bytes::from_static(&[3, 2, 1])]), ); @@ -635,7 +635,7 @@ mod test { payload.clone(), "other_event".into(), "other_namespace", - Some(10), + Some(AckId::new(10)), ) .unwrap(); assert_eq!( @@ -644,7 +644,7 @@ mod test { packet_type: PacketId::Event, nsp: "other_namespace".to_owned(), data: Some("[\"other_event\",\"test\"]".to_owned()), - id: Some(10), + id: Some(AckId::new(10)), attachment_count: 0, attachments: None } @@ -657,15 +657,20 @@ mod test { serde_json::json!("String test"), serde_json::json!({"type":"object"}), ]); - let result = - Packet::new_from_payload(payload.clone(), "third_event".into(), "/", Some(10)).unwrap(); + let result = Packet::new_from_payload( + payload.clone(), + "third_event".into(), + "/", + Some(AckId::new(10)), + ) + .unwrap(); assert_eq!( result, Packet { packet_type: PacketId::Event, nsp: "/".to_owned(), data: Some("[\"third_event\",\"String test\",{\"type\":\"object\"}]".to_owned()), - id: Some(10), + id: Some(AckId::new(10)), attachment_count: 0, attachments: None } diff --git a/socketio/src/payload.rs b/socketio/src/payload.rs index 2fde2c1b..e56b966e 100644 --- a/socketio/src/payload.rs +++ b/socketio/src/payload.rs @@ -1,3 +1,5 @@ +use std::{fmt, str::FromStr}; + use bytes::Bytes; /// A type which represents a `payload` in the `socket.io` context. @@ -72,6 +74,29 @@ impl From for Payload { } } +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] +pub struct AckId(i32); + +impl AckId { + pub(crate) const fn new(id: i32) -> Self { + Self(id) + } +} + +impl fmt::Display for AckId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl FromStr for AckId { + type Err = ::Err; + + fn from_str(s: &str) -> Result { + s.parse().map(Self) + } +} + #[cfg(test)] mod tests { use serde_json::json; From b5935fbef0497d1ea55f175d681e0a2071cfff9e Mon Sep 17 00:00:00 2001 From: mendess Date: Mon, 16 Sep 2024 20:17:29 +0100 Subject: [PATCH 2/2] Add ackable events --- socketio/src/asynchronous/client/builder.rs | 81 ++++++++++++++++++++ socketio/src/asynchronous/client/callback.rs | 71 +++++++++++++---- socketio/src/asynchronous/client/client.rs | 47 ++++++++---- socketio/src/client/builder.rs | 67 +++++++++++++++- socketio/src/client/callback.rs | 45 ++++++++--- socketio/src/client/raw_client.rs | 52 +++++++++---- socketio/src/packet.rs | 39 +++++++++- 7 files changed, 348 insertions(+), 54 deletions(-) diff --git a/socketio/src/asynchronous/client/builder.rs b/socketio/src/asynchronous/client/builder.rs index 44710e19..9b457b7a 100644 --- a/socketio/src/asynchronous/client/builder.rs +++ b/socketio/src/asynchronous/client/builder.rs @@ -17,6 +17,7 @@ use super::{ client::{Client, ReconnectSettings}, }; use crate::asynchronous::socket::Socket as InnerSocket; +use crate::AckId; /// A builder class for a `socket.io` socket. This handles setting up the client and /// configuring the callback, the namespace and metadata of the socket. If no @@ -190,6 +191,51 @@ impl ClientBuilder { + 'static + Send + Sync, + { + self.on.insert( + event.into(), + Callback::::new_no_ack(callback), + ); + self + } + + /// Registers a new callback for a certain [`crate::event::Event`] that expects the client to + /// ack. The event could either be one of the common events like `message`, `error`, `open`, + /// `close` or a custom event defined by a string, e.g. `onPayment` or `foo`. + /// + /// # Example + /// ```rust + /// use rust_socketio::{asynchronous::{ClientBuilder, Client}, AckId, Payload}; + /// use futures_util::FutureExt; + /// + /// #[tokio::main] + /// async fn main() { + /// let socket = ClientBuilder::new("http://localhost:4200/") + /// .namespace("/admin") + /// .on("test", |payload: Payload, client: Client, ack: AckId| { + /// async move { + /// match payload { + /// Payload::Text(values) => println!("Received: {:#?}", values), + /// Payload::Binary(bin_data) => println!("Received bytes: {:#?}", bin_data), + /// // This is deprecated, use Payload::Text instead + /// Payload::String(str) => println!("Received: {}", str), + /// } + /// client.ack(ack, "received").await; + /// } + /// .boxed() + /// }) + /// .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed()) + /// .connect() + /// .await; + /// } + /// + #[cfg(feature = "async-callbacks")] + pub fn on_with_ack, F>(mut self, event: T, callback: F) -> Self + where + F: for<'a> std::ops::FnMut(Payload, Client, AckId) -> BoxFuture<'static, ()> + + 'static + + Send + + Sync, { self.on .insert(event.into(), Callback::::new(callback)); @@ -257,6 +303,41 @@ impl ClientBuilder { pub fn on_any(mut self, callback: F) -> Self where F: for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync, + { + self.on_any = Some(Callback::::new_no_ack(callback)); + self + } + + /// Registers a Callback for all [`crate::event::Event::Custom`] and + /// [`crate::event::Event::Message`] that expect the client to ack. + /// + /// # Example + /// ```rust + /// use rust_socketio::{asynchronous::ClientBuilder, Payload}; + /// use futures_util::future::FutureExt; + /// + /// #[tokio::main] + /// async fn main() { + /// let client = ClientBuilder::new("http://localhost:4200/") + /// .namespace("/admin") + /// .on_any(|event, payload, client, ack| { + /// async { + /// if let Payload::String(str) = payload { + /// println!("{}: {}", String::from(event), str); + /// } + /// client.ack(ack, "received").await; + /// }.boxed() + /// }) + /// .connect() + /// .await; + /// } + /// ``` + pub fn on_any_with_ack(mut self, callback: F) -> Self + where + F: for<'a> FnMut(Event, Payload, Client, AckId) -> BoxFuture<'static, ()> + + 'static + + Send + + Sync, { self.on_any = Some(Callback::::new(callback)); self diff --git a/socketio/src/asynchronous/client/callback.rs b/socketio/src/asynchronous/client/callback.rs index 3188b175..47f027a0 100644 --- a/socketio/src/asynchronous/client/callback.rs +++ b/socketio/src/asynchronous/client/callback.rs @@ -1,19 +1,27 @@ -use futures_util::future::BoxFuture; +use futures_util::{future::BoxFuture, FutureExt}; use std::{ fmt::Debug, + future::Future, ops::{Deref, DerefMut}, }; -use crate::{Event, Payload}; +use crate::{AckId, Event, Payload}; use super::client::{Client, ReconnectSettings}; /// Internal type, provides a way to store futures and return them in a boxed manner. -pub(crate) type DynAsyncCallback = - Box FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync>; +pub(crate) type DynAsyncCallback = Box< + dyn for<'a> FnMut(Payload, Client, Option) -> BoxFuture<'static, ()> + + 'static + + Send + + Sync, +>; pub(crate) type DynAsyncAnyCallback = Box< - dyn for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync, + dyn for<'a> FnMut(Event, Payload, Client, Option) -> BoxFuture<'static, ()> + + 'static + + Send + + Sync, >; pub(crate) type DynAsyncReconnectSettingsCallback = @@ -30,8 +38,10 @@ impl Debug for Callback { } impl Deref for Callback { - type Target = - dyn for<'a> FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send; + type Target = dyn for<'a> FnMut(Payload, Client, Option) -> BoxFuture<'static, ()> + + 'static + + Sync + + Send; fn deref(&self) -> &Self::Target { self.inner.as_ref() @@ -45,19 +55,34 @@ impl DerefMut for Callback { } impl Callback { - pub(crate) fn new(callback: T) -> Self + pub(crate) fn new(mut callback: T) -> Self where - T: for<'a> FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send, + T: for<'a> FnMut(Payload, Client, AckId) -> BoxFuture<'static, ()> + 'static + Sync + Send, { Callback { - inner: Box::new(callback), + inner: Box::new(move |p, c, a| match a { + Some(a) => callback(p, c, a).boxed(), + None => std::future::ready(()).boxed(), + }), + } + } + + pub(crate) fn new_no_ack(mut callback: T) -> Self + where + T: FnMut(Payload, Client) -> Fut + Sync + Send + 'static, + Fut: Future + 'static + Send, + { + Callback { + inner: Box::new(move |p, c, _a| callback(p, c).boxed()), } } } impl Deref for Callback { - type Target = - dyn for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send; + type Target = dyn for<'a> FnMut(Event, Payload, Client, Option) -> BoxFuture<'static, ()> + + 'static + + Sync + + Send; fn deref(&self) -> &Self::Target { self.inner.as_ref() @@ -71,12 +96,28 @@ impl DerefMut for Callback { } impl Callback { - pub(crate) fn new(callback: T) -> Self + pub(crate) fn new(mut callback: T) -> Self where - T: for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send, + T: for<'a> FnMut(Event, Payload, Client, AckId) -> BoxFuture<'static, ()> + + 'static + + Sync + + Send, { Callback { - inner: Box::new(callback), + inner: Box::new(move |e, p, c, a| match a { + Some(a) => callback(e, p, c, a).boxed(), + None => std::future::ready(()).boxed(), + }), + } + } + + pub(crate) fn new_no_ack(mut callback: T) -> Self + where + T: FnMut(Event, Payload, Client) -> Fut + Sync + Send + 'static, + Fut: Future + 'static + Send, + { + Callback { + inner: Box::new(move |e, p, c, _a| callback(e, p, c).boxed()), } } } diff --git a/socketio/src/asynchronous/client/client.rs b/socketio/src/asynchronous/client/client.rs index 72725263..01b44639 100644 --- a/socketio/src/asynchronous/client/client.rs +++ b/socketio/src/asynchronous/client/client.rs @@ -367,7 +367,7 @@ impl Client { id, time_started: Instant::now(), timeout, - callback: Callback::::new(callback), + callback: Callback::::new_no_ack(callback), }; // add the ack to the tuple of outstanding acks @@ -376,19 +376,33 @@ impl Client { self.socket.read().await.send(socket_packet).await } - async fn callback>(&self, event: &Event, payload: P) -> Result<()> { + pub async fn ack(&self, ack_id: AckId, data: D) -> Result<()> + where + D: Into, + { + let socket_packet = Packet::new_ack(data.into(), &self.nsp, ack_id); + + self.socket.read().await.send(socket_packet).await + } + + async fn callback>( + &self, + event: &Event, + payload: P, + ack_id: Option, + ) -> Result<()> { let mut builder = self.builder.write().await; let payload = payload.into(); if let Some(callback) = builder.on.get_mut(event) { - callback(payload.clone(), self.clone()).await; + callback(payload.clone(), self.clone(), ack_id).await; } // Call on_any for all common and custom events. match event { Event::Message | Event::Custom(_) => { if let Some(callback) = builder.on_any.as_mut() { - callback(event.clone(), payload, self.clone()).await; + callback(event.clone(), payload, self.clone(), ack_id).await; } } _ => (), @@ -411,6 +425,7 @@ impl Client { ack.callback.deref_mut()( Payload::from(payload.to_owned()), self.clone(), + None, ) .await; } @@ -419,6 +434,7 @@ impl Client { ack.callback.deref_mut()( Payload::Binary(payload.to_owned()), self.clone(), + None, ) .await; } @@ -446,8 +462,12 @@ impl Client { if let Some(attachments) = &packet.attachments { if let Some(binary_payload) = attachments.get(0) { - self.callback(&event, Payload::Binary(binary_payload.to_owned())) - .await?; + self.callback( + &event, + Payload::Binary(binary_payload.to_owned()), + packet.id, + ) + .await?; } } Ok(()) @@ -480,7 +500,7 @@ impl Client { }; // call the correct callback - self.callback(&event, payloads.to_vec()).await?; + self.callback(&event, payloads.to_vec(), packet.id).await?; } Ok(()) @@ -495,22 +515,22 @@ impl Client { match packet.packet_type { PacketId::Ack | PacketId::BinaryAck => { if let Err(err) = self.handle_ack(packet).await { - self.callback(&Event::Error, err.to_string()).await?; + self.callback(&Event::Error, err.to_string(), None).await?; return Err(err); } } PacketId::BinaryEvent => { if let Err(err) = self.handle_binary_event(packet).await { - self.callback(&Event::Error, err.to_string()).await?; + self.callback(&Event::Error, err.to_string(), None).await?; } } PacketId::Connect => { *(self.disconnect_reason.write().await) = DisconnectReason::default(); - self.callback(&Event::Connect, "").await?; + self.callback(&Event::Connect, "", None).await?; } PacketId::Disconnect => { *(self.disconnect_reason.write().await) = DisconnectReason::Server; - self.callback(&Event::Close, "").await?; + self.callback(&Event::Close, "", None).await?; } PacketId::ConnectError => { self.callback( @@ -520,12 +540,13 @@ impl Client { .data .as_ref() .unwrap_or(&String::from("\"No error message provided\"")), + None, ) .await?; } PacketId::Event => { if let Err(err) = self.handle_event(packet).await { - self.callback(&Event::Error, err.to_string()).await?; + self.callback(&Event::Error, err.to_string(), None).await?; } } } @@ -547,7 +568,7 @@ impl Client { None => None, Some(Err(err)) => { // call the error callback - match self.callback(&Event::Error, err.to_string()).await { + match self.callback(&Event::Error, err.to_string(), None).await { Err(callback_err) => Some((Err(callback_err), socket)), Ok(_) => Some((Err(err), socket)), } diff --git a/socketio/src/client/builder.rs b/socketio/src/client/builder.rs index 724971f0..7b6e53b9 100644 --- a/socketio/src/client/builder.rs +++ b/socketio/src/client/builder.rs @@ -1,7 +1,7 @@ use super::super::{event::Event, payload::Payload}; use super::callback::Callback; use super::client::Client; -use crate::RawClient; +use crate::{AckId, RawClient}; use native_tls::TlsConnector; use rust_engineio::client::ClientBuilder as EngineIoClientBuilder; use rust_engineio::header::{HeaderMap, HeaderValue}; @@ -173,6 +173,41 @@ impl ClientBuilder { pub fn on, F>(mut self, event: T, callback: F) -> Self where F: FnMut(Payload, RawClient) + 'static + Send, + { + let callback = Callback::::new_no_ack(callback); + // SAFETY: Lock is held for such amount of time no code paths lead to a panic while lock is held + self.on.lock().unwrap().insert(event.into(), callback); + self + } + + /// Registers a new callback for a certain [`crate::event::Event`] that expects the client to + /// ack. The event could either be one of the common events like `message`, `error`, `open`, + /// `close` or a custom event defined by a string, e.g. `onPayment` or `foo`. + /// + /// # Example + /// ```rust + /// use rust_socketio::{ClientBuilder, Payload}; + /// + /// let socket = ClientBuilder::new("http://localhost:4200/") + /// .namespace("/admin") + /// .on("test", |payload: Payload, client, ack_id| { + /// match payload { + /// Payload::Text(values) => println!("Received: {:#?}", values), + /// Payload::Binary(bin_data) => println!("Received bytes: {:#?}", bin_data), + /// // This payload type is deprecated, use Payload::Text instead + /// Payload::String(str) => println!("Received: {}", str), + /// } + /// client.ack(ack_id, "received"); + /// }) + /// .on("error", |err, _| eprintln!("Error: {:#?}", err)) + /// .connect(); + /// + /// ``` + // While present implementation doesn't require mut, it's reasonable to require mutability. + #[allow(unused_mut)] + pub fn on_with_ack, F>(mut self, event: T, callback: F) -> Self + where + F: FnMut(Payload, RawClient, AckId) + 'static + Send, { let callback = Callback::::new(callback); // SAFETY: Lock is held for such amount of time no code paths lead to a panic while lock is held @@ -201,6 +236,36 @@ impl ClientBuilder { pub fn on_any(mut self, callback: F) -> Self where F: FnMut(Event, Payload, RawClient) + 'static + Send, + { + let callback = Some(Callback::::new_no_ack(callback)); + // SAFETY: Lock is held for such amount of time no code paths lead to a panic while lock is held + *self.on_any.lock().unwrap() = callback; + self + } + + /// Registers a Callback for all [`crate::event::Event::Custom`] and + /// [`crate::event::Event::Message`] that expects the client to ack. + /// + /// # Example + /// ```rust + /// use rust_socketio::{ClientBuilder, Payload}; + /// + /// let client = ClientBuilder::new("http://localhost:4200/") + /// .namespace("/admin") + /// .on_any(|event, payload, client, ack_id| { + /// if let Payload::String(str) = payload { + /// println!("{} {}", String::from(event), str); + /// } + /// client.ack(ack_id, "received") + /// }) + /// .connect(); + /// + /// ``` + // While present implementation doesn't require mut, it's reasonable to require mutability. + #[allow(unused_mut)] + pub fn on_any_with_ack(mut self, callback: F) -> Self + where + F: FnMut(Event, Payload, RawClient, AckId) + 'static + Send, { let callback = Some(Callback::::new(callback)); // SAFETY: Lock is held for such amount of time no code paths lead to a panic while lock is held diff --git a/socketio/src/client/callback.rs b/socketio/src/client/callback.rs index 1015ec03..2a9d7a7b 100644 --- a/socketio/src/client/callback.rs +++ b/socketio/src/client/callback.rs @@ -4,10 +4,11 @@ use std::{ }; use super::RawClient; -use crate::{Event, Payload}; +use crate::{AckId, Event, Payload}; -pub(crate) type SocketCallback = Box; -pub(crate) type SocketAnyCallback = Box; +pub(crate) type SocketCallback = Box) + 'static + Send>; +pub(crate) type SocketAnyCallback = + Box) + 'static + Send>; pub(crate) struct Callback { inner: T, @@ -22,7 +23,7 @@ impl Debug for Callback { } impl Deref for Callback { - type Target = dyn FnMut(Payload, RawClient) + 'static + Send; + type Target = dyn FnMut(Payload, RawClient, Option) + 'static + Send; fn deref(&self) -> &Self::Target { self.inner.as_ref() @@ -36,12 +37,25 @@ impl DerefMut for Callback { } impl Callback { - pub(crate) fn new(callback: T) -> Self + pub(crate) fn new(mut callback: T) -> Self + where + T: FnMut(Payload, RawClient, AckId) + 'static + Send, + { + Callback { + inner: Box::new(move |p, c, a| { + if let Some(a) = a { + callback(p, c, a) + } + }), + } + } + + pub(crate) fn new_no_ack(mut callback: T) -> Self where T: FnMut(Payload, RawClient) + 'static + Send, { Callback { - inner: Box::new(callback), + inner: Box::new(move |p, c, _a| callback(p, c)), } } } @@ -55,7 +69,7 @@ impl Debug for Callback { } impl Deref for Callback { - type Target = dyn FnMut(Event, Payload, RawClient) + 'static + Send; + type Target = dyn FnMut(Event, Payload, RawClient, Option) + 'static + Send; fn deref(&self) -> &Self::Target { self.inner.as_ref() @@ -69,12 +83,25 @@ impl DerefMut for Callback { } impl Callback { - pub(crate) fn new(callback: T) -> Self + pub(crate) fn new(mut callback: T) -> Self + where + T: FnMut(Event, Payload, RawClient, AckId) + 'static + Send, + { + Callback { + inner: Box::new(move |e, p, c, a| { + if let Some(a) = a { + callback(e, p, c, a) + } + }), + } + } + + pub(crate) fn new_no_ack(mut callback: T) -> Self where T: FnMut(Event, Payload, RawClient) + 'static + Send, { Callback { - inner: Box::new(callback), + inner: Box::new(move |e, p, c, _a| callback(e, p, c)), } } } diff --git a/socketio/src/client/raw_client.rs b/socketio/src/client/raw_client.rs index 9fbb5ab3..65743299 100644 --- a/socketio/src/client/raw_client.rs +++ b/socketio/src/client/raw_client.rs @@ -149,7 +149,7 @@ impl RawClient { let _ = self.socket.send(disconnect_packet); self.socket.disconnect()?; - let _ = self.callback(&Event::Close, ""); // trigger on_close + let _ = self.callback(&Event::Close, "", None); // trigger on_close Ok(()) } @@ -211,7 +211,7 @@ impl RawClient { id, time_started: Instant::now(), timeout, - callback: Callback::::new(callback), + callback: Callback::::new_no_ack(callback), }; // add the ack to the tuple of outstanding acks @@ -221,11 +221,19 @@ impl RawClient { Ok(()) } + pub fn ack(&self, ack_id: AckId, data: D) -> Result<()> + where + D: Into, + { + let socket_packet = Packet::new_ack(data.into(), &self.nsp, ack_id); + self.socket.send(socket_packet) + } + pub(crate) fn poll(&self) -> Result> { loop { match self.socket.poll() { Err(err) => { - self.callback(&Event::Error, err.to_string())?; + self.callback(&Event::Error, err.to_string(), None)?; return Err(err); } Ok(Some(packet)) => { @@ -246,7 +254,12 @@ impl RawClient { Iter { socket: self } } - fn callback>(&self, event: &Event, payload: P) -> Result<()> { + fn callback>( + &self, + event: &Event, + payload: P, + ack_id: Option, + ) -> Result<()> { let mut on = self.on.lock()?; let mut on_any = self.on_any.lock()?; let lock = on.deref_mut(); @@ -255,12 +268,12 @@ impl RawClient { let payload = payload.into(); if let Some(callback) = lock.get_mut(event) { - callback(payload.clone(), self.clone()); + callback(payload.clone(), self.clone(), ack_id); } match event { Event::Message | Event::Custom(_) => { if let Some(callback) = on_any_lock { - callback(event.clone(), payload, self.clone()) + callback(event.clone(), payload, self.clone(), ack_id) } } _ => {} @@ -284,12 +297,16 @@ impl RawClient { if ack.time_started.elapsed() < ack.timeout { if let Some(ref payload) = socket_packet.data { - ack.callback.deref_mut()(Payload::from(payload.to_owned()), self.clone()); + ack.callback.deref_mut()(Payload::from(payload.to_owned()), self.clone(), None); } if let Some(ref attachments) = socket_packet.attachments { if let Some(payload) = attachments.first() { - ack.callback.deref_mut()(Payload::Binary(payload.to_owned()), self.clone()); + ack.callback.deref_mut()( + Payload::Binary(payload.to_owned()), + self.clone(), + None, + ); } } } @@ -312,7 +329,11 @@ impl RawClient { if let Some(attachments) = &packet.attachments { if let Some(binary_payload) = attachments.first() { - self.callback(&event, Payload::Binary(binary_payload.to_owned()))?; + self.callback( + &event, + Payload::Binary(binary_payload.to_owned()), + packet.id, + )?; } } Ok(()) @@ -344,7 +365,7 @@ impl RawClient { }; // call the correct callback - self.callback(&event, payloads.to_vec())?; + self.callback(&event, payloads.to_vec(), packet.id)?; } Ok(()) @@ -359,20 +380,20 @@ impl RawClient { match packet.packet_type { PacketId::Ack | PacketId::BinaryAck => { if let Err(err) = self.handle_ack(packet) { - self.callback(&Event::Error, err.to_string())?; + self.callback(&Event::Error, err.to_string(), None)?; return Err(err); } } PacketId::BinaryEvent => { if let Err(err) = self.handle_binary_event(packet) { - self.callback(&Event::Error, err.to_string())?; + self.callback(&Event::Error, err.to_string(), None)?; } } PacketId::Connect => { - self.callback(&Event::Connect, "")?; + self.callback(&Event::Connect, "", None)?; } PacketId::Disconnect => { - self.callback(&Event::Close, "")?; + self.callback(&Event::Close, "", None)?; } PacketId::ConnectError => { self.callback( @@ -382,11 +403,12 @@ impl RawClient { .clone() .data .unwrap_or_else(|| String::from("\"No error message provided\"")), + None, )?; } PacketId::Event => { if let Err(err) = self.handle_event(packet) { - self.callback(&Event::Error, err.to_string())?; + self.callback(&Event::Error, err.to_string(), None)?; } } } diff --git a/socketio/src/packet.rs b/socketio/src/packet.rs index 48cc6312..689a64bf 100644 --- a/socketio/src/packet.rs +++ b/socketio/src/packet.rs @@ -88,6 +88,43 @@ impl Packet { } } } + + pub(crate) fn new_ack(payload: Payload, nsp: &str, id: AckId) -> Self { + match payload { + Payload::Text(data) => Packet::new( + PacketId::Ack, + nsp.to_owned(), + Some(serde_json::Value::Array(data).to_string()), + Some(id), + 0, + None, + ), + #[allow(deprecated)] + Payload::String(str_data) => { + let payload = if serde_json::from_str::(&str_data).is_ok() { + format!("[{str_data}]") + } else { + format!("[{str_data:?}]") + }; + Packet::new( + PacketId::Ack, + nsp.to_owned(), + Some(payload), + Some(id), + 0, + None, + ) + } + Payload::Binary(data) => Packet::new( + PacketId::BinaryAck, + nsp.to_owned(), + None, + Some(id), + 1, + Some(vec![data]), + ), + } + } } impl Default for Packet { @@ -605,7 +642,7 @@ mod test { #[test] fn test_illegal_packet_id() { let _sut = PacketId::try_from(42).expect_err("error!"); - assert!(matches!(Error::InvalidPacketId(42 as char), _sut)) + assert!(matches!(Error::InvalidPacketId(42u8 as char), _sut)) } #[test]