diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index af3c8fba30ff..97d9899a9fb0 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::{pin::Pin, task::Poll}; - use crate::{ decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, @@ -28,16 +26,15 @@ use crate::{ use arrow_schema::Schema; use bytes::Bytes; use futures::{ - channel::oneshot::{Receiver, Sender}, future::ready, - ready, stream::{self, BoxStream}, - FutureExt, Stream, StreamExt, TryStreamExt, + Stream, StreamExt, TryStreamExt, }; use prost::Message; use tonic::{metadata::MetadataMap, transport::Channel}; use crate::error::{FlightError, Result}; +use crate::streams::{FallibleRequestStream, FallibleTonicResponseStream}; /// A "Mid level" [Apache Arrow Flight](https://arrow.apache.org/docs/format/Flight.html) client. /// @@ -674,103 +671,3 @@ impl FlightClient { request } } - -/// Wrapper around fallible stream such that when -/// it encounters an error it uses the oneshot sender to -/// notify the error and stop any further streaming. See `do_put` or -/// `do_exchange` for it's uses. -pub(crate) struct FallibleRequestStream { - /// sender to notify error - sender: Option>, - /// fallible stream - fallible_stream: Pin> + Send + 'static>>, -} - -impl FallibleRequestStream { - pub(crate) fn new( - sender: Sender, - fallible_stream: Pin> + Send + 'static>>, - ) -> Self { - Self { - sender: Some(sender), - fallible_stream, - } - } -} - -impl Stream for FallibleRequestStream { - type Item = T; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let pinned = self.get_mut(); - let mut request_streams = pinned.fallible_stream.as_mut(); - match ready!(request_streams.poll_next_unpin(cx)) { - Some(Ok(data)) => Poll::Ready(Some(data)), - Some(Err(e)) => { - // in theory this should only ever be called once - // as this stream should not be polled again after returning - // None, however we still check for None to be safe - if let Some(sender) = pinned.sender.take() { - // an error means the other end of the channel is not around - // to receive the error, so ignore it - let _ = sender.send(e); - } - Poll::Ready(None) - } - None => Poll::Ready(None), - } - } -} - -/// Wrapper for a tonic response stream that can produce a tonic -/// error. This is tied to a oneshot receiver which can be notified -/// of other errors. When it receives an error through receiver -/// end, it prioritises that error to be sent back. See `do_put` or -/// `do_exchange` for it's uses -struct FallibleTonicResponseStream { - /// Receiver for FlightError - receiver: Receiver, - /// Tonic response stream - response_stream: - Pin> + Send + 'static>>, -} - -impl FallibleTonicResponseStream { - fn new( - receiver: Receiver, - response_stream: Pin< - Box> + Send + 'static>, - >, - ) -> Self { - Self { - receiver, - response_stream, - } - } -} - -impl Stream for FallibleTonicResponseStream { - type Item = Result; - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let pinned = self.get_mut(); - let receiver = &mut pinned.receiver; - // Prioritise sending the error that's been notified over - // polling the response_stream - if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { - return Poll::Ready(Some(Err(err))); - }; - - match ready!(pinned.response_stream.poll_next_unpin(cx)) { - Some(Ok(res)) => Poll::Ready(Some(Ok(res))), - Some(Err(status)) => Poll::Ready(Some(Err(FlightError::Tonic(status)))), - None => Poll::Ready(None), - } - } -} diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs index 8fa61b1d5719..1180264e5ddd 100644 --- a/arrow-flight/src/lib.rs +++ b/arrow-flight/src/lib.rs @@ -120,6 +120,7 @@ pub mod utils; #[cfg(feature = "flight-sql-experimental")] pub mod sql; +mod streams; use flight_descriptor::DescriptorType; diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index 345254a63a3b..ef52aa27ef50 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -24,7 +24,6 @@ use std::collections::HashMap; use std::str::FromStr; use tonic::metadata::AsciiMetadataKey; -use crate::client::FallibleRequestStream; use crate::decode::FlightRecordBatchStream; use crate::encode::FlightDataEncoderBuilder; use crate::error::FlightError; @@ -43,6 +42,7 @@ use crate::sql::{ CommandStatementIngest, CommandStatementQuery, CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo, }; +use crate::streams::FallibleRequestStream; use crate::trailers::extract_lazy_trailers; use crate::{ Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, diff --git a/arrow-flight/src/streams.rs b/arrow-flight/src/streams.rs new file mode 100644 index 000000000000..e532a80e1ebb --- /dev/null +++ b/arrow-flight/src/streams.rs @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`FallibleRequestStream`] and [`FallibleTonicResponseStream`] adapters + +use crate::error::FlightError; +use futures::{ + channel::oneshot::{Receiver, Sender}, + FutureExt, Stream, StreamExt, +}; +use std::pin::Pin; +use std::task::{ready, Poll}; + +/// Wrapper around a fallible stream (one that returns errors) that makes it infallible. +/// +/// Any errors encountered in the stream are ignored are sent to the provided +/// oneshot sender. +/// +/// This can be used to accept a stream of `Result<_>` from a client API and send +/// them to the remote server that wants only the successful results. +pub(crate) struct FallibleRequestStream { + /// sender to notify error + sender: Option>, + /// fallible stream + fallible_stream: Pin> + Send + 'static>>, +} + +impl FallibleRequestStream { + pub(crate) fn new( + sender: Sender, + fallible_stream: Pin> + Send + 'static>>, + ) -> Self { + Self { + sender: Some(sender), + fallible_stream, + } + } +} + +impl Stream for FallibleRequestStream { + type Item = T; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let pinned = self.get_mut(); + let mut request_streams = pinned.fallible_stream.as_mut(); + match ready!(request_streams.poll_next_unpin(cx)) { + Some(Ok(data)) => Poll::Ready(Some(data)), + Some(Err(e)) => { + // in theory this should only ever be called once + // as this stream should not be polled again after returning + // None, however we still check for None to be safe + if let Some(sender) = pinned.sender.take() { + // an error means the other end of the channel is not around + // to receive the error, so ignore it + let _ = sender.send(e); + } + Poll::Ready(None) + } + None => Poll::Ready(None), + } + } +} + +/// Wrapper for a tonic response stream that maps errors to `FlightError` and +/// returns errors from a oneshot channel into the stream. +/// +/// The user of this stream can inject an error into the response stream using +/// the one shot receiver. This is used to propagate errors in +/// [`FlightClient::do_put`] and [`FlightClient::do_exchange`] from the client +/// provided input stream to the response stream. +/// +/// # Error Priority +/// Error from the receiver are prioritised over the response stream. +/// +/// [`FlightClient::do_put`]: crate::FlightClient::do_put +/// [`FlightClient::do_exchange`]: crate::FlightClient::do_exchange +pub(crate) struct FallibleTonicResponseStream { + /// Receiver for FlightError + receiver: Receiver, + /// Tonic response stream + response_stream: Pin> + Send + 'static>>, +} + +impl FallibleTonicResponseStream { + pub(crate) fn new( + receiver: Receiver, + response_stream: Pin> + Send + 'static>>, + ) -> Self { + Self { + receiver, + response_stream, + } + } +} + +impl Stream for FallibleTonicResponseStream { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let pinned = self.get_mut(); + let receiver = &mut pinned.receiver; + // Prioritise sending the error that's been notified over + // polling the response_stream + if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { + return Poll::Ready(Some(Err(err))); + }; + + match ready!(pinned.response_stream.poll_next_unpin(cx)) { + Some(Ok(res)) => Poll::Ready(Some(Ok(res))), + Some(Err(status)) => Poll::Ready(Some(Err(FlightError::Tonic(status)))), + None => Poll::Ready(None), + } + } +}