From 3949517a6ace7882a5c44307f987fc5bba9aff0c Mon Sep 17 00:00:00 2001 From: Thomas Scholtes Date: Mon, 29 Jun 2020 19:11:21 +0200 Subject: [PATCH] Fix panics and unsafe code This change fixes panics (#6) and unsafe code (#5). This comes at the cost of an additional copy of the data send through the pipe and having a buffer in the state. All unsafe code is removed and the need for a custom `Drop` implementation which makes the code overall easier. We also provide an implementation for traits from `futures` which is behind a feature flag. We also add tests. --- Cargo.toml | 9 +++- README.md | 5 +- examples/main.rs | 4 +- src/lib.rs | 76 ++++++++++++++++++++------ src/reader.rs | 107 ++++++++++++++++++------------------ src/state.rs | 14 ++--- src/writer.rs | 137 ++++++++++++++++++++++++++++------------------- 7 files changed, 210 insertions(+), 142 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 14893f0..ff53aa6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,9 +11,16 @@ readme = "README.md" license = "MIT" edition = "2018" +[features] +default = ["tokio"] + [dependencies] -tokio = { version = "1", features= [] } +tokio = { version = "1", features= [], optional = true } log = "0.4" +futures = { version = "0.3", optional = true } [dev-dependencies] tokio = { version = "1", features = ["full"] } + +[package.metadata.docs.rs] +features = ["futures", "tokio"] diff --git a/README.md b/README.md index 5721578..ecfaebd 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,8 @@ [![Documentation](https://docs.rs/async-pipe/badge.svg)](https://docs.rs/async-pipe) [![MIT](https://img.shields.io/crates/l/async-pipe.svg)](./LICENSE) -Creates an asynchronous piped reader and writer pair using `tokio.rs`. +Creates an asynchronous piped reader and writer pair using `tokio.rs` or +`futures` [Docs](https://docs.rs/async-pipe) @@ -38,4 +39,4 @@ async fn main() { ## Contributing -Your PRs and stars are always welcome. \ No newline at end of file +Your PRs and stars are always welcome. diff --git a/examples/main.rs b/examples/main.rs index c336b5a..6b89d5d 100644 --- a/examples/main.rs +++ b/examples/main.rs @@ -5,9 +5,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; async fn main() { let (mut w, mut r) = async_pipe::pipe(); - tokio::spawn(async move { - w.write_all(b"hello world").await.unwrap(); - }); + let _ = w.write_all(b"hello world").await; let mut v = Vec::new(); r.read_to_end(&mut v).await.unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 5d1f6d4..1eac958 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -//! Creates an asynchronous piped reader and writer pair using `tokio.rs`. +//! Creates an asynchronous piped reader and writer pair using `tokio.rs` and `futures`. //! //! # Examples //! @@ -21,6 +21,11 @@ //! //! tokio::runtime::Runtime::new().unwrap().block_on(run()); //! ``` +//! +//! # Featues +//! +//! * `tokio` (default) Implement `AsyncWrite` and `AsyncRead` from `tokio::io`. +//! * `futures` Implement `AsyncWrite` and `AsyncRead` from `futures::io` use state::State; use std::sync::{Arc, Mutex}; @@ -37,11 +42,8 @@ pub fn pipe() -> (PipeWriter, PipeReader) { let shared_state = Arc::new(Mutex::new(State { reader_waker: None, writer_waker: None, - data: None, - done_reading: false, - read: 0, - done_cycle: true, closed: false, + buffer: Vec::new(), })); let w = PipeWriter { @@ -49,30 +51,72 @@ pub fn pipe() -> (PipeWriter, PipeReader) { }; let r = PipeReader { - state: shared_state.clone(), + state: shared_state, }; (w, r) } #[cfg(test)] -mod tests { +mod test { use super::*; + use std::io; use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[tokio::test] - async fn should_read_expected_text() { - const EXPECTED: &'static str = "hello world"; + async fn read_write() { + let (mut writer, mut reader) = pipe(); + let data = b"hello world"; + + let write_handle = tokio::spawn(async move { + writer.write_all(data).await.unwrap(); + }); + + let mut read_buf = Vec::new(); + reader.read_to_end(&mut read_buf).await.unwrap(); + write_handle.await.unwrap(); + + assert_eq!(&read_buf, data); + } + + #[tokio::test] + async fn eof_when_writer_is_shutdown() { + let (mut writer, mut reader) = pipe(); + writer.shutdown().await.unwrap(); + let mut buf = [0u8; 8]; + let bytes_read = reader.read(&mut buf).await.unwrap(); + assert_eq!(bytes_read, 0); + } + + #[tokio::test] + async fn broken_pipe_when_reader_is_dropped() { + let (mut writer, reader) = pipe(); + drop(reader); + let io_error = writer.write_all(&[0u8; 8]).await.unwrap_err(); + assert_eq!(io_error.kind(), io::ErrorKind::BrokenPipe); + } - let (mut w, mut r) = pipe(); + #[tokio::test] + async fn eof_when_writer_is_dropped() { + let (writer, mut reader) = pipe(); + drop(writer); + let mut buf = [0u8; 8]; + let bytes_read = reader.read(&mut buf).await.unwrap(); + assert_eq!(bytes_read, 0); + } + + #[tokio::test] + async fn drop_read_exact() { + let (mut writer, mut reader) = pipe(); + const BUF_SIZE: usize = 8; - tokio::spawn(async move { - w.write_all(EXPECTED.as_bytes()).await.unwrap(); + let write_handle = tokio::spawn(async move { + writer.write_all(&[0u8; BUF_SIZE]).await.unwrap(); }); - let mut v = Vec::new(); - r.read_to_end(&mut v).await.unwrap(); - let actual = String::from_utf8(v).unwrap(); - assert_eq!(EXPECTED, actual.as_str()); + let mut buf = [0u8; BUF_SIZE]; + reader.read_exact(&mut buf).await.unwrap(); + drop(reader); + write_handle.await.unwrap(); } } diff --git a/src/reader.rs b/src/reader.rs index a286cc1..ca5df0f 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -1,11 +1,17 @@ -use crate::state::{Data, State}; +use crate::state::State; +use std::io; use std::pin::Pin; -use std::ptr; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; -use tokio::io::{self, AsyncRead, ReadBuf}; -/// The read half of the pipe which implements [`AsyncRead`](https://docs.rs/tokio/0.2.15/tokio/io/trait.AsyncRead.html). +/// The read half of the pipe +/// +/// Implements [`tokio::io::AsyncRead`][tokio-async-read] when feature `tokio` is enabled (the +/// default). Implements [`futures::io::AsyncRead`][futures-async-read] when feature `futures` is +/// enabled. +/// +/// [futures-async-read]: https://docs.rs/futures/0.3.16/futures/io/trait.AsyncRead.html +/// [tokio-async-read]: https://docs.rs/tokio/1.9.0/tokio/io/trait.AsyncRead.html pub struct PipeReader { pub(crate) state: Arc>, } @@ -46,7 +52,7 @@ impl PipeReader { } }; - Ok(state.done_cycle) + Ok(state.buffer.is_empty()) } fn wake_writer_half(&self, state: &State) { @@ -55,36 +61,13 @@ impl PipeReader { } } - fn copy_data_into_buffer(&self, data: &Data, buf: &mut ReadBuf) -> usize { - let len = data.len.min(buf.capacity()); - unsafe { - ptr::copy_nonoverlapping(data.ptr, buf.initialize_unfilled().as_mut_ptr(), len); - } - len - } -} - -impl Drop for PipeReader { - fn drop(&mut self) { - if let Err(err) = self.close() { - log::warn!( - "{}: PipeReader: Failed to close the channel on drop: {}", - env!("CARGO_PKG_NAME"), - err - ); - } - } -} - -impl AsyncRead for PipeReader { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - buf: &mut ReadBuf, - ) -> Poll> { - let mut state; - match self.state.lock() { - Ok(s) => state = s, + buf: &mut [u8], + ) -> Poll> { + let mut state = match self.state.lock() { + Ok(s) => s, Err(err) => { return Poll::Ready(Err(io::Error::new( io::ErrorKind::Other, @@ -95,31 +78,49 @@ impl AsyncRead for PipeReader { ), ))) } - } - - if state.closed { - return Poll::Ready(Ok(())); - } - - return if state.done_cycle { - state.reader_waker = Some(cx.waker().clone()); - Poll::Pending - } else { - if let Some(ref data) = state.data { - let copied_bytes_len = self.copy_data_into_buffer(data, buf); - - state.data = None; - state.read = copied_bytes_len; - state.done_reading = true; - state.reader_waker = None; - - self.wake_writer_half(&*state); + }; - Poll::Ready(Ok(())) + if state.buffer.is_empty() { + if state.closed || Arc::strong_count(&self.state) == 1 { + Poll::Ready(Ok(0)) } else { + self.wake_writer_half(&*state); state.reader_waker = Some(cx.waker().clone()); Poll::Pending } - }; + } else { + self.wake_writer_half(&*state); + let size_to_read = state.buffer.len().min(buf.len()); + let (to_read, rest) = state.buffer.split_at(size_to_read); + buf[..size_to_read].copy_from_slice(to_read); + state.buffer = rest.to_vec(); + + Poll::Ready(Ok(size_to_read)) + } + } +} + +#[cfg(feature = "tokio")] +impl tokio::io::AsyncRead for PipeReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut tokio::io::ReadBuf, + ) -> Poll> { + let dst = buf.initialize_unfilled(); + self.poll_read(cx, dst).map_ok(|read| { + buf.advance(read); + }) + } +} + +#[cfg(feature = "futures")] +impl futures::io::AsyncRead for PipeReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { + self.poll_read(cx, buf) } } diff --git a/src/state.rs b/src/state.rs index 7982fdf..fc37fbd 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,18 +1,10 @@ use std::task::Waker; +pub const BUFFER_SIZE: usize = 1024; + pub(crate) struct State { pub(crate) reader_waker: Option, pub(crate) writer_waker: Option, - pub(crate) data: Option, - pub(crate) done_reading: bool, - pub(crate) read: usize, - pub(crate) done_cycle: bool, pub(crate) closed: bool, + pub(crate) buffer: Vec, } - -pub(crate) struct Data { - pub(crate) ptr: *const u8, - pub(crate) len: usize, -} - -unsafe impl Send for Data {} diff --git a/src/writer.rs b/src/writer.rs index 119fcca..f32c3b2 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -1,11 +1,17 @@ -use crate::state::Data; -use crate::state::State; +use crate::state::{State, BUFFER_SIZE}; +use std::io; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; -use tokio::io::{self, AsyncWrite}; -/// The write half of the pipe which implements [`AsyncWrite`](https://docs.rs/tokio/0.2.16/tokio/io/trait.AsyncWrite.html). +/// The write half of the pipe +/// +/// Implements [`tokio::io::AsyncWrite`][tokio-async-write] when feature `tokio` is enabled (the +/// default). Implements [`futures::io::AsyncWrite`][futures-async-write] when feature `futures` is +/// enabled. +/// +/// [futures-async-write]: https://docs.rs/futures/0.3.16/futures/io/trait.AsyncWrite.html +/// [tokio-async-write]: https://docs.rs/tokio/1.9.0/tokio/io/trait.AsyncWrite.html pub struct PipeWriter { pub(crate) state: Arc>, } @@ -46,7 +52,7 @@ impl PipeWriter { } }; - Ok(state.done_cycle) + Ok(state.buffer.is_empty()) } fn wake_reader_half(&self, state: &State) { @@ -54,25 +60,20 @@ impl PipeWriter { waker.clone().wake(); } } -} -impl Drop for PipeWriter { - fn drop(&mut self) { - if let Err(err) = self.close() { - log::warn!( - "{}: PipeWriter: Failed to close the channel on drop: {}", - env!("CARGO_PKG_NAME"), - err - ); + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + if Arc::strong_count(&self.state) == 1 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::BrokenPipe, + format!( + "{}: PipeWriter: The channel is closed", + env!("CARGO_PKG_NAME") + ), + ))); } - } -} -impl AsyncWrite for PipeWriter { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { - let mut state; - match self.state.lock() { - Ok(s) => state = s, + let mut state = match self.state.lock() { + Ok(s) => s, Err(err) => { return Poll::Ready(Err(io::Error::new( io::ErrorKind::Other, @@ -83,49 +84,43 @@ impl AsyncWrite for PipeWriter { ), ))) } - } + }; - if state.closed { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - format!( - "{}: PipeWriter: The channel is closed", - env!("CARGO_PKG_NAME") - ), - ))); - } + self.wake_reader_half(&*state); - return if state.done_cycle { - state.data = Some(Data { - ptr: buf.as_ptr(), - len: buf.len(), - }); - state.done_cycle = false; + let remaining = BUFFER_SIZE - state.buffer.len(); + if remaining == 0 { state.writer_waker = Some(cx.waker().clone()); - - self.wake_reader_half(&*state); - Poll::Pending } else { - if state.done_reading { - let read_bytes_len = state.read; - - state.done_cycle = true; - state.read = 0; - state.writer_waker = None; - state.data = None; - state.done_reading = false; - - Poll::Ready(Ok(read_bytes_len)) - } else { - state.writer_waker = Some(cx.waker().clone()); - Poll::Pending + let bytes_to_write = remaining.min(buf.len()); + state.buffer.extend_from_slice(&buf[..bytes_to_write]); + Poll::Ready(Ok(bytes_to_write)) + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let mut state = match self.state.lock() { + Ok(s) => s, + Err(err) => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + format!( + "{}: PipeWriter: Failed to lock the channel state: {}", + env!("CARGO_PKG_NAME"), + err + ), + ))) } }; - } - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { - Poll::Ready(Ok(())) + if state.buffer.is_empty() { + Poll::Ready(Ok(())) + } else { + state.writer_waker = Some(cx.waker().clone()); + self.wake_reader_half(&*state); + Poll::Pending + } } fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { @@ -142,3 +137,33 @@ impl AsyncWrite for PipeWriter { } } } + +#[cfg(feature = "tokio")] +impl tokio::io::AsyncWrite for PipeWriter { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + self.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.poll_shutdown(cx) + } +} + +#[cfg(feature = "futures")] +impl futures::io::AsyncWrite for PipeWriter { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + self.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.poll_shutdown(cx) + } +}