Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DataFrame supports explode by array column #13958

Merged
merged 1 commit into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 58 additions & 18 deletions crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,27 +87,63 @@ impl ChunkExplode for ListChunked {
#[cfg(feature = "dtype-array")]
impl ChunkExplode for ArrayChunked {
fn offsets(&self) -> PolarsResult<OffsetsBuffer<i64>> {
let width = self.width() as i64;
let offsets = (0..self.len() + 1)
// fast-path for non-null array.
if self.null_count() == 0 {
let width = self.width() as i64;
let offsets = (0..self.len() + 1)
.map(|i| {
let i = i as i64;
i * width
})
.collect::<Vec<_>>();
// safety: monotonically increasing
let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };

return Ok(offsets);
}

let ca = self.rechunk();
let arr = ca.downcast_iter().next().unwrap();
// we have already ensure that validity is not none.
let validity = arr.validity().unwrap();
let width = arr.size();

let mut current_offset = 0i64;
let offsets = (0..=arr.len())
.map(|i| {
let i = i as i64;
i * width
if i == 0 {
return current_offset;
}
// Safety: we are within bounds
if unsafe { validity.get_bit_unchecked(i - 1) } {
current_offset += width as i64
}
current_offset
})
.collect::<Vec<_>>();
// safety: monotonically increasing
let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };

Ok(offsets)
}

fn explode(&self) -> PolarsResult<Series> {
fn explode_and_offsets(&self) -> PolarsResult<(Series, OffsetsBuffer<i64>)> {
let ca = self.rechunk();
let arr = ca.downcast_iter().next().unwrap();
// fast-path for non-null array.
if arr.null_count() == 0 {
return Series::try_from((self.name(), arr.values().clone()))
let s = Series::try_from((self.name(), arr.values().clone()))
.unwrap()
.cast(&ca.inner_dtype());
.cast(&ca.inner_dtype())?;
let width = self.width() as i64;
let offsets = (0..self.len() + 1)
.map(|i| {
let i = i as i64;
i * width
})
.collect::<Vec<_>>();
// safety: monotonically increasing
let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };
return Ok((s, offsets));
}

// we have already ensure that validity is not none.
Expand All @@ -118,30 +154,34 @@ impl ChunkExplode for ArrayChunked {
let mut indices = MutablePrimitiveArray::<IdxSize>::with_capacity(
values.len() - arr.null_count() * (width - 1),
);
let mut offsets = Vec::with_capacity(arr.len() + 1);
let mut current_offset = 0i64;
offsets.push(current_offset);
(0..arr.len()).for_each(|i| {
// Safety: we are within bounds
if unsafe { validity.get_bit_unchecked(i) } {
let start = (i * width) as IdxSize;
let end = start + width as IdxSize;
indices.extend_trusted_len_values(start..end);
current_offset += width as i64;
} else {
indices.push_null();
}
offsets.push(current_offset);
});

// Safety: the indices we generate are in bounds
let chunk = unsafe { take_unchecked(&**values, &indices.into()) };
// safety: monotonically increasing
let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };

// Safety: inner_dtype should be correct
Ok(unsafe {
Series::from_chunks_and_dtype_unchecked(ca.name(), vec![chunk], &ca.inner_dtype())
})
}

fn explode_and_offsets(&self) -> PolarsResult<(Series, OffsetsBuffer<i64>)> {
let s = self.explode().unwrap();

Ok((s, self.offsets()?))
Ok((
// Safety: inner_dtype should be correct
unsafe {
Series::from_chunks_and_dtype_unchecked(ca.name(), vec![chunk], &ca.inner_dtype())
},
offsets,
))
}
}

Expand Down
64 changes: 64 additions & 0 deletions py-polars/tests/unit/operations/test_explode.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,67 @@ def test_explode_null_struct() -> None:
{"field1": None, "field2": "some", "field3": "value"},
]
}


def test_df_explode_with_array() -> None:
df = pl.DataFrame(
{
"arr": [["a", "b"], ["c", None], None, ["d", "e"]],
"list": [[1, 2], [3], [4, None], None],
"val": ["x", "y", "z", "q"],
},
schema={
"arr": pl.Array(pl.String, 2),
"list": pl.List(pl.Int64),
"val": pl.String,
},
)

expected_by_arr = pl.DataFrame(
{
"arr": ["a", "b", "c", None, None, "d", "e"],
"list": [[1, 2], [1, 2], [3], [3], [4, None], None, None],
"val": ["x", "x", "y", "y", "z", "q", "q"],
}
)
assert_frame_equal(df.explode(pl.col("arr")), expected_by_arr)

expected_by_list = pl.DataFrame(
{
"arr": [["a", "b"], ["a", "b"], ["c", None], None, None, ["d", "e"]],
"list": [1, 2, 3, 4, None, None],
"val": ["x", "x", "y", "z", "z", "q"],
},
schema={
"arr": pl.Array(pl.String, 2),
"list": pl.Int64,
"val": pl.String,
},
)
assert_frame_equal(df.explode(pl.col("list")), expected_by_list)

df = pl.DataFrame(
{
"arr": [["a", "b"], ["c", None], None, ["d", "e"]],
"list": [[1, 2], [3, 4], None, [5, None]],
"val": [None, 1, 2, None],
},
schema={
"arr": pl.Array(pl.String, 2),
"list": pl.List(pl.Int64),
"val": pl.Int64,
},
)
expected_by_arr_and_list = pl.DataFrame(
{
"arr": ["a", "b", "c", None, None, "d", "e"],
"list": [1, 2, 3, 4, None, 5, None],
"val": [None, None, 1, 1, 2, None, None],
},
schema={
"arr": pl.String,
"list": pl.Int64,
"val": pl.Int64,
},
)
assert_frame_equal(df.explode("arr", "list"), expected_by_arr_and_list)