Skip to content

Commit

Permalink
Simplify AsyncWrite protocol. Add implementation for shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
dmzmk committed May 12, 2023
1 parent a505ab7 commit 4494fc5
Showing 1 changed file with 34 additions and 49 deletions.
83 changes: 34 additions & 49 deletions journal/src/async_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use crate::{Error as JournalError, Journal, Protocol, Stream as JournalStream};
use serde_sqlite::de;
use tokio::sync::mpsc::error::TrySendError;
use std::io::{BufRead, Read, Write};
use std::path::PathBuf;
use std::pin::Pin;
Expand Down Expand Up @@ -143,26 +144,24 @@ impl AsyncRead for AsyncReadJournalStreamHandle {

#[derive(Debug)]
enum AsyncWriteProto {
W(Waker),
B(Vec<u8>),
WriteBuf(Vec<u8>, Waker),
Shutdown(Waker),
}

pub struct ReadReceiver {
buf: Vec<u8>,
buf_pos: usize,
waker: Option<Waker>,
rx: Receiver<AsyncWriteProto>,
tx: Sender<()>,
}

impl ReadReceiver {
fn new(rx: Receiver<AsyncWriteProto>, tx: Sender<()>) -> Self {
fn new(rx: Receiver<AsyncWriteProto>) -> Self {
Self {
buf: vec![],
buf_pos: 0,
waker: None,
rx,
tx,
}
}
}
Expand All @@ -175,23 +174,19 @@ impl BufRead for ReadReceiver {
self.buf_pos = 0;
self.buf.clear();
}
// buffer request
self.tx.try_send(()).ok();

loop {
// wake up future
if let Some(waker) = self.waker.take() {
waker.wake()
}
match self.rx.blocking_recv() {
Some(AsyncWriteProto::W(waker)) => {
self.waker = Some(waker);
}
Some(AsyncWriteProto::B(buf)) => {
Some(AsyncWriteProto::WriteBuf(buf, waker)) => {
waker.wake();
self.buf = buf;
self.buf_pos = 0;
break;
}
},
Some(AsyncWriteProto::Shutdown(waker)) => {
self.waker = Some(waker);
break;
},
None => {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
Expand Down Expand Up @@ -237,11 +232,12 @@ impl Drop for ReadReceiver {
fn drop(&mut self) {
self.rx.close();
if let Some(waker) = self.waker.take() {
waker.wake()
waker.wake();
}
while let Ok(message) = self.rx.try_recv() {
if let AsyncWriteProto::W(waker) = message {
waker.wake()
match message {
AsyncWriteProto::WriteBuf(_buf, waker) => waker.wake(),
AsyncWriteProto::Shutdown(waker) => waker.wake(),
}
}
}
Expand All @@ -259,15 +255,10 @@ impl AsyncWriteJournalStream {
}

pub fn spawn(mut self) -> AsyncWriteJournalStreamHandle {
let (buf_tx, buf_rx) = channel(2); // enough space to store waker and buf
let (req_tx, req_rx) = channel(1);
let read_receiver = ReadReceiver::new(buf_rx, req_tx);
let (tx, rx) = channel(1); // enough space to store waker and buf
let read_receiver = ReadReceiver::new(rx);
let join_handle = tokio::task::spawn_blocking(move || self.enter_loop(read_receiver));
AsyncWriteJournalStreamHandle {
tx: buf_tx,
rx: req_rx,
join_handle,
}
AsyncWriteJournalStreamHandle { tx, join_handle }
}

pub fn enter_loop(&mut self, mut read_receiver: ReadReceiver) -> Result<(), JournalError> {
Expand Down Expand Up @@ -305,6 +296,7 @@ impl AsyncWriteJournalStream {
}
Protocol::EndOfStream(_) => {
journal.commit().map_err(to_err)?;
drop(journal);
return Ok(());
}
msg => {
Expand All @@ -322,7 +314,6 @@ impl AsyncWriteJournalStream {
#[derive(Debug)]
pub struct AsyncWriteJournalStreamHandle {
tx: Sender<AsyncWriteProto>,
rx: Receiver<()>,
join_handle: tokio::task::JoinHandle<Result<(), JournalError>>,
}

Expand All @@ -339,33 +330,27 @@ impl AsyncWrite for AsyncWriteJournalStreamHandle {
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let me = self.get_mut();
match me.rx.try_recv() {
Err(TryRecvError::Empty) => {
match me.tx.try_send(AsyncWriteProto::W(ctx.waker().clone())) {
Ok(_) => (),
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => (),
Err(e) => return Poll::Ready(Err(to_err(e))),
}
Poll::Pending
}
Err(e @ TryRecvError::Disconnected) => Poll::Ready(Err(to_err(e))),
Ok(_) => {
// eh
let len = buf.len();
let buf: Vec<u8> = buf.into();
match me.tx.try_send(AsyncWriteProto::B(buf)) {
Ok(_) => Poll::Ready(Ok(len)),
Err(e) => Poll::Ready(Err(to_err(e))),
}
}
match me.tx.try_send(AsyncWriteProto::WriteBuf(buf.into(), ctx.waker().clone())) {
Ok(_) => Poll::Ready(Ok(buf.len())),
Err(TrySendError::Full(_)) => Poll::Pending,
Err(e@TrySendError::Closed(_)) => Poll::Ready(Err(to_err(e))),
}
}

fn poll_flush(self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}

fn poll_shutdown(self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
fn poll_shutdown(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let me = self.get_mut();
match me.tx.try_send(AsyncWriteProto::Shutdown(ctx.waker().clone())) {
Ok(_) => Poll::Pending,
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
Poll::Pending
},
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
Poll::Ready(Ok(()))
}
}
}
}

0 comments on commit 4494fc5

Please sign in to comment.