diff --git a/mqrstt/src/client.rs b/mqrstt/src/client.rs index 0a96e44..50bcc8a 100644 --- a/mqrstt/src/client.rs +++ b/mqrstt/src/client.rs @@ -46,56 +46,6 @@ impl MqttClient { max_packet_size, } } - - /// This function is only here for you to use during testing of for example your handler - /// For a simple client look at [`MqttClient::test_client`] - #[cfg(feature = "test")] - pub fn test_custom_client(available_packet_ids_r: Receiver, to_network_s: Sender, max_packet_size: usize) -> Self { - Self { - available_packet_ids_r, - to_network_s, - max_packet_size, - } - } - - /// This function is only here for you to use during testing of for example your handler - /// For control over the input of this type look at [`MqttClient::test_custom_client`] - /// - /// The returned values should not be dropped otherwise the client won't be able to operate normally. - /// - /// # Example - /// ```ignore - /// let ( - /// client, // An instance of this client - /// ids, // Allows you to indicate which packet IDs have become available again. - /// network_receiver // Messages send through the `client` will be dispatched through this channel - /// ) = MqttClient::test_client(); - /// - /// // perform testing - /// - /// // Make sure to not drop these before the test is done! - /// std::hint::black_box((ids, network_receiver)); - /// ``` - #[cfg(feature = "test")] - pub fn test_client() -> (Self, crate::available_packet_ids::AvailablePacketIds, Receiver) { - use async_channel::unbounded; - - use crate::{available_packet_ids::AvailablePacketIds, util::constants::MAXIMUM_PACKET_SIZE}; - - let (available_packet_ids, available_packet_ids_r) = AvailablePacketIds::new(u16::MAX); - - let (s, r) = unbounded(); - - ( - Self { - available_packet_ids_r, - to_network_s: s, - max_packet_size: MAXIMUM_PACKET_SIZE as usize, - }, - available_packet_ids, - r, - ) - } } /// Async functions to perform MQTT operations @@ -106,7 +56,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("example_id").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("example_id").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -140,6 +90,8 @@ impl MqttClient { sub.validate(self.max_packet_size)?; self.to_network_s.send(Packet::Subscribe(sub)).await.map_err(|_| ClientError::NoNetworkChannel)?; + #[cfg(feature = "logs")] + info!("Send to network: Subscribe with ID {:?}", pkid); Ok(()) } @@ -150,7 +102,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("example_id").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("example_id").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -210,7 +162,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (_, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (_, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -264,7 +216,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -334,7 +286,7 @@ impl MqttClient { /// # Examples /// /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// // Unsubscribe from a single topic specified as a string: @@ -381,7 +333,7 @@ impl MqttClient { /// # Examples /// /// ``` - /// # let (_, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (_, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::UnsubscribeProperties; @@ -450,7 +402,7 @@ impl MqttClient { /// # Example /// /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// mqtt_client.disconnect().await.unwrap(); @@ -476,7 +428,7 @@ impl MqttClient { /// # Example /// /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::DisconnectProperties; @@ -512,7 +464,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// use mqrstt::packets::QoS; /// use mqrstt::packets::{SubscriptionOptions, RetainHandling}; @@ -556,7 +508,7 @@ impl MqttClient { /// This function blocks until the packet is queued for transmission /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// use mqrstt::packets::QoS; /// use mqrstt::packets::{SubscribeProperties, SubscriptionOptions, RetainHandling}; @@ -616,7 +568,7 @@ impl MqttClient { /// This function blocks until the packet is queued for transmission /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -672,7 +624,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -742,7 +694,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// // Unsubscribe from a single topic specified as a string: @@ -790,7 +742,7 @@ impl MqttClient { /// # Examples /// /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::UnsubscribeProperties; @@ -851,7 +803,7 @@ impl MqttClient { /// # Example /// /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// mqtt_client.disconnect_blocking().unwrap(); @@ -877,7 +829,7 @@ impl MqttClient { /// /// ``` /// - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::DisconnectProperties; diff --git a/mqrstt/src/lib.rs b/mqrstt/src/lib.rs index ccd133c..9cfd25d 100644 --- a/mqrstt/src/lib.rs +++ b/mqrstt/src/lib.rs @@ -44,7 +44,7 @@ //! // To reconnect after a disconnect or error //! let (mut network, client) = NetworkBuilder //! ::new_from_client_id("mqrsttSmolExample") -//! .smol_sequential_network(); +//! .smol_network(); //! let stream = smol::net::TcpStream::connect(("broker.emqx.io", 1883)) //! .await //! .unwrap(); @@ -236,9 +236,9 @@ where /// ``` /// let (mut network, client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream> /// ::new_from_client_id("ExampleClient") - /// .smol_sequential_network(); + /// .smol_network(); /// ``` - pub fn smol_sequential_network(self) -> (smol::Network, MqttClient) { + pub fn smol_network(self) -> (smol::Network, MqttClient) { let (to_network_s, to_network_r) = async_channel::bounded(CHANNEL_SIZE); let (apkids, apkids_r) = available_packet_ids::AvailablePacketIds::new(self.options.send_maximum()); @@ -253,38 +253,6 @@ where } } -#[cfg(feature = "todo")] -/// Creates a new [`sync::Network`] and [`MqttClient`] that can be connected to a broker. -/// S should implement [`std::io::Read`] and [`std::io::Write`]. -/// Additionally, S should be made non_blocking otherwise it will not progress. -/// -/// # Example -/// -/// ``` -/// use mqrstt::ConnectOptions; -/// -/// let options = ConnectOptions::new("ExampleClient"); -/// let (network, client) = mqrstt::new_sync::(options); -/// ``` -pub fn new_sync(options: ConnectOptions) -> (sync::Network, MqttClient) -where - S: std::io::Read + std::io::Write + Sized + Unpin, -{ - use available_packet_ids::AvailablePacketIds; - - let (to_network_s, to_network_r) = async_channel::bounded(100); - - let (apkids, apkids_r) = AvailablePacketIds::new(options.send_maximum()); - - let max_packet_size = options.maximum_packet_size(); - - let client = MqttClient::new(apkids_r, to_network_s, max_packet_size); - - let network = sync::Network::new(options, to_network_r, apkids); - - (network, client) -} - #[cfg(test)] fn random_chars() -> String { rand::Rng::sample_iter(rand::thread_rng(), &rand::distributions::Alphanumeric).take(7).map(char::from).collect() @@ -310,7 +278,7 @@ mod smol_lib_test { let address = "broker.emqx.io"; let port = 1883; - let (mut network, client) = NetworkBuilder::new_from_options(options).smol_sequential_network(); + let (mut network, client) = NetworkBuilder::new_from_options(options).smol_network(); let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); let mut pingpong = PingPong::new(client.clone()); @@ -347,7 +315,7 @@ mod smol_lib_test { let address = "broker.emqx.io"; let port = 1883; - let (mut network, client) = NetworkBuilder::new_from_options(options).smol_sequential_network(); + let (mut network, client) = NetworkBuilder::new_from_options(options).smol_network(); let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); let mut pingresp = crate::example_handlers::PingResp::new(client.clone()); @@ -393,7 +361,7 @@ mod smol_lib_test { let (n, _) = futures::join!( async { - let (mut network, client) = NetworkBuilder::new_from_options(options).smol_sequential_network(); + let (mut network, client) = NetworkBuilder::new_from_options(options).smol_network(); let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); let mut pingresp = crate::example_handlers::PingResp::new(client.clone()); network.connect(stream, &mut pingresp).await diff --git a/mqrstt/src/state_handler.rs b/mqrstt/src/state_handler.rs index d0eeeb9..ecf7fd6 100644 --- a/mqrstt/src/state_handler.rs +++ b/mqrstt/src/state_handler.rs @@ -1,7 +1,6 @@ use crate::available_packet_ids::AvailablePacketIds; use crate::connect_options::ConnectOptions; use crate::error::HandlerError; -use crate::packets::{PubRecReasonCode, PubAckReasonCode, ConnAckReasonCode}; use crate::packets::PubComp; use crate::packets::PubRec; use crate::packets::PubRel; @@ -12,6 +11,7 @@ use crate::packets::Subscribe; use crate::packets::UnsubAck; use crate::packets::Unsubscribe; use crate::packets::{ConnAck, Disconnect}; +use crate::packets::{ConnAckReasonCode, PubAckReasonCode, PubRecReasonCode}; use crate::packets::{Packet, PacketType}; use crate::packets::{PubAck, PubAckProperties}; use crate::state::State; @@ -188,6 +188,7 @@ impl StateHandler { _a => { #[cfg(test)] unreachable!("Was given unexpected packet {:?} ", _a); + #[cfg(not(test))] Ok(()) } } @@ -247,13 +248,8 @@ mod handler_tests { use crate::{ available_packet_ids::AvailablePacketIds, packets::{ - Packet, - PubComp, PubCompProperties, PubCompReasonCode, - PubRec, PubRecProperties, PubRecReasonCode, - PubRel, PubRelProperties, PubRelReasonCode, - QoS, - SubAck, SubAckProperties, SubAckReasonCode, - UnsubAck, UnsubAckProperties, UnsubAckReasonCode + Packet, PubComp, PubCompProperties, PubCompReasonCode, PubRec, PubRecProperties, PubRecReasonCode, PubRel, PubRelProperties, PubRelReasonCode, QoS, SubAck, SubAckProperties, + SubAckReasonCode, UnsubAck, UnsubAckProperties, UnsubAckReasonCode, }, tests::test_packets::{create_connack_packet, create_puback_packet, create_publish_packet, create_subscribe_packet, create_unsubscribe_packet}, ConnectOptions, StateHandler, diff --git a/mqrstt/src/tokio/network.rs b/mqrstt/src/tokio/network.rs index 84efd01..3caa0e4 100644 --- a/mqrstt/src/tokio/network.rs +++ b/mqrstt/src/tokio/network.rs @@ -13,7 +13,7 @@ use crate::packets::{Disconnect, Packet, PacketType}; use crate::{AsyncEventHandler, NetworkStatus, StateHandler}; -use super::stream::Stream; +use super::stream::StreamExt; /// [`Network`] reads and writes to the network based on tokios [`::tokio::io::AsyncReadExt`] [`::tokio::io::AsyncWriteExt`]. /// This way you can provide the `connect` function with a TLS and TCP stream of your choosing. @@ -21,7 +21,7 @@ use super::stream::Stream; /// (i.e. you need to reconnect after any expected or unexpected disconnect). pub struct Network { handler: PhantomData, - network: Option>, + network: Option, /// Options of the current mqtt connection options: ConnectOptions, @@ -55,8 +55,8 @@ where S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin + Send + 'static, { /// Initializes an MQTT connection with the provided configuration an stream - pub async fn connect(&mut self, stream: S, handler: &mut H) -> Result<(), ConnectionError> { - let (mut network, conn_ack) = Stream::connect(&self.options, stream).await?; + pub async fn connect(&mut self, mut stream: S, handler: &mut H) -> Result<(), ConnectionError> { + let conn_ack = stream.connect(&self.options).await?; self.last_network_action = Instant::now(); if let Some(keep_alive_interval) = conn_ack.connack_properties.server_keep_alive { @@ -68,12 +68,12 @@ where let packets = self.state_handler.handle_incoming_connack(&conn_ack)?; handler.handle(Packet::ConnAck(conn_ack)).await; - if let Some(mut packets) = packets { - network.write_all(&mut packets).await?; + if let Some(packets) = packets { + stream.write_packets(&packets).await?; self.last_network_action = Instant::now(); } - self.network = Some(network); + self.network = Some(stream); Ok(()) } @@ -117,7 +117,6 @@ where } = self; let mut await_pingresp = None; - // let mut outgoing_packet_buffer = Vec::new(); loop { let sleep; @@ -129,7 +128,10 @@ where if let Some(stream) = network { tokio::select! { - res = stream.read() => { + res = stream.read_packet() => { + #[cfg(feature = "logs")] + tracing::trace!("Received incoming packet {:?}", &res); + let packet = res?; match packet{ Packet::PingResp => { @@ -145,12 +147,12 @@ where (maybe_reply_packet, true) => { handler.handle(packet).await; if let Some(reply_packet) = maybe_reply_packet { - stream.write(&reply_packet).await?; + stream.write_packet(&reply_packet).await?; *last_network_action = Instant::now(); } }, (Some(reply_packet), false) => { - stream.write(&reply_packet).await?; + stream.write_packet(&reply_packet).await?; *last_network_action = Instant::now(); }, (None, false) => (), @@ -159,8 +161,15 @@ where } }, outgoing = to_network_r.recv() => { + #[cfg(feature = "logs")] + tracing::trace!("Received outgoing item {:?}", &outgoing); + let packet = outgoing?; - stream.write(&packet).await?; + + #[cfg(feature = "logs")] + tracing::trace!("Sending packet {}", packet); + + stream.write_packet(&packet).await?; let disconnect = packet.packet_type() == PacketType::Disconnect; state_handler.handle_outgoing_packet(packet)?; @@ -173,13 +182,13 @@ where }, _ = tokio::time::sleep(sleep), if await_pingresp.is_none() && *perform_keep_alive => { let packet = Packet::PingReq; - stream.write(&packet).await?; + stream.write_packet(&packet).await?; *last_network_action = Instant::now(); await_pingresp = Some(Instant::now()); }, _ = tokio::time::sleep(sleep), if await_pingresp.is_some() => { let disconnect = Disconnect{ reason_code: DisconnectReasonCode::KeepAliveTimeout, properties: Default::default() }; - stream.write(&Packet::Disconnect(disconnect)).await?; + stream.write_packet(&Packet::Disconnect(disconnect)).await?; return Ok(NetworkStatus::KeepAliveTimeout); } } @@ -188,4 +197,86 @@ where } } } + + // async fn concurrent_tokio_select(&mut self, handler: &mut H) -> Result { + // let Network { + // network, + // options, + // last_network_action, + // perform_keep_alive, + // to_network_r, + // handler: _, + // state_handler, + // } = self; + + // let mut await_pingresp = None; + + // loop { + // let sleep; + // if let Some(instant) = await_pingresp { + // sleep = instant + options.get_keep_alive_interval() - Instant::now(); + // } else { + // sleep = *last_network_action + options.get_keep_alive_interval() - Instant::now(); + // } + + // if let Some(stream) = network { + // tokio::select! { + // res = stream.read_packet() => { + // let packet = res?; + // match packet{ + // Packet::PingResp => { + // handler.handle(packet).await; + // await_pingresp = None; + // }, + // Packet::Disconnect(_) => { + // handler.handle(packet).await; + // return Ok(NetworkStatus::IncomingDisconnect); + // } + // packet => { + // match state_handler.handle_incoming_packet(&packet)? { + // (maybe_reply_packet, true) => { + // handler.handle(packet).await; + // if let Some(reply_packet) = maybe_reply_packet { + // stream.write_packet(&reply_packet).await?; + // *last_network_action = Instant::now(); + // } + // }, + // (Some(reply_packet), false) => { + // stream.write_packet(&reply_packet).await?; + // *last_network_action = Instant::now(); + // }, + // (None, false) => (), + // } + // } + // } + // }, + // outgoing = to_network_r.recv() => { + // let packet = outgoing?; + // stream.write_packet(&packet).await?; + // let disconnect = packet.packet_type() == PacketType::Disconnect; + + // state_handler.handle_outgoing_packet(packet)?; + // *last_network_action = Instant::now(); + + // if disconnect{ + // return Ok(NetworkStatus::OutgoingDisconnect); + // } + // }, + // _ = tokio::time::sleep(sleep), if await_pingresp.is_none() && *perform_keep_alive => { + // let packet = Packet::PingReq; + // stream.write_packet(&packet).await?; + // *last_network_action = Instant::now(); + // await_pingresp = Some(Instant::now()); + // }, + // _ = tokio::time::sleep(sleep), if await_pingresp.is_some() => { + // let disconnect = Disconnect{ reason_code: DisconnectReasonCode::KeepAliveTimeout, properties: Default::default() }; + // stream.write_packet(&Packet::Disconnect(disconnect)).await?; + // return Ok(NetworkStatus::KeepAliveTimeout); + // } + // } + // } else { + // return Err(ConnectionError::NoNetwork); + // } + // } + // } } diff --git a/mqrstt/src/tokio/stream.rs b/mqrstt/src/tokio/stream.rs index 081d720..a5647e7 100644 --- a/mqrstt/src/tokio/stream.rs +++ b/mqrstt/src/tokio/stream.rs @@ -7,65 +7,76 @@ use crate::packets::ConnAck; use crate::packets::{ConnAckReasonCode, Packet}; use crate::{connect_options::ConnectOptions, error::ConnectionError}; -#[derive(Debug)] -pub struct Stream { - stream: S, +pub(crate) trait StreamExt { + fn connect(&mut self, options: &ConnectOptions) -> impl std::future::Future>; + fn read_packet(&mut self) -> impl std::future::Future>; + fn write_packet(&mut self, packet: &Packet) -> impl std::future::Future>; + fn write_packets(&mut self, packets: &[Packet]) -> impl std::future::Future>; + fn flush_packets(&mut self) -> impl std::future::Future>; } -impl Stream +impl StreamExt for S where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Sized + Unpin, { - pub async fn connect(options: &ConnectOptions, stream: S) -> Result<(Self, ConnAck), ConnectionError> { - let mut s = Self { stream }; + fn connect(&mut self, options: &ConnectOptions) -> impl std::future::Future> { + async move { + let connect = options.create_connect_from_options(); - let connect = options.create_connect_from_options(); + self.write_packet(&connect).await?; - s.write(&connect).await?; - - let packet = Packet::async_read(&mut s.stream).await?; - if let Packet::ConnAck(con) = packet { - if con.reason_code == ConnAckReasonCode::Success { - #[cfg(feature = "logs")] - trace!("Connected to server"); - Ok((s, con)) + let packet = Packet::async_read(self).await?; + if let Packet::ConnAck(con) = packet { + if con.reason_code == ConnAckReasonCode::Success { + #[cfg(feature = "logs")] + trace!("Connected to server"); + Ok(con) + } else { + Err(ConnectionError::ConnectionRefused(con.reason_code)) + } } else { - Err(ConnectionError::ConnectionRefused(con.reason_code)) + Err(ConnectionError::NotConnAck(packet)) } - } else { - Err(ConnectionError::NotConnAck(packet)) } } - pub async fn read(&mut self) -> Result { - Ok(Packet::async_read(&mut self.stream).await?) + fn read_packet(&mut self) -> impl std::future::Future> { + async move { Ok(Packet::async_read(self).await?) } } - pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> { - match packet.async_write(&mut self.stream).await { - Ok(_) => (), - Err(err) => { - return match err { - crate::packets::error::WriteError::SerializeError(serialize_error) => Err(ConnectionError::SerializationError(serialize_error)), - crate::packets::error::WriteError::IoError(error) => Err(ConnectionError::Io(error)), + fn write_packet(&mut self, packet: &Packet) -> impl std::future::Future> { + async move { + match packet.async_write(self).await { + Ok(_) => (), + Err(err) => { + return match err { + crate::packets::error::WriteError::SerializeError(serialize_error) => Err(ConnectionError::SerializationError(serialize_error)), + crate::packets::error::WriteError::IoError(error) => Err(ConnectionError::Io(error)), + } } } - } - self.stream.flush().await?; - #[cfg(feature = "logs")] - trace!("Sending packet {}", packet); + #[cfg(feature = "logs")] + trace!("Sending packet {}", packet); + + self.flush().await?; + // self.flush_packets().await?; - Ok(()) + Ok(()) + } } - pub async fn write_all(&mut self, packets: &mut Vec) -> Result<(), ConnectionError> { + async fn write_packets(&mut self, packets: &[Packet]) -> Result<(), ConnectionError> { for packet in packets { - let _ = packet.async_write(&mut self.stream).await; + let _ = packet.async_write(self).await; #[cfg(feature = "logs")] trace!("Sending packet {}", packet); } - self.stream.flush().await?; + self.flush_packets().await?; Ok(()) } + + fn flush_packets(&mut self) -> impl std::future::Future> { + tokio::io::AsyncWriteExt::flush(self) + } }