Skip to content

Commit

Permalink
refactor: Add concurrent error test for BlockWrite (#3968)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xuanwo authored Jan 10, 2024
1 parent 2c77e1d commit 7144ab1
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 74 deletions.
265 changes: 209 additions & 56 deletions core/src/raw/oio/write/block_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use async_trait::async_trait;
use futures::Future;
use futures::FutureExt;
use futures::StreamExt;
use uuid::Uuid;

use crate::raw::*;
use crate::*;
Expand Down Expand Up @@ -77,17 +78,22 @@ pub trait BlockWrite: Send + Sync + Unpin + 'static {
/// order.
///
/// - block_id is the id of the block.
async fn write_block(&self, size: u64, block_id: String, body: AsyncBody) -> Result<()>;
async fn write_block(&self, block_id: Uuid, size: u64, body: AsyncBody) -> Result<()>;

/// complete_block will complete the block upload to build the final
/// file.
async fn complete_block(&self, block_ids: Vec<String>) -> Result<()>;
async fn complete_block(&self, block_ids: Vec<Uuid>) -> Result<()>;

/// abort_block will cancel the block upload and purge all data.
async fn abort_block(&self, block_ids: Vec<String>) -> Result<()>;
async fn abort_block(&self, block_ids: Vec<Uuid>) -> Result<()>;
}

struct WriteBlockFuture(BoxedFuture<Result<()>>);
/// WriteBlockResult is the result returned by [`WriteBlockFuture`].
///
/// The error part will carries input `(block_id, bytes, err)` so caller can retry them.
type WriteBlockResult = std::result::Result<Uuid, (Uuid, oio::ChunkedBytes, Error)>;

struct WriteBlockFuture(BoxedFuture<WriteBlockResult>);

/// # Safety
///
Expand All @@ -100,19 +106,38 @@ unsafe impl Send for WriteBlockFuture {}
unsafe impl Sync for WriteBlockFuture {}

impl Future for WriteBlockFuture {
type Output = Result<()>;
type Output = WriteBlockResult;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.get_mut().0.poll_unpin(cx)
}
}

impl WriteBlockFuture {
pub fn new<W: BlockWrite>(w: Arc<W>, block_id: Uuid, bytes: oio::ChunkedBytes) -> Self {
let fut = async move {
w.write_block(
block_id,
bytes.len() as u64,
AsyncBody::ChunkedBytes(bytes.clone()),
)
.await
// Return bytes while we got an error to allow retry.
.map_err(|err| (block_id, bytes, err))
// Return the successful block id.
.map(|_| block_id)
};

WriteBlockFuture(Box::pin(fut))
}
}

/// BlockWriter will implements [`Write`] based on block
/// uploads.
pub struct BlockWriter<W: BlockWrite> {
state: State,
w: Arc<W>,

block_ids: Vec<String>,
block_ids: Vec<Uuid>,
cache: Option<oio::ChunkedBytes>,
futures: ConcurrentFutures<WriteBlockFuture>,
}
Expand Down Expand Up @@ -168,20 +193,30 @@ where
let size = self.fill_cache(bs);
return Poll::Ready(Ok(size));
}

let cache = self.cache.take().expect("pending write must exist");
let block_id = uuid::Uuid::new_v4().to_string();
self.block_ids.push(block_id.clone());
let w = self.w.clone();
let size = cache.len();
self.futures
.push_back(WriteBlockFuture(Box::pin(async move {
w.write_block(size as u64, block_id, AsyncBody::ChunkedBytes(cache))
.await
})));
self.futures.push_back(WriteBlockFuture::new(
self.w.clone(),
Uuid::new_v4(),
cache,
));

