Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Rewrite implementation of top_k/bottom_k and fix a variety of bugs #16804

Merged
merged 7 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 6 additions & 22 deletions crates/polars-arrow/src/array/binview/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ impl<T: ViewType + ?Sized> BinaryViewArrayGeneric<T> {
&self.views
}

pub fn into_views(self) -> Vec<View> {
self.views.make_mut()
}

pub fn try_new(
data_type: ArrowDataType,
views: Buffer<View>,
Expand Down Expand Up @@ -265,28 +269,8 @@ impl<T: ViewType + ?Sized> BinaryViewArrayGeneric<T> {
/// Assumes that the `i < self.len`.
#[inline]
pub unsafe fn value_unchecked(&self, i: usize) -> &T {
let v = *self.views.get_unchecked_release(i);
let len = v.length;

// view layout:
// length: 4 bytes
// prefix: 4 bytes
// buffer_index: 4 bytes
// offset: 4 bytes

// inlined layout:
// length: 4 bytes
// data: 12 bytes

let bytes = if len <= 12 {
let ptr = self.views.as_ptr() as *const u8;
std::slice::from_raw_parts(ptr.add(i * 16 + 4), len as usize)
} else {
let data = self.buffers.get_unchecked_release(v.buffer_idx as usize);
let offset = v.offset as usize;
data.get_unchecked_release(offset..offset + len as usize)
};
T::from_bytes_unchecked(bytes)
let v = self.views.get_unchecked_release(i);
T::from_bytes_unchecked(v.get_slice_unchecked(&self.buffers))
}

/// Returns an iterator of `Option<&T>` over every element of this array.
Expand Down
22 changes: 14 additions & 8 deletions crates/polars-arrow/src/array/binview/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,24 +343,30 @@ impl<T: ViewType + ?Sized> MutableBinaryViewArray<T> {
/// Assumes that the `i < self.len`.
#[inline]
pub unsafe fn value_unchecked(&self, i: usize) -> &T {
let v = *self.views.get_unchecked(i);
let len = v.length;
self.value_from_view_unchecked(self.views.get_unchecked(i))
}

// view layout:
/// Returns the element indicated by the given view.
///
/// # Safety
/// Assumes the View belongs to this MutableBinaryViewArray.
pub unsafe fn value_from_view_unchecked<'a>(&'a self, view: &'a View) -> &'a T {
// View layout:
// length: 4 bytes
// prefix: 4 bytes
// buffer_index: 4 bytes
// offset: 4 bytes

// inlined layout:
// Inlined layout:
// length: 4 bytes
// data: 12 bytes
let len = view.length;
let bytes = if len <= 12 {
let ptr = self.views.as_ptr() as *const u8;
std::slice::from_raw_parts(ptr.add(i * 16 + 4), len as usize)
let ptr = view as *const View as *const u8;
std::slice::from_raw_parts(ptr.add(4), len as usize)
} else {
let buffer_idx = v.buffer_idx as usize;
let offset = v.offset;
let buffer_idx = view.buffer_idx as usize;
let offset = view.offset;

let data = if buffer_idx == self.completed_buffers.len() {
self.in_progress_buffer.as_slice()
Expand Down
17 changes: 17 additions & 0 deletions crates/polars-arrow/src/array/binview/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,23 @@ impl View {
}
}
}

/// Constructs a byteslice from this view.
///
/// # Safety
/// Assumes that this view is valid for the given buffers.
pub unsafe fn get_slice_unchecked<'a>(&'a self, buffers: &'a [Buffer<u8>]) -> &'a [u8] {
unsafe {
if self.length <= 12 {
let ptr = self as *const View as *const u8;
std::slice::from_raw_parts(ptr.add(4), self.length as usize)
} else {
let data = buffers.get_unchecked_release(self.buffer_idx as usize);
let offset = self.offset as usize;
data.get_unchecked_release(offset..offset + self.length as usize)
}
}
}
}

impl IsNull for View {
Expand Down
45 changes: 21 additions & 24 deletions crates/polars-compute/src/filter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,53 +9,50 @@ mod avx512;
use arrow::array::growable::make_growable;
use arrow::array::{new_empty_array, Array, BinaryViewArray, BooleanArray, PrimitiveArray};
use arrow::bitmap::utils::SlicesIterator;
use arrow::datatypes::ArrowDataType;
use arrow::bitmap::Bitmap;
use arrow::with_match_primitive_type_full;
use polars_error::PolarsResult;

pub fn filter(array: &dyn Array, mask: &BooleanArray) -> PolarsResult<Box<dyn Array>> {
pub fn filter(array: &dyn Array, mask: &BooleanArray) -> Box<dyn Array> {
assert_eq!(array.len(), mask.len());

// Treat null mask values as false.
if let Some(validities) = mask.validity() {
let values = mask.values();
let new_values = values & validities;
let mask = BooleanArray::new(ArrowDataType::Boolean, new_values, None);
return filter(array, &mask);
let combined_mask = mask.values() & validities;
filter_with_bitmap(array, &combined_mask)
} else {
filter_with_bitmap(array, mask.values())
}
}

pub fn filter_with_bitmap(array: &dyn Array, mask: &Bitmap) -> Box<dyn Array> {
// Fast-path: completely empty or completely full mask.
let false_count = mask.values().unset_bits();
let false_count = mask.unset_bits();
if false_count == mask.len() {
return Ok(new_empty_array(array.data_type().clone()));
return new_empty_array(array.data_type().clone());
}
if false_count == 0 {
return Ok(array.to_boxed());
return array.to_boxed();
}

use arrow::datatypes::PhysicalType::*;
match array.data_type().to_physical_type() {
Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {
let array: &PrimitiveArray<$T> = array.as_any().downcast_ref().unwrap();
let (values, validity) = primitive::filter_values_and_validity::<$T>(array.values(), array.validity(), mask.values());
Ok(Box::new(PrimitiveArray::from_vec(values).with_validity(validity)))
let (values, validity) = primitive::filter_values_and_validity::<$T>(array.values(), array.validity(), mask);
Box::new(PrimitiveArray::from_vec(values).with_validity(validity))
}),
Boolean => {
let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
let (values, validity) = boolean::filter_bitmap_and_validity(
array.values(),
array.validity(),
mask.values(),
);
Ok(BooleanArray::new(array.data_type().clone(), values, validity).boxed())
let (values, validity) =
boolean::filter_bitmap_and_validity(array.values(), array.validity(), mask);
BooleanArray::new(array.data_type().clone(), values, validity).boxed()
},
BinaryView => {
let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
let views = array.views();
let validity = array.validity();
let (views, validity) =
primitive::filter_values_and_validity(views, validity, mask.values());
Ok(unsafe {
let (views, validity) = primitive::filter_values_and_validity(views, validity, mask);
unsafe {
BinaryViewArray::new_unchecked_unknown_md(
array.data_type().clone(),
views.into(),
Expand All @@ -64,19 +61,19 @@ pub fn filter(array: &dyn Array, mask: &BooleanArray) -> PolarsResult<Box<dyn Ar
Some(array.total_buffer_len()),
)
}
.boxed())
.boxed()
},
// Should go via BinaryView
Utf8View => {
unreachable!()
},
_ => {
let iter = SlicesIterator::new(mask.values());
let iter = SlicesIterator::new(mask);
let mut mutable = make_growable(&[array], false, iter.slots());
// SAFETY:
// we are in bounds
iter.for_each(|(start, len)| unsafe { mutable.extend(0, start, len) });
Ok(mutable.as_box())
mutable.as_box()
},
}
}
22 changes: 16 additions & 6 deletions crates/polars-core/src/chunked_array/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ where
unsafe { Self::from_chunks(name, vec![Box::new(arr)]) }
}

pub fn with_chunk_like<A>(ca: &Self, arr: A) -> Self
where
A: Array,
T: PolarsDataType<Array = A>,
{
Self::from_chunk_iter_like(ca, std::iter::once(arr))
}

pub fn from_chunk_iter<I>(name: &str, iter: I) -> Self
where
I: IntoIterator,
Expand Down Expand Up @@ -165,12 +173,14 @@ where
})
.collect();

ChunkedArray::new_with_dims(
field,
chunks,
length.try_into().expect(LENGTH_LIMIT_MSG),
null_count as IdxSize,
)
unsafe {
ChunkedArray::new_with_dims(
field,
chunks,
length.try_into().expect(LENGTH_LIMIT_MSG),
null_count as IdxSize,
)
}
}

/// Create a new [`ChunkedArray`] from existing chunks.
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/logical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl<K: PolarsDataType, T: PolarsDataType> DerefMut for Logical<K, T> {
}

impl<K: PolarsDataType, T: PolarsDataType> Logical<K, T> {
pub(crate) fn new_logical<J: PolarsDataType>(ca: ChunkedArray<T>) -> Logical<J, T> {
pub fn new_logical<J: PolarsDataType>(ca: ChunkedArray<T>) -> Logical<J, T> {
Logical(ca, PhantomData, None)
}
}
Expand Down
39 changes: 35 additions & 4 deletions crates/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;

use arrow::array::*;
use arrow::bitmap::Bitmap;
use polars_compute::filter::filter_with_bitmap;

use crate::prelude::*;

Expand Down Expand Up @@ -148,16 +149,21 @@ impl<T: PolarsDataType> ChunkedArray<T> {
/// If you want to explicitly the `length` and `null_count`, look at
/// [`ChunkedArray::new_with_dims`]
pub fn new_with_compute_len(field: Arc<Field>, chunks: Vec<ArrayRef>) -> Self {
let mut chunked_arr = Self::new_with_dims(field, chunks, 0, 0);
chunked_arr.compute_len();
chunked_arr
unsafe {
let mut chunked_arr = Self::new_with_dims(field, chunks, 0, 0);
chunked_arr.compute_len();
chunked_arr
}
}

/// Create a new [`ChunkedArray`] and explicitly set its `length` and `null_count`.
///
/// If you want to compute the `length` and `null_count`, look at
/// [`ChunkedArray::new_with_compute_len`]
pub fn new_with_dims(
///
/// # Safety
/// The length and null_count must be correct.
pub unsafe fn new_with_dims(
field: Arc<Field>,
chunks: Vec<ArrayRef>,
length: IdxSize,
Expand Down Expand Up @@ -424,6 +430,31 @@ impl<T: PolarsDataType> ChunkedArray<T> {
}
}

pub fn drop_nulls(&self) -> Self {
if self.null_count() == 0 {
self.clone()
} else {
let chunks = self
.downcast_iter()
.map(|arr| {
if arr.null_count() == 0 {
arr.to_boxed()
} else {
filter_with_bitmap(arr, arr.validity().unwrap())
}
})
.collect();
unsafe {
Self::new_with_dims(
self.field.clone(),
chunks,
(self.len() - self.null_count()) as IdxSize,
0,
)
}
}
}

/// Get the buffer of bits representing null values
#[inline]
#[allow(clippy::type_complexity)]
Expand Down
10 changes: 7 additions & 3 deletions crates/polars-core/src/chunked_array/object/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ where

self.field.dtype = get_object_type::<T>();

ChunkedArray::new_with_dims(Arc::new(self.field), vec![arr], len as IdxSize, null_count)
unsafe {
ChunkedArray::new_with_dims(Arc::new(self.field), vec![arr], len as IdxSize, null_count)
}
}
}

Expand Down Expand Up @@ -141,7 +143,7 @@ where
len,
});

