Skip to content

Commit

Permalink
fix: Reset if next caller clones inner series (#16812)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jun 7, 2024
1 parent 1dce3f4 commit 38149d6
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 127 deletions.
90 changes: 57 additions & 33 deletions crates/polars-core/src/chunked_array/array/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use std::ptr::NonNull;

use super::*;
use crate::chunked_array::list::iterator::AmortizedListIter;
use crate::series::unstable::{ArrayBox, UnstableSeries};
use crate::series::unstable::{unstable_series_container_and_ptr, ArrayBox, UnstableSeries};

impl ArrayChunked {
/// This is an iterator over a [`ListChunked`] that save allocations.
/// This is an iterator over a [`ArrayChunked`] that save allocations.
/// A Series is:
/// 1. [`Arc<ChunkedArray>`]
/// ChunkedArray is:
Expand All @@ -21,11 +21,37 @@ impl ArrayChunked {
/// this function still needs precautions. The returned should never be cloned or taken longer
/// than a single iteration, as every call on `next` of the iterator will change the contents of
/// that Series.
pub fn amortized_iter(&self) -> AmortizedListIter<impl Iterator<Item = Option<ArrayBox>> + '_> {
///
/// # Safety
/// The lifetime of [UnstableSeries] is bound to the iterator. Keeping it alive
/// longer than the iterator is UB.
pub unsafe fn amortized_iter(
&self,
) -> AmortizedListIter<impl Iterator<Item = Option<ArrayBox>> + '_> {
self.amortized_iter_with_name("")
}

pub fn amortized_iter_with_name(
/// This is an iterator over a [`ArrayChunked`] that save allocations.
/// A Series is:
/// 1. [`Arc<ChunkedArray>`]
/// ChunkedArray is:
/// 2. Vec< 3. ArrayRef>
///
/// The [`ArrayRef`] we indicated with 3. will be updated during iteration.
/// The Series will be pinned in memory, saving an allocation for
/// 1. Arc<..>
/// 2. Vec<...>
///
/// # Warning
/// Though memory safe in the sense that it will not read unowned memory, UB, or memory leaks
/// this function still needs precautions. The returned should never be cloned or taken longer
/// than a single iteration, as every call on `next` of the iterator will change the contents of
/// that Series.
///
/// # Safety
/// The lifetime of [UnstableSeries] is bound to the iterator. Keeping it alive
/// longer than the iterator is UB.
pub unsafe fn amortized_iter_with_name(
&self,
name: &str,
) -> AmortizedListIter<impl Iterator<Item = Option<ArrayBox>> + '_> {
Expand All @@ -46,19 +72,12 @@ impl ArrayChunked {

// SAFETY:
// inner type passed as physical type
let series_container = unsafe {
Box::pin(Series::from_chunks_and_dtype_unchecked(
name,
vec![inner_values.clone()],
&iter_dtype,
))
};

let ptr = series_container.array_ref(0) as *const ArrayRef as *mut ArrayRef;
let (s, ptr) =
unsafe { unstable_series_container_and_ptr(name, inner_values.clone(), &iter_dtype) };

AmortizedListIter::new(
self.len(),
series_container,
s,
NonNull::new(ptr).unwrap(),
self.downcast_iter().flat_map(|arr| arr.iter()),
inner_dtype.clone(),
Expand All @@ -79,22 +98,24 @@ impl ArrayChunked {
.clone());
}
let mut fast_explode = self.null_count() == 0;
let mut ca: ListChunked = self
.amortized_iter()
.map(|opt_v| {
opt_v
.map(|v| {
let out = f(v);
if let Ok(out) = &out {
if out.is_empty() {
fast_explode = false
}
};
out
})
.transpose()
})
.collect::<PolarsResult<_>>()?;
// SAFETY: lifetime of iterator is bound to this functions scope
let mut ca: ListChunked = unsafe {
self.amortized_iter()
.map(|opt_v| {
opt_v
.map(|v| {
let out = f(v);
if let Ok(out) = &out {
if out.is_empty() {
fast_explode = false
}
};
out
})
.transpose()
})
.collect::<PolarsResult<_>>()?
};
ca.rename(self.name());
if fast_explode {
ca.set_fast_explode();
Expand Down Expand Up @@ -181,7 +202,8 @@ impl ArrayChunked {
F: FnMut(Option<UnstableSeries<'a>>) -> Option<K> + Copy,
V::Array: ArrayFromIter<Option<K>>,
{
self.amortized_iter().map(f).collect_ca(self.name())
// SAFETY: lifetime of iterator is bound to this functions scope
unsafe { self.amortized_iter().map(f).collect_ca(self.name()) }
}

/// Try apply a closure `F` elementwise.
Expand All @@ -191,14 +213,16 @@ impl ArrayChunked {
F: FnMut(Option<UnstableSeries<'a>>) -> PolarsResult<Option<K>> + Copy,
V::Array: ArrayFromIter<Option<K>>,
{
self.amortized_iter().map(f).try_collect_ca(self.name())
// SAFETY: lifetime of iterator is bound to this functions scope
unsafe { self.amortized_iter().map(f).try_collect_ca(self.name()) }
}

pub fn for_each_amortized<'a, F>(&'a self, f: F)
where
F: FnMut(Option<UnstableSeries<'a>>),
{
self.amortized_iter().for_each(f)
// SAFETY: lifetime of iterator is bound to this functions scope
unsafe { self.amortized_iter().for_each(f) }
}
}

Expand Down
50 changes: 26 additions & 24 deletions crates/polars-core/src/chunked_array/list/iterator.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use std::marker::PhantomData;
use std::pin::Pin;
use std::ptr::NonNull;

use crate::prelude::*;
use crate::series::unstable::{ArrayBox, UnstableSeries};
use crate::series::unstable::{unstable_series_container_and_ptr, ArrayBox, UnstableSeries};

pub struct AmortizedListIter<'a, I: Iterator<Item = Option<ArrayBox>>> {
len: usize,
series_container: Pin<Box<Series>>,
series_container: Box<Series>,
inner: NonNull<ArrayRef>,
lifetime: PhantomData<&'a ArrayRef>,
iter: I,
Expand All @@ -19,14 +18,14 @@ pub struct AmortizedListIter<'a, I: Iterator<Item = Option<ArrayBox>>> {
impl<'a, I: Iterator<Item = Option<ArrayBox>>> AmortizedListIter<'a, I> {
pub(crate) fn new(
len: usize,
series_container: Pin<Box<Series>>,
series_container: Series,
inner: NonNull<ArrayRef>,
iter: I,
inner_dtype: DataType,
) -> Self {
Self {
len,
series_container,
series_container: Box::new(series_container),
inner,
lifetime: PhantomData,
iter,
Expand Down Expand Up @@ -64,14 +63,26 @@ impl<'a, I: Iterator<Item = Option<ArrayBox>>> Iterator for AmortizedListIter<'a
);
}
}
// The series is cloned, we make a new container.
if Arc::strong_count(&self.series_container.0) > 1 {
let (s, ptr) = unsafe {
unstable_series_container_and_ptr(
self.series_container.name(),
array_ref,
self.series_container.dtype(),
)
};
*self.series_container.as_mut() = s;
self.inner = NonNull::new(ptr).unwrap();
} else {
// update the inner state
unsafe { *self.inner.as_mut() = array_ref };

// update the inner state
unsafe { *self.inner.as_mut() = array_ref };

// last iteration could have set the sorted flag (e.g. in compute_len)
self.series_container.clear_flags();
// make sure that the length is correct
self.series_container._get_inner_mut().compute_len();
// last iteration could have set the sorted flag (e.g. in compute_len)
self.series_container.clear_flags();
// make sure that the length is correct
self.series_container._get_inner_mut().compute_len();
}

// SAFETY:
// we cannot control the lifetime of an iterators `next` method.
Expand Down Expand Up @@ -145,21 +156,12 @@ impl ListChunked {

// SAFETY:
// inner type passed as physical type
let series_container = unsafe {
let mut s = Series::from_chunks_and_dtype_unchecked(
name,
vec![inner_values.clone()],
&iter_dtype,
);
s.clear_flags();
Box::pin(s)
};

let ptr = series_container.array_ref(0) as *const ArrayRef as *mut ArrayRef;
let (s, ptr) =
unsafe { unstable_series_container_and_ptr(name, inner_values.clone(), &iter_dtype) };

AmortizedListIter::new(
self.len(),
series_container,
s,
NonNull::new(ptr).unwrap(),
self.downcast_iter().flat_map(|arr| arr.iter()),
inner_dtype.clone(),
Expand Down
33 changes: 27 additions & 6 deletions crates/polars-core/src/series/unstable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ impl<'a> UnstableSeries<'a> {
#[inline]
/// Swaps inner state with the `array`. Prefer `UnstableSeries::with_array` as this
/// restores the state.
pub fn swap(&mut self, array: &mut ArrayRef) {
unsafe { std::mem::swap(self.inner.as_mut(), array) }
/// # Safety
/// This swaps an underlying pointer that might be hold by other cloned series.
pub unsafe fn swap(&mut self, array: &mut ArrayRef) {
std::mem::swap(self.inner.as_mut(), array);
// ensure lengths are correct.
self.as_mut()._get_inner_mut().compute_len();
}
Expand All @@ -84,9 +86,28 @@ impl<'a> UnstableSeries<'a> {
where
F: Fn(&UnstableSeries) -> T,
{
self.swap(array);
let out = f(self);
self.swap(array);
out
unsafe {
self.swap(array);
let out = f(self);
self.swap(array);
out
}
}
}

// SAFETY:
// type must be matching
pub(crate) unsafe fn unstable_series_container_and_ptr(
name: &str,
inner_values: ArrayRef,
iter_dtype: &DataType,
) -> (Series, *mut ArrayRef) {
let series_container = {
let mut s = Series::from_chunks_and_dtype_unchecked(name, vec![inner_values], iter_dtype);
s.clear_flags();
s
};

let ptr = series_container.array_ref(0) as *const ArrayRef as *mut ArrayRef;
(series_container, ptr)
}
2 changes: 1 addition & 1 deletion crates/polars-expr/src/expressions/group_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ impl<'a> Iterator for FlatIter<'a> {
} else {
if self.chunk_offset < self.current_array.len() {
let mut arr = unsafe { self.current_array.sliced_unchecked(self.chunk_offset, 1) };
self.item.swap(&mut arr);
unsafe { self.item.swap(&mut arr) };
} else {
match self.chunks.pop() {
Some(arr) => {
Expand Down
10 changes: 6 additions & 4 deletions crates/polars-ops/src/chunked_array/array/dispersion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ pub(super) fn std_with_nulls(ca: &ArrayChunked, ddof: u8) -> PolarsResult<Series
out.into_duration(*tu).into_series()
},
_ => {
let out: Float64Chunked = ca
.amortized_iter()
.map(|s| s.and_then(|s| s.as_ref().std(ddof)))
.collect();
// SAFETY: lifetime of iterator bound to scope of function
let out: Float64Chunked = unsafe {
ca.amortized_iter()
.map(|s| s.and_then(|s| s.as_ref().std(ddof)))
.collect()
};
out.into_series()
},
};
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-ops/src/chunked_array/array/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ fn join_many(
let mut buf = String::new();
let mut builder = StringChunkedBuilder::new(ca.name(), ca.len());

ca.amortized_iter()
// SAFETY: lifetime of iterator bound to scope of function
unsafe { ca.amortized_iter() }
.zip(separator)
.for_each(|(opt_s, opt_sep)| match opt_sep {
Some(separator) => {
Expand Down
Loading

0 comments on commit 38149d6

Please sign in to comment.