diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 5fa4f7438c5..59d5a765e2e 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -10,8 +10,8 @@ edition = "2018" tokio = { version = "0.3.0", path = "../tokio", features = ["full", "tracing"] } tracing = "0.1" tracing-subscriber = { version = "0.2.7", default-features = false, features = ["fmt", "ansi", "env-filter", "chrono", "tracing-log"] } -tokio-util = { version = "0.5.0", path = "../tokio-util", features = ["full"] } -bytes = "0.6" +tokio-util = { version = "0.4.0", path = "../tokio-util", features = ["full"] } +bytes = "0.5" futures = "0.3.0" http = "0.2" serde = "1.0" diff --git a/tokio-util/Cargo.toml b/tokio-util/Cargo.toml index 3c5b1bf97e4..11419951235 100644 --- a/tokio-util/Cargo.toml +++ b/tokio-util/Cargo.toml @@ -7,7 +7,7 @@ name = "tokio-util" # - Cargo.toml # - Update CHANGELOG.md. # - Create "v0.2.x" git tag. -version = "0.5.0" +version = "0.4.0" edition = "2018" authors = ["Tokio Contributors "] license = "MIT" @@ -27,15 +27,15 @@ default = [] full = ["codec", "compat", "io", "time"] compat = ["futures-io",] -codec = ["tokio/io-util", "tokio/stream"] +codec = ["tokio/stream"] time = ["tokio/time","slab"] -io = ["tokio/io-util"] +io = [] rt = ["tokio/rt"] [dependencies] tokio = { version = "0.3.0", path = "../tokio" } -bytes = "0.6.0" +bytes = "0.5.0" futures-core = "0.3.0" futures-sink = "0.3.0" futures-io = { version = "0.3.0", optional = true } diff --git a/tokio-util/src/codec/framed_impl.rs b/tokio-util/src/codec/framed_impl.rs index ccb8b3c8e32..c161808f66e 100644 --- a/tokio-util/src/codec/framed_impl.rs +++ b/tokio-util/src/codec/framed_impl.rs @@ -2,7 +2,7 @@ use crate::codec::decoder::Decoder; use crate::codec::encoder::Encoder; use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite}, + io::{AsyncRead, AsyncWrite}, stream::Stream, }; @@ -118,6 +118,8 @@ where type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use crate::util::poll_read_buf; + let mut pinned = self.project(); let state: &mut ReadFrame = pinned.state.borrow_mut(); loop { @@ -148,7 +150,7 @@ where // got room for at least one byte to read to ensure that we don't // get a spurious 0 that looks like EOF state.buffer.reserve(1); - let bytect = match pinned.inner.as_mut().poll_read_buf(&mut state.buffer, cx)? { + let bytect = match poll_read_buf(cx, pinned.inner.as_mut(), &mut state.buffer)? { Poll::Ready(ct) => ct, Poll::Pending => return Poll::Pending, }; diff --git a/tokio-util/src/io/mod.rs b/tokio-util/src/io/mod.rs index 53066c4e444..7cf25989647 100644 --- a/tokio-util/src/io/mod.rs +++ b/tokio-util/src/io/mod.rs @@ -6,8 +6,12 @@ //! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html //! [`AsyncRead`]: tokio::io::AsyncRead +mod poll_read_buf; +mod read_buf; mod reader_stream; mod stream_reader; +pub use self::poll_read_buf::poll_read_buf; +pub use self::read_buf::read_buf; pub use self::reader_stream::ReaderStream; pub use self::stream_reader::StreamReader; diff --git a/tokio-util/src/io/poll_read_buf.rs b/tokio-util/src/io/poll_read_buf.rs new file mode 100644 index 00000000000..efce7ced2bb --- /dev/null +++ b/tokio-util/src/io/poll_read_buf.rs @@ -0,0 +1,90 @@ +use bytes::BufMut; +use futures_core::ready; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, ReadBuf}; + +/// Try to read data from an `AsyncRead` into an implementer of the [`Buf`] trait. +/// +/// [`Buf`]: bytes::Buf +/// +/// # Example +/// +/// ``` +/// use bytes::{Bytes, BytesMut}; +/// use tokio::stream; +/// use tokio::io::Result; +/// use tokio_util::io::{StreamReader, poll_read_buf}; +/// use futures::future::poll_fn; +/// use std::pin::Pin; +/// # #[tokio::main] +/// # async fn main() -> std::io::Result<()> { +/// +/// // Create a reader from an iterator. This particular reader will always be +/// // ready. +/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); +/// +/// let mut buf = BytesMut::new(); +/// let mut reads = 0; +/// +/// loop { +/// reads += 1; +/// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?; +/// +/// if n == 0 { +/// break; +/// } +/// } +/// +/// // one or more reads might be necessary. +/// assert!(reads >= 1); +/// assert_eq!(&buf[..], &[0, 1, 2, 3]); +/// # Ok(()) +/// # } +/// ``` +pub fn poll_read_buf( + read: Pin<&mut R>, + cx: &mut Context<'_>, + buf: &mut B, +) -> Poll> +where + R: AsyncRead, + B: BufMut, +{ + if !buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let n = { + let mut buf = ReadBuf::uninit(buf.bytes_mut()); + let before = buf.filled().as_ptr(); + + ready!(read.poll_read(cx, &mut buf)?); + + // This prevents a malicious read implementation from swapping out the + // buffer being read, which would allow `filled` to be advanced without + // actually initializing the provided buffer. + // + // We avoid this by asserting that the `ReadBuf` instance wraps the same + // memory address both before and after the poll. Which will panic in + // case its swapped. + // + // See https://github.com/tokio-rs/tokio/issues/2827 for more info. + assert! { + std::ptr::eq(before, buf.filled().as_ptr()), + "Read buffer must not be changed during a read poll. \ + See https://github.com/tokio-rs/tokio/issues/2827 for more info." + }; + + buf.filled().len() + }; + + // Safety: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + buf.advance_mut(n); + } + + Poll::Ready(Ok(n)) +} diff --git a/tokio-util/src/io/read_buf.rs b/tokio-util/src/io/read_buf.rs new file mode 100644 index 00000000000..d617fa6f042 --- /dev/null +++ b/tokio-util/src/io/read_buf.rs @@ -0,0 +1,65 @@ +use bytes::BufMut; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::AsyncRead; + +/// Read data from an `AsyncRead` into an implementer of the [`Buf`] trait. +/// +/// [`Buf`]: bytes::Buf +/// +/// # Example +/// +/// ``` +/// use bytes::{Bytes, BytesMut}; +/// use tokio::stream; +/// use tokio::io::Result; +/// use tokio_util::io::{StreamReader, read_buf}; +/// # #[tokio::main] +/// # async fn main() -> std::io::Result<()> { +/// +/// // Create a reader from an iterator. This particular reader will always be +/// // ready. +/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); +/// +/// let mut buf = BytesMut::new(); +/// let mut reads = 0; +/// +/// loop { +/// reads += 1; +/// let n = read_buf(&mut read, &mut buf).await?; +/// +/// if n == 0 { +/// break; +/// } +/// } +/// +/// // one or more reads might be necessary. +/// assert!(reads >= 1); +/// assert_eq!(&buf[..], &[0, 1, 2, 3]); +/// # Ok(()) +/// # } +/// ``` +pub async fn read_buf(read: &mut R, buf: &mut B) -> io::Result +where + R: AsyncRead + Unpin, + B: BufMut, +{ + return ReadBufFn(read, buf).await; + + struct ReadBufFn<'a, R, B>(&'a mut R, &'a mut B); + + impl<'a, R, B> Future for ReadBufFn<'a, R, B> + where + R: AsyncRead + Unpin, + B: BufMut, + { + type Output = io::Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = &mut *self; + super::poll_read_buf(Pin::new(this.0), cx, this.1) + } + } +} diff --git a/tokio-util/src/io/reader_stream.rs b/tokio-util/src/io/reader_stream.rs index 49288c45daa..ab0c22fba73 100644 --- a/tokio-util/src/io/reader_stream.rs +++ b/tokio-util/src/io/reader_stream.rs @@ -3,7 +3,7 @@ use futures_core::stream::Stream; use pin_project_lite::pin_project; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio::io::AsyncRead; const CAPACITY: usize = 4096; @@ -70,9 +70,11 @@ impl ReaderStream { impl Stream for ReaderStream { type Item = std::io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use crate::util::poll_read_buf; + let mut this = self.as_mut().project(); - let mut reader = match this.reader.as_pin_mut() { + let reader = match this.reader.as_pin_mut() { Some(r) => r, None => return Poll::Ready(None), }; @@ -81,7 +83,7 @@ impl Stream for ReaderStream { this.buf.reserve(CAPACITY); } - match reader.poll_read_buf(&mut this.buf, cx) { + match poll_read_buf(cx, reader, &mut this.buf) { Poll::Pending => Poll::Pending, Poll::Ready(Err(err)) => { self.project().reader.set(None); diff --git a/tokio-util/src/lib.rs b/tokio-util/src/lib.rs index 1e4b9d40050..10b828ef9fd 100644 --- a/tokio-util/src/lib.rs +++ b/tokio-util/src/lib.rs @@ -57,3 +57,37 @@ pub mod either; #[cfg(feature = "time")] pub mod time; + +#[cfg(any(feature = "io", feature = "codec"))] +mod util { + use tokio::io::{AsyncRead, ReadBuf}; + + use bytes::BufMut; + use futures_core::ready; + use std::io; + use std::pin::Pin; + use std::task::{Context, Poll}; + + pub(crate) fn poll_read_buf( + cx: &mut Context<'_>, + io: Pin<&mut T>, + buf: &mut impl BufMut, + ) -> Poll> { + if !buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let orig = buf.bytes_mut().as_ptr() as *const u8; + let mut b = ReadBuf::uninit(buf.bytes_mut()); + + ready!(io.poll_read(cx, &mut b))?; + let n = b.filled().len(); + + // Safety: we can assume `n` bytes were read, since they are in`filled`. + assert_eq!(orig, b.filled().as_ptr()); + unsafe { + buf.advance_mut(n); + } + Poll::Ready(Ok(n)) + } +} diff --git a/tokio/src/io/util/async_read_ext.rs b/tokio/src/io/util/async_read_ext.rs index 96a5f70d14e..1f918f1973f 100644 --- a/tokio/src/io/util/async_read_ext.rs +++ b/tokio/src/io/util/async_read_ext.rs @@ -1,6 +1,6 @@ use crate::io::util::chain::{chain, Chain}; use crate::io::util::read::{read, Read}; -use crate::io::util::read_buf::{poll_read_buf, read_buf, ReadBuf}; +use crate::io::util::read_buf::{read_buf, ReadBuf}; use crate::io::util::read_exact::{read_exact, ReadExact}; use crate::io::util::read_int::{ ReadI128, ReadI128Le, ReadI16, ReadI16Le, ReadI32, ReadI32Le, ReadI64, ReadI64Le, ReadI8, @@ -14,8 +14,6 @@ use crate::io::util::take::{take, Take}; use crate::io::AsyncRead; use bytes::BufMut; -use std::io; -use std::task::{Context, Poll}; cfg_io_util! { /// Defines numeric reader @@ -233,28 +231,6 @@ cfg_io_util! { read_buf(self, buf) } - /// Attempts to pull some bytes from this source into the specified buffer, - /// advancing the buffer's internal cursor if the underlying reader is ready. - /// - /// Usually, only a single `read` syscall is issued, even if there is - /// more space in the supplied buffer. - /// - /// # Return - /// - /// On a successful read, the number of read bytes is returned. If the - /// supplied buffer is not empty and the function returns `Ok(0)` then - /// the source has reached an "end-of-file" event. - /// - /// # Errors - /// - /// If this function encounters any form of I/O or other error, an error - /// variant will be returned. If an error is returned then it must be - /// guaranteed that no bytes were read. - /// ``` - fn poll_read_buf<'a, B>(&'a mut self, buf: &'a mut B, cx: &mut Context<'_>) -> Poll> where Self: Unpin + Sized, B: BufMut { - poll_read_buf(self, buf, cx) - } - /// Reads the exact number of bytes required to fill `buf`. /// /// Equivalent to: diff --git a/tokio/src/io/util/read_buf.rs b/tokio/src/io/util/read_buf.rs index 7df429d72ff..696deefd1e6 100644 --- a/tokio/src/io/util/read_buf.rs +++ b/tokio/src/io/util/read_buf.rs @@ -40,44 +40,33 @@ where type Output = io::Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut me = self.project(); - poll_read_buf(&mut me.reader, &mut me.buf, cx) - } -} + use crate::io::ReadBuf; + use std::mem::MaybeUninit; -pub(crate) fn poll_read_buf<'a, R, B>( - reader: &'a mut R, - buf: &'a mut B, - cx: &mut Context<'_>, -) -> Poll> -where - R: AsyncRead + Unpin, - B: BufMut, -{ - use crate::io::ReadBuf; - use std::mem::MaybeUninit; + let me = self.project(); - if !buf.has_remaining_mut() { - return Poll::Ready(Ok(0)); - } + if !me.buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } - let n = { - let dst = buf.bytes_mut(); - let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit]) }; - let mut buf = ReadBuf::uninit(dst); - let ptr = buf.filled().as_ptr(); - ready!(Pin::new(reader).poll_read(cx, &mut buf)?); + let n = { + let dst = me.buf.bytes_mut(); + let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit]) }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(Pin::new(me.reader).poll_read(cx, &mut buf)?); - // Ensure the pointer does not change from under us - assert_eq!(ptr, buf.filled().as_ptr()); - buf.filled().len() - }; + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; - // Safety: This is guaranteed to be the number of initialized (and read) - // bytes due to the invariants provided by `ReadBuf::filled`. - unsafe { - buf.advance_mut(n); - } + // Safety: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + me.buf.advance_mut(n); + } - Poll::Ready(Ok(n)) + Poll::Ready(Ok(n)) + } } diff --git a/tokio/tests/io_read_buf.rs b/tokio/tests/io_read_buf.rs index 35c12126f7a..0328168d7ab 100644 --- a/tokio/tests/io_read_buf.rs +++ b/tokio/tests/io_read_buf.rs @@ -4,7 +4,6 @@ use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; use tokio_test::assert_ok; -use futures::future::poll_fn; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; @@ -35,38 +34,3 @@ async fn read_buf() { assert_eq!(n, 11); assert_eq!(buf[..], b"hello world"[..]); } - -#[tokio::test] -async fn poll_read_buf() { - struct Rd { - cnt: usize, - } - - impl AsyncRead for Rd { - fn poll_read( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - self.cnt += 1; - buf.put_slice(b"hello world"); - Poll::Ready(Ok(())) - } - } - - let mut buf = vec![]; - let mut rd = Rd { cnt: 0 }; - - let res = tokio::spawn(async move { - poll_fn(|cx| { - let res = rd.poll_read_buf(&mut buf, cx); - assert_eq!(1, rd.cnt); - assert_eq!(buf[..], b"hello world"[..]); - res - }) - .await - }) - .await; - - assert!(matches!(res, Ok(Ok(11usize)))); -}