ObjectChunked::new_with_dims(field, vec![arr], len as IdxSize, 0)
unsafe { ObjectChunked::new_with_dims(field, vec![arr], len as IdxSize, 0) }
}

pub fn new_from_vec_and_validity(name: &str, v: Vec<T>, validity: Bitmap) -> Self {
Expand All @@ -155,7 +157,9 @@ where
len,
});

ObjectChunked::new_with_dims(field, vec![arr], len as IdxSize, null_count as IdxSize)
unsafe {
ObjectChunked::new_with_dims(field, vec![arr], len as IdxSize, null_count as IdxSize)
}
}

pub fn new_empty(name: &str) -> Self {
Expand Down
12 changes: 6 additions & 6 deletions crates/polars-core/src/chunked_array/ops/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ where
arity::binary_unchecked_same_type(
self,
filter,
|left, mask| filter_fn(left, mask).unwrap(),
|left, mask| filter_fn(left, mask),
true,
true,
)
Expand All @@ -53,7 +53,7 @@ impl ChunkFilter<BooleanType> for BooleanChunked {
arity::binary_unchecked_same_type(
self,
filter,
|left, mask| filter_fn(left, mask).unwrap(),
|left, mask| filter_fn(left, mask),
true,
true,
)
Expand Down Expand Up @@ -82,7 +82,7 @@ impl ChunkFilter<BinaryType> for BinaryChunked {
arity::binary_unchecked_same_type(
self,
filter,
|left, mask| filter_fn(left, mask).unwrap(),
|left, mask| filter_fn(left, mask),
true,
true,
)
Expand All @@ -104,7 +104,7 @@ impl ChunkFilter<BinaryOffsetType> for BinaryOffsetChunked {
arity::binary_unchecked_same_type(
self,
filter,
|left, mask| filter_fn(left, mask).unwrap(),
|left, mask| filter_fn(left, mask),
true,
true,
)
Expand All @@ -129,7 +129,7 @@ impl ChunkFilter<ListType> for ListChunked {
arity::binary_unchecked_same_type(
self,
filter,
|left, mask| filter_fn(left, mask).unwrap(),
|left, mask| filter_fn(left, mask),
true,
true,
)
Expand All @@ -155,7 +155,7 @@ impl ChunkFilter<FixedSizeListType> for ArrayChunked {
arity::binary_unchecked_same_type(
self,
filter,
|left, mask| filter_fn(left, mask).unwrap(),
|left, mask| filter_fn(left, mask),
true,
true,
)
Expand Down
Loading
Loading