Skip to content

Commit

Permalink
Fix panics and unsafe code
Browse files Browse the repository at this point in the history
This change fixes panics (routerify#6) and unsafe code (routerify#5). This comes at the
cost of an additional copy of the data send through the pipe and having
a buffer in the state.

Moreover all unsafe code is removed and the need for a custom `Drop`
implementation which makes the code overall easier.

We also add tests.
  • Loading branch information
Thomas Scholtes authored and ttiurani committed Oct 25, 2020
1 parent 0004c04 commit 18f0678
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 113 deletions.
69 changes: 65 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,8 @@ pub fn pipe() -> (PipeWriter, PipeReader) {
let shared_state = Arc::new(Mutex::new(State {
reader_waker: None,
writer_waker: None,
data: None,
done_reading: false,
read: 0,
done_cycle: true,
closed: false,
buffer: Vec::new(),
}));

let w = PipeWriter {
Expand All @@ -59,3 +56,67 @@ pub fn pipe() -> (PipeWriter, PipeReader) {

(w, r)
}

#[cfg(test)]
mod test {
use super::pipe;
use std::io;
use tokio::prelude::*;

#[tokio::test]
async fn read_write() {
let (mut writer, mut reader) = pipe();
let data = b"hello world";

let write_handle = tokio::spawn(async move {
writer.write_all(data).await.unwrap();
});

let mut read_buf = Vec::new();
reader.read_to_end(&mut read_buf).await.unwrap();
write_handle.await.unwrap();

assert_eq!(&read_buf, data);
}

#[tokio::test]
async fn eof_when_writer_is_shutdown() {
let (mut writer, mut reader) = pipe();
writer.shutdown().await.unwrap();
let mut buf = [0u8; 8];
let bytes_read = reader.read(&mut buf).await.unwrap();
assert_eq!(bytes_read, 0);
}

#[tokio::test]
async fn broken_pipe_when_reader_is_dropped() {
let (mut writer, reader) = pipe();
drop(reader);
let io_error = writer.write_all(&[0u8; 8]).await.unwrap_err();
assert_eq!(io_error.kind(), io::ErrorKind::BrokenPipe);
}

#[tokio::test]
async fn eof_when_writer_is_dropped() {
let (writer, mut reader) = pipe();
drop(writer);
let mut buf = [0u8; 8];
let bytes_read = reader.read(&mut buf).await.unwrap();
assert_eq!(bytes_read, 0);
}

#[tokio::test]
async fn drop_read_exact() {
let (mut writer, mut reader) = pipe();
const BUF_SIZE: usize = 8;

let write_handle = tokio::spawn(async move {
writer.write_all(&mut [0u8; BUF_SIZE]).await.unwrap();
});

let mut buf = [0u8; BUF_SIZE];
reader.read_exact(&mut buf).await.unwrap();
drop(reader);
write_handle.await.unwrap();
}
}
61 changes: 16 additions & 45 deletions src/reader.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::state::{Data, State};
use crate::state::State;
use std::io;
use std::pin::Pin;
use std::ptr;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};

Expand Down Expand Up @@ -53,7 +52,7 @@ impl PipeReader {
}
};

Ok(state.done_cycle)
Ok(state.buffer.is_empty())
}

fn wake_writer_half(&self, state: &State) {
Expand All @@ -62,22 +61,13 @@ impl PipeReader {
}
}

fn copy_data_into_buffer(&self, data: &Data, buf: &mut [u8]) -> usize {
let len = data.len.min(buf.len());
unsafe {
ptr::copy_nonoverlapping(data.ptr, buf.as_mut_ptr(), len);
}
len
}

fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut state;
match self.state.lock() {
Ok(s) => state = s,
let mut state = match self.state.lock() {
Ok(s) => s,
Err(err) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
Expand All @@ -88,43 +78,24 @@ impl PipeReader {
),
)))
}
}

if state.closed {
return Poll::Ready(Ok(0));
}

