diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index 17952394cd645..e0f7f74ab8f2d 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -1613,9 +1613,18 @@ impl ChannelClient { pub fn request( &self, payload: T, + ) -> impl 'static + Future> { + self.request_internal(payload, true) + } + + fn request_internal( + &self, + payload: T, + use_buffer: bool, ) -> impl 'static + Future> { log::debug!("ssh request start. name:{}", T::NAME); - let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME); + let response = + self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer); async move { let response = response.await?; log::debug!("ssh request finish. name:{}", T::NAME); @@ -1627,7 +1636,9 @@ impl ChannelClient { pub async fn resync(&self, timeout: Duration) -> Result<()> { smol::future::or( async { - self.request(proto::FlushBufferedMessages {}).await?; + self.request_internal(proto::FlushBufferedMessages {}, false) + .await?; + for envelope in self.buffer.lock().iter() { self.outgoing_tx .lock() @@ -1663,10 +1674,11 @@ impl ChannelClient { self.send_dynamic(payload.into_envelope(0, None, None)) } - pub fn request_dynamic( + fn request_dynamic( &self, mut envelope: proto::Envelope, type_name: &'static str, + use_buffer: bool, ) -> impl 'static + Future> { envelope.id = self.next_message_id.fetch_add(1, SeqCst); let (tx, rx) = oneshot::channel(); @@ -1674,7 +1686,11 @@ impl ChannelClient { response_channels_lock.insert(MessageId(envelope.id), tx); drop(response_channels_lock); - let result = self.send_buffered(envelope); + let result = if use_buffer { + self.send_buffered(envelope) + } else { + self.send_unbuffered(envelope) + }; async move { if let Err(error) = &result { log::error!("failed to send message: {}", error); @@ -1694,7 +1710,7 @@ impl ChannelClient { self.send_buffered(envelope) } - pub fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> { + fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> { envelope.ack_id = Some(self.max_received.load(SeqCst)); self.buffer.lock().push_back(envelope.clone()); // ignore errors on send (happen while we're reconnecting) @@ -1702,6 +1718,12 @@ impl ChannelClient { self.outgoing_tx.lock().unbounded_send(envelope).ok(); Ok(()) } + + fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> { + envelope.ack_id = Some(self.max_received.load(SeqCst)); + self.outgoing_tx.lock().unbounded_send(envelope).ok(); + Ok(()) + } } impl ProtoClient for ChannelClient { @@ -1710,7 +1732,7 @@ impl ProtoClient for ChannelClient { envelope: proto::Envelope, request_type: &'static str, ) -> BoxFuture<'static, Result> { - self.request_dynamic(envelope, request_type).boxed() + self.request_dynamic(envelope, request_type, true).boxed() } fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {