diff --git a/Cargo.toml b/Cargo.toml index 1667aa59..162f5a93 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,3 +3,7 @@ members = [ "tower-http", "examples/*", ] + +[patch.crates-io] +# for `Frame::map_data` +http-body = { git = "https://github.com/hyperium/http-body", rev = "7bf321acbb422214c89933c103417bcfe3892aed" } diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index 9ef26363..0d5ac2b7 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -18,7 +18,8 @@ bytes = "1" futures-core = "0.3" futures-util = { version = "0.3.14", default_features = false, features = [] } http = "0.2.2" -http-body = "0.4.5" +http-body = "1.0.0-rc.2" +http-body-util = "0.1.0-rc.2" pin-project-lite = "0.2.7" tower-layer = "0.3" tower-service = "0.3" @@ -39,6 +40,7 @@ httpdate = { version = "1.0", optional = true } uuid = { version = "1.0", features = ["v4"], optional = true } [dev-dependencies] +async-trait = "0.1" bytes = "1" flate2 = "1.0" brotli = "3" diff --git a/tower-http/src/catch_panic.rs b/tower-http/src/catch_panic.rs index b547b32d..caa3a496 100644 --- a/tower-http/src/catch_panic.rs +++ b/tower-http/src/catch_panic.rs @@ -86,7 +86,8 @@ use bytes::Bytes; use futures_core::ready; use futures_util::future::{CatchUnwind, FutureExt}; use http::{HeaderValue, Request, Response, StatusCode}; -use http_body::{combinators::UnsyncBoxBody, Body, Full}; +use http_body::Body; +use http_body_util::{combinators::UnsyncBoxBody, BodyExt, Full}; use pin_project_lite::pin_project; use std::{ any::Any, diff --git a/tower-http/src/compression/body.rs b/tower-http/src/compression/body.rs index eeb798ba..486a14bb 100644 --- a/tower-http/src/compression/body.rs +++ b/tower-http/src/compression/body.rs @@ -247,46 +247,29 @@ where type Data = Bytes; type Error = BoxError; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { match self.project().inner.project() { #[cfg(feature = "compression-gzip")] - BodyInnerProj::Gzip { inner } => inner.poll_data(cx), + BodyInnerProj::Gzip { inner } => inner.poll_frame(cx), #[cfg(feature = "compression-deflate")] - BodyInnerProj::Deflate { inner } => inner.poll_data(cx), + BodyInnerProj::Deflate { inner } => inner.poll_frame(cx), #[cfg(feature = "compression-br")] - BodyInnerProj::Brotli { inner } => inner.poll_data(cx), + BodyInnerProj::Brotli { inner } => inner.poll_frame(cx), #[cfg(feature = "compression-zstd")] - BodyInnerProj::Zstd { inner } => inner.poll_data(cx), - BodyInnerProj::Identity { inner } => match ready!(inner.poll_data(cx)) { - Some(Ok(mut buf)) => { - let bytes = buf.copy_to_bytes(buf.remaining()); - Poll::Ready(Some(Ok(bytes))) + BodyInnerProj::Zstd { inner } => inner.poll_frame(cx), + BodyInnerProj::Identity { inner } => match ready!(inner.poll_frame(cx)) { + Some(Ok(frame)) => { + let frame = frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining())); + Poll::Ready(Some(Ok(frame))) } Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), None => Poll::Ready(None), }, } } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - match self.project().inner.project() { - #[cfg(feature = "compression-gzip")] - BodyInnerProj::Gzip { inner } => inner.poll_trailers(cx), - #[cfg(feature = "compression-deflate")] - BodyInnerProj::Deflate { inner } => inner.poll_trailers(cx), - #[cfg(feature = "compression-br")] - BodyInnerProj::Brotli { inner } => inner.poll_trailers(cx), - #[cfg(feature = "compression-zstd")] - BodyInnerProj::Zstd { inner } => inner.poll_trailers(cx), - BodyInnerProj::Identity { inner } => inner.poll_trailers(cx).map_err(Into::into), - } - } } #[cfg(feature = "compression-gzip")] diff --git a/tower-http/src/compression/layer.rs b/tower-http/src/compression/layer.rs index 359385e7..56854da2 100644 --- a/tower-http/src/compression/layer.rs +++ b/tower-http/src/compression/layer.rs @@ -123,7 +123,7 @@ impl CompressionLayer { #[cfg(test)] mod tests { use super::*; - use crate::test_helpers::Body; + use crate::test_helpers::{Body, TowerHttpBodyExt}; use http::{header::ACCEPT_ENCODING, Request, Response}; use http_body::Body as _; use tokio::fs::File; diff --git a/tower-http/src/compression/mod.rs b/tower-http/src/compression/mod.rs index d9d4434b..1fdf7aa2 100644 --- a/tower-http/src/compression/mod.rs +++ b/tower-http/src/compression/mod.rs @@ -87,7 +87,7 @@ mod tests { use crate::compression::predicate::SizeAbove; use super::*; - use crate::test_helpers::Body; + use crate::test_helpers::{Body, TowerHttpBodyExt}; use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder}; use bytes::BytesMut; use flate2::read::GzDecoder; diff --git a/tower-http/src/compression_utils.rs b/tower-http/src/compression_utils.rs index 7b289371..626543dd 100644 --- a/tower-http/src/compression_utils.rs +++ b/tower-http/src/compression_utils.rs @@ -188,51 +188,61 @@ where type Data = Bytes; type Error = BoxError; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - let mut this = self.project(); - let mut buf = BytesMut::new(); - - let read = match ready!(poll_read_buf(this.read.as_mut(), cx, &mut buf)) { - Ok(read) => read, - Err(err) => { - let body_error: Option = M::get_pin_mut(this.read) - .get_pin_mut() - .project() - .error - .take(); - - if let Some(body_error) = body_error { - return Poll::Ready(Some(Err(body_error.into()))); - } else if err.raw_os_error() == Some(SENTINEL_ERROR_CODE) { - // SENTINEL_ERROR_CODE only gets used when storing an underlying body error - unreachable!() - } else { - return Poll::Ready(Some(Err(err.into()))); - } - } - }; + ) -> Poll, Self::Error>>> { + // I'm not sure our previous body wrapping setup works. It assumes we can poll data and + // trailers separately, but we can't anymore - if read == 0 { - Poll::Ready(None) - } else { - Poll::Ready(Some(Ok(buf.freeze()))) - } + todo!() } - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - let this = self.project(); - let body = M::get_pin_mut(this.read) - .get_pin_mut() - .get_pin_mut() - .get_pin_mut(); - body.poll_trailers(cx).map_err(Into::into) - } + // fn poll_data( + // self: Pin<&mut Self>, + // cx: &mut Context<'_>, + // ) -> Poll>> { + // let mut this = self.project(); + // let mut buf = BytesMut::new(); + + // let read = match ready!(poll_read_buf(this.read.as_mut(), cx, &mut buf)) { + // Ok(read) => read, + // Err(err) => { + // let body_error: Option = M::get_pin_mut(this.read) + // .get_pin_mut() + // .project() + // .error + // .take(); + + // if let Some(body_error) = body_error { + // return Poll::Ready(Some(Err(body_error.into()))); + // } else if err.raw_os_error() == Some(SENTINEL_ERROR_CODE) { + // // SENTINEL_ERROR_CODE only gets used when storing an underlying body error + // unreachable!() + // } else { + // return Poll::Ready(Some(Err(err.into()))); + // } + // } + // }; + + // if read == 0 { + // Poll::Ready(None) + // } else { + // Poll::Ready(Some(Ok(buf.freeze()))) + // } + // } + + // fn poll_trailers( + // self: Pin<&mut Self>, + // cx: &mut Context<'_>, + // ) -> Poll, Self::Error>> { + // let this = self.project(); + // let body = M::get_pin_mut(this.read) + // .get_pin_mut() + // .get_pin_mut() + // .get_pin_mut(); + // body.poll_trailers(cx).map_err(Into::into) + // } } pin_project! { @@ -276,8 +286,17 @@ where { type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().body.poll_data(cx) + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match futures_util::ready!(self.as_mut().project().body.poll_frame(cx)) { + Some(Ok(frame)) => match frame.into_data() { + Ok(data) => return Poll::Ready(Some(Ok(data))), + Err(_frame) => {} + }, + Some(Err(err)) => return Poll::Ready(Some(Err(err))), + None => return Poll::Ready(None), + } + } } } diff --git a/tower-http/src/decompression/body.rs b/tower-http/src/decompression/body.rs index 279dc283..e70033d2 100644 --- a/tower-http/src/decompression/body.rs +++ b/tower-http/src/decompression/body.rs @@ -277,23 +277,23 @@ where type Data = Bytes; type Error = BoxError; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { match self.project().inner.project() { #[cfg(feature = "decompression-gzip")] - BodyInnerProj::Gzip { inner } => inner.poll_data(cx), + BodyInnerProj::Gzip { inner } => inner.poll_frame(cx), #[cfg(feature = "decompression-deflate")] - BodyInnerProj::Deflate { inner } => inner.poll_data(cx), + BodyInnerProj::Deflate { inner } => inner.poll_frame(cx), #[cfg(feature = "decompression-br")] - BodyInnerProj::Brotli { inner } => inner.poll_data(cx), + BodyInnerProj::Brotli { inner } => inner.poll_frame(cx), #[cfg(feature = "decompression-zstd")] - BodyInnerProj::Zstd { inner } => inner.poll_data(cx), - BodyInnerProj::Identity { inner } => match ready!(inner.poll_data(cx)) { - Some(Ok(mut buf)) => { - let bytes = buf.copy_to_bytes(buf.remaining()); - Poll::Ready(Some(Ok(bytes))) + BodyInnerProj::Zstd { inner } => inner.poll_frame(cx), + BodyInnerProj::Identity { inner } => match ready!(inner.poll_frame(cx)) { + Some(Ok(frame)) => { + let frame = frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining())); + Poll::Ready(Some(Ok(frame))) } Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), None => Poll::Ready(None), @@ -309,32 +309,6 @@ where BodyInnerProj::Zstd { inner } => match inner.0 {}, } } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - match self.project().inner.project() { - #[cfg(feature = "decompression-gzip")] - BodyInnerProj::Gzip { inner } => inner.poll_trailers(cx), - #[cfg(feature = "decompression-deflate")] - BodyInnerProj::Deflate { inner } => inner.poll_trailers(cx), - #[cfg(feature = "decompression-br")] - BodyInnerProj::Brotli { inner } => inner.poll_trailers(cx), - #[cfg(feature = "decompression-zstd")] - BodyInnerProj::Zstd { inner } => inner.poll_trailers(cx), - BodyInnerProj::Identity { inner } => inner.poll_trailers(cx).map_err(Into::into), - - #[cfg(not(feature = "decompression-gzip"))] - BodyInnerProj::Gzip { inner } => match inner.0 {}, - #[cfg(not(feature = "decompression-deflate"))] - BodyInnerProj::Deflate { inner } => match inner.0 {}, - #[cfg(not(feature = "decompression-br"))] - BodyInnerProj::Brotli { inner } => match inner.0 {}, - #[cfg(not(feature = "decompression-zstd"))] - BodyInnerProj::Zstd { inner } => match inner.0 {}, - } - } } #[cfg(feature = "decompression-gzip")] diff --git a/tower-http/src/decompression/mod.rs b/tower-http/src/decompression/mod.rs index 61aab330..7bad2ca3 100644 --- a/tower-http/src/decompression/mod.rs +++ b/tower-http/src/decompression/mod.rs @@ -120,6 +120,7 @@ mod tests { use super::*; use crate::compression::Compression; use crate::test_helpers::Body; + use crate::test_helpers::TowerHttpBodyExt; use bytes::BytesMut; use http::Request; use http::Response; diff --git a/tower-http/src/decompression/request/future.rs b/tower-http/src/decompression/request/future.rs index ce3b04ad..82ad412b 100644 --- a/tower-http/src/decompression/request/future.rs +++ b/tower-http/src/decompression/request/future.rs @@ -2,7 +2,10 @@ use crate::compression_utils::AcceptEncoding; use crate::BoxError; use bytes::Buf; use http::{header, HeaderValue, Response, StatusCode}; -use http_body::{combinators::UnsyncBoxBody, Body, Empty}; +use http_body::Body; +use http_body_util::combinators::UnsyncBoxBody; +use http_body_util::BodyExt; +use http_body_util::Empty; use pin_project_lite::pin_project; use std::future::Future; use std::pin::Pin; diff --git a/tower-http/src/decompression/request/mod.rs b/tower-http/src/decompression/request/mod.rs index 6dbbf85f..5f9c2438 100644 --- a/tower-http/src/decompression/request/mod.rs +++ b/tower-http/src/decompression/request/mod.rs @@ -5,8 +5,8 @@ pub(super) mod service; #[cfg(test)] mod tests { use super::service::RequestDecompression; - use crate::decompression::DecompressionBody; use crate::test_helpers::Body; + use crate::{decompression::DecompressionBody, test_helpers::TowerHttpBodyExt}; use bytes::BytesMut; use flate2::{write::GzEncoder, Compression}; use http::{header, Request, Response, StatusCode}; diff --git a/tower-http/src/decompression/request/service.rs b/tower-http/src/decompression/request/service.rs index 6d507366..dd383b6d 100644 --- a/tower-http/src/decompression/request/service.rs +++ b/tower-http/src/decompression/request/service.rs @@ -7,7 +7,8 @@ use crate::{ }; use bytes::Buf; use http::{header, Request, Response}; -use http_body::{combinators::UnsyncBoxBody, Body}; +use http_body::Body; +use http_body_util::combinators::UnsyncBoxBody; use std::task::{Context, Poll}; use tower_service::Service; diff --git a/tower-http/src/limit/body.rs b/tower-http/src/limit/body.rs index 4e746a5d..4e540f8b 100644 --- a/tower-http/src/limit/body.rs +++ b/tower-http/src/limit/body.rs @@ -1,6 +1,7 @@ use bytes::Bytes; -use http::{HeaderMap, HeaderValue, Response, StatusCode}; -use http_body::{Body, Full, SizeHint}; +use http::{HeaderValue, Response, StatusCode}; +use http_body::{Body, SizeHint}; +use http_body_util::Full; use pin_project_lite::pin_project; use std::pin::Pin; use std::task::{Context, Poll}; @@ -52,25 +53,13 @@ where type Data = Bytes; type Error = B::Error; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { match self.project().inner.project() { - BodyProj::PayloadTooLarge { body } => body.poll_data(cx).map_err(|err| match err {}), - BodyProj::Body { body } => body.poll_data(cx), - } - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - match self.project().inner.project() { - BodyProj::PayloadTooLarge { body } => { - body.poll_trailers(cx).map_err(|err| match err {}) - } - BodyProj::Body { body } => body.poll_trailers(cx), + BodyProj::PayloadTooLarge { body } => body.poll_frame(cx).map_err(|err| match err {}), + BodyProj::Body { body } => body.poll_frame(cx), } } diff --git a/tower-http/src/limit/service.rs b/tower-http/src/limit/service.rs index 66ae41fe..e057379c 100644 --- a/tower-http/src/limit/service.rs +++ b/tower-http/src/limit/service.rs @@ -1,6 +1,7 @@ use super::{RequestBodyLimitLayer, ResponseBody, ResponseFuture}; use http::{Request, Response}; -use http_body::{Body, Limited}; +use http_body::Body; +use http_body_util::Limited; use std::task::{Context, Poll}; use tower_service::Service; diff --git a/tower-http/src/macros.rs b/tower-http/src/macros.rs index 6641199b..556ce815 100644 --- a/tower-http/src/macros.rs +++ b/tower-http/src/macros.rs @@ -46,19 +46,11 @@ macro_rules! opaque_body { type Error = <$actual as http_body::Body>::Error; #[inline] - fn poll_data( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll>> { - self.project().inner.poll_data(cx) - } - - #[inline] - fn poll_trailers( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll, Self::Error>> { - self.project().inner.poll_trailers(cx) + fn poll_frame( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + self.project().inner.poll_frame(cx) } #[inline] diff --git a/tower-http/src/metrics/in_flight_requests.rs b/tower-http/src/metrics/in_flight_requests.rs index c5ee157b..c6fc55ae 100644 --- a/tower-http/src/metrics/in_flight_requests.rs +++ b/tower-http/src/metrics/in_flight_requests.rs @@ -267,19 +267,11 @@ where type Error = B::Error; #[inline] - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - self.project().inner.poll_data(cx) - } - - #[inline] - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - self.project().inner.poll_trailers(cx) + ) -> Poll, Self::Error>>> { + self.project().inner.poll_frame(cx) } #[inline] diff --git a/tower-http/src/services/fs/mod.rs b/tower-http/src/services/fs/mod.rs index ce6ef463..937e1f2a 100644 --- a/tower-http/src/services/fs/mod.rs +++ b/tower-http/src/services/fs/mod.rs @@ -2,8 +2,7 @@ use bytes::Bytes; use futures_util::Stream; -use http::HeaderMap; -use http_body::Body; +use http_body::{Body, Frame}; use pin_project_lite::pin_project; use std::{ io, @@ -67,17 +66,14 @@ where type Data = Bytes; type Error = io::Error; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - self.project().reader.poll_next(cx) - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) + ) -> Poll, Self::Error>>> { + match futures_util::ready!(self.project().reader.poll_next(cx)) { + Some(Ok(chunk)) => Poll::Ready(Some(Ok(Frame::data(chunk)))), + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } } } diff --git a/tower-http/src/services/fs/serve_dir/future.rs b/tower-http/src/services/fs/serve_dir/future.rs index 1d7eed9e..dc2ae2e4 100644 --- a/tower-http/src/services/fs/serve_dir/future.rs +++ b/tower-http/src/services/fs/serve_dir/future.rs @@ -12,7 +12,7 @@ use http::{ header::{self, ALLOW}, HeaderValue, Request, Response, StatusCode, }; -use http_body::{Body, Empty, Full}; +use http_body_util::{BodyExt, Empty, Full}; use pin_project_lite::pin_project; use std::{ convert::Infallible, diff --git a/tower-http/src/services/fs/serve_dir/mod.rs b/tower-http/src/services/fs/serve_dir/mod.rs index 9d1e5638..19060412 100644 --- a/tower-http/src/services/fs/serve_dir/mod.rs +++ b/tower-http/src/services/fs/serve_dir/mod.rs @@ -6,7 +6,8 @@ use crate::{ use bytes::Bytes; use futures_util::FutureExt; use http::{header, HeaderValue, Method, Request, Response, StatusCode}; -use http_body::{combinators::UnsyncBoxBody, Body, Empty}; +use http_body::Body; +use http_body_util::{combinators::UnsyncBoxBody, BodyExt, Empty}; use percent_encoding::percent_decode; use std::{ convert::Infallible, diff --git a/tower-http/src/services/fs/serve_dir/open_file.rs b/tower-http/src/services/fs/serve_dir/open_file.rs index a24aa088..401d34de 100644 --- a/tower-http/src/services/fs/serve_dir/open_file.rs +++ b/tower-http/src/services/fs/serve_dir/open_file.rs @@ -5,7 +5,7 @@ use super::{ use crate::content_encoding::{Encoding, QValue}; use bytes::Bytes; use http::{header, HeaderValue, Method, Request, Uri}; -use http_body::Empty; +use http_body_util::Empty; use http_range_header::RangeUnsatisfiableError; use std::{ ffi::OsStr, diff --git a/tower-http/src/services/fs/serve_dir/tests.rs b/tower-http/src/services/fs/serve_dir/tests.rs index 1cbb0a02..e78b3064 100644 --- a/tower-http/src/services/fs/serve_dir/tests.rs +++ b/tower-http/src/services/fs/serve_dir/tests.rs @@ -1,5 +1,5 @@ use crate::services::{ServeDir, ServeFile}; -use crate::test_helpers::{to_bytes, Body}; +use crate::test_helpers::{to_bytes, Body, TowerHttpBodyExt}; use brotli::BrotliDecompress; use bytes::Bytes; use flate2::bufread::{DeflateDecoder, GzDecoder}; diff --git a/tower-http/src/services/fs/serve_file.rs b/tower-http/src/services/fs/serve_file.rs index 2eb277bf..7d519fb6 100644 --- a/tower-http/src/services/fs/serve_file.rs +++ b/tower-http/src/services/fs/serve_file.rs @@ -129,6 +129,7 @@ where mod tests { use crate::services::ServeFile; use crate::test_helpers::Body; + use crate::test_helpers::TowerHttpBodyExt; use brotli::BrotliDecompress; use flate2::bufread::DeflateDecoder; use flate2::bufread::GzDecoder; diff --git a/tower-http/src/test_helpers.rs b/tower-http/src/test_helpers.rs index 4604b1c5..ac9d862c 100644 --- a/tower-http/src/test_helpers.rs +++ b/tower-http/src/test_helpers.rs @@ -1,17 +1,19 @@ use std::{ + future::Future, pin::Pin, task::{Context, Poll}, }; -use bytes::{Buf, BufMut, Bytes}; +use async_trait::async_trait; +use bytes::Bytes; use futures::TryStream; -use http::HeaderMap; -use http_body::Body as _; +use http_body::{Body as _, Frame}; +use http_body_util::BodyExt; use pin_project_lite::pin_project; use sync_wrapper::SyncWrapper; use tower::BoxError; -type BoxBody = http_body::combinators::UnsyncBoxBody; +type BoxBody = http_body_util::combinators::UnsyncBoxBody; #[derive(Debug)] pub(crate) struct Body(BoxBody); @@ -26,7 +28,7 @@ impl Body { } pub(crate) fn empty() -> Self { - Self::new(http_body::Empty::new()) + Self::new(http_body_util::Empty::new()) } pub(crate) fn from_stream(stream: S) -> Self @@ -51,7 +53,7 @@ macro_rules! body_from_impl { ($ty:ty) => { impl From<$ty> for Body { fn from(buf: $ty) -> Self { - Self::new(http_body::Full::from(buf)) + Self::new(http_body_util::Full::from(buf)) } } }; @@ -71,18 +73,11 @@ impl http_body::Body for Body { type Data = Bytes; type Error = BoxError; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> std::task::Poll>> { - Pin::new(&mut self.0).poll_data(cx) - } - - fn poll_trailers( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> std::task::Poll, Self::Error>> { - Pin::new(&mut self.0).poll_trailers(cx) + ) -> Poll, Self::Error>>> { + Pin::new(&mut self.0).poll_frame(cx) } fn size_hint(&self) -> http_body::SizeHint { @@ -110,24 +105,17 @@ where type Data = Bytes; type Error = BoxError; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let stream = self.project().stream.get_pin_mut(); match futures_util::ready!(stream.try_poll_next(cx)) { - Some(Ok(chunk)) => Poll::Ready(Some(Ok(chunk.into()))), + Some(Ok(chunk)) => Poll::Ready(Some(Ok(Frame::data(chunk.into())))), Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), None => Poll::Ready(None), } } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } } // copied from hyper @@ -136,29 +124,39 @@ where T: http_body::Body, { futures_util::pin_mut!(body); + Ok(body.collect().await?.to_bytes()) +} - // If there's only 1 chunk, we can just return Buf::to_bytes() - let mut first = if let Some(buf) = body.data().await { - buf? - } else { - return Ok(Bytes::new()); - }; +pub(crate) trait TowerHttpBodyExt: http_body::Body + Unpin { + /// Returns future that resolves to next data chunk, if any. + fn data(&mut self) -> Data<'_, Self> + where + Self: Unpin + Sized, + { + Data(self) + } +} - let second = if let Some(buf) = body.data().await { - buf? - } else { - return Ok(first.copy_to_bytes(first.remaining())); - }; +impl TowerHttpBodyExt for B where B: http_body::Body + Unpin {} - // With more than 1 buf, we gotta flatten into a Vec first. - let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize; - let mut vec = Vec::with_capacity(cap); - vec.put(first); - vec.put(second); +pub(crate) struct Data<'a, T>(pub(crate) &'a mut T); - while let Some(buf) = body.data().await { - vec.put(buf?); +impl<'a, T> Future for Data<'a, T> +where + T: http_body::Body + Unpin, +{ + type Output = Option>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + match futures_util::ready!(Pin::new(&mut self.0).poll_frame(cx)) { + Some(Ok(frame)) => match frame.into_data() { + Ok(data) => return Poll::Ready(Some(Ok(data))), + Err(_frame) => {} + }, + Some(Err(err)) => return Poll::Ready(Some(Err(err))), + None => return Poll::Ready(None), + } + } } - - Ok(vec.into()) } diff --git a/tower-http/src/timeout/body.rs b/tower-http/src/timeout/body.rs index 79712efd..8a74ae21 100644 --- a/tower-http/src/timeout/body.rs +++ b/tower-http/src/timeout/body.rs @@ -50,13 +50,8 @@ pin_project! { /// Wrapper around a [`http_body::Body`] to time out if data is not ready within the specified duration. pub struct TimeoutBody { timeout: Duration, - // In http-body 1.0, `poll_*` will be merged into `poll_frame`. - // Merge the two `sleep_data` and `sleep_trailers` into one `sleep`. - // See: https://github.com/tower-rs/tower-http/pull/303#discussion_r1004834958 #[pin] - sleep_data: Option, - #[pin] - sleep_trailers: Option, + sleep: Option, #[pin] body: B, } @@ -67,8 +62,7 @@ impl TimeoutBody { pub fn new(timeout: Duration, body: B) -> Self { TimeoutBody { timeout, - sleep_data: None, - sleep_trailers: None, + sleep: None, body, } } @@ -82,18 +76,18 @@ where type Data = B::Data; type Error = Box; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let mut this = self.project(); // Start the `Sleep` if not active. - let sleep_pinned = if let Some(some) = this.sleep_data.as_mut().as_pin_mut() { + let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() { some } else { - this.sleep_data.set(Some(sleep(*this.timeout))); - this.sleep_data.as_mut().as_pin_mut().unwrap() + this.sleep.set(Some(sleep(*this.timeout))); + this.sleep.as_mut().as_pin_mut().unwrap() }; // Error if the timeout has expired. @@ -102,36 +96,11 @@ where } // Check for body data. - let data = ready!(this.body.poll_data(cx)); - // Some data is ready. Reset the `Sleep`... - this.sleep_data.set(None); + let frame = ready!(this.body.poll_frame(cx)); + // A frame is ready. Reset the `Sleep`... + this.sleep.set(None); - Poll::Ready(data.transpose().map_err(Into::into).transpose()) - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - let mut this = self.project(); - - // In http-body 1.0, `poll_*` will be merged into `poll_frame`. - // Merge the two `sleep_data` and `sleep_trailers` into one `sleep`. - // See: https://github.com/tower-rs/tower-http/pull/303#discussion_r1004834958 - - let sleep_pinned = if let Some(some) = this.sleep_trailers.as_mut().as_pin_mut() { - some - } else { - this.sleep_trailers.set(Some(sleep(*this.timeout))); - this.sleep_trailers.as_mut().as_pin_mut().unwrap() - }; - - // Error if the timeout has expired. - if let Poll::Ready(()) = sleep_pinned.poll(cx) { - return Poll::Ready(Err(Box::new(TimeoutError(())))); - } - - this.body.poll_trailers(cx).map_err(Into::into) + Poll::Ready(frame.transpose().map_err(Into::into).transpose()) } } @@ -148,9 +117,13 @@ impl std::fmt::Display for TimeoutError { } #[cfg(test)] mod tests { + use crate::test_helpers::TowerHttpBodyExt; + use super::*; use bytes::Bytes; + use http_body::Frame; + use http_body_util::BodyExt; use pin_project_lite::pin_project; use std::{error::Error, fmt::Display}; @@ -175,19 +148,14 @@ mod tests { type Data = Bytes; type Error = MockError; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let this = self.project(); - this.sleep.poll(cx).map(|_| Some(Ok(vec![].into()))) - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - todo!() + this.sleep + .poll(cx) + .map(|_| Some(Ok(Frame::data(vec![].into())))) } } @@ -201,7 +169,7 @@ mod tests { }; let timeout_body = TimeoutBody::new(timeout_sleep, mock_body); - assert!(timeout_body.boxed().data().await.unwrap().is_ok()); + assert!(timeout_body.boxed().data().await.expect("no data").is_ok()); } #[tokio::test] diff --git a/tower-http/src/trace/body.rs b/tower-http/src/trace/body.rs index d38770d5..39579575 100644 --- a/tower-http/src/trace/body.rs +++ b/tower-http/src/trace/body.rs @@ -2,7 +2,7 @@ use super::{OnBodyChunk, OnEos, OnFailure}; use crate::classify::ClassifyEos; use futures_core::ready; use http::HeaderMap; -use http_body::Body; +use http_body::{Body, Frame}; use pin_project_lite::pin_project; use std::{ fmt, @@ -41,70 +41,57 @@ where type Data = B::Data; type Error = B::Error; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let this = self.project(); let _guard = this.span.enter(); - - let result = if let Some(result) = ready!(this.inner.poll_data(cx)) { - result - } else { - return Poll::Ready(None); - }; + let result = ready!(this.inner.poll_frame(cx)); let latency = this.start.elapsed(); *this.start = Instant::now(); - match &result { - Ok(chunk) => { - this.on_body_chunk.on_body_chunk(chunk, latency, this.span); + match result { + Some(Ok(frame)) => { + let frame = match frame.into_data() { + Ok(chunk) => { + this.on_body_chunk.on_body_chunk(&chunk, latency, this.span); + Frame::data(chunk) + } + Err(frame) => frame, + }; + + let frame = match frame.into_trailers() { + Ok(trailers) => { + if let Some((on_eos, stream_start)) = this.on_eos.take() { + on_eos.on_eos(Some(&trailers), stream_start.elapsed(), this.span); + } + Frame::trailers(trailers) + } + Err(frame) => frame, + }; + + Poll::Ready(Some(Ok(frame))) } - Err(err) => { + Some(Err(err)) => { if let Some((classify_eos, mut on_failure)) = this.classify_eos.take().zip(this.on_failure.take()) { - let failure_class = classify_eos.classify_error(err); + let failure_class = classify_eos.classify_error(&err); on_failure.on_failure(failure_class, latency, this.span); } - } - } - - Poll::Ready(Some(result)) - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - let this = self.project(); - let _guard = this.span.enter(); - let result = ready!(this.inner.poll_trailers(cx)); - - let latency = this.start.elapsed(); - - if let Some((classify_eos, mut on_failure)) = - this.classify_eos.take().zip(this.on_failure.take()) - { - match &result { - Ok(trailers) => { - if let Err(failure_class) = classify_eos.classify_eos(trailers.as_ref()) { - on_failure.on_failure(failure_class, latency, this.span); - } - if let Some((on_eos, stream_start)) = this.on_eos.take() { - on_eos.on_eos(trailers.as_ref(), stream_start.elapsed(), this.span); - } - } - Err(err) => { - let failure_class = classify_eos.classify_error(err); - on_failure.on_failure(failure_class, latency, this.span); + Poll::Ready(Some(Err(err))) + } + None => { + if let Some((on_eos, stream_start)) = this.on_eos.take() { + on_eos.on_eos(None, stream_start.elapsed(), this.span); } + + Poll::Ready(None) } } - - Poll::Ready(result) } fn is_end_stream(&self) -> bool {