Skip to content

Commit

Permalink
feat: New quantile interpolation method & QUANTILE_DISC function in S…
Browse files Browse the repository at this point in the history
…QL (#19139)
  • Loading branch information
pomo-mondreganto authored Oct 16, 2024
1 parent 780430f commit 2736621
Show file tree
Hide file tree
Showing 43 changed files with 435 additions and 413 deletions.
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/legacy/kernels/rolling/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,5 @@ pub struct RollingVarParams {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RollingQuantileParams {
pub prob: f64,
pub interpol: QuantileInterpolOptions,
pub method: QuantileMethod,
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,19 @@ where

#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum QuantileInterpolOptions {
pub enum QuantileMethod {
#[default]
Nearest,
Lower,
Higher,
Midpoint,
Linear,
Equiprobable,
}

#[deprecated(note = "use QuantileMethod instead")]
pub type QuantileInterpolOptions = QuantileMethod;

pub(super) fn rolling_apply_weights<T, Fo, Fa>(
values: &[T],
window_size: usize,
Expand Down
50 changes: 30 additions & 20 deletions crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use num_traits::ToPrimitive;
use polars_error::polars_ensure;
use polars_utils::slice::GetSaferUnchecked;

use super::QuantileInterpolOptions::*;
use super::QuantileMethod::*;
use super::*;

pub struct QuantileWindow<'a, T: NativeType> {
sorted: SortedBuf<'a, T>,
prob: f64,
interpol: QuantileInterpolOptions,
method: QuantileMethod,
}

impl<
Expand All @@ -34,15 +34,15 @@ impl<
Self {
sorted: SortedBuf::new(slice, start, end),
prob: params.prob,
interpol: params.interpol,
method: params.method,
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
let vals = self.sorted.update(start, end);
let length = vals.len();

let idx = match self.interpol {
let idx = match self.method {
Linear => {
// Maybe add a fast path for median case? They could branch depending on odd/even.
let length_f = length as f64;
Expand Down Expand Up @@ -92,6 +92,7 @@ impl<
let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
std::cmp::min(idx, length - 1)
},
Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize,
};

// SAFETY:
Expand Down Expand Up @@ -134,7 +135,7 @@ where
unreachable!("expected Quantile params");
};
let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>(
params.interpol,
params.method,
min_periods,
window_size,
values,
Expand Down Expand Up @@ -170,7 +171,7 @@ where
Ok(rolling_apply_weighted_quantile(
values,
params.prob,
params.interpol,
params.method,
window_size,
min_periods,
offset_fn,
Expand All @@ -182,7 +183,7 @@ where
}

#[inline]
fn compute_wq<T>(buf: &[(T, f64)], p: f64, wsum: f64, interp: QuantileInterpolOptions) -> T
fn compute_wq<T>(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T
where
T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
{
Expand All @@ -201,7 +202,7 @@ where
(s_old, v_old, vk) = (s, vk, v);
s += w;
}
match (h == s_old, interp) {
match (h == s_old, method) {
(true, _) => v_old, // If we hit the break exactly interpolation shouldn't matter
(_, Lower) => v_old,
(_, Higher) => vk,
Expand All @@ -212,6 +213,14 @@ where
vk
}
},
(_, Equiprobable) => {
let threshold = (wsum * p).ceil() - 1.0;
if s > threshold {
vk
} else {
v_old
}
},
(_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(),
// This is seemingly the canonical way to do it.
(_, Linear) => {
Expand All @@ -224,7 +233,7 @@ where
fn rolling_apply_weighted_quantile<T, Fo>(
values: &[T],
p: f64,
interpolation: QuantileInterpolOptions,
method: QuantileMethod,
window_size: usize,
min_periods: usize,
det_offsets_fn: Fo,
Expand Down Expand Up @@ -252,7 +261,7 @@ where
.for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w));
}
buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0));
compute_wq(&buf, p, wsum, interpolation)
compute_wq(&buf, p, wsum, method)
})
.collect_trusted::<Vec<T>>();

Expand All @@ -273,7 +282,7 @@ mod test {
let values = &[1.0, 2.0, 3.0, 4.0];
let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 0.5,
interpol: Linear,
method: Linear,
}));
let out = rolling_quantile(values, 2, 2, false, None, med_pars.clone()).unwrap();
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
Expand Down Expand Up @@ -305,18 +314,19 @@ mod test {
fn test_rolling_quantile_limits() {
let values = &[1.0f64, 2.0, 3.0, 4.0];

let interpol_options = vec![
QuantileInterpolOptions::Lower,
QuantileInterpolOptions::Higher,
QuantileInterpolOptions::Nearest,
QuantileInterpolOptions::Midpoint,
QuantileInterpolOptions::Linear,
let methods = vec![
QuantileMethod::Lower,
QuantileMethod::Higher,
QuantileMethod::Nearest,
QuantileMethod::Midpoint,
QuantileMethod::Linear,
QuantileMethod::Equiprobable,
];

for interpol in interpol_options {
for method in methods {
let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 0.0,
interpol,
method,
}));
let out1 = rolling_min(values, 2, 2, false, None, None).unwrap();
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
Expand All @@ -328,7 +338,7 @@ mod test {

let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 1.0,
interpol,
method,
}));
let out1 = rolling_max(values, 2, 2, false, None, None).unwrap();
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
Expand Down
46 changes: 24 additions & 22 deletions crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::array::MutablePrimitiveArray;
pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
sorted: SortedBufNulls<'a, T>,
prob: f64,
interpol: QuantileInterpolOptions,
method: QuantileMethod,
}

