From 273662149f67190712ae4598a3c086c0a095509f Mon Sep 17 00:00:00 2001 From: Roman Nikitin Date: Wed, 16 Oct 2024 17:51:16 +0300 Subject: [PATCH] feat: New quantile interpolation method & QUANTILE_DISC function in SQL (#19139) --- .../src/legacy/kernels/rolling/mod.rs | 2 +- .../legacy/kernels/rolling/no_nulls/mod.rs | 6 +- .../kernels/rolling/no_nulls/quantile.rs | 50 ++-- .../legacy/kernels/rolling/nulls/quantile.rs | 46 ++-- .../legacy/kernels/rolling/quantile_filter.rs | 32 +-- crates/polars-arrow/src/legacy/prelude.rs | 5 +- .../src/chunked_array/ops/aggregate/mod.rs | 227 ++++++++---------- .../chunked_array/ops/aggregate/quantile.rs | 112 ++++----- .../polars-core/src/chunked_array/ops/mod.rs | 6 +- crates/polars-core/src/frame/column/mod.rs | 4 +- .../frame/group_by/aggregations/dispatch.rs | 9 +- .../src/frame/group_by/aggregations/mod.rs | 53 ++-- crates/polars-core/src/frame/group_by/mod.rs | 16 +- .../src/series/implementations/datetime.rs | 6 +- .../src/series/implementations/decimal.rs | 8 +- .../src/series/implementations/duration.rs | 8 +- .../src/series/implementations/floats.rs | 4 +- .../src/series/implementations/mod.rs | 4 +- crates/polars-core/src/series/series_trait.rs | 6 +- .../src/expressions/aggregation.rs | 10 +- crates/polars-expr/src/planner.rs | 4 +- crates/polars-lazy/src/frame/mod.rs | 12 +- crates/polars-lazy/src/lib.rs | 4 +- crates/polars-ops/src/series/ops/cut.rs | 6 +- crates/polars-plan/src/dsl/expr.rs | 2 +- .../src/dsl/functions/syntactic_sugar.rs | 4 +- crates/polars-plan/src/dsl/mod.rs | 18 +- crates/polars-plan/src/plans/aexpr/mod.rs | 8 +- .../src/plans/conversion/dsl_to_ir.rs | 4 +- .../src/plans/conversion/expr_to_ir.rs | 4 +- .../src/plans/conversion/ir_to_dsl.rs | 4 +- crates/polars-plan/src/plans/functions/dsl.rs | 2 +- crates/polars-plan/src/plans/visitor/expr.rs | 2 +- crates/polars-python/src/conversion/mod.rs | 15 +- crates/polars-python/src/expr/general.rs | 2 +- crates/polars-python/src/expr/rolling.rs | 4 +- crates/polars-python/src/lazyframe/general.rs | 2 +- .../src/lazyframe/visitor/expr_nodes.rs | 15 +- .../polars-python/src/series/aggregation.rs | 6 +- crates/polars-sql/src/functions.rs | 40 ++- .../polars-sql/tests/functions_aggregate.rs | 70 +++++- crates/polars-time/src/group_by/dynamic.rs | 4 +- crates/polars/tests/it/lazy/aggregation.rs | 2 +- 43 files changed, 435 insertions(+), 413 deletions(-) diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs index 8f8004c570e8..51f3c95d2a56 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/mod.rs @@ -93,5 +93,5 @@ pub struct RollingVarParams { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct RollingQuantileParams { pub prob: f64, - pub interpol: QuantileInterpolOptions, + pub method: QuantileMethod, } diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs index 1b9695358dcb..3277318e6807 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs @@ -71,15 +71,19 @@ where #[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum QuantileInterpolOptions { +pub enum QuantileMethod { #[default] Nearest, Lower, Higher, Midpoint, Linear, + Equiprobable, } +#[deprecated(note = "use QuantileMethod instead")] +pub type QuantileInterpolOptions = QuantileMethod; + pub(super) fn rolling_apply_weights( values: &[T], window_size: usize, diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs index ab3919b9aaaa..bf0ad01e79c3 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs @@ -2,13 +2,13 @@ use num_traits::ToPrimitive; use polars_error::polars_ensure; use polars_utils::slice::GetSaferUnchecked; -use super::QuantileInterpolOptions::*; +use super::QuantileMethod::*; use super::*; pub struct QuantileWindow<'a, T: NativeType> { sorted: SortedBuf<'a, T>, prob: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, } impl< @@ -34,7 +34,7 @@ impl< Self { sorted: SortedBuf::new(slice, start, end), prob: params.prob, - interpol: params.interpol, + method: params.method, } } @@ -42,7 +42,7 @@ impl< let vals = self.sorted.update(start, end); let length = vals.len(); - let idx = match self.interpol { + let idx = match self.method { Linear => { // Maybe add a fast path for median case? They could branch depending on odd/even. let length_f = length as f64; @@ -92,6 +92,7 @@ impl< let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize; std::cmp::min(idx, length - 1) }, + Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize, }; // SAFETY: @@ -134,7 +135,7 @@ where unreachable!("expected Quantile params"); }; let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>( - params.interpol, + params.method, min_periods, window_size, values, @@ -170,7 +171,7 @@ where Ok(rolling_apply_weighted_quantile( values, params.prob, - params.interpol, + params.method, window_size, min_periods, offset_fn, @@ -182,7 +183,7 @@ where } #[inline] -fn compute_wq(buf: &[(T, f64)], p: f64, wsum: f64, interp: QuantileInterpolOptions) -> T +fn compute_wq(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T where T: Debug + NativeType + Mul + Sub + NumCast + ToPrimitive + Zero, { @@ -201,7 +202,7 @@ where (s_old, v_old, vk) = (s, vk, v); s += w; } - match (h == s_old, interp) { + match (h == s_old, method) { (true, _) => v_old, // If we hit the break exactly interpolation shouldn't matter (_, Lower) => v_old, (_, Higher) => vk, @@ -212,6 +213,14 @@ where vk } }, + (_, Equiprobable) => { + let threshold = (wsum * p).ceil() - 1.0; + if s > threshold { + vk + } else { + v_old + } + }, (_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(), // This is seemingly the canonical way to do it. (_, Linear) => { @@ -224,7 +233,7 @@ where fn rolling_apply_weighted_quantile( values: &[T], p: f64, - interpolation: QuantileInterpolOptions, + method: QuantileMethod, window_size: usize, min_periods: usize, det_offsets_fn: Fo, @@ -252,7 +261,7 @@ where .for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w)); } buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0)); - compute_wq(&buf, p, wsum, interpolation) + compute_wq(&buf, p, wsum, method) }) .collect_trusted::>(); @@ -273,7 +282,7 @@ mod test { let values = &[1.0, 2.0, 3.0, 4.0]; let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.5, - interpol: Linear, + method: Linear, })); let out = rolling_quantile(values, 2, 2, false, None, med_pars.clone()).unwrap(); let out = out.as_any().downcast_ref::>().unwrap(); @@ -305,18 +314,19 @@ mod test { fn test_rolling_quantile_limits() { let values = &[1.0f64, 2.0, 3.0, 4.0]; - let interpol_options = vec![ - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Nearest, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { + for method in methods { let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.0, - interpol, + method, })); let out1 = rolling_min(values, 2, 2, false, None, None).unwrap(); let out1 = out1.as_any().downcast_ref::>().unwrap(); @@ -328,7 +338,7 @@ mod test { let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 1.0, - interpol, + method, })); let out1 = rolling_max(values, 2, 2, false, None, None).unwrap(); let out1 = out1.as_any().downcast_ref::>().unwrap(); diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs index 259316513fe5..3d5dd664bd34 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs @@ -6,7 +6,7 @@ use crate::array::MutablePrimitiveArray; pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> { sorted: SortedBufNulls<'a, T>, prob: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, } impl< @@ -39,7 +39,7 @@ impl< Self { sorted: SortedBufNulls::new(slice, validity, start, end), prob: params.prob, - interpol: params.interpol, + method: params.method, } } @@ -53,21 +53,22 @@ impl< let values = &values[null_count..]; let length = values.len(); - let mut idx = match self.interpol { - QuantileInterpolOptions::Nearest => ((length as f64) * self.prob) as usize, - QuantileInterpolOptions::Lower - | QuantileInterpolOptions::Midpoint - | QuantileInterpolOptions::Linear => { + let mut idx = match self.method { + QuantileMethod::Nearest => ((length as f64) * self.prob) as usize, + QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => { ((length as f64 - 1.0) * self.prob).floor() as usize }, - QuantileInterpolOptions::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize, + QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize, + QuantileMethod::Equiprobable => { + ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize + }, }; idx = std::cmp::min(idx, length - 1); // we can unwrap because we sliced of the nulls - match self.interpol { - QuantileInterpolOptions::Midpoint => { + match self.method { + QuantileMethod::Midpoint => { let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize; Some( (values.get_unchecked_release(idx).unwrap() @@ -75,7 +76,7 @@ impl< / T::from::(2.0f64).unwrap(), ) }, - QuantileInterpolOptions::Linear => { + QuantileMethod::Linear => { let float_idx = (length as f64 - 1.0) * self.prob; let top_idx = f64::ceil(float_idx) as usize; @@ -136,7 +137,7 @@ where }; let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>( - params.interpol, + params.method, min_periods, window_size, arr.clone(), @@ -171,7 +172,7 @@ mod test { ); let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.5, - interpol: QuantileInterpolOptions::Linear, + method: QuantileMethod::Linear, })); let out = rolling_quantile(arr, 2, 2, false, None, med_pars.clone()); @@ -210,18 +211,19 @@ mod test { Some(Bitmap::from(&[true, false, false, true, true])), ); - let interpol_options = vec![ - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Nearest, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { + for method in methods { let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 0.0, - interpol, + method, })); let out1 = rolling_min(values, 2, 1, false, None, None); let out1 = out1.as_any().downcast_ref::>().unwrap(); @@ -233,7 +235,7 @@ mod test { let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: 1.0, - interpol, + method, })); let out1 = rolling_max(values, 2, 1, false, None, None); let out1 = out1.as_any().downcast_ref::>().unwrap(); diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs b/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs index 40a464e6f5bc..0b5fb4d97e86 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs @@ -11,7 +11,7 @@ use polars_utils::slice::{GetSaferUnchecked, SliceAble}; use polars_utils::sort::arg_sort_ascending; use polars_utils::total_ord::TotalOrd; -use crate::legacy::prelude::QuantileInterpolOptions; +use crate::legacy::prelude::QuantileMethod; use crate::pushable::Pushable; use crate::types::NativeType; @@ -573,7 +573,7 @@ struct QuantileUpdate { inner: M, quantile: f64, min_periods: usize, - interpol: QuantileInterpolOptions, + method: QuantileMethod, } impl QuantileUpdate @@ -581,12 +581,12 @@ where M: LenGet, ::Item: Default + IsNull + Copy + FinishLinear + Debug, { - fn new(interpol: QuantileInterpolOptions, min_periods: usize, quantile: f64, inner: M) -> Self { + fn new(method: QuantileMethod, min_periods: usize, quantile: f64, inner: M) -> Self { Self { min_periods, quantile, inner, - interpol, + method, } } @@ -602,8 +602,8 @@ where let valid_length_f = valid_length as f64; - use QuantileInterpolOptions::*; - match self.interpol { + use QuantileMethod::*; + match self.method { Linear => { let float_idx_top = (valid_length_f - 1.0) * self.quantile; let idx = float_idx_top.floor() as usize; @@ -623,6 +623,10 @@ where let idx = std::cmp::min(idx, valid_length - 1); self.inner.get(idx + null_count) }, + Equiprobable => { + let idx = ((valid_length_f * self.quantile).ceil() - 1.0).max(0.0) as usize; + self.inner.get(idx + null_count) + }, Midpoint => { let idx = (valid_length_f * self.quantile) as usize; let idx = std::cmp::min(idx, valid_length - 1); @@ -651,7 +655,7 @@ where } pub(super) fn rolling_quantile::Item>>( - interpol: QuantileInterpolOptions, + method: QuantileMethod, min_periods: usize, k: usize, values: A, @@ -709,7 +713,7 @@ where // SAFETY: bounded by capacity unsafe { block_left.undelete(i) }; - let mut mu = QuantileUpdate::new(interpol, min_periods, quantile, &mut block_left); + let mut mu = QuantileUpdate::new(method, min_periods, quantile, &mut block_left); out.push(mu.quantile()); } for i in 1..n_blocks + 1 { @@ -747,7 +751,7 @@ where let mut union = BlockUnion::new(&mut *ptr_left, &mut *ptr_right); union.set_state(j); let q: ::Item = - QuantileUpdate::new(interpol, min_periods, quantile, union).quantile(); + QuantileUpdate::new(method, min_periods, quantile, union).quantile(); out.push(q); } } @@ -1062,22 +1066,22 @@ mod test { 2.0, 8.0, 5.0, 9.0, 1.0, 2.0, 4.0, 2.0, 4.0, 8.1, -1.0, 2.9, 1.2, 23.0, ] .as_ref(); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 3, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 3, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 8.0, 5.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 2.9, 1.2, 2.9, ]; assert_eq!(out, expected); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 5, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 5, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 6.5, 5.0, 5.0, 4.0, 2.0, 2.0, 4.0, 4.0, 2.9, 2.9, 2.9, ]; assert_eq!(out, expected); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 7, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 7, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 6.5, 5.0, 3.5, 4.0, 4.0, 4.0, 4.0, 2.0, 2.9, 2.9, 2.9, ]; assert_eq!(out, expected); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 4, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 4, values, 0.5); let expected = [ 2.0, 5.0, 5.0, 6.5, 6.5, 3.5, 3.0, 2.0, 3.0, 4.0, 3.0, 3.45, 2.05, 2.05, ]; @@ -1087,7 +1091,7 @@ mod test { #[test] fn test_median_2() { let values = [10, 10, 15, 13, 9, 5, 3, 13, 19, 15, 19].as_ref(); - let out: Vec<_> = rolling_quantile(QuantileInterpolOptions::Linear, 0, 3, values, 0.5); + let out: Vec<_> = rolling_quantile(QuantileMethod::Linear, 0, 3, values, 0.5); let expected = [10, 10, 10, 13, 13, 9, 5, 5, 13, 15, 19]; assert_eq!(out, expected); } diff --git a/crates/polars-arrow/src/legacy/prelude.rs b/crates/polars-arrow/src/legacy/prelude.rs index 88b2dd48bbea..6afeb0c6c9be 100644 --- a/crates/polars-arrow/src/legacy/prelude.rs +++ b/crates/polars-arrow/src/legacy/prelude.rs @@ -2,7 +2,7 @@ use crate::array::{BinaryArray, ListArray, Utf8Array}; pub use crate::legacy::array::default_arrays::*; pub use crate::legacy::array::*; pub use crate::legacy::index::*; -pub use crate::legacy::kernels::rolling::no_nulls::QuantileInterpolOptions; +pub use crate::legacy::kernels::rolling::no_nulls::QuantileMethod; pub use crate::legacy::kernels::rolling::{ RollingFnParams, RollingQuantileParams, RollingVarParams, }; @@ -11,3 +11,6 @@ pub use crate::legacy::kernels::{Ambiguous, NonExistent}; pub type LargeStringArray = Utf8Array; pub type LargeBinaryArray = BinaryArray; pub type LargeListArray = ListArray; + +#[allow(deprecated)] +pub use crate::legacy::kernels::rolling::no_nulls::QuantileInterpolOptions; diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs index 071073460ff3..5b3c0b5f53d7 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -369,12 +369,8 @@ where ::Simd: Add::Simd> + compute::aggregate::Sum, { - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.quantile(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.quantile(quantile, method)?; Ok(Scalar::new(DataType::Float64, v.into())) } @@ -385,12 +381,8 @@ where } impl QuantileAggSeries for Float32Chunked { - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.quantile(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.quantile(quantile, method)?; Ok(Scalar::new(DataType::Float32, v.into())) } @@ -401,12 +393,8 @@ impl QuantileAggSeries for Float32Chunked { } impl QuantileAggSeries for Float64Chunked { - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.quantile(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.quantile(quantile, method)?; Ok(Scalar::new(DataType::Float64, v.into())) } @@ -735,19 +723,20 @@ mod test { let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); - let interpol_options = vec![ - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Nearest, + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { - assert_eq!(test_f32.quantile(0.9, interpol).unwrap(), None); - assert_eq!(test_i32.quantile(0.9, interpol).unwrap(), None); - assert_eq!(test_f64.quantile(0.9, interpol).unwrap(), None); - assert_eq!(test_i64.quantile(0.9, interpol).unwrap(), None); + for method in methods { + assert_eq!(test_f32.quantile(0.9, method).unwrap(), None); + assert_eq!(test_i32.quantile(0.9, method).unwrap(), None); + assert_eq!(test_f64.quantile(0.9, method).unwrap(), None); + assert_eq!(test_i64.quantile(0.9, method).unwrap(), None); } } @@ -758,19 +747,20 @@ mod test { let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]); let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]); - let interpol_options = vec![ - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Nearest, + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { - assert_eq!(test_f32.quantile(0.5, interpol).unwrap(), Some(1.0)); - assert_eq!(test_i32.quantile(0.5, interpol).unwrap(), Some(1.0)); - assert_eq!(test_f64.quantile(0.5, interpol).unwrap(), Some(1.0)); - assert_eq!(test_i64.quantile(0.5, interpol).unwrap(), Some(1.0)); + for method in methods { + assert_eq!(test_f32.quantile(0.5, method).unwrap(), Some(1.0)); + assert_eq!(test_i32.quantile(0.5, method).unwrap(), Some(1.0)); + assert_eq!(test_f64.quantile(0.5, method).unwrap(), Some(1.0)); + assert_eq!(test_i64.quantile(0.5, method).unwrap(), Some(1.0)); } } @@ -793,37 +783,38 @@ mod test { &[None, Some(1i64), Some(5i64), Some(1i64)], ); - let interpol_options = vec![ - QuantileInterpolOptions::Nearest, - QuantileInterpolOptions::Lower, - QuantileInterpolOptions::Higher, - QuantileInterpolOptions::Midpoint, - QuantileInterpolOptions::Linear, + let methods = vec![ + QuantileMethod::Nearest, + QuantileMethod::Lower, + QuantileMethod::Higher, + QuantileMethod::Midpoint, + QuantileMethod::Linear, + QuantileMethod::Equiprobable, ]; - for interpol in interpol_options { - assert_eq!(test_f32.quantile(0.0, interpol).unwrap(), test_f32.min()); - assert_eq!(test_f32.quantile(1.0, interpol).unwrap(), test_f32.max()); + for method in methods { + assert_eq!(test_f32.quantile(0.0, method).unwrap(), test_f32.min()); + assert_eq!(test_f32.quantile(1.0, method).unwrap(), test_f32.max()); assert_eq!( - test_i32.quantile(0.0, interpol).unwrap().unwrap(), + test_i32.quantile(0.0, method).unwrap().unwrap(), test_i32.min().unwrap() as f64 ); assert_eq!( - test_i32.quantile(1.0, interpol).unwrap().unwrap(), + test_i32.quantile(1.0, method).unwrap().unwrap(), test_i32.max().unwrap() as f64 ); - assert_eq!(test_f64.quantile(0.0, interpol).unwrap(), test_f64.min()); - assert_eq!(test_f64.quantile(1.0, interpol).unwrap(), test_f64.max()); - assert_eq!(test_f64.quantile(0.5, interpol).unwrap(), test_f64.median()); + assert_eq!(test_f64.quantile(0.0, method).unwrap(), test_f64.min()); + assert_eq!(test_f64.quantile(1.0, method).unwrap(), test_f64.max()); + assert_eq!(test_f64.quantile(0.5, method).unwrap(), test_f64.median()); assert_eq!( - test_i64.quantile(0.0, interpol).unwrap().unwrap(), + test_i64.quantile(0.0, method).unwrap().unwrap(), test_i64.min().unwrap() as f64 ); assert_eq!( - test_i64.quantile(1.0, interpol).unwrap().unwrap(), + test_i64.quantile(1.0, method).unwrap().unwrap(), test_i64.max().unwrap() as f64 ); } @@ -837,72 +828,56 @@ mod test { ); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.1, QuantileMethod::Nearest).unwrap(), Some(1.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.9, QuantileMethod::Nearest).unwrap(), Some(5.0) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.6, QuantileMethod::Nearest).unwrap(), Some(3.0) ); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Lower).unwrap(), - Some(1.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Lower).unwrap(), - Some(4.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Lower).unwrap(), - Some(3.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(4.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(3.0)); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Higher).unwrap(), - Some(2.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Higher).unwrap(), - Some(5.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Higher).unwrap(), - Some(4.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(5.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(4.0)); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(), Some(1.5) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(), Some(4.5) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(), Some(3.5) ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.4)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(4.6)); + assert!( + (ca.quantile(0.6, QuantileMethod::Linear).unwrap().unwrap() - 3.4).abs() < 0.0000001 + ); + assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Linear).unwrap(), - Some(1.4) + ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(), + Some(1.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Linear).unwrap(), - Some(4.6) + ca.quantile(0.25, QuantileMethod::Equiprobable).unwrap(), + Some(2.0) ); - assert!( - (ca.quantile(0.6, QuantileInterpolOptions::Linear) - .unwrap() - .unwrap() - - 3.4) - .abs() - < 0.0000001 + assert_eq!( + ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(), + Some(3.0) ); let ca = UInt32Chunked::new( @@ -922,68 +897,54 @@ mod test { ); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.1, QuantileMethod::Nearest).unwrap(), Some(2.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.9, QuantileMethod::Nearest).unwrap(), Some(6.0) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Nearest).unwrap(), + ca.quantile(0.6, QuantileMethod::Nearest).unwrap(), Some(5.0) ); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Lower).unwrap(), - Some(1.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Lower).unwrap(), - Some(6.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Lower).unwrap(), - Some(4.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(6.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(4.0)); - assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Higher).unwrap(), - Some(2.0) - ); - assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Higher).unwrap(), - Some(7.0) - ); - assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Higher).unwrap(), - Some(5.0) - ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(7.0)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(5.0)); assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(), Some(1.5) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(), Some(6.5) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Midpoint).unwrap(), + ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(), Some(4.5) ); + assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.6)); + assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(6.4)); + assert_eq!(ca.quantile(0.6, QuantileMethod::Linear).unwrap(), Some(4.6)); + assert_eq!( - ca.quantile(0.1, QuantileInterpolOptions::Linear).unwrap(), - Some(1.6) + ca.quantile(0.14, QuantileMethod::Equiprobable).unwrap(), + Some(1.0) ); assert_eq!( - ca.quantile(0.9, QuantileInterpolOptions::Linear).unwrap(), - Some(6.4) + ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(), + Some(2.0) ); assert_eq!( - ca.quantile(0.6, QuantileInterpolOptions::Linear).unwrap(), - Some(4.6) + ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(), + Some(5.0) ); } } diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs index d6218e81d463..f7716c864559 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs @@ -4,11 +4,7 @@ pub trait QuantileAggSeries { /// Get the median of the [`ChunkedArray`] as a new [`Series`] of length 1. fn median_reduce(&self) -> Scalar; /// Get the quantile of the [`ChunkedArray`] as a new [`Series`] of length 1. - fn quantile_reduce( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult; + fn quantile_reduce(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult; } /// helper @@ -16,18 +12,23 @@ fn quantile_idx( quantile: f64, length: usize, null_count: usize, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> (usize, f64, usize) { - let float_idx = ((length - null_count) as f64 - 1.0) * quantile + null_count as f64; - let mut base_idx = match interpol { - QuantileInterpolOptions::Nearest => { + let nonnull_count = (length - null_count) as f64; + let float_idx = (nonnull_count - 1.0) * quantile + null_count as f64; + let mut base_idx = match method { + QuantileMethod::Nearest => { let idx = float_idx.round() as usize; - return (float_idx.round() as usize, 0.0, idx); + return (idx, 0.0, idx); + }, + QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => { + float_idx as usize + }, + QuantileMethod::Higher => float_idx.ceil() as usize, + QuantileMethod::Equiprobable => { + let idx = ((nonnull_count * quantile).ceil() - 1.0).max(0.0) as usize + null_count; + return (idx, 0.0, idx); }, - QuantileInterpolOptions::Lower - | QuantileInterpolOptions::Midpoint - | QuantileInterpolOptions::Linear => float_idx as usize, - QuantileInterpolOptions::Higher => float_idx.ceil() as usize, }; base_idx = base_idx.clamp(0, length - 1); @@ -57,7 +58,7 @@ fn midpoint_interpol(lower: T, upper: T) -> T { fn quantile_slice( vals: &mut [T], quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { polars_ensure!((0.0..=1.0).contains(&quantile), ComputeError: "quantile should be between 0.0 and 1.0", @@ -68,21 +69,21 @@ fn quantile_slice( if vals.len() == 1 { return Ok(vals[0].to_f64()); } - let (idx, float_idx, top_idx) = quantile_idx(quantile, vals.len(), 0, interpol); + let (idx, float_idx, top_idx) = quantile_idx(quantile, vals.len(), 0, method); let (_lhs, lower, rhs) = vals.select_nth_unstable_by(idx, TotalOrd::tot_cmp); if idx == top_idx { Ok(lower.to_f64()) } else { - match interpol { - QuantileInterpolOptions::Midpoint => { + match method { + QuantileMethod::Midpoint => { let upper = rhs.iter().copied().min_by(TotalOrd::tot_cmp).unwrap(); Ok(Some(midpoint_interpol( lower.to_f64().unwrap(), upper.to_f64().unwrap(), ))) }, - QuantileInterpolOptions::Linear => { + QuantileMethod::Linear => { let upper = rhs.iter().copied().min_by(TotalOrd::tot_cmp).unwrap(); Ok(linear_interpol( lower.to_f64().unwrap(), @@ -100,7 +101,7 @@ fn quantile_slice( fn generic_quantile( ca: ChunkedArray, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> where T: PolarsNumericType, @@ -117,12 +118,12 @@ where return Ok(None); } - let (idx, float_idx, top_idx) = quantile_idx(quantile, length, null_count, interpol); + let (idx, float_idx, top_idx) = quantile_idx(quantile, length, null_count, method); let sorted = ca.sort(false); let lower = sorted.get(idx).map(|v| v.to_f64().unwrap()); - let opt = match interpol { - QuantileInterpolOptions::Midpoint => { + let opt = match method { + QuantileMethod::Midpoint => { if top_idx == idx { lower } else { @@ -130,7 +131,7 @@ where midpoint_interpol(lower.unwrap(), upper.unwrap()).to_f64() } }, - QuantileInterpolOptions::Linear => { + QuantileMethod::Linear => { if top_idx == idx { lower } else { @@ -149,22 +150,18 @@ where T: PolarsIntegerType, T::Native: TotalOrd, { - fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route if let (Ok(slice), false) = (self.cont_slice(), self.is_sorted_ascending_flag()) { let mut owned = slice.to_vec(); - quantile_slice(&mut owned, quantile, interpol) + quantile_slice(&mut owned, quantile, method) } else { - generic_quantile(self.clone(), quantile, interpol) + generic_quantile(self.clone(), quantile, method) } } fn median(&self) -> Option { - self.quantile(0.5, QuantileInterpolOptions::Linear).unwrap() // unwrap fine since quantile in range + self.quantile(0.5, QuantileMethod::Linear).unwrap() // unwrap fine since quantile in range } } @@ -177,61 +174,52 @@ where pub(crate) fn quantile_faster( mut self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let is_sorted = self.is_sorted_ascending_flag(); if let (Some(slice), false) = (self.cont_slice_mut(), is_sorted) { - quantile_slice(slice, quantile, interpol) + quantile_slice(slice, quantile, method) } else { - self.quantile(quantile, interpol) + self.quantile(quantile, method) } } pub(crate) fn median_faster(self) -> Option { - self.quantile_faster(0.5, QuantileInterpolOptions::Linear) - .unwrap() + self.quantile_faster(0.5, QuantileMethod::Linear).unwrap() } } impl ChunkQuantile for Float32Chunked { - fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let out = if let (Ok(slice), false) = (self.cont_slice(), self.is_sorted_ascending_flag()) { let mut owned = slice.to_vec(); - quantile_slice(&mut owned, quantile, interpol) + quantile_slice(&mut owned, quantile, method) } else { - generic_quantile(self.clone(), quantile, interpol) + generic_quantile(self.clone(), quantile, method) }; out.map(|v| v.map(|v| v as f32)) } fn median(&self) -> Option { - self.quantile(0.5, QuantileInterpolOptions::Linear).unwrap() // unwrap fine since quantile in range + self.quantile(0.5, QuantileMethod::Linear).unwrap() // unwrap fine since quantile in range } } impl ChunkQuantile for Float64Chunked { - fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route if let (Ok(slice), false) = (self.cont_slice(), self.is_sorted_ascending_flag()) { let mut owned = slice.to_vec(); - quantile_slice(&mut owned, quantile, interpol) + quantile_slice(&mut owned, quantile, method) } else { - generic_quantile(self.clone(), quantile, interpol) + generic_quantile(self.clone(), quantile, method) } } fn median(&self) -> Option { - self.quantile(0.5, QuantileInterpolOptions::Linear).unwrap() // unwrap fine since quantile in range + self.quantile(0.5, QuantileMethod::Linear).unwrap() // unwrap fine since quantile in range } } @@ -239,20 +227,19 @@ impl Float64Chunked { pub(crate) fn quantile_faster( mut self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let is_sorted = self.is_sorted_ascending_flag(); if let (Some(slice), false) = (self.cont_slice_mut(), is_sorted) { - quantile_slice(slice, quantile, interpol) + quantile_slice(slice, quantile, method) } else { - self.quantile(quantile, interpol) + self.quantile(quantile, method) } } pub(crate) fn median_faster(self) -> Option { - self.quantile_faster(0.5, QuantileInterpolOptions::Linear) - .unwrap() + self.quantile_faster(0.5, QuantileMethod::Linear).unwrap() } } @@ -260,20 +247,19 @@ impl Float32Chunked { pub(crate) fn quantile_faster( mut self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult> { // in case of sorted data, the sort is free, so don't take quickselect route let is_sorted = self.is_sorted_ascending_flag(); if let (Some(slice), false) = (self.cont_slice_mut(), is_sorted) { - quantile_slice(slice, quantile, interpol).map(|v| v.map(|v| v as f32)) + quantile_slice(slice, quantile, method).map(|v| v.map(|v| v as f32)) } else { - self.quantile(quantile, interpol) + self.quantile(quantile, method) } } pub(crate) fn median_faster(self) -> Option { - self.quantile_faster(0.5, QuantileInterpolOptions::Linear) - .unwrap() + self.quantile_faster(0.5, QuantileMethod::Linear).unwrap() } } diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 33f43d530e45..a3e7f04cc9e1 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -278,11 +278,7 @@ pub trait ChunkQuantile { } /// Aggregate a given quantile of the ChunkedArray. /// Returns `None` if the array is empty or only contains null values. - fn quantile( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult> { + fn quantile(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult> { Ok(None) } } diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index d83b91b78cff..e66c1ad12875 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -560,12 +560,12 @@ impl Column { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Self { // @scalar-opt unsafe { self.as_materialized_series() - .agg_quantile(groups, quantile, interpol) + .agg_quantile(groups, quantile, method) } .into() } diff --git a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs index fe71148cd49b..aaf24a470969 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs @@ -236,7 +236,7 @@ impl Series { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { // Prevent a rechunk for every individual group. let s = if groups.len() > 1 { @@ -247,13 +247,12 @@ impl Series { use DataType::*; match s.dtype() { - Float32 => s.f32().unwrap().agg_quantile(groups, quantile, interpol), - Float64 => s.f64().unwrap().agg_quantile(groups, quantile, interpol), + Float32 => s.f32().unwrap().agg_quantile(groups, quantile, method), + Float64 => s.f64().unwrap().agg_quantile(groups, quantile, method), dt if dt.is_numeric() || dt.is_temporal() => { let ca = s.to_physical_repr(); let physical_type = ca.dtype(); - let s = - apply_method_physical_integer!(ca, agg_quantile, groups, quantile, interpol); + let s = apply_method_physical_integer!(ca, agg_quantile, groups, quantile, method); if dt.is_logical() { // back to physical and then // back to logical type diff --git a/crates/polars-core/src/frame/group_by/aggregations/mod.rs b/crates/polars-core/src/frame/group_by/aggregations/mod.rs index 092d660fb4d2..19b8d5c2d061 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/mod.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/mod.rs @@ -13,7 +13,7 @@ use arrow::legacy::kernels::rolling::no_nulls::{ }; use arrow::legacy::kernels::rolling::nulls::RollingAggWindowNulls; use arrow::legacy::kernels::take_agg::*; -use arrow::legacy::prelude::QuantileInterpolOptions; +use arrow::legacy::prelude::QuantileMethod; use arrow::legacy::trusted_len::TrustedLenPush; use arrow::types::NativeType; use num_traits::pow::Pow; @@ -295,8 +295,7 @@ impl_take_extremum!(float: f64); /// This trait will ensure the specific dispatch works without complicating /// the trait bounds. trait QuantileDispatcher { - fn _quantile(self, quantile: f64, interpol: QuantileInterpolOptions) - -> PolarsResult>; + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult>; fn _median(self) -> Option; } @@ -307,12 +306,8 @@ where T::Native: Ord, ChunkedArray: IntoSeries, { - fn _quantile( - self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { - self.quantile_faster(quantile, interpol) + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + self.quantile_faster(quantile, method) } fn _median(self) -> Option { self.median_faster() @@ -320,24 +315,16 @@ where } impl QuantileDispatcher for Float32Chunked { - fn _quantile( - self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { - self.quantile_faster(quantile, interpol) + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + self.quantile_faster(quantile, method) } fn _median(self) -> Option { self.median_faster() } } impl QuantileDispatcher for Float64Chunked { - fn _quantile( - self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult> { - self.quantile_faster(quantile, interpol) + fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult> { + self.quantile_faster(quantile, method) } fn _median(self) -> Option { self.median_faster() @@ -348,7 +335,7 @@ unsafe fn agg_quantile_generic( ca: &ChunkedArray, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series where T: PolarsNumericType, @@ -371,7 +358,7 @@ where } let take = { ca.take_unchecked(idx) }; // checked with invalid quantile check - take._quantile(quantile, interpol).unwrap_unchecked() + take._quantile(quantile, method).unwrap_unchecked() }) }, GroupsProxy::Slice { groups, .. } => { @@ -390,7 +377,7 @@ where offset_iter, Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })), ), Some(validity) => { @@ -400,7 +387,7 @@ where offset_iter, Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })), ) }, @@ -418,7 +405,7 @@ where let arr_group = _slice_from_offsets(ca, first, len); // unwrap checked with invalid quantile check arr_group - ._quantile(quantile, interpol) + ._quantile(quantile, method) .unwrap_unchecked() .map(|flt| NumCast::from(flt).unwrap_unchecked()) }, @@ -450,7 +437,7 @@ where }) }, GroupsProxy::Slice { .. } => { - agg_quantile_generic::(ca, groups, 0.5, QuantileInterpolOptions::Linear) + agg_quantile_generic::(ca, groups, 0.5, QuantileMethod::Linear) }, } } @@ -977,9 +964,9 @@ impl Float32Chunked { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { - agg_quantile_generic::<_, Float32Type>(self, groups, quantile, interpol) + agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method) } pub(crate) unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { agg_median_generic::<_, Float32Type>(self, groups) @@ -990,9 +977,9 @@ impl Float64Chunked { &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { - agg_quantile_generic::<_, Float64Type>(self, groups, quantile, interpol) + agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method) } pub(crate) unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { agg_median_generic::<_, Float64Type>(self, groups) @@ -1184,9 +1171,9 @@ where &self, groups: &GroupsProxy, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Series { - agg_quantile_generic::<_, Float64Type>(self, groups, quantile, interpol) + agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method) } pub(crate) unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series { agg_median_generic::<_, Float64Type>(self, groups) diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index 9c56f7c49122..31936a3a5906 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -594,18 +594,14 @@ impl<'df> GroupBy<'df> { /// /// ```rust /// # use polars_core::prelude::*; - /// # use arrow::legacy::prelude::QuantileInterpolOptions; + /// # use arrow::legacy::prelude::QuantileMethod; /// /// fn example(df: DataFrame) -> PolarsResult { - /// df.group_by(["date"])?.select(["temp"]).quantile(0.2, QuantileInterpolOptions::default()) + /// df.group_by(["date"])?.select(["temp"]).quantile(0.2, QuantileMethod::default()) /// } /// ``` #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")] - pub fn quantile( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { + pub fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { polars_ensure!( (0.0..=1.0).contains(&quantile), ComputeError: "`quantile` should be within 0.0 and 1.0" @@ -614,9 +610,9 @@ impl<'df> GroupBy<'df> { for agg_col in agg_cols { let new_name = fmt_group_by_column( agg_col.name().as_str(), - GroupByMethod::Quantile(quantile, interpol), + GroupByMethod::Quantile(quantile, method), ); - let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, interpol) }; + let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, method) }; agg.rename(new_name); cols.push(agg); } @@ -868,7 +864,7 @@ pub enum GroupByMethod { Sum, Groups, NUnique, - Quantile(f64, QuantileInterpolOptions), + Quantile(f64, QuantileMethod), Count { include_nulls: bool, }, diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index b91df29a0a38..ace52993b8a1 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -358,11 +358,7 @@ impl SeriesTrait for SeriesWrap { Ok(Scalar::new(self.dtype().clone(), av)) } - fn quantile_reduce( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult { + fn quantile_reduce(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult { Ok(Scalar::new(self.dtype().clone(), AnyValue::Null)) } diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 30125ccc15b6..612505057eca 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -404,13 +404,9 @@ impl SeriesTrait for SeriesWrap { Ok(self.apply_scale(self.0.std_reduce(ddof))) } - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { self.0 - .quantile_reduce(quantile, interpol) + .quantile_reduce(quantile, method) .map(|v| self.apply_scale(v)) } diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 13b121aee0ca..803ca813aa1c 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -501,12 +501,8 @@ impl SeriesTrait for SeriesWrap { v.as_duration(self.0.time_unit()), )) } - fn quantile_reduce( - &self, - quantile: f64, - interpol: QuantileInterpolOptions, - ) -> PolarsResult { - let v = self.0.quantile_reduce(quantile, interpol)?; + fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult { + let v = self.0.quantile_reduce(quantile, method)?; let to = self.dtype().to_physical(); let v = v.value().cast(&to); Ok(Scalar::new( diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 24be56671d69..846e326d35b2 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -365,9 +365,9 @@ macro_rules! impl_dyn_series { fn quantile_reduce( &self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult { - QuantileAggSeries::quantile_reduce(&self.0, quantile, interpol) + QuantileAggSeries::quantile_reduce(&self.0, quantile, method) } #[cfg(feature = "bitwise")] fn and_reduce(&self) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 9d8357a905bc..b2cb97e39b69 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -468,9 +468,9 @@ macro_rules! impl_dyn_series { fn quantile_reduce( &self, quantile: f64, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> PolarsResult { - QuantileAggSeries::quantile_reduce(&self.0, quantile, interpol) + QuantileAggSeries::quantile_reduce(&self.0, quantile, method) } #[cfg(feature = "bitwise")] diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index 1ee69300fa92..14a0752eae1e 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -519,11 +519,7 @@ pub trait SeriesTrait: polars_bail!(opq = std, self._dtype()); } /// Get the quantile of the ChunkedArray as a new Series of length 1. - fn quantile_reduce( - &self, - _quantile: f64, - _interpol: QuantileInterpolOptions, - ) -> PolarsResult { + fn quantile_reduce(&self, _quantile: f64, _method: QuantileMethod) -> PolarsResult { polars_bail!(opq = quantile, self._dtype()); } /// Get the bitwise AND of the Series as a new Series of length 1, diff --git a/crates/polars-expr/src/expressions/aggregation.rs b/crates/polars-expr/src/expressions/aggregation.rs index af5383e83c83..f1cfa5251899 100644 --- a/crates/polars-expr/src/expressions/aggregation.rs +++ b/crates/polars-expr/src/expressions/aggregation.rs @@ -715,19 +715,19 @@ impl PartitionedAggregation for AggregationExpr { pub struct AggQuantileExpr { pub(crate) input: Arc, pub(crate) quantile: Arc, - pub(crate) interpol: QuantileInterpolOptions, + pub(crate) method: QuantileMethod, } impl AggQuantileExpr { pub fn new( input: Arc, quantile: Arc, - interpol: QuantileInterpolOptions, + method: QuantileMethod, ) -> Self { Self { input, quantile, - interpol, + method, } } @@ -750,7 +750,7 @@ impl PhysicalExpr for AggQuantileExpr { let input = self.input.evaluate(df, state)?; let quantile = self.get_quantile(df, state)?; input - .quantile_reduce(quantile, self.interpol) + .quantile_reduce(quantile, self.method) .map(|sc| sc.into_series(input.name().clone())) } #[allow(clippy::ptr_arg)] @@ -771,7 +771,7 @@ impl PhysicalExpr for AggQuantileExpr { let mut agg = unsafe { ac.flat_naive() .into_owned() - .agg_quantile(ac.groups(), quantile, self.interpol) + .agg_quantile(ac.groups(), quantile, self.method) }; agg.rename(keep_name); Ok(AggregationContext::from_agg_state( diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index 620f8bf87089..c4006de0c8ec 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -402,7 +402,9 @@ fn create_physical_expr_inner( }, _ => { if let IRAggExpr::Quantile { - quantile, interpol, .. + quantile, + method: interpol, + .. } = agg { let quantile = diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index fd4334cad066..f073c6643f8c 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -1004,7 +1004,7 @@ impl LazyFrame { /// ```rust /// use polars_core::prelude::*; /// use polars_lazy::prelude::*; - /// use arrow::legacy::prelude::QuantileInterpolOptions; + /// use arrow::legacy::prelude::QuantileMethod; /// /// fn example(df: DataFrame) -> LazyFrame { /// df.lazy() @@ -1012,7 +1012,7 @@ impl LazyFrame { /// .agg([ /// col("rain").min().alias("min_rain"), /// col("rain").sum().alias("sum_rain"), - /// col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"), + /// col("rain").quantile(lit(0.5), QuantileMethod::Nearest).alias("median_rain"), /// ]) /// } /// ``` @@ -1495,10 +1495,10 @@ impl LazyFrame { } /// Aggregate all the columns as their quantile values. - pub fn quantile(self, quantile: Expr, interpol: QuantileInterpolOptions) -> Self { + pub fn quantile(self, quantile: Expr, method: QuantileMethod) -> Self { self.map_private(DslFunction::Stats(StatsFunction::Quantile { quantile, - interpol, + method, })) } @@ -1885,7 +1885,7 @@ impl LazyGroupBy { /// ```rust /// use polars_core::prelude::*; /// use polars_lazy::prelude::*; - /// use arrow::legacy::prelude::QuantileInterpolOptions; + /// use arrow::legacy::prelude::QuantileMethod; /// /// fn example(df: DataFrame) -> LazyFrame { /// df.lazy() @@ -1893,7 +1893,7 @@ impl LazyGroupBy { /// .agg([ /// col("rain").min().alias("min_rain"), /// col("rain").sum().alias("sum_rain"), - /// col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"), + /// col("rain").quantile(lit(0.5), QuantileMethod::Nearest).alias("median_rain"), /// ]) /// } /// ``` diff --git a/crates/polars-lazy/src/lib.rs b/crates/polars-lazy/src/lib.rs index 3059384a1c8c..f3dff5710170 100644 --- a/crates/polars-lazy/src/lib.rs +++ b/crates/polars-lazy/src/lib.rs @@ -104,7 +104,7 @@ //! use polars_core::prelude::*; //! use polars_core::df; //! use polars_lazy::prelude::*; -//! use arrow::legacy::prelude::QuantileInterpolOptions; +//! use arrow::legacy::prelude::QuantileMethod; //! //! fn example() -> PolarsResult { //! let df = df!( @@ -118,7 +118,7 @@ //! .agg([ //! col("rain").min().alias("min_rain"), //! col("rain").sum().alias("sum_rain"), -//! col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"), +//! col("rain").quantile(lit(0.5), QuantileMethod::Nearest).alias("median_rain"), //! ]) //! .sort(["date"], Default::default()) //! .collect() diff --git a/crates/polars-ops/src/series/ops/cut.rs b/crates/polars-ops/src/series/ops/cut.rs index c78aa67efca8..b7b8d3e9f179 100644 --- a/crates/polars-ops/src/series/ops/cut.rs +++ b/crates/polars-ops/src/series/ops/cut.rs @@ -144,11 +144,7 @@ pub fn qcut( let s2 = s.sort(SortOptions::default())?; let ca = s2.f64()?; - let f = |&p| { - ca.quantile(p, QuantileInterpolOptions::Linear) - .unwrap() - .unwrap() - }; + let f = |&p| ca.quantile(p, QuantileMethod::Linear).unwrap().unwrap(); let mut qbreaks: Vec<_> = probs.iter().map(f).collect(); qbreaks.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index 33eb20e86da6..32fa45528d3a 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -33,7 +33,7 @@ pub enum AggExpr { Quantile { expr: Arc, quantile: Arc, - interpol: QuantileInterpolOptions, + method: QuantileMethod, }, Sum(Arc), AggGroups(Arc), diff --git a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs index e1ef64ee02ec..8363f6baa2fa 100644 --- a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs +++ b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs @@ -33,8 +33,8 @@ pub fn median(name: &str) -> Expr { } /// Find a specific quantile of all the values in the column named `name`. -pub fn quantile(name: &str, quantile: Expr, interpol: QuantileInterpolOptions) -> Expr { - col(name).quantile(quantile, interpol) +pub fn quantile(name: &str, quantile: Expr, method: QuantileMethod) -> Expr { + col(name).quantile(quantile, method) } /// Negates a boolean column. diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 5d814877a977..a88ff858e6ee 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -45,7 +45,7 @@ use std::sync::Arc; pub use arity::*; #[cfg(feature = "dtype-array")] pub use array::*; -use arrow::legacy::prelude::QuantileInterpolOptions; +use arrow::legacy::prelude::QuantileMethod; pub use expr::*; pub use function_expr::schema::FieldsMapper; pub use function_expr::*; @@ -227,11 +227,11 @@ impl Expr { } /// Compute the quantile per group. - pub fn quantile(self, quantile: Expr, interpol: QuantileInterpolOptions) -> Self { + pub fn quantile(self, quantile: Expr, method: QuantileMethod) -> Self { AggExpr::Quantile { expr: Arc::new(self), quantile: Arc::new(quantile), - interpol, + method, } .into() } @@ -1358,13 +1358,13 @@ impl Expr { pub fn rolling_quantile_by( self, by: Expr, - interpol: QuantileInterpolOptions, + method: QuantileMethod, quantile: f64, mut options: RollingOptionsDynamicWindow, ) -> Expr { options.fn_params = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })); self.finish_rolling_by(by, options, RollingFunctionBy::QuantileBy) @@ -1385,7 +1385,7 @@ impl Expr { /// Apply a rolling median based on another column. #[cfg(feature = "rolling_window_by")] pub fn rolling_median_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { - self.rolling_quantile_by(by, QuantileInterpolOptions::Linear, 0.5, options) + self.rolling_quantile_by(by, QuantileMethod::Linear, 0.5, options) } /// Apply a rolling minimum. @@ -1425,7 +1425,7 @@ impl Expr { /// See: [`RollingAgg::rolling_median`] #[cfg(feature = "rolling_window")] pub fn rolling_median(self, options: RollingOptionsFixedWindow) -> Expr { - self.rolling_quantile(QuantileInterpolOptions::Linear, 0.5, options) + self.rolling_quantile(QuantileMethod::Linear, 0.5, options) } /// Apply a rolling quantile. @@ -1434,13 +1434,13 @@ impl Expr { #[cfg(feature = "rolling_window")] pub fn rolling_quantile( self, - interpol: QuantileInterpolOptions, + method: QuantileMethod, quantile: f64, mut options: RollingOptionsFixedWindow, ) -> Expr { options.fn_params = Some(RollingFnParams::Quantile(RollingQuantileParams { prob: quantile, - interpol, + method, })); self.finish_rolling(options, RollingFunction::Quantile) diff --git a/crates/polars-plan/src/plans/aexpr/mod.rs b/crates/polars-plan/src/plans/aexpr/mod.rs index 565710c0dbaf..286ea86ac968 100644 --- a/crates/polars-plan/src/plans/aexpr/mod.rs +++ b/crates/polars-plan/src/plans/aexpr/mod.rs @@ -44,7 +44,7 @@ pub enum IRAggExpr { Quantile { expr: Node, quantile: Node, - interpol: QuantileInterpolOptions, + method: QuantileMethod, }, Sum(Node), Count(Node, bool), @@ -62,7 +62,9 @@ impl Hash for IRAggExpr { Self::Min { propagate_nans, .. } | Self::Max { propagate_nans, .. } => { propagate_nans.hash(state) }, - Self::Quantile { interpol, .. } => interpol.hash(state), + Self::Quantile { + method: interpol, .. + } => interpol.hash(state), Self::Std(_, v) | Self::Var(_, v) => v.hash(state), #[cfg(feature = "bitwise")] Self::Bitwise(_, f) => f.hash(state), @@ -92,7 +94,7 @@ impl IRAggExpr { propagate_nans: r, .. }, ) => l == r, - (Quantile { interpol: l, .. }, Quantile { interpol: r, .. }) => l == r, + (Quantile { method: l, .. }, Quantile { method: r, .. }) => l == r, (Std(_, l), Std(_, r)) => l == r, (Var(_, l), Var(_, r)) => l == r, #[cfg(feature = "bitwise")] diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index c0178f5b383c..7ee9c7f069d7 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -738,9 +738,9 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult |name| col(name.clone()).std(ddof), &input_schema, ), - StatsFunction::Quantile { quantile, interpol } => stats_helper( + StatsFunction::Quantile { quantile, method } => stats_helper( |dt| dt.is_numeric(), - |name| col(name.clone()).quantile(quantile.clone(), interpol), + |name| col(name.clone()).quantile(quantile.clone(), method), &input_schema, ), StatsFunction::Mean => stats_helper( diff --git a/crates/polars-plan/src/plans/conversion/expr_to_ir.rs b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs index 6873ad3f6851..d3e0c17f8098 100644 --- a/crates/polars-plan/src/plans/conversion/expr_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs @@ -237,11 +237,11 @@ pub(super) fn to_aexpr_impl( AggExpr::Quantile { expr, quantile, - interpol, + method, } => IRAggExpr::Quantile { expr: to_aexpr_impl_materialized_lit(owned(expr), arena, state)?, quantile: to_aexpr_impl_materialized_lit(owned(quantile), arena, state)?, - interpol, + method, }, AggExpr::Sum(expr) => { IRAggExpr::Sum(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) diff --git a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs index 5d2e4c373b30..160b70951962 100644 --- a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs +++ b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs @@ -129,14 +129,14 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { IRAggExpr::Quantile { expr, quantile, - interpol, + method, } => { let expr = node_to_expr(expr, expr_arena); let quantile = node_to_expr(quantile, expr_arena); AggExpr::Quantile { expr: Arc::new(expr), quantile: Arc::new(quantile), - interpol, + method, } .into() }, diff --git a/crates/polars-plan/src/plans/functions/dsl.rs b/crates/polars-plan/src/plans/functions/dsl.rs index e470bd3044bc..f1aa33a7e7dd 100644 --- a/crates/polars-plan/src/plans/functions/dsl.rs +++ b/crates/polars-plan/src/plans/functions/dsl.rs @@ -72,7 +72,7 @@ pub enum StatsFunction { }, Quantile { quantile: Expr, - interpol: QuantileInterpolOptions, + method: QuantileMethod, }, Median, Mean, diff --git a/crates/polars-plan/src/plans/visitor/expr.rs b/crates/polars-plan/src/plans/visitor/expr.rs index 71b287d03b85..62a64319ae2e 100644 --- a/crates/polars-plan/src/plans/visitor/expr.rs +++ b/crates/polars-plan/src/plans/visitor/expr.rs @@ -67,7 +67,7 @@ impl TreeWalker for Expr { Mean(x) => Mean(am(x, f)?), Implode(x) => Implode(am(x, f)?), Count(x, nulls) => Count(am(x, f)?, nulls), - Quantile { expr, quantile, interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, interpol }, + Quantile { expr, quantile, method: interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, method: interpol }, Sum(x) => Sum(am(x, f)?), AggGroups(x) => AggGroups(am(x, f)?), Std(x, ddf) => Std(am(x, f)?, ddf), diff --git a/crates/polars-python/src/conversion/mod.rs b/crates/polars-python/src/conversion/mod.rs index abde51745554..26bd02c6e540 100644 --- a/crates/polars-python/src/conversion/mod.rs +++ b/crates/polars-python/src/conversion/mod.rs @@ -986,17 +986,18 @@ impl<'py> FromPyObject<'py> for Wrap { } } -impl<'py> FromPyObject<'py> for Wrap { +impl<'py> FromPyObject<'py> for Wrap { fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { let parsed = match &*ob.extract::()? { - "lower" => QuantileInterpolOptions::Lower, - "higher" => QuantileInterpolOptions::Higher, - "nearest" => QuantileInterpolOptions::Nearest, - "linear" => QuantileInterpolOptions::Linear, - "midpoint" => QuantileInterpolOptions::Midpoint, + "lower" => QuantileMethod::Lower, + "higher" => QuantileMethod::Higher, + "nearest" => QuantileMethod::Nearest, + "linear" => QuantileMethod::Linear, + "midpoint" => QuantileMethod::Midpoint, + "equiprobable" => QuantileMethod::Equiprobable, v => { return Err(PyValueError::new_err(format!( - "`interpolation` must be one of {{'lower', 'higher', 'nearest', 'linear', 'midpoint'}}, got {v}", + "`interpolation` must be one of {{'lower', 'higher', 'nearest', 'linear', 'midpoint', 'equiprobable'}}, got {v}", ))) } }; diff --git a/crates/polars-python/src/expr/general.rs b/crates/polars-python/src/expr/general.rs index d0b6d30c31e1..604049f62b66 100644 --- a/crates/polars-python/src/expr/general.rs +++ b/crates/polars-python/src/expr/general.rs @@ -149,7 +149,7 @@ impl PyExpr { fn implode(&self) -> Self { self.inner.clone().implode().into() } - fn quantile(&self, quantile: Self, interpolation: Wrap) -> Self { + fn quantile(&self, quantile: Self, interpolation: Wrap) -> Self { self.inner .clone() .quantile(quantile.inner, interpolation.0) diff --git a/crates/polars-python/src/expr/rolling.rs b/crates/polars-python/src/expr/rolling.rs index 629f1eab391d..a5ef9213128f 100644 --- a/crates/polars-python/src/expr/rolling.rs +++ b/crates/polars-python/src/expr/rolling.rs @@ -276,7 +276,7 @@ impl PyExpr { fn rolling_quantile( &self, quantile: f64, - interpolation: Wrap, + interpolation: Wrap, window_size: usize, weights: Option>, min_periods: Option, @@ -302,7 +302,7 @@ impl PyExpr { &self, by: PyExpr, quantile: f64, - interpolation: Wrap, + interpolation: Wrap, window_size: &str, min_periods: usize, closed: Wrap, diff --git a/crates/polars-python/src/lazyframe/general.rs b/crates/polars-python/src/lazyframe/general.rs index 85200e339065..fa28a9f8e5ed 100644 --- a/crates/polars-python/src/lazyframe/general.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -1051,7 +1051,7 @@ impl PyLazyFrame { out.into() } - fn quantile(&self, quantile: PyExpr, interpolation: Wrap) -> Self { + fn quantile(&self, quantile: PyExpr, interpolation: Wrap) -> Self { let ldf = self.ldf.clone(); let out = ldf.quantile(quantile.inner, interpolation.0); out.into() diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index 67c25d755084..07d2f872437c 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -2,7 +2,7 @@ use polars::datatypes::TimeUnit; #[cfg(feature = "iejoin")] use polars::prelude::InequalityOperator; use polars::series::ops::NullBehavior; -use polars_core::prelude::{NonExistent, QuantileInterpolOptions}; +use polars_core::prelude::{NonExistent, QuantileMethod}; use polars_core::series::IsSorted; use polars_ops::prelude::ClosedInterval; use polars_ops::series::InterpolationMethod; @@ -700,16 +700,17 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { IRAggExpr::Quantile { expr, quantile, - interpol, + method: interpol, } => Agg { name: "quantile".to_object(py), arguments: vec![expr.0, quantile.0], options: match interpol { - QuantileInterpolOptions::Nearest => "nearest", - QuantileInterpolOptions::Lower => "lower", - QuantileInterpolOptions::Higher => "higher", - QuantileInterpolOptions::Midpoint => "midpoint", - QuantileInterpolOptions::Linear => "linear", + QuantileMethod::Nearest => "nearest", + QuantileMethod::Lower => "lower", + QuantileMethod::Higher => "higher", + QuantileMethod::Midpoint => "midpoint", + QuantileMethod::Linear => "linear", + QuantileMethod::Equiprobable => "equiprobable", } .to_object(py), }, diff --git a/crates/polars-python/src/series/aggregation.rs b/crates/polars-python/src/series/aggregation.rs index dbcbad59ddac..5aa8ee16639e 100644 --- a/crates/polars-python/src/series/aggregation.rs +++ b/crates/polars-python/src/series/aggregation.rs @@ -105,11 +105,7 @@ impl PySeries { .into_py(py)) } - fn quantile( - &self, - quantile: f64, - interpolation: Wrap, - ) -> PyResult { + fn quantile(&self, quantile: f64, interpolation: Wrap) -> PyResult { let bind = self.series.quantile_reduce(quantile, interpolation.0); let sc = bind.map_err(PyPolarsErr::from)?; diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 48011a323764..5f6c311a5e15 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -3,7 +3,7 @@ use std::ops::Sub; use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions}; use polars_core::export::regex; use polars_core::prelude::{ - polars_bail, polars_err, DataType, PolarsResult, QuantileInterpolOptions, Schema, TimeUnit, + polars_bail, polars_err, DataType, PolarsResult, QuantileMethod, Schema, TimeUnit, }; use polars_lazy::dsl::Expr; #[cfg(feature = "list_eval")] @@ -513,6 +513,13 @@ pub(crate) enum PolarsSQLFunctions { /// SELECT QUANTILE_CONT(column_1) FROM df; /// ``` QuantileCont, + /// SQL 'quantile_disc' function + /// Divides the [0, 1] interval into equal-length subintervals, each corresponding to a value, + /// and returns the value associated with the subinterval where the quantile value falls. + /// ```sql + /// SELECT QUANTILE_DISC(column_1) FROM df; + /// ``` + QuantileDisc, /// SQL 'min' function /// Returns the smallest (minimum) of all the elements in the grouping. /// ```sql @@ -688,6 +695,7 @@ impl PolarsSQLFunctions { "ltrim", "max", "median", + "quantile_disc", "min", "mod", "nullif", @@ -696,6 +704,7 @@ impl PolarsSQLFunctions { "pow", "power", "quantile_cont", + "quantile_disc", "radians", "regexp_like", "replace", @@ -829,6 +838,7 @@ impl PolarsSQLFunctions { "max" => Self::Max, "median" => Self::Median, "quantile_cont" => Self::QuantileCont, + "quantile_disc" => Self::QuantileDisc, "min" => Self::Min, "stdev" | "stddev" | "stdev_samp" | "stddev_samp" => Self::StdDev, "sum" => Self::Sum, @@ -1275,11 +1285,37 @@ impl SQLFunctionVisitor<'_> { }, _ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_CONT ({})", args[1]) }; - Ok(e.quantile(value, QuantileInterpolOptions::Linear)) + Ok(e.quantile(value, QuantileMethod::Linear)) }), _ => polars_bail!(SQLSyntax: "QUANTILE_CONT expects 2 arguments (found {})", args.len()), } }, + QuantileDisc => { + let args = extract_args(function)?; + match args.len() { + 2 => self.try_visit_binary(|e, q| { + let value = match q { + Expr::Literal(LiteralValue::Float(f)) => { + if (0.0..=1.0).contains(&f) { + Expr::from(f) + } else { + polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1]) + } + }, + Expr::Literal(LiteralValue::Int(n)) => { + if (0..=1).contains(&n) { + Expr::from(n as f64) + } else { + polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1]) + } + }, + _ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_DISC ({})", args[1]) + }; + Ok(e.quantile(value, QuantileMethod::Equiprobable)) + }), + _ => polars_bail!(SQLSyntax: "QUANTILE_DISC expects 2 arguments (found {})", args.len()), + } + }, Min => self.visit_unary_with_opt_cumulative(Expr::min, Expr::cum_min), StdDev => self.visit_unary(|e| e.std(1)), Sum => self.visit_unary_with_opt_cumulative(Expr::sum, Expr::cum_sum), diff --git a/crates/polars-sql/tests/functions_aggregate.rs b/crates/polars-sql/tests/functions_aggregate.rs index 621ca18bd355..092a340f5f18 100644 --- a/crates/polars-sql/tests/functions_aggregate.rs +++ b/crates/polars-sql/tests/functions_aggregate.rs @@ -5,9 +5,7 @@ use polars_sql::*; fn create_df() -> LazyFrame { df! { - "Year" => [2018, 2018, 2019, 2019, 2020, 2020], - "Country" => ["US", "UK", "US", "UK", "US", "UK"], - "Sales" => [1000, 2000, 3000, 4000, 5000, 6000] + "Data" => [1000, 2000, 3000, 4000, 5000, 6000] } .unwrap() .lazy() @@ -41,9 +39,9 @@ fn create_expected(expr: Expr, sql: &str) -> (DataFrame, DataFrame) { #[test] fn test_median() { - let expr = col("Sales").median(); + let expr = col("Data").median(); - let sql_expr = "MEDIAN(Sales)"; + let sql_expr = "MEDIAN(Data)"; let (expected, actual) = create_expected(expr, sql_expr); assert!(expected.equals(&actual)) @@ -52,9 +50,9 @@ fn test_median() { #[test] fn test_quantile_cont() { for &q in &[0.25, 0.5, 0.75] { - let expr = col("Sales").quantile(lit(q), QuantileInterpolOptions::Linear); + let expr = col("Data").quantile(lit(q), QuantileMethod::Linear); - let sql_expr = format!("QUANTILE_CONT(Sales, {})", q); + let sql_expr = format!("QUANTILE_CONT(Data, {})", q); let (expected, actual) = create_expected(expr, &sql_expr); assert!( @@ -63,3 +61,61 @@ fn test_quantile_cont() { ) } } + +#[test] +fn test_quantile_disc() { + for &q in &[0.25, 0.5, 0.75] { + let expr = col("Data").quantile(lit(q), QuantileMethod::Equiprobable); + + let sql_expr = format!("QUANTILE_DISC(Data, {})", q); + let (expected, actual) = create_expected(expr, &sql_expr); + + assert!(expected.equals(&actual)) + } +} + +#[test] +fn test_quantile_out_of_range() { + for &q in &["-1", "2", "-0.01", "1.01"] { + for &func in &["QUANTILE_CONT", "QUANTILE_DISC"] { + let query = format!("SELECT {func}(Data, {q})"); + let mut ctx = SQLContext::new(); + ctx.register("df", create_df()); + let actual = ctx.execute(&query); + assert!(actual.is_err()) + } + } +} + +#[test] +fn test_quantile_disc_conformance() { + let expected = df![ + "q" => [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], + "Data" => [1000, 1000, 2000, 2000, 3000, 3000, 4000, 5000, 5000, 6000, 6000], + ] + .unwrap(); + + let mut ctx = SQLContext::new(); + ctx.register("df", create_df()); + + let mut actual: Option = None; + for &q in &[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] { + let res = ctx + .execute(&format!( + "SELECT {q}::float as q, QUANTILE_DISC(Data, {q}) as Data FROM df" + )) + .unwrap() + .collect() + .unwrap(); + actual = if let Some(df) = actual { + Some(df.vstack(&res).unwrap()) + } else { + Some(res) + }; + } + + assert!( + expected.equals(actual.as_ref().unwrap()), + "expected {expected:?}, got {actual:?}" + ) +} diff --git a/crates/polars-time/src/group_by/dynamic.rs b/crates/polars-time/src/group_by/dynamic.rs index 8a8d2312d580..3ff08ee4d308 100644 --- a/crates/polars-time/src/group_by/dynamic.rs +++ b/crates/polars-time/src/group_by/dynamic.rs @@ -789,12 +789,12 @@ mod test { let quantile = unsafe { a.as_materialized_series() - .agg_quantile(&groups, 0.5, QuantileInterpolOptions::Linear) + .agg_quantile(&groups, 0.5, QuantileMethod::Linear) }; let expected = Series::new("".into(), [3.0, 5.0, 5.0, 6.0, 5.5, 1.0]); assert_eq!(quantile, expected); - let quantile = unsafe { nulls.agg_quantile(&groups, 0.5, QuantileInterpolOptions::Linear) }; + let quantile = unsafe { nulls.agg_quantile(&groups, 0.5, QuantileMethod::Linear) }; let expected = Series::new("".into(), [3.0, 5.0, 5.0, 7.0, 5.5, 1.0]); assert_eq!(quantile, expected); diff --git a/crates/polars/tests/it/lazy/aggregation.rs b/crates/polars/tests/it/lazy/aggregation.rs index ad043e698e2e..10c386037d17 100644 --- a/crates/polars/tests/it/lazy/aggregation.rs +++ b/crates/polars/tests/it/lazy/aggregation.rs @@ -26,7 +26,7 @@ fn test_lazy_agg() { col("rain").min().alias("min"), col("rain").sum().alias("sum"), col("rain") - .quantile(lit(0.5), QuantileInterpolOptions::default()) + .quantile(lit(0.5), QuantileMethod::default()) .alias("median_rain"), ]) .sort(["date"], Default::default());