let size = self.fill_cache(bs);
return Poll::Ready(Ok(size));
} else if let Some(res) = ready!(self.futures.poll_next_unpin(cx)) {
res?;
match res {
Ok(block_id) => {
self.block_ids.push(block_id);
}
Err((block_id, bytes, err)) => {
self.futures.push_front(WriteBlockFuture::new(
self.w.clone(),
block_id,
bytes,
));
return Poll::Ready(Err(err));
}
}
}
}
State::Close(_) => {
Expand All @@ -198,53 +233,55 @@ where
loop {
match &mut self.state {
State::Idle => {
let w = self.w.clone();
let block_ids = self.block_ids.clone();
// No write block has been sent.
if self.futures.is_empty() && self.block_ids.is_empty() {
let w = self.w.clone();
let (size, body) = match self.cache.clone() {
Some(cache) => (cache.len(), AsyncBody::ChunkedBytes(cache)),
None => (0, AsyncBody::Empty),
};
// Call write_once if there is no data in buffer and no location.
self.state =
State::Close(Box::pin(
async move { w.write_once(size as u64, body).await },
));
continue;
}

if self.block_ids.is_empty() {
match &self.cache {
Some(cache) => {
let w = self.w.clone();
let bs = cache.clone();
self.state = State::Close(Box::pin(async move {
let size = bs.len();
w.write_once(size as u64, AsyncBody::ChunkedBytes(bs)).await
}));
}
None => {
let w = self.w.clone();
// Call write_once if there is no data in cache.
self.state = State::Close(Box::pin(async move {
w.write_once(0, AsyncBody::Empty).await
}));
if self.futures.has_remaining() {
if let Some(cache) = self.cache.take() {
self.futures.push_back(WriteBlockFuture::new(
self.w.clone(),
Uuid::new_v4(),
cache,
));
}
}

if !self.futures.is_empty() {
while let Some(result) = ready!(self.futures.poll_next_unpin(cx)) {
match result {
Ok(block_id) => {
self.block_ids.push(block_id);
}
Err((block_id, bytes, err)) => {
self.futures.push_front(WriteBlockFuture::new(
self.w.clone(),
block_id,
bytes,
));
return Poll::Ready(Err(err));
}
}
}
} else if self.futures.is_empty() && self.cache.is_none() {
} else {
let w = self.w.clone();
let block_ids = self.block_ids.clone();
self.state =
State::Close(Box::pin(
async move { w.complete_block(block_ids).await },
));
} else {
if self.futures.has_remaining() {
if let Some(cache) = self.cache.take() {
let block_id = uuid::Uuid::new_v4().to_string();
self.block_ids.push(block_id.clone());
let size = cache.len();
let w = self.w.clone();
self.futures
.push_back(WriteBlockFuture(Box::pin(async move {
w.write_block(
size as u64,
block_id,
AsyncBody::ChunkedBytes(cache),
)
.await
})));
}
}
while let Some(res) = ready!(self.futures.poll_next_unpin(cx)) {
res?;
}
continue;
}
}
State::Close(fut) => {
Expand All @@ -270,6 +307,7 @@ where
let w = self.w.clone();
let block_ids = self.block_ids.clone();
self.futures.clear();
self.cache = None;
self.state =
State::Abort(Box::pin(async move { w.abort_block(block_ids).await }));
}
Expand All @@ -285,3 +323,118 @@ where
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::raw::oio::{StreamExt, WriteBuf, WriteExt};
use bytes::Bytes;
use pretty_assertions::assert_eq;
use rand::{thread_rng, Rng, RngCore};
use std::collections::HashMap;
use std::sync::Mutex;

struct TestWrite {
length: u64,
bytes: HashMap<Uuid, Bytes>,
content: Option<Bytes>,
}

impl TestWrite {
pub fn new() -> Arc<Mutex<Self>> {
let v = Self {
length: 0,
bytes: HashMap::new(),
content: None,
};

Arc::new(Mutex::new(v))
}
}

#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl BlockWrite for Arc<Mutex<TestWrite>> {
async fn write_once(&self, _: u64, _: AsyncBody) -> Result<()> {
Ok(())
}

async fn write_block(&self, block_id: Uuid, size: u64, body: AsyncBody) -> Result<()> {
// We will have 50% percent rate for write part to fail.
if thread_rng().gen_bool(5.0 / 10.0) {
return Err(Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!"));
}

let bs = match body {
AsyncBody::Empty => Bytes::new(),
AsyncBody::Bytes(bs) => bs,
AsyncBody::ChunkedBytes(cb) => cb.bytes(cb.remaining()),
AsyncBody::Stream(s) => s.collect().await.unwrap(),
};

let mut this = self.lock().unwrap();
this.length += size;
this.bytes.insert(block_id, bs);

Ok(())
}

async fn complete_block(&self, block_ids: Vec<Uuid>) -> Result<()> {
let mut this = self.lock().unwrap();
let mut bs = Vec::new();
for id in block_ids {
bs.extend_from_slice(&this.bytes[&id]);
}
this.content = Some(bs.into());

Ok(())
}

async fn abort_block(&self, _: Vec<Uuid>) -> Result<()> {
Ok(())
}
}

#[tokio::test]
async fn test_block_writer_with_concurrent_errors() {
let mut rng = thread_rng();

let mut w = BlockWriter::new(TestWrite::new(), 8);
let mut total_size = 0u64;
let mut expected_content = Vec::new();

for _ in 0..1000 {
let size = rng.gen_range(1..1024);
total_size += size as u64;

let mut bs = vec![0; size];
rng.fill_bytes(&mut bs);

expected_content.extend_from_slice(&bs);

loop {
match w.write(&bs.as_slice()).await {
Ok(_) => break,
Err(_) => continue,
}
}
}

loop {
match w.close().await {
Ok(_) => break,
Err(_) => continue,
}
}

let inner = w.w.lock().unwrap();

assert_eq!(total_size, inner.length, "length must be the same");
assert!(inner.content.is_some());
assert_eq!(
expected_content,
inner.content.clone().unwrap(),
"content must be the same"
);
}
}
26 changes: 12 additions & 14 deletions core/src/raw/oio/write/range_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,20 +291,18 @@ impl<W: RangeWrite> oio::Write for RangeWriter<W> {
}
}
}
None => match self.buffer.clone() {
Some(bs) => {
self.state = State::Complete(Box::pin(async move {
let size = bs.len();
w.write_once(size as u64, AsyncBody::ChunkedBytes(bs)).await
}));
}
None => {
// Call write_once if there is no data in buffer and no location.
self.state = State::Complete(Box::pin(async move {
w.write_once(0, AsyncBody::Empty).await
}));
}
},
None => {
let w = self.w.clone();
let (size, body) = match self.buffer.clone() {
Some(cache) => (cache.len(), AsyncBody::ChunkedBytes(cache)),
None => (0, AsyncBody::Empty),
};
// Call write_once if there is no data in buffer and no location.

self.state = State::Complete(Box::pin(async move {
w.write_once(size as u64, body).await
}));
}
}
}
State::Init(_) => {
Expand Down
Loading

0 comments on commit 7144ab1

Please sign in to comment.