Skip to content

Commit

Permalink
perf: Special no null branch in arg-sort
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 7, 2024
1 parent a0a577a commit 0460731
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 77 deletions.
48 changes: 41 additions & 7 deletions crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
use polars_utils::iter::EnumerateIdxTrait;

use super::*;

// Reduce monomorphisation.
fn sort_impl<T>(vals: &mut [(IdxSize, T)], options: SortOptions)
where
T: TotalOrd + Send + Sync,
{
sort_by_branch(
vals,
options.descending,
|a, b| a.1.tot_cmp(&b.1),
options.multithreaded,
);
}

pub(super) fn arg_sort<I, J, T>(
name: &str,
iters: I,
Expand All @@ -12,7 +27,6 @@ where
J: IntoIterator<Item = Option<T>>,
T: TotalOrd + Send + Sync,
{
let descending = options.descending;
let nulls_last = options.nulls_last;

let mut vals = Vec::with_capacity(len - null_count);
Expand All @@ -37,12 +51,7 @@ where
vals.extend(iter);
}

sort_by_branch(
vals.as_mut_slice(),
descending,
|a, b| a.1.tot_cmp(&b.1),
options.multithreaded,
);
sort_impl(vals.as_mut_slice(), options);

let iter = vals.into_iter().map(|(idx, _v)| idx);
let idx = if nulls_last {
Expand All @@ -60,3 +69,28 @@ where

ChunkedArray::with_chunk(name, IdxArr::from_data_default(Buffer::from(idx), None))
}

pub(super) fn arg_sort_no_nulls<I, J, T>(
name: &str,
iters: I,
options: SortOptions,
len: usize,
) -> IdxCa
where
I: IntoIterator<Item = J>,
J: IntoIterator<Item = T>,
T: TotalOrd + Send + Sync,
{
let mut vals = Vec::with_capacity(len);

for arr_iter in iters {
vals.extend(arr_iter.into_iter().enumerate_idx());
}

sort_impl(vals.as_mut_slice(), options);

let iter = vals.into_iter().map(|(idx, _v)| idx);
let idx: Vec<_> = iter.collect_trusted();

ChunkedArray::with_chunk(name, IdxArr::from_data_default(Buffer::from(idx), None))
}
139 changes: 69 additions & 70 deletions crates/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ where
}
}

#[inline]
fn sort_unstable_by_branch<T, C>(slice: &mut [T], descending: bool, cmp: C, parallel: bool)
where
T: Send,
Expand All @@ -62,6 +61,19 @@ where
}
}

// Reduce monomorphisation.
fn sort_impl_unstable<T>(vals: &mut [T], options: SortOptions)
where
T: TotalOrd + Send + Sync,
{
sort_unstable_by_branch(
vals,
options.descending,
TotalOrd::tot_cmp,
options.multithreaded,
);
}

