Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust): Fix unsoundness in group_tuples_perfect #19359

Merged
merged 4 commits into from
Oct 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 79 additions & 135 deletions crates/polars-core/src/frame/group_by/perfect.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::fmt::Debug;
use std::mem::MaybeUninit;

use num_traits::{FromPrimitive, ToPrimitive};
use polars_utils::idx_vec::IdxVec;
Expand All @@ -17,162 +18,128 @@ where
T: PolarsIntegerType,
T::Native: ToPrimitive + FromPrimitive + Debug,
{
// Use the indexes as perfect groups
pub fn group_tuples_perfect(
/// Use the indexes as perfect groups.
///
/// # Safety
/// This ChunkedArray must contain each value in [0..num_groups) at least
/// once, and nothing outside this range.
pub unsafe fn group_tuples_perfect(
&self,
max: usize,
num_groups: usize,
mut multithreaded: bool,
group_capacity: usize,
) -> GroupsProxy {
multithreaded &= POOL.current_num_threads() > 1;
// The latest index will be used for the null sentinel.
let len = if self.null_count() > 0 {
// we add one to store the null sentinel group
max + 2
num_groups + 2
} else {
max + 1
num_groups + 1
};

// the latest index will be used for the null sentinel
let null_idx = len.saturating_sub(1);
let n_threads = POOL.current_num_threads();

let n_threads = POOL.current_num_threads();
let chunk_size = len / n_threads;

let (groups, first) = if multithreaded && chunk_size > 1 {
let mut groups: Vec<IdxVec> = unsafe { aligned_vec(len) };
let mut groups: Vec<IdxVec> = Vec::new();
groups.resize_with(len, || IdxVec::with_capacity(group_capacity));
let mut first: Vec<IdxSize> = unsafe { aligned_vec(len) };

// ensure we keep aligned to cache lines
let chunk_size = (chunk_size * size_of::<T::Native>()).next_multiple_of(64);
let chunk_size = chunk_size / size_of::<T::Native>();

let mut cache_line_offsets = Vec::with_capacity(n_threads + 1);
cache_line_offsets.push(0);
let mut current_offset = chunk_size;

while current_offset <= len {
cache_line_offsets.push(current_offset);
current_offset += chunk_size;
let mut first: Vec<IdxSize> = Vec::with_capacity(len);

// Round up offsets to nearest cache line for groups to reduce false sharing.
let groups_start = groups.as_ptr();
let mut per_thread_offsets = Vec::with_capacity(n_threads + 1);
per_thread_offsets.push(0);
for t in 0..n_threads {
let ideal_offset = (t + 1) * chunk_size;
let cache_aligned_offset =
ideal_offset + groups_start.wrapping_add(ideal_offset).align_offset(128);
per_thread_offsets.push(std::cmp::min(cache_aligned_offset, len));
}
cache_line_offsets.push(current_offset);

let groups_ptr = unsafe { SyncPtr::new(groups.as_mut_ptr()) };
let first_ptr = unsafe { SyncPtr::new(first.as_mut_ptr()) };

// The number of threads is dependent on the number of categoricals/ unique values
// as every at least writes to a single cache line
// lower bound per thread:
// 32bit: 16
// 64bit: 8
POOL.install(|| {
(0..cache_line_offsets.len() - 1)
.into_par_iter()
.for_each(|thread_no| {
let mut row_nr = 0 as IdxSize;
let start = cache_line_offsets[thread_no];
let start = T::Native::from_usize(start).unwrap();
let end = cache_line_offsets[thread_no + 1];
let end = T::Native::from_usize(end).unwrap();

// SAFETY: we don't alias
let groups =
unsafe { std::slice::from_raw_parts_mut(groups_ptr.get(), len) };
let first = unsafe { std::slice::from_raw_parts_mut(first_ptr.get(), len) };

for arr in self.downcast_iter() {
if arr.null_count() == 0 {
for &cat in arr.values().as_slice() {
(0..n_threads).into_par_iter().for_each(|thread_no| {
// We use raw pointers because the slices would overlap.
// However, each thread has its own range it is responsible for.
let groups = groups_ptr.get();
let first = first_ptr.get();
let start = per_thread_offsets[thread_no];
let start = T::Native::from_usize(start).unwrap();
let end = per_thread_offsets[thread_no + 1];
let end = T::Native::from_usize(end).unwrap();

let push_to_group = |cat, row_nr| unsafe {
debug_assert!(cat < len);
let buf = &mut *groups.add(cat);
buf.push(row_nr);
if buf.len() == 1 {
*first.add(cat) = row_nr;
}
};

let mut row_nr = 0 as IdxSize;
for arr in self.downcast_iter() {
if arr.null_count() == 0 {
for &cat in arr.values().as_slice() {
if cat >= start && cat < end {
push_to_group(cat.to_usize().unwrap(), row_nr);
}

row_nr += 1;
}
} else {
for opt_cat in arr.iter() {
if let Some(&cat) = opt_cat {
if cat >= start && cat < end {
let cat = cat.to_usize().unwrap();
let buf = unsafe { groups.get_unchecked_release_mut(cat) };
buf.push(row_nr);

unsafe {
if buf.len() == 1 {
// SAFETY: we just pushed
let first_value = buf.get_unchecked(0);
*first.get_unchecked_release_mut(cat) = *first_value
}
}
push_to_group(cat.to_usize().unwrap(), row_nr);
}
row_nr += 1;
} else if thread_no == n_threads - 1 {
// Last thread handles null values.
push_to_group(null_idx, row_nr);
}
} else {
for opt_cat in arr.iter() {
if let Some(&cat) = opt_cat {
// cannot factor out due to bchk
if cat >= start && cat < end {
let cat = cat.to_usize().unwrap();
let buf =
unsafe { groups.get_unchecked_release_mut(cat) };
buf.push(row_nr);

unsafe {
if buf.len() == 1 {
// SAFETY: we just pushed
let first_value = buf.get_unchecked(0);
*first.get_unchecked_release_mut(cat) =
*first_value
}
}
}
}
// last thread handles null values
else if thread_no == cache_line_offsets.len() - 2 {
let buf =
unsafe { groups.get_unchecked_release_mut(null_idx) };
buf.push(row_nr);
unsafe {
if buf.len() == 1 {
let first_value = buf.get_unchecked(0);
*first.get_unchecked_release_mut(null_idx) =
*first_value
}
}
}

row_nr += 1;
}
row_nr += 1;
}
}
});
}
});
});
unsafe {
groups.set_len(len);
first.set_len(len);
}
(groups, first)
} else {
let mut groups = Vec::with_capacity(len);
let mut first = vec![IdxSize::MAX; len];
let mut first = Vec::with_capacity(len);
let first_out = first.spare_capacity_mut();
groups.resize_with(len, || IdxVec::with_capacity(group_capacity));

let mut push_to_group = |cat, row_nr| unsafe {
let buf: &mut IdxVec = groups.get_unchecked_release_mut(cat);
buf.push(row_nr);
if buf.len() == 1 {
*first_out.get_unchecked_release_mut(cat) = MaybeUninit::new(row_nr);
}
};

let mut row_nr = 0 as IdxSize;
for arr in self.downcast_iter() {
for opt_cat in arr.iter() {
if let Some(cat) = opt_cat {
let group_id = cat.to_usize().unwrap();
let buf = unsafe { groups.get_unchecked_release_mut(group_id) };
buf.push(row_nr);

unsafe {
if buf.len() == 1 {
*first.get_unchecked_release_mut(group_id) = row_nr;
}
}
push_to_group(cat.to_usize().unwrap(), row_nr);
} else {
let buf = unsafe { groups.get_unchecked_release_mut(null_idx) };
buf.push(row_nr);
unsafe {
let first_value = buf.get_unchecked(0);
*first.get_unchecked_release_mut(null_idx) = *first_value
}
push_to_group(null_idx, row_nr);
}

row_nr += 1;
}
}
unsafe {
first.set_len(len);
}
(groups, first)
};

Expand Down Expand Up @@ -201,7 +168,7 @@ impl CategoricalChunked {
}
// on relative small tables this isn't much faster than the default strategy
// but on huge tables, this can be > 2x faster
cats.group_tuples_perfect(cached.len() - 1, multithreaded, 0)
unsafe { cats.group_tuples_perfect(cached.len() - 1, multithreaded, 0) }
} else {
self.physical().group_tuples(multithreaded, sorted).unwrap()
}
Expand All @@ -220,26 +187,3 @@ impl CategoricalChunked {
out
}
}

#[repr(C, align(64))]
struct AlignTo64([u8; 64]);

/// There are no guarantees that the [`Vec<T>`] will remain aligned if you reallocate the data.
/// This means that you cannot reallocate so you will need to know how big to allocate up front.
unsafe fn aligned_vec<T>(n: usize) -> Vec<T> {
assert!(align_of::<T>() <= 64);
let n_units = (n * size_of::<T>() / size_of::<AlignTo64>()) + 1;

let mut aligned: Vec<AlignTo64> = Vec::with_capacity(n_units);

let ptr = aligned.as_mut_ptr();
let cap_units = aligned.capacity();

std::mem::forget(aligned);

Vec::from_raw_parts(
ptr as *mut T,
0,
cap_units * size_of::<AlignTo64>() / size_of::<T>(),
)
}