impl<
Expand Down Expand Up @@ -39,7 +39,7 @@ impl<
Self {
sorted: SortedBufNulls::new(slice, validity, start, end),
prob: params.prob,
interpol: params.interpol,
method: params.method,
}
}

Expand All @@ -53,29 +53,30 @@ impl<
let values = &values[null_count..];
let length = values.len();

let mut idx = match self.interpol {
QuantileInterpolOptions::Nearest => ((length as f64) * self.prob) as usize,
QuantileInterpolOptions::Lower
| QuantileInterpolOptions::Midpoint
| QuantileInterpolOptions::Linear => {
let mut idx = match self.method {
QuantileMethod::Nearest => ((length as f64) * self.prob) as usize,
QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {
((length as f64 - 1.0) * self.prob).floor() as usize
},
QuantileInterpolOptions::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,
QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,
QuantileMethod::Equiprobable => {
((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize
},
};

idx = std::cmp::min(idx, length - 1);

// we can unwrap because we sliced of the nulls
match self.interpol {
QuantileInterpolOptions::Midpoint => {
match self.method {
QuantileMethod::Midpoint => {
let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
Some(
(values.get_unchecked_release(idx).unwrap()
+ values.get_unchecked_release(top_idx).unwrap())
/ T::from::<f64>(2.0f64).unwrap(),
)
},
QuantileInterpolOptions::Linear => {
QuantileMethod::Linear => {
let float_idx = (length as f64 - 1.0) * self.prob;
let top_idx = f64::ceil(float_idx) as usize;

Expand Down Expand Up @@ -136,7 +137,7 @@ where
};

let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>(
params.interpol,
params.method,
min_periods,
window_size,
arr.clone(),
Expand Down Expand Up @@ -171,7 +172,7 @@ mod test {
);
let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 0.5,
interpol: QuantileInterpolOptions::Linear,
method: QuantileMethod::Linear,
}));

let out = rolling_quantile(arr, 2, 2, false, None, med_pars.clone());
Expand Down Expand Up @@ -210,18 +211,19 @@ mod test {
Some(Bitmap::from(&[true, false, false, true, true])),
);

let interpol_options = vec![
QuantileInterpolOptions::Lower,
QuantileInterpolOptions::Higher,
QuantileInterpolOptions::Nearest,
QuantileInterpolOptions::Midpoint,
QuantileInterpolOptions::Linear,
let methods = vec![
QuantileMethod::Lower,
QuantileMethod::Higher,
QuantileMethod::Nearest,
QuantileMethod::Midpoint,
QuantileMethod::Linear,
QuantileMethod::Equiprobable,
];

for interpol in interpol_options {
for method in methods {
let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 0.0,
interpol,
method,
}));
let out1 = rolling_min(values, 2, 1, false, None, None);
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
Expand All @@ -233,7 +235,7 @@ mod test {

let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 1.0,
interpol,
method,
}));
let out1 = rolling_max(values, 2, 1, false, None, None);
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
Expand Down
Loading

0 comments on commit 2736621

Please sign in to comment.