Skip to content

Commit

Permalink
Merge pull request #7 from geigerzaehler/fix-panics-and-crashes
Browse files Browse the repository at this point in the history
Fix panics and unsafe code
  • Loading branch information
gsserge authored Jul 28, 2021
2 parents 2349e6e + 3949517 commit cd927c3
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 142 deletions.
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@ readme = "README.md"
license = "MIT"
edition = "2018"

[features]
default = ["tokio"]

[dependencies]
tokio = { version = "1", features= [] }
tokio = { version = "1", features= [], optional = true }
log = "0.4"
futures = { version = "0.3", optional = true }

[dev-dependencies]
tokio = { version = "1", features = ["full"] }

[package.metadata.docs.rs]
features = ["futures", "tokio"]
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
[![Documentation](https://docs.rs/async-pipe/badge.svg)](https://docs.rs/async-pipe)
[![MIT](https://img.shields.io/crates/l/async-pipe.svg)](./LICENSE)

Creates an asynchronous piped reader and writer pair using `tokio.rs`.
Creates an asynchronous piped reader and writer pair using `tokio.rs` or
`futures`

[Docs](https://docs.rs/async-pipe)

Expand Down Expand Up @@ -38,4 +39,4 @@ async fn main() {

## Contributing

Your PRs and stars are always welcome.
Your PRs and stars are always welcome.
4 changes: 1 addition & 3 deletions examples/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
async fn main() {
let (mut w, mut r) = async_pipe::pipe();

tokio::spawn(async move {
w.write_all(b"hello world").await.unwrap();
});
let _ = w.write_all(b"hello world").await;

let mut v = Vec::new();
r.read_to_end(&mut v).await.unwrap();
Expand Down
76 changes: 60 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! Creates an asynchronous piped reader and writer pair using `tokio.rs`.
//! Creates an asynchronous piped reader and writer pair using `tokio.rs` and `futures`.
//!
//! # Examples
//!
Expand All @@ -21,6 +21,11 @@
//!
//! tokio::runtime::Runtime::new().unwrap().block_on(run());
//! ```
//!
//! # Featues
//!
//! * `tokio` (default) Implement `AsyncWrite` and `AsyncRead` from `tokio::io`.
//! * `futures` Implement `AsyncWrite` and `AsyncRead` from `futures::io`
use state::State;
use std::sync::{Arc, Mutex};
Expand All @@ -37,42 +42,81 @@ pub fn pipe() -> (PipeWriter, PipeReader) {
let shared_state = Arc::new(Mutex::new(State {
reader_waker: None,
writer_waker: None,
data: None,
done_reading: false,
read: 0,
done_cycle: true,
closed: false,
buffer: Vec::new(),
}));

let w = PipeWriter {
state: shared_state.clone(),
};

let r = PipeReader {
state: shared_state.clone(),
state: shared_state,
};

(w, r)
}

#[cfg(test)]
mod tests {
mod test {
use super::*;
use std::io;
use tokio::io::{AsyncReadExt, AsyncWriteExt};

#[tokio::test]
async fn should_read_expected_text() {
const EXPECTED: &'static str = "hello world";
async fn read_write() {
let (mut writer, mut reader) = pipe();
let data = b"hello world";

let write_handle = tokio::spawn(async move {
writer.write_all(data).await.unwrap();
});

let mut read_buf = Vec::new();
reader.read_to_end(&mut read_buf).await.unwrap();
write_handle.await.unwrap();

assert_eq!(&read_buf, data);
}

#[tokio::test]
async fn eof_when_writer_is_shutdown() {
let (mut writer, mut reader) = pipe();
writer.shutdown().await.unwrap();
let mut buf = [0u8; 8];
let bytes_read = reader.read(&mut buf).await.unwrap();
assert_eq!(bytes_read, 0);
}

#[tokio::test]
async fn broken_pipe_when_reader_is_dropped() {
let (mut writer, reader) = pipe();
drop(reader);
let io_error = writer.write_all(&[0u8; 8]).await.unwrap_err();
assert_eq!(io_error.kind(), io::ErrorKind::BrokenPipe);
}

let (mut w, mut r) = pipe();
#[tokio::test]
async fn eof_when_writer_is_dropped() {
let (writer, mut reader) = pipe();
drop(writer);
let mut buf = [0u8; 8];
let bytes_read = reader.read(&mut buf).await.unwrap();
assert_eq!(bytes_read, 0);
}

#[tokio::test]
async fn drop_read_exact() {
let (mut writer, mut reader) = pipe();
const BUF_SIZE: usize = 8;

tokio::spawn(async move {
w.write_all(EXPECTED.as_bytes()).await.unwrap();
let write_handle = tokio::spawn(async move {
writer.write_all(&[0u8; BUF_SIZE]).await.unwrap();
});

let mut v = Vec::new();
r.read_to_end(&mut v).await.unwrap();
let actual = String::from_utf8(v).unwrap();
assert_eq!(EXPECTED, actual.as_str());
let mut buf = [0u8; BUF_SIZE];
reader.read_exact(&mut buf).await.unwrap();
drop(reader);
write_handle.await.unwrap();
}
}
107 changes: 54 additions & 53 deletions src/reader.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
use crate::state::{Data, State};
use crate::state::State;
use std::io;
use std::pin::Pin;
use std::ptr;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use tokio::io::{self, AsyncRead, ReadBuf};

/// The read half of the pipe which implements [`AsyncRead`](https://docs.rs/tokio/0.2.15/tokio/io/trait.AsyncRead.html).
/// The read half of the pipe
///
/// Implements [`tokio::io::AsyncRead`][tokio-async-read] when feature `tokio` is enabled (the
/// default). Implements [`futures::io::AsyncRead`][futures-async-read] when feature `futures` is
/// enabled.
///
/// [futures-async-read]: https://docs.rs/futures/0.3.16/futures/io/trait.AsyncRead.html
/// [tokio-async-read]: https://docs.rs/tokio/1.9.0/tokio/io/trait.AsyncRead.html
pub struct PipeReader {
pub(crate) state: Arc<Mutex<State>>,
}
Expand Down Expand Up @@ -46,7 +52,7 @@ impl PipeReader {
}
};

Ok(state.done_cycle)
Ok(state.buffer.is_empty())
}

fn wake_writer_half(&self, state: &State) {
Expand All @@ -55,36 +61,13 @@ impl PipeReader {
}
}

fn copy_data_into_buffer(&self, data: &Data, buf: &mut ReadBuf) -> usize {
let len = data.len.min(buf.capacity());
unsafe {
ptr::copy_nonoverlapping(data.ptr, buf.initialize_unfilled().as_mut_ptr(), len);
}
len
}
}

impl Drop for PipeReader {
fn drop(&mut self) {
if let Err(err) = self.close() {
log::warn!(
"{}: PipeReader: Failed to close the channel on drop: {}",
env!("CARGO_PKG_NAME"),
err
);
}
}
}

impl AsyncRead for PipeReader {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let mut state;
match self.state.lock() {
Ok(s) => state = s,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut state = match self.state.lock() {
Ok(s) => s,
Err(err) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
Expand All @@ -95,31 +78,49 @@ impl AsyncRead for PipeReader {
),
)))
}
}

if state.closed {
return Poll::Ready(Ok(()));
}

return if state.done_cycle {
state.reader_waker = Some(cx.waker().clone());
Poll::Pending
} else {
if let Some(ref data) = state.data {
let copied_bytes_len = self.copy_data_into_buffer(data, buf);

state.data = None;
state.read = copied_bytes_len;
state.done_reading = true;
state.reader_waker = None;

self.wake_writer_half(&*state);
};

Poll::Ready(Ok(()))
if state.buffer.is_empty() {
if state.closed || Arc::strong_count(&self.state) == 1 {
Poll::Ready(Ok(0))
} else {
self.wake_writer_half(&*state);
state.reader_waker = Some(cx.waker().clone());
Poll::Pending
}
};
} else {
self.wake_writer_half(&*state);
let size_to_read = state.buffer.len().min(buf.len());
let (to_read, rest) = state.buffer.split_at(size_to_read);
buf[..size_to_read].copy_from_slice(to_read);
state.buffer = rest.to_vec();

Poll::Ready(Ok(size_to_read))
}
}
}

#[cfg(feature = "tokio")]
impl tokio::io::AsyncRead for PipeReader {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut tokio::io::ReadBuf,
) -> Poll<io::Result<()>> {
let dst = buf.initialize_unfilled();
self.poll_read(cx, dst).map_ok(|read| {
buf.advance(read);
})
}
}

#[cfg(feature = "futures")]
impl futures::io::AsyncRead for PipeReader {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.poll_read(cx, buf)
}
}
14 changes: 3 additions & 11 deletions src/state.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
use std::task::Waker;

pub const BUFFER_SIZE: usize = 1024;

pub(crate) struct State {
pub(crate) reader_waker: Option<Waker>,
pub(crate) writer_waker: Option<Waker>,
pub(crate) data: Option<Data>,
pub(crate) done_reading: bool,
pub(crate) read: usize,
pub(crate) done_cycle: bool,
pub(crate) closed: bool,
pub(crate) buffer: Vec<u8>,
}

pub(crate) struct Data {
pub(crate) ptr: *const u8,
pub(crate) len: usize,
}

unsafe impl Send for Data {}
Loading

0 comments on commit cd927c3

Please sign in to comment.