diff --git a/src/io.rs b/src/io.rs index 746ddf8..317af9e 100644 --- a/src/io.rs +++ b/src/io.rs @@ -24,6 +24,8 @@ //! It should also be kept around even if no requests are sent, as dropping it is used to signal the //! [`IoCore`] to close the connection. +mod wait_queue; + use std::{ collections::VecDeque, fmt::{self, Display, Formatter}, @@ -49,12 +51,13 @@ use tokio::{ use crate::{ header::Header, protocol::{ - payload_is_multi_frame, CompletedRead, FrameIter, JulietProtocol, LocalProtocolViolation, - OutgoingFrame, OutgoingMessage, ProtocolBuilder, + CompletedRead, FrameIter, JulietProtocol, LocalProtocolViolation, OutgoingFrame, + OutgoingMessage, ProtocolBuilder, }, util::PayloadFormat, ChannelId, Id, Outcome, }; +use wait_queue::{PushOutcome, WaitQueue}; /// Maximum number of bytes to pre-allocate in buffers. const MAX_ALLOC: usize = 32 * 1024; // 32 KiB @@ -269,7 +272,7 @@ pub struct IoCore { /// Frames waiting to be sent. ready_queue: VecDeque, /// Messages that are not yet ready to be sent. - wait_queue: [VecDeque; N], + wait_queue: [WaitQueue; N], /// Receiver for new messages to be queued. receiver: UnboundedReceiver, /// Mapping for outgoing requests, mapping internal IDs to public ones. @@ -559,27 +562,29 @@ where let header_sent = frame_sent.header(); // If we finished the active multi frame send, clear it. + let mut cleared_multi_frame = false; if was_final { let channel_idx = header_sent.channel().get() as usize; if let Some(ref active_multi_frame) = self.active_multi_frame[channel_idx] { if header_sent == *active_multi_frame { self.active_multi_frame[channel_idx] = None; + cleared_multi_frame = true; } } - } + }; if header_sent.is_error() { // We finished sending an error frame, time to exit. return Err(CoreError::RemoteProtocolViolation(header_sent)); } - // TODO: We should restrict the dirty-queue processing here a little bit - // (only check when completing a multi-frame message). // A message has completed sending, process the wait queue in case we have // to start sending a multi-frame message like a response that was delayed // only because of the one-multi-frame-per-channel restriction. - self.process_wait_queue(header_sent.channel())?; + if cleared_multi_frame { + self.process_wait_queue(header_sent.channel())?; + } } else { #[cfg(feature = "tracing")] tracing::error!("current frame should not disappear"); @@ -719,17 +724,34 @@ where /// Handles a new item to send out that arrived through the incoming channel. fn handle_incoming_item(&mut self, item: QueuedItem) -> Result<(), LocalProtocolViolation> { - // Check if the item is sendable immediately. - if let Some(channel) = item_should_wait(&item, &self.juliet, &self.active_multi_frame)? { - #[cfg(feature = "tracing")] - tracing::debug!(%item, "postponing send"); - self.wait_queue[channel.get() as usize].push_back(item); - return Ok(()); - } + let channel = match &item { + QueuedItem::Request { channel, .. } | QueuedItem::Response { channel, .. } => *channel, + QueuedItem::RequestCancellation { .. } + | QueuedItem::ResponseCancellation { .. } + | QueuedItem::Error { .. } => { + // These variants always get send immediately. + #[cfg(feature = "tracing")] + tracing::debug!(%item, "ready to send"); + return self.send_to_ready_queue(item); + } + }; - #[cfg(feature = "tracing")] - tracing::debug!(%item, "ready to send"); - self.send_to_ready_queue(item) + // Process the wait queue to avoid this new item "jumping the queue". + self.process_wait_queue(channel)?; + + // Add the item to the wait queue, or send if the wait queue returns the item. + match self.wait_queue[channel.get() as usize].try_push_back( + item, + &self.juliet, + &self.active_multi_frame, + )? { + PushOutcome::Pushed => Ok(()), + PushOutcome::NotPushed(ready_item) => { + #[cfg(feature = "tracing")] + tracing::debug!(item = %ready_item, "ready to send"); + self.send_to_ready_queue(ready_item) + } + } } /// Sends an item directly to the ready queue, causing it to be sent out eventually. @@ -745,6 +767,7 @@ where let id = msg.header().id(); self.request_map.insert(io_id, (channel, id)); if msg.is_multi_frame(self.juliet.max_frame_size()) { + debug_assert!(self.active_multi_frame[channel.get() as usize].is_none()); self.active_multi_frame[channel.get() as usize] = Some(msg.header()); } self.ready_queue.push_back(msg.frames()); @@ -771,6 +794,7 @@ where } => { if let Some(msg) = self.juliet.create_response(channel, id, payload)? { if msg.is_multi_frame(self.juliet.max_frame_size()) { + debug_assert!(self.active_multi_frame[channel.get() as usize].is_none()); self.active_multi_frame[channel.get() as usize] = Some(msg.header()); } self.ready_queue.push_back(msg.frames()) @@ -827,77 +851,17 @@ where /// Process the wait queue of a given channel, promoting messages that are ready to be sent. fn process_wait_queue(&mut self, channel: ChannelId) -> Result<(), LocalProtocolViolation> { - let mut remaining = self.wait_queue[channel.get() as usize].len(); - - while let Some(item) = self.wait_queue[channel.get() as usize].pop_front() { - if item_should_wait(&item, &self.juliet, &self.active_multi_frame)?.is_some() { - // Put it right back into the queue. - self.wait_queue[channel.get() as usize].push_back(item); - } else { - self.send_to_ready_queue(item)?; - - // No need to look further if we have saturated the channel. - if !self.juliet.allowed_to_send_request(channel)? { - break; - } - } - - // Ensure we do not loop endlessly if we cannot find anything. - remaining -= 1; - if remaining == 0 { - break; - } + while let Some(item) = self.wait_queue[channel.get() as usize].next_item( + channel, + &self.juliet, + &self.active_multi_frame, + )? { + self.send_to_ready_queue(item)?; } - Ok(()) } } -/// Determines whether an item is ready to be moved from the wait queue to the ready queue. -/// -/// Returns `None` if the item does not need to wait. Otherwise, the item's channel ID is returned. -fn item_should_wait( - item: &QueuedItem, - juliet: &JulietProtocol, - active_multi_frame: &[Option
; N], -) -> Result, LocalProtocolViolation> { - let (payload, channel) = match item { - QueuedItem::Request { - channel, payload, .. - } => { - // Check if we cannot schedule due to the message exceeding the request limit. - if !juliet.allowed_to_send_request(*channel)? { - return Ok(Some(*channel)); - } - - (payload, channel) - } - QueuedItem::Response { - channel, payload, .. - } => (payload, channel), - - // Other messages are always ready. - QueuedItem::RequestCancellation { .. } - | QueuedItem::ResponseCancellation { .. } - | QueuedItem::Error { .. } => return Ok(None), - }; - - let active_multi_frame = active_multi_frame[channel.get() as usize]; - - // Check if we cannot schedule due to the message being multi-frame and there being a - // multi-frame send in progress: - if active_multi_frame.is_some() { - if let Some(payload) = payload { - if payload_is_multi_frame(juliet.max_frame_size(), payload.len()) { - return Ok(Some(*channel)); - } - } - } - - // Otherwise, this should be a legitimate add to the run queue. - Ok(None) -} - /// A handle to the input queue to the [`IoCore`] that allows sending requests and responses. /// /// The handle is roughly three pointers in size and can be cloned at will. Dropping the last handle diff --git a/src/io/wait_queue.rs b/src/io/wait_queue.rs new file mode 100644 index 0000000..30e1fc0 --- /dev/null +++ b/src/io/wait_queue.rs @@ -0,0 +1,371 @@ +use std::{cmp, collections::VecDeque}; + +use bytes::Bytes; +use tokio::sync::OwnedSemaphorePermit; +#[cfg(feature = "tracing")] +use tracing::debug; + +use super::{IoId, QueuedItem}; +use crate::{ + header::Header, + protocol::{payload_is_multi_frame, JulietProtocol, LocalProtocolViolation}, + ChannelId, Id, +}; + +/// A single-frame request in the wait queue, converted from a `QueuedItem`, pending conversion +/// back to a `QueuedItem` when being moved to the ready queue. +#[derive(Debug)] +struct SingleFrameRequest { + wait_index: u64, + channel: ChannelId, + io_id: IoId, + payload: Option, + permit: OwnedSemaphorePermit, +} + +impl From for QueuedItem { + fn from(sf_req: SingleFrameRequest) -> Self { + QueuedItem::Request { + channel: sf_req.channel, + io_id: sf_req.io_id, + payload: sf_req.payload, + permit: sf_req.permit, + } + } +} + +/// A multi-frame request in the wait queue, converted from a `QueuedItem`, pending conversion +/// back to a `QueuedItem` when being moved to the ready queue. +#[derive(Debug)] +struct MultiFrameRequest { + wait_index: u64, + channel: ChannelId, + io_id: IoId, + payload: Option, + permit: OwnedSemaphorePermit, +} + +impl From for QueuedItem { + fn from(mf_req: MultiFrameRequest) -> Self { + QueuedItem::Request { + channel: mf_req.channel, + io_id: mf_req.io_id, + payload: mf_req.payload, + permit: mf_req.permit, + } + } +} + +/// A multi-frame response in the wait queue, converted from a `QueuedItem`, pending conversion +/// back to a `QueuedItem` when being moved to the ready queue. +#[derive(Debug)] +struct MultiFrameResponse { + wait_index: u64, + channel: ChannelId, + id: Id, + payload: Option, +} + +impl From for QueuedItem { + fn from(mf_resp: MultiFrameResponse) -> Self { + QueuedItem::Response { + channel: mf_resp.channel, + id: mf_resp.id, + payload: mf_resp.payload, + } + } +} + +/// The outcome of trying to push a new item to the wait queue. +pub(super) enum PushOutcome { + Pushed, + NotPushed(QueuedItem), +} + +/// The wait queue: an ordered collection of items waiting for the ready queue to become available +/// for new items. +#[derive(Default, Debug)] +pub(super) struct WaitQueue { + single_frame_requests: VecDeque, + multi_frame_requests: VecDeque, + multi_frame_responses: VecDeque, +} + +impl WaitQueue { + /// Add the given item to the back of the wait queue. + pub(super) fn try_push_back( + &mut self, + item: QueuedItem, + juliet: &JulietProtocol, + active_multi_frame: &[Option
; N], + ) -> Result { + match item { + QueuedItem::Request { + channel, + io_id, + payload, + permit, + } => { + if !juliet.allowed_to_send_request(channel)? { + #[cfg(feature = "tracing")] + if payload_is_multi_frame( + juliet.max_frame_size(), + payload + .as_ref() + .map(|payld| payld.len()) + .unwrap_or_default(), + ) { + debug!(%channel, %io_id, "multi-frame request postponed: channel full"); + } else { + debug!(%channel, %io_id, "single-frame request postponed: channel full"); + } + self.push_request(channel, io_id, payload, permit, juliet); + return Ok(PushOutcome::Pushed); + } + + if self.has_active_multi_frame(channel, &payload, juliet, active_multi_frame) { + #[cfg(feature = "tracing")] + debug!(%channel, %io_id, "multi-frame request postponed: other in progress"); + self.push_multi_frame_request(channel, io_id, payload, permit); + return Ok(PushOutcome::Pushed); + } + + // We don't need to wait - rebuild the item and return it. + let item = QueuedItem::Request { + channel, + io_id, + payload, + permit, + }; + Ok(PushOutcome::NotPushed(item)) + } + QueuedItem::Response { + channel, + id, + payload, + } => { + if self.has_active_multi_frame(channel, &payload, juliet, active_multi_frame) { + #[cfg(feature = "tracing")] + debug!(%channel, %id, "multi-frame response postponed: other in progress"); + self.push_response(channel, id, payload); + return Ok(PushOutcome::Pushed); + } + + // We don't need to wait - rebuild the item and return it. + let item = QueuedItem::Response { + channel, + id, + payload, + }; + Ok(PushOutcome::NotPushed(item)) + } + QueuedItem::RequestCancellation { .. } + | QueuedItem::ResponseCancellation { .. } + | QueuedItem::Error { .. } => Ok(PushOutcome::NotPushed(item)), + } + } + + /// Returns the wait index to assign to a new item being added to the wait queue. + fn next_wait_index(&self) -> u64 { + let mut current_max = 0; + + if let Some(index) = self + .single_frame_requests + .back() + .map(|item| item.wait_index) + { + current_max = index; + } + + if let Some(index) = self.multi_frame_requests.back().map(|item| item.wait_index) { + current_max = cmp::max(current_max, index); + } + + if let Some(index) = self + .multi_frame_responses + .back() + .map(|item| item.wait_index) + { + current_max = cmp::max(current_max, index); + } + + current_max.wrapping_add(1) + } + + /// Pushes a request onto either the single-frame request queue or the multi-frame one. + fn push_request( + &mut self, + channel: ChannelId, + io_id: IoId, + payload: Option, + permit: OwnedSemaphorePermit, + juliet: &JulietProtocol, + ) { + if payload_is_multi_frame( + juliet.max_frame_size(), + payload + .as_ref() + .map(|payld| payld.len()) + .unwrap_or_default(), + ) { + self.push_multi_frame_request(channel, io_id, payload, permit); + } else { + let wait_index = self.next_wait_index(); + let sf_req = SingleFrameRequest { + wait_index, + channel, + io_id, + payload, + permit, + }; + self.single_frame_requests.push_back(sf_req); + } + } + + /// Pushes a request onto the multi-frame request queue. + fn push_multi_frame_request( + &mut self, + channel: ChannelId, + io_id: IoId, + payload: Option, + permit: OwnedSemaphorePermit, + ) { + let wait_index = self.next_wait_index(); + let mf_req = MultiFrameRequest { + wait_index, + channel, + io_id, + payload, + permit, + }; + self.multi_frame_requests.push_back(mf_req); + } + + /// Pushes a response onto the multi-frame response queue. + fn push_response(&mut self, channel: ChannelId, id: Id, payload: Option) { + let wait_index = self.next_wait_index(); + let mf_resp = MultiFrameResponse { + wait_index, + channel, + id, + payload, + }; + self.multi_frame_responses.push_back(mf_resp); + } + + /// Checks if we cannot schedule due to the message being multi-frame and there being a + /// multi-frame send in progress. + fn has_active_multi_frame( + &self, + channel: ChannelId, + payload: &Option, + juliet: &JulietProtocol, + active_multi_frame: &[Option
; N], + ) -> bool { + if active_multi_frame[channel.get() as usize].is_none() { + return false; + }; + + if let Some(payload) = payload { + return payload_is_multi_frame(juliet.max_frame_size(), payload.len()); + } + + false + } + + /// Removes and returns the next "sendable" item (i.e. the oldest item) from the wait queue, + /// where "sendable" is decided based on whether requests and/or multi-frame messages are + /// allowed. + pub(super) fn next_item( + &mut self, + channel: ChannelId, + juliet: &JulietProtocol, + active_multi_frame: &[Option
; N], + ) -> Result, LocalProtocolViolation> { + let request_allowed = juliet.allowed_to_send_request(channel)?; + let multi_frame_allowed = active_multi_frame[channel.get() as usize].is_none(); + let maybe_item = match (request_allowed, multi_frame_allowed) { + (false, false) => None, + (false, true) => self.multi_frame_responses.pop_front().map(QueuedItem::from), + (true, false) => self.single_frame_requests.pop_front().map(QueuedItem::from), + (true, true) => { + #[derive(Clone, Copy)] + enum QueueToPop { + SingleFrameRequests(u64), + MultiFrameRequests(u64), + MultiFrameResponses(u64), + } + + impl QueueToPop { + fn index(&self) -> u64 { + match self { + QueueToPop::SingleFrameRequests(index) + | QueueToPop::MultiFrameRequests(index) + | QueueToPop::MultiFrameResponses(index) => *index, + } + } + } + + let mut queue_to_pop: Option = None; + + if let Some(wait_index) = self + .single_frame_requests + .front() + .map(|sf_req| sf_req.wait_index) + { + queue_to_pop = Some(QueueToPop::SingleFrameRequests(wait_index)) + } + + if let Some(wait_index) = self + .multi_frame_requests + .front() + .map(|mf_req| mf_req.wait_index) + { + match queue_to_pop { + Some(qtp) => { + if wait_index < qtp.index() { + queue_to_pop = Some(QueueToPop::MultiFrameRequests(wait_index)); + } + } + None => queue_to_pop = Some(QueueToPop::MultiFrameRequests(wait_index)), + }; + } + + if let Some(index) = self + .multi_frame_responses + .front() + .map(|mf_resp| mf_resp.wait_index) + { + match queue_to_pop { + Some(qtp) => { + if index < qtp.index() { + queue_to_pop = Some(QueueToPop::MultiFrameResponses(index)); + } + } + None => queue_to_pop = Some(QueueToPop::MultiFrameResponses(index)), + }; + } + + match queue_to_pop { + Some(QueueToPop::SingleFrameRequests(_)) => { + self.single_frame_requests.pop_front().map(QueuedItem::from) + } + Some(QueueToPop::MultiFrameRequests(_)) => { + self.multi_frame_requests.pop_front().map(QueuedItem::from) + } + Some(QueueToPop::MultiFrameResponses(_)) => { + self.multi_frame_responses.pop_front().map(QueuedItem::from) + } + None => None, + } + } + }; + Ok(maybe_item) + } + + pub(super) fn clear(&mut self) { + self.single_frame_requests.clear(); + self.multi_frame_requests.clear(); + self.multi_frame_responses.clear(); + } +} diff --git a/src/protocol.rs b/src/protocol.rs index 82145a5..759ddc1 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -305,6 +305,11 @@ impl Channel { self.outgoing_requests.len() < self.config.request_limit as usize } + /// Returns the configured request limit for this channel. + pub fn request_limit(&self) -> u16 { + self.config.request_limit + } + /// Creates a new request, bypassing all client-side checks. /// /// Low-level function that does nothing but create a syntactically correct request and track @@ -474,7 +479,7 @@ impl Display for CompletedRead { } } -/// The caller of the this crate has violated the protocol. +/// The caller of this crate has violated the protocol. /// /// A correct implementation of a client should never encounter this, thus simply unwrapping every /// instance of this as part of a `Result<_, LocalProtocolViolation>` is usually a valid choice. @@ -487,24 +492,45 @@ pub enum LocalProtocolViolation { /// /// Wait for additional requests to be cancelled or answered. Calling /// [`JulietProtocol::allowed_to_send_request()`] beforehand is recommended. - #[error("sending would exceed request limit")] - WouldExceedRequestLimit, + #[error("sending would exceed request limit of {limit}")] + WouldExceedRequestLimit { + /// The configured limit for requests on the channel. + limit: u16, + }, /// The channel given does not exist. /// /// The given [`ChannelId`] exceeds `N` of [`JulietProtocol`]. - #[error("invalid channel")] - InvalidChannel(ChannelId), + #[error("channel {channel} not a member of configured {channel_count} channels")] + InvalidChannel { + /// The provided channel ID. + channel: ChannelId, + /// The configured number of channels. + channel_count: usize, + }, /// The given payload exceeds the configured limit. /// /// See [`ChannelConfiguration::with_max_request_payload_size()`] and /// [`ChannelConfiguration::with_max_response_payload_size()`] for details. - #[error("payload exceeds configured limit")] - PayloadExceedsLimit, + #[error("payload length of {payload_length} bytes exceeds configured limit of {limit}")] + PayloadExceedsLimit { + /// The payload length in bytes. + payload_length: usize, + /// The configured upper limit for payload length in bytes. + limit: usize, + }, /// The given error payload exceeds a single frame. /// /// Error payloads may not span multiple frames, shorten the payload or increase frame size. - #[error("error payload would be multi-frame")] - ErrorPayloadIsMultiFrame, + #[error( + "error payload of {payload_length} bytes exceeds a single frame with configured max size \ + of {max_frame_size})" + )] + ErrorPayloadIsMultiFrame { + /// The payload length in bytes. + payload_length: usize, + /// The configured maximum frame size in bytes. + max_frame_size: u32, + }, } macro_rules! log_frame { @@ -534,7 +560,10 @@ impl JulietProtocol { #[inline(always)] const fn lookup_channel(&self, channel: ChannelId) -> Result<&Channel, LocalProtocolViolation> { if channel.0 as usize >= N { - Err(LocalProtocolViolation::InvalidChannel(channel)) + Err(LocalProtocolViolation::InvalidChannel { + channel, + channel_count: N, + }) } else { Ok(&self.channels[channel.0 as usize]) } @@ -549,7 +578,10 @@ impl JulietProtocol { channel: ChannelId, ) -> Result<&mut Channel, LocalProtocolViolation> { if channel.0 as usize >= N { - Err(LocalProtocolViolation::InvalidChannel(channel)) + Err(LocalProtocolViolation::InvalidChannel { + channel, + channel_count: N, + }) } else { Ok(&mut self.channels[channel.0 as usize]) } @@ -595,12 +627,17 @@ impl JulietProtocol { if let Some(ref payload) = payload { if payload.len() > chan.config.max_request_payload_size as usize { - return Err(LocalProtocolViolation::PayloadExceedsLimit); + return Err(LocalProtocolViolation::PayloadExceedsLimit { + payload_length: payload.len(), + limit: chan.config.max_request_payload_size as usize, + }); } } if !chan.allowed_to_send_request() { - return Err(LocalProtocolViolation::WouldExceedRequestLimit); + return Err(LocalProtocolViolation::WouldExceedRequestLimit { + limit: chan.request_limit(), + }); } Ok(chan.create_unchecked_request(channel, payload)) @@ -637,7 +674,10 @@ impl JulietProtocol { if let Some(ref payload) = payload { if payload.len() > chan.config.max_response_payload_size as usize { - return Err(LocalProtocolViolation::PayloadExceedsLimit); + return Err(LocalProtocolViolation::PayloadExceedsLimit { + payload_length: payload.len(), + limit: chan.config.max_request_payload_size as usize, + }); } } @@ -712,11 +752,15 @@ impl JulietProtocol { id: Id, payload: Bytes, ) -> Result { - let header = Header::new_error(header::ErrorKind::Other, channel, id); + let header = Header::new_error(ErrorKind::Other, channel, id); + let payload_length = payload.len(); let msg = OutgoingMessage::new(header, Some(payload)); if msg.is_multi_frame(self.max_frame_size) { - Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame) + Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame { + payload_length, + max_frame_size: self.max_frame_size.0, + }) } else { Ok(msg) } @@ -1253,7 +1297,8 @@ mod tests { #[test] fn test_channel_lookups_work() { - let mut protocol: JulietProtocol<3> = ProtocolBuilder::new().build(); + const CHANNEL_COUNT: usize = 3; + let mut protocol: JulietProtocol = ProtocolBuilder::new().build(); // We mark channels by inserting an ID into them, that way we can ensure we're not getting // back the same channel every time. @@ -1274,15 +1319,24 @@ mod tests { .insert(Id::new(102)); assert!(matches!( protocol.lookup_channel_mut(ChannelId(3)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(3))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(3), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel_mut(ChannelId(4)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(4))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(4), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel_mut(ChannelId(255)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(255))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(255), + channel_count: CHANNEL_COUNT + }) )); // Now look up the channels and ensure they contain the right values @@ -1309,15 +1363,24 @@ mod tests { ); assert!(matches!( protocol.lookup_channel(ChannelId(3)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(3))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(3), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel(ChannelId(4)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(4))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(4), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel(ChannelId(255)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(255))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(255), + channel_count: CHANNEL_COUNT + }) )); } @@ -1442,7 +1505,10 @@ mod tests { // Try an invalid channel, should result in an error. assert!(matches!( protocol.create_request(ChannelId::new(2), payload.get()), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(2))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(2), + channel_count: 2 + }) )); assert!(protocol @@ -1454,7 +1520,7 @@ mod tests { assert!(matches!( protocol.create_request(channel, payload.get()), - Err(LocalProtocolViolation::WouldExceedRequestLimit) + Err(LocalProtocolViolation::WouldExceedRequestLimit { limit: 1 }) )); } } @@ -2202,7 +2268,10 @@ mod tests { .create_request(env.common_channel, payload.get()) .expect_err("should not be able to create too large request"); - assert_matches!(violation, LocalProtocolViolation::PayloadExceedsLimit); + assert_matches!( + violation, + LocalProtocolViolation::PayloadExceedsLimit { .. } + ); // If we force the issue, Bob must refuse it instead. let bob_result = env.inject_and_send_request(Alice, payload.get()); @@ -2219,7 +2288,10 @@ mod tests { .bob .create_request(env.common_channel, payload.get()) .expect_err("should not be able to create too large response"); - assert_matches!(violation, LocalProtocolViolation::PayloadExceedsLimit); + assert_matches!( + violation, + LocalProtocolViolation::PayloadExceedsLimit { .. } + ); // If we force the issue, Alice must refuse it. let alice_result = env.inject_and_send_response(Bob, id, payload.get()); diff --git a/src/protocol/multiframe.rs b/src/protocol/multiframe.rs index de6a913..75c908a 100644 --- a/src/protocol/multiframe.rs +++ b/src/protocol/multiframe.rs @@ -44,6 +44,7 @@ pub(super) enum MultiframeReceiver { /// The outcome of a multiframe acceptance. #[derive(Debug)] +#[allow(clippy::enum_variant_names)] pub(crate) enum CompletedFrame { /// A new multi-frame transfer was started. NewMultiFrame, diff --git a/src/rpc.rs b/src/rpc.rs index 1bd80a5..4f3e645 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -766,7 +766,7 @@ impl IncomingRequest { // Do nothing, just discard the response. } EnqueueError::BufferLimitHit(_) => { - // TODO: Add seperate type to avoid this. + // TODO: Add separate type to avoid this. unreachable!("cannot hit request limit when responding") } } @@ -851,7 +851,10 @@ mod tests { use bytes::Bytes; use futures::FutureExt; - use tokio::io::{DuplexStream, ReadHalf, WriteHalf}; + use tokio::{ + io::{DuplexStream, ReadHalf, WriteHalf}, + sync::mpsc, + }; use tracing::{error_span, info, span, Instrument, Level}; use crate::{ @@ -923,7 +926,7 @@ mod tests { async fn run_echo_client( mut rpc_server: JulietRpcServer, WriteHalf>, ) { - while let Some(inc) = rpc_server + if let Some(inc) = rpc_server .next_request() .await .expect("client rpc_server error") @@ -1330,7 +1333,7 @@ mod tests { large_volume_test::<1>(spec).await; } - #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + #[tokio::test(flavor = "multi_thread", worker_threads = 5)] async fn run_large_volume_test_with_default_values_10_channels() { tracing_subscriber::fmt() .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) @@ -1352,7 +1355,7 @@ mod tests { let (mut alice, mut bob) = LargeVolumeTestSpec::::default().mk_rpc(); - // Alice server. Will close the connection after enough bytes have been sent. + // Alice server. Will close the connection after enough bytes have been received. let mut remaining = spec.min_send_bytes; let alice_server = tokio::spawn( async move { @@ -1371,6 +1374,7 @@ mod tests { request.respond(None); remaining = remaining.saturating_sub(payload_size); + tracing::debug!("payload_size: {payload_size}, remaining: {remaining}"); if remaining == 0 { // We've reached the volume we were looking for, end test. break; @@ -1420,14 +1424,18 @@ mod tests { Err(guard) => { // Not ready, but we are not going to wait. - tokio::spawn(async move { - if let Err(err) = guard.wait_for_response().await { - match err { - RequestError::RemoteClosed(_) | RequestError::Shutdown => {} - err => panic!("{}", err), + tokio::spawn( + async move { + if let Err(err) = guard.wait_for_response().await { + match err { + RequestError::RemoteClosed(_) + | RequestError::Shutdown => {} + err => panic!("{}", err), + } } } - }); + .in_current_span(), + ); } } } @@ -1437,10 +1445,11 @@ mod tests { .instrument(error_span!("alice_client")), ); - // Bob server. + // A channel to allow Bob's server to notify Bob's client to send a new request to Alice. + let (notify_tx, mut notify_rx) = mpsc::unbounded_channel(); + // Bob server. Will shut down once Alice closes the connection. let bob_server = tokio::spawn( async move { - let mut bob_counter = 0; while let Some(request) = bob .server .next_request() @@ -1459,7 +1468,19 @@ mod tests { let channel = request.channel(); // Just discard the message payload, but acknowledge receiving it. request.respond(None); + // Notify Bob client to send a new request to Alice. + notify_tx.send(channel).unwrap(); + } + info!("exiting"); + } + .instrument(error_span!("bob_server")), + ); + // Bob client. Will shut down once Alice closes the connection. + let bob_client = tokio::spawn( + async move { + let mut bob_counter = 0; + while let Some(channel) = notify_rx.recv().await { let payload_size = spec.gen_payload_size(bob_counter); let large_payload: Bytes = iter::repeat(0xFF) .take(payload_size) @@ -1470,11 +1491,11 @@ mod tests { let bobs_request: RequestGuard = bob .client .create_request(channel) - .with_payload(large_payload.clone()) + .with_payload(large_payload) .queue_for_sending() .await; - info!(bob_counter, "bob enqueued request"); + info!(bob_counter, payload_size, "bob enqueued request"); bob_counter += 1; match bobs_request.try_get_response() { @@ -1492,26 +1513,30 @@ mod tests { Err(guard) => { // Do not wait, instead attempt to retrieve next request. - tokio::spawn(async move { - if let Err(err) = guard.wait_for_response().await { - match err { - RequestError::RemoteClosed(_) | RequestError::Shutdown => {} - err => panic!("{}", err), + tokio::spawn( + async move { + if let Err(err) = guard.wait_for_response().await { + match err { + RequestError::RemoteClosed(_) + | RequestError::Shutdown => {} + err => panic!("{}", err), + } } } - }); + .in_current_span(), + ); } } } - info!("exiting"); } - .instrument(error_span!("bob_server")), + .instrument(error_span!("bob_client")), ); alice_server.await.expect("failed to join alice server"); alice_client.await.expect("failed to join alice client"); bob_server.await.expect("failed to join bob server"); + bob_client.await.expect("failed to join bob client"); info!("all joined"); } @@ -1551,7 +1576,7 @@ mod tests { let mut bob = CompleteSetup::new(&rpc_builder, bob_stream); let alice_join_handle = tokio::spawn(async move { - while let Some(incoming_request) = alice + if let Some(incoming_request) = alice .server .next_request() .await