diff --git a/core/src/raw/oio/write/block_write.rs b/core/src/raw/oio/write/block_write.rs index fb021335bb67..aecc7c66c610 100644 --- a/core/src/raw/oio/write/block_write.rs +++ b/core/src/raw/oio/write/block_write.rs @@ -25,6 +25,7 @@ use async_trait::async_trait; use futures::Future; use futures::FutureExt; use futures::StreamExt; +use uuid::Uuid; use crate::raw::*; use crate::*; @@ -77,17 +78,22 @@ pub trait BlockWrite: Send + Sync + Unpin + 'static { /// order. /// /// - block_id is the id of the block. - async fn write_block(&self, size: u64, block_id: String, body: AsyncBody) -> Result<()>; + async fn write_block(&self, block_id: Uuid, size: u64, body: AsyncBody) -> Result<()>; /// complete_block will complete the block upload to build the final /// file. - async fn complete_block(&self, block_ids: Vec) -> Result<()>; + async fn complete_block(&self, block_ids: Vec) -> Result<()>; /// abort_block will cancel the block upload and purge all data. - async fn abort_block(&self, block_ids: Vec) -> Result<()>; + async fn abort_block(&self, block_ids: Vec) -> Result<()>; } -struct WriteBlockFuture(BoxedFuture>); +/// WriteBlockResult is the result returned by [`WriteBlockFuture`]. +/// +/// The error part will carries input `(block_id, bytes, err)` so caller can retry them. +type WriteBlockResult = std::result::Result; + +struct WriteBlockFuture(BoxedFuture); /// # Safety /// @@ -100,19 +106,38 @@ unsafe impl Send for WriteBlockFuture {} unsafe impl Sync for WriteBlockFuture {} impl Future for WriteBlockFuture { - type Output = Result<()>; + type Output = WriteBlockResult; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.get_mut().0.poll_unpin(cx) } } +impl WriteBlockFuture { + pub fn new(w: Arc, block_id: Uuid, bytes: oio::ChunkedBytes) -> Self { + let fut = async move { + w.write_block( + block_id, + bytes.len() as u64, + AsyncBody::ChunkedBytes(bytes.clone()), + ) + .await + // Return bytes while we got an error to allow retry. + .map_err(|err| (block_id, bytes, err)) + // Return the successful block id. + .map(|_| block_id) + }; + + WriteBlockFuture(Box::pin(fut)) + } +} + /// BlockWriter will implements [`Write`] based on block /// uploads. pub struct BlockWriter { state: State, w: Arc, - block_ids: Vec, + block_ids: Vec, cache: Option, futures: ConcurrentFutures, } @@ -168,20 +193,30 @@ where let size = self.fill_cache(bs); return Poll::Ready(Ok(size)); } + let cache = self.cache.take().expect("pending write must exist"); - let block_id = uuid::Uuid::new_v4().to_string(); - self.block_ids.push(block_id.clone()); - let w = self.w.clone(); - let size = cache.len(); - self.futures - .push_back(WriteBlockFuture(Box::pin(async move { - w.write_block(size as u64, block_id, AsyncBody::ChunkedBytes(cache)) - .await - }))); + self.futures.push_back(WriteBlockFuture::new( + self.w.clone(), + Uuid::new_v4(), + cache, + )); + let size = self.fill_cache(bs); return Poll::Ready(Ok(size)); } else if let Some(res) = ready!(self.futures.poll_next_unpin(cx)) { - res?; + match res { + Ok(block_id) => { + self.block_ids.push(block_id); + } + Err((block_id, bytes, err)) => { + self.futures.push_front(WriteBlockFuture::new( + self.w.clone(), + block_id, + bytes, + )); + return Poll::Ready(Err(err)); + } + } } } State::Close(_) => { @@ -198,53 +233,55 @@ where loop { match &mut self.state { State::Idle => { - let w = self.w.clone(); - let block_ids = self.block_ids.clone(); + // No write block has been sent. + if self.futures.is_empty() && self.block_ids.is_empty() { + let w = self.w.clone(); + let (size, body) = match self.cache.clone() { + Some(cache) => (cache.len(), AsyncBody::ChunkedBytes(cache)), + None => (0, AsyncBody::Empty), + }; + // Call write_once if there is no data in buffer and no location. + self.state = + State::Close(Box::pin( + async move { w.write_once(size as u64, body).await }, + )); + continue; + } - if self.block_ids.is_empty() { - match &self.cache { - Some(cache) => { - let w = self.w.clone(); - let bs = cache.clone(); - self.state = State::Close(Box::pin(async move { - let size = bs.len(); - w.write_once(size as u64, AsyncBody::ChunkedBytes(bs)).await - })); - } - None => { - let w = self.w.clone(); - // Call write_once if there is no data in cache. - self.state = State::Close(Box::pin(async move { - w.write_once(0, AsyncBody::Empty).await - })); + if self.futures.has_remaining() { + if let Some(cache) = self.cache.take() { + self.futures.push_back(WriteBlockFuture::new( + self.w.clone(), + Uuid::new_v4(), + cache, + )); + } + } + + if !self.futures.is_empty() { + while let Some(result) = ready!(self.futures.poll_next_unpin(cx)) { + match result { + Ok(block_id) => { + self.block_ids.push(block_id); + } + Err((block_id, bytes, err)) => { + self.futures.push_front(WriteBlockFuture::new( + self.w.clone(), + block_id, + bytes, + )); + return Poll::Ready(Err(err)); + } } } - } else if self.futures.is_empty() && self.cache.is_none() { + } else { + let w = self.w.clone(); + let block_ids = self.block_ids.clone(); self.state = State::Close(Box::pin( async move { w.complete_block(block_ids).await }, )); - } else { - if self.futures.has_remaining() { - if let Some(cache) = self.cache.take() { - let block_id = uuid::Uuid::new_v4().to_string(); - self.block_ids.push(block_id.clone()); - let size = cache.len(); - let w = self.w.clone(); - self.futures - .push_back(WriteBlockFuture(Box::pin(async move { - w.write_block( - size as u64, - block_id, - AsyncBody::ChunkedBytes(cache), - ) - .await - }))); - } - } - while let Some(res) = ready!(self.futures.poll_next_unpin(cx)) { - res?; - } + continue; } } State::Close(fut) => { @@ -270,6 +307,7 @@ where let w = self.w.clone(); let block_ids = self.block_ids.clone(); self.futures.clear(); + self.cache = None; self.state = State::Abort(Box::pin(async move { w.abort_block(block_ids).await })); } @@ -285,3 +323,118 @@ where } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::raw::oio::{StreamExt, WriteBuf, WriteExt}; + use bytes::Bytes; + use pretty_assertions::assert_eq; + use rand::{thread_rng, Rng, RngCore}; + use std::collections::HashMap; + use std::sync::Mutex; + + struct TestWrite { + length: u64, + bytes: HashMap, + content: Option, + } + + impl TestWrite { + pub fn new() -> Arc> { + let v = Self { + length: 0, + bytes: HashMap::new(), + content: None, + }; + + Arc::new(Mutex::new(v)) + } + } + + #[cfg_attr(not(target_arch = "wasm32"), async_trait)] + #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] + impl BlockWrite for Arc> { + async fn write_once(&self, _: u64, _: AsyncBody) -> Result<()> { + Ok(()) + } + + async fn write_block(&self, block_id: Uuid, size: u64, body: AsyncBody) -> Result<()> { + // We will have 50% percent rate for write part to fail. + if thread_rng().gen_bool(5.0 / 10.0) { + return Err(Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!")); + } + + let bs = match body { + AsyncBody::Empty => Bytes::new(), + AsyncBody::Bytes(bs) => bs, + AsyncBody::ChunkedBytes(cb) => cb.bytes(cb.remaining()), + AsyncBody::Stream(s) => s.collect().await.unwrap(), + }; + + let mut this = self.lock().unwrap(); + this.length += size; + this.bytes.insert(block_id, bs); + + Ok(()) + } + + async fn complete_block(&self, block_ids: Vec) -> Result<()> { + let mut this = self.lock().unwrap(); + let mut bs = Vec::new(); + for id in block_ids { + bs.extend_from_slice(&this.bytes[&id]); + } + this.content = Some(bs.into()); + + Ok(()) + } + + async fn abort_block(&self, _: Vec) -> Result<()> { + Ok(()) + } + } + + #[tokio::test] + async fn test_block_writer_with_concurrent_errors() { + let mut rng = thread_rng(); + + let mut w = BlockWriter::new(TestWrite::new(), 8); + let mut total_size = 0u64; + let mut expected_content = Vec::new(); + + for _ in 0..1000 { + let size = rng.gen_range(1..1024); + total_size += size as u64; + + let mut bs = vec![0; size]; + rng.fill_bytes(&mut bs); + + expected_content.extend_from_slice(&bs); + + loop { + match w.write(&bs.as_slice()).await { + Ok(_) => break, + Err(_) => continue, + } + } + } + + loop { + match w.close().await { + Ok(_) => break, + Err(_) => continue, + } + } + + let inner = w.w.lock().unwrap(); + + assert_eq!(total_size, inner.length, "length must be the same"); + assert!(inner.content.is_some()); + assert_eq!( + expected_content, + inner.content.clone().unwrap(), + "content must be the same" + ); + } +} diff --git a/core/src/raw/oio/write/range_write.rs b/core/src/raw/oio/write/range_write.rs index d876962bb67f..e4210b5ab0d5 100644 --- a/core/src/raw/oio/write/range_write.rs +++ b/core/src/raw/oio/write/range_write.rs @@ -291,20 +291,18 @@ impl oio::Write for RangeWriter { } } } - None => match self.buffer.clone() { - Some(bs) => { - self.state = State::Complete(Box::pin(async move { - let size = bs.len(); - w.write_once(size as u64, AsyncBody::ChunkedBytes(bs)).await - })); - } - None => { - // Call write_once if there is no data in buffer and no location. - self.state = State::Complete(Box::pin(async move { - w.write_once(0, AsyncBody::Empty).await - })); - } - }, + None => { + let w = self.w.clone(); + let (size, body) = match self.buffer.clone() { + Some(cache) => (cache.len(), AsyncBody::ChunkedBytes(cache)), + None => (0, AsyncBody::Empty), + }; + // Call write_once if there is no data in buffer and no location. + + self.state = State::Complete(Box::pin(async move { + w.write_once(size as u64, body).await + })); + } } } State::Init(_) => { diff --git a/core/src/services/webhdfs/writer.rs b/core/src/services/webhdfs/writer.rs index f405708b742b..8cd935328fa3 100644 --- a/core/src/services/webhdfs/writer.rs +++ b/core/src/services/webhdfs/writer.rs @@ -17,6 +17,7 @@ use async_trait::async_trait; use http::StatusCode; +use uuid::Uuid; use super::backend::WebhdfsBackend; use super::error::parse_error; @@ -59,7 +60,7 @@ impl oio::BlockWrite for WebhdfsWriter { } } - async fn write_block(&self, size: u64, block_id: String, body: AsyncBody) -> Result<()> { + async fn write_block(&self, block_id: Uuid, size: u64, body: AsyncBody) -> Result<()> { let Some(ref atomic_write_dir) = self.backend.atomic_write_dir else { return Err(Error::new( ErrorKind::Unsupported, @@ -88,7 +89,7 @@ impl oio::BlockWrite for WebhdfsWriter { } } - async fn complete_block(&self, block_ids: Vec) -> Result<()> { + async fn complete_block(&self, block_ids: Vec) -> Result<()> { let Some(ref atomic_write_dir) = self.backend.atomic_write_dir else { return Err(Error::new( ErrorKind::Unsupported, @@ -138,9 +139,9 @@ impl oio::BlockWrite for WebhdfsWriter { } } - async fn abort_block(&self, block_ids: Vec) -> Result<()> { + async fn abort_block(&self, block_ids: Vec) -> Result<()> { for block_id in block_ids { - let resp = self.backend.webhdfs_delete(&block_id).await?; + let resp = self.backend.webhdfs_delete(&block_id.to_string()).await?; match resp.status() { StatusCode::OK => { resp.into_body().consume().await?;