diff --git a/Cargo.lock b/Cargo.lock index c3109a177997..081b40c1b320 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2169,7 +2169,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.48.5", ] [[package]] @@ -3500,9 +3500,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.3" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15ee168e30649f7f234c3d49ef5a7a6cbf5134289bc46c29ff3155fa3221c225" +checksum = "00e89ce2565d6044ca31a3eb79a334c3a79a841120a98f64eea9f579564cb691" dependencies = [ "cfg-if", "chrono", @@ -3520,9 +3520,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.3" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e61cef80755fe9e46bb8a0b8f20752ca7676dcc07a5277d8b7768c6172e529b3" +checksum = "d8afbaf3abd7325e08f35ffb8deb5892046fcb2608b703db6a583a5ba4cea01e" dependencies = [ "once_cell", "target-lexicon", @@ -3530,9 +3530,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.3" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67ce096073ec5405f5ee2b8b31f03a68e02aa10d5d4f565eca04acc41931fa1c" +checksum = "ec15a5ba277339d04763f4c23d85987a5b08cbb494860be141e6a10a8eb88022" dependencies = [ "libc", "pyo3-build-config", @@ -3540,9 +3540,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.3" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2440c6d12bc8f3ae39f1e775266fa5122fd0c8891ce7520fa6048e683ad3de28" +checksum = "15e0f01b5364bcfbb686a52fc4181d412b708a68ed20c330db9fc8d2c2bf5a43" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -3552,9 +3552,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.22.3" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1be962f0e06da8f8465729ea2cb71a416d2257dff56cbe40a70d3e62a93ae5d1" +checksum = "a09b550200e1e5ed9176976d0060cbc2ea82dc8515da07885e7b8153a85caacb" dependencies = [ "heck 0.5.0", "proc-macro2", diff --git a/crates/polars-arrow/src/io/ipc/read/file.rs b/crates/polars-arrow/src/io/ipc/read/file.rs index e289c08742f5..a83e1b758d80 100644 --- a/crates/polars-arrow/src/io/ipc/read/file.rs +++ b/crates/polars-arrow/src/io/ipc/read/file.rs @@ -9,7 +9,7 @@ use polars_utils::aliases::{InitHashMaps, PlHashMap}; use super::super::{ARROW_MAGIC_V1, ARROW_MAGIC_V2, CONTINUATION_MARKER}; use super::common::*; use super::schema::fb_to_schema; -use super::{Dictionaries, OutOfSpecKind}; +use super::{Dictionaries, OutOfSpecKind, SendableIterator}; use crate::array::Array; use crate::datatypes::ArrowSchemaRef; use crate::io::ipc::IpcSchema; @@ -208,7 +208,7 @@ pub(super) fn deserialize_schema_ref_from_footer( /// Get the IPC blocks from the footer containing record batches pub(super) fn iter_recordbatch_blocks_from_footer( footer: arrow_format::ipc::FooterRef, -) -> PolarsResult> + '_> { +) -> PolarsResult> + '_> { let blocks = footer .record_batches() .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))? @@ -223,7 +223,8 @@ pub(super) fn iter_recordbatch_blocks_from_footer( pub(super) fn iter_dictionary_blocks_from_footer( footer: arrow_format::ipc::FooterRef, -) -> PolarsResult> + '_>> { +) -> PolarsResult> + '_>> +{ let dictionaries = footer .dictionaries() .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferDictionaries(err)))?; diff --git a/crates/polars-arrow/src/io/ipc/read/flight.rs b/crates/polars-arrow/src/io/ipc/read/flight.rs index 4bb5fd023051..aecf816c2a9a 100644 --- a/crates/polars-arrow/src/io/ipc/read/flight.rs +++ b/crates/polars-arrow/src/io/ipc/read/flight.rs @@ -8,12 +8,14 @@ use futures::{Stream, StreamExt}; use polars_error::{polars_bail, polars_err, PolarsResult}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; +use crate::datatypes::ArrowSchema; +use crate::io::ipc::read::common::read_record_batch; use crate::io::ipc::read::file::{ decode_footer_len, deserialize_schema_ref_from_footer, iter_dictionary_blocks_from_footer, iter_recordbatch_blocks_from_footer, }; use crate::io::ipc::read::schema::deserialize_stream_metadata; -use crate::io::ipc::read::{Dictionaries, OutOfSpecKind, StreamMetadata}; +use crate::io::ipc::read::{Dictionaries, OutOfSpecKind, SendableIterator, StreamMetadata}; use crate::io::ipc::write::common::EncodedData; use crate::mmap::{mmap_dictionary_from_batch, mmap_record}; use crate::record_batch::RecordBatch; @@ -169,8 +171,8 @@ pub async fn into_flight_stream( pub struct FlightStreamProducer<'a, R: AsyncRead + AsyncSeek + Unpin + Send> { footer: Option<*const FooterRef<'static>>, footer_data: Vec, - dict_blocks: Option>>>, - data_blocks: Option>>>, + dict_blocks: Option>>>, + data_blocks: Option>>>, reader: &'a mut R, } @@ -184,21 +186,23 @@ impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> Drop for FlightStreamProducer< } } +unsafe impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> Send for FlightStreamProducer<'a, R> {} + impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> { - pub async fn new(reader: &'a mut R) -> PolarsResult { + pub async fn new(reader: &'a mut R) -> PolarsResult>> { let (_end, len) = read_footer_len(reader).await?; let footer_data = read_footer(reader, len).await?; - Ok(Self { + Ok(Box::pin(Self { footer: None, footer_data, dict_blocks: None, data_blocks: None, reader, - }) + })) } - pub fn init(self: &mut Pin<&mut Self>) -> PolarsResult<()> { + pub fn init(self: &mut Pin>) -> PolarsResult<()> { let footer = arrow_format::ipc::FooterRef::read_as_root(&self.footer_data) .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?; @@ -210,16 +214,15 @@ impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> { self.footer = Some(ptr); let footer = &unsafe { **self.footer.as_ref().unwrap() }; - self.data_blocks = - Some(Box::new(iter_recordbatch_blocks_from_footer(*footer)?) - as Box>); + self.data_blocks = Some(Box::new(iter_recordbatch_blocks_from_footer(*footer)?) + as Box>); self.dict_blocks = iter_dictionary_blocks_from_footer(*footer)? - .map(|i| Box::new(i) as Box>); + .map(|i| Box::new(i) as Box>); Ok(()) } - pub fn get_schema(self: &Pin<&mut Self>) -> PolarsResult { + pub fn get_schema(self: &Pin>) -> PolarsResult { let footer = &unsafe { **self.footer.as_ref().expect("init must be called first") }; let schema_ref = deserialize_schema_ref_from_footer(*footer)?; @@ -229,7 +232,7 @@ impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> { } pub async fn next_dict( - self: &mut Pin<&mut Self>, + self: &mut Pin>, encoded_data: &mut EncodedData, ) -> PolarsResult> { assert!(self.data_blocks.is_some(), "init must be called first"); @@ -250,7 +253,7 @@ impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> { } pub async fn next_data( - self: &mut Pin<&mut Self>, + self: &mut Pin>, encoded_data: &mut EncodedData, ) -> PolarsResult> { encoded_data.ipc_message.clear(); @@ -270,62 +273,78 @@ impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> { } } -pub struct FlightstreamConsumer> + Unpin> { +pub struct FlightConsumer { dictionaries: Dictionaries, md: StreamMetadata, - stream: S, + scratch: Vec, } -impl> + Unpin> FlightstreamConsumer { - pub async fn new(mut stream: S) -> PolarsResult { - let Some(first) = stream.next().await else { - polars_bail!(ComputeError: "expected the schema") - }; - let first = first?; - +impl FlightConsumer { + pub fn new(first: EncodedData) -> PolarsResult { let md = deserialize_stream_metadata(&first.ipc_message)?; - Ok(FlightstreamConsumer { + Ok(Self { dictionaries: Default::default(), md, - stream, + scratch: vec![], }) } - pub async fn next_batch(&mut self) -> PolarsResult> { - while let Some(msg) = self.stream.next().await { - let msg = msg?; + pub fn schema(&self) -> &ArrowSchema { + &self.md.schema + } - // Parse the header - let message = arrow_format::ipc::MessageRef::read_as_root(&msg.ipc_message) - .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?; - - let header = message - .header() - .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))? - .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?; - - // Needed to memory map. - let arrow_data = Arc::new(msg.arrow_data); - - // Either append to the dictionaries and return None or return Some(ArrowChunk) - match header { - MessageHeaderRef::Schema(_) => { - polars_bail!(ComputeError: "Unexpected schema message while parsing Stream"); - }, - // Add to dictionary state and continue iteration - MessageHeaderRef::DictionaryBatch(batch) => unsafe { - mmap_dictionary_from_batch( - &self.md.schema, - &self.md.ipc_schema.fields, - &arrow_data, + pub fn consume(&mut self, msg: EncodedData) -> PolarsResult> { + // Parse the header + let message = arrow_format::ipc::MessageRef::read_as_root(&msg.ipc_message) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let header = message + .header() + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?; + + // Either append to the dictionaries and return None or return Some(ArrowChunk) + match header { + MessageHeaderRef::Schema(_) => { + polars_bail!(ComputeError: "Unexpected schema message while parsing Stream"); + }, + // Add to dictionary state and continue iteration + MessageHeaderRef::DictionaryBatch(batch) => unsafe { + // Needed to memory map. + let arrow_data = Arc::new(msg.arrow_data); + mmap_dictionary_from_batch( + &self.md.schema, + &self.md.ipc_schema.fields, + &arrow_data, + batch, + &mut self.dictionaries, + 0, + ) + .map(|_| None) + }, + // Return Batch + MessageHeaderRef::RecordBatch(batch) => { + if batch.compression()?.is_some() { + let data_size = msg.arrow_data.len() as u64; + let mut reader = std::io::Cursor::new(msg.arrow_data.as_slice()); + read_record_batch( batch, - &mut self.dictionaries, + &self.md.schema, + &self.md.ipc_schema, + None, + None, + &self.dictionaries, + self.md.version, + &mut reader, 0, - )? - }, - // Return Batch - MessageHeaderRef::RecordBatch(batch) => { - return unsafe { + data_size, + &mut self.scratch, + ) + .map(Some) + } else { + // Needed to memory map. + let arrow_data = Arc::new(msg.arrow_data); + unsafe { mmap_record( &self.md.schema, &self.md.ipc_schema.fields, @@ -336,8 +355,37 @@ impl> + Unpin> FlightstreamConsumer unimplemented!(), + } + }, + _ => unimplemented!(), + } + } +} + +pub struct FlightstreamConsumer> + Unpin> { + inner: FlightConsumer, + stream: S, +} + +impl> + Unpin> FlightstreamConsumer { + pub async fn new(mut stream: S) -> PolarsResult { + let Some(first) = stream.next().await else { + polars_bail!(ComputeError: "expected the schema") + }; + let first = first?; + + Ok(FlightstreamConsumer { + inner: FlightConsumer::new(first)?, + stream, + }) + } + + pub async fn next_batch(&mut self) -> PolarsResult> { + while let Some(msg) = self.stream.next().await { + let msg = msg?; + let option_recordbatch = self.inner.consume(msg)?; + if option_recordbatch.is_some() { + return Ok(option_recordbatch); } } Ok(None) @@ -355,7 +403,7 @@ mod test { fn get_file_path() -> PathBuf { let polars_arrow = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); - std::path::Path::new(&polars_arrow).join("../../py-polars/tests/unit/io/files/foods1.ipc") + Path::new(&polars_arrow).join("../../py-polars/tests/unit/io/files/foods1.ipc") } fn read_file(path: &Path) -> RecordBatch { @@ -384,7 +432,6 @@ mod test { let path = &get_file_path(); let mut file = File::open(path).await.unwrap(); let mut p = FlightStreamProducer::new(&mut file).await.unwrap(); - let mut p = std::pin::pin!(p); p.init().unwrap(); let mut batches = vec![]; diff --git a/crates/polars-arrow/src/io/ipc/read/mod.rs b/crates/polars-arrow/src/io/ipc/read/mod.rs index 2cbe0c0f1332..88411f9b905f 100644 --- a/crates/polars-arrow/src/io/ipc/read/mod.rs +++ b/crates/polars-arrow/src/io/ipc/read/mod.rs @@ -39,3 +39,7 @@ pub(crate) type Version = arrow_format::ipc::MetadataVersion; #[cfg(feature = "io_flight")] pub use flight::*; + +pub trait SendableIterator: Send + Iterator {} + +impl SendableIterator for T {} diff --git a/crates/polars-arrow/src/io/ipc/write/mod.rs b/crates/polars-arrow/src/io/ipc/write/mod.rs index 99f6fcc3f355..2291448d3012 100644 --- a/crates/polars-arrow/src/io/ipc/write/mod.rs +++ b/crates/polars-arrow/src/io/ipc/write/mod.rs @@ -5,7 +5,7 @@ mod serialize; mod stream; pub(crate) mod writer; -pub use common::{Compression, Record, WriteOptions}; +pub use common::{Compression, EncodedData, Record, WriteOptions}; pub use schema::schema_to_bytes; pub use serialize::write; use serialize::write_dictionary; diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs index 8f8004c570e8..51f3c95d2a56 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs @@ -93,5 +93,5 @@ pub struct RollingVarParams { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct RollingQuantileParams { pub prob: f64, - pub interpol: QuantileInterpolOptions, + pub method: QuantileMethod, } diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs index 1b9695358dcb..3277318e6807 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs @@ -71,15 +71,19 @@ where #[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum QuantileInterpolOptions { +pub enum QuantileMethod { #[default] Nearest, Lower, Higher, Midpoint, Linear, + Equiprobable, } +#[deprecated(note = "use QuantileMethod instead")] +pub type QuantileInterpolOptions = QuantileMethod; + pub(super) fn rolling_apply_weights( values: &[T], window_size: usize, diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs index ab3919b9aaaa..bf0ad01e79c3 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs @@ -2,13 +2,13 @@ use num_traits::ToPrimitive; use polars_error::polars_ensure; use polars_utils::slice::GetSaferUnchecked; -use super::QuantileInterpolOptions::*; +use super::QuantileMethod::*; use super::*; pub struct QuantileWindow<'a, T: NativeType> { sorted: SortedBuf<'a, T>, prob: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, } impl< @@ -34,7 +34,7 @@ impl< Self { sorted: SortedBuf::new(slice, start, end), prob: params.prob, - interpol: params.interpol, + method: params.method, } } @@ -42,7 +42,7 @@ impl< let vals = self.sorted.update(start, end); let length = vals.len(); - let idx = match self.interpol { + let idx = match self.method { Linear => { // Maybe add a fast path for median case? They could branch depending on odd/even. let length_f = length as f64; @@ -92,6 +92,7 @@ impl< let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize; std::cmp::min(idx, length - 1) }, + Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize, }; // SAFETY: @@ -134,7 +135,7 @@ where unreachable!("expected Quantile params"); }; let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>( - params.interpol, + params.method, min_periods, window_size, values, @@ -170,7 +171,7 @@ where Ok(rolling_apply_weighted_quantile( values, params.prob, - params.interpol, + params.method, window_size, min_periods, offset_fn, @@ -182,7 +183,7 @@ where } #[inline] -fn compute_wq(buf: &[(T, f64)], p: f64, wsum: f64, interp: QuantileInterpolOptions) -> T +fn compute_wq(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T where T: Debug + NativeType + Mul + Sub + NumCast + ToPrimitive + Zero, { @@ -201,7 +202,7 @@ where (s_old, v_old, vk) = (s, vk, v); s += w; } - match (h == s_old, interp) { + match (h == s_old, method) { (true, _) => v_old, // If we hit the break exactly interpolation shouldn't matter (_, Lower) => v_old, (_, Higher) => vk, @@ -212,6 +213,14 @@ where vk } }, + (_, Equiprobable) => { + let threshold = (wsum * p).ceil() - 1.0; + if s > threshold { + vk + } else { + v_old + } + }, (_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(), // This is seemingly the canonical way to do it. (_, Linear) => { @@ -224,7 +233,7 @@ where fn rolling_apply_weighted_quantile( values: &[T], p: f64, - interpolation: QuantileInterpolOptions, + method: QuantileMethod, window_size: usize, min_periods: usize, det_offsets_fn: Fo, @@ -252,7 +261,7 @@ where .for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w)); } buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0)); - compute_wq(&buf, p, wsum, interpolation) + compute_wq(&buf, p, wsum, method) }) .collect_trusted::>(); @@ -273,7 +282,7 @@ mod test { let values = &[1.0, 2.0, 3.0, 4.0]; let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.5, - interpol: Linear, + method: Linear, })); let out = rolling_quantile(values, 2, 2, false, None, med_pars.clone()).unwrap(); let out = out.as_any().downcast_ref::>().unwrap(); @@ -305,18 +314,19 @@ mod test { fn test_rolling_quantile_limits() { let values = &[1.0f64, 2.0, 3.0, 4.0]; - let interpol_options = vec![ - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Nearest, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { + for method in methods { let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.0, - interpol, + method, })); let out1 = rolling_min(values, 2, 2, false, None, None).unwrap(); let out1 = out1.as_any().downcast_ref::>().unwrap(); @@ -328,7 +338,7 @@ mod test { let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 1.0, - interpol, + method, })); let out1 = rolling_max(values, 2, 2, false, None, None).unwrap(); let out1 = out1.as_any().downcast_ref::>().unwrap(); diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs index 259316513fe5..3d5dd664bd34 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs @@ -6,7 +6,7 @@ use crate::array::MutablePrimitiveArray; pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> { sorted: SortedBufNulls<'a, T>, prob: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, } impl< @@ -39,7 +39,7 @@ impl< Self { sorted: SortedBufNulls::new(slice, validity, start, end), prob: params.prob, - interpol: params.interpol, + method: params.method, } } @@ -53,21 +53,22 @@ impl< let values = &values[null_count..]; let length = values.len(); - let mut idx = match self.interpol { - QuantileInterpolOptions::Nearest => ((length as f64) * self.prob) as usize, - QuantileInterpolOptions::Lower - | QuantileInterpolOptions::Midpoint - | QuantileInterpolOptions::Linear => { + let mut idx = match self.method { + QuantileMethod::Nearest => ((length as f64) * self.prob) as usize, + QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => { ((length as f64 - 1.0) * self.prob).floor() as usize }, - QuantileInterpolOptions::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize, + QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize, + QuantileMethod::Equiprobable => { + ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize + }, }; idx = std::cmp::min(idx, length - 1); // we can unwrap because we sliced of the nulls - match self.interpol { - QuantileInterpolOptions::Midpoint => { + match self.method { + QuantileMethod::Midpoint => { let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize; Some( (values.get_unchecked_release(idx).unwrap() @@ -75,7 +76,7 @@ impl< / T::from::(2.0f64).unwrap(), ) }, - QuantileInterpolOptions::Linear => { + QuantileMethod::Linear => { let float_idx = (length as f64 - 1.0) * self.prob; let top_idx = f64::ceil(float_idx) as usize; @@ -136,7 +137,7 @@ where }; let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>( - params.interpol, + params.method, min_periods, window_size, arr.clone(), @@ -171,7 +172,7 @@ mod test { ); let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.5, - interpol: QuantileInterpolOptions::Linear, + method: QuantileMethod::Linear, })); let out = rolling_quantile(arr, 2, 2, false, None, med_pars.clone()); @@ -210,18 +211,19 @@ mod test { Some(Bitmap::from(&[true, false, false, true, true])), ); - let interpol_options = vec![ - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Nearest, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { + for method in methods { let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.0, - interpol, + method, })); let out1 = rolling_min(values, 2, 1, false, None, None); let out1 = out1.as_any().downcast_ref::>().unwrap(); @@ -233,7 +235,7 @@ mod test { let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 1.0, - interpol, + method, })); let out1 = rolling_max(values, 2, 1, false, None, None); let out1 = out1.as_any().downcast_ref::>().unwrap(); diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs b/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs index 40a464e6f5bc..0b5fb4d97e86 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs @@ -11,7 +11,7 @@ use polars_utils::slice::{GetSaferUnchecked, SliceAble}; use polars_utils::sort::arg_sort_ascending; use polars_utils::total_ord::TotalOrd; -use crate::legacy::prelude::QuantileInterpolOptions; +use crate::legacy::prelude::QuantileMethod; use crate::pushable::Pushable; use crate::types::NativeType; @@ -573,7 +573,7 @@ struct QuantileUpdate { inner: M, quantile: f64, min_periods: usize, - interpol: QuantileInterpolOptions, + method: QuantileMethod, } impl QuantileUpdate @@ -581,12 +581,12 @@ where M: LenGet, ::Item: Default + IsNull + Copy + FinishLinear + Debug, { - fn new(interpol: QuantileInterpolOptions, min_periods: usize, quantile: f64, inner: M) -> Self { + fn new(method: QuantileMethod, min_periods: usize, quantile: f64, inner: M) -> Self { Self { min_periods, quantile, inner, - interpol, + method, } } @@ -602,8 +602,8 @@ where let valid_length_f = valid_length as f64; - use QuantileInterpolOptions::*; - match self.interpol { + use QuantileMethod::*; + match self.method { Linear => { let float_idx_top = (valid_length_f - 1.0) * self.quantile; let idx = float_idx_top.floor() as usize; @@ -623,6 +623,10 @@ where let idx = std::cmp::min(idx, valid_length - 1); self.inner.get(idx + null_count) }, + Equiprobable => { + let idx = ((valid_length_f * self.quantile).ceil() - 1.0).max(0.0) as usize; + self.inner.get(idx + null_count) + }, Midpoint => { let idx = (valid_length_f * self.quantile) as usize; let idx = std::cmp::min(idx, valid_length - 1); @@ -651,7 +655,7 @@ where } pub(super) fn rolling_quantile::Item>>( - interpol: QuantileInterpolOptions, + method: QuantileMethod, min_periods: usize, k: usize, values: A, @@ -709,7 +713,7 @@ where // SAFETY: bounded by capacity unsafe { block_left.undelete(i) }; - let mut mu = QuantileUpdate::new(interpol, min_periods, quantile, &mut block_left); + let mut mu = QuantileUpdate::new(method, min_periods, quantile, &mut block_left); out.push(mu.quantile()); } for i in 1..n_blocks + 1 { @@ -747,7 +751,7 @@ where let mut union = BlockUnion::new(&mut *ptr_left, &mut *ptr_right); union.set_state(j); let q: ::Item = - QuantileUpdate::new(interpol, min_periods, quantile, union).quantile(); + QuantileUpdate::new(method, min_periods, quantile, union).quantile(); out.push(q); } } @@ -1062,22 +1066,22 @@ mod test { 2.0, 8.0, 5.0, 9.0, 1.0, 2.0, 4.0, 2.0, 4.0, 8.1, -1.0, 2.9, 1.2, 23.0, ] .as_ref(); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 3, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 3, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 8.0, 5.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 2.9, 1.2, 2.9, ]; assert_eq!(out, expected); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 5, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 5, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 6.5, 5.0, 5.0, 4.0, 2.0, 2.0, 4.0, 4.0, 2.9, 2.9, 2.9, ]; assert_eq!(out, expected); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 7, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 7, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 6.5, 5.0, 3.5, 4.0, 4.0, 4.0, 4.0, 2.0, 2.9, 2.9, 2.9, ]; assert_eq!(out, expected); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 4, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 4, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 6.5, 6.5, 3.5, 3.0, 2.0, 3.0, 4.0, 3.0, 3.45, 2.05, 2.05, ]; @@ -1087,7 +1091,7 @@ mod test { #[test] fn test_median_2() { let values = [10, 10, 15, 13, 9, 5, 3, 13, 19, 15, 19].as_ref(); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 3, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 3, values, 0.5); let expected = [10, 10, 10, 13, 13, 9, 5, 5, 13, 15, 19]; assert_eq!(out, expected); } diff --git a/crates/polars-arrow/src/legacy/prelude.rs b/crates/polars-arrow/src/legacy/prelude.rs index 88b2dd48bbea..6afeb0c6c9be 100644 --- a/crates/polars-arrow/src/legacy/prelude.rs +++ b/crates/polars-arrow/src/legacy/prelude.rs @@ -2,7 +2,7 @@ use crate::array::{BinaryArray, ListArray, Utf8Array}; pub use crate::legacy::array::default_arrays::*; pub use crate::legacy::array::*; pub use crate::legacy::index::*; -pub use crate::legacy::kernels::rolling::no_nulls::QuantileInterpolOptions; +pub use crate::legacy::kernels::rolling::no_nulls::QuantileMethod; pub use crate::legacy::kernels::rolling::{ RollingFnParams, RollingQuantileParams, RollingVarParams, }; @@ -11,3 +11,6 @@ pub use crate::legacy::kernels::{Ambiguous, NonExistent}; pub type LargeStringArray = Utf8Array; pub type LargeBinaryArray = BinaryArray; pub type LargeListArray = ListArray; + +#[allow(deprecated)] +pub use crate::legacy::kernels::rolling::no_nulls::QuantileInterpolOptions; diff --git a/crates/polars-arrow/src/types/native.rs b/crates/polars-arrow/src/types/native.rs index 6f869df32602..230fdde387d1 100644 --- a/crates/polars-arrow/src/types/native.rs +++ b/crates/polars-arrow/src/types/native.rs @@ -1,10 +1,11 @@ +use std::hash::{Hash, Hasher}; use std::ops::Neg; use std::panic::RefUnwindSafe; use bytemuck::{Pod, Zeroable}; use polars_utils::min_max::MinMax; use polars_utils::nulls::IsNull; -use polars_utils::total_ord::{TotalEq, TotalOrd}; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrd, TotalOrdWrap}; use super::PrimitiveType; @@ -434,6 +435,44 @@ impl PartialEq for f16 { } } +/// Converts an f32 into a canonical form, where -0 == 0 and all NaNs map to +/// the same value. +#[inline] +pub fn canonical_f16(x: f16) -> f16 { + // zero out the sign bit if the f16 is zero. + let convert_zero = f16(x.0 & (0x7FFF | (u16::from(x.0 & 0x7FFF == 0) << 15))); + if convert_zero.is_nan() { + f16::from_bits(0x7c00) // Canonical quiet NaN. + } else { + convert_zero + } +} + +impl TotalHash for f16 { + #[inline(always)] + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + canonical_f16(*self).to_bits().hash(state) + } +} + +impl ToTotalOrd for f16 { + type TotalOrdItem = TotalOrdWrap; + type SourceItem = f16; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + TotalOrdWrap(*self) + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item.0 + } +} + impl IsNull for f16 { const HAS_NULLS: bool = false; type Inner = f16; diff --git a/crates/polars-compute/src/cardinality.rs b/crates/polars-compute/src/cardinality.rs new file mode 100644 index 000000000000..d28efa9d051e --- /dev/null +++ b/crates/polars-compute/src/cardinality.rs @@ -0,0 +1,159 @@ +use arrow::array::{ + Array, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeBinaryArray, PrimitiveArray, + Utf8Array, Utf8ViewArray, +}; +use arrow::datatypes::PhysicalType; +use arrow::types::Offset; +use arrow::with_match_primitive_type_full; +use polars_utils::total_ord::ToTotalOrd; + +use crate::hyperloglogplus::HyperLogLog; + +/// Get an estimate for the *cardinality* of the array (i.e. the number of unique values) +/// +/// This is not currently implemented for nested types. +pub fn estimate_cardinality(array: &dyn Array) -> usize { + if array.is_empty() { + return 0; + } + + if array.null_count() == array.len() { + return 1; + } + + // Estimate the cardinality with HyperLogLog + use PhysicalType as PT; + match array.dtype().to_physical_type() { + PT::Null => 1, + + PT::Boolean => { + let mut cardinality = 0; + + let array = array.as_any().downcast_ref::().unwrap(); + + cardinality += usize::from(array.has_nulls()); + + if let Some(unset_bits) = array.values().lazy_unset_bits() { + cardinality += 1 + usize::from(unset_bits != array.len()); + } else { + cardinality += 2; + } + + cardinality + }, + + PT::Primitive(primitive_type) => with_match_primitive_type_full!(primitive_type, |$T| { + let mut hll = HyperLogLog::new(); + + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.copied().unwrap_or_default(); + hll.add(&v.to_total_ord()); + } + } else { + for v in array.values_iter() { + hll.add(&v.to_total_ord()); + } + } + + hll.count() + }), + PT::FixedSizeBinary => { + let mut hll = HyperLogLog::new(); + + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.unwrap_or_default(); + hll.add(v); + } + } else { + for v in array.values_iter() { + hll.add(v); + } + } + + hll.count() + }, + PT::Binary => { + binary_offset_array_estimate(array.as_any().downcast_ref::>().unwrap()) + }, + PT::LargeBinary => { + binary_offset_array_estimate(array.as_any().downcast_ref::>().unwrap()) + }, + PT::Utf8 => binary_offset_array_estimate( + &array + .as_any() + .downcast_ref::>() + .unwrap() + .to_binary(), + ), + PT::LargeUtf8 => binary_offset_array_estimate( + &array + .as_any() + .downcast_ref::>() + .unwrap() + .to_binary(), + ), + PT::BinaryView => { + binary_view_array_estimate(array.as_any().downcast_ref::().unwrap()) + }, + PT::Utf8View => binary_view_array_estimate( + &array + .as_any() + .downcast_ref::() + .unwrap() + .to_binview(), + ), + PT::List => unimplemented!(), + PT::FixedSizeList => unimplemented!(), + PT::LargeList => unimplemented!(), + PT::Struct => unimplemented!(), + PT::Union => unimplemented!(), + PT::Map => unimplemented!(), + PT::Dictionary(_) => unimplemented!(), + } +} + +fn binary_offset_array_estimate(array: &BinaryArray) -> usize { + let mut hll = HyperLogLog::new(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.unwrap_or_default(); + hll.add(v); + } + } else { + for v in array.values_iter() { + hll.add(v); + } + } + + hll.count() +} + +fn binary_view_array_estimate(array: &BinaryViewArray) -> usize { + let mut hll = HyperLogLog::new(); + + if array.has_nulls() { + for v in array.iter() { + let v = v.unwrap_or_default(); + hll.add(v); + } + } else { + for v in array.values_iter() { + hll.add(v); + } + } + + hll.count() +} diff --git a/crates/polars-compute/src/lib.rs b/crates/polars-compute/src/lib.rs index da56c65983db..30efdd59adc7 100644 --- a/crates/polars-compute/src/lib.rs +++ b/crates/polars-compute/src/lib.rs @@ -10,6 +10,8 @@ use arrow::types::NativeType; pub mod arithmetic; pub mod arity; pub mod bitwise; +#[cfg(feature = "approx_unique")] +pub mod cardinality; pub mod comparisons; pub mod filter; pub mod float_sum; diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs index 071073460ff3..5b3c0b5f53d7 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -369,12 +369,8 @@ where ::Simd: Add::Simd> + compute::aggregate::Sum, { - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.quantile(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.quantile(quantile, method)?; Ok(Scalar::new(DataType::Float64, v.into())) } @@ -385,12 +381,8 @@ where } impl QuantileAggSeries for Float32Chunked { - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.quantile(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.quantile(quantile, method)?; Ok(Scalar::new(DataType::Float32, v.into())) } @@ -401,12 +393,8 @@ impl QuantileAggSeries for Float32Chunked { } impl QuantileAggSeries for Float64Chunked { - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.quantile(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.quantile(quantile, method)?; Ok(Scalar::new(DataType::Float64, v.into())) } @@ -735,19 +723,20 @@ mod test { let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); - let interpol_options = vec![ - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Nearest, + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { - assert_eq!(test_f32.quantile(0.9, interpol).unwrap(), None); - assert_eq!(test_i32.quantile(0.9, interpol).unwrap(), None); - assert_eq!(test_f64.quantile(0.9, interpol).unwrap(), None); - assert_eq!(test_i64.quantile(0.9, interpol).unwrap(), None); + for method in methods { + assert_eq!(test_f32.quantile(0.9, method).unwrap(), None); + assert_eq!(test_i32.quantile(0.9, method).unwrap(), None); + assert_eq!(test_f64.quantile(0.9, method).unwrap(), None); + assert_eq!(test_i64.quantile(0.9, method).unwrap(), None); } } @@ -758,19 +747,20 @@ mod test { let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]); let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]); - let interpol_options = vec![ - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Nearest, + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { - assert_eq!(test_f32.quantile(0.5, interpol).unwrap(), Some(1.0)); - assert_eq!(test_i32.quantile(0.5, interpol).unwrap(), Some(1.0)); - assert_eq!(test_f64.quantile(0.5, interpol).unwrap(), Some(1.0)); - assert_eq!(test_i64.quantile(0.5, interpol).unwrap(), Some(1.0)); + for method in methods { + assert_eq!(test_f32.quantile(0.5, method).unwrap(), Some(1.0)); + assert_eq!(test_i32.quantile(0.5, method).unwrap(), Some(1.0)); + assert_eq!(test_f64.quantile(0.5, method).unwrap(), Some(1.0)); + assert_eq!(test_i64.quantile(0.5, method).unwrap(), Some(1.0)); } } @@ -793,37 +783,38 @@ mod test { &[None, Some(1i64), Some(5i64), Some(1i64)], ); - let interpol_options = vec![ - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Nearest, + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { - assert_eq!(test_f32.quantile(0.0, interpol).unwrap(), test_f32.min()); - assert_eq!(test_f32.quantile(1.0, interpol).unwrap(), test_f32.max()); + for method in methods { + assert_eq!(test_f32.quantile(0.0, method).unwrap(), test_f32.min()); + assert_eq!(test_f32.quantile(1.0, method).unwrap(), test_f32.max()); assert_eq!( - test_i32.quantile(0.0, interpol).unwrap().unwrap(), + test_i32.quantile(0.0, method).unwrap().unwrap(), test_i32.min().unwrap() as f64 ); assert_eq!( - test_i32.quantile(1.0, interpol).unwrap().unwrap(), + test_i32.quantile(1.0, method).unwrap().unwrap(), test_i32.max().unwrap() as f64 ); - assert_eq!(test_f64.quantile(0.0, interpol).unwrap(), test_f64.min()); - assert_eq!(test_f64.quantile(1.0, interpol).unwrap(), test_f64.max()); - assert_eq!(test_f64.quantile(0.5, interpol).unwrap(), test_f64.median()); + assert_eq!(test_f64.quantile(0.0, method).unwrap(), test_f64.min()); + assert_eq!(test_f64.quantile(1.0, method).unwrap(), test_f64.max()); + assert_eq!(test_f64.quantile(0.5, method).unwrap(), test_f64.median()); assert_eq!( - test_i64.quantile(0.0, interpol).unwrap().unwrap(), + test_i64.quantile(0.0, method).unwrap().unwrap(), test_i64.min().unwrap() as f64 ); assert_eq!( - test_i64.quantile(1.0, interpol).unwrap().unwrap(), + test_i64.quantile(1.0, method).unwrap().unwrap(), test_i64.max().unwrap() as f64 ); } @@ -837,72 +828,56 @@ mod test { ); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.1, QuantileMethod::Nearest).unwrap(), Some(1.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.9, QuantileMethod::Nearest).unwrap(), Some(5.0) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.6, QuantileMethod::Nearest).unwrap(), Some(3.0) ); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Lower).unwrap(), - Some(1.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Lower).unwrap(), - Some(4.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Lower).unwrap(), - Some(3.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(4.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(3.0)); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Higher).unwrap(), - Some(2.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Higher).unwrap(), - Some(5.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Higher).unwrap(), - Some(4.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(5.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(4.0)); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(), Some(1.5) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(), Some(4.5) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(), Some(3.5) ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.4)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(4.6)); + assert!( + (ca.quantile(0.6, QuantileMethod::Linear).unwrap().unwrap() - 3.4).abs() < 0.0000001 + ); + assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Linear).unwrap(), - Some(1.4) + ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(), + Some(1.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Linear).unwrap(), - Some(4.6) + ca.quantile(0.25, QuantileMethod::Equiprobable).unwrap(), + Some(2.0) ); - assert!( - (ca.quantile(0.6, QuantileInterpolOptions::Linear) - .unwrap() - .unwrap() - - 3.4) - .abs() - < 0.0000001 + assert_eq!( + ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(), + Some(3.0) ); let ca = UInt32Chunked::new( @@ -922,68 +897,54 @@ mod test { ); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.1, QuantileMethod::Nearest).unwrap(), Some(2.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.9, QuantileMethod::Nearest).unwrap(), Some(6.0) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.6, QuantileMethod::Nearest).unwrap(), Some(5.0) ); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Lower).unwrap(), - Some(1.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Lower).unwrap(), - Some(6.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Lower).unwrap(), - Some(4.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(6.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(4.0)); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Higher).unwrap(), - Some(2.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Higher).unwrap(), - Some(7.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Higher).unwrap(), - Some(5.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(7.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(5.0)); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(), Some(1.5) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(), Some(6.5) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(), Some(4.5) ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.6)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(6.4)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Linear).unwrap(), Some(4.6)); + assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Linear).unwrap(), - Some(1.6) + ca.quantile(0.14, QuantileMethod::Equiprobable).unwrap(), + Some(1.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Linear).unwrap(), - Some(6.4) + ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(), + Some(2.0) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Linear).unwrap(), - Some(4.6) + ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(), + Some(5.0) ); } } diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs index d6218e81d463..f7716c864559 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs @@ -4,11 +4,7 @@ pub trait QuantileAggSeries { /// Get the median of the [`ChunkedArray`] as a new [`Series`] of length 1. fn median_reduce(&self) -> Scalar; /// Get the quantile of the [`ChunkedArray`] as a new [`Series`] of length 1. - fn quantile_reduce( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult; + fn quantile_reduce(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult; } /// helper @@ -16,18 +12,23 @@ fn quantile_idx( quantile: f64, length: usize, null_count: usize, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> (usize, f64, usize) { - let float_idx = ((length - null_count) as f64 - 1.0) * quantile + null_count as f64; - let mut base_idx = match interpol { - QuantileInterpolOptions::Nearest => { + let nonnull_count = (length - null_count) as f64; + let float_idx = (nonnull_count - 1.0) * quantile + null_count as f64; + let mut base_idx = match method { + QuantileMethod::Nearest => { let idx = float_idx.round() as usize; - return (float_idx.round() as usize, 0.0, idx); + return (idx, 0.0, idx); + }, + QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => { + float_idx as usize + }, + QuantileMethod::Higher => float_idx.ceil() as usize, + QuantileMethod::Equiprobable => { + let idx = ((nonnull_count * quantile).ceil() - 1.0).max(0.0) as usize + null_count; + return (idx, 0.0, idx); }, - QuantileInterpolOptions::Lower - | QuantileInterpolOptions::Midpoint - | QuantileInterpolOptions::Linear => float_idx as usize, - QuantileInterpolOptions::Higher => float_idx.ceil() as usize, }; base_idx = base_idx.clamp(0, length - 1); @@ -57,7 +58,7 @@ fn midpoint_interpol(lower: T, upper: T) -> T { fn quantile_slice( vals: &mut [T], quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { polars_ensure!((0.0..=1.0).contains(&quantile), ComputeError: "quantile should be between 0.0 and 1.0", @@ -68,21 +69,21 @@ fn quantile_slice( if vals.len() == 1 { return Ok(vals[0].to_f64()); } - let (idx, float_idx, top_idx) = quantile_idx(quantile, vals.len(), 0, interpol); + let (idx, float_idx, top_idx) = quantile_idx(quantile, vals.len(), 0, method); let (_lhs, lower, rhs) = vals.select_nth_unstable_by(idx, TotalOrd::tot_cmp); if idx == top_idx { Ok(lower.to_f64()) } else { - match interpol { - QuantileInterpolOptions::Midpoint => { + match method { + QuantileMethod::Midpoint => { let upper = rhs.iter().copied().min_by(TotalOrd::tot_cmp).unwrap(); Ok(Some(midpoint_interpol( lower.to_f64().unwrap(), upper.to_f64().unwrap(), ))) }, - QuantileInterpolOptions::Linear => { + QuantileMethod::Linear => { let upper = rhs.iter().copied().min_by(TotalOrd::tot_cmp).unwrap(); Ok(linear_interpol( lower.to_f64().unwrap(), @@ -100,7 +101,7 @@ fn quantile_slice( fn generic_quantile( ca: ChunkedArray, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> where T: PolarsNumericType, @@ -117,12 +118,12 @@ where return Ok(None); } - let (idx, float_idx, top_idx) = quantile_idx(quantile, length, null_count, interpol); + let (idx, float_idx, top_idx) = quantile_idx(quantile, length, null_count, method); let sorted = ca.sort(false); let lower = sorted.get(idx).map(|v| v.to_f64().unwrap()); - let opt = match interpol { - QuantileInterpolOptions::Midpoint => { + let opt = match method { + QuantileMethod::Midpoint => { if top_idx == idx { lower } else { @@ -130,7 +131,7 @@ where midpoint_interpol(lower.unwrap(), upper.unwrap()).to_f64() } }, - QuantileInterpolOptions::Linear => { + QuantileMethod::Linear => { if top_idx == idx { lower } else { @@ -149,22 +150,18 @@ where T: PolarsIntegerType, T::Native: TotalOrd, { - fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route if let (Ok(slice), false) = (self.cont_slice(), self.is_sorted_ascending_flag()) { let mut owned = slice.to_vec(); - quantile_slice(&mut owned, quantile, interpol) + quantile_slice(&mut owned, quantile, method) } else { - generic_quantile(self.clone(), quantile, interpol) + generic_quantile(self.clone(), quantile, method) } } fn median(&self) -> Option { - self.quantile(0.5, QuantileInterpolOptions::Linear).unwrap() // unwrap fine since quantile in range + self.quantile(0.5, QuantileMethod::Linear).unwrap() // unwrap fine since quantile in range } } @@ -177,61 +174,52 @@ where pub(crate) fn quantile_faster( mut self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let is_sorted = self.is_sorted_ascending_flag(); if let (Some(slice), false) = (self.cont_slice_mut(), is_sorted) { - quantile_slice(slice, quantile, interpol) + quantile_slice(slice, quantile, method) } else { - self.quantile(quantile, interpol) + self.quantile(quantile, method) } } pub(crate) fn median_faster(self) -> Option { - self.quantile_faster(0.5, QuantileInterpolOptions::Linear) - .unwrap() + self.quantile_faster(0.5, QuantileMethod::Linear).unwrap() } } impl ChunkQuantile for Float32Chunked { - fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let out = if let (Ok(slice), false) = (self.cont_slice(), self.is_sorted_ascending_flag()) { let mut owned = slice.to_vec(); - quantile_slice(&mut owned, quantile, interpol) + quantile_slice(&mut owned, quantile, method) } else { - generic_quantile(self.clone(), quantile, interpol) + generic_quantile(self.clone(), quantile, method) }; out.map(|v| v.map(|v| v as f32)) } fn median(&self) -> Option { - self.quantile(0.5, QuantileInterpolOptions::Linear).unwrap() // unwrap fine since quantile in range + self.quantile(0.5, QuantileMethod::Linear).unwrap() // unwrap fine since quantile in range } } impl ChunkQuantile for Float64Chunked { - fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route if let (Ok(slice), false) = (self.cont_slice(), self.is_sorted_ascending_flag()) { let mut owned = slice.to_vec(); - quantile_slice(&mut owned, quantile, interpol) + quantile_slice(&mut owned, quantile, method) } else { - generic_quantile(self.clone(), quantile, interpol) + generic_quantile(self.clone(), quantile, method) } } fn median(&self) -> Option { - self.quantile(0.5, QuantileInterpolOptions::Linear).unwrap() // unwrap fine since quantile in range + self.quantile(0.5, QuantileMethod::Linear).unwrap() // unwrap fine since quantile in range } } @@ -239,20 +227,19 @@ impl Float64Chunked { pub(crate) fn quantile_faster( mut self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let is_sorted = self.is_sorted_ascending_flag(); if let (Some(slice), false) = (self.cont_slice_mut(), is_sorted) { - quantile_slice(slice, quantile, interpol) + quantile_slice(slice, quantile, method) } else { - self.quantile(quantile, interpol) + self.quantile(quantile, method) } } pub(crate) fn median_faster(self) -> Option { - self.quantile_faster(0.5, QuantileInterpolOptions::Linear) - .unwrap() + self.quantile_faster(0.5, QuantileMethod::Linear).unwrap() } } @@ -260,20 +247,19 @@ impl Float32Chunked { pub(crate) fn quantile_faster( mut self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let is_sorted = self.is_sorted_ascending_flag(); if let (Some(slice), false) = (self.cont_slice_mut(), is_sorted) { - quantile_slice(slice, quantile, interpol).map(|v| v.map(|v| v as f32)) + quantile_slice(slice, quantile, method).map(|v| v.map(|v| v as f32)) } else { - self.quantile(quantile, interpol) + self.quantile(quantile, method) } } pub(crate) fn median_faster(self) -> Option { - self.quantile_faster(0.5, QuantileInterpolOptions::Linear) - .unwrap() + self.quantile_faster(0.5, QuantileMethod::Linear).unwrap() } } diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 33f43d530e45..a3e7f04cc9e1 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -278,11 +278,7 @@ pub trait ChunkQuantile { } /// Aggregate a given quantile of the ChunkedArray. /// Returns `None` if the array is empty or only contains null values. - fn quantile( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult> { Ok(None) } } diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index d83b91b78cff..e66c1ad12875 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -560,12 +560,12 @@ impl Column { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Self { // @scalar-opt unsafe { self.as_materialized_series() - .agg_quantile(groups, quantile, interpol) + .agg_quantile(groups, quantile, method) } .into() } diff --git a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs index fe71148cd49b..aaf24a470969 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs @@ -236,7 +236,7 @@ impl Series { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { // Prevent a rechunk for every individual group. let s = if groups.len() > 1 { @@ -247,13 +247,12 @@ impl Series { use DataType::*; match s.dtype() { - Float32 => s.f32().unwrap().agg_quantile(groups, quantile, interpol), - Float64 => s.f64().unwrap().agg_quantile(groups, quantile, interpol), + Float32 => s.f32().unwrap().agg_quantile(groups, quantile, method), + Float64 => s.f64().unwrap().agg_quantile(groups, quantile, method), dt if dt.is_numeric() || dt.is_temporal() => { let ca = s.to_physical_repr(); let physical_type = ca.dtype(); - let s = - apply_method_physical_integer!(ca, agg_quantile, groups, quantile, interpol); + let s = apply_method_physical_integer!(ca, agg_quantile, groups, quantile, method); if dt.is_logical() { // back to physical and then // back to logical type diff --git a/crates/polars-core/src/frame/group_by/aggregations/mod.rs b/crates/polars-core/src/frame/group_by/aggregations/mod.rs index 092d660fb4d2..19b8d5c2d061 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/mod.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/mod.rs @@ -13,7 +13,7 @@ use arrow::legacy::kernels::rolling::no_nulls::{ }; use arrow::legacy::kernels::rolling::nulls::RollingAggWindowNulls; use arrow::legacy::kernels::take_agg::*; -use arrow::legacy::prelude::QuantileInterpolOptions; +use arrow::legacy::prelude::QuantileMethod; use arrow::legacy::trusted_len::TrustedLenPush; use arrow::types::NativeType; use num_traits::pow::Pow; @@ -295,8 +295,7 @@ impl_take_extremum!(float: f64); /// This trait will ensure the specific dispatch works without complicating /// the trait bounds. trait QuantileDispatcher { - fn _quantile(self, quantile: f64, interpol: QuantileInterpolOptions) - -> PolarsResult>; + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult>; fn _median(self) -> Option; } @@ -307,12 +306,8 @@ where T::Native: Ord, ChunkedArray: IntoSeries, { - fn _quantile( - self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { - self.quantile_faster(quantile, interpol) + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + self.quantile_faster(quantile, method) } fn _median(self) -> Option { self.median_faster() @@ -320,24 +315,16 @@ where } impl QuantileDispatcher for Float32Chunked { - fn _quantile( - self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { - self.quantile_faster(quantile, interpol) + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + self.quantile_faster(quantile, method) } fn _median(self) -> Option { self.median_faster() } } impl QuantileDispatcher for Float64Chunked { - fn _quantile( - self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { - self.quantile_faster(quantile, interpol) + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + self.quantile_faster(quantile, method) } fn _median(self) -> Option { self.median_faster() @@ -348,7 +335,7 @@ unsafe fn agg_quantile_generic( ca: &ChunkedArray, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series where T: PolarsNumericType, @@ -371,7 +358,7 @@ where } let take = { ca.take_unchecked(idx) }; // checked with invalid quantile check - take._quantile(quantile, interpol).unwrap_unchecked() + take._quantile(quantile, method).unwrap_unchecked() }) }, GroupsProxy::Slice { groups, .. } => { @@ -390,7 +377,7 @@ where offset_iter, Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })), ), Some(validity) => { @@ -400,7 +387,7 @@ where offset_iter, Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })), ) }, @@ -418,7 +405,7 @@ where let arr_group = _slice_from_offsets(ca, first, len); // unwrap checked with invalid quantile check arr_group - ._quantile(quantile, interpol) + ._quantile(quantile, method) .unwrap_unchecked() .map(|flt| NumCast::from(flt).unwrap_unchecked()) }, @@ -450,7 +437,7 @@ where }) }, GroupsProxy::Slice { .. } => { - agg_quantile_generic::(ca, groups, 0.5, QuantileInterpolOptions::Linear) + agg_quantile_generic::(ca, groups, 0.5, QuantileMethod::Linear) }, } } @@ -977,9 +964,9 @@ impl Float32Chunked { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { - agg_quantile_generic::<_, Float32Type>(self, groups, quantile, interpol) + agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method) } pub(crate) unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { agg_median_generic::<_, Float32Type>(self, groups) @@ -990,9 +977,9 @@ impl Float64Chunked { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { - agg_quantile_generic::<_, Float64Type>(self, groups, quantile, interpol) + agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method) } pub(crate) unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { agg_median_generic::<_, Float64Type>(self, groups) @@ -1184,9 +1171,9 @@ where &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { - agg_quantile_generic::<_, Float64Type>(self, groups, quantile, interpol) + agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method) } pub(crate) unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { agg_median_generic::<_, Float64Type>(self, groups) diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index 9c56f7c49122..31936a3a5906 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -594,18 +594,14 @@ impl<'df> GroupBy<'df> { /// /// ```rust /// # use polars_core::prelude::*; - /// # use arrow::legacy::prelude::QuantileInterpolOptions; + /// # use arrow::legacy::prelude::QuantileMethod; /// /// fn example(df: DataFrame) -> PolarsResult { - /// df.group_by(["date"])?.select(["temp"]).quantile(0.2, QuantileInterpolOptions::default()) + /// df.group_by(["date"])?.select(["temp"]).quantile(0.2, QuantileMethod::default()) /// } /// ``` #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] - pub fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { + pub fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { polars_ensure!( (0.0..=1.0).contains(&quantile), ComputeError: "`quantile` should be within 0.0 and 1.0" @@ -614,9 +610,9 @@ impl<'df> GroupBy<'df> { for agg_col in agg_cols { let new_name = fmt_group_by_column( agg_col.name().as_str(), - GroupByMethod::Quantile(quantile, interpol), + GroupByMethod::Quantile(quantile, method), ); - let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, interpol) }; + let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, method) }; agg.rename(new_name); cols.push(agg); } @@ -868,7 +864,7 @@ pub enum GroupByMethod { Sum, Groups, NUnique, - Quantile(f64, QuantileInterpolOptions), + Quantile(f64, QuantileMethod), Count { include_nulls: bool, }, diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index b91df29a0a38..ace52993b8a1 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -358,11 +358,7 @@ impl SeriesTrait for SeriesWrap { Ok(Scalar::new(self.dtype().clone(), av)) } - fn quantile_reduce( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult { + fn quantile_reduce(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult { Ok(Scalar::new(self.dtype().clone(), AnyValue::Null)) } diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 30125ccc15b6..612505057eca 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -404,13 +404,9 @@ impl SeriesTrait for SeriesWrap { Ok(self.apply_scale(self.0.std_reduce(ddof))) } - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { self.0 - .quantile_reduce(quantile, interpol) + .quantile_reduce(quantile, method) .map(|v| self.apply_scale(v)) } diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 13b121aee0ca..803ca813aa1c 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -501,12 +501,8 @@ impl SeriesTrait for SeriesWrap { v.as_duration(self.0.time_unit()), )) } - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.0.quantile_reduce(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.0.quantile_reduce(quantile, method)?; let to = self.dtype().to_physical(); let v = v.value().cast(&to); Ok(Scalar::new( diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 24be56671d69..846e326d35b2 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -365,9 +365,9 @@ macro_rules! impl_dyn_series { fn quantile_reduce( &self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult { - QuantileAggSeries::quantile_reduce(&self.0, quantile, interpol) + QuantileAggSeries::quantile_reduce(&self.0, quantile, method) } #[cfg(feature = "bitwise")] fn and_reduce(&self) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 9d8357a905bc..b2cb97e39b69 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -468,9 +468,9 @@ macro_rules! impl_dyn_series { fn quantile_reduce( &self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult { - QuantileAggSeries::quantile_reduce(&self.0, quantile, interpol) + QuantileAggSeries::quantile_reduce(&self.0, quantile, method) } #[cfg(feature = "bitwise")] diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index 1ee69300fa92..14a0752eae1e 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -519,11 +519,7 @@ pub trait SeriesTrait: polars_bail!(opq = std, self._dtype()); } /// Get the quantile of the ChunkedArray as a new Series of length 1. - fn quantile_reduce( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult { + fn quantile_reduce(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult { polars_bail!(opq = quantile, self._dtype()); } /// Get the bitwise AND of the Series as a new Series of length 1, diff --git a/crates/polars-expr/src/expressions/aggregation.rs b/crates/polars-expr/src/expressions/aggregation.rs index af5383e83c83..f1cfa5251899 100644 --- a/crates/polars-expr/src/expressions/aggregation.rs +++ b/crates/polars-expr/src/expressions/aggregation.rs @@ -715,19 +715,19 @@ impl PartitionedAggregation for AggregationExpr { pub struct AggQuantileExpr { pub(crate) input: Arc, pub(crate) quantile: Arc, - pub(crate) interpol: QuantileInterpolOptions, + pub(crate) method: QuantileMethod, } impl AggQuantileExpr { pub fn new( input: Arc, quantile: Arc, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Self { Self { input, quantile, - interpol, + method, } } @@ -750,7 +750,7 @@ impl PhysicalExpr for AggQuantileExpr { let input = self.input.evaluate(df, state)?; let quantile = self.get_quantile(df, state)?; input - .quantile_reduce(quantile, self.interpol) + .quantile_reduce(quantile, self.method) .map(|sc| sc.into_series(input.name().clone())) } #[allow(clippy::ptr_arg)] @@ -771,7 +771,7 @@ impl PhysicalExpr for AggQuantileExpr { let mut agg = unsafe { ac.flat_naive() .into_owned() - .agg_quantile(ac.groups(), quantile, self.interpol) + .agg_quantile(ac.groups(), quantile, self.method) }; agg.rename(keep_name); Ok(AggregationContext::from_agg_state( diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index 8c71c90cd152..53579b763033 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -525,14 +525,6 @@ fn apply_multiple_elementwise<'a>( impl StatsEvaluator for ApplyExpr { fn should_read(&self, stats: &BatchStats) -> PolarsResult { let read = self.should_read_impl(stats)?; - if ExecutionState::new().verbose() { - if read { - eprintln!("parquet file must be read, statistics not sufficient for predicate.") - } else { - eprintln!("parquet file can be skipped, the statistics were sufficient to apply the predicate.") - } - } - Ok(read) } } diff --git a/crates/polars-expr/src/expressions/binary.rs b/crates/polars-expr/src/expressions/binary.rs index c1cb286e7104..23f50af45273 100644 --- a/crates/polars-expr/src/expressions/binary.rs +++ b/crates/polars-expr/src/expressions/binary.rs @@ -351,7 +351,7 @@ mod stats { use ChunkCompareIneq as C; match op { Operator::Eq => apply_operator_stats_eq(min_max, literal), - Operator::NotEq => apply_operator_stats_eq(min_max, literal), + Operator::NotEq => apply_operator_stats_neq(min_max, literal), Operator::Gt => { // Literal is bigger than max value, selection needs all rows. C::gt(literal, min_max).map(|ca| ca.any()).unwrap_or(false) @@ -454,10 +454,6 @@ mod stats { impl StatsEvaluator for BinaryExpr { fn should_read(&self, stats: &BatchStats) -> PolarsResult { - if std::env::var("POLARS_NO_PARQUET_STATISTICS").is_ok() { - return Ok(true); - } - use Operator::*; match ( self.left.as_stats_evaluator(), diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index 620f8bf87089..c4006de0c8ec 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -402,7 +402,9 @@ fn create_physical_expr_inner( }, _ => { if let IRAggExpr::Quantile { - quantile, interpol, .. + quantile, + method: interpol, + .. } = agg { let quantile = diff --git a/crates/polars-io/src/parquet/read/mod.rs b/crates/polars-io/src/parquet/read/mod.rs index 1fec749af5ce..cc0020cc7857 100644 --- a/crates/polars-io/src/parquet/read/mod.rs +++ b/crates/polars-io/src/parquet/read/mod.rs @@ -33,6 +33,7 @@ or set 'streaming'", pub use options::{ParallelStrategy, ParquetOptions}; use polars_error::{ErrString, PolarsError}; +pub use read_impl::{create_sorting_map, try_set_sorted_flag}; #[cfg(feature = "cloud")] pub use reader::ParquetAsyncReader; pub use reader::{BatchedParquetReader, ParquetReader}; diff --git a/crates/polars-io/src/parquet/read/predicates.rs b/crates/polars-io/src/parquet/read/predicates.rs index eb8f7747f078..a3269341c1a3 100644 --- a/crates/polars-io/src/parquet/read/predicates.rs +++ b/crates/polars-io/src/parquet/read/predicates.rs @@ -1,3 +1,4 @@ +use polars_core::config; use polars_core::prelude::*; use polars_parquet::read::statistics::{deserialize, Statistics}; use polars_parquet::read::RowGroupMetadata; @@ -50,18 +51,38 @@ pub fn read_this_row_group( md: &RowGroupMetadata, schema: &ArrowSchema, ) -> PolarsResult { + if std::env::var("POLARS_NO_PARQUET_STATISTICS").is_ok() { + return Ok(true); + } + + let mut should_read = true; + if let Some(pred) = predicate { if let Some(pred) = pred.as_stats_evaluator() { if let Some(stats) = collect_statistics(md, schema)? { - let should_read = pred.should_read(&stats); + let pred_result = pred.should_read(&stats); + // a parquet file may not have statistics of all columns - if matches!(should_read, Ok(false)) { - return Ok(false); - } else if !matches!(should_read, Err(PolarsError::ColumnNotFound(_))) { - let _ = should_read?; + match pred_result { + Err(PolarsError::ColumnNotFound(errstr)) => { + return Err(PolarsError::ColumnNotFound(errstr)) + }, + Ok(false) => should_read = false, + _ => {}, } } } + + if config::verbose() { + if should_read { + eprintln!( + "parquet row group must be read, statistics not sufficient for predicate." + ); + } else { + eprintln!("parquet row group can be skipped, the statistics were sufficient to apply the predicate."); + } + } } - Ok(true) + + Ok(should_read) } diff --git a/crates/polars-io/src/parquet/read/read_impl.rs b/crates/polars-io/src/parquet/read/read_impl.rs index 0389a73b5081..abe8f1887790 100644 --- a/crates/polars-io/src/parquet/read/read_impl.rs +++ b/crates/polars-io/src/parquet/read/read_impl.rs @@ -7,8 +7,9 @@ use arrow::bitmap::MutableBitmap; use arrow::datatypes::ArrowSchemaRef; use polars_core::chunked_array::builder::NullChunkedBuilder; use polars_core::prelude::*; +use polars_core::series::IsSorted; use polars_core::utils::{accumulate_dataframes_vertical, split_df}; -use polars_core::POOL; +use polars_core::{config, POOL}; use polars_parquet::parquet::error::ParquetResult; use polars_parquet::parquet::statistics::Statistics; use polars_parquet::read::{ @@ -60,6 +61,57 @@ fn assert_dtypes(dtype: &ArrowDataType) { } } +fn should_copy_sortedness(dtype: &DataType) -> bool { + // @NOTE: For now, we are a bit conservative with this. + use DataType as D; + + matches!( + dtype, + D::Int8 | D::Int16 | D::Int32 | D::Int64 | D::UInt8 | D::UInt16 | D::UInt32 | D::UInt64 + ) +} + +pub fn try_set_sorted_flag( + series: &mut Series, + col_idx: usize, + sorting_map: &PlHashMap, +) { + if let Some(is_sorted) = sorting_map.get(&col_idx) { + if should_copy_sortedness(series.dtype()) { + if config::verbose() { + eprintln!( + "Parquet conserved SortingColumn for column chunk of '{}' to {is_sorted:?}", + series.name() + ); + } + + series.set_sorted_flag(*is_sorted); + } + } +} + +pub fn create_sorting_map(md: &RowGroupMetadata) -> PlHashMap { + let capacity = md.sorting_columns().map_or(0, |s| s.len()); + let mut sorting_map = PlHashMap::with_capacity(capacity); + + if let Some(sorting_columns) = md.sorting_columns() { + for sorting in sorting_columns { + let prev_value = sorting_map.insert( + sorting.column_idx as usize, + if sorting.descending { + IsSorted::Descending + } else { + IsSorted::Ascending + }, + ); + + debug_assert!(prev_value.is_none()); + } + } + + sorting_map +} + fn column_idx_to_series( column_i: usize, // The metadata belonging to this column @@ -68,6 +120,8 @@ fn column_idx_to_series( file_schema: &ArrowSchema, store: &mmap::ColumnStore, ) -> PolarsResult { + let did_filter = filter.is_some(); + let field = file_schema.get_at_index(column_i).unwrap().1; #[cfg(debug_assertions)] @@ -91,6 +145,11 @@ fn column_idx_to_series( _ => {}, } + // We cannot trust the statistics if we filtered the parquet already. + if did_filter { + return Ok(series); + } + // See if we can find some statistics for this series. If we cannot find anything just return // the series as is. let Ok(Some(stats)) = stats.map(|mut s| s.pop().flatten()) else { @@ -320,6 +379,8 @@ fn rg_to_dfs_prefiltered( } } + let sorting_map = create_sorting_map(md); + // Collect the data for the live columns let live_columns = (0..num_live_columns) .into_par_iter() @@ -338,8 +399,12 @@ fn rg_to_dfs_prefiltered( let part = iter.collect::>(); - column_idx_to_series(col_idx, part.as_slice(), None, schema, store) - .map(Column::from) + let mut series = + column_idx_to_series(col_idx, part.as_slice(), None, schema, store)?; + + try_set_sorted_flag(&mut series, col_idx, &sorting_map); + + Ok(series.into_column()) }) .collect::>>()?; @@ -445,7 +510,7 @@ fn rg_to_dfs_prefiltered( array.filter(&mask_arr) }; - let array = if mask_setting.should_prefilter( + let mut series = if mask_setting.should_prefilter( prefilter_cost, &schema.get_at_index(col_idx).unwrap().1.dtype, ) { @@ -454,9 +519,11 @@ fn rg_to_dfs_prefiltered( post()? }; - debug_assert_eq!(array.len(), filter_mask.set_bits()); + debug_assert_eq!(series.len(), filter_mask.set_bits()); - Ok(array.into_column()) + try_set_sorted_flag(&mut series, col_idx, &sorting_map); + + Ok(series.into_column()) }) .collect::>>()?; @@ -569,6 +636,8 @@ fn rg_to_dfs_optionally_par_over_columns( assert!(std::env::var("POLARS_PANIC_IF_PARQUET_PARSED").is_err()) } + let sorting_map = create_sorting_map(md); + let columns = if let ParallelStrategy::Columns = parallel { POOL.install(|| { projection @@ -586,14 +655,17 @@ fn rg_to_dfs_optionally_par_over_columns( let part = iter.collect::>(); - column_idx_to_series( + let mut series = column_idx_to_series( *column_i, part.as_slice(), Some(Filter::new_ranged(rg_slice.0, rg_slice.0 + rg_slice.1)), schema, store, - ) - .map(Column::from) + )?; + + try_set_sorted_flag(&mut series, *column_i, &sorting_map); + + Ok(series.into_column()) }) .collect::>>() })? @@ -613,14 +685,17 @@ fn rg_to_dfs_optionally_par_over_columns( let part = iter.collect::>(); - column_idx_to_series( + let mut series = column_idx_to_series( *column_i, part.as_slice(), Some(Filter::new_ranged(rg_slice.0, rg_slice.0 + rg_slice.1)), schema, store, - ) - .map(Column::from) + )?; + + try_set_sorted_flag(&mut series, *column_i, &sorting_map); + + Ok(series.into_column()) }) .collect::>>()? }; @@ -705,6 +780,8 @@ fn rg_to_dfs_par_over_rg( assert!(std::env::var("POLARS_PANIC_IF_PARQUET_PARSED").is_err()) } + let sorting_map = create_sorting_map(md); + let columns = projection .iter() .map(|column_i| { @@ -720,14 +797,17 @@ fn rg_to_dfs_par_over_rg( let part = iter.collect::>(); - column_idx_to_series( + let mut series = column_idx_to_series( *column_i, part.as_slice(), Some(Filter::new_ranged(slice.0, slice.0 + slice.1)), schema, store, - ) - .map(Column::from) + )?; + + try_set_sorted_flag(&mut series, *column_i, &sorting_map); + + Ok(series.into_column()) }) .collect::>>()?; diff --git a/crates/polars-io/src/parquet/read/reader.rs b/crates/polars-io/src/parquet/read/reader.rs index 2a70ef2c5046..25d1f51b098b 100644 --- a/crates/polars-io/src/parquet/read/reader.rs +++ b/crates/polars-io/src/parquet/read/reader.rs @@ -89,9 +89,15 @@ impl ParquetReader { projected_arrow_schema: Option<&ArrowSchema>, allow_missing_columns: bool, ) -> PolarsResult { + // `self.schema` gets overwritten if allow_missing_columns + let this_schema_width = self.schema()?.len(); + if allow_missing_columns { // Must check the dtypes - ensure_matching_dtypes_if_found(first_schema, self.schema()?.as_ref())?; + ensure_matching_dtypes_if_found( + projected_arrow_schema.unwrap_or(first_schema.as_ref()), + self.schema()?.as_ref(), + )?; self.schema.replace(first_schema.clone()); } @@ -104,7 +110,7 @@ impl ParquetReader { projected_arrow_schema, )?; } else { - if schema.len() > first_schema.len() { + if this_schema_width > first_schema.len() { polars_bail!( SchemaMismatch: "parquet file contained extra columns and no selection was given" @@ -328,9 +334,15 @@ impl ParquetAsyncReader { projected_arrow_schema: Option<&ArrowSchema>, allow_missing_columns: bool, ) -> PolarsResult { + // `self.schema` gets overwritten if allow_missing_columns + let this_schema_width = self.schema().await?.len(); + if allow_missing_columns { // Must check the dtypes - ensure_matching_dtypes_if_found(first_schema, self.schema().await?.as_ref())?; + ensure_matching_dtypes_if_found( + projected_arrow_schema.unwrap_or(first_schema.as_ref()), + self.schema().await?.as_ref(), + )?; self.schema.replace(first_schema.clone()); } @@ -343,7 +355,7 @@ impl ParquetAsyncReader { projected_arrow_schema, )?; } else { - if schema.len() > first_schema.len() { + if this_schema_width > first_schema.len() { polars_bail!( SchemaMismatch: "parquet file contained extra columns and no selection was given" diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index fd4334cad066..f073c6643f8c 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -1004,7 +1004,7 @@ impl LazyFrame { /// ```rust /// use polars_core::prelude::*; /// use polars_lazy::prelude::*; - /// use arrow::legacy::prelude::QuantileInterpolOptions; + /// use arrow::legacy::prelude::QuantileMethod; /// /// fn example(df: DataFrame) -> LazyFrame { /// df.lazy() @@ -1012,7 +1012,7 @@ impl LazyFrame { /// .agg([ /// col("rain").min().alias("min_rain"), /// col("rain").sum().alias("sum_rain"), - /// col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"), + /// col("rain").quantile(lit(0.5), QuantileMethod::Nearest).alias("median_rain"), /// ]) /// } /// ``` @@ -1495,10 +1495,10 @@ impl LazyFrame { } /// Aggregate all the columns as their quantile values. - pub fn quantile(self, quantile: Expr, interpol: QuantileInterpolOptions) -> Self { + pub fn quantile(self, quantile: Expr, method: QuantileMethod) -> Self { self.map_private(DslFunction::Stats(StatsFunction::Quantile { quantile, - interpol, + method, })) } @@ -1885,7 +1885,7 @@ impl LazyGroupBy { /// ```rust /// use polars_core::prelude::*; /// use polars_lazy::prelude::*; - /// use arrow::legacy::prelude::QuantileInterpolOptions; + /// use arrow::legacy::prelude::QuantileMethod; /// /// fn example(df: DataFrame) -> LazyFrame { /// df.lazy() @@ -1893,7 +1893,7 @@ impl LazyGroupBy { /// .agg([ /// col("rain").min().alias("min_rain"), /// col("rain").sum().alias("sum_rain"), - /// col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"), + /// col("rain").quantile(lit(0.5), QuantileMethod::Nearest).alias("median_rain"), /// ]) /// } /// ``` diff --git a/crates/polars-lazy/src/lib.rs b/crates/polars-lazy/src/lib.rs index 3059384a1c8c..f3dff5710170 100644 --- a/crates/polars-lazy/src/lib.rs +++ b/crates/polars-lazy/src/lib.rs @@ -104,7 +104,7 @@ //! use polars_core::prelude::*; //! use polars_core::df; //! use polars_lazy::prelude::*; -//! use arrow::legacy::prelude::QuantileInterpolOptions; +//! use arrow::legacy::prelude::QuantileMethod; //! //! fn example() -> PolarsResult { //! let df = df!( @@ -118,7 +118,7 @@ //! .agg([ //! col("rain").min().alias("min_rain"), //! col("rain").sum().alias("sum_rain"), -//! col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"), +//! col("rain").quantile(lit(0.5), QuantileMethod::Nearest).alias("median_rain"), //! ]) //! .sort(["date"], Default::default()) //! .collect() diff --git a/crates/polars-mem-engine/src/executors/scan/ndjson.rs b/crates/polars-mem-engine/src/executors/scan/ndjson.rs index 58862bd71f9e..1f90e07a72c1 100644 --- a/crates/polars-mem-engine/src/executors/scan/ndjson.rs +++ b/crates/polars-mem-engine/src/executors/scan/ndjson.rs @@ -1,5 +1,6 @@ use polars_core::config; use polars_core::utils::accumulate_dataframes_vertical; +use polars_io::prelude::{JsonLineReader, SerReader}; use polars_io::utils::compression::maybe_decompress_bytes; use super::*; diff --git a/crates/polars-ops/src/series/ops/cut.rs b/crates/polars-ops/src/series/ops/cut.rs index c78aa67efca8..b7b8d3e9f179 100644 --- a/crates/polars-ops/src/series/ops/cut.rs +++ b/crates/polars-ops/src/series/ops/cut.rs @@ -144,11 +144,7 @@ pub fn qcut( let s2 = s.sort(SortOptions::default())?; let ca = s2.f64()?; - let f = |&p| { - ca.quantile(p, QuantileInterpolOptions::Linear) - .unwrap() - .unwrap() - }; + let f = |&p| ca.quantile(p, QuantileMethod::Linear).unwrap().unwrap(); let mut qbreaks: Vec<_> = probs.iter().map(f).collect(); qbreaks.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); diff --git a/crates/polars-parquet/Cargo.toml b/crates/polars-parquet/Cargo.toml index 26a57b22e713..881c9a477398 100644 --- a/crates/polars-parquet/Cargo.toml +++ b/crates/polars-parquet/Cargo.toml @@ -22,7 +22,7 @@ fallible-streaming-iterator = { workspace = true, optional = true } futures = { workspace = true, optional = true } hashbrown = { workspace = true } num-traits = { workspace = true } -polars-compute = { workspace = true } +polars-compute = { workspace = true, features = ["approx_unique"] } polars-error = { workspace = true } polars-utils = { workspace = true, features = ["mmap"] } simdutf8 = { workspace = true } diff --git a/crates/polars-parquet/src/arrow/write/dictionary.rs b/crates/polars-parquet/src/arrow/write/dictionary.rs index 4e0d57302314..17527fc488f7 100644 --- a/crates/polars-parquet/src/arrow/write/dictionary.rs +++ b/crates/polars-parquet/src/arrow/write/dictionary.rs @@ -3,11 +3,12 @@ use arrow::array::{ }; use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::buffer::Buffer; -use arrow::datatypes::{ArrowDataType, IntegerType}; +use arrow::datatypes::{ArrowDataType, IntegerType, PhysicalType}; +use arrow::legacy::utils::CustomIterTools; +use arrow::trusted_len::TrustMyLength; use arrow::types::NativeType; use polars_compute::min_max::MinMaxKernel; use polars_error::{polars_bail, PolarsResult}; -use polars_utils::unwrap::UnwrapUncheckedRelease; use super::binary::{ build_statistics as binary_build_statistics, encode_plain as binary_encode_plain, @@ -31,33 +32,51 @@ use crate::parquet::CowBuffer; use crate::write::DynIter; trait MinMaxThreshold { - const DELTA_THRESHOLD: Self; + const DELTA_THRESHOLD: usize; + const BITMASK_THRESHOLD: usize; + + fn from_start_and_offset(start: Self, offset: usize) -> Self; } macro_rules! minmaxthreshold_impls { - ($($t:ty => $threshold:literal,)+) => { + ($($signed:ty, $unsigned:ty => $threshold:literal, $bm_threshold:expr,)+) => { $( - impl MinMaxThreshold for $t { - const DELTA_THRESHOLD: Self = $threshold; + impl MinMaxThreshold for $signed { + const DELTA_THRESHOLD: usize = $threshold; + const BITMASK_THRESHOLD: usize = $bm_threshold; + + fn from_start_and_offset(start: Self, offset: usize) -> Self { + start + ((offset as $unsigned) as $signed) + } + } + impl MinMaxThreshold for $unsigned { + const DELTA_THRESHOLD: usize = $threshold; + const BITMASK_THRESHOLD: usize = $bm_threshold; + + fn from_start_and_offset(start: Self, offset: usize) -> Self { + start + (offset as $unsigned) + } } )+ }; } minmaxthreshold_impls! { - i8 => 16, - i16 => 256, - i32 => 512, - i64 => 2048, - u8 => 16, - u16 => 256, - u32 => 512, - u64 => 2048, + i8, u8 => 16, u8::MAX as usize, + i16, u16 => 256, u16::MAX as usize, + i32, u32 => 512, u16::MAX as usize, + i64, u64 => 2048, u16::MAX as usize, +} + +enum DictionaryDecision { + NotWorth, + TryAgain, + Found(DictionaryArray), } fn min_max_integer_encode_as_dictionary_optional<'a, E, T>( array: &'a dyn Array, -) -> Option> +) -> DictionaryDecision where E: std::fmt::Debug, T: NativeType @@ -65,26 +84,82 @@ where + std::cmp::Ord + TryInto + std::ops::Sub - + num_traits::CheckedSub, + + num_traits::CheckedSub + + num_traits::cast::AsPrimitive, std::ops::RangeInclusive: Iterator, PrimitiveArray: MinMaxKernel = T>, { - use ArrowDataType as DT; - let (min, max): (T, T) = as MinMaxKernel>::min_max_ignore_nan_kernel( + let min_max = as MinMaxKernel>::min_max_ignore_nan_kernel( array.as_any().downcast_ref().unwrap(), - )?; + ); + + let Some((min, max)) = min_max else { + return DictionaryDecision::TryAgain; + }; debug_assert!(max >= min, "{max} >= {min}"); - if !max - .checked_sub(&min) - .is_some_and(|v| v <= T::DELTA_THRESHOLD) - { - return None; + let Some(diff) = max.checked_sub(&min) else { + return DictionaryDecision::TryAgain; + }; + + let diff = diff.as_(); + + if diff > T::BITMASK_THRESHOLD { + return DictionaryDecision::TryAgain; + } + + let mut seen_mask = MutableBitmap::from_len_zeroed(diff + 1); + + let array = array.as_any().downcast_ref::>().unwrap(); + + if array.has_nulls() { + for v in array.non_null_values_iter() { + let offset = (v - min).as_(); + debug_assert!(offset <= diff); + + unsafe { + seen_mask.set_unchecked(offset, true); + } + } + } else { + for v in array.values_iter() { + let offset = (*v - min).as_(); + debug_assert!(offset <= diff); + + unsafe { + seen_mask.set_unchecked(offset, true); + } + } } - // @TODO: This currently overestimates the values, it might be interesting to use the unique - // kernel here. - let values = PrimitiveArray::new(DT::from(T::PRIMITIVE), (min..=max).collect(), None); + let cardinality = seen_mask.set_bits(); + + let mut is_worth_it = false; + + is_worth_it |= cardinality <= T::DELTA_THRESHOLD; + is_worth_it |= (cardinality as f64) / (array.len() as f64) < 0.75; + + if !is_worth_it { + return DictionaryDecision::NotWorth; + } + + let seen_mask = seen_mask.freeze(); + + // SAFETY: We just did the calculation for this. + let indexes = seen_mask + .true_idx_iter() + .map(|idx| T::from_start_and_offset(min, idx)); + let indexes = unsafe { TrustMyLength::new(indexes, cardinality) }; + let indexes = indexes.collect_trusted::>(); + + let mut lookup = vec![0u16; diff + 1]; + + for (i, &idx) in indexes.iter().enumerate() { + lookup[(idx - min).as_()] = i as u16; + } + + use ArrowDataType as DT; + let values = PrimitiveArray::new(DT::from(T::PRIMITIVE), indexes.into(), None); let values = Box::new(values); let keys: Buffer = array @@ -93,20 +168,19 @@ where .unwrap() .values() .iter() - .map(|v| unsafe { + .map(|v| { // @NOTE: // Since the values might contain nulls which have a undefined value. We just // clamp the values to between the min and max value. This way, they will still - // be valid dictionary keys. This is mostly to make the - // unwrap_unchecked_release not produce any unsafety. - (*v.clamp(&min, &max) - min) - .try_into() - .unwrap_unchecked_release() + // be valid dictionary keys. + let idx = *v.clamp(&min, &max) - min; + let value = unsafe { lookup.get_unchecked(idx.as_()) }; + (*value).into() }) .collect(); let keys = PrimitiveArray::new(DT::UInt32, keys, array.validity().cloned()); - Some( + DictionaryDecision::Found( DictionaryArray::::try_new( ArrowDataType::Dictionary( IntegerType::UInt32, @@ -126,26 +200,15 @@ pub(crate) fn encode_as_dictionary_optional( type_: PrimitiveType, options: WriteOptions, ) -> Option>>> { - use ArrowDataType as DT; - let fast_dictionary = match array.dtype() { - DT::Int8 => min_max_integer_encode_as_dictionary_optional::<_, i8>(array), - DT::Int16 => min_max_integer_encode_as_dictionary_optional::<_, i16>(array), - DT::Int32 | DT::Date32 | DT::Time32(_) => { - min_max_integer_encode_as_dictionary_optional::<_, i32>(array) - }, - DT::Int64 | DT::Date64 | DT::Time64(_) | DT::Timestamp(_, _) | DT::Duration(_) => { - min_max_integer_encode_as_dictionary_optional::<_, i64>(array) - }, - DT::UInt8 => min_max_integer_encode_as_dictionary_optional::<_, u8>(array), - DT::UInt16 => min_max_integer_encode_as_dictionary_optional::<_, u16>(array), - DT::UInt32 => min_max_integer_encode_as_dictionary_optional::<_, u32>(array), - DT::UInt64 => min_max_integer_encode_as_dictionary_optional::<_, u64>(array), - _ => None, - }; + if array.is_empty() { + let array = DictionaryArray::::new_empty(ArrowDataType::Dictionary( + IntegerType::UInt32, + Box::new(array.dtype().clone()), + false, // @TODO: This might be able to be set to true? + )); - if let Some(fast_dictionary) = fast_dictionary { return Some(array_to_pages( - &fast_dictionary, + &array, type_, nested, options, @@ -153,9 +216,44 @@ pub(crate) fn encode_as_dictionary_optional( )); } + use arrow::types::PrimitiveType as PT; + let fast_dictionary = match array.dtype().to_physical_type() { + PhysicalType::Primitive(pt) => match pt { + PT::Int8 => min_max_integer_encode_as_dictionary_optional::<_, i8>(array), + PT::Int16 => min_max_integer_encode_as_dictionary_optional::<_, i16>(array), + PT::Int32 => min_max_integer_encode_as_dictionary_optional::<_, i32>(array), + PT::Int64 => min_max_integer_encode_as_dictionary_optional::<_, i64>(array), + PT::UInt8 => min_max_integer_encode_as_dictionary_optional::<_, u8>(array), + PT::UInt16 => min_max_integer_encode_as_dictionary_optional::<_, u16>(array), + PT::UInt32 => min_max_integer_encode_as_dictionary_optional::<_, u32>(array), + PT::UInt64 => min_max_integer_encode_as_dictionary_optional::<_, u64>(array), + _ => DictionaryDecision::TryAgain, + }, + _ => DictionaryDecision::TryAgain, + }; + + match fast_dictionary { + DictionaryDecision::NotWorth => return None, + DictionaryDecision::Found(dictionary_array) => { + return Some(array_to_pages( + &dictionary_array, + type_, + nested, + options, + Encoding::RleDictionary, + )) + }, + DictionaryDecision::TryAgain => {}, + } + let dtype = Box::new(array.dtype().clone()); - let len_before = array.len(); + let estimated_cardinality = polars_compute::cardinality::estimate_cardinality(array); + + if array.len() > 128 && (estimated_cardinality as f64) / (array.len() as f64) > 0.75 { + return None; + } + // This does the group by. let array = arrow::compute::cast::cast( array, @@ -169,10 +267,6 @@ pub(crate) fn encode_as_dictionary_optional( .downcast_ref::>() .unwrap(); - if (array.values().len() as f64) / (len_before as f64) > 0.75 { - return None; - } - Some(array_to_pages( array, type_, diff --git a/crates/polars-parquet/src/parquet/metadata/row_metadata.rs b/crates/polars-parquet/src/parquet/metadata/row_metadata.rs index 9cca27553415..bf27bffb66ef 100644 --- a/crates/polars-parquet/src/parquet/metadata/row_metadata.rs +++ b/crates/polars-parquet/src/parquet/metadata/row_metadata.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use hashbrown::hash_map::RawEntryMut; -use parquet_format_safe::RowGroup; +use parquet_format_safe::{RowGroup, SortingColumn}; use polars_utils::aliases::{InitHashMaps, PlHashMap}; use polars_utils::idx_vec::UnitVec; use polars_utils::pl_str::PlSmallStr; @@ -41,6 +41,7 @@ pub struct RowGroupMetadata { num_rows: usize, total_byte_size: usize, full_byte_range: core::ops::Range, + sorting_columns: Option>, } impl RowGroupMetadata { @@ -59,6 +60,11 @@ impl RowGroupMetadata { .map(|x| x.iter().map(|&x| &self.columns[x])) } + /// Fetch all columns under this root name if it exists. + pub fn columns_idxs_under_root_iter<'a>(&'a self, root_name: &str) -> Option<&'a [usize]> { + self.column_lookup.get(root_name).map(|x| x.as_slice()) + } + /// Number of rows in this row group. pub fn num_rows(&self) -> usize { self.num_rows @@ -85,6 +91,10 @@ impl RowGroupMetadata { self.columns.iter().map(|x| x.byte_range()) } + pub fn sorting_columns(&self) -> Option<&[SortingColumn]> { + self.sorting_columns.as_deref() + } + /// Method to convert from Thrift. pub(crate) fn try_from_thrift( schema_descr: &SchemaDescriptor, @@ -106,6 +116,8 @@ impl RowGroupMetadata { 0..0 }; + let sorting_columns = rg.sorting_columns.clone(); + let columns = rg .columns .into_iter() @@ -131,6 +143,7 @@ impl RowGroupMetadata { num_rows, total_byte_size, full_byte_range, + sorting_columns, }) } } diff --git a/crates/polars-parquet/src/parquet/read/page/reader.rs b/crates/polars-parquet/src/parquet/read/page/reader.rs index cd23af0499d7..ad453a0ff50a 100644 --- a/crates/polars-parquet/src/parquet/read/page/reader.rs +++ b/crates/polars-parquet/src/parquet/read/page/reader.rs @@ -13,6 +13,7 @@ use crate::parquet::page::{ ParquetPageHeader, }; use crate::parquet::CowBuffer; +use crate::write::Encoding; /// This meta is a small part of [`ColumnChunkMetadata`]. #[derive(Debug, Clone, PartialEq, Eq)] @@ -251,7 +252,10 @@ pub(super) fn finish_page( })?; if do_verbose { - println!("DictPage ( )"); + eprintln!( + "Parquet DictPage ( num_values: {}, datatype: {:?} )", + dict_header.num_values, descriptor.primitive_type + ); } let is_sorted = dict_header.is_sorted.unwrap_or(false); @@ -275,9 +279,11 @@ pub(super) fn finish_page( })?; if do_verbose { - println!( - "DataPageV1 ( num_values: {}, datatype: {:?}, encoding: {:?} )", - header.num_values, descriptor.primitive_type, header.encoding + eprintln!( + "Parquet DataPageV1 ( num_values: {}, datatype: {:?}, encoding: {:?} )", + header.num_values, + descriptor.primitive_type, + Encoding::try_from(header.encoding).ok() ); } @@ -298,8 +304,10 @@ pub(super) fn finish_page( if do_verbose { println!( - "DataPageV2 ( num_values: {}, datatype: {:?}, encoding: {:?} )", - header.num_values, descriptor.primitive_type, header.encoding + "Parquet DataPageV2 ( num_values: {}, datatype: {:?}, encoding: {:?} )", + header.num_values, + descriptor.primitive_type, + Encoding::try_from(header.encoding).ok() ); } diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index 33eb20e86da6..32fa45528d3a 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -33,7 +33,7 @@ pub enum AggExpr { Quantile { expr: Arc, quantile: Arc, - interpol: QuantileInterpolOptions, + method: QuantileMethod, }, Sum(Arc), AggGroups(Arc), diff --git a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs index e1ef64ee02ec..8363f6baa2fa 100644 --- a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs +++ b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs @@ -33,8 +33,8 @@ pub fn median(name: &str) -> Expr { } /// Find a specific quantile of all the values in the column named `name`. -pub fn quantile(name: &str, quantile: Expr, interpol: QuantileInterpolOptions) -> Expr { - col(name).quantile(quantile, interpol) +pub fn quantile(name: &str, quantile: Expr, method: QuantileMethod) -> Expr { + col(name).quantile(quantile, method) } /// Negates a boolean column. diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 5d814877a977..a88ff858e6ee 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -45,7 +45,7 @@ use std::sync::Arc; pub use arity::*; #[cfg(feature = "dtype-array")] pub use array::*; -use arrow::legacy::prelude::QuantileInterpolOptions; +use arrow::legacy::prelude::QuantileMethod; pub use expr::*; pub use function_expr::schema::FieldsMapper; pub use function_expr::*; @@ -227,11 +227,11 @@ impl Expr { } /// Compute the quantile per group. - pub fn quantile(self, quantile: Expr, interpol: QuantileInterpolOptions) -> Self { + pub fn quantile(self, quantile: Expr, method: QuantileMethod) -> Self { AggExpr::Quantile { expr: Arc::new(self), quantile: Arc::new(quantile), - interpol, + method, } .into() } @@ -1358,13 +1358,13 @@ impl Expr { pub fn rolling_quantile_by( self, by: Expr, - interpol: QuantileInterpolOptions, + method: QuantileMethod, quantile: f64, mut options: RollingOptionsDynamicWindow, ) -> Expr { options.fn_params = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })); self.finish_rolling_by(by, options, RollingFunctionBy::QuantileBy) @@ -1385,7 +1385,7 @@ impl Expr { /// Apply a rolling median based on another column. #[cfg(feature = "rolling_window_by")] pub fn rolling_median_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { - self.rolling_quantile_by(by, QuantileInterpolOptions::Linear, 0.5, options) + self.rolling_quantile_by(by, QuantileMethod::Linear, 0.5, options) } /// Apply a rolling minimum. @@ -1425,7 +1425,7 @@ impl Expr { /// See: [`RollingAgg::rolling_median`] #[cfg(feature = "rolling_window")] pub fn rolling_median(self, options: RollingOptionsFixedWindow) -> Expr { - self.rolling_quantile(QuantileInterpolOptions::Linear, 0.5, options) + self.rolling_quantile(QuantileMethod::Linear, 0.5, options) } /// Apply a rolling quantile. @@ -1434,13 +1434,13 @@ impl Expr { #[cfg(feature = "rolling_window")] pub fn rolling_quantile( self, - interpol: QuantileInterpolOptions, + method: QuantileMethod, quantile: f64, mut options: RollingOptionsFixedWindow, ) -> Expr { options.fn_params = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })); self.finish_rolling(options, RollingFunction::Quantile) diff --git a/crates/polars-plan/src/plans/aexpr/mod.rs b/crates/polars-plan/src/plans/aexpr/mod.rs index 565710c0dbaf..286ea86ac968 100644 --- a/crates/polars-plan/src/plans/aexpr/mod.rs +++ b/crates/polars-plan/src/plans/aexpr/mod.rs @@ -44,7 +44,7 @@ pub enum IRAggExpr { Quantile { expr: Node, quantile: Node, - interpol: QuantileInterpolOptions, + method: QuantileMethod, }, Sum(Node), Count(Node, bool), @@ -62,7 +62,9 @@ impl Hash for IRAggExpr { Self::Min { propagate_nans, .. } | Self::Max { propagate_nans, .. } => { propagate_nans.hash(state) }, - Self::Quantile { interpol, .. } => interpol.hash(state), + Self::Quantile { + method: interpol, .. + } => interpol.hash(state), Self::Std(_, v) | Self::Var(_, v) => v.hash(state), #[cfg(feature = "bitwise")] Self::Bitwise(_, f) => f.hash(state), @@ -92,7 +94,7 @@ impl IRAggExpr { propagate_nans: r, .. }, ) => l == r, - (Quantile { interpol: l, .. }, Quantile { interpol: r, .. }) => l == r, + (Quantile { method: l, .. }, Quantile { method: r, .. }) => l == r, (Std(_, l), Std(_, r)) => l == r, (Var(_, l), Var(_, r)) => l == r, #[cfg(feature = "bitwise")] diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index c0178f5b383c..7ee9c7f069d7 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -738,9 +738,9 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult |name| col(name.clone()).std(ddof), &input_schema, ), - StatsFunction::Quantile { quantile, interpol } => stats_helper( + StatsFunction::Quantile { quantile, method } => stats_helper( |dt| dt.is_numeric(), - |name| col(name.clone()).quantile(quantile.clone(), interpol), + |name| col(name.clone()).quantile(quantile.clone(), method), &input_schema, ), StatsFunction::Mean => stats_helper( diff --git a/crates/polars-plan/src/plans/conversion/expr_to_ir.rs b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs index 6873ad3f6851..d3e0c17f8098 100644 --- a/crates/polars-plan/src/plans/conversion/expr_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs @@ -237,11 +237,11 @@ pub(super) fn to_aexpr_impl( AggExpr::Quantile { expr, quantile, - interpol, + method, } => IRAggExpr::Quantile { expr: to_aexpr_impl_materialized_lit(owned(expr), arena, state)?, quantile: to_aexpr_impl_materialized_lit(owned(quantile), arena, state)?, - interpol, + method, }, AggExpr::Sum(expr) => { IRAggExpr::Sum(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) diff --git a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs index 5d2e4c373b30..160b70951962 100644 --- a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs +++ b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs @@ -129,14 +129,14 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { IRAggExpr::Quantile { expr, quantile, - interpol, + method, } => { let expr = node_to_expr(expr, expr_arena); let quantile = node_to_expr(quantile, expr_arena); AggExpr::Quantile { expr: Arc::new(expr), quantile: Arc::new(quantile), - interpol, + method, } .into() }, diff --git a/crates/polars-plan/src/plans/functions/dsl.rs b/crates/polars-plan/src/plans/functions/dsl.rs index e470bd3044bc..f1aa33a7e7dd 100644 --- a/crates/polars-plan/src/plans/functions/dsl.rs +++ b/crates/polars-plan/src/plans/functions/dsl.rs @@ -72,7 +72,7 @@ pub enum StatsFunction { }, Quantile { quantile: Expr, - interpol: QuantileInterpolOptions, + method: QuantileMethod, }, Median, Mean, diff --git a/crates/polars-plan/src/plans/visitor/expr.rs b/crates/polars-plan/src/plans/visitor/expr.rs index 71b287d03b85..62a64319ae2e 100644 --- a/crates/polars-plan/src/plans/visitor/expr.rs +++ b/crates/polars-plan/src/plans/visitor/expr.rs @@ -67,7 +67,7 @@ impl TreeWalker for Expr { Mean(x) => Mean(am(x, f)?), Implode(x) => Implode(am(x, f)?), Count(x, nulls) => Count(am(x, f)?, nulls), - Quantile { expr, quantile, interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, interpol }, + Quantile { expr, quantile, method: interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, method: interpol }, Sum(x) => Sum(am(x, f)?), AggGroups(x) => AggGroups(am(x, f)?), Std(x, ddf) => Std(am(x, f)?, ddf), diff --git a/crates/polars-python/src/conversion/mod.rs b/crates/polars-python/src/conversion/mod.rs index abde51745554..26bd02c6e540 100644 --- a/crates/polars-python/src/conversion/mod.rs +++ b/crates/polars-python/src/conversion/mod.rs @@ -986,17 +986,18 @@ impl<'py> FromPyObject<'py> for Wrap { } } -impl<'py> FromPyObject<'py> for Wrap { +impl<'py> FromPyObject<'py> for Wrap { fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { let parsed = match &*ob.extract::()? { - "lower" => QuantileInterpolOptions::Lower, - "higher" => QuantileInterpolOptions::Higher, - "nearest" => QuantileInterpolOptions::Nearest, - "linear" => QuantileInterpolOptions::Linear, - "midpoint" => QuantileInterpolOptions::Midpoint, + "lower" => QuantileMethod::Lower, + "higher" => QuantileMethod::Higher, + "nearest" => QuantileMethod::Nearest, + "linear" => QuantileMethod::Linear, + "midpoint" => QuantileMethod::Midpoint, + "equiprobable" => QuantileMethod::Equiprobable, v => { return Err(PyValueError::new_err(format!( - "`interpolation` must be one of {{'lower', 'higher', 'nearest', 'linear', 'midpoint'}}, got {v}", + "`interpolation` must be one of {{'lower', 'higher', 'nearest', 'linear', 'midpoint', 'equiprobable'}}, got {v}", ))) } }; diff --git a/crates/polars-python/src/expr/general.rs b/crates/polars-python/src/expr/general.rs index d0b6d30c31e1..604049f62b66 100644 --- a/crates/polars-python/src/expr/general.rs +++ b/crates/polars-python/src/expr/general.rs @@ -149,7 +149,7 @@ impl PyExpr { fn implode(&self) -> Self { self.inner.clone().implode().into() } - fn quantile(&self, quantile: Self, interpolation: Wrap) -> Self { + fn quantile(&self, quantile: Self, interpolation: Wrap) -> Self { self.inner .clone() .quantile(quantile.inner, interpolation.0) diff --git a/crates/polars-python/src/expr/rolling.rs b/crates/polars-python/src/expr/rolling.rs index 629f1eab391d..a5ef9213128f 100644 --- a/crates/polars-python/src/expr/rolling.rs +++ b/crates/polars-python/src/expr/rolling.rs @@ -276,7 +276,7 @@ impl PyExpr { fn rolling_quantile( &self, quantile: f64, - interpolation: Wrap, + interpolation: Wrap, window_size: usize, weights: Option>, min_periods: Option, @@ -302,7 +302,7 @@ impl PyExpr { &self, by: PyExpr, quantile: f64, - interpolation: Wrap, + interpolation: Wrap, window_size: &str, min_periods: usize, closed: Wrap, diff --git a/crates/polars-python/src/lazyframe/general.rs b/crates/polars-python/src/lazyframe/general.rs index 85200e339065..fa28a9f8e5ed 100644 --- a/crates/polars-python/src/lazyframe/general.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -1051,7 +1051,7 @@ impl PyLazyFrame { out.into() } - fn quantile(&self, quantile: PyExpr, interpolation: Wrap) -> Self { + fn quantile(&self, quantile: PyExpr, interpolation: Wrap) -> Self { let ldf = self.ldf.clone(); let out = ldf.quantile(quantile.inner, interpolation.0); out.into() diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index 67c25d755084..07d2f872437c 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -2,7 +2,7 @@ use polars::datatypes::TimeUnit; #[cfg(feature = "iejoin")] use polars::prelude::InequalityOperator; use polars::series::ops::NullBehavior; -use polars_core::prelude::{NonExistent, QuantileInterpolOptions}; +use polars_core::prelude::{NonExistent, QuantileMethod}; use polars_core::series::IsSorted; use polars_ops::prelude::ClosedInterval; use polars_ops::series::InterpolationMethod; @@ -700,16 +700,17 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { IRAggExpr::Quantile { expr, quantile, - interpol, + method: interpol, } => Agg { name: "quantile".to_object(py), arguments: vec![expr.0, quantile.0], options: match interpol { - QuantileInterpolOptions::Nearest => "nearest", - QuantileInterpolOptions::Lower => "lower", - QuantileInterpolOptions::Higher => "higher", - QuantileInterpolOptions::Midpoint => "midpoint", - QuantileInterpolOptions::Linear => "linear", + QuantileMethod::Nearest => "nearest", + QuantileMethod::Lower => "lower", + QuantileMethod::Higher => "higher", + QuantileMethod::Midpoint => "midpoint", + QuantileMethod::Linear => "linear", + QuantileMethod::Equiprobable => "equiprobable", } .to_object(py), }, diff --git a/crates/polars-python/src/series/aggregation.rs b/crates/polars-python/src/series/aggregation.rs index dbcbad59ddac..5aa8ee16639e 100644 --- a/crates/polars-python/src/series/aggregation.rs +++ b/crates/polars-python/src/series/aggregation.rs @@ -105,11 +105,7 @@ impl PySeries { .into_py(py)) } - fn quantile( - &self, - quantile: f64, - interpolation: Wrap, - ) -> PyResult { + fn quantile(&self, quantile: f64, interpolation: Wrap) -> PyResult { let bind = self.series.quantile_reduce(quantile, interpolation.0); let sc = bind.map_err(PyPolarsErr::from)?; diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 48011a323764..5f6c311a5e15 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -3,7 +3,7 @@ use std::ops::Sub; use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions}; use polars_core::export::regex; use polars_core::prelude::{ - polars_bail, polars_err, DataType, PolarsResult, QuantileInterpolOptions, Schema, TimeUnit, + polars_bail, polars_err, DataType, PolarsResult, QuantileMethod, Schema, TimeUnit, }; use polars_lazy::dsl::Expr; #[cfg(feature = "list_eval")] @@ -513,6 +513,13 @@ pub(crate) enum PolarsSQLFunctions { /// SELECT QUANTILE_CONT(column_1) FROM df; /// ``` QuantileCont, + /// SQL 'quantile_disc' function + /// Divides the [0, 1] interval into equal-length subintervals, each corresponding to a value, + /// and returns the value associated with the subinterval where the quantile value falls. + /// ```sql + /// SELECT QUANTILE_DISC(column_1) FROM df; + /// ``` + QuantileDisc, /// SQL 'min' function /// Returns the smallest (minimum) of all the elements in the grouping. /// ```sql @@ -688,6 +695,7 @@ impl PolarsSQLFunctions { "ltrim", "max", "median", + "quantile_disc", "min", "mod", "nullif", @@ -696,6 +704,7 @@ impl PolarsSQLFunctions { "pow", "power", "quantile_cont", + "quantile_disc", "radians", "regexp_like", "replace", @@ -829,6 +838,7 @@ impl PolarsSQLFunctions { "max" => Self::Max, "median" => Self::Median, "quantile_cont" => Self::QuantileCont, + "quantile_disc" => Self::QuantileDisc, "min" => Self::Min, "stdev" | "stddev" | "stdev_samp" | "stddev_samp" => Self::StdDev, "sum" => Self::Sum, @@ -1275,11 +1285,37 @@ impl SQLFunctionVisitor<'_> { }, _ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_CONT ({})", args[1]) }; - Ok(e.quantile(value, QuantileInterpolOptions::Linear)) + Ok(e.quantile(value, QuantileMethod::Linear)) }), _ => polars_bail!(SQLSyntax: "QUANTILE_CONT expects 2 arguments (found {})", args.len()), } }, + QuantileDisc => { + let args = extract_args(function)?; + match args.len() { + 2 => self.try_visit_binary(|e, q| { + let value = match q { + Expr::Literal(LiteralValue::Float(f)) => { + if (0.0..=1.0).contains(&f) { + Expr::from(f) + } else { + polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1]) + } + }, + Expr::Literal(LiteralValue::Int(n)) => { + if (0..=1).contains(&n) { + Expr::from(n as f64) + } else { + polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1]) + } + }, + _ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_DISC ({})", args[1]) + }; + Ok(e.quantile(value, QuantileMethod::Equiprobable)) + }), + _ => polars_bail!(SQLSyntax: "QUANTILE_DISC expects 2 arguments (found {})", args.len()), + } + }, Min => self.visit_unary_with_opt_cumulative(Expr::min, Expr::cum_min), StdDev => self.visit_unary(|e| e.std(1)), Sum => self.visit_unary_with_opt_cumulative(Expr::sum, Expr::cum_sum), diff --git a/crates/polars-sql/tests/functions_aggregate.rs b/crates/polars-sql/tests/functions_aggregate.rs index 621ca18bd355..092a340f5f18 100644 --- a/crates/polars-sql/tests/functions_aggregate.rs +++ b/crates/polars-sql/tests/functions_aggregate.rs @@ -5,9 +5,7 @@ use polars_sql::*; fn create_df() -> LazyFrame { df! { - "Year" => [2018, 2018, 2019, 2019, 2020, 2020], - "Country" => ["US", "UK", "US", "UK", "US", "UK"], - "Sales" => [1000, 2000, 3000, 4000, 5000, 6000] + "Data" => [1000, 2000, 3000, 4000, 5000, 6000] } .unwrap() .lazy() @@ -41,9 +39,9 @@ fn create_expected(expr: Expr, sql: &str) -> (DataFrame, DataFrame) { #[test] fn test_median() { - let expr = col("Sales").median(); + let expr = col("Data").median(); - let sql_expr = "MEDIAN(Sales)"; + let sql_expr = "MEDIAN(Data)"; let (expected, actual) = create_expected(expr, sql_expr); assert!(expected.equals(&actual)) @@ -52,9 +50,9 @@ fn test_median() { #[test] fn test_quantile_cont() { for &q in &[0.25, 0.5, 0.75] { - let expr = col("Sales").quantile(lit(q), QuantileInterpolOptions::Linear); + let expr = col("Data").quantile(lit(q), QuantileMethod::Linear); - let sql_expr = format!("QUANTILE_CONT(Sales, {})", q); + let sql_expr = format!("QUANTILE_CONT(Data, {})", q); let (expected, actual) = create_expected(expr, &sql_expr); assert!( @@ -63,3 +61,61 @@ fn test_quantile_cont() { ) } } + +#[test] +fn test_quantile_disc() { + for &q in &[0.25, 0.5, 0.75] { + let expr = col("Data").quantile(lit(q), QuantileMethod::Equiprobable); + + let sql_expr = format!("QUANTILE_DISC(Data, {})", q); + let (expected, actual) = create_expected(expr, &sql_expr); + + assert!(expected.equals(&actual)) + } +} + +#[test] +fn test_quantile_out_of_range() { + for &q in &["-1", "2", "-0.01", "1.01"] { + for &func in &["QUANTILE_CONT", "QUANTILE_DISC"] { + let query = format!("SELECT {func}(Data, {q})"); + let mut ctx = SQLContext::new(); + ctx.register("df", create_df()); + let actual = ctx.execute(&query); + assert!(actual.is_err()) + } + } +} + +#[test] +fn test_quantile_disc_conformance() { + let expected = df![ + "q" => [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], + "Data" => [1000, 1000, 2000, 2000, 3000, 3000, 4000, 5000, 5000, 6000, 6000], + ] + .unwrap(); + + let mut ctx = SQLContext::new(); + ctx.register("df", create_df()); + + let mut actual: Option = None; + for &q in &[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] { + let res = ctx + .execute(&format!( + "SELECT {q}::float as q, QUANTILE_DISC(Data, {q}) as Data FROM df" + )) + .unwrap() + .collect() + .unwrap(); + actual = if let Some(df) = actual { + Some(df.vstack(&res).unwrap()) + } else { + Some(res) + }; + } + + assert!( + expected.equals(actual.as_ref().unwrap()), + "expected {expected:?}, got {actual:?}" + ) +} diff --git a/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs b/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs index 746c517ce744..e3377036b908 100644 --- a/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs +++ b/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs @@ -141,7 +141,7 @@ impl ParquetSourceNode { } if allow_missing_columns { - ensure_matching_dtypes_if_found(&first_schema, &schema)?; + ensure_matching_dtypes_if_found(projected_arrow_schema.as_ref(), &schema)?; } else { ensure_schema_has_projected_fields( &schema, diff --git a/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs b/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs index dfa4b11e3b02..52d3003de7ea 100644 --- a/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs +++ b/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs @@ -2,11 +2,12 @@ use std::future::Future; use std::sync::Arc; use polars_core::prelude::{ArrowSchema, InitHashMaps, PlHashMap}; +use polars_core::series::IsSorted; use polars_core::utils::operation_exceeded_idxsize_msg; use polars_error::{polars_err, PolarsResult}; use polars_io::predicates::PhysicalIoExpr; -use polars_io::prelude::FileMetadata; use polars_io::prelude::_internal::read_this_row_group; +use polars_io::prelude::{create_sorting_map, FileMetadata}; use polars_io::utils::byte_source::{ByteSource, DynByteSource}; use polars_io::utils::slice::SplitSlicePosition; use polars_parquet::read::RowGroupMetadata; @@ -27,6 +28,7 @@ pub(super) struct RowGroupData { pub(super) slice: Option<(usize, usize)>, pub(super) file_max_row_group_height: usize, pub(super) row_group_metadata: RowGroupMetadata, + pub(super) sorting_map: PlHashMap, pub(super) shared_file_state: Arc>, } @@ -86,6 +88,7 @@ impl RowGroupDataFetcher { let current_row_group_idx = self.current_row_group_idx; let num_rows = row_group_metadata.num_rows(); + let sorting_map = create_sorting_map(&row_group_metadata); self.current_row_offset = current_row_offset.saturating_add(num_rows); self.current_row_group_idx += 1; @@ -246,6 +249,7 @@ impl RowGroupDataFetcher { slice, file_max_row_group_height: current_max_row_group_height, row_group_metadata, + sorting_map, shared_file_state: current_shared_file_state.clone(), }) }); diff --git a/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs b/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs index 119345295686..975ff6de22cb 100644 --- a/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs +++ b/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs @@ -11,6 +11,7 @@ use polars_error::{polars_bail, PolarsResult}; use polars_io::predicates::PhysicalIoExpr; use polars_io::prelude::_internal::calc_prefilter_cost; pub use polars_io::prelude::_internal::PrefilterMaskSetting; +use polars_io::prelude::try_set_sorted_flag; use polars_io::RowIndex; use polars_plan::plans::hive::HivePartitions; use polars_plan::plans::ScanSources; @@ -367,11 +368,20 @@ fn decode_column( assert_eq!(array.len(), expected_num_rows); - let series = Series::try_from((arrow_field, array))?; + let mut series = Series::try_from((arrow_field, array))?; + + if let Some(col_idxs) = row_group_data + .row_group_metadata + .columns_idxs_under_root_iter(&arrow_field.name) + { + if col_idxs.len() == 1 { + try_set_sorted_flag(&mut series, col_idxs[0], &row_group_data.sorting_map); + } + } // TODO: Also load in the metadata. - Ok(series.into()) + Ok(series.into_column()) } /// # Safety @@ -652,17 +662,26 @@ fn decode_column_prefiltered( deserialize_filter, )?; - let column = Series::try_from((arrow_field, array))?.into_column(); + let mut series = Series::try_from((arrow_field, array))?; + + if let Some(col_idxs) = row_group_data + .row_group_metadata + .columns_idxs_under_root_iter(&arrow_field.name) + { + if col_idxs.len() == 1 { + try_set_sorted_flag(&mut series, col_idxs[0], &row_group_data.sorting_map); + } + } - let column = if !prefilter { - column.filter(mask)? + let series = if !prefilter { + series.filter(mask)? } else { - column + series }; - assert_eq!(column.len(), expected_num_rows); + assert_eq!(series.len(), expected_num_rows); - Ok(column) + Ok(series.into_column()) } mod tests { diff --git a/crates/polars-time/src/group_by/dynamic.rs b/crates/polars-time/src/group_by/dynamic.rs index 8a8d2312d580..3ff08ee4d308 100644 --- a/crates/polars-time/src/group_by/dynamic.rs +++ b/crates/polars-time/src/group_by/dynamic.rs @@ -789,12 +789,12 @@ mod test { let quantile = unsafe { a.as_materialized_series() - .agg_quantile(&groups, 0.5, QuantileInterpolOptions::Linear) + .agg_quantile(&groups, 0.5, QuantileMethod::Linear) }; let expected = Series::new("".into(), [3.0, 5.0, 5.0, 6.0, 5.5, 1.0]); assert_eq!(quantile, expected); - let quantile = unsafe { nulls.agg_quantile(&groups, 0.5, QuantileInterpolOptions::Linear) }; + let quantile = unsafe { nulls.agg_quantile(&groups, 0.5, QuantileMethod::Linear) }; let expected = Series::new("".into(), [3.0, 5.0, 5.0, 7.0, 5.5, 1.0]); assert_eq!(quantile, expected); diff --git a/crates/polars/tests/it/lazy/aggregation.rs b/crates/polars/tests/it/lazy/aggregation.rs index ad043e698e2e..10c386037d17 100644 --- a/crates/polars/tests/it/lazy/aggregation.rs +++ b/crates/polars/tests/it/lazy/aggregation.rs @@ -26,7 +26,7 @@ fn test_lazy_agg() { col("rain").min().alias("min"), col("rain").sum().alias("sum"), col("rain") - .quantile(lit(0.5), QuantileInterpolOptions::default()) + .quantile(lit(0.5), QuantileMethod::default()) .alias("median_rain"), ]) .sort(["date"], Default::default()); diff --git a/docs/source/user-guide/ecosystem.md b/docs/source/user-guide/ecosystem.md index 21f1dbc2ba60..d6fc8e0c9524 100644 --- a/docs/source/user-guide/ecosystem.md +++ b/docs/source/user-guide/ecosystem.md @@ -71,3 +71,7 @@ With [Great Tables](https://posit-dev.github.io/great-tables/articles/intro.html #### Mage [Mage](https://www.mage.ai) is an open-source data pipeline tool for transforming and integrating data. Learn about integration between Polars and Mage at [docs.mage.ai](https://docs.mage.ai/integrations/polars). + +#### marimo + +[marimo](https://marimo.io) is a reactive notebook for Python and SQL that models notebooks as dataflow graphs. It offers built-in support for Polars, allowing seamless integration of Polars dataframes in an interactive, reactive environment - such as displaying rich Polars tables, no-code transformations of Polars dataframes, or selecting points on a Polars-backed reactive chart. diff --git a/py-polars/polars/io/database/_executor.py b/py-polars/polars/io/database/_executor.py index 278e3e8e0738..3162c974a317 100644 --- a/py-polars/polars/io/database/_executor.py +++ b/py-polars/polars/io/database/_executor.py @@ -392,7 +392,7 @@ def _normalise_cursor(self, conn: Any) -> Cursor: return conn.engine.raw_connection().cursor() elif conn.engine.driver == "duckdb_engine": self.driver_name = "duckdb" - return conn.engine.raw_connection().driver_connection + return conn elif self._is_alchemy_engine(conn): # note: if we create it, we can close it self.can_close_cursor = True @@ -505,8 +505,11 @@ def execute( ) result = cursor_execute(query, *positional_options) - # note: some cursors execute in-place + # note: some cursors execute in-place, some access results via a property result = self.cursor if result is None else result + if self.driver_name == "duckdb": + result = result.cursor + self.result = result return self diff --git a/py-polars/tests/unit/io/test_lazy_parquet.py b/py-polars/tests/unit/io/test_lazy_parquet.py index 792ddc42b02a..49a842386ea2 100644 --- a/py-polars/tests/unit/io/test_lazy_parquet.py +++ b/py-polars/tests/unit/io/test_lazy_parquet.py @@ -235,12 +235,12 @@ def test_parquet_is_in_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) captured = capfd.readouterr().err assert ( - "parquet file must be read, statistics not sufficient for predicate." + "parquet row group must be read, statistics not sufficient for predicate." in captured ) assert ( - "parquet file can be skipped, the statistics were sufficient" - " to apply the predicate." in captured + "parquet row group can be skipped, the statistics were sufficient to apply the predicate." + in captured ) @@ -710,10 +710,18 @@ def test_parquet_schema_arg( schema: dict[str, type[pl.DataType]] = {"a": pl.Int64} # type: ignore[no-redef] - lf = pl.scan_parquet(paths, parallel=parallel, schema=schema) + for allow_missing_columns in [True, False]: + lf = pl.scan_parquet( + paths, + parallel=parallel, + schema=schema, + allow_missing_columns=allow_missing_columns, + ) - with pytest.raises(pl.exceptions.SchemaError, match="file contained extra columns"): - lf.collect(streaming=streaming) + with pytest.raises( + pl.exceptions.SchemaError, match="file contained extra columns" + ): + lf.collect(streaming=streaming) lf = pl.scan_parquet(paths, parallel=parallel, schema=schema).select("a") @@ -731,3 +739,29 @@ def test_parquet_schema_arg( match="data type mismatch for column b: expected: i8, found: i64", ): lf.collect(streaming=streaming) + + +@pytest.mark.parametrize("streaming", [True, False]) +@pytest.mark.parametrize("allow_missing_columns", [True, False]) +@pytest.mark.write_disk +def test_scan_parquet_ignores_dtype_mismatch_for_non_projected_columns_19249( + tmp_path: Path, + allow_missing_columns: bool, + streaming: bool, +) -> None: + tmp_path.mkdir(exist_ok=True) + paths = [tmp_path / "1", tmp_path / "2"] + + pl.DataFrame({"a": 1, "b": 1}, schema={"a": pl.Int32, "b": pl.UInt8}).write_parquet( + paths[0] + ) + pl.DataFrame( + {"a": 1, "b": 1}, schema={"a": pl.Int32, "b": pl.UInt64} + ).write_parquet(paths[1]) + + assert_frame_equal( + pl.scan_parquet(paths, allow_missing_columns=allow_missing_columns) + .select("a") + .collect(streaming=streaming), + pl.DataFrame({"a": [1, 1]}, schema={"a": pl.Int32}), + ) diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 850bf61d978b..71dfab8913d9 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -1990,3 +1990,52 @@ def test_nested_nonnullable_19158() -> None: f.seek(0) assert_frame_equal(pl.read_parquet(f), pl.DataFrame(tbl)) + + +@pytest.mark.parametrize("parallel", ["prefiltered", "columns", "row_groups", "auto"]) +def test_conserve_sortedness( + monkeypatch: Any, capfd: Any, parallel: pl.ParallelStrategy +) -> None: + f = io.BytesIO() + + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5, None], + "b": [1.0, 2.0, 3.0, 4.0, 5.0, None], + "c": [None, 5, 4, 3, 2, 1], + "d": [None, 5.0, 4.0, 3.0, 2.0, 1.0], + "a_nosort": [1, 2, 3, 4, 5, None], + "f": range(6), + } + ) + + pq.write_table( + df.to_arrow(), + f, + sorting_columns=[ + pq.SortingColumn(0, False, False), + pq.SortingColumn(1, False, False), + pq.SortingColumn(2, True, True), + pq.SortingColumn(3, True, True), + ], + ) + + f.seek(0) + + monkeypatch.setenv("POLARS_VERBOSE", "1") + + df = pl.scan_parquet(f, parallel=parallel).filter(pl.col.f > 1).collect() + + captured = capfd.readouterr().err + + # @NOTE: We don't conserve sortedness for anything except integers at the + # moment. + assert captured.count("Parquet conserved SortingColumn for column chunk of") == 2 + assert ( + "Parquet conserved SortingColumn for column chunk of 'a' to Ascending" + in captured + ) + assert ( + "Parquet conserved SortingColumn for column chunk of 'c' to Descending" + in captured + )