Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into warn-user-on-fork-take-2
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonspeed committed Oct 16, 2024
2 parents cedb140 + 3255066 commit a65272f
Show file tree
Hide file tree
Showing 69 changed files with 1,216 additions and 611 deletions.
22 changes: 11 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions crates/polars-arrow/src/io/ipc/read/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use polars_utils::aliases::{InitHashMaps, PlHashMap};
use super::super::{ARROW_MAGIC_V1, ARROW_MAGIC_V2, CONTINUATION_MARKER};
use super::common::*;
use super::schema::fb_to_schema;
use super::{Dictionaries, OutOfSpecKind};
use super::{Dictionaries, OutOfSpecKind, SendableIterator};
use crate::array::Array;
use crate::datatypes::ArrowSchemaRef;
use crate::io::ipc::IpcSchema;
Expand Down Expand Up @@ -208,7 +208,7 @@ pub(super) fn deserialize_schema_ref_from_footer(
/// Get the IPC blocks from the footer containing record batches
pub(super) fn iter_recordbatch_blocks_from_footer(
footer: arrow_format::ipc::FooterRef,
) -> PolarsResult<impl Iterator<Item = PolarsResult<arrow_format::ipc::Block>> + '_> {
) -> PolarsResult<impl SendableIterator<Item = PolarsResult<arrow_format::ipc::Block>> + '_> {
let blocks = footer
.record_batches()
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))?
Expand All @@ -223,7 +223,8 @@ pub(super) fn iter_recordbatch_blocks_from_footer(

pub(super) fn iter_dictionary_blocks_from_footer(
footer: arrow_format::ipc::FooterRef,
) -> PolarsResult<Option<impl Iterator<Item = PolarsResult<arrow_format::ipc::Block>> + '_>> {
) -> PolarsResult<Option<impl SendableIterator<Item = PolarsResult<arrow_format::ipc::Block>> + '_>>
{
let dictionaries = footer
.dictionaries()
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferDictionaries(err)))?;
Expand Down
169 changes: 108 additions & 61 deletions crates/polars-arrow/src/io/ipc/read/flight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ use futures::{Stream, StreamExt};
use polars_error::{polars_bail, polars_err, PolarsResult};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt};

use crate::datatypes::ArrowSchema;
use crate::io::ipc::read::common::read_record_batch;
use crate::io::ipc::read::file::{
decode_footer_len, deserialize_schema_ref_from_footer, iter_dictionary_blocks_from_footer,
iter_recordbatch_blocks_from_footer,
};
use crate::io::ipc::read::schema::deserialize_stream_metadata;
use crate::io::ipc::read::{Dictionaries, OutOfSpecKind, StreamMetadata};
use crate::io::ipc::read::{Dictionaries, OutOfSpecKind, SendableIterator, StreamMetadata};
use crate::io::ipc::write::common::EncodedData;
use crate::mmap::{mmap_dictionary_from_batch, mmap_record};
use crate::record_batch::RecordBatch;
Expand Down Expand Up @@ -169,8 +171,8 @@ pub async fn into_flight_stream<R: AsyncRead + AsyncSeek + Unpin + Send>(
pub struct FlightStreamProducer<'a, R: AsyncRead + AsyncSeek + Unpin + Send> {
footer: Option<*const FooterRef<'static>>,
footer_data: Vec<u8>,
dict_blocks: Option<Box<dyn Iterator<Item = PolarsResult<Block>>>>,
data_blocks: Option<Box<dyn Iterator<Item = PolarsResult<Block>>>>,
dict_blocks: Option<Box<dyn SendableIterator<Item = PolarsResult<Block>>>>,
data_blocks: Option<Box<dyn SendableIterator<Item = PolarsResult<Block>>>>,
reader: &'a mut R,
}

Expand All @@ -184,21 +186,23 @@ impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> Drop for FlightStreamProducer<
}
}

unsafe impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> Send for FlightStreamProducer<'a, R> {}

impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> {
pub async fn new(reader: &'a mut R) -> PolarsResult<Self> {
pub async fn new(reader: &'a mut R) -> PolarsResult<Pin<Box<Self>>> {
let (_end, len) = read_footer_len(reader).await?;
let footer_data = read_footer(reader, len).await?;

Ok(Self {
Ok(Box::pin(Self {
footer: None,
footer_data,
dict_blocks: None,
data_blocks: None,
reader,
})
}))
}

pub fn init(self: &mut Pin<&mut Self>) -> PolarsResult<()> {
pub fn init(self: &mut Pin<Box<Self>>) -> PolarsResult<()> {
let footer = arrow_format::ipc::FooterRef::read_as_root(&self.footer_data)
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?;

Expand All @@ -210,16 +214,15 @@ impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> {
self.footer = Some(ptr);
let footer = &unsafe { **self.footer.as_ref().unwrap() };

self.data_blocks =
Some(Box::new(iter_recordbatch_blocks_from_footer(*footer)?)
as Box<dyn Iterator<Item = _>>);
self.data_blocks = Some(Box::new(iter_recordbatch_blocks_from_footer(*footer)?)
as Box<dyn SendableIterator<Item = _>>);
self.dict_blocks = iter_dictionary_blocks_from_footer(*footer)?
.map(|i| Box::new(i) as Box<dyn Iterator<Item = _>>);
.map(|i| Box::new(i) as Box<dyn SendableIterator<Item = _>>);

Ok(())
}

pub fn get_schema(self: &Pin<&mut Self>) -> PolarsResult<EncodedData> {
pub fn get_schema(self: &Pin<Box<Self>>) -> PolarsResult<EncodedData> {
let footer = &unsafe { **self.footer.as_ref().expect("init must be called first") };

let schema_ref = deserialize_schema_ref_from_footer(*footer)?;
Expand All @@ -229,7 +232,7 @@ impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> {
}

pub async fn next_dict(
self: &mut Pin<&mut Self>,
self: &mut Pin<Box<Self>>,
encoded_data: &mut EncodedData,
) -> PolarsResult<Option<()>> {
assert!(self.data_blocks.is_some(), "init must be called first");
Expand All @@ -250,7 +253,7 @@ impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> {
}

pub async fn next_data(
self: &mut Pin<&mut Self>,
self: &mut Pin<Box<Self>>,
encoded_data: &mut EncodedData,
) -> PolarsResult<Option<()>> {
encoded_data.ipc_message.clear();
Expand All @@ -270,62 +273,78 @@ impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> {
}
}

pub struct FlightstreamConsumer<S: Stream<Item = PolarsResult<EncodedData>> + Unpin> {
pub struct FlightConsumer {
dictionaries: Dictionaries,
md: StreamMetadata,
stream: S,
scratch: Vec<u8>,
}

impl<S: Stream<Item = PolarsResult<EncodedData>> + Unpin> FlightstreamConsumer<S> {
pub async fn new(mut stream: S) -> PolarsResult<Self> {
let Some(first) = stream.next().await else {
polars_bail!(ComputeError: "expected the schema")
};
let first = first?;

impl FlightConsumer {
pub fn new(first: EncodedData) -> PolarsResult<Self> {
let md = deserialize_stream_metadata(&first.ipc_message)?;
Ok(FlightstreamConsumer {
Ok(Self {
dictionaries: Default::default(),
md,
stream,
scratch: vec![],
})
}

pub async fn next_batch(&mut self) -> PolarsResult<Option<RecordBatch>> {
while let Some(msg) = self.stream.next().await {
let msg = msg?;
pub fn schema(&self) -> &ArrowSchema {
&self.md.schema
}

// Parse the header
let message = arrow_format::ipc::MessageRef::read_as_root(&msg.ipc_message)
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?;

let header = message
.header()
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))?
.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?;

// Needed to memory map.
let arrow_data = Arc::new(msg.arrow_data);

// Either append to the dictionaries and return None or return Some(ArrowChunk)
match header {
MessageHeaderRef::Schema(_) => {
polars_bail!(ComputeError: "Unexpected schema message while parsing Stream");
},
// Add to dictionary state and continue iteration
MessageHeaderRef::DictionaryBatch(batch) => unsafe {
mmap_dictionary_from_batch(
&self.md.schema,
&self.md.ipc_schema.fields,
&arrow_data,
pub fn consume(&mut self, msg: EncodedData) -> PolarsResult<Option<RecordBatch>> {
// Parse the header
let message = arrow_format::ipc::MessageRef::read_as_root(&msg.ipc_message)
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?;

let header = message
.header()
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))?
.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?;

// Either append to the dictionaries and return None or return Some(ArrowChunk)
match header {
MessageHeaderRef::Schema(_) => {
polars_bail!(ComputeError: "Unexpected schema message while parsing Stream");
},
// Add to dictionary state and continue iteration
MessageHeaderRef::DictionaryBatch(batch) => unsafe {
// Needed to memory map.
let arrow_data = Arc::new(msg.arrow_data);
mmap_dictionary_from_batch(
&self.md.schema,
&self.md.ipc_schema.fields,
&arrow_data,
batch,
&mut self.dictionaries,
0,
)
.map(|_| None)
},
// Return Batch
MessageHeaderRef::RecordBatch(batch) => {
if batch.compression()?.is_some() {
let data_size = msg.arrow_data.len() as u64;
let mut reader = std::io::Cursor::new(msg.arrow_data.as_slice());
read_record_batch(
batch,
&mut self.dictionaries,
&self.md.schema,
&self.md.ipc_schema,
None,
None,
&self.dictionaries,
self.md.version,
&mut reader,
0,
)?
},
// Return Batch
MessageHeaderRef::RecordBatch(batch) => {
return unsafe {
data_size,
&mut self.scratch,
)
.map(Some)
} else {
// Needed to memory map.
let arrow_data = Arc::new(msg.arrow_data);
unsafe {
mmap_record(
&self.md.schema,
&self.md.ipc_schema.fields,
Expand All @@ -336,8 +355,37 @@ impl<S: Stream<Item = PolarsResult<EncodedData>> + Unpin> FlightstreamConsumer<S
)
.map(Some)
}
},
_ => unimplemented!(),
}
},
_ => unimplemented!(),
}
}
}

pub struct FlightstreamConsumer<S: Stream<Item = PolarsResult<EncodedData>> + Unpin> {
inner: FlightConsumer,
stream: S,
}

impl<S: Stream<Item = PolarsResult<EncodedData>> + Unpin> FlightstreamConsumer<S> {
pub async fn new(mut stream: S) -> PolarsResult<Self> {
let Some(first) = stream.next().await else {
polars_bail!(ComputeError: "expected the schema")
};
let first = first?;

Ok(FlightstreamConsumer {
inner: FlightConsumer::new(first)?,
stream,
})
}

pub async fn next_batch(&mut self) -> PolarsResult<Option<RecordBatch>> {
while let Some(msg) = self.stream.next().await {
let msg = msg?;
let option_recordbatch = self.inner.consume(msg)?;
if option_recordbatch.is_some() {
return Ok(option_recordbatch);
}
}
Ok(None)
Expand All @@ -355,7 +403,7 @@ mod test {

fn get_file_path() -> PathBuf {
let polars_arrow = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
std::path::Path::new(&polars_arrow).join("../../py-polars/tests/unit/io/files/foods1.ipc")
Path::new(&polars_arrow).join("../../py-polars/tests/unit/io/files/foods1.ipc")
}

fn read_file(path: &Path) -> RecordBatch {
Expand Down Expand Up @@ -384,7 +432,6 @@ mod test {
let path = &get_file_path();
let mut file = File::open(path).await.unwrap();
let mut p = FlightStreamProducer::new(&mut file).await.unwrap();
let mut p = std::pin::pin!(p);
p.init().unwrap();

let mut batches = vec![];
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-arrow/src/io/ipc/read/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ pub(crate) type Version = arrow_format::ipc::MetadataVersion;

#[cfg(feature = "io_flight")]
pub use flight::*;

pub trait SendableIterator: Send + Iterator {}

impl<T: Iterator + Send> SendableIterator for T {}
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/io/ipc/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod serialize;
mod stream;
pub(crate) mod writer;

pub use common::{Compression, Record, WriteOptions};
pub use common::{Compression, EncodedData, Record, WriteOptions};
pub use schema::schema_to_bytes;
pub use serialize::write;
use serialize::write_dictionary;
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/legacy/kernels/rolling/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,5 @@ pub struct RollingVarParams {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RollingQuantileParams {
pub prob: f64,
pub interpol: QuantileInterpolOptions,
pub method: QuantileMethod,
}
Loading

0 comments on commit a65272f

Please sign in to comment.