Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to handle events that expect an ack #463

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions socketio/src/asynchronous/client/ack.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::time::Duration;

use crate::asynchronous::client::callback::Callback;
use crate::{asynchronous::client::callback::Callback, AckId};
use tokio::time::Instant;

use super::callback::DynAsyncCallback;
Expand All @@ -11,7 +11,7 @@ use super::callback::DynAsyncCallback;
/// won't contain data.
#[derive(Debug)]
pub(crate) struct Ack {
pub id: i32,
pub id: AckId,
pub timeout: Duration,
pub time_started: Instant,
pub callback: Callback<DynAsyncCallback>,
Expand Down
81 changes: 81 additions & 0 deletions socketio/src/asynchronous/client/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use super::{
client::{Client, ReconnectSettings},
};
use crate::asynchronous::socket::Socket as InnerSocket;
use crate::AckId;

/// A builder class for a `socket.io` socket. This handles setting up the client and
/// configuring the callback, the namespace and metadata of the socket. If no
Expand Down Expand Up @@ -190,6 +191,51 @@ impl ClientBuilder {
+ 'static
+ Send
+ Sync,
{
self.on.insert(
event.into(),
Callback::<DynAsyncCallback>::new_no_ack(callback),
);
self
}

/// Registers a new callback for a certain [`crate::event::Event`] that expects the client to
/// ack. The event could either be one of the common events like `message`, `error`, `open`,
/// `close` or a custom event defined by a string, e.g. `onPayment` or `foo`.
///
/// # Example
/// ```rust
/// use rust_socketio::{asynchronous::{ClientBuilder, Client}, AckId, Payload};
/// use futures_util::FutureExt;
///
/// #[tokio::main]
/// async fn main() {
/// let socket = ClientBuilder::new("http://localhost:4200/")
/// .namespace("/admin")
/// .on("test", |payload: Payload, client: Client, ack: AckId| {
/// async move {
/// match payload {
/// Payload::Text(values) => println!("Received: {:#?}", values),
/// Payload::Binary(bin_data) => println!("Received bytes: {:#?}", bin_data),
/// // This is deprecated, use Payload::Text instead
/// Payload::String(str) => println!("Received: {}", str),
/// }
/// client.ack(ack, "received").await;
/// }
/// .boxed()
/// })
/// .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
/// .connect()
/// .await;
/// }
///
#[cfg(feature = "async-callbacks")]
pub fn on_with_ack<T: Into<Event>, F>(mut self, event: T, callback: F) -> Self
where
F: for<'a> std::ops::FnMut(Payload, Client, AckId) -> BoxFuture<'static, ()>
+ 'static
+ Send
+ Sync,
{
self.on
.insert(event.into(), Callback::<DynAsyncCallback>::new(callback));
Expand Down Expand Up @@ -257,6 +303,41 @@ impl ClientBuilder {
pub fn on_any<F>(mut self, callback: F) -> Self
where
F: for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync,
{
self.on_any = Some(Callback::<DynAsyncAnyCallback>::new_no_ack(callback));
self
}

/// Registers a Callback for all [`crate::event::Event::Custom`] and
/// [`crate::event::Event::Message`] that expect the client to ack.
///
/// # Example
/// ```rust
/// use rust_socketio::{asynchronous::ClientBuilder, Payload};
/// use futures_util::future::FutureExt;
///
/// #[tokio::main]
/// async fn main() {
/// let client = ClientBuilder::new("http://localhost:4200/")
/// .namespace("/admin")
/// .on_any(|event, payload, client, ack| {
/// async {
/// if let Payload::String(str) = payload {
/// println!("{}: {}", String::from(event), str);
/// }
/// client.ack(ack, "received").await;
/// }.boxed()
/// })
/// .connect()
/// .await;
/// }
/// ```
pub fn on_any_with_ack<F>(mut self, callback: F) -> Self
where
F: for<'a> FnMut(Event, Payload, Client, AckId) -> BoxFuture<'static, ()>
+ 'static
+ Send
+ Sync,
{
self.on_any = Some(Callback::<DynAsyncAnyCallback>::new(callback));
self
Expand Down
71 changes: 56 additions & 15 deletions socketio/src/asynchronous/client/callback.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
use futures_util::future::BoxFuture;
use futures_util::{future::BoxFuture, FutureExt};
use std::{
fmt::Debug,
future::Future,
ops::{Deref, DerefMut},
};

use crate::{Event, Payload};
use crate::{AckId, Event, Payload};

use super::client::{Client, ReconnectSettings};

/// Internal type, provides a way to store futures and return them in a boxed manner.
pub(crate) type DynAsyncCallback =
Box<dyn for<'a> FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync>;
pub(crate) type DynAsyncCallback = Box<
dyn for<'a> FnMut(Payload, Client, Option<AckId>) -> BoxFuture<'static, ()>
+ 'static
+ Send
+ Sync,
>;

pub(crate) type DynAsyncAnyCallback = Box<
dyn for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync,
dyn for<'a> FnMut(Event, Payload, Client, Option<AckId>) -> BoxFuture<'static, ()>
+ 'static
+ Send
+ Sync,
>;

pub(crate) type DynAsyncReconnectSettingsCallback =
Expand All @@ -30,8 +38,10 @@ impl<T> Debug for Callback<T> {
}

impl Deref for Callback<DynAsyncCallback> {
type Target =
dyn for<'a> FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send;
type Target = dyn for<'a> FnMut(Payload, Client, Option<AckId>) -> BoxFuture<'static, ()>
+ 'static
+ Sync
+ Send;

fn deref(&self) -> &Self::Target {
self.inner.as_ref()
Expand All @@ -45,19 +55,34 @@ impl DerefMut for Callback<DynAsyncCallback> {
}

impl Callback<DynAsyncCallback> {
pub(crate) fn new<T>(callback: T) -> Self
pub(crate) fn new<T>(mut callback: T) -> Self
where
T: for<'a> FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send,
T: for<'a> FnMut(Payload, Client, AckId) -> BoxFuture<'static, ()> + 'static + Sync + Send,
{
Callback {
inner: Box::new(callback),
inner: Box::new(move |p, c, a| match a {
Some(a) => callback(p, c, a).boxed(),
None => std::future::ready(()).boxed(),
}),
}
}

pub(crate) fn new_no_ack<T, Fut>(mut callback: T) -> Self
where
T: FnMut(Payload, Client) -> Fut + Sync + Send + 'static,
Fut: Future<Output = ()> + 'static + Send,
{
Callback {
inner: Box::new(move |p, c, _a| callback(p, c).boxed()),
}
}
}

impl Deref for Callback<DynAsyncAnyCallback> {
type Target =
dyn for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send;
type Target = dyn for<'a> FnMut(Event, Payload, Client, Option<AckId>) -> BoxFuture<'static, ()>
+ 'static
+ Sync
+ Send;

fn deref(&self) -> &Self::Target {
self.inner.as_ref()
Expand All @@ -71,12 +96,28 @@ impl DerefMut for Callback<DynAsyncAnyCallback> {
}

impl Callback<DynAsyncAnyCallback> {
pub(crate) fn new<T>(callback: T) -> Self
pub(crate) fn new<T>(mut callback: T) -> Self
where
T: for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send,
T: for<'a> FnMut(Event, Payload, Client, AckId) -> BoxFuture<'static, ()>
+ 'static
+ Sync
+ Send,
{
Callback {
inner: Box::new(callback),
inner: Box::new(move |e, p, c, a| match a {
Some(a) => callback(e, p, c, a).boxed(),
None => std::future::ready(()).boxed(),
}),
}
}

pub(crate) fn new_no_ack<T, Fut>(mut callback: T) -> Self
where
T: FnMut(Event, Payload, Client) -> Fut + Sync + Send + 'static,
Fut: Future<Output = ()> + 'static + Send,
{
Callback {
inner: Box::new(move |e, p, c, _a| callback(e, p, c).boxed()),
}
}
}
Expand Down
51 changes: 36 additions & 15 deletions socketio/src/asynchronous/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
asynchronous::socket::Socket as InnerSocket,
error::{Error, Result},
packet::{Packet, PacketId},
Event, Payload,
AckId, Event, Payload,
};

#[derive(Default)]
Expand Down Expand Up @@ -359,15 +359,15 @@ impl Client {
E: Into<Event>,
D: Into<Payload>,
{
let id = thread_rng().gen_range(0..999);
let id = AckId::new(thread_rng().gen_range(0..999));
let socket_packet =
Packet::new_from_payload(data.into(), event.into(), &self.nsp, Some(id))?;

let ack = Ack {
id,
time_started: Instant::now(),
timeout,
callback: Callback::<DynAsyncCallback>::new(callback),
callback: Callback::<DynAsyncCallback>::new_no_ack(callback),
};

// add the ack to the tuple of outstanding acks
Expand All @@ -376,19 +376,33 @@ impl Client {
self.socket.read().await.send(socket_packet).await
}

async fn callback<P: Into<Payload>>(&self, event: &Event, payload: P) -> Result<()> {
pub async fn ack<D>(&self, ack_id: AckId, data: D) -> Result<()>
where
D: Into<Payload>,
{
let socket_packet = Packet::new_ack(data.into(), &self.nsp, ack_id);

self.socket.read().await.send(socket_packet).await
}

async fn callback<P: Into<Payload>>(
&self,
event: &Event,
payload: P,
ack_id: Option<AckId>,
) -> Result<()> {
let mut builder = self.builder.write().await;
let payload = payload.into();

if let Some(callback) = builder.on.get_mut(event) {
callback(payload.clone(), self.clone()).await;
callback(payload.clone(), self.clone(), ack_id).await;
}

// Call on_any for all common and custom events.
match event {
Event::Message | Event::Custom(_) => {
if let Some(callback) = builder.on_any.as_mut() {
callback(event.clone(), payload, self.clone()).await;
callback(event.clone(), payload, self.clone(), ack_id).await;
}
}
_ => (),
Expand All @@ -411,6 +425,7 @@ impl Client {
ack.callback.deref_mut()(
Payload::from(payload.to_owned()),
self.clone(),
None,
)
.await;
}
Expand All @@ -419,6 +434,7 @@ impl Client {
ack.callback.deref_mut()(
Payload::Binary(payload.to_owned()),
self.clone(),
None,
)
.await;
}
Expand Down Expand Up @@ -446,8 +462,12 @@ impl Client {

if let Some(attachments) = &packet.attachments {
if let Some(binary_payload) = attachments.get(0) {
self.callback(&event, Payload::Binary(binary_payload.to_owned()))
.await?;
self.callback(
&event,
Payload::Binary(binary_payload.to_owned()),
packet.id,
)
.await?;
}
}
Ok(())
Expand Down Expand Up @@ -480,7 +500,7 @@ impl Client {
};

// call the correct callback
self.callback(&event, payloads.to_vec()).await?;
self.callback(&event, payloads.to_vec(), packet.id).await?;
}

Ok(())
Expand All @@ -495,22 +515,22 @@ impl Client {
match packet.packet_type {
PacketId::Ack | PacketId::BinaryAck => {
if let Err(err) = self.handle_ack(packet).await {
self.callback(&Event::Error, err.to_string()).await?;
self.callback(&Event::Error, err.to_string(), None).await?;
return Err(err);
}
}
PacketId::BinaryEvent => {
if let Err(err) = self.handle_binary_event(packet).await {
self.callback(&Event::Error, err.to_string()).await?;
self.callback(&Event::Error, err.to_string(), None).await?;
}
}
PacketId::Connect => {
*(self.disconnect_reason.write().await) = DisconnectReason::default();
self.callback(&Event::Connect, "").await?;
self.callback(&Event::Connect, "", None).await?;
}
PacketId::Disconnect => {
*(self.disconnect_reason.write().await) = DisconnectReason::Server;
self.callback(&Event::Close, "").await?;
self.callback(&Event::Close, "", None).await?;
}
PacketId::ConnectError => {
self.callback(
Expand All @@ -520,12 +540,13 @@ impl Client {
.data
.as_ref()
.unwrap_or(&String::from("\"No error message provided\"")),
None,
)
.await?;
}
PacketId::Event => {
if let Err(err) = self.handle_event(packet).await {
self.callback(&Event::Error, err.to_string()).await?;
self.callback(&Event::Error, err.to_string(), None).await?;
}
}
}
Expand All @@ -547,7 +568,7 @@ impl Client {
None => None,
Some(Err(err)) => {
// call the error callback
match self.callback(&Event::Error, err.to_string()).await {
match self.callback(&Event::Error, err.to_string(), None).await {
Err(callback_err) => Some((Err(callback_err), socket)),
Ok(_) => Some((Err(err), socket)),
}
Expand Down
Loading
Loading