return if state.done_cycle {
state.reader_waker = Some(cx.waker().clone());
Poll::Pending
} else {
if let Some(ref data) = state.data {
let copied_bytes_len = self.copy_data_into_buffer(data, buf);

state.data = None;
state.read = copied_bytes_len;
state.done_reading = true;
state.reader_waker = None;

self.wake_writer_half(&*state);
};

Poll::Ready(Ok(copied_bytes_len))
if state.buffer.is_empty() {
if state.closed || Arc::strong_count(&self.state) == 1 {
Poll::Ready(Ok(0))
} else {
self.wake_writer_half(&*state);
state.reader_waker = Some(cx.waker().clone());
Poll::Pending
}
};
}
}
} else {
self.wake_writer_half(&*state);
let size_to_read = state.buffer.len().min(buf.len());
let (to_read, rest) = state.buffer.split_at(size_to_read);
buf[..size_to_read].copy_from_slice(to_read);
state.buffer = rest.to_vec();

impl Drop for PipeReader {
fn drop(&mut self) {
if let Err(err) = self.close() {
log::warn!(
"{}: PipeReader: Failed to close the channel on drop: {}",
env!("CARGO_PKG_NAME"),
err
);
Poll::Ready(Ok(size_to_read))
}
}
}
Expand Down
14 changes: 3 additions & 11 deletions src/state.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
use std::task::Waker;

pub const BUFFER_SIZE: usize = 1024;

pub(crate) struct State {
pub(crate) reader_waker: Option<Waker>,
pub(crate) writer_waker: Option<Waker>,
pub(crate) data: Option<Data>,
pub(crate) done_reading: bool,
pub(crate) read: usize,
pub(crate) done_cycle: bool,
pub(crate) closed: bool,
pub(crate) buffer: Vec<u8>,
}

pub(crate) struct Data {
pub(crate) ptr: *const u8,
pub(crate) len: usize,
}

unsafe impl Send for Data {}
96 changes: 43 additions & 53 deletions src/writer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::state::Data;
use crate::state::State;
use crate::state::{State, BUFFER_SIZE};
use std::io;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
Expand Down Expand Up @@ -53,7 +52,7 @@ impl PipeWriter {
}
};

Ok(state.done_cycle)
Ok(state.buffer.is_empty())
}

fn wake_reader_half(&self, state: &State) {
Expand All @@ -63,9 +62,18 @@ impl PipeWriter {
}

fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let mut state;
match self.state.lock() {
Ok(s) => state = s,
if Arc::strong_count(&self.state) == 1 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
format!(
"{}: PipeWriter: The channel is closed",
env!("CARGO_PKG_NAME")
),
)));
}

let mut state = match self.state.lock() {
Ok(s) => s,
Err(err) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
Expand All @@ -76,49 +84,43 @@ impl PipeWriter {
),
)))
}
}
};

if state.closed {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
format!(
"{}: PipeWriter: The channel is closed",
env!("CARGO_PKG_NAME")
),
)));
}
self.wake_reader_half(&*state);

return if state.done_cycle {
state.data = Some(Data {
ptr: buf.as_ptr(),
len: buf.len(),
});
state.done_cycle = false;
let remaining = BUFFER_SIZE - state.buffer.len();
if remaining == 0 {
state.writer_waker = Some(cx.waker().clone());

self.wake_reader_half(&*state);

Poll::Pending
} else {
if state.done_reading {
let read_bytes_len = state.read;

state.done_cycle = true;
state.read = 0;
state.writer_waker = None;
state.data = None;
state.done_reading = false;

Poll::Ready(Ok(read_bytes_len))
} else {
state.writer_waker = Some(cx.waker().clone());
Poll::Pending
let bytes_to_write = remaining.min(buf.len());
state.buffer.extend_from_slice(&buf[..bytes_to_write]);
Poll::Ready(Ok(bytes_to_write))
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let mut state = match self.state.lock() {
Ok(s) => s,
Err(err) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
format!(
"{}: PipeWriter: Failed to lock the channel state: {}",
env!("CARGO_PKG_NAME"),
err
),
)))
}
};
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
if state.buffer.is_empty() {
Poll::Ready(Ok(()))
} else {
state.writer_waker = Some(cx.waker().clone());
self.wake_reader_half(&*state);
Poll::Pending
}
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Expand All @@ -136,18 +138,6 @@ impl PipeWriter {
}
}

impl Drop for PipeWriter {
fn drop(&mut self) {
if let Err(err) = self.close() {
log::warn!(
"{}: PipeWriter: Failed to close the channel on drop: {}",
env!("CARGO_PKG_NAME"),
err
);
}
}
}

#[cfg(feature = "tokio")]
impl tokio::io::AsyncWrite for PipeWriter {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
Expand Down

0 comments on commit 18f0678

Please sign in to comment.