diff --git a/async-nats/src/jetstream/consumer/mod.rs b/async-nats/src/jetstream/consumer/mod.rs index 313e424f8..cc1d4c3c3 100644 --- a/async-nats/src/jetstream/consumer/mod.rs +++ b/async-nats/src/jetstream/consumer/mod.rs @@ -17,13 +17,15 @@ pub mod pull; pub mod push; #[cfg(feature = "server_2_10")] use std::collections::HashMap; +use std::future::IntoFuture; use std::time::Duration; use serde::{Deserialize, Serialize}; use serde_json::json; use time::serde::rfc3339; -use super::context::RequestError; +use super::context::{RequestError, RequestErrorKind}; +use super::response::Response; use super::stream::ClusterInfo; use super::Context; use crate::jetstream::consumer; @@ -76,14 +78,38 @@ impl Consumer { pub async fn info(&mut self) -> Result<&consumer::Info, RequestError> { let subject = format!("CONSUMER.INFO.{}.{}", self.info.stream_name, self.info.name); - let info = self.context.request(subject, &json!({})).await?; - self.info = info; - Ok(&self.info) + let response: Response = self + .context + .request(subject, &json!({})) + .into_future() + .await?; + + match response { + Response::Ok(info) => { + self.info = info; + Ok(&self.info) + } + Response::Err { error } => { + Err(RequestError::with_source(RequestErrorKind::Other, error)) + } + } } async fn fetch_info(&self) -> Result { let subject = format!("CONSUMER.INFO.{}.{}", self.info.stream_name, self.info.name); - self.context.request(subject, &json!({})).await + + let response: Response = self + .context + .request(subject, &json!({})) + .into_future() + .await?; + + match response { + Response::Ok(info) => Ok(info), + Response::Err { error } => { + Err(RequestError::with_source(RequestErrorKind::Other, error)) + } + } } /// Returns cached [Info] for the [Consumer]. diff --git a/async-nats/src/jetstream/context.rs b/async-nats/src/jetstream/context.rs index 32ca7d80c..cf7dc9238 100644 --- a/async-nats/src/jetstream/context.rs +++ b/async-nats/src/jetstream/context.rs @@ -28,6 +28,7 @@ use std::borrow::Borrow; use std::fmt::Display; use std::future::IntoFuture; use std::io::ErrorKind; +use std::marker::PhantomData; use std::pin::Pin; use std::str::from_utf8; use std::task::Poll; @@ -747,30 +748,12 @@ impl Context { /// # Ok(()) /// # } /// ``` - pub async fn request(&self, subject: String, payload: &T) -> Result + pub fn request(&self, subject: String, payload: T) -> Request where - T: ?Sized + Serialize, + T: Sized + Serialize, V: DeserializeOwned, { - let request = serde_json::to_vec(&payload) - .map(Bytes::from) - .map_err(|err| RequestError::with_source(RequestErrorKind::Other, err))?; - - debug!("JetStream request sent: {:?}", request); - - let message = self - .client - .request(format!("{}.{}", self.prefix, subject), request) - .await; - let message = message?; - debug!( - "JetStream request response: {:?}", - from_utf8(&message.payload) - ); - let response = serde_json::from_slice(message.payload.as_ref()) - .map_err(|err| RequestError::with_source(RequestErrorKind::Other, err))?; - - Ok(response) + Request::new(self.clone(), subject, payload) } /// Creates a new object store bucket. @@ -1254,6 +1237,67 @@ impl IntoFuture for Publish { } } +#[derive(Debug)] +pub struct Request { + context: Context, + subject: String, + payload: T, + timeout: Option, + response_type: PhantomData, +} + +impl Request { + pub fn new(context: Context, subject: String, payload: T) -> Self { + Self { + context, + subject, + payload, + timeout: None, + response_type: PhantomData, + } + } + + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } +} + +impl IntoFuture for Request { + type Output = Result, RequestError>; + + type IntoFuture = Pin, RequestError>> + Send>>; + + fn into_future(self) -> Self::IntoFuture { + let payload_result = serde_json::to_vec(&self.payload).map(Bytes::from); + + let prefix = self.context.prefix; + let client = self.context.client; + let subject = self.subject; + let timeout = self.timeout; + + Box::pin(std::future::IntoFuture::into_future(async move { + let payload = payload_result + .map_err(|err| RequestError::with_source(RequestErrorKind::Other, err))?; + + debug!("JetStream request sent: {:?}", payload); + + let request = client.request(format!("{}.{}", prefix, subject), payload); + let request = request.timeout(timeout); + let message = request.await?; + + debug!( + "JetStream request response: {:?}", + from_utf8(&message.payload) + ); + let response = serde_json::from_slice(message.payload.as_ref()) + .map_err(|err| RequestError::with_source(RequestErrorKind::Other, err))?; + + Ok(response) + })) + } +} + #[derive(Debug)] pub struct RequestError { kind: RequestErrorKind, diff --git a/async-nats/src/jetstream/stream.rs b/async-nats/src/jetstream/stream.rs index 8c67d2a5b..fc70972f3 100644 --- a/async-nats/src/jetstream/stream.rs +++ b/async-nats/src/jetstream/stream.rs @@ -590,6 +590,7 @@ impl Stream { let response: Response = self .context .request(subject, &payload) + .into_future() .map_err(|err| LastRawMessageError::with_source(LastRawMessageErrorKind::Other, err)) .await?; match response { @@ -640,6 +641,7 @@ impl Stream { let response: Response = self .context .request(subject, &payload) + .into_future() .map_err(|err| match err.kind() { RequestErrorKind::TimedOut => { DeleteMessageError::new(DeleteMessageErrorKind::TimedOut) @@ -1579,6 +1581,7 @@ where .stream .context .request(request_subject, &self.inner) + .into_future() .map_err(|err| match err.kind() { RequestErrorKind::TimedOut => PurgeError::new(PurgeErrorKind::TimedOut), _ => PurgeError::with_source(PurgeErrorKind::FailedRequest, err), diff --git a/async-nats/tests/jetstream_tests.rs b/async-nats/tests/jetstream_tests.rs index 92b4aa607..6d3f494e6 100644 --- a/async-nats/tests/jetstream_tests.rs +++ b/async-nats/tests/jetstream_tests.rs @@ -311,6 +311,21 @@ mod jetstream { assert!(matches!(response, Response::Err { .. })); } + #[tokio::test] + async fn request_timeout() { + let server = nats_server::run_server("tests/configs/jetstream.conf"); + let client = async_nats::connect(server.client_url()).await.unwrap(); + let context = async_nats::jetstream::new(client); + + let response: Response = context + .request("INFO".to_string(), &()) + .timeout(Duration::from_secs(1)) + .await + .unwrap(); + + assert!(matches!(response, Response::Ok { .. })); + } + #[tokio::test] async fn create_stream() { let server = nats_server::run_server("tests/configs/jetstream.conf");