macro_rules! sort_with_fast_path {
($ca:ident, $options:expr) => {{
if $ca.is_empty() {
Expand Down Expand Up @@ -103,12 +115,7 @@ where
if ca.null_count() == 0 {
let mut vals = ca.to_vec_null_aware().left().unwrap();

sort_unstable_by_branch(
vals.as_mut_slice(),
options.descending,
TotalOrd::tot_cmp,
options.multithreaded,
);
sort_impl_unstable(vals.as_mut_slice(), options);

let mut ca = ChunkedArray::from_vec(ca.name(), vals);
let s = if options.descending {
Expand Down Expand Up @@ -139,12 +146,7 @@ where
&mut vals[null_count..]
};

sort_unstable_by_branch(
mut_slice,
options.descending,
TotalOrd::tot_cmp,
options.multithreaded,
);
sort_impl_unstable(mut_slice, options);

let mut validity = MutableBitmap::with_capacity(len);
if options.nulls_last {
Expand Down Expand Up @@ -176,31 +178,11 @@ fn arg_sort_numeric<T>(ca: &ChunkedArray<T>, options: SortOptions) -> IdxCa
where
T: PolarsNumericType,
{
let descending = options.descending;
if ca.null_count() == 0 {
let mut vals = Vec::with_capacity(ca.len());
let mut count: IdxSize = 0;
ca.downcast_iter().for_each(|arr| {
let values = arr.values();
let iter = values.iter().map(|&v| {
let i = count;
count += 1;
(i, v)
});
vals.extend_trusted_len(iter);
});

sort_by_branch(
vals.as_mut_slice(),
descending,
|a, b| a.1.tot_cmp(&b.1),
options.multithreaded,
);

let out: NoNull<IdxCa> = vals.into_iter().map(|(idx, _v)| idx).collect_trusted();
let mut out = out.into_inner();
out.rename(ca.name());
out
let iter = ca
.downcast_iter()
.map(|arr| arr.values().as_slice().iter().copied());
arg_sort::arg_sort_no_nulls(ca.name(), iter, options, ca.len())
} else {
let iter = ca
.downcast_iter()
Expand Down Expand Up @@ -337,12 +319,7 @@ impl ChunkSort<BinaryType> for BinaryChunked {
for arr in self.downcast_iter() {
v.extend(arr.non_null_values_iter());
}
sort_unstable_by_branch(
v.as_mut_slice(),
options.descending,
Ord::cmp,
options.multithreaded,
);
sort_impl_unstable(v.as_mut_slice(), options);

let len = self.len();
let null_count = self.null_count();
Expand Down Expand Up @@ -380,13 +357,22 @@ impl ChunkSort<BinaryType> for BinaryChunked {
}

fn arg_sort(&self, options: SortOptions) -> IdxCa {
arg_sort::arg_sort(
self.name(),
self.downcast_iter().map(|arr| arr.iter()),
options,
self.null_count(),
self.len(),
)
if self.null_count() == 0 {
arg_sort::arg_sort_no_nulls(
self.name(),
self.downcast_iter().map(|arr| arr.values_iter()),
options,
self.len(),
)
} else {
arg_sort::arg_sort(
self.name(),
self.downcast_iter().map(|arr| arr.iter()),
options,
self.null_count(),
self.len(),
)
}
}

fn arg_sort_multiple(
Expand Down Expand Up @@ -420,12 +406,7 @@ impl ChunkSort<BinaryOffsetType> for BinaryOffsetChunked {
v.extend(arr.non_null_values_iter());
}

sort_unstable_by_branch(
v.as_mut_slice(),
options.descending,
Ord::cmp,
options.multithreaded,
);
sort_impl_unstable(v.as_mut_slice(), options);

let mut values = Vec::<u8>::with_capacity(self.get_values_size());
let mut offsets = Vec::<i64>::with_capacity(self.len() + 1);
Expand Down Expand Up @@ -511,13 +492,22 @@ impl ChunkSort<BinaryOffsetType> for BinaryOffsetChunked {
}

fn arg_sort(&self, options: SortOptions) -> IdxCa {
arg_sort::arg_sort(
self.name(),
self.downcast_iter().map(|arr| arr.iter()),
options,
self.null_count(),
self.len(),
)
if self.null_count() == 0 {
arg_sort::arg_sort_no_nulls(
self.name(),
self.downcast_iter().map(|arr| arr.values_iter()),
options,
self.len(),
)
} else {
arg_sort::arg_sort(
self.name(),
self.downcast_iter().map(|arr| arr.iter()),
options,
self.null_count(),
self.len(),
)
}
}

/// # Panics
Expand Down Expand Up @@ -609,13 +599,22 @@ impl ChunkSort<BooleanType> for BooleanChunked {
}

fn arg_sort(&self, options: SortOptions) -> IdxCa {
arg_sort::arg_sort(
self.name(),
self.downcast_iter().map(|arr| arr.iter()),
options,
self.null_count(),
self.len(),
)
if self.null_count() == 0 {
arg_sort::arg_sort_no_nulls(
self.name(),
self.downcast_iter().map(|arr| arr.values_iter()),
options,
self.len(),
)
} else {
arg_sort::arg_sort(
self.name(),
self.downcast_iter().map(|arr| arr.iter()),
options,
self.null_count(),
self.len(),
)
}
}
fn arg_sort_multiple(
&self,
Expand Down

0 comments on commit 0460731

Please sign